Przeglądaj źródła

新增任务、模型 导出导出功能。新增wsl docker 配置

liuq 3 miesięcy temu
rodzic
commit
90a38858b9

+ 123 - 1
backend/app/api/endpoints/models_api.py

@@ -1,11 +1,14 @@
 from typing import Any, List
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
+from fastapi.responses import Response
 from sqlalchemy.orm import Session
 from backend.app.api import deps
 from backend.app.core.database import get_db
 from backend.app.models import sql_models
 from backend.app.schemas import schemas
 from backend.app.services.llm_agent import test_connection
+import openpyxl
+import io
 
 router = APIRouter()
 
@@ -18,6 +21,125 @@ def read_models(
 ) -> Any:
     return db.query(sql_models.ModelConfig).offset(skip).limit(limit).all()
 
+@router.get("/export_template")
+def export_template(
+    db: Session = Depends(get_db),
+    current_user: sql_models.User = Depends(deps.get_current_user)
+) -> Any:
+    """
+    Export model configuration template (including current data)
+    """
+    wb = openpyxl.Workbook()
+    ws = wb.active
+    ws.title = "Models"
+    
+    # Headers
+    headers = ["Name", "Base URL", "API Key", "Model Name"]
+    ws.append(headers)
+    
+    # Data
+    models = db.query(sql_models.ModelConfig).all()
+    for model in models:
+        ws.append([model.name, model.base_url, model.api_key, model.model_name])
+        
+    # Adjust column width
+    ws.column_dimensions['A'].width = 20
+    ws.column_dimensions['B'].width = 40
+    ws.column_dimensions['C'].width = 40
+    ws.column_dimensions['D'].width = 20
+    
+    # Save to buffer
+    buffer = io.BytesIO()
+    wb.save(buffer)
+    buffer.seek(0)
+    
+    return Response(
+        content=buffer.getvalue(),
+        media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+        headers={"Content-Disposition": "attachment; filename=models_template.xlsx"}
+    )
+
+@router.post("/import_template")
+async def import_template(
+    file: UploadFile = File(...),
+    db: Session = Depends(get_db),
+    current_user: sql_models.User = Depends(deps.get_current_user)
+) -> Any:
+    """
+    Import models from template
+    """
+    if not file.filename.endswith(('.xlsx', '.xls')):
+        raise HTTPException(status_code=400, detail="Invalid file format. Please upload an Excel file.")
+    
+    try:
+        contents = await file.read()
+        wb = openpyxl.load_workbook(io.BytesIO(contents))
+        ws = wb.active
+        
+        # Read headers
+        rows = list(ws.rows)
+        if not rows:
+            return {"message": "Empty file"}
+            
+        header_row = rows[0]
+        headers = [cell.value for cell in header_row]
+        
+        # Validate headers
+        required_headers = ["Name", "Base URL", "API Key", "Model Name"]
+        if not all(h in headers for h in required_headers):
+            raise HTTPException(status_code=400, detail=f"Invalid template format. Required columns: {', '.join(required_headers)}")
+            
+        name_idx = headers.index("Name")
+        url_idx = headers.index("Base URL")
+        key_idx = headers.index("API Key")
+        model_name_idx = headers.index("Model Name")
+        
+        added_count = 0
+        updated_count = 0
+        
+        # Process data
+        for row in rows[1:]:
+            name = row[name_idx].value
+            base_url = row[url_idx].value
+            api_key = row[key_idx].value
+            model_name = row[model_name_idx].value
+            
+            if not name or not base_url or not model_name:
+                continue
+                
+            # Check if model exists by Name
+            existing_model = db.query(sql_models.ModelConfig).filter(sql_models.ModelConfig.name == name).first()
+            
+            if existing_model:
+                # Update details
+                existing_model.base_url = base_url
+                existing_model.api_key = api_key
+                existing_model.model_name = model_name
+                updated_count += 1
+            else:
+                # Create new model
+                new_model = sql_models.ModelConfig(
+                    name=name,
+                    base_url=base_url,
+                    api_key=api_key,
+                    model_name=model_name
+                )
+                db.add(new_model)
+                added_count += 1
+                
+        db.commit()
+        
+        return {
+            "message": "Import successful",
+            "added": added_count,
+            "updated": updated_count
+        }
+        
+    except HTTPException as e:
+        raise e
+    except Exception as e:
+        raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
+
 @router.post("", response_model=schemas.ModelConfig)
 def create_model(
     *,

+ 179 - 1
backend/app/api/endpoints/tasks.py

@@ -1,11 +1,15 @@
 from typing import Any, List
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
+from fastapi.responses import Response
 from sqlalchemy.orm import Session
 from backend.app.api import deps
 from backend.app.core.database import get_db
 from backend.app.models import sql_models
 from backend.app.schemas import schemas
 from backend.app.services.scheduler import update_job, run_analysis_pipeline
+import openpyxl
+import io
+import json
 
 router = APIRouter()
 
@@ -18,6 +22,180 @@ def read_tasks(
 ) -> Any:
     return db.query(sql_models.Task).offset(skip).limit(limit).all()
 
+@router.get("/export_template")
+def export_template(
+    db: Session = Depends(get_db),
+    current_user: sql_models.User = Depends(deps.get_current_user)
+) -> Any:
+    """
+    Export task configuration template (including current data)
+    """
+    wb = openpyxl.Workbook()
+    ws = wb.active
+    ws.title = "Tasks"
+    
+    # Headers
+    headers = ["Task Name", "Model Config Name", "Camera Names", "Cron Expression", "Rules (JSON)"]
+    ws.append(headers)
+    
+    # Data
+    tasks = db.query(sql_models.Task).all()
+    
+    # Pre-fetch cameras for name lookup (optimization)
+    all_cameras = {c.id: c.name for c in db.query(sql_models.Camera).all()}
+    
+    for task in tasks:
+        # Get Model Name
+        model_name = task.model_config.name if task.model_config else ""
+        
+        # Get Camera Names
+        cam_names = []
+        if task.camera_ids:
+            for cid in task.camera_ids:
+                if cid in all_cameras:
+                    cam_names.append(all_cameras[cid])
+        camera_names_str = ",".join(cam_names)
+        
+        # Rules to JSON
+        rules_str = json.dumps(task.rules, ensure_ascii=False) if task.rules else "[]"
+        
+        ws.append([task.name, model_name, camera_names_str, task.cron_expression, rules_str])
+        
+    # Adjust column width
+    ws.column_dimensions['A'].width = 25
+    ws.column_dimensions['B'].width = 25
+    ws.column_dimensions['C'].width = 40
+    ws.column_dimensions['D'].width = 20
+    ws.column_dimensions['E'].width = 50
+    
+    # Save to buffer
+    buffer = io.BytesIO()
+    wb.save(buffer)
+    buffer.seek(0)
+    
+    return Response(
+        content=buffer.getvalue(),
+        media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+        headers={"Content-Disposition": "attachment; filename=tasks_template.xlsx"}
+    )
+
+@router.post("/import_template")
+async def import_template(
+    file: UploadFile = File(...),
+    db: Session = Depends(get_db),
+    current_user: sql_models.User = Depends(deps.get_current_user)
+) -> Any:
+    """
+    Import tasks from template
+    """
+    if not file.filename.endswith(('.xlsx', '.xls')):
+        raise HTTPException(status_code=400, detail="Invalid file format. Please upload an Excel file.")
+    
+    try:
+        contents = await file.read()
+        wb = openpyxl.load_workbook(io.BytesIO(contents))
+        ws = wb.active
+        
+        # Read headers
+        rows = list(ws.rows)
+        if not rows:
+            return {"message": "Empty file"}
+            
+        header_row = rows[0]
+        headers = [cell.value for cell in header_row]
+        
+        # Validate headers
+        required_headers = ["Task Name", "Model Config Name", "Camera Names", "Cron Expression", "Rules (JSON)"]
+        if not all(h in headers for h in required_headers):
+            raise HTTPException(status_code=400, detail=f"Invalid template format. Required columns: {', '.join(required_headers)}")
+            
+        name_idx = headers.index("Task Name")
+        model_idx = headers.index("Model Config Name")
+        cams_idx = headers.index("Camera Names")
+        cron_idx = headers.index("Cron Expression")
+        rules_idx = headers.index("Rules (JSON)")
+        
+        added_count = 0
+        updated_count = 0
+        
+        # Pre-fetch lookup maps
+        # Model Configs: Name -> ID
+        model_map = {m.name: m.id for m in db.query(sql_models.ModelConfig).all()}
+        # Cameras: Name -> ID
+        camera_map = {c.name: c.id for c in db.query(sql_models.Camera).all()}
+        
+        # Process data
+        for row in rows[1:]:
+            task_name = row[name_idx].value
+            model_config_name = row[model_idx].value
+            camera_names_str = row[cams_idx].value
+            cron_expr = row[cron_idx].value
+            rules_str = row[rules_idx].value
+            
+            if not task_name:
+                continue
+                
+            # Resolve Model ID
+            model_config_id = model_map.get(model_config_name)
+            if not model_config_id:
+                # Optionally skip or log warning. Here we skip.
+                continue
+                
+            # Resolve Camera IDs
+            camera_ids = []
+            if camera_names_str:
+                cam_names = [c.strip() for c in str(camera_names_str).split(",")]
+                for cname in cam_names:
+                    if cname in camera_map:
+                        camera_ids.append(camera_map[cname])
+            
+            # Parse Rules
+            try:
+                rules = json.loads(rules_str) if rules_str else []
+            except:
+                rules = []
+                
+            # Check if task exists
+            existing_task = db.query(sql_models.Task).filter(sql_models.Task.name == task_name).first()
+            
+            if existing_task:
+                # Update
+                existing_task.model_config_id = model_config_id
+                existing_task.camera_ids = camera_ids
+                existing_task.rules = rules
+                existing_task.cron_expression = cron_expr
+                
+                # If task is running, update scheduler
+                if existing_task.is_running:
+                     update_job(existing_task.id, existing_task.cron_expression, True)
+                     
+                updated_count += 1
+            else:
+                # Create
+                new_task = sql_models.Task(
+                    name=task_name,
+                    model_config_id=model_config_id,
+                    camera_ids=camera_ids,
+                    rules=rules,
+                    cron_expression=cron_expr,
+                    is_running=False
+                )
+                db.add(new_task)
+                added_count += 1
+                
+        db.commit()
+        
+        return {
+            "message": "Import successful",
+            "added": added_count,
+            "updated": updated_count
+        }
+        
+    except HTTPException as e:
+        raise e
+    except Exception as e:
+        raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
+
 @router.post("", response_model=schemas.Task)
 def create_task(
     *,

+ 45 - 0
docker-compose.wsl.yml

@@ -0,0 +1,45 @@
+version: '3.8'
+
+services:
+  app:
+    image: ai-watch-platform
+    build:
+      context: .
+      network: host
+    container_name: ai_watch_app
+    restart: always
+    ports:
+      - "8000:8000"
+    environment:
+      - MYSQL_SERVER=db
+      - MYSQL_PORT=3306
+      - MYSQL_USER=root
+      - MYSQL_PASSWORD=root_password
+      - MYSQL_DB=ai_watch
+      - TZ=Asia/Shanghai
+    volumes:
+      # Reports volume
+      - ./reports:/app/reports
+      # Snapshots volume (nested inside the app structure)
+      - ./snapshots:/app/backend/app/static/snapshots
+    depends_on:
+      - db
+
+  db:
+    image: mysql:8.0
+    container_name: ai_watch_db
+    restart: always
+    environment:
+      - MYSQL_ROOT_PASSWORD=root_password
+      - MYSQL_DATABASE=ai_watch
+      - TZ=Asia/Shanghai
+    volumes:
+      # Use named volume to avoid permission issues on Windows/WSL with mounted drives
+      - db_data:/var/lib/mysql
+    ports:
+      - "3307:3306"
+    command: --default-authentication-plugin=mysql_native_password
+
+volumes:
+  db_data:
+

+ 52 - 2
frontend/src/views/Models.vue

@@ -1,6 +1,7 @@
 <script setup lang="ts">
-import { ref, onMounted } from 'vue'
-import api from '../api'
+import { ref, onMounted, computed } from 'vue'
+import api, { getBaseURL } from '../api'
+import { useAuthStore } from '../stores/auth'
 import { ElMessage, ElMessageBox } from 'element-plus'
 
 interface ModelConfig {
@@ -11,6 +12,7 @@ interface ModelConfig {
   model_name: string
 }
 
+const authStore = useAuthStore()
 const models = ref<ModelConfig[]>([])
 const dialogVisible = ref(false)
 const testing = ref(false)
@@ -24,11 +26,46 @@ const form = ref({
   model_name: 'gpt-3.5-turbo'
 })
 
+const uploadHeaders = computed(() => {
+  return {
+    Authorization: `Bearer ${authStore.token}`
+  }
+})
+
+const uploadUrl = computed(() => `${getBaseURL()}/models/import_template`)
+
 const fetchModels = async () => {
   const res = await api.get('/models')
   models.value = res.data
 }
 
+const handleExport = async () => {
+  try {
+    const response = await api.get('/models/export_template', {
+      responseType: 'blob'
+    })
+    const url = window.URL.createObjectURL(new Blob([response.data]))
+    const link = document.createElement('a')
+    link.href = url
+    link.setAttribute('download', 'models_template.xlsx')
+    document.body.appendChild(link)
+    link.click()
+    document.body.removeChild(link)
+  } catch (e) {
+    ElMessage.error('导出失败')
+  }
+}
+
+const handleImportSuccess = (response: any) => {
+  ElMessage.success(`导入成功: 新增 ${response.added} 条, 更新 ${response.updated} 条`)
+  fetchModels()
+}
+
+const handleImportError = (error: any) => {
+  const msg = JSON.parse(error.message).detail || '导入失败'
+  ElMessage.error(msg)
+}
+
 const resetForm = () => {
   form.value = {
     name: '',
@@ -117,6 +154,19 @@ onMounted(fetchModels)
   <div class="models-view">
     <div class="toolbar">
       <el-button type="primary" @click="handleAdd">添加模型配置</el-button>
+      <el-button type="success" @click="handleExport">导出模板</el-button>
+      <el-upload
+        class="upload-demo"
+        :action="uploadUrl"
+        :headers="uploadHeaders"
+        :show-file-list="false"
+        :on-success="handleImportSuccess"
+        :on-error="handleImportError"
+        accept=".xlsx,.xls"
+        style="display: inline-block; margin-left: 12px;"
+      >
+        <el-button type="warning">导入模板</el-button>
+      </el-upload>
     </div>
     <el-table :data="models" style="width: 100%">
       <el-table-column prop="id" label="ID" width="60" />

+ 52 - 2
frontend/src/views/Tasks.vue

@@ -1,6 +1,7 @@
 <script setup lang="ts">
-import { ref, onMounted } from 'vue'
-import api from '../api'
+import { ref, onMounted, computed } from 'vue'
+import api, { getBaseURL } from '../api'
+import { useAuthStore } from '../stores/auth'
 import { ElMessage, ElMessageBox } from 'element-plus'
 
 interface TaskRule {
@@ -18,6 +19,7 @@ interface Task {
   camera_ids: number[]
 }
 
+const authStore = useAuthStore()
 const tasks = ref<Task[]>([])
 const dialogVisible = ref(false)
 const isEdit = ref(false)
@@ -39,6 +41,14 @@ const form = ref({
 const cameras = ref<any[]>([])
 const models = ref<any[]>([])
 
+const uploadHeaders = computed(() => {
+  return {
+    Authorization: `Bearer ${authStore.token}`
+  }
+})
+
+const uploadUrl = computed(() => `${getBaseURL()}/tasks/import_template`)
+
 const fetchTasks = async () => {
   const res = await api.get('/tasks')
   tasks.value = res.data
@@ -53,6 +63,33 @@ const fetchData = async () => {
   models.value = modelRes.data
 }
 
+const handleExport = async () => {
+  try {
+    const response = await api.get('/tasks/export_template', {
+      responseType: 'blob'
+    })
+    const url = window.URL.createObjectURL(new Blob([response.data]))
+    const link = document.createElement('a')
+    link.href = url
+    link.setAttribute('download', 'tasks_template.xlsx')
+    document.body.appendChild(link)
+    link.click()
+    document.body.removeChild(link)
+  } catch (e) {
+    ElMessage.error('导出失败')
+  }
+}
+
+const handleImportSuccess = (response: any) => {
+  ElMessage.success(`导入成功: 新增 ${response.added} 条, 更新 ${response.updated} 条`)
+  fetchTasks()
+}
+
+const handleImportError = (error: any) => {
+  const msg = JSON.parse(error.message).detail || '导入失败'
+  ElMessage.error(msg)
+}
+
 const resetForm = () => {
   form.value = {
     name: '',
@@ -247,6 +284,19 @@ onMounted(() => {
   <div class="tasks-view">
     <div class="toolbar">
       <el-button type="primary" @click="handleAdd">新建任务</el-button>
+      <el-button type="success" @click="handleExport">导出模板</el-button>
+      <el-upload
+        class="upload-demo"
+        :action="uploadUrl"
+        :headers="uploadHeaders"
+        :show-file-list="false"
+        :on-success="handleImportSuccess"
+        :on-error="handleImportError"
+        accept=".xlsx,.xls"
+        style="display: inline-block; margin-left: 12px;"
+      >
+        <el-button type="warning">导入模板</el-button>
+      </el-upload>
     </div>
     <el-table :data="tasks" style="width: 100%">
       <el-table-column prop="id" label="ID" width="60" />