Coverage for mpcforces_extractor\api\db\database.py: 77%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-28 21:26 +0100

1from typing import List, Optional 

2from fastapi import HTTPException 

3from sqlmodel import Session, create_engine, SQLModel, select, text 

4from sqlalchemy.sql.expression import asc, desc 

5from mpcforces_extractor.datastructure.rigids import MPC 

6from mpcforces_extractor.datastructure.entities import Node 

7from mpcforces_extractor.datastructure.subcases import Subcase 

8from mpcforces_extractor.api.db.models import ( 

9 RBE2DBModel, 

10 RBE3DBModel, 

11 NodeDBModel, 

12 SubcaseDBModel, 

13) 

14from mpcforces_extractor.datastructure.rigids import MPC_CONFIG 

15 

16 

17class MPCDatabase: 

18 """ 

19 A Database class to store MPC instances, Nodes and Subcases 

20 """ 

21 

22 last_sort_column = "id" 

23 last_sort_direction = 1 

24 last_subcase_id = None 

25 last_query = None 

26 last_filter = None 

27 

28 def __init__(self, file_path: str): 

29 """ 

30 Development database creation and population 

31 """ 

32 

33 # Initialize the database 

34 self.engine = None 

35 self.rbe2s = {} 

36 self.rbe3s = {} 

37 self.subcases = {} 

38 

39 self.engine = create_engine(f"sqlite:///{file_path}") 

40 

41 def close(self): 

42 """ 

43 Close the database connection 

44 """ 

45 self.engine.dispose() 

46 self.engine = None 

47 

48 def reinitialize_db(self, file_path: str): 

49 """ 

50 Reinitialize the database with the data from the file 

51 """ 

52 self.engine = create_engine(f"sqlite:///{file_path}") 

53 with Session(self.engine) as session: 

54 self.rbe2s = { 

55 rbe2.id: rbe2 for rbe2 in session.exec(select(RBE2DBModel)).all() 

56 } 

57 self.rbe3s = { 

58 rbe3.id: rbe3 for rbe3 in session.exec(select(RBE3DBModel)).all() 

59 } 

60 self.subcases = { 

61 subcase.id: subcase 

62 for subcase in session.exec(select(SubcaseDBModel)).all() 

63 } 

64 

65 def populate_database(self, load_all_nodes=False): 

66 """ 

67 Function to populate the database from MPC instances 

68 """ 

69 # delete the existing data 

70 # drop all tables 

71 with Session(self.engine) as session: 

72 session.exec(text("DROP TABLE IF EXISTS RBE2DBModel")) 

73 session.exec(text("DROP TABLE IF EXISTS RBE3DBModel")) 

74 session.exec(text("DROP TABLE IF EXISTS nodedbmodel")) 

75 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel")) 

76 

77 # Create the tables again 

78 SQLModel.metadata.create_all(self.engine) 

79 

80 with Session(self.engine) as session: 

81 

82 self.populate_nodes(load_all_nodes, session) 

83 

84 self.populate_mpcs(session) 

85 

86 # Populate Subcases 

87 for subcase in Subcase.subcases: 

88 db_subcase = SubcaseDBModel( 

89 id=subcase.subcase_id, 

90 node_id2forces=subcase.node_id2forces, 

91 time=subcase.time, 

92 ) 

93 session.add(db_subcase) 

94 

95 # Commit to the database 

96 session.commit() 

97 

98 self.rbe2s = { 

99 rbe2.id: rbe2 for rbe2 in session.exec(select(RBE2DBModel)).all() 

100 } 

101 self.rbe3s = { 

102 rbe3.id: rbe3 for rbe3 in session.exec(select(RBE3DBModel)).all() 

103 } 

104 self.subcases = { 

105 subcase.id: subcase 

106 for subcase in session.exec(select(SubcaseDBModel)).all() 

107 } 

108 

109 def populate_nodes(self, load_all_nodes=False, session=None): 

110 """ 

111 Function to populate the database with nodes 

112 """ 

113 if load_all_nodes: # Load in all the nodes 

114 for node in Node.node_id2node.values(): 

115 db_node = NodeDBModel( 

116 id=node.id, 

117 coord_x=node.coords[0], 

118 coord_y=node.coords[1], 

119 coord_z=node.coords[2], 

120 ) 

121 session.add(db_node) 

122 else: # load in just the nodes that are used in the MPCs 

123 unique_nodes = set() 

124 for mpc_config in MPC_CONFIG: 

125 if mpc_config.value not in MPC.config_2_id_2_instance: 

126 continue 

127 for mpc in MPC.config_2_id_2_instance[mpc_config.value].values(): 

128 for node in mpc.nodes: 

129 unique_nodes.add(node) 

130 unique_nodes.add(mpc.master_node) 

131 

132 for node in unique_nodes: 

133 db_node = NodeDBModel( 

134 id=node.id, 

135 coord_x=node.coords[0], 

136 coord_y=node.coords[1], 

137 coord_z=node.coords[2], 

138 ) 

139 session.add(db_node) 

140 

141 def populate_mpcs(self, session): 

142 """ 

143 Function to populate the database with MPCs 

144 """ 

145 for mpc_config in MPC_CONFIG: 

146 if mpc_config.value not in MPC.config_2_id_2_instance: 

147 continue 

148 for mpc in MPC.config_2_id_2_instance[mpc_config.value].values(): 

149 mpc.get_part_id2force(None) 

150 sub2part2force = mpc.get_subcase_id2part_id2force() 

