Coverage for mpcforces_extractor\api\db\test_database.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 01:41 +0100

1import os 

2import pytest 

3from mpcforces_extractor.api.db.database import MPCDatabase 

4from mpcforces_extractor.datastructure.rigids import MPC, MPC_CONFIG 

5from mpcforces_extractor.datastructure.entities import Node, Element 

6from mpcforces_extractor.datastructure.subcases import Subcase 

7 

8# Initialize db_save at the module level 

9db_save = None # Ensure db_save is defined before use 

10 

11 

12@pytest.mark.asyncio 

13async def get_db(): 

14 global db_save # Declare db_save as global to modify it 

15 

16 if db_save: 

17 return db_save 

18 

19 # Define the initial MPC instances 

20 node1 = Node(1, [0, 0, 0]) 

21 node2 = Node(2, [1, 2, 3]) 

22 node3 = Node(3, [4, 5, 6]) 

23 node4 = Node(4, [0, 0, 0]) 

24 node5 = Node(5, [1, 2, 3]) 

25 node6 = Node(6, [4, 5, 6]) 

26 Node(7, [0, 0, 0]) # Unused node 

27 

28 MPC.reset() 

29 MPC( 

30 element_id=1, 

31 mpc_config=MPC_CONFIG.RBE2, 

32 master_node=node1, 

33 nodes=[node2, node3], 

34 dofs="", 

35 ) 

36 MPC( 

37 element_id=2, 

38 mpc_config=MPC_CONFIG.RBE3, 

39 master_node=node4, 

40 nodes=[node5, node6], 

41 dofs="", 

42 ) 

43 

44 Element(1, 1, [node2, node3]) 

45 Element(2, 2, [node6, node5]) 

46 

47 subcase = Subcase(1, 1.0) 

48 subcase.add_force(1, [1.0, 0, 0, 0, 0, 0]) 

49 subcase.add_force(2, [1.0, 0, 0, 0, 0, 0]) 

50 subcase.add_force(3, [1.0, 0, 0, 0, 0, 0]) 

51 subcase.add_force(4, [1.0, 0, 0, 0, 0, 0]) 

52 subcase.add_force(5, [1.0, 0, 0, 0, 0, 0]) 

53 subcase.add_force(6, [1.0, 0, 0, 0, 0, 0]) 

54 

55 db = MPCDatabase("test.db") 

56 db.populate_database() 

57 db_save = db # Save the initialized database 

58 return db_save 

59 

60 

61@pytest.mark.asyncio 

62async def test_initialize_database(): 

63 db = await get_db() 

64 assert len(await db.get_rbe2s()) == 1 # Check initial population 

65 assert len(await db.get_rbe3s()) == 1 

66 

67 

68@pytest.mark.asyncio 

69async def test_get_nodes(): 

70 db = await get_db() 

71 nodes_all = await db.get_all_nodes() 

72 assert len(nodes_all) == 6 

73 offset = 1 

74 nodes = await db.get_nodes(offset=offset, limit=10) 

75 assert len(nodes) == len(nodes_all) - offset 

76 

77 db.populate_database(load_all_nodes=True) 

78 assert len(await db.get_all_nodes()) == 7 

79 

80 

81@pytest.mark.asyncio 

82async def test_subcases(): 

83 db = await get_db() 

84 subcases = await db.get_subcases() 

85 assert len(subcases) == 1 

86 subcase = subcases[0] 

87 assert subcase.id == 1 

88 assert subcase.time == 1.0 

89 assert subcase.node_id2forces["1"] == [1.0, 0, 0, 0, 0, 0] 

90 

91 

92# remove the db.db after all test 

93def test_teardown(): 

94 db_save.close() 

95 os.remove("test.db")