Coverage for mpcforces_extractor\visualization\api.py: 0%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-25 00:11 +0200

1import os 

2from typing import List 

3from fastapi import ( 

4 FastAPI, 

5 HTTPException, 

6 status, 

7 Request, 

8 Query, 

9 Form, 

10 UploadFile, 

11) 

12from fastapi.templating import Jinja2Templates 

13from fastapi.staticfiles import StaticFiles 

14from fastapi.responses import HTMLResponse 

15from mpcforces_extractor.database.database import ( 

16 MPCDatabase, 

17 MPCDBModel, 

18 NodeDBModel, 

19 SubcaseDBModel, 

20 RunExtractorRequest, 

21) 

22from mpcforces_extractor.force_extractor import MPCForceExtractor 

23from mpcforces_extractor.datastructure.entities import Element, Node, Element1D 

24from mpcforces_extractor.datastructure.subcases import Subcase 

25from mpcforces_extractor.datastructure.rigids import MPC 

26 

27 

28ITEMS_PER_PAGE = 100 # Define a fixed number of items per page 

29 

30 

31# Setup Jinja2 templates 

32templates = Jinja2Templates( 

33 directory="mpcforces_extractor/visualization/frontend/templates" 

34) 

35 

36 

37app = FastAPI() 

38 

39 

40@app.on_event("startup") 

41async def startup_event(): 

42 """ 

43 Connect to the database when the application starts 

44 """ 

45 print("Connecting to the database") 

46 app.db = MPCDatabase() 

47 

48 

49# Mount the static files directory 

50app.mount( 

51 "/static", 

52 StaticFiles(directory="mpcforces_extractor/visualization/frontend/static"), 

53 name="static", 

54) 

55 

56 

57# API endpoint to get all MPCs 

58@app.get("/api/v1/mpcs", response_model=List[MPCDBModel]) 

59async def get_mpcs() -> List[MPCDBModel]: 

60 """Get all MPCs""" 

61 return await app.db.get_mpcs() 

62 

63 

64# API endpoint to get a specific MPC by ID 

65@app.get("/api/v1/mpcs/{mpc_id}", response_model=MPCDBModel) 

66async def get_mpc(mpc_id: int) -> MPCDBModel: 

67 """Get info about a specific MPC""" 

68 mpc = await app.db.get_mpc(mpc_id) 

69 if mpc is None: 

70 raise HTTPException( 

71 status_code=status.HTTP_404_NOT_FOUND, 

72 detail=f"MPC with id: {mpc_id} does not exist", 

73 ) 

74 

75 return mpc 

76 

77 

78# API endpoint to remove an MPC by ID 

79@app.delete("/api/v1/mpcs/{mpc_id}") 

80async def remove_mpc(mpc_id: int): 

81 """Remove an MPC""" 

82 await app.db.remove_mpc(mpc_id) 

83 return {"message": f"MPC with id: {mpc_id} removed"} 

84 

85 

86@app.get("/api/v1/nodes", response_model=List[NodeDBModel]) 

87async def get_nodes(page: int = Query(1, ge=1)) -> List[NodeDBModel]: 

88 """ 

89 Get nodes with pagination (fixed 100 items per page) 

90 """ 

91 # Calculate offset based on the current page 

92 offset = (page - 1) * ITEMS_PER_PAGE 

93 

94 # Fetch nodes from the database with the calculated offset and limit (fixed at 100) 

95 nodes = await app.db.get_nodes(offset=offset, limit=ITEMS_PER_PAGE) 

96 

97 # Handle case when no nodes are found 

98 if not nodes: 

99 raise HTTPException(status_code=404, detail="No nodes found") 

100 

101 return nodes 

102 

103 

104@app.get("/api/v1/nodes/all", response_model=List[NodeDBModel]) 

105async def get_all_nodes() -> int: 

106 """ 

107 Get all nodes 

108 """ 

109 nodes = await app.db.get_all_nodes() 

110 

111 if not nodes: 

112 raise HTTPException(status_code=404, detail="No nodes found") 

113 

114 return nodes 

115 

116 

117@app.get("/api/v1/nodes/filter/{filter_input}", response_model=List[NodeDBModel]) 

118async def get_nodes_filtered(filter_input: str) -> List[NodeDBModel]: 

119 """ 

120 Get nodes filtered by a string, get it from all nodes, not paginated. 

121 The filter can be a range like '1-3' or comma-separated values like '1,2,3'. 

122 """ 

123 nodes = await app.db.get_all_nodes() 

