Coverage for mpcforces_extractor\database\database.py: 74%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-10-31 16:49 +0100

1from typing import List, Dict 

2from fastapi import HTTPException 

3from sqlmodel import Session, create_engine, SQLModel, Field, select, Column, JSON, text 

4from mpcforces_extractor.datastructure.rigids import MPC 

5from mpcforces_extractor.datastructure.rigids import MPC_CONFIG 

6from mpcforces_extractor.datastructure.entities import Node 

7from mpcforces_extractor.datastructure.subcases import Subcase 

8 

9 

10class RunExtractorRequest(SQLModel, table=False): 

11 """ 

12 Request model for running the extractor 

13 """ 

14 

15 fem_filename: str 

16 mpcf_filename: str 

17 

18 

19class DatabaseRequest(SQLModel, table=False): 

20 """ 

21 Request model for running the extractor 

22 """ 

23 

24 database_filename: str 

25 

26 

27class MPCDBModel(SQLModel, table=True): 

28 """ 

29 Database Representation of MPC Class 

30 """ 

31 

32 id: int = Field(primary_key=True) 

33 config: str = Field() # Store MPC_CONFIG as a string 

34 master_node: int = Field() # Store master node as an integer 

35 nodes: str = Field() # Store nodes as a string 

36 part_id2nodes: Dict = Field( 

37 default_factory=dict, sa_column=Column(JSON) 

38 ) # Store part_id2nodes as a dictionary 

39 subcase_id2part_id2forces: Dict = Field( 

40 default_factory=dict, sa_column=Column(JSON) 

41 ) # Store subcase_id2part_id2forces as a dictionary 

42 

43 def to_mpc(self): 

44 """ 

45 Method to convert MPCDBModel back to MPC object if needed 

46 """ 

47 print(f"Converting MPCDBModel to MPC: id={self.id}, nodes={self.nodes}") 

48 nodes_list = ( 

49 str(self.nodes).split(",") if self.nodes else [] 

50 ) # Add a check to avoid splitting None 

51 mpc = MPC( 

52 element_id=self.id, 

53 mpc_config=MPC_CONFIG[self.config], # Convert string back to enum 

54 master_node=Node.node_id2node[ 

55 self.master_node 

56 ], # Handle node conversion as needed 

57 nodes=[Node.node_id2node[int(node_id)] for node_id in nodes_list], 

58 dofs="", 

59 ) 

60 mpc.part_id2node_ids = self.part_id2nodes 

61 return mpc 

62 

63 

64class NodeDBModel(SQLModel, table=True): 

65 """ 

66 Database Representation of Node Instance 

67 """ 

68 

69 id: int = Field(primary_key=True) 

70 coord_x: float = Field() 

71 coord_y: float = Field() 

72 coord_z: float = Field() 

73 

74 def to_node(self): 

75 """ 

76 Method to convert NodeDBModel back to Node object if needed 

77 """ 

78 return Node(node_id=self.id, coords=[self.coord_x, self.coord_y, self.coord_z]) 

79 

80 

81class SubcaseDBModel(SQLModel, table=True): 

82 """ 

83 Database Representation of Subcase Class 

84 """ 

85 

86 id: int = Field(primary_key=True) 

87 node_id2forces: Dict = Field(default_factory=dict, sa_column=Column(JSON)) 

88 time: float = Field() 

89 

90 def to_subcase(self): 

91 """ 

92 Method to convert SubcaseDBModel back to Subcase object if needed 

93 """ 

94 subcase = Subcase(subcase_id=self.id, time=self.time) 

95 for node_id, forces in self.node_id2forces.items(): 

96 subcase.add_force(node_id, forces) 

97 return subcase 

98 

99 

100class MPCDatabase: 

101 """ 

102 A Database class to store MPC instances, Nodes and Subcases 

103 """ 

104 

105 def __init__(self, file_path: str): 

106 """ 

107 Development database creation and population 

108 """ 

109 

110 # Initialize the database 

111 self.engine = None 

112 self.mpcs = {} 

113 self.subcases = {} 

114 

115 self.engine = create_engine(f"sqlite:///{file_path}") 

116 

117 def close(self): 

118 """ 

119 Close the database connection 

120 """ 

121 self.engine.dispose() 

122 self.engine = None 

123 

124 def reinitialize_db(self, file_path: str): 

125 """ 

126 Reinitialize the database with the data from the file 

127 """ 

