Coverage for mpcforces_extractor\api\routes\nodes.py: 43%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-28 21:26 +0100

1from typing import List 

2from fastapi import APIRouter, Depends, HTTPException, Query 

3from pydantic import BaseModel 

4from mpcforces_extractor.api.db.database import NodeDBModel 

5from mpcforces_extractor.api.dependencies import get_db 

6from mpcforces_extractor.api.config import ITEMS_PER_PAGE 

7from mpcforces_extractor.api.db.database import MPCDatabase 

8 

9router = APIRouter() 

10 

11 

12class FilterDataModel(BaseModel): 

13 """ 

14 Model for filter data. 

15 """ 

16 

17 ids: List[str] # List of strings to handle IDs and ranges 

18 

19 

20# Route to get nodes with pagination, sorting, and filtering 

21@router.post("", response_model=List[NodeDBModel]) 

22async def get_nodes( 

23 page: int = Query(1, ge=1), # Pagination 

24 *, 

25 sort_column: str = Query("id", alias="sortColumn"), # Sorting column 

26 sort_direction: int = Query( 

27 1, ge=-1, le=1, alias="sortDirection" 

28 ), # Sorting direction: 1 (asc) or -1 (desc) 

29 filter_data: FilterDataModel, 

30 db: MPCDatabase = Depends(get_db), # Dependency for DB session 

31 subcase_id: int = Query(None, alias="subcaseId"), 

32) -> List[NodeDBModel]: 

33 """ 

34 Get nodes with pagination, sorting, and optional filtering by IDs. 

35 """ 

36 # Calculate offset based on the current page 

37 offset = (page - 1) * ITEMS_PER_PAGE 

38 

39 # Handle filtering if filter_ids is provided 

40 node_ids = expand_filter_string(filter_data) 

41 

42 # Fetch nodes from the database with the calculated offset, limit, sorting, and filtering 

43 nodes = await db.get_nodes( 

44 offset=offset, 

45 limit=ITEMS_PER_PAGE, 

46 sort_column=sort_column, 

47 sort_direction=sort_direction, 

48 node_ids=node_ids, 

49 subcase_id=subcase_id, 

50 ) 

51 

52 # Handle case when no nodes are found 

53 if not nodes: 

54 raise HTTPException(status_code=404, detail="No nodes found") 

55 

56 return nodes 

57 

58 

59@router.post("/all", response_model=List[NodeDBModel]) 

60async def get_all_nodes(filter_data: FilterDataModel, db=Depends(get_db)) -> int: 

61 """ 

62 Get all nodes 

63 """ 

64 

65 node_ids = expand_filter_string(filter_data) 

66 nodes = await db.get_all_nodes(node_ids) 

67 

68 if not nodes: 

69 raise HTTPException(status_code=404, detail="No nodes found") 

70 

71 return nodes 

72 

73 

74def expand_filter_string(filter_data: FilterDataModel) -> List[int]: 

75 """ 

76 HELPER METHOD 

77 Get nodes filtered by a string, get it from all nodes, not paginated. 

78 The filter can be a range like '1-3' or comma-separated values like '1,2,3'. 

79 """ 

80 filtered_nodes = [] 

81 

82 if not filter_data: 

83 return filtered_nodes 

84 

85 # Split the filter string by comma and process each part 

86 

87 for filter_part in filter_data.ids: 

88 # Check if the filter part contains a range 

89 if "-" in filter_part: 

90 # Split the range by '-' and convert the parts into integers 

91 start, end = map(int, filter_part.split("-")) 

92 

93 # Add the range of nodes to the filtered nodes list 

94 filtered_nodes.extend(range(start, end + 1)) 

95 else: 

96 # Convert the filter part into an integer and add it to the filtered nodes list 

97 filtered_nodes.append(int(filter_part)) 

98 

99 return filtered_nodes