151 

152 if mpc_config == MPC_CONFIG.RBE2: 

153 db_mpc = RBE2DBModel( 

154 id=mpc.element_id, 

155 config=mpc.mpc_config.name, # Store enum as string 

156 master_node=mpc.master_node.id, 

157 nodes=",".join([str(node.id) for node in mpc.nodes]), 

158 part_id2nodes=mpc.part_id2node_ids, 

159 subcase_id2part_id2forces=sub2part2force, 

160 ) 

161 elif mpc_config == MPC_CONFIG.RBE3: 

162 db_mpc = RBE3DBModel( 

163 id=mpc.element_id, 

164 config=mpc.mpc_config.name, # Store enum as string 

165 master_node=mpc.master_node.id, 

166 nodes=",".join([str(node.id) for node in mpc.nodes]), 

167 part_id2nodes=mpc.part_id2node_ids, 

168 subcase_id2part_id2forces=sub2part2force, 

169 ) 

170 else: 

171 raise ValueError(f"Unknown MPC config {mpc_config}") 

172 # Add to the session 

173 session.add(db_mpc) 

174 

175 async def get_rbe2s(self) -> List[RBE2DBModel]: 

176 """ 

177 Get all MPCs 

178 """ 

179 return list(self.rbe2s.values()) 

180 

181 async def get_rbe3s(self) -> List[RBE3DBModel]: 

182 """ 

183 Get all MPCs 

184 """ 

185 return list(self.rbe3s.values()) 

186 

187 async def get_nodes( 

188 self, 

189 *, 

190 offset: int, 

191 limit: int = 100, 

192 sort_column: str = "id", 

193 sort_direction: int = 1, 

194 node_ids: Optional[List[int]] = None, 

195 subcase_id: Optional[int] = None, 

196 ) -> List[NodeDBModel]: 

197 """ 

198 Get nodes for pagination, sorting, and filtering. 

199 

200 - offset: The offset for pagination. 

201 - limit: The limit for pagination (default: 100). 

202 - sort_column: The column to sort by (default: 'id'). 

203 - sort_direction: The direction of sorting (1 for ascending, -1 for descending). 

204 - node_ids: An optional list of node IDs to filter by (default: None). 

205 """ 

206 

207 # Start a session with the database engine 

208 with Session(self.engine) as session: 

209 

210 # early return if the last query is the same 

211 if self.last_query is not None: 

212 if ( 

213 self.last_sort_column == sort_column 

214 and self.last_sort_direction == sort_direction 

215 and self.last_filter == node_ids 

216 ): 

217 return session.exec( 

218 self.last_query.offset(offset).limit(limit) 

219 ).all() 

220 

221 # Create the base query 

222 query = select(NodeDBModel) 

223 

224 # Apply filtering by node IDs if provided 

225 if node_ids: 

226 query = query.filter(NodeDBModel.id.in_(node_ids)) 

227 

228 # add force data if requested only if the subcase_id is different from a previous request 

229 # 0 for subcase means that its not necessary to add forces data as the request is coords or id 

230 if subcase_id not in (0, self.last_subcase_id): 

231 subcase = self.subcases[subcase_id] 

232 node_id2forces = subcase.node_id2forces 

233 for node_id, forces in node_id2forces.items(): 

234 node = session.exec( 

235 select(NodeDBModel).filter(NodeDBModel.id == node_id) 

236 ).first() 

237 node.fx = forces[0] 

238 node.fy = forces[1] 

239 node.fz = forces[2] 

240 node.fabs = ( 

241 forces[0] ** 2 + forces[1] ** 2 + forces[2] ** 2 

242 ) ** 0.5 

243 node.mx = forces[3] 

244 node.my = forces[4] 

245 node.mz = forces[5] 

246 node.mabs = ( 

247 forces[3] ** 2 + forces[4] ** 2 + forces[5] ** 2 

248 ) ** 0.5 

249 self.last_subcase_id = subcase_id 

250 session.commit() 

251 

252 # Apply sorting based on the specified column and direction 

253 if sort_direction == 1: 

254 query = query.order_by(asc(getattr(NodeDBModel, sort_column))) 

255 elif sort_direction == -1: 

256 query = query.order_by(desc(getattr(NodeDBModel, sort_column))) 

257 

258 # caching for speed 

259 self.last_query = query 

260 self.last_sort_column = sort_column 

261 self.last_sort_direction = sort_direction 

262 self.last_filter = node_ids 

263 

264 # Execute the query and return the results (with pagination) 

265 return session.exec(query.offset(offset).limit(limit)).all() 

266 

267 async def get_all_nodes( 

268 self, node_ids: Optional[List[int]] = None 

269 ) -> List[NodeDBModel]: 

270 """ 

271 Get all nodes 

272 """ 

273 with Session(self.engine) as session: 

274 if node_ids: 

275 statement = select(NodeDBModel).filter(NodeDBModel.id.in_(node_ids)) 

276 else: 

277 statement = select(NodeDBModel) 

278 return session.exec(statement).all() 

279 

280 async def remove_mpc(self, mpc_id: int): 

281 """ 

282 Remove a specific MPC 

283 """ 

284 if mpc_id in self.mpcs: 

285 del self.mpcs[mpc_id] 

286 else: 

287 raise HTTPException( 

288 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

289 ) 

290 

291 async def get_subcases(self) -> List[SubcaseDBModel]: 

292 """ 

293 Get all subcases 

294 """ 

295 return list(self.subcases.values())