tasks.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from typing import Any, List
  2. from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
  3. from fastapi.responses import Response
  4. from sqlalchemy.orm import Session
  5. from backend.app.api import deps
  6. from backend.app.core.database import get_db
  7. from backend.app.models import sql_models
  8. from backend.app.schemas import schemas
  9. from backend.app.services.scheduler import update_job, run_analysis_pipeline
  10. import openpyxl
  11. import io
  12. import json
  13. router = APIRouter()
  14. @router.get("", response_model=List[schemas.Task])
  15. def read_tasks(
  16. db: Session = Depends(get_db),
  17. skip: int = 0,
  18. limit: int = 100,
  19. current_user: sql_models.User = Depends(deps.get_current_user)
  20. ) -> Any:
  21. return db.query(sql_models.Task).offset(skip).limit(limit).all()
  22. @router.get("/export_template")
  23. def export_template(
  24. db: Session = Depends(get_db),
  25. current_user: sql_models.User = Depends(deps.get_current_user)
  26. ) -> Any:
  27. """
  28. Export task configuration template (including current data)
  29. """
  30. wb = openpyxl.Workbook()
  31. ws = wb.active
  32. ws.title = "Tasks"
  33. # Headers
  34. headers = ["Task Name", "Model Config Name", "Camera Names", "Cron Expression", "Rules (JSON)"]
  35. ws.append(headers)
  36. # Data
  37. tasks = db.query(sql_models.Task).all()
  38. # Pre-fetch cameras for name lookup (optimization)
  39. all_cameras = {c.id: c.name for c in db.query(sql_models.Camera).all()}
  40. for task in tasks:
  41. # Get Model Name
  42. model_name = task.model_config.name if task.model_config else ""
  43. # Get Camera Names
  44. cam_names = []
  45. if task.camera_ids:
  46. for cid in task.camera_ids:
  47. if cid in all_cameras:
  48. cam_names.append(all_cameras[cid])
  49. camera_names_str = ",".join(cam_names)
  50. # Rules to JSON
  51. rules_str = json.dumps(task.rules, ensure_ascii=False) if task.rules else "[]"
  52. ws.append([task.name, model_name, camera_names_str, task.cron_expression, rules_str])
  53. # Adjust column width
  54. ws.column_dimensions['A'].width = 25
  55. ws.column_dimensions['B'].width = 25
  56. ws.column_dimensions['C'].width = 40
  57. ws.column_dimensions['D'].width = 20
  58. ws.column_dimensions['E'].width = 50
  59. # Save to buffer
  60. buffer = io.BytesIO()
  61. wb.save(buffer)
  62. buffer.seek(0)
  63. return Response(
  64. content=buffer.getvalue(),
  65. media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  66. headers={"Content-Disposition": "attachment; filename=tasks_template.xlsx"}
  67. )
  68. @router.post("/import_template")
  69. async def import_template(
  70. file: UploadFile = File(...),
  71. db: Session = Depends(get_db),
  72. current_user: sql_models.User = Depends(deps.get_current_user)
  73. ) -> Any:
  74. """
  75. Import tasks from template
  76. """
  77. if not file.filename.endswith(('.xlsx', '.xls')):
  78. raise HTTPException(status_code=400, detail="Invalid file format. Please upload an Excel file.")
  79. try:
  80. contents = await file.read()
  81. wb = openpyxl.load_workbook(io.BytesIO(contents))
  82. ws = wb.active
  83. # Read headers
  84. rows = list(ws.rows)
  85. if not rows:
  86. return {"message": "Empty file"}
  87. header_row = rows[0]
  88. headers = [cell.value for cell in header_row]
  89. # Validate headers
  90. required_headers = ["Task Name", "Model Config Name", "Camera Names", "Cron Expression", "Rules (JSON)"]
  91. if not all(h in headers for h in required_headers):
  92. raise HTTPException(status_code=400, detail=f"Invalid template format. Required columns: {', '.join(required_headers)}")
  93. name_idx = headers.index("Task Name")
  94. model_idx = headers.index("Model Config Name")
  95. cams_idx = headers.index("Camera Names")
  96. cron_idx = headers.index("Cron Expression")
  97. rules_idx = headers.index("Rules (JSON)")
  98. added_count = 0
  99. updated_count = 0
  100. # Pre-fetch lookup maps
  101. # Model Configs: Name -> ID
  102. model_map = {m.name: m.id for m in db.query(sql_models.ModelConfig).all()}
  103. # Cameras: Name -> ID
  104. camera_map = {c.name: c.id for c in db.query(sql_models.Camera).all()}
  105. # Process data
  106. for row in rows[1:]:
  107. task_name = row[name_idx].value
  108. model_config_name = row[model_idx].value
  109. camera_names_str = row[cams_idx].value
  110. cron_expr = row[cron_idx].value
  111. rules_str = row[rules_idx].value
  112. if not task_name:
  113. continue
  114. # Resolve Model ID
  115. model_config_id = model_map.get(model_config_name)
  116. if not model_config_id:
  117. # Optionally skip or log warning. Here we skip.
  118. continue
  119. # Resolve Camera IDs
  120. camera_ids = []
  121. if camera_names_str:
  122. cam_names = [c.strip() for c in str(camera_names_str).split(",")]
  123. for cname in cam_names:
  124. if cname in camera_map:
  125. camera_ids.append(camera_map[cname])
  126. # Parse Rules
  127. try:
  128. rules = json.loads(rules_str) if rules_str else []
  129. except:
  130. rules = []
  131. # Check if task exists
  132. existing_task = db.query(sql_models.Task).filter(sql_models.Task.name == task_name).first()
  133. if existing_task:
  134. # Update
  135. existing_task.model_config_id = model_config_id
  136. existing_task.camera_ids = camera_ids
  137. existing_task.rules = rules
  138. existing_task.cron_expression = cron_expr
  139. # If task is running, update scheduler
  140. if existing_task.is_running:
  141. update_job(existing_task.id, existing_task.cron_expression, True)
  142. updated_count += 1
  143. else:
  144. # Create
  145. new_task = sql_models.Task(
  146. name=task_name,
  147. model_config_id=model_config_id,
  148. camera_ids=camera_ids,
  149. rules=rules,
  150. cron_expression=cron_expr,
  151. is_running=False
  152. )
  153. db.add(new_task)
  154. added_count += 1
  155. db.commit()
  156. return {
  157. "message": "Import successful",
  158. "added": added_count,
  159. "updated": updated_count
  160. }
  161. except HTTPException as e:
  162. raise e
  163. except Exception as e:
  164. raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
  165. @router.post("", response_model=schemas.Task)
  166. def create_task(
  167. *,
  168. db: Session = Depends(get_db),
  169. task_in: schemas.TaskCreate,
  170. current_user: sql_models.User = Depends(deps.get_current_user)
  171. ) -> Any:
  172. task = sql_models.Task(**task_in.model_dump())
  173. db.add(task)
  174. db.commit()
  175. db.refresh(task)
  176. return task
  177. @router.put("/{id}", response_model=schemas.Task)
  178. def update_task(
  179. *,
  180. db: Session = Depends(get_db),
  181. id: int,
  182. task_in: schemas.TaskUpdate,
  183. current_user: sql_models.User = Depends(deps.get_current_user)
  184. ) -> Any:
  185. task = db.query(sql_models.Task).filter(sql_models.Task.id == id).first()
  186. if not task:
  187. raise HTTPException(status_code=404, detail="Task not found")
  188. # Update fields
  189. for key, value in task_in.model_dump().items():
  190. setattr(task, key, value)
  191. db.commit()
  192. db.refresh(task)
  193. # Update scheduler if running
  194. if task.is_running:
  195. update_job(task.id, task.cron_expression, True)
  196. return task
  197. @router.delete("/{id}", response_model=schemas.Task)
  198. def delete_task(
  199. *,
  200. db: Session = Depends(get_db),
  201. id: int,
  202. current_user: sql_models.User = Depends(deps.get_current_user)
  203. ) -> Any:
  204. task = db.query(sql_models.Task).filter(sql_models.Task.id == id).first()
  205. if not task:
  206. raise HTTPException(status_code=404, detail="Task not found")
  207. # Stop scheduler
  208. update_job(task.id, task.cron_expression, False)
  209. # Delete associated logs first
  210. db.query(sql_models.TaskLog).filter(sql_models.TaskLog.task_id == id).delete()
  211. db.delete(task)
  212. db.commit()
  213. return task
  214. @router.post("/{id}/toggle", response_model=schemas.Task)
  215. def toggle_task(
  216. *,
  217. db: Session = Depends(get_db),
  218. id: int,
  219. toggle_in: schemas.TaskToggle,
  220. current_user: sql_models.User = Depends(deps.get_current_user)
  221. ) -> Any:
  222. task = db.query(sql_models.Task).filter(sql_models.Task.id == id).first()
  223. if not task:
  224. raise HTTPException(status_code=404, detail="Task not found")
  225. task.is_running = toggle_in.running
  226. db.commit()
  227. db.refresh(task)
  228. update_job(task.id, task.cron_expression, task.is_running)
  229. return task
  230. @router.post("/test")
  231. async def test_task(
  232. *,
  233. db: Session = Depends(get_db),
  234. task_in: schemas.TaskTest,
  235. current_user: sql_models.User = Depends(deps.get_current_user)
  236. ) -> Any:
  237. # 1. Get Model Config
  238. model_config = db.query(sql_models.ModelConfig).filter(sql_models.ModelConfig.id == task_in.model_config_id).first()
  239. if not model_config:
  240. raise HTTPException(status_code=404, detail="Model configuration not found")
  241. # 2. Run Pipeline (without saving to DB, or maybe we want to save logs for test?
  242. # User requirement: "Output logs to console".
  243. # Let's execute the pipeline. run_analysis_pipeline can skip DB save if we add a param,
  244. # but currently it saves. Saving test logs to DB is actually fine/useful.)
  245. # We will use task_id=0 for temporary test or just pass dummy ID.
  246. # But run_analysis_pipeline writes log with task_id.
  247. # If we pass 0, it might violate FK if we save to DB.
  248. # So for test, we might NOT save to DB to avoid FK error, or we create a dummy task.
  249. # Let's modify run_analysis_pipeline to optionally save to DB.
  250. # Wait, I already updated scheduler.py to accept save_to_db param.
  251. results = await run_analysis_pipeline(
  252. task_id=0, # Dummy ID, won't be used if save_to_db=False
  253. camera_ids=task_in.camera_ids,
  254. model_config=model_config,
  255. rules=task_in.rules,
  256. save_to_db=False # Don't save test runs to DB to avoid polluting logs/FK issues
  257. )
  258. return {"status": "success", "results": results}