Coverage for mpcforces_extractor\api\db\test_database.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-06 21:34 +0100

1import os 

2import pytest 

3from mpcforces_extractor.api.db.database import MPCDatabase 

4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG 

5from mpcforces_extractor.datastructure.entities import Node, Element 

6from mpcforces_extractor.datastructure.subcases import Subcase 

7from fastapi import HTTPException 

8 

9# Initialize db_save at the module level 

10db_save = None # Ensure db_save is defined before use 

11 

12 

13@pytest.mark.asyncio 

14async def get_db(): 

15 global db_save # Declare db_save as global to modify it 

16 

17 if db_save: 

18 return db_save 

19 

20 # Define the initial MPC instances 

21 node1 = Node(1, [0, 0, 0]) 

22 node2 = Node(2, [1, 2, 3]) 

23 node3 = Node(3, [4, 5, 6]) 

24 node4 = Node(4, [0, 0, 0]) 

25 node5 = Node(5, [1, 2, 3]) 

26 node6 = Node(6, [4, 5, 6]) 

27 Node(7, [0, 0, 0]) # Unused node 

28 

29 MPC.reset() 

30 MPC( 

31 element_id=1, 

32 mpc_config=MPC_CONFIG.RBE2, 

33 master_node=node1, 

34 nodes=[node2, node3], 

35 dofs="", 

36 ) 

37 MPC( 

38 element_id=2, 

39 mpc_config=MPC_CONFIG.RBE3, 

40 master_node=node4, 

41 nodes=[node5, node6], 

42 dofs="", 

43 ) 

44 

45 Element(1, 1, [node2, node3]) 

46 Element(2, 2, [node6, node5]) 

47 

48 subcase = Subcase(1, 1.0) 

49 subcase.add_force(1, [1.0, 0, 0, 0, 0, 0]) 

50 subcase.add_force(2, [1.0, 0, 0, 0, 0, 0]) 

51 subcase.add_force(3, [1.0, 0, 0, 0, 0, 0]) 

52 subcase.add_force(4, [1.0, 0, 0, 0, 0, 0]) 

53 subcase.add_force(5, [1.0, 0, 0, 0, 0, 0]) 

54 subcase.add_force(6, [1.0, 0, 0, 0, 0, 0]) 

55 

56 db = MPCDatabase("test.db") 

57 db.populate_database() 

58 db_save = db # Save the initialized database 

59 return db_save 

60 

61 

62@pytest.mark.asyncio 

63async def test_initialize_database(): 

64 db = await get_db() 

65 assert len(await db.get_mpcs()) == 2 # Check initial population 

66 

67 

68@pytest.mark.asyncio 

69async def test_get_mpc(): 

70 db = await get_db() 

71 mpc = await db.get_mpc(1) # Await the async function 

72 assert mpc.id == 1 

73 assert mpc.config == "RBE2" 

74 

75 

76@pytest.mark.asyncio 

77async def test_remove_mpc(): 

78 db = await get_db() 

79 await db.remove_mpc(1) # Await the async function 

80 with pytest.raises(HTTPException): 

81 await db.get_mpc(1) # Await the async function 

82 

83 

84@pytest.mark.asyncio 

85async def test_remove_mpc_not_exist(): 

86 db = await get_db() 

87 with pytest.raises(HTTPException): 

88 await db.remove_mpc(3) 

89 

90 

91@pytest.mark.asyncio 

92async def test_get_nodes(): 

93 db = await get_db() 

94 nodes_all = await db.get_all_nodes() 

95 assert len(nodes_all) == 6 

96 offset = 1 

97 nodes = await db.get_nodes(offset, 100) 

98 assert len(nodes) == len(nodes_all) - offset 

99 

100 db.populate_database(load_all_nodes=True) 

101 assert len(await db.get_all_nodes()) == 7 

102 

103 

104@pytest.mark.asyncio 

105async def test_subcases(): 

106 db = await get_db() 

107 subcases = await db.get_subcases() 

108 assert len(subcases) == 1 

109 subcase = subcases[0] 

110 assert subcase.id == 1 

111 assert subcase.time == 1.0 

112 assert subcase.node_id2forces["1"] == [1.0, 0, 0, 0, 0, 0] 

113 

114 

115# remove the db.db after all test 

116def test_teardown(): 

117 db_save.close() 

118 os.remove("test.db")