Coverage for mpcforces_extractor\database\test_database.py: 98%
47 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 16:53 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 16:53 +0100
1import os
2import pytest
3from mpcforces_extractor.database.database import MPCDatabase
4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG
5from mpcforces_extractor.datastructure.entities import Node, Element
6from fastapi import HTTPException
8# Initialize db_save at the module level
9db_save = None # Ensure db_save is defined before use
12@pytest.mark.asyncio
13async def get_db():
14 global db_save # Declare db_save as global to modify it
16 if db_save:
17 return db_save
19 # Define the initial MPC instances
20 node1 = Node(1, [0, 0, 0])
21 node2 = Node(2, [1, 2, 3])
22 node3 = Node(3, [4, 5, 6])
23 node4 = Node(4, [0, 0, 0])
24 node5 = Node(5, [1, 2, 3])
25 node6 = Node(6, [4, 5, 6])
27 MPC.reset()
28 MPC(
29 element_id=1,
30 mpc_config=MPC_CONFIG.RBE2,
31 master_node=node1,
32 nodes=[node2, node3],
33 dofs="",
34 )
35 MPC(
36 element_id=2,
37 mpc_config=MPC_CONFIG.RBE3,
38 master_node=node4,
39 nodes=[node5, node6],
40 dofs="",
41 )
43 Element(1, 1, [node2, node3])
44 Element(2, 2, [node6, node5])
46 db = MPCDatabase("test.db")
47 db.populate_database()
48 db_save = db # Save the initialized database
49 return db_save
52@pytest.mark.asyncio
53async def test_initialize_database():
54 db = await get_db()
55 assert len(await db.get_mpcs()) == 2 # Check initial population
58@pytest.mark.asyncio
59async def test_get_mpc():
60 db = await get_db()
61 mpc = await db.get_mpc(1) # Await the async function
62 assert mpc.id == 1
63 assert mpc.config == "RBE2"
66@pytest.mark.asyncio
67async def test_remove_mpc():
68 db = await get_db()
69 await db.remove_mpc(1) # Await the async function
70 with pytest.raises(HTTPException):
71 await db.get_mpc(1) # Await the async function
74# remove the db.db after all test
75def test_teardown():
76 db_save.close()
77 os.remove("test.db")
80if __name__ == "__main__":
81 pytest.main(["-s", "-v", __file__])