Coverage for mpcforces_extractor\visualization\api.py: 0%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-10-31 17:43 +0100

1import os 

2from typing import List 

3from fastapi import ( 

4 FastAPI, 

5 HTTPException, 

6 status, 

7 Request, 

8 Query, 

9 Form, 

10 UploadFile, 

11) 

12from fastapi.templating import Jinja2Templates 

13from fastapi.staticfiles import StaticFiles 

14from fastapi.responses import HTMLResponse 

15from mpcforces_extractor.database.database import ( 

16 MPCDatabase, 

17 MPCDBModel, 

18 NodeDBModel, 

19 SubcaseDBModel, 

20 RunExtractorRequest, 

21 DatabaseRequest, 

22) 

23from mpcforces_extractor.force_extractor import MPCForceExtractor 

24from mpcforces_extractor.datastructure.entities import Element, Node, Element1D 

25from mpcforces_extractor.datastructure.subcases import Subcase 

26from mpcforces_extractor.datastructure.rigids import MPC 

27 

28 

29ITEMS_PER_PAGE = 100 # Define a fixed number of items per page 

30UPLOAD_FOLDER = "data/uploads" 

31OUTPUT_FOLDER = "data/output" 

32 

33 

34# Setup Jinja2 templates 

35templates = Jinja2Templates( 

36 directory="mpcforces_extractor/visualization/frontend/templates" 

37) 

38app = FastAPI() 

39 

40# Mount the static files directory 

41app.mount( 

42 "/static", 

43 StaticFiles(directory="mpcforces_extractor/visualization/frontend/static"), 

44 name="static", 

45) 

46 

47 

48# API endpoint to get all MPCs 

49@app.get("/api/v1/mpcs", response_model=List[MPCDBModel]) 

50async def get_mpcs() -> List[MPCDBModel]: 

51 """Get all MPCs""" 

52 if not hasattr(app, "db"): 

53 raise HTTPException(status_code=500, detail="Database not initialized") 

54 return await app.db.get_mpcs() 

55 

56 

57# API endpoint to get a specific MPC by ID 

58@app.get("/api/v1/mpcs/{mpc_id}", response_model=MPCDBModel) 

59async def get_mpc(mpc_id: int) -> MPCDBModel: 

60 """Get info about a specific MPC""" 

61 if not hasattr(app, "db"): 

62 raise HTTPException(status_code=500, detail="Database not initialized") 

63 mpc = await app.db.get_mpc(mpc_id) 

64 if mpc is None: 

65 raise HTTPException( 

66 status_code=status.HTTP_404_NOT_FOUND, 

67 detail=f"MPC with id: {mpc_id} does not exist", 

68 ) 

69 

70 return mpc 

71 

72 

73# API endpoint to remove an MPC by ID 

74@app.delete("/api/v1/mpcs/{mpc_id}") 

75async def remove_mpc(mpc_id: int): 

76 """Remove an MPC""" 

77 if not hasattr(app, "db"): 

78 raise HTTPException(status_code=500, detail="Database not initialized") 

79 

80 await app.db.remove_mpc(mpc_id) 

81 return {"message": f"MPC with id: {mpc_id} removed"} 

82 

83 

84@app.get("/api/v1/nodes", response_model=List[NodeDBModel]) 

85async def get_nodes(page: int = Query(1, ge=1)) -> List[NodeDBModel]: 

86 """ 

87 Get nodes with pagination (fixed 100 items per page) 

88 """ 

89 

90 if not hasattr(app, "db"): 

91 raise HTTPException(status_code=500, detail="Database not initialized") 

92 

93 # Calculate offset based on the current page 

94 offset = (page - 1) * ITEMS_PER_PAGE 

95 

96 # Fetch nodes from the database with the calculated offset and limit (fixed at 100) 

97 nodes = await app.db.get_nodes(offset=offset, limit=ITEMS_PER_PAGE) 

98 

99 # Handle case when no nodes are found 

100 if not nodes: 

101 raise HTTPException(status_code=404, detail="No nodes found") 

102 

103 return nodes 

104 

105 

106@app.get("/api/v1/nodes/all", response_model=List[NodeDBModel]) 

107async def get_all_nodes() -> int: 

108 """ 

109 Get all nodes 

110 """ 

111 if not hasattr(app, "db"): 

112 raise HTTPException(status_code=500, detail="Database not initialized") 

113 

114 nodes = await app.db.get_all_nodes() 

115 

116 if not nodes: 

117 raise HTTPException(status_code=404, detail="No nodes found") 

118 

119 return nodes 

120 

121 

122@app.get("/api/v1/nodes/filter/{filter_input}", response_model=List[NodeDBModel]) 

123async def get_nodes_filtered(filter_input: str) -> List[NodeDBModel]: 

124 """ 

125 Get nodes filtered by a string, get it from all nodes, not paginated. 

126 The filter can be a range like '1-3' or comma-separated values like '1,2,3'. 

127 """ 

128 nodes = await app.db.get_all_nodes() 

129 filtered_nodes = [] 

130 

131 # Split the filter string by comma and process each part 

132 filter_parts = filter_input.split(",") 

133 for part in filter_parts: 

134 part = part.strip() # Trim whitespace 

135 if "-" in part: 