124 filtered_nodes = [] 

125 

126 # Split the filter string by comma and process each part 

127 filter_parts = filter_input.split(",") 

128 for part in filter_parts: 

129 part = part.strip() # Trim whitespace 

130 if "-" in part: 

131 # Handle range like '1-3' 

132 start, end = part.split("-") 

133 try: 

134 start_id = int(start.strip()) 

135 end_id = int(end.strip()) 

136 filtered_nodes.extend( 

137 node for node in nodes if start_id <= node.id <= end_id 

138 ) 

139 except ValueError: 

140 raise HTTPException( 

141 status_code=400, detail="Invalid range in filter" 

142 ) from ValueError 

143 else: 

144 # Handle single ID 

145 try: 

146 node_id = int(part) 

147 node = next((node for node in nodes if node.id == node_id), None) 

148 if node: 

149 filtered_nodes.append(node) 

150 except ValueError: 

151 raise HTTPException( 

152 status_code=400, detail="Invalid ID in filter" 

153 ) from ValueError 

154 return filtered_nodes 

155 

156 

157# API endpoint to get all subcases 

158@app.get("/api/v1/subcases", response_model=List[SubcaseDBModel]) 

159async def get_subcases() -> List[SubcaseDBModel]: 

160 """Get all subcases""" 

161 return await app.db.get_subcases() 

162 

163 

164@app.post("/api/v1/upload-chunk") 

165async def upload_chunk( 

166 file: UploadFile, filename: str = Form(...), offset: int = Form(...) 

167): 

168 """ 

169 Upload a chunk of a file 

170 """ 

171 upload_folder = "data/uploads" 

172 file_path = os.path.join(upload_folder, filename) 

173 

174 # check if the file exists, if so, delete it 

175 if os.path.exists(file_path): 

176 os.remove(file_path) 

177 

178 # Create the upload directory if it doesn't exist 

179 os.makedirs(upload_folder, exist_ok=True) 

180 

181 # Open the file in append mode to write the chunk at the correct offset 

182 with open(file_path, "ab") as f: 

183 f.seek(offset) 

184 content = await file.read() 

185 f.write(content) 

186 

187 return {"message": "Chunk uploaded successfully!"} 

188 

189 

190@app.post("/api/v1/run-extractor") 

191async def run_extractor(request: RunExtractorRequest): 

192 """ 

193 Run the extractor with the provided filenames 

194 """ 

195 fem_file = request.fem_filename 

196 mpcf_file = request.mpcf_filename 

197 

198 print(f"Running extractor with files: {fem_file}, {mpcf_file}") 

199 

200 # Clear all Instances 

201 Node.reset() 

202 Element1D.reset() 

203 Element.reset_graph() 

204 Subcase.reset() 

205 MPC.reset() 

206 

207 input_folder = "data/uploads" 

208 output_folder = "data/output" 

209 blocksize = 8 

210 

211 mpc_force_extractor = MPCForceExtractor( 

212 input_folder + f"/{fem_file}", 

213 input_folder + f"/{mpcf_file}", 

214 output_folder + f"/FRONTEND_{fem_file.split('.')[0]}", 

215 ) 

216 

217 # Write Summary 

218 mpc_force_extractor.build_fem_and_subcase_data(blocksize) 

219 app.db = MPCDatabase() 

220 

221 # Implement your logic here to run the extractor using the provided filenames 

222 # For example, call your main routine here 

223 try: 

224 # Assuming you have a function called run_extractor_function 

225 # run_extractor_function(fem_file, mpcf_file) 

226 return {"message": "Extractor run successfully!"} 

227 except Exception as e: 

228 raise HTTPException(status_code=500, detail=str(e)) from e 

229 

230 

231# HMTL Section 

232# Route for the main page (MPC list) 

233@app.get("/mpcs", response_class=HTMLResponse) 

234async def read_mpcs(request: Request): 

235 """Render the mpcs.html template""" 

236 return templates.TemplateResponse("mpcs.html", {"request": request}) 

237 

238 

239# Route for nodes view (HTML) 

240@app.get("/nodes", response_class=HTMLResponse) 

241async def read_nodes(request: Request): 

242 """Render the nodes.html template""" 

243 return templates.TemplateResponse("nodes.html", {"request": request}) 

244 

245 

246# Route for main view (HTML) 

247@app.get("/", response_class=HTMLResponse) 

248async def read_root(request: Request): 

249 """Render the nodes.html template""" 

250 return templates.TemplateResponse("main.html", {"request": request})