Coverage for mpcforces_extractor\datastructure\entities.py: 97%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-04 17:47 +0100

1import time 

2from typing import List, Dict 

3import networkx as nx 

4 

5 

6class Node: 

7 """ 

8 This class is used to store the nodes 

9 """ 

10 

11 node_id2node: Dict = {} 

12 

13 def __init__(self, node_id: int, coords: List): 

14 self.id = node_id 

15 self.coords = coords 

16 Node.node_id2node[node_id] = self 

17 self.connected_elements = [] 

18 

19 def add_element(self, element): 

20 """ 

21 This method adds the element to the connected elements 

22 """ 

23 if element not in self.connected_elements: 

24 self.connected_elements.append(element) 

25 

26 @staticmethod 

27 def reset() -> None: 

28 """ 

29 This method resets the node_id2node dictionary 

30 """ 

31 Node.node_id2node = {} 

32 

33 

34class Element1D: 

35 """ 

36 This class represents the 1D elements 

37 """ 

38 

39 all_elements = [] 

40 

41 def __init__(self, element_id: int, property_id: int, node1: Node, node2: Node): 

42 self.id = element_id 

43 self.property_id = property_id 

44 self.node1 = node1 

45 self.node2 = node2 

46 Element1D.all_elements.append(self) 

47 

48 @staticmethod 

49 def reset(): 

50 """ 

51 This method resets the all_elements list 

52 """ 

53 Element1D.all_elements = [] 

54 

55 

56class Element: 

57 """ 

58 This class is used to store the 2D/3D elements 

59 """ 

60 

61 element_id2element: Dict = {} 

62 graph = nx.Graph() 

63 part_id2node_ids = {} 

64 

65 @staticmethod 

66 def reset_graph(): 

67 """ 

68 This method is used to reset the graph (very important for testing) 

69 """ 

70 Element.graph = nx.Graph() 

71 Element.element_id2element = {} 

72 Element.part_id2node_ids = {} 

73 

74 def __init__(self, element_id: int, property_id: int, nodes: list): 

75 self.id = element_id 

76 self.property_id = property_id 

77 self.nodes = nodes 

78 for node in nodes: 

79 node.add_element(self) 

80 

81 # Graph - careful: Careless implementation regarding nodes: 

82 # every node is connected to every other node. 

83 # Real implementation should be done depending on element keyword TODO 

84 for node in nodes: 

85 for node2 in nodes: 

86 if node.id != node2.id: 

87 # add the edge to the graph if it does not exist 

88 if not Element.graph.has_edge(node, node2): 

89 Element.graph.add_edge(node, node2) 

90 

91 self.centroid = self.__calculate_centroid() 

92 self.neighbors = [] 

93 self.element_id2element[self.id] = self 

94 Element.part_id2node_ids = {} 

95 

96 def __calculate_centroid(self): 

97 """ 

98 This method calculates the centroid of the element 

99 """ 

100 centroid = [0, 0, 0] 

101 for node in self.nodes: 

102 for i in range(3): 

103 centroid[i] += node.coords[i] 

104 for i in range(3): 

105 centroid[i] /= len(self.nodes) 

106 return centroid 

107 

108 @staticmethod 

109 def get_part_id2node_ids_graph(force_update: bool = False) -> Dict: 

110 """ 

111 This method is used to get the part_id2node_ids using the graph 

112 """ 

113 

114 if force_update or not Element.part_id2node_ids: 

115 start_time = time.time() 

116 print("Building the part_id2node_ids using the graph") 

117 

118 print("...Calculating connected components") 

119 connected_components = list(nx.connected_components(Element.graph.copy())) 

120 

121 print( 

122 "Finished calculating the connected components, returning part_id2node_ids" 

123 ) 

124 print("..took ", round(time.time() - start_time, 2), "seconds") 

125 

126 for i, connected_component in enumerate(connected_components): 

127 Element.part_id2node_ids[i + 1] = [ 

128 node.id for node in connected_component 

129 ] 

130 

131 return Element.part_id2node_ids 

132 return Element.part_id2node_ids