Coverage for mpcforces_extractor\database\database.py: 76%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-25 00:11 +0200

1from typing import List, Dict 

2from fastapi import HTTPException 

3from sqlmodel import Session, create_engine, SQLModel, Field, select, Column, JSON, text 

4from mpcforces_extractor.datastructure.rigids import MPC 

5from mpcforces_extractor.datastructure.rigids import MPC_CONFIG 

6from mpcforces_extractor.datastructure.entities import Node 

7from mpcforces_extractor.datastructure.subcases import Subcase 

8 

9 

10class RunExtractorRequest(SQLModel, table=False): 

11 """ 

12 Request model for running the extractor 

13 """ 

14 

15 fem_filename: str 

16 mpcf_filename: str 

17 

18 

19class MPCDBModel(SQLModel, table=True): 

20 """ 

21 Database Representation of MPC Class 

22 """ 

23 

24 id: int = Field(primary_key=True) 

25 config: str = Field() # Store MPC_CONFIG as a string 

26 master_node: int = Field() # Store master node as an integer 

27 nodes: str = Field() # Store nodes as a string 

28 part_id2nodes: Dict = Field( 

29 default_factory=dict, sa_column=Column(JSON) 

30 ) # Store part_id2nodes as a dictionary 

31 subcase_id2part_id2forces: Dict = Field( 

32 default_factory=dict, sa_column=Column(JSON) 

33 ) # Store subcase_id2part_id2forces as a dictionary 

34 

35 def to_mpc(self): 

36 """ 

37 Method to convert MPCDBModel back to MPC object if needed 

38 """ 

39 print(f"Converting MPCDBModel to MPC: id={self.id}, nodes={self.nodes}") 

40 nodes_list = ( 

41 str(self.nodes).split(",") if self.nodes else [] 

42 ) # Add a check to avoid splitting None 

43 mpc = MPC( 

44 element_id=self.id, 

45 mpc_config=MPC_CONFIG[self.config], # Convert string back to enum 

46 master_node=Node.node_id2node[ 

47 self.master_node 

48 ], # Handle node conversion as needed 

49 nodes=[Node.node_id2node[int(node_id)] for node_id in nodes_list], 

50 dofs="", 

51 ) 

52 mpc.part_id2node_ids = self.part_id2nodes 

53 return mpc 

54 

55 

56class NodeDBModel(SQLModel, table=True): 

57 """ 

58 Database Representation of Node Instance 

59 """ 

60 

61 id: int = Field(primary_key=True) 

62 coord_x: float = Field() 

63 coord_y: float = Field() 

64 coord_z: float = Field() 

65 

66 def to_node(self): 

67 """ 

68 Method to convert NodeDBModel back to Node object if needed 

69 """ 

70 return Node(node_id=self.id, coords=[self.coord_x, self.coord_y, self.coord_z]) 

71 

72 

73class SubcaseDBModel(SQLModel, table=True): 

74 """ 

75 Database Representation of Subcase Class 

76 """ 

77 

78 id: int = Field(primary_key=True) 

79 node_id2forces: Dict = Field(default_factory=dict, sa_column=Column(JSON)) 

80 time: float = Field() 

81 

82 def to_subcase(self): 

83 """ 

84 Method to convert SubcaseDBModel back to Subcase object if needed 

85 """ 

86 subcase = Subcase(subcase_id=self.id, time=self.time) 

87 for node_id, forces in self.node_id2forces.items(): 

88 subcase.add_force(node_id, forces) 

89 return subcase 

90 

91 

92class MPCDatabase: 

93 """ 

94 A Database class to store MPC instances, Nodes and Subcases 

95 """ 

96 

97 def __init__(self): 

98 """ 

99 Development database creation and population 

100 """ 

101 

102 # Create the SQLite engine 

103 self.engine = create_engine("sqlite:///db.db") 

104 

105 # remover all values from the database 

106 with Session(self.engine) as session: 

107 

108 session.exec(text("DELETE FROM mpcdbmodel")) 

109 session.exec(text("DELETE FROM nodedbmodel")) 

110 session.exec(text("DELETE FROM subcasedbmodel")) 

111 session.commit() 

112 

