Coverage for mpcforces_extractor\api\db\database.py: 94%
71 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 21:34 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 21:34 +0100
1from typing import List
2from fastapi import HTTPException
3from sqlmodel import Session, create_engine, SQLModel, select, text
4from mpcforces_extractor.datastructure.rigids import MPC
5from mpcforces_extractor.datastructure.entities import Node
6from mpcforces_extractor.datastructure.subcases import Subcase
7from mpcforces_extractor.api.db.models import (
8 MPCDBModel,
9 NodeDBModel,
10 SubcaseDBModel,
11)
14class MPCDatabase:
15 """
16 A Database class to store MPC instances, Nodes and Subcases
17 """
19 def __init__(self, file_path: str):
20 """
21 Development database creation and population
22 """
24 # Initialize the database
25 self.engine = None
26 self.mpcs = {}
27 self.subcases = {}
29 self.engine = create_engine(f"sqlite:///{file_path}")
31 def close(self):
32 """
33 Close the database connection
34 """
35 self.engine.dispose()
36 self.engine = None
38 def reinitialize_db(self, file_path: str):
39 """
40 Reinitialize the database with the data from the file
41 """
42 self.engine = create_engine(f"sqlite:///{file_path}")
43 with Session(self.engine) as session:
44 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()}
45 self.subcases = {
46 subcase.id: subcase
47 for subcase in session.exec(select(SubcaseDBModel)).all()
48 }
50 def populate_database(self, load_all_nodes=False):
51 """
52 Function to populate the database from MPC instances
53 """
54 # delete the existing data
55 # drop all tables
56 with Session(self.engine) as session:
57 session.exec(text("DROP TABLE IF EXISTS mpcdbmodel"))
58 session.exec(text("DROP TABLE IF EXISTS nodedbmodel"))
59 session.exec(text("DROP TABLE IF EXISTS subcasedbmodel"))
61 # Create the tables again
62 SQLModel.metadata.create_all(self.engine)
64 with Session(self.engine) as session:
66 if load_all_nodes: # Load in all the nodes
67 for node in Node.node_id2node.values():
68 db_node = NodeDBModel(
69 id=node.id,
70 coord_x=node.coords[0],
71 coord_y=node.coords[1],
72 coord_z=node.coords[2],
73 )
74 session.add(db_node)
75 else: # load in just the nodes that are used in the MPCs
76 unique_nodes = set()
77 for mpc in MPC.id_2_instance.values():
78 for node in mpc.nodes:
79 unique_nodes.add(node)
80 unique_nodes.add(mpc.master_node)
82 for node in unique_nodes:
83 db_node = NodeDBModel(
84 id=node.id,
85 coord_x=node.coords[0],
86 coord_y=node.coords[1],
87 coord_z=node.coords[2],
88 )
89 session.add(db_node)
91 for mpc in MPC.id_2_instance.values():
93 mpc.get_part_id2force(None)
94 sub2part2force = mpc.get_subcase_id2part_id2force()
96 # Convert MPC instance to MPCDBModel
97 db_mpc = MPCDBModel(
98 id=mpc.element_id,
99 config=mpc.mpc_config.name, # Store enum as string
100 master_node=mpc.master_node.id,
101 nodes=",".join([str(node.id) for node in mpc.nodes]),
102 part_id2nodes=mpc.part_id2node_ids,
103 subcase_id2part_id2forces=sub2part2force,
104 )
105 # Add to the session
106 session.add(db_mpc)
108 # Subcases
109 for subcase in Subcase.subcases:
110 db_subcase = SubcaseDBModel(
111 id=subcase.subcase_id,
112 node_id2forces=subcase.node_id2forces,
113 time=subcase.time,
114 )
115 session.add(db_subcase)
117 # Commit to the database
118 session.commit()
120 self.mpcs = {mpc.id: mpc for mpc in session.exec(select(MPCDBModel)).all()}
121 self.subcases = {
122 subcase.id: subcase
123 for subcase in session.exec(select(SubcaseDBModel)).all()
124 }
126 async def get_mpcs(self) -> List[MPCDBModel]:
127 """
128 Get all MPCs
129 """
130 return list(self.mpcs.values())
132 async def get_mpc(self, mpc_id: int) -> MPCDBModel:
133 """
134 Get a specific MPC
135 """
136 if mpc_id in self.mpcs:
137 return self.mpcs.get(mpc_id)
138 raise HTTPException(
139 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
140 )
142 async def get_nodes(self, offset: int, limit: int = 100) -> List[NodeDBModel]:
143 """
144 Get nodes for pagination
145 """
146 with Session(self.engine) as session:
147 statement = select(NodeDBModel).offset(offset).limit(limit)
148 return session.exec(statement).all()
150 async def get_all_nodes(self) -> List[NodeDBModel]:
151 """
152 Get all nodes
153 """
154 with Session(self.engine) as session:
155 statement = select(NodeDBModel)
156 return session.exec(statement).all()
158 async def remove_mpc(self, mpc_id: int):
159 """
160 Remove a specific MPC
161 """
162 if mpc_id in self.mpcs:
163 del self.mpcs[mpc_id]
164 else:
165 raise HTTPException(
166 status_code=404, detail=f"MPC with id {mpc_id} does not exist"
167 )
169 async def get_subcases(self) -> List[SubcaseDBModel]:
170 """
171 Get all subcases
172 """
173 return list(self.subcases.values())