128 self.engine = create_engine(f"sqlite:///{file_path}") 

129 with Session(self.engine) as session: 

130 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()} 

131 self.subcases = { 

132 subcase.id: subcase 

133 for subcase in session.exec(select(SubcaseDBModel)).all() 

134 } 

135 

136 def populate_database(self, load_all_nodes=False): 

137 """ 

138 Function to populate the database from MPC instances 

139 """ 

140 # delete the existing data 

141 # drop all tables 

142 with Session(self.engine) as session: 

143 session.exec(text("DROP TABLE IF EXISTS mpcdbmodel")) 

144 session.exec(text("DROP TABLE IF EXISTS nodedbmodel")) 

145 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel")) 

146 

147 # Create the tables again 

148 SQLModel.metadata.create_all(self.engine) 

149 

150 with Session(self.engine) as session: 

151 

152 if load_all_nodes: # Load in all the nodes 

153 for node in Node.node_id2node.values(): 

154 db_node = NodeDBModel( 

155 id=node.id, 

156 coord_x=node.coords[0], 

157 coord_y=node.coords[1], 

158 coord_z=node.coords[2], 

159 ) 

160 session.add(db_node) 

161 else: # load in just the nodes that are used in the MPCs 

162 unique_nodes = set() 

163 for mpc in MPC.id_2_instance.values(): 

164 for node in mpc.nodes: 

165 unique_nodes.add(node) 

166 unique_nodes.add(mpc.master_node) 

167 

168 for node in unique_nodes: 

169 db_node = NodeDBModel( 

170 id=node.id, 

171 coord_x=node.coords[0], 

172 coord_y=node.coords[1], 

173 coord_z=node.coords[2], 

174 ) 

175 session.add(db_node) 

176 

177 for mpc in MPC.id_2_instance.values(): 

178 

179 mpc.get_part_id2force(None) 

180 sub2part2force = mpc.get_subcase_id2part_id2force() 

181 

182 # Convert MPC instance to MPCDBModel 

183 db_mpc = MPCDBModel( 

184 id=mpc.element_id, 

185 config=mpc.mpc_config.name, # Store enum as string 

186 master_node=mpc.master_node.id, 

187 nodes=",".join([str(node.id) for node in mpc.nodes]), 

188 part_id2nodes=mpc.part_id2node_ids, 

189 subcase_id2part_id2forces=sub2part2force, 

190 ) 

191 # Add to the session 

192 session.add(db_mpc) 

193 

194 # Subcases 

195 for subcase in Subcase.subcases: 

196 db_subcase = SubcaseDBModel( 

197 id=subcase.subcase_id, 

198 node_id2forces=subcase.node_id2forces, 

199 time=subcase.time, 

200 ) 

201 session.add(db_subcase) 

202 

203 # Commit to the database 

204 session.commit() 

205 

206 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()} 

207 self.subcases = { 

208 subcase.id: subcase 

209 for subcase in session.exec(select(SubcaseDBModel)).all() 

210 } 

211 

212 async def get_mpcs(self) -> List[MPCDBModel]: 

213 """ 

214 Get all MPCs 

215 """ 

216 return list(self.mpcs.values()) 

217 

218 async def get_mpc(self, mpc_id: int) -> MPCDBModel: 

219 """ 

220 Get a specific MPC 

221 """ 

222 if mpc_id in self.mpcs: 

223 return self.mpcs.get(mpc_id) 

224 raise HTTPException( 

225 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

226 ) 

227 

228 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]: 

229 """ 

230 Get nodes for pagination 

231 """ 

232 with Session(self.engine) as session: 

233 statement = select(NodeDBModel).offset(offset).limit(limit) 

234 return session.exec(statement).all() 

235 

236 async def get_all_nodes(self) -> List[NodeDBModel]: 

237 """ 

238 Get all nodes 

239 """ 

240 with Session(self.engine) as session: 

241 statement = select(NodeDBModel) 

242 return session.exec(statement).all() 

243 

244 async def remove_mpc(self, mpc_id: int): 

245 """ 

246 Remove a specific MPC 

247 """ 

248 if mpc_id in self.mpcs: 

249 del self.mpcs[mpc_id] 

250 else: 

251 raise HTTPException( 

252 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

253 ) 

254 

255 async def get_subcases(self) -> List[SubcaseDBModel]: 

256 """ 

257 Get all subcases 

258 """ 

259 return list(self.subcases.values())