Coverage for mpcforces_extractor\database\database.py: 74%
105 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 16:49 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 16:49 +0100
1from typing import List, Dict
2from fastapi import HTTPException
3from sqlmodel import Session, create_engine, SQLModel, Field, select, Column, JSON, text
4from mpcforces_extractor.datastructure.rigids import MPC
5from mpcforces_extractor.datastructure.rigids import MPC_CONFIG
6from mpcforces_extractor.datastructure.entities import Node
7from mpcforces_extractor.datastructure.subcases import Subcase
10class RunExtractorRequest(SQLModel, table=False):
11 """
12 Request model for running the extractor
13 """
15 fem_filename: str
16 mpcf_filename: str
19class DatabaseRequest(SQLModel, table=False):
20 """
21 Request model for running the extractor
22 """
24 database_filename: str
27class MPCDBModel(SQLModel, table=True):
28 """
29 Database Representation of MPC Class
30 """
32 id: int = Field(primary_key=True)
33 config: str = Field() # Store MPC_CONFIG as a string
34 master_node: int = Field() # Store master node as an integer
35 nodes: str = Field() # Store nodes as a string
36 part_id2nodes: Dict = Field(
37 default_factory=dict, sa_column=Column(JSON)
38 ) # Store part_id2nodes as a dictionary
39 subcase_id2part_id2forces: Dict = Field(
40 default_factory=dict, sa_column=Column(JSON)
41 ) # Store subcase_id2part_id2forces as a dictionary
43 def to_mpc(self):
44 """
45 Method to convert MPCDBModel back to MPC object if needed
46 """
47 print(f"Converting MPCDBModel to MPC: id={self.id}, nodes={self.nodes}")
48 nodes_list = (
49 str(self.nodes).split(",") if self.nodes else []
50 ) # Add a check to avoid splitting None
51 mpc = MPC(
52 element_id=self.id,
53 mpc_config=MPC_CONFIG[self.config], # Convert string back to enum
54 master_node=Node.node_id2node[
55 self.master_node
56 ], # Handle node conversion as needed
57 nodes=[Node.node_id2node[int(node_id)] for node_id in nodes_list],
58 dofs="",
59 )
60 mpc.part_id2node_ids = self.part_id2nodes
61 return mpc
64class NodeDBModel(SQLModel, table=True):
65 """
66 Database Representation of Node Instance
67 """
69 id: int = Field(primary_key=True)
70 coord_x: float = Field()
71 coord_y: float = Field()
72 coord_z: float = Field()
74 def to_node(self):
75 """
76 Method to convert NodeDBModel back to Node object if needed
77 """
78 return Node(node_id=self.id, coords=[self.coord_x, self.coord_y, self.coord_z])
81class SubcaseDBModel(SQLModel, table=True):
82 """
83 Database Representation of Subcase Class
84 """
86 id: int = Field(primary_key=True)
87 node_id2forces: Dict = Field(default_factory=dict, sa_column=Column(JSON))
88 time: float = Field()
90 def to_subcase(self):
91 """
92 Method to convert SubcaseDBModel back to Subcase object if needed
93 """
94 subcase = Subcase(subcase_id=self.id, time=self.time)
95 for node_id, forces in self.node_id2forces.items():
96 subcase.add_force(node_id, forces)
97 return subcase
100class MPCDatabase:
101 """
102 A Database class to store MPC instances, Nodes and Subcases
103 """
105 def __init__(self, file_path: str):
106 """
107 Development database creation and population
108 """
110 # Initialize the database
111 self.engine = None
112 self.mpcs = {}
113 self.subcases = {}
115 self.engine = create_engine(f"sqlite:///{file_path}")
117 def close(self):
118 """
119 Close the database connection
120 """
121 self.engine.dispose()
122 self.engine = None
124 def reinitialize_db(self, file_path: str):
125 """
126 Reinitialize the database with the data from the file
127 """
128 self.engine = create_engine(f"sqlite:///{file_path}")
129 with Session(self.engine) as session:
130 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()}
131 self.subcases = {
132 subcase.id: subcase
133 for subcase in session.exec(select(SubcaseDBModel)).all()
134 }
136 def populate_database(self, load_all_nodes=False):
137 """
138 Function to populate the database from MPC instances
139 """
140 # delete the existing data
141 # drop all tables
142 with Session(self.engine) as session:
143 session.exec(text("DROP TABLE IF EXISTS mpcdbmodel"))
144 session.exec(text("DROP TABLE IF EXISTS nodedbmodel"))
145 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel"))
147 # Create the tables again
148 SQLModel.metadata.create_all(self.engine)
150 with Session(self.engine) as session:
152 if load_all_nodes: # Load in all the nodes
153 for node in Node.node_id2node.values():
154 db_node = NodeDBModel(
155 id=node.id,
156 coord_x=node.coords[0],
157 coord_y=node.coords[1],
158 coord_z=node.coords[2],
159 )
160 session.add(db_node)
161 else: # load in just the nodes that are used in the MPCs
162 unique_nodes = set()
163 for mpc in MPC.id_2_instance.values():
164 for node in mpc.nodes:
165 unique_nodes.add(node)
166 unique_nodes.add(mpc.master_node)
168 for node in unique_nodes:
169 db_node = NodeDBModel(
170 id=node.id,
171 coord_x=node.coords[0],
172 coord_y=node.coords[1],
173 coord_z=node.coords[2],
174 )
175 session.add(db_node)
177 for mpc in MPC.id_2_instance.values():
179 mpc.get_part_id2force(None)
180 sub2part2force = mpc.get_subcase_id2part_id2force()
182 # Convert MPC instance to MPCDBModel
183 db_mpc = MPCDBModel(
184 id=mpc.element_id,
185 config=mpc.mpc_config.name, # Store enum as string
186 master_node=mpc.master_node.id,
187 nodes=",".join([str(node.id) for node in mpc.nodes]),
188 part_id2nodes=mpc.part_id2node_ids,
189 subcase_id2part_id2forces=sub2part2force,
190 )
191 # Add to the session
192 session.add(db_mpc)
194 # Subcases
195 for subcase in Subcase.subcases:
196 db_subcase = SubcaseDBModel(
197 id=subcase.subcase_id,
198 node_id2forces=subcase.node_id2forces,
199 time=subcase.time,
200 )
201 session.add(db_subcase)
203 # Commit to the database
204 session.commit()
206 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()}
207 self.subcases = {
208 subcase.id: subcase
209 for subcase in session.exec(select(SubcaseDBModel)).all()
210 }
212 async def get_mpcs(self) -> List[MPCDBModel]:
213 """
214 Get all MPCs
215 """
216 return list(self.mpcs.values())
218 async def get_mpc(self, mpc_id: int) -> MPCDBModel:
219 """
220 Get a specific MPC
221 """
222 if mpc_id in self.mpcs:
223 return self.mpcs.get(mpc_id)
224 raise HTTPException(
225 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
226 )
228 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]:
229 """
230 Get nodes for pagination
231 """
232 with Session(self.engine) as session:
233 statement = select(NodeDBModel).offset(offset).limit(limit)
234 return session.exec(statement).all()
236 async def get_all_nodes(self) -> List[NodeDBModel]:
237 """
238 Get all nodes
239 """
240 with Session(self.engine) as session:
241 statement = select(NodeDBModel)
242 return session.exec(statement).all()
244 async def remove_mpc(self, mpc_id: int):
245 """
246 Remove a specific MPC
247 """
248 if mpc_id in self.mpcs:
249 del self.mpcs[mpc_id]
250 else:
251 raise HTTPException(
252 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
253 )
255 async def get_subcases(self) -> List[SubcaseDBModel]:
256 """
257 Get all subcases
258 """
259 return list(self.subcases.values())