Coverage for mpcforces_extractor\database\test_database.py: 98%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-10-31 16:53 +0100

1import os 

2import pytest 

3from mpcforces_extractor.database.database import MPCDatabase 

4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG 

5from mpcforces_extractor.datastructure.entities import Node, Element 

6from fastapi import HTTPException 

7 

8# Initialize db_save at the module level 

9db_save = None # Ensure db_save is defined before use 

10 

11 

12@pytest.mark.asyncio 

13async def get_db(): 

14 global db_save # Declare db_save as global to modify it 

15 

16 if db_save: 

17 return db_save 

18 

19 # Define the initial MPC instances 

20 node1 = Node(1, [0, 0, 0]) 

21 node2 = Node(2, [1, 2, 3]) 

22 node3 = Node(3, [4, 5, 6]) 

23 node4 = Node(4, [0, 0, 0]) 

24 node5 = Node(5, [1, 2, 3]) 

25 node6 = Node(6, [4, 5, 6]) 

26 

27 MPC.reset() 

28 MPC( 

29 element_id=1, 

30 mpc_config=MPC_CONFIG.RBE2, 

31 master_node=node1, 

32 nodes=[node2, node3], 

33 dofs="", 

34 ) 

35 MPC( 

36 element_id=2, 

37 mpc_config=MPC_CONFIG.RBE3, 

38 master_node=node4, 

39 nodes=[node5, node6], 

40 dofs="", 

41 ) 

42 

43 Element(1, 1, [node2, node3]) 

44 Element(2, 2, [node6, node5]) 

45 

46 db = MPCDatabase("test.db") 

47 db.populate_database() 

48 db_save = db # Save the initialized database 

49 return db_save 

50 

51 

52@pytest.mark.asyncio 

53async def test_initialize_database(): 

54 db = await get_db() 

55 assert len(await db.get_mpcs()) == 2 # Check initial population 

56 

57 

58@pytest.mark.asyncio 

59async def test_get_mpc(): 

60 db = await get_db() 

61 mpc = await db.get_mpc(1) # Await the async function 

62 assert mpc.id == 1 

63 assert mpc.config == "RBE2" 

64 

65 

66@pytest.mark.asyncio 

67async def test_remove_mpc(): 

68 db = await get_db() 

69 await db.remove_mpc(1) # Await the async function 

70 with pytest.raises(HTTPException): 

71 await db.get_mpc(1) # Await the async function 

72 

73 

74# remove the db.db after all test 

75def test_teardown(): 

76 db_save.close() 

77 os.remove("test.db") 

78 

79 

80if __name__ == "__main__": 

81 pytest.main(["-s", "-v", __file__])