Coverage for mpcforces_extractor\api\routes\nodes.py: 43%
35 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-28 21:26 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-28 21:26 +0100
1from typing import List
2from fastapi import APIRouter, Depends, HTTPException, Query
3from pydantic import BaseModel
4from mpcforces_extractor.api.db.database import NodeDBModel
5from mpcforces_extractor.api.dependencies import get_db
6from mpcforces_extractor.api.config import ITEMS_PER_PAGE
7from mpcforces_extractor.api.db.database import MPCDatabase
9router = APIRouter()
12class FilterDataModel(BaseModel):
13 """
14 Model for filter data.
15 """
17 ids: List[str] # List of strings to handle IDs and ranges
20# Route to get nodes with pagination, sorting, and filtering
21@router.post("", response_model=List[NodeDBModel])
22async def get_nodes(
23 page: int = Query(1, ge=1), # Pagination
24 *,
25 sort_column: str = Query("id", alias="sortColumn"), # Sorting column
26 sort_direction: int = Query(
27 1, ge=-1, le=1, alias="sortDirection"
28 ), # Sorting direction: 1 (asc) or -1 (desc)
29 filter_data: FilterDataModel,
30 db: MPCDatabase = Depends(get_db), # Dependency for DB session
31 subcase_id: int = Query(None, alias="subcaseId"),
32) -> List[NodeDBModel]:
33 """
34 Get nodes with pagination, sorting, and optional filtering by IDs.
35 """
36 # Calculate offset based on the current page
37 offset = (page - 1) * ITEMS_PER_PAGE
39 # Handle filtering if filter_ids is provided
40 node_ids = expand_filter_string(filter_data)
42 # Fetch nodes from the database with the calculated offset, limit, sorting, and filtering
43 nodes = await db.get_nodes(
44 offset=offset,
45 limit=ITEMS_PER_PAGE,
46 sort_column=sort_column,
47 sort_direction=sort_direction,
48 node_ids=node_ids,
49 subcase_id=subcase_id,
50 )
52 # Handle case when no nodes are found
53 if not nodes:
54 raise HTTPException(status_code=404, detail="No nodes found")
56 return nodes
59@router.post("/all", response_model=List[NodeDBModel])
60async def get_all_nodes(filter_data: FilterDataModel, db=Depends(get_db)) -> int:
61 """
62 Get all nodes
63 """
65 node_ids = expand_filter_string(filter_data)
66 nodes = await db.get_all_nodes(node_ids)
68 if not nodes:
69 raise HTTPException(status_code=404, detail="No nodes found")
71 return nodes
74def expand_filter_string(filter_data: FilterDataModel) -> List[int]:
75 """
76 HELPER METHOD
77 Get nodes filtered by a string, get it from all nodes, not paginated.
78 The filter can be a range like '1-3' or comma-separated values like '1,2,3'.
79 """
80 filtered_nodes = []
82 if not filter_data:
83 return filtered_nodes
85 # Split the filter string by comma and process each part
87 for filter_part in filter_data.ids:
88 # Check if the filter part contains a range
89 if "-" in filter_part:
90 # Split the range by '-' and convert the parts into integers
91 start, end = map(int, filter_part.split("-"))
93 # Add the range of nodes to the filtered nodes list
94 filtered_nodes.extend(range(start, end + 1))
95 else:
96 # Convert the filter part into an integer and add it to the filtered nodes list
97 filtered_nodes.append(int(filter_part))
99 return filtered_nodes