Coverage for mpcforces_extractor\api\db\database.py: 77%
124 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-28 21:26 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-28 21:26 +0100
1from typing import List, Optional
2from fastapi import HTTPException
3from sqlmodel import Session, create_engine, SQLModel, select, text
4from sqlalchemy.sql.expression import asc, desc
5from mpcforces_extractor.datastructure.rigids import MPC
6from mpcforces_extractor.datastructure.entities import Node
7from mpcforces_extractor.datastructure.subcases import Subcase
8from mpcforces_extractor.api.db.models import (
9 RBE2DBModel,
10 RBE3DBModel,
11 NodeDBModel,
12 SubcaseDBModel,
13)
14from mpcforces_extractor.datastructure.rigids import MPC_CONFIG
17class MPCDatabase:
18 """
19 A Database class to store MPC instances, Nodes and Subcases
20 """
22 last_sort_column = "id"
23 last_sort_direction = 1
24 last_subcase_id = None
25 last_query = None
26 last_filter = None
28 def __init__(self, file_path: str):
29 """
30 Development database creation and population
31 """
33 # Initialize the database
34 self.engine = None
35 self.rbe2s = {}
36 self.rbe3s = {}
37 self.subcases = {}
39 self.engine = create_engine(f"sqlite:///{file_path}")
41 def close(self):
42 """
43 Close the database connection
44 """
45 self.engine.dispose()
46 self.engine = None
48 def reinitialize_db(self, file_path: str):
49 """
50 Reinitialize the database with the data from the file
51 """
52 self.engine = create_engine(f"sqlite:///{file_path}")
53 with Session(self.engine) as session:
54 self.rbe2s = {
55 rbe2.id: rbe2 for rbe2 in session.exec(select(RBE2DBModel)).all()
56 }
57 self.rbe3s = {
58 rbe3.id: rbe3 for rbe3 in session.exec(select(RBE3DBModel)).all()
59 }
60 self.subcases = {
61 subcase.id: subcase
62 for subcase in session.exec(select(SubcaseDBModel)).all()
63 }
65 def populate_database(self, load_all_nodes=False):
66 """
67 Function to populate the database from MPC instances
68 """
69 # delete the existing data
70 # drop all tables
71 with Session(self.engine) as session:
72 session.exec(text("DROP TABLE IF EXISTS RBE2DBModel"))
73 session.exec(text("DROP TABLE IF EXISTS RBE3DBModel"))
74 session.exec(text("DROP TABLE IF EXISTS nodedbmodel"))
75 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel"))
77 # Create the tables again
78 SQLModel.metadata.create_all(self.engine)
80 with Session(self.engine) as session:
82 self.populate_nodes(load_all_nodes, session)
84 self.populate_mpcs(session)
86 # Populate Subcases
87 for subcase in Subcase.subcases:
88 db_subcase = SubcaseDBModel(
89 id=subcase.subcase_id,
90 node_id2forces=subcase.node_id2forces,
91 time=subcase.time,
92 )
93 session.add(db_subcase)
95 # Commit to the database
96 session.commit()
98 self.rbe2s = {
99 rbe2.id: rbe2 for rbe2 in session.exec(select(RBE2DBModel)).all()
100 }
101 self.rbe3s = {
102 rbe3.id: rbe3 for rbe3 in session.exec(select(RBE3DBModel)).all()
103 }
104 self.subcases = {
105 subcase.id: subcase
106 for subcase in session.exec(select(SubcaseDBModel)).all()
107 }
109 def populate_nodes(self, load_all_nodes=False, session=None):
110 """
111 Function to populate the database with nodes
112 """
113 if load_all_nodes: # Load in all the nodes
114 for node in Node.node_id2node.values():
115 db_node = NodeDBModel(
116 id=node.id,
117 coord_x=node.coords[0],
118 coord_y=node.coords[1],
119 coord_z=node.coords[2],
120 )
121 session.add(db_node)
122 else: # load in just the nodes that are used in the MPCs
123 unique_nodes = set()
124 for mpc_config in MPC_CONFIG:
125 if mpc_config.value not in MPC.config_2_id_2_instance:
126 continue
127 for mpc in MPC.config_2_id_2_instance[mpc_config.value].values():
128 for node in mpc.nodes:
129 unique_nodes.add(node)
130 unique_nodes.add(mpc.master_node)
132 for node in unique_nodes:
133 db_node = NodeDBModel(
134 id=node.id,
135 coord_x=node.coords[0],
136 coord_y=node.coords[1],
137 coord_z=node.coords[2],
138 )
139 session.add(db_node)
141 def populate_mpcs(self, session):
142 """
143 Function to populate the database with MPCs
144 """
145 for mpc_config in MPC_CONFIG:
146 if mpc_config.value not in MPC.config_2_id_2_instance:
147 continue
148 for mpc in MPC.config_2_id_2_instance[mpc_config.value].values():
149 mpc.get_part_id2force(None)
150 sub2part2force = mpc.get_subcase_id2part_id2force()
152 if mpc_config == MPC_CONFIG.RBE2:
153 db_mpc = RBE2DBModel(
154 id=mpc.element_id,
155 config=mpc.mpc_config.name, # Store enum as string
156 master_node=mpc.master_node.id,
157 nodes=",".join([str(node.id) for node in mpc.nodes]),
158 part_id2nodes=mpc.part_id2node_ids,
159 subcase_id2part_id2forces=sub2part2force,
160 )
161 elif mpc_config == MPC_CONFIG.RBE3:
162 db_mpc = RBE3DBModel(
163 id=mpc.element_id,
164 config=mpc.mpc_config.name, # Store enum as string
165 master_node=mpc.master_node.id,
166 nodes=",".join([str(node.id) for node in mpc.nodes]),
167 part_id2nodes=mpc.part_id2node_ids,
168 subcase_id2part_id2forces=sub2part2force,
169 )
170 else:
171 raise ValueError(f"Unknown MPC config {mpc_config}")
172 # Add to the session
173 session.add(db_mpc)
175 async def get_rbe2s(self) -> List[RBE2DBModel]:
176 """
177 Get all MPCs
178 """
179 return list(self.rbe2s.values())
181 async def get_rbe3s(self) -> List[RBE3DBModel]:
182 """
183 Get all MPCs
184 """
185 return list(self.rbe3s.values())
187 async def get_nodes(
188 self,
189 *,
190 offset: int,
191 limit: int = 100,
192 sort_column: str = "id",
193 sort_direction: int = 1,
194 node_ids: Optional[List[int]] = None,
195 subcase_id: Optional[int] = None,
196 ) -> List[NodeDBModel]:
197 """
198 Get nodes for pagination, sorting, and filtering.
200 - offset: The offset for pagination.
201 - limit: The limit for pagination (default: 100).
202 - sort_column: The column to sort by (default: 'id').
203 - sort_direction: The direction of sorting (1 for ascending, -1 for descending).
204 - node_ids: An optional list of node IDs to filter by (default: None).
205 """
207 # Start a session with the database engine
208 with Session(self.engine) as session:
210 # early return if the last query is the same
211 if self.last_query is not None:
212 if (
213 self.last_sort_column == sort_column
214 and self.last_sort_direction == sort_direction
215 and self.last_filter == node_ids
216 ):
217 return session.exec(
218 self.last_query.offset(offset).limit(limit)
219 ).all()
221 # Create the base query
222 query = select(NodeDBModel)
224 # Apply filtering by node IDs if provided
225 if node_ids:
226 query = query.filter(NodeDBModel.id.in_(node_ids))
228 # add force data if requested only if the subcase_id is different from a previous request
229 # 0 for subcase means that its not necessary to add forces data as the request is coords or id
230 if subcase_id not in (0, self.last_subcase_id):
231 subcase = self.subcases[subcase_id]
232 node_id2forces = subcase.node_id2forces
233 for node_id, forces in node_id2forces.items():
234 node = session.exec(
235 select(NodeDBModel).filter(NodeDBModel.id == node_id)
236 ).first()
237 node.fx = forces[0]
238 node.fy = forces[1]
239 node.fz = forces[2]
240 node.fabs = (
241 forces[0] ** 2 + forces[1] ** 2 + forces[2] ** 2
242 ) ** 0.5
243 node.mx = forces[3]
244 node.my = forces[4]
245 node.mz = forces[5]
246 node.mabs = (
247 forces[3] ** 2 + forces[4] ** 2 + forces[5] ** 2
248 ) ** 0.5
249 self.last_subcase_id = subcase_id
250 session.commit()
252 # Apply sorting based on the specified column and direction
253 if sort_direction == 1:
254 query = query.order_by(asc(getattr(NodeDBModel, sort_column)))
255 elif sort_direction == -1:
256 query = query.order_by(desc(getattr(NodeDBModel, sort_column)))
258 # caching for speed
259 self.last_query = query
260 self.last_sort_column = sort_column
261 self.last_sort_direction = sort_direction
262 self.last_filter = node_ids
264 # Execute the query and return the results (with pagination)
265 return session.exec(query.offset(offset).limit(limit)).all()
267 async def get_all_nodes(
268 self, node_ids: Optional[List[int]] = None
269 ) -> List[NodeDBModel]:
270 """
271 Get all nodes
272 """
273 with Session(self.engine) as session:
274 if node_ids:
275 statement = select(NodeDBModel).filter(NodeDBModel.id.in_(node_ids))
276 else:
277 statement = select(NodeDBModel)
278 return session.exec(statement).all()
280 async def remove_mpc(self, mpc_id: int):
281 """
282 Remove a specific MPC
283 """
284 if mpc_id in self.mpcs:
285 del self.mpcs[mpc_id]
286 else:
287 raise HTTPException(
288 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
289 )
291 async def get_subcases(self) -> List[SubcaseDBModel]:
292 """
293 Get all subcases
294 """
295 return list(self.subcases.values())