Coverage for mpcforces_extractor\database\database.py: 76%
96 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-25 00:11 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-25 00:11 +0200
1from typing import List, Dict
2from fastapi import HTTPException
3from sqlmodel import Session, create_engine, SQLModel, Field, select, Column, JSON, text
4from mpcforces_extractor.datastructure.rigids import MPC
5from mpcforces_extractor.datastructure.rigids import MPC_CONFIG
6from mpcforces_extractor.datastructure.entities import Node
7from mpcforces_extractor.datastructure.subcases import Subcase
10class RunExtractorRequest(SQLModel, table=False):
11 """
12 Request model for running the extractor
13 """
15 fem_filename: str
16 mpcf_filename: str
19class MPCDBModel(SQLModel, table=True):
20 """
21 Database Representation of MPC Class
22 """
24 id: int = Field(primary_key=True)
25 config: str = Field() # Store MPC_CONFIG as a string
26 master_node: int = Field() # Store master node as an integer
27 nodes: str = Field() # Store nodes as a string
28 part_id2nodes: Dict = Field(
29 default_factory=dict, sa_column=Column(JSON)
30 ) # Store part_id2nodes as a dictionary
31 subcase_id2part_id2forces: Dict = Field(
32 default_factory=dict, sa_column=Column(JSON)
33 ) # Store subcase_id2part_id2forces as a dictionary
35 def to_mpc(self):
36 """
37 Method to convert MPCDBModel back to MPC object if needed
38 """
39 print(f"Converting MPCDBModel to MPC: id={self.id}, nodes={self.nodes}")
40 nodes_list = (
41 str(self.nodes).split(",") if self.nodes else []
42 ) # Add a check to avoid splitting None
43 mpc = MPC(
44 element_id=self.id,
45 mpc_config=MPC_CONFIG[self.config], # Convert string back to enum
46 master_node=Node.node_id2node[
47 self.master_node
48 ], # Handle node conversion as needed
49 nodes=[Node.node_id2node[int(node_id)] for node_id in nodes_list],
50 dofs="",
51 )
52 mpc.part_id2node_ids = self.part_id2nodes
53 return mpc
56class NodeDBModel(SQLModel, table=True):
57 """
58 Database Representation of Node Instance
59 """
61 id: int = Field(primary_key=True)
62 coord_x: float = Field()
63 coord_y: float = Field()
64 coord_z: float = Field()
66 def to_node(self):
67 """
68 Method to convert NodeDBModel back to Node object if needed
69 """
70 return Node(node_id=self.id, coords=[self.coord_x, self.coord_y, self.coord_z])
73class SubcaseDBModel(SQLModel, table=True):
74 """
75 Database Representation of Subcase Class
76 """
78 id: int = Field(primary_key=True)
79 node_id2forces: Dict = Field(default_factory=dict, sa_column=Column(JSON))
80 time: float = Field()
82 def to_subcase(self):
83 """
84 Method to convert SubcaseDBModel back to Subcase object if needed
85 """
86 subcase = Subcase(subcase_id=self.id, time=self.time)
87 for node_id, forces in self.node_id2forces.items():
88 subcase.add_force(node_id, forces)
89 return subcase
92class MPCDatabase:
93 """
94 A Database class to store MPC instances, Nodes and Subcases
95 """
97 def __init__(self):
98 """
99 Development database creation and population
100 """
102 # Create the SQLite engine
103 self.engine = create_engine("sqlite:///db.db")
105 # remover all values from the database
106 with Session(self.engine) as session:
108 session.exec(text("DELETE FROM mpcdbmodel"))
109 session.exec(text("DELETE FROM nodedbmodel"))
110 session.exec(text("DELETE FROM subcasedbmodel"))
111 session.commit()
113 # Drop existing tables for development purposes
114 SQLModel.metadata.drop_all(self.engine)
116 # Create the tables
117 SQLModel.metadata.create_all(self.engine)
119 self.populate_database()
121 # Read from the database
122 with Session(self.engine) as session:
123 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()}
124 self.subcases = {
125 subcase.id: subcase
126 for subcase in session.exec(select(SubcaseDBModel)).all()
127 }
129 def populate_database(self, load_all_nodes=False):
130 """
131 Function to populate the database from MPC instances
132 """
133 # delete the existing data
135 with Session(self.engine) as session:
137 if load_all_nodes: # Load in all the nodes
138 for node in Node.node_id2node.values():
139 db_node = NodeDBModel(
140 id=node.id,
141 coord_x=node.coords[0],
142 coord_y=node.coords[1],
143 coord_z=node.coords[2],
144 )
145 session.add(db_node)
146 else: # load in just the nodes that are used in the MPCs
147 unique_nodes = set()
148 for mpc in MPC.id_2_instance.values():
149 for node in mpc.nodes:
150 unique_nodes.add(node)
151 unique_nodes.add(mpc.master_node)
153 for node in unique_nodes:
154 db_node = NodeDBModel(
155 id=node.id,
156 coord_x=node.coords[0],
157 coord_y=node.coords[1],
158 coord_z=node.coords[2],
159 )
160 session.add(db_node)
162 for mpc in MPC.id_2_instance.values():
164 mpc.get_part_id2force(None)
165 sub2part2force = mpc.get_subcase_id2part_id2force()
167 # Convert MPC instance to MPCDBModel
168 db_mpc = MPCDBModel(
169 id=mpc.element_id,
170 config=mpc.mpc_config.name, # Store enum as string
171 master_node=mpc.master_node.id,
172 nodes=",".join([str(node.id) for node in mpc.nodes]),
173 part_id2nodes=mpc.part_id2node_ids,
174 subcase_id2part_id2forces=sub2part2force,
175 )
176 # Add to the session
177 session.add(db_mpc)
179 # Subcases
180 for subcase in Subcase.subcases:
181 db_subcase = SubcaseDBModel(
182 id=subcase.subcase_id,
183 node_id2forces=subcase.node_id2forces,
184 time=subcase.time,
185 )
186 session.add(db_subcase)
188 # Commit to the database
189 session.commit()
191 async def get_mpcs(self) -> List[MPCDBModel]:
192 """
193 Get all MPCs
194 """
195 return list(self.mpcs.values())
197 async def get_mpc(self, mpc_id: int) -> MPCDBModel:
198 """
199 Get a specific MPC
200 """
201 if mpc_id in self.mpcs:
202 return self.mpcs.get(mpc_id)
203 raise HTTPException(
204 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
205 )
207 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]:
208 """
209 Get nodes for pagination
210 """
211 with Session(self.engine) as session:
212 statement = select(NodeDBModel).offset(offset).limit(limit)
213 return session.exec(statement).all()
215 async def get_all_nodes(self) -> List[NodeDBModel]:
216 """
217 Get all nodes
218 """
219 with Session(self.engine) as session:
220 statement = select(NodeDBModel)
221 return session.exec(statement).all()
223 async def remove_mpc(self, mpc_id: int):
224 """
225 Remove a specific MPC
226 """
227 if mpc_id in self.mpcs:
228 del self.mpcs[mpc_id]
229 else:
230 raise HTTPException(
231 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
232 )
234 async def get_subcases(self) -> List[SubcaseDBModel]:
235 """
236 Get all subcases
237 """
238 return list(self.subcases.values())