Coverage for mpcforces_extractor\api\db\database.py: 94%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-04 17:42 +0100

1from typing import List 

2from fastapi import HTTPException 

3from sqlmodel import Session, create_engine, SQLModel, select, text 

4from mpcforces_extractor.datastructure.rigids import MPC 

5from mpcforces_extractor.datastructure.entities import Node 

6from mpcforces_extractor.datastructure.subcases import Subcase 

7from mpcforces_extractor.api.db.models import ( 

8 MPCDBModel, 

9 NodeDBModel, 

10 SubcaseDBModel, 

11) 

12 

13 

14class MPCDatabase: 

15 """ 

16 A Database class to store MPC instances, Nodes and Subcases 

17 """ 

18 

19 def __init__(self, file_path: str): 

20 """ 

21 Development database creation and population 

22 """ 

23 

24 # Initialize the database 

25 self.engine = None 

26 self.mpcs = {} 

27 self.subcases = {} 

28 

29 self.engine = create_engine(f"sqlite:///{file_path}") 

30 

31 def close(self): 

32 """ 

33 Close the database connection 

34 """ 

35 self.engine.dispose() 

36 self.engine = None 

37 

38 def reinitialize_db(self, file_path: str): 

39 """ 

40 Reinitialize the database with the data from the file 

41 """ 

42 self.engine = create_engine(f"sqlite:///{file_path}") 

43 with Session(self.engine) as session: 

44 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()} 

45 self.subcases = { 

46 subcase.id: subcase 

47 for subcase in session.exec(select(SubcaseDBModel)).all() 

48 } 

49 

50 def populate_database(self, load_all_nodes=False): 

51 """ 

52 Function to populate the database from MPC instances 

53 """ 

54 # delete the existing data 

55 # drop all tables 

56 with Session(self.engine) as session: 

57 session.exec(text("DROP TABLE IF EXISTS mpcdbmodel")) 

58 session.exec(text("DROP TABLE IF EXISTS nodedbmodel")) 

59 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel")) 

60 

61 # Create the tables again 

62 SQLModel.metadata.create_all(self.engine) 

63 

64 with Session(self.engine) as session: 

65 

66 if load_all_nodes: # Load in all the nodes 

67 for node in Node.node_id2node.values(): 

68 db_node = NodeDBModel( 

69 id=node.id, 

70 coord_x=node.coords[0], 

71 coord_y=node.coords[1], 

72 coord_z=node.coords[2], 

73 ) 

74 session.add(db_node) 

75 else: # load in just the nodes that are used in the MPCs 

76 unique_nodes = set() 

77 for mpc in MPC.id_2_instance.values(): 

78 for node in mpc.nodes: 

79 unique_nodes.add(node) 

80 unique_nodes.add(mpc.master_node) 

81 

82 for node in unique_nodes: 

83 db_node = NodeDBModel( 

84 id=node.id, 

85 coord_x=node.coords[0], 

86 coord_y=node.coords[1], 

87 coord_z=node.coords[2], 

88 ) 

89 session.add(db_node) 

90 

91 for mpc in MPC.id_2_instance.values(): 

92 

93 mpc.get_part_id2force(None) 

94 sub2part2force = mpc.get_subcase_id2part_id2force() 

95 

96 # Convert MPC instance to MPCDBModel 

97 db_mpc = MPCDBModel( 

98 id=mpc.element_id, 

99 config=mpc.mpc_config.name, # Store enum as string 

100 master_node=mpc.master_node.id, 

101 nodes=",".join([str(node.id) for node in mpc.nodes]), 

102 part_id2nodes=mpc.part_id2node_ids, 

103 subcase_id2part_id2forces=sub2part2force, 

104 ) 

105 # Add to the session 

106 session.add(db_mpc) 

107 

108 # Subcases 

109 for subcase in Subcase.subcases: 

110 db_subcase = SubcaseDBModel( 

111 id=subcase.subcase_id, 

112 node_id2forces=subcase.node_id2forces, 

113 time=subcase.time, 

114 ) 

115 session.add(db_subcase) 

116 

117 # Commit to the database 

118 session.commit() 

119 

120 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()} 

121 self.subcases = { 

122 subcase.id: subcase 

123 for subcase in session.exec(select(SubcaseDBModel)).all() 

124 } 

125 

126 async def get_mpcs(self) -> List[MPCDBModel]: 

127 """ 

128 Get all MPCs 

129 """ 

130 return list(self.mpcs.values()) 

131 

132 async def get_mpc(self, mpc_id: int) -> MPCDBModel: 

133 """ 

134 Get a specific MPC 

135 """ 

136 if mpc_id in self.mpcs: 

137 return self.mpcs.get(mpc_id) 

138 raise HTTPException( 

139 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

140 ) 

141 

142 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]: 

143 """ 

144 Get nodes for pagination 

145 """ 

146 with Session(self.engine) as session: 

147 statement = select(NodeDBModel).offset(offset).limit(limit) 

148 return session.exec(statement).all() 

149 

150 async def get_all_nodes(self) -> List[NodeDBModel]: 

151 """ 

152 Get all nodes 

153 """ 

154 with Session(self.engine) as session: 

155 statement = select(NodeDBModel) 

156 return session.exec(statement).all() 

157 

158 async def remove_mpc(self, mpc_id: int): 

159 """ 

160 Remove a specific MPC 

161 """ 

162 if mpc_id in self.mpcs: 

163 del self.mpcs[mpc_id] 

164 else: 

165 raise HTTPException( 

166 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

167 ) 

168 

169 async def get_subcases(self) -> List[SubcaseDBModel]: 

170 """ 

171 Get all subcases 

172 """ 

173 return list(self.subcases.values())