Coverage for mpcforces_extractor\api\db\test_database.py: 100%
61 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 01:41 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 01:41 +0100
1import os
2import pytest
3from mpcforces_extractor.api.db.database import MPCDatabase
4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG
5from mpcforces_extractor.datastructure.entities import Node, Element
6from mpcforces_extractor.datastructure.subcases import Subcase
8# Initialize db_save at the module level
9db_save = None # Ensure db_save is defined before use
12@pytest.mark.asyncio
13async def get_db():
14 global db_save # Declare db_save as global to modify it
16 if db_save:
17 return db_save
19 # Define the initial MPC instances
20 node1 = Node(1, [0, 0, 0])
21 node2 = Node(2, [1, 2, 3])
22 node3 = Node(3, [4, 5, 6])
23 node4 = Node(4, [0, 0, 0])
24 node5 = Node(5, [1, 2, 3])
25 node6 = Node(6, [4, 5, 6])
26 Node(7, [0, 0, 0]) # Unused node
28 MPC.reset()
29 MPC(
30 element_id=1,
31 mpc_config=MPC_CONFIG.RBE2,
32 master_node=node1,
33 nodes=[node2, node3],
34 dofs="",
35 )
36 MPC(
37 element_id=2,
38 mpc_config=MPC_CONFIG.RBE3,
39 master_node=node4,
40 nodes=[node5, node6],
41 dofs="",
42 )
44 Element(1, 1, [node2, node3])
45 Element(2, 2, [node6, node5])
47 subcase = Subcase(1, 1.0)
48 subcase.add_force(1, [1.0, 0, 0, 0, 0, 0])
49 subcase.add_force(2, [1.0, 0, 0, 0, 0, 0])
50 subcase.add_force(3, [1.0, 0, 0, 0, 0, 0])
51 subcase.add_force(4, [1.0, 0, 0, 0, 0, 0])
52 subcase.add_force(5, [1.0, 0, 0, 0, 0, 0])
53 subcase.add_force(6, [1.0, 0, 0, 0, 0, 0])
55 db = MPCDatabase("test.db")
56 db.populate_database()
57 db_save = db # Save the initialized database
58 return db_save
61@pytest.mark.asyncio
62async def test_initialize_database():
63 db = await get_db()
64 assert len(await db.get_rbe2s()) == 1 # Check initial population
65 assert len(await db.get_rbe3s()) == 1
68@pytest.mark.asyncio
69async def test_get_nodes():
70 db = await get_db()
71 nodes_all = await db.get_all_nodes()
72 assert len(nodes_all) == 6
73 offset = 1
74 nodes = await db.get_nodes(offset=offset, limit=10)
75 assert len(nodes) == len(nodes_all) - offset
77 db.populate_database(load_all_nodes=True)
78 assert len(await db.get_all_nodes()) == 7
81@pytest.mark.asyncio
82async def test_subcases():
83 db = await get_db()
84 subcases = await db.get_subcases()
85 assert len(subcases) == 1
86 subcase = subcases[0]
87 assert subcase.id == 1
88 assert subcase.time == 1.0
89 assert subcase.node_id2forces["1"] == [1.0, 0, 0, 0, 0, 0]
92# remove the db.db after all test
93def test_teardown():
94 db_save.close()
95 os.remove("test.db")