113 # Drop existing tables for development purposes 

114 SQLModel.metadata.drop_all(self.engine) 

115 

116 # Create the tables 

117 SQLModel.metadata.create_all(self.engine) 

118 

119 self.populate_database() 

120 

121 # Read from the database 

122 with Session(self.engine) as session: 

123 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()} 

124 self.subcases = { 

125 subcase.id: subcase 

126 for subcase in session.exec(select(SubcaseDBModel)).all() 

127 } 

128 

129 def populate_database(self, load_all_nodes=False): 

130 """ 

131 Function to populate the database from MPC instances 

132 """ 

133 # delete the existing data 

134 

135 with Session(self.engine) as session: 

136 

137 if load_all_nodes: # Load in all the nodes 

138 for node in Node.node_id2node.values(): 

139 db_node = NodeDBModel( 

140 id=node.id, 

141 coord_x=node.coords[0], 

142 coord_y=node.coords[1], 

143 coord_z=node.coords[2], 

144 ) 

145 session.add(db_node) 

146 else: # load in just the nodes that are used in the MPCs 

147 unique_nodes = set() 

148 for mpc in MPC.id_2_instance.values(): 

149 for node in mpc.nodes: 

150 unique_nodes.add(node) 

151 unique_nodes.add(mpc.master_node) 

152 

153 for node in unique_nodes: 

154 db_node = NodeDBModel( 

155 id=node.id, 

156 coord_x=node.coords[0], 

157 coord_y=node.coords[1], 

158 coord_z=node.coords[2], 

159 ) 

160 session.add(db_node) 

161 

162 for mpc in MPC.id_2_instance.values(): 

163 

164 mpc.get_part_id2force(None) 

165 sub2part2force = mpc.get_subcase_id2part_id2force() 

166 

167 # Convert MPC instance to MPCDBModel 

168 db_mpc = MPCDBModel( 

169 id=mpc.element_id, 

170 config=mpc.mpc_config.name, # Store enum as string 

171 master_node=mpc.master_node.id, 

172 nodes=",".join([str(node.id) for node in mpc.nodes]), 

173 part_id2nodes=mpc.part_id2node_ids, 

174 subcase_id2part_id2forces=sub2part2force, 

175 ) 

176 # Add to the session 

177 session.add(db_mpc) 

178 

179 # Subcases 

180 for subcase in Subcase.subcases: 

181 db_subcase = SubcaseDBModel( 

182 id=subcase.subcase_id, 

183 node_id2forces=subcase.node_id2forces, 

184 time=subcase.time, 

185 ) 

186 session.add(db_subcase) 

187 

188 # Commit to the database 

189 session.commit() 

190 

191 async def get_mpcs(self) -> List[MPCDBModel]: 

192 """ 

193 Get all MPCs 

194 """ 

195 return list(self.mpcs.values()) 

196 

197 async def get_mpc(self, mpc_id: int) -> MPCDBModel: 

198 """ 

199 Get a specific MPC 

200 """ 

201 if mpc_id in self.mpcs: 

202 return self.mpcs.get(mpc_id) 

203 raise HTTPException( 

204 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

205 ) 

206 

207 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]: 

208 """ 

209 Get nodes for pagination 

210 """ 

211 with Session(self.engine) as session: 

212 statement = select(NodeDBModel).offset(offset).limit(limit) 

213 return session.exec(statement).all() 

214 

215 async def get_all_nodes(self) -> List[NodeDBModel]: 

216 """ 

217 Get all nodes 

218 """ 

219 with Session(self.engine) as session: 

220 statement = select(NodeDBModel) 

221 return session.exec(statement).all() 

222 

223 async def remove_mpc(self, mpc_id: int): 

224 """ 

225 Remove a specific MPC 

226 """ 

227 if mpc_id in self.mpcs: 

228 del self.mpcs[mpc_id] 

229 else: 

230 raise HTTPException( 

231 status_code=404, detail=f"MPC with id {mpc_id} does not exist" 

232 ) 

233 

234 async def get_subcases(self) -> List[SubcaseDBModel]: 

235 """ 

236 Get all subcases 

237 """ 

238 return list(self.subcases.values())