136 # Handle range like '1-3' 

137 start, end = part.split("-") 

138 try: 

139 start_id = int(start.strip()) 

140 end_id = int(end.strip()) 

141 filtered_nodes.extend( 

142 node for node in nodes if start_id <= node.id <= end_id 

143 ) 

144 except ValueError: 

145 raise HTTPException( 

146 status_code=400, detail="Invalid range in filter" 

147 ) from ValueError 

148 else: 

149 # Handle single ID 

150 try: 

151 node_id = int(part) 

152 node = next((node for node in nodes if node.id == node_id), None) 

153 if node: 

154 filtered_nodes.append(node) 

155 except ValueError: 

156 raise HTTPException( 

157 status_code=400, detail="Invalid ID in filter" 

158 ) from ValueError 

159 return filtered_nodes 

160 

161 

162# API endpoint to get all subcases 

163@app.get("/api/v1/subcases", response_model=List[SubcaseDBModel]) 

164async def get_subcases() -> List[SubcaseDBModel]: 

165 """Get all subcases""" 

166 if not hasattr(app, "db"): 

167 raise HTTPException(status_code=500, detail="Database not initialized") 

168 return await app.db.get_subcases() 

169 

170 

171@app.post("/api/v1/upload-chunk") 

172async def upload_chunk( 

173 file: UploadFile, filename: str = Form(...), offset: int = Form(...) 

174): 

175 """ 

176 Upload a chunk of a file 

177 """ 

178 file_path = os.path.join(UPLOAD_FOLDER, filename) 

179 

180 # check if the file exists, if so, delete it 

181 if os.path.exists(file_path): 

182 os.remove(file_path) 

183 

184 # Create the upload directory if it doesn't exist 

185 os.makedirs(UPLOAD_FOLDER, exist_ok=True) 

186 

187 # Open the file in append mode to write the chunk at the correct offset 

188 with open(file_path, "ab") as f: 

189 f.seek(offset) 

190 content = await file.read() 

191 f.write(content) 

192 

193 return {"message": "Chunk uploaded successfully!"} 

194 

195 

196@app.post("/api/v1/run-extractor") 

197async def run_extractor(request: RunExtractorRequest): 

198 """ 

199 Run the extractor with the provided filenames 

200 """ 

201 fem_file = request.fem_filename 

202 mpcf_file = request.mpcf_filename 

203 

204 print(f"Running extractor with files: {fem_file}, {mpcf_file}") 

205 

206 # Clear all Instances 

207 Node.reset() 

208 Element1D.reset() 

209 Element.reset_graph() 

210 Subcase.reset() 

211 MPC.reset() 

212 

213 blocksize = 8 

214 model_output_folder = OUTPUT_FOLDER + "/" + f"FRONTEND_{fem_file.split('.')[0]}" 

215 

216 mpc_force_extractor = MPCForceExtractor( 

217 UPLOAD_FOLDER + f"/{fem_file}", 

218 UPLOAD_FOLDER + f"/{mpcf_file}", 

219 model_output_folder, 

220 ) 

221 

222 # Write Summary 

223 mpc_force_extractor.build_fem_and_subcase_data(blocksize) 

224 app.db = MPCDatabase(model_output_folder + "/db.db") 

225 app.db.populate_database() 

226 

227 # Implement your logic here to run the extractor using the provided filenames 

228 # For example, call your main routine here 

229 try: 

230 # Assuming you have a function called run_extractor_function 

231 # run_extractor_function(fem_file, mpcf_file) 

232 return {"message": "Extractor run successfully!"} 

233 except Exception as e: 

234 raise HTTPException(status_code=500, detail=str(e)) from e 

235 

236 

237@app.post("/api/v1/import-db") 

238async def import_db(request: DatabaseRequest): 

239 """ 

240 Import a database (db) file and reinitialize the database 

241 """ 

242 # Get the uploaded file 

243 db_file = request.database_filename 

244 

245 db_path = UPLOAD_FOLDER + "/" + db_file 

246 

247 # Check if the file exists 

248 if not os.path.exists(db_path): 

249 raise HTTPException( 

250 status_code=404, detail=f"Database file {db_file} not found" 

251 ) 

252 

253 # Reinitialize the database 

254 if not hasattr(app, "db"): 

255 app.db = MPCDatabase(db_path) 

256 app.db.reinitialize_db(db_path) 

257 return {"message": "Database imported successfully!"} 

258 

259 

260# HMTL Section 

261# Route for the main page (MPC list) 

262@app.get("/mpcs", response_class=HTMLResponse) 

263async def read_mpcs(request: Request): 

264 """Render the mpcs.html template""" 

265 return templates.TemplateResponse("mpcs.html", {"request": request}) 

266 

267 

268# Route for nodes view (HTML) 

269@app.get("/nodes", response_class=HTMLResponse) 

270async def read_nodes(request: Request): 

271 """Render the nodes.html template""" 

272 return templates.TemplateResponse("nodes.html", {"request": request}) 

273 

274 

275# Route for main view (HTML) 

276@app.get("/", response_class=HTMLResponse) 

277async def read_root(request: Request): 

278 """Render the nodes.html template""" 

279 return templates.TemplateResponse("main.html", {"request": request})