Coverage for mpcforces_extractor\api\db\test_database.py: 100%
78 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 21:34 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 21:34 +0100
1import os
2import pytest
3from mpcforces_extractor.api.db.database import MPCDatabase
4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG
5from mpcforces_extractor.datastructure.entities import Node, Element
6from mpcforces_extractor.datastructure.subcases import Subcase
7from fastapi import HTTPException
9# Initialize db_save at the module level
10db_save = None # Ensure db_save is defined before use
13@pytest.mark.asyncio
14async def get_db():
15 global db_save # Declare db_save as global to modify it
17 if db_save:
18 return db_save
20 # Define the initial MPC instances
21 node1 = Node(1, [0, 0, 0])
22 node2 = Node(2, [1, 2, 3])
23 node3 = Node(3, [4, 5, 6])
24 node4 = Node(4, [0, 0, 0])
25 node5 = Node(5, [1, 2, 3])
26 node6 = Node(6, [4, 5, 6])
27 Node(7, [0, 0, 0]) # Unused node
29 MPC.reset()
30 MPC(
31 element_id=1,
32 mpc_config=MPC_CONFIG.RBE2,
33 master_node=node1,
34 nodes=[node2, node3],
35 dofs="",
36 )
37 MPC(
38 element_id=2,
39 mpc_config=MPC_CONFIG.RBE3,
40 master_node=node4,
41 nodes=[node5, node6],
42 dofs="",
43 )
45 Element(1, 1, [node2, node3])
46 Element(2, 2, [node6, node5])
48 subcase = Subcase(1, 1.0)
49 subcase.add_force(1, [1.0, 0, 0, 0, 0, 0])
50 subcase.add_force(2, [1.0, 0, 0, 0, 0, 0])
51 subcase.add_force(3, [1.0, 0, 0, 0, 0, 0])
52 subcase.add_force(4, [1.0, 0, 0, 0, 0, 0])
53 subcase.add_force(5, [1.0, 0, 0, 0, 0, 0])
54 subcase.add_force(6, [1.0, 0, 0, 0, 0, 0])
56 db = MPCDatabase("test.db")
57 db.populate_database()
58 db_save = db # Save the initialized database
59 return db_save
62@pytest.mark.asyncio
63async def test_initialize_database():
64 db = await get_db()
65 assert len(await db.get_mpcs()) == 2 # Check initial population
68@pytest.mark.asyncio
69async def test_get_mpc():
70 db = await get_db()
71 mpc = await db.get_mpc(1) # Await the async function
72 assert mpc.id == 1
73 assert mpc.config == "RBE2"
76@pytest.mark.asyncio
77async def test_remove_mpc():
78 db = await get_db()
79 await db.remove_mpc(1) # Await the async function
80 with pytest.raises(HTTPException):
81 await db.get_mpc(1) # Await the async function
84@pytest.mark.asyncio
85async def test_remove_mpc_not_exist():
86 db = await get_db()
87 with pytest.raises(HTTPException):
88 await db.remove_mpc(3)
91@pytest.mark.asyncio
92async def test_get_nodes():
93 db = await get_db()
94 nodes_all = await db.get_all_nodes()
95 assert len(nodes_all) == 6
96 offset = 1
97 nodes = await db.get_nodes(offset, 100)
98 assert len(nodes) == len(nodes_all) - offset
100 db.populate_database(load_all_nodes=True)
101 assert len(await db.get_all_nodes()) == 7
104@pytest.mark.asyncio
105async def test_subcases():
106 db = await get_db()
107 subcases = await db.get_subcases()
108 assert len(subcases) == 1
109 subcase = subcases[0]
110 assert subcase.id == 1
111 assert subcase.time == 1.0
112 assert subcase.node_id2forces["1"] == [1.0, 0, 0, 0, 0, 0]
115# remove the db.db after all test
116def test_teardown():
117 db_save.close()
118 os.remove("test.db")