backup_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import os
  2. import shutil
  3. import zipfile
  4. import pandas as pd
  5. import io
  6. import csv
  7. from datetime import datetime
  8. from typing import List, Dict, Any
  9. from sqlalchemy.orm import Session
  10. from sqlalchemy import inspect
  11. from fastapi import HTTPException
  12. from app.core.database import SessionLocal
  13. from app.models.application import Application
  14. from app.models.mapping import AppUserMapping
  15. from app.models.user import User
  16. from app.models.backup import BackupRecord, BackupType, BackupSettings
  17. from app.core.scheduler import scheduler
  18. from app.core.security import verify_password
  19. from app.services.captcha_service import CaptchaService
  20. from app.services.log_service import LogService
  21. from app.schemas.operation_log import ActionType
  22. from apscheduler.triggers.cron import CronTrigger
  23. # Ensure backup directory exists relative to backend root
  24. BACKUP_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "backups")
  25. class BackupService:
  26. @staticmethod
  27. def ensure_backup_dir():
  28. if not os.path.exists(BACKUP_DIR):
  29. os.makedirs(BACKUP_DIR, exist_ok=True)
  30. @staticmethod
  31. def create_backup(db: Session, backup_type: BackupType = BackupType.MANUAL) -> BackupRecord:
  32. BackupService.ensure_backup_dir()
  33. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  34. base_filename = f"backup_{timestamp}"
  35. zip_filename = f"{base_filename}.zip"
  36. zip_filepath = os.path.join(BACKUP_DIR, zip_filename)
  37. # Temp dir for csvs
  38. temp_dir = os.path.join(BACKUP_DIR, f"temp_{timestamp}")
  39. os.makedirs(temp_dir, exist_ok=True)
  40. try:
  41. # 1. Export Applications
  42. apps = db.query(Application).all()
  43. apps_df = pd.read_sql(db.query(Application).statement, db.bind)
  44. apps_df.to_csv(os.path.join(temp_dir, "applications.csv"), index=False, encoding='utf-8-sig')
  45. # 2. Export Mappings (Separate by Application)
  46. mappings_dir = os.path.join(temp_dir, "mappings")
  47. os.makedirs(mappings_dir, exist_ok=True)
  48. for app in apps:
  49. # Use app_id (string) for filename, sanitized
  50. safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
  51. if not safe_app_id:
  52. safe_app_id = f"app_{app.id}"
  53. app_mappings_query = db.query(AppUserMapping).filter(AppUserMapping.app_id == app.id)
  54. # Check if there are any mappings
  55. if app_mappings_query.count() > 0:
  56. app_mappings_df = pd.read_sql(app_mappings_query.statement, db.bind)
  57. filename = f"mappings_{safe_app_id}.csv"
  58. app_mappings_df.to_csv(os.path.join(mappings_dir, filename), index=False, encoding='utf-8-sig')
  59. # 3. Export Users
  60. users_df = pd.read_sql(db.query(User).statement, db.bind)
  61. users_df.to_csv(os.path.join(temp_dir, "users.csv"), index=False, encoding='utf-8-sig')
  62. # 4. Zip
  63. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  64. for root, dirs, files in os.walk(temp_dir):
  65. for file in files:
  66. file_path = os.path.join(root, file)
  67. zipf.write(file_path, file.replace(temp_dir, "").lstrip(os.sep))
  68. # Get file size
  69. file_size = os.path.getsize(zip_filepath)
  70. # Record
  71. backup_record = BackupRecord(
  72. filename=zip_filename,
  73. file_path=zip_filepath,
  74. backup_type=backup_type,
  75. content_types="users,applications,mappings",
  76. file_size=file_size
  77. )
  78. db.add(backup_record)
  79. db.commit()
  80. db.refresh(backup_record)
  81. return backup_record
  82. except Exception as e:
  83. # Cleanup zip if failed
  84. if os.path.exists(zip_filepath):
  85. os.remove(zip_filepath)
  86. raise e
  87. finally:
  88. # Cleanup temp
  89. if os.path.exists(temp_dir):
  90. shutil.rmtree(temp_dir)
  91. @staticmethod
  92. def get_settings(db: Session) -> BackupSettings:
  93. settings = db.query(BackupSettings).first()
  94. if not settings:
  95. settings = BackupSettings(auto_backup_enabled=False, backup_time="02:00")
  96. db.add(settings)
  97. db.commit()
  98. db.refresh(settings)
  99. return settings
  100. @staticmethod
  101. def update_settings(db: Session, auto_backup_enabled: bool, backup_time: str):
  102. settings = BackupService.get_settings(db)
  103. settings.auto_backup_enabled = auto_backup_enabled
  104. settings.backup_time = backup_time
  105. settings.updated_at = datetime.now()
  106. db.commit()
  107. db.refresh(settings)
  108. # Update Scheduler
  109. BackupService.configure_scheduler(settings)
  110. return settings
  111. @staticmethod
  112. def configure_scheduler(settings: BackupSettings):
  113. job_id = "auto_backup_job"
  114. if scheduler.get_job(job_id):
  115. scheduler.remove_job(job_id)
  116. if settings.auto_backup_enabled:
  117. try:
  118. hour, minute = settings.backup_time.split(":")
  119. trigger = CronTrigger(hour=int(hour), minute=int(minute))
  120. scheduler.add_job(
  121. BackupService.perform_auto_backup,
  122. trigger=trigger,
  123. id=job_id,
  124. replace_existing=True
  125. )
  126. except ValueError:
  127. # Handle invalid time format if necessary
  128. pass
  129. @staticmethod
  130. def perform_auto_backup():
  131. db = SessionLocal()
  132. try:
  133. BackupService.create_backup(db, BackupType.AUTO)
  134. # Update last_backup_at
  135. settings = BackupService.get_settings(db)
  136. settings.last_backup_at = datetime.now()
  137. db.commit()
  138. finally:
  139. db.close()
  140. @staticmethod
  141. def init_scheduler(db: Session):
  142. settings = BackupService.get_settings(db)
  143. BackupService.configure_scheduler(settings)
  144. # --- Restore Logic ---
  145. @staticmethod
  146. def get_model_columns(model):
  147. return [c.key for c in inspect(model).mapper.column_attrs]
  148. @staticmethod
  149. def preview_restore(db: Session, backup_id: int, restore_type: str):
  150. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  151. if not backup or not os.path.exists(backup.file_path):
  152. raise HTTPException(status_code=404, detail="Backup file not found")
  153. csv_filename = ""
  154. model = None
  155. if restore_type == "APPLICATIONS":
  156. csv_filename = "applications.csv"
  157. model = Application
  158. elif restore_type == "USERS":
  159. csv_filename = "users.csv"
  160. model = User
  161. elif restore_type == "MAPPINGS":
  162. # For mappings, we just check the first file in mappings/ dir to get headers
  163. # Logic: list zip contents, find first file starting with mappings/
  164. model = AppUserMapping
  165. else:
  166. raise HTTPException(status_code=400, detail="Invalid restore type")
  167. db_columns = BackupService.get_model_columns(model)
  168. csv_headers = []
  169. try:
  170. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  171. target_file = None
  172. if restore_type == "MAPPINGS":
  173. for name in zipf.namelist():
  174. if name.startswith("mappings/") and name.endswith(".csv"):
  175. target_file = name
  176. break
  177. else:
  178. if csv_filename in zipf.namelist():
  179. target_file = csv_filename
  180. if not target_file:
  181. # It's possible the backup doesn't have this file (e.g. empty mappings)
  182. return {"csv_headers": [], "db_columns": db_columns}
  183. with zipf.open(target_file, 'r') as f:
  184. # zipf.open returns bytes, need text wrapper
  185. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  186. reader = csv.reader(wrapper)
  187. try:
  188. csv_headers = next(reader)
  189. except StopIteration:
  190. csv_headers = []
  191. except zipfile.BadZipFile:
  192. raise HTTPException(status_code=400, detail="Invalid backup file format")
  193. return {"csv_headers": csv_headers, "db_columns": db_columns}
  194. @staticmethod
  195. def restore_data(
  196. db: Session,
  197. current_user: User,
  198. backup_id: int,
  199. restore_type: str,
  200. field_mapping: Dict[str, str],
  201. password: str,
  202. captcha_id: str,
  203. captcha_code: str
  204. ):
  205. # 1. Verification
  206. if not CaptchaService.verify_captcha(captcha_id, captcha_code):
  207. raise HTTPException(status_code=400, detail="验证码错误")
  208. if not verify_password(password, current_user.password_hash):
  209. raise HTTPException(status_code=400, detail="密码错误")
  210. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  211. if not backup or not os.path.exists(backup.file_path):
  212. raise HTTPException(status_code=404, detail="Backup file not found")
  213. # 2. Determine Model and Files
  214. model = None
  215. target_files = []
  216. if restore_type == "APPLICATIONS":
  217. target_files = ["applications.csv"]
  218. model = Application
  219. elif restore_type == "USERS":
  220. target_files = ["users.csv"]
  221. model = User
  222. elif restore_type == "MAPPINGS":
  223. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  224. target_files = [name for name in zipf.namelist() if name.startswith("mappings/") and name.endswith(".csv")]
  225. model = AppUserMapping
  226. else:
  227. raise HTTPException(status_code=400, detail="Invalid restore type")
  228. # 3. Process Restore
  229. restored_count = 0
  230. try:
  231. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  232. for filename in target_files:
  233. if filename not in zipf.namelist():
  234. continue
  235. with zipf.open(filename, 'r') as f:
  236. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  237. # Use DictReader but we need to map headers manually based on field_mapping
  238. # Actually we can just read rows and map values
  239. reader = csv.DictReader(wrapper)
  240. for row in reader:
  241. # Construct data dict based on mapping
  242. # field_mapping: { "csv_col": "db_col" }
  243. data = {}
  244. for csv_col, db_col in field_mapping.items():
  245. if csv_col in row and db_col: # if db_col is not empty/none
  246. val = row[csv_col]
  247. # Handle special conversions if needed (e.g. boolean, nulls)
  248. if val == "":
  249. val = None
  250. data[db_col] = val
  251. # Upsert Logic
  252. # We assume 'id' is present if mapped.
  253. # If id exists, merge. Else add.
  254. if 'id' in data and data['id']:
  255. existing = db.query(model).filter(model.id == data['id']).first()
  256. if existing:
  257. for k, v in data.items():
  258. setattr(existing, k, v)
  259. else:
  260. obj = model(**data)
  261. db.add(obj)
  262. else:
  263. # No ID, just add? Might create duplicates.
  264. # Ideally we should map unique keys.
  265. # For now, let's assume ID is required for restore to work correctly with relationships
  266. obj = model(**data)
  267. db.add(obj)
  268. restored_count += 1
  269. db.commit()
  270. # Log Operation
  271. LogService.create_log(
  272. db=db,
  273. operator_id=current_user.id,
  274. action_type=ActionType.UPDATE, # Using UPDATE generic for Restore
  275. details={
  276. "event": "restore_data",
  277. "type": restore_type,
  278. "backup_id": backup_id,
  279. "count": restored_count
  280. }
  281. )
  282. return {"message": f"Successfully restored {restored_count} records", "count": restored_count}
  283. except Exception as e:
  284. db.rollback()
  285. raise HTTPException(status_code=500, detail=f"Restore failed: {str(e)}")