| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- import os
- import shutil
- import zipfile
- import pandas as pd
- import io
- import csv
- from datetime import datetime
- from typing import List, Dict, Any
- from sqlalchemy.orm import Session
- from sqlalchemy import inspect, Boolean, Integer
- from fastapi import HTTPException
- from app.core.database import SessionLocal
- from app.models.application import Application
- from app.models.mapping import AppUserMapping
- from app.models.user import User
- from app.models.backup import BackupRecord, BackupType, BackupSettings
- from app.core.scheduler import scheduler
- from app.core.security import verify_password
- from app.services.captcha_service import CaptchaService
- from app.services.sms_service import SmsService
- from app.services.log_service import LogService
- from app.schemas.operation_log import ActionType
- from apscheduler.triggers.cron import CronTrigger
- # Ensure backup directory exists relative to backend root
- BACKUP_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "backups")
- class BackupService:
- @staticmethod
- def ensure_backup_dir():
- if not os.path.exists(BACKUP_DIR):
- os.makedirs(BACKUP_DIR, exist_ok=True)
- @staticmethod
- def create_backup(db: Session, backup_type: BackupType = BackupType.MANUAL) -> BackupRecord:
- BackupService.ensure_backup_dir()
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- base_filename = f"backup_{timestamp}"
- zip_filename = f"{base_filename}.zip"
- zip_filepath = os.path.join(BACKUP_DIR, zip_filename)
-
- # Temp dir for csvs
- temp_dir = os.path.join(BACKUP_DIR, f"temp_{timestamp}")
- os.makedirs(temp_dir, exist_ok=True)
-
- try:
- # 1. Export Applications
- apps = db.query(Application).all()
- apps_df = pd.read_sql(db.query(Application).statement, db.bind)
- apps_df.to_csv(os.path.join(temp_dir, "applications.csv"), index=False, encoding='utf-8-sig')
-
- # 2. Export Mappings (Separate by Application)
- mappings_dir = os.path.join(temp_dir, "mappings")
- os.makedirs(mappings_dir, exist_ok=True)
-
- for app in apps:
- # Use app_id (string) for filename, sanitized
- safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
- if not safe_app_id:
- safe_app_id = f"app_{app.id}"
-
- app_mappings_query = db.query(AppUserMapping).filter(AppUserMapping.app_id == app.id)
- # Check if there are any mappings
- if app_mappings_query.count() > 0:
- app_mappings_df = pd.read_sql(app_mappings_query.statement, db.bind)
- filename = f"mappings_{safe_app_id}.csv"
- app_mappings_df.to_csv(os.path.join(mappings_dir, filename), index=False, encoding='utf-8-sig')
- # 3. Export Users
- users_df = pd.read_sql(db.query(User).statement, db.bind)
- users_df.to_csv(os.path.join(temp_dir, "users.csv"), index=False, encoding='utf-8-sig')
-
- # 4. Zip
- with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
- for root, dirs, files in os.walk(temp_dir):
- for file in files:
- file_path = os.path.join(root, file)
- # Correctly calculate relative path for arcname
- arcname = os.path.relpath(file_path, temp_dir).replace(os.path.sep, '/')
- zipf.write(file_path, arcname)
-
- # Get file size
- file_size = os.path.getsize(zip_filepath)
-
- # Record
- backup_record = BackupRecord(
- filename=zip_filename,
- file_path=zip_filepath,
- backup_type=backup_type,
- content_types="users,applications,mappings",
- file_size=file_size
- )
- db.add(backup_record)
- db.commit()
- db.refresh(backup_record)
-
- return backup_record
-
- except Exception as e:
- # Cleanup zip if failed
- if os.path.exists(zip_filepath):
- os.remove(zip_filepath)
- raise e
-
- finally:
- # Cleanup temp
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
- @staticmethod
- def save_uploaded_backup(db: Session, file_content: bytes, filename: str) -> BackupRecord:
- BackupService.ensure_backup_dir()
-
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- # Sanitize filename
- safe_filename = "".join([c for c in filename if c.isalnum() or c in ('.', '-', '_')])
- if not safe_filename.endswith('.zip'):
- safe_filename += '.zip'
-
- # avoid overwrite
- base, ext = os.path.splitext(safe_filename)
- final_filename = f"{base}_{timestamp}{ext}"
-
- file_path = os.path.join(BACKUP_DIR, final_filename)
-
- with open(file_path, "wb") as f:
- f.write(file_content)
-
- file_size = os.path.getsize(file_path)
-
- # Determine content types (optional)
- content_types = "users,applications,mappings"
-
- backup_record = BackupRecord(
- filename=final_filename,
- file_path=file_path,
- backup_type=BackupType.MANUAL,
- content_types=content_types,
- file_size=file_size
- )
- db.add(backup_record)
- db.commit()
- db.refresh(backup_record)
-
- return backup_record
- @staticmethod
- def get_settings(db: Session) -> BackupSettings:
- settings = db.query(BackupSettings).first()
- if not settings:
- settings = BackupSettings(auto_backup_enabled=False, backup_time="02:00")
- db.add(settings)
- db.commit()
- db.refresh(settings)
- return settings
- @staticmethod
- def update_settings(db: Session, auto_backup_enabled: bool, backup_time: str):
- settings = BackupService.get_settings(db)
- settings.auto_backup_enabled = auto_backup_enabled
- settings.backup_time = backup_time
- settings.updated_at = datetime.now()
- db.commit()
- db.refresh(settings)
-
- # Update Scheduler
- BackupService.configure_scheduler(settings)
-
- return settings
- @staticmethod
- def configure_scheduler(settings: BackupSettings):
- job_id = "auto_backup_job"
- if scheduler.get_job(job_id):
- scheduler.remove_job(job_id)
-
- if settings.auto_backup_enabled:
- try:
- hour, minute = settings.backup_time.split(":")
- trigger = CronTrigger(hour=int(hour), minute=int(minute))
- scheduler.add_job(
- BackupService.perform_auto_backup,
- trigger=trigger,
- id=job_id,
- replace_existing=True
- )
- except ValueError:
- # Handle invalid time format if necessary
- pass
- @staticmethod
- def perform_auto_backup():
- db = SessionLocal()
- try:
- BackupService.create_backup(db, BackupType.AUTO)
- # Update last_backup_at
- settings = BackupService.get_settings(db)
- settings.last_backup_at = datetime.now()
- db.commit()
- finally:
- db.close()
-
- @staticmethod
- def init_scheduler(db: Session):
- settings = BackupService.get_settings(db)
- BackupService.configure_scheduler(settings)
- # --- Restore Logic ---
- @staticmethod
- def get_model_columns(model):
- return [c.key for c in inspect(model).mapper.column_attrs]
- @staticmethod
- def preview_restore(db: Session, backup_id: int, restore_type: str):
- backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
- if not backup or not os.path.exists(backup.file_path):
- raise HTTPException(status_code=404, detail="Backup file not found")
- csv_filename = ""
- model = None
-
- if restore_type == "APPLICATIONS":
- csv_filename = "applications.csv"
- model = Application
- elif restore_type == "USERS":
- csv_filename = "users.csv"
- model = User
- elif restore_type == "MAPPINGS":
- # For mappings, we just check the first file in mappings/ dir to get headers
- # Logic: list zip contents, find first file starting with mappings/
- model = AppUserMapping
- else:
- raise HTTPException(status_code=400, detail="Invalid restore type")
- db_columns = BackupService.get_model_columns(model)
- csv_headers = []
- mapping_files = [] # For MAPPINGS type
- try:
- with zipfile.ZipFile(backup.file_path, 'r') as zipf:
- target_file = None
-
- if restore_type == "MAPPINGS":
- for name in zipf.namelist():
- if (name.startswith("mappings/") and name.endswith(".csv")) or \
- (name.startswith("mappings_") and name.endswith(".csv") and "/" not in name):
-
- # Parse AppID for display
- app_id_display = "Unknown"
- filename_base = os.path.basename(name)
- if filename_base.startswith("mappings_"):
- # format: mappings_{safe_app_id}.csv
- # safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
- app_id_display = filename_base.replace("mappings_", "").replace(".csv", "")
-
- mapping_files.append({
- "filename": name,
- "app_id": app_id_display
- })
-
- # Use the first one found to get headers
- if not target_file:
- target_file = name
- else:
- if csv_filename in zipf.namelist():
- target_file = csv_filename
-
- if not target_file:
- # It's possible the backup doesn't have this file (e.g. empty mappings)
- return {"csv_headers": [], "db_columns": db_columns, "mapping_files": []}
- with zipf.open(target_file, 'r') as f:
- # zipf.open returns bytes, need text wrapper
- wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
- reader = csv.reader(wrapper)
- try:
- csv_headers = next(reader)
- except StopIteration:
- csv_headers = []
- except zipfile.BadZipFile:
- raise HTTPException(status_code=400, detail="Invalid backup file format")
- return {"csv_headers": csv_headers, "db_columns": db_columns, "mapping_files": mapping_files}
-
- @staticmethod
- def send_restore_sms(captcha_id: str, captcha_code: str, user: User):
- # 1. Verify Captcha
- if not CaptchaService.verify_captcha(captcha_id, captcha_code):
- raise HTTPException(status_code=400, detail="图形验证码错误")
-
- # 2. Send SMS to current user
- try:
- SmsService.send_code(user.mobile)
- except Exception as e:
- if hasattr(e, "detail"):
- raise e
- raise HTTPException(status_code=400, detail="发送短信失败")
- @staticmethod
- def restore_data(
- db: Session,
- current_user: User,
- backup_id: int,
- restore_type: str,
- field_mapping: Dict[str, str],
- password: str,
- sms_code: str,
- selected_files: List[str] = None
- ):
- # 1. Verification
- # Verify Password
- if not verify_password(password, current_user.password_hash):
- raise HTTPException(status_code=400, detail="密码错误")
-
- # Verify SMS Code
- if not SmsService.verify_code(current_user.mobile, sms_code):
- raise HTTPException(status_code=400, detail="短信验证码错误或已过期")
- backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
- if not backup or not os.path.exists(backup.file_path):
- raise HTTPException(status_code=404, detail="Backup file not found")
- # 2. Determine Model and Files
- model = None
- target_files = []
-
- if restore_type == "APPLICATIONS":
- target_files = ["applications.csv"]
- model = Application
- elif restore_type == "USERS":
- target_files = ["users.csv"]
- model = User
- elif restore_type == "MAPPINGS":
- with zipfile.ZipFile(backup.file_path, 'r') as zipf:
- # Find all mapping files (both nested and flat)
- all_mapping_files = []
- for name in zipf.namelist():
- if (name.startswith("mappings/") and name.endswith(".csv")) or \
- (name.startswith("mappings_") and name.endswith(".csv") and "/" not in name):
- all_mapping_files.append(name)
-
- # Filter if selected_files provided
- if selected_files:
- target_files = [f for f in all_mapping_files if f in selected_files]
- else:
- target_files = all_mapping_files
-
- model = AppUserMapping
- else:
- raise HTTPException(status_code=400, detail="Invalid restore type")
- # 3. Process Restore
- restored_count = 0
- try:
- # Get columns and types for type conversion
- mapper = inspect(model)
- columns = mapper.columns
- with zipfile.ZipFile(backup.file_path, 'r') as zipf:
- for filename in target_files:
- # Check if file exists in zip (considering flat structure compatibility for MAPPINGS)
- actual_filename = filename
- if filename not in zipf.namelist():
- # Try flat name if it's a mapping file
- if restore_type == "MAPPINGS" and filename.startswith("mappings/") and "/" in filename:
- flat_name = filename.split("/")[-1]
- if flat_name in zipf.namelist():
- actual_filename = flat_name
- else:
- continue
- else:
- continue
-
- with zipf.open(actual_filename, 'r') as f:
- wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
- # Use DictReader but we need to map headers manually based on field_mapping
- # Actually we can just read rows and map values
- reader = csv.DictReader(wrapper)
-
- for row in reader:
- # Construct data dict based on mapping
- # field_mapping: { "csv_col": "db_col" }
- data = {}
- for csv_col, db_col in field_mapping.items():
- if csv_col in row and db_col: # if db_col is not empty/none
- val = row[csv_col]
-
- # Type Conversion
- if db_col in columns:
- col_type = columns[db_col].type
-
- # Handle Boolean
- if isinstance(col_type, Boolean):
- if str(val).lower() in ('true', '1', 't', 'yes'):
- val = True
- elif str(val).lower() in ('false', '0', 'f', 'no'):
- val = False
- else:
- if val == "": val = None
-
- # Handle Integer
- elif isinstance(col_type, Integer):
- if val == "":
- val = None
- else:
- try:
- val = int(val)
- except ValueError:
- pass # Keep as is or ignore
- # Handle Empty Strings for others
- elif val == "":
- val = None
-
- data[db_col] = val
-
- # Upsert Logic
- # We assume 'id' is present if mapped.
- # If id exists, merge. Else add.
- if 'id' in data and data['id']:
- existing = db.query(model).filter(model.id == data['id']).first()
- if existing:
- for k, v in data.items():
- setattr(existing, k, v)
- else:
- obj = model(**data)
- db.add(obj)
- else:
- # No ID, just add? Might create duplicates.
- # Ideally we should map unique keys.
- # For now, let's assume ID is required for restore to work correctly with relationships
- obj = model(**data)
- db.add(obj)
-
- restored_count += 1
-
- db.commit()
-
- # Log Operation
- LogService.create_log(
- db=db,
- operator_id=current_user.id,
- action_type=ActionType.UPDATE, # Using UPDATE generic for Restore
- details={
- "event": "restore_data",
- "type": restore_type,
- "backup_id": backup_id,
- "count": restored_count
- }
- )
-
- return {"message": f"Successfully restored {restored_count} records", "count": restored_count}
- except Exception as e:
- db.rollback()
- raise HTTPException(status_code=500, detail=f"Restore failed: {str(e)}")
|