backup_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. import os
  2. import shutil
  3. import zipfile
  4. import pandas as pd
  5. import io
  6. import csv
  7. from datetime import datetime
  8. from typing import List, Dict, Any
  9. from sqlalchemy.orm import Session
  10. from sqlalchemy import inspect, Boolean, Integer
  11. from fastapi import HTTPException
  12. from app.core.database import SessionLocal
  13. from app.models.application import Application
  14. from app.models.mapping import AppUserMapping
  15. from app.models.user import User
  16. from app.models.backup import BackupRecord, BackupType, BackupSettings
  17. from app.core.scheduler import scheduler
  18. from app.core.security import verify_password
  19. from app.services.captcha_service import CaptchaService
  20. from app.services.sms_service import SmsService
  21. from app.services.log_service import LogService
  22. from app.schemas.operation_log import ActionType
  23. from apscheduler.triggers.cron import CronTrigger
  24. # Ensure backup directory exists relative to backend root
  25. BACKUP_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "backups")
  26. class BackupService:
  27. @staticmethod
  28. def ensure_backup_dir():
  29. if not os.path.exists(BACKUP_DIR):
  30. os.makedirs(BACKUP_DIR, exist_ok=True)
  31. @staticmethod
  32. def create_backup(db: Session, backup_type: BackupType = BackupType.MANUAL) -> BackupRecord:
  33. BackupService.ensure_backup_dir()
  34. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  35. base_filename = f"backup_{timestamp}"
  36. zip_filename = f"{base_filename}.zip"
  37. zip_filepath = os.path.join(BACKUP_DIR, zip_filename)
  38. # Temp dir for csvs
  39. temp_dir = os.path.join(BACKUP_DIR, f"temp_{timestamp}")
  40. os.makedirs(temp_dir, exist_ok=True)
  41. try:
  42. # 1. Export Applications
  43. apps = db.query(Application).all()
  44. apps_df = pd.read_sql(db.query(Application).statement, db.bind)
  45. apps_df.to_csv(os.path.join(temp_dir, "applications.csv"), index=False, encoding='utf-8-sig')
  46. # 2. Export Mappings (Separate by Application)
  47. mappings_dir = os.path.join(temp_dir, "mappings")
  48. os.makedirs(mappings_dir, exist_ok=True)
  49. for app in apps:
  50. # Use app_id (string) for filename, sanitized
  51. safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
  52. if not safe_app_id:
  53. safe_app_id = f"app_{app.id}"
  54. app_mappings_query = db.query(AppUserMapping).filter(AppUserMapping.app_id == app.id)
  55. # Check if there are any mappings
  56. if app_mappings_query.count() > 0:
  57. app_mappings_df = pd.read_sql(app_mappings_query.statement, db.bind)
  58. filename = f"mappings_{safe_app_id}.csv"
  59. app_mappings_df.to_csv(os.path.join(mappings_dir, filename), index=False, encoding='utf-8-sig')
  60. # 3. Export Users
  61. users_df = pd.read_sql(db.query(User).statement, db.bind)
  62. users_df.to_csv(os.path.join(temp_dir, "users.csv"), index=False, encoding='utf-8-sig')
  63. # 4. Zip
  64. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  65. for root, dirs, files in os.walk(temp_dir):
  66. for file in files:
  67. file_path = os.path.join(root, file)
  68. # Correctly calculate relative path for arcname
  69. arcname = os.path.relpath(file_path, temp_dir).replace(os.path.sep, '/')
  70. zipf.write(file_path, arcname)
  71. # Get file size
  72. file_size = os.path.getsize(zip_filepath)
  73. # Record
  74. backup_record = BackupRecord(
  75. filename=zip_filename,
  76. file_path=zip_filepath,
  77. backup_type=backup_type,
  78. content_types="users,applications,mappings",
  79. file_size=file_size
  80. )
  81. db.add(backup_record)
  82. db.commit()
  83. db.refresh(backup_record)
  84. return backup_record
  85. except Exception as e:
  86. # Cleanup zip if failed
  87. if os.path.exists(zip_filepath):
  88. os.remove(zip_filepath)
  89. raise e
  90. finally:
  91. # Cleanup temp
  92. if os.path.exists(temp_dir):
  93. shutil.rmtree(temp_dir)
  94. @staticmethod
  95. def save_uploaded_backup(db: Session, file_content: bytes, filename: str) -> BackupRecord:
  96. BackupService.ensure_backup_dir()
  97. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  98. # Sanitize filename
  99. safe_filename = "".join([c for c in filename if c.isalnum() or c in ('.', '-', '_')])
  100. if not safe_filename.endswith('.zip'):
  101. safe_filename += '.zip'
  102. # avoid overwrite
  103. base, ext = os.path.splitext(safe_filename)
  104. final_filename = f"{base}_{timestamp}{ext}"
  105. file_path = os.path.join(BACKUP_DIR, final_filename)
  106. with open(file_path, "wb") as f:
  107. f.write(file_content)
  108. file_size = os.path.getsize(file_path)
  109. # Determine content types (optional)
  110. content_types = "users,applications,mappings"
  111. backup_record = BackupRecord(
  112. filename=final_filename,
  113. file_path=file_path,
  114. backup_type=BackupType.MANUAL,
  115. content_types=content_types,
  116. file_size=file_size
  117. )
  118. db.add(backup_record)
  119. db.commit()
  120. db.refresh(backup_record)
  121. return backup_record
  122. @staticmethod
  123. def get_settings(db: Session) -> BackupSettings:
  124. settings = db.query(BackupSettings).first()
  125. if not settings:
  126. settings = BackupSettings(auto_backup_enabled=False, backup_time="02:00")
  127. db.add(settings)
  128. db.commit()
  129. db.refresh(settings)
  130. return settings
  131. @staticmethod
  132. def update_settings(db: Session, auto_backup_enabled: bool, backup_time: str):
  133. settings = BackupService.get_settings(db)
  134. settings.auto_backup_enabled = auto_backup_enabled
  135. settings.backup_time = backup_time
  136. settings.updated_at = datetime.now()
  137. db.commit()
  138. db.refresh(settings)
  139. # Update Scheduler
  140. BackupService.configure_scheduler(settings)
  141. return settings
  142. @staticmethod
  143. def configure_scheduler(settings: BackupSettings):
  144. job_id = "auto_backup_job"
  145. if scheduler.get_job(job_id):
  146. scheduler.remove_job(job_id)
  147. if settings.auto_backup_enabled:
  148. try:
  149. hour, minute = settings.backup_time.split(":")
  150. trigger = CronTrigger(hour=int(hour), minute=int(minute))
  151. scheduler.add_job(
  152. BackupService.perform_auto_backup,
  153. trigger=trigger,
  154. id=job_id,
  155. replace_existing=True
  156. )
  157. except ValueError:
  158. # Handle invalid time format if necessary
  159. pass
  160. @staticmethod
  161. def perform_auto_backup():
  162. db = SessionLocal()
  163. try:
  164. BackupService.create_backup(db, BackupType.AUTO)
  165. # Update last_backup_at
  166. settings = BackupService.get_settings(db)
  167. settings.last_backup_at = datetime.now()
  168. db.commit()
  169. finally:
  170. db.close()
  171. @staticmethod
  172. def init_scheduler(db: Session):
  173. settings = BackupService.get_settings(db)
  174. BackupService.configure_scheduler(settings)
  175. # --- Restore Logic ---
  176. @staticmethod
  177. def get_model_columns(model):
  178. return [c.key for c in inspect(model).mapper.column_attrs]
  179. @staticmethod
  180. def preview_restore(db: Session, backup_id: int, restore_type: str):
  181. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  182. if not backup or not os.path.exists(backup.file_path):
  183. raise HTTPException(status_code=404, detail="Backup file not found")
  184. csv_filename = ""
  185. model = None
  186. if restore_type == "APPLICATIONS":
  187. csv_filename = "applications.csv"
  188. model = Application
  189. elif restore_type == "USERS":
  190. csv_filename = "users.csv"
  191. model = User
  192. elif restore_type == "MAPPINGS":
  193. # For mappings, we just check the first file in mappings/ dir to get headers
  194. # Logic: list zip contents, find first file starting with mappings/
  195. model = AppUserMapping
  196. else:
  197. raise HTTPException(status_code=400, detail="Invalid restore type")
  198. db_columns = BackupService.get_model_columns(model)
  199. csv_headers = []
  200. mapping_files = [] # For MAPPINGS type
  201. try:
  202. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  203. target_file = None
  204. if restore_type == "MAPPINGS":
  205. for name in zipf.namelist():
  206. if (name.startswith("mappings/") and name.endswith(".csv")) or \
  207. (name.startswith("mappings_") and name.endswith(".csv") and "/" not in name):
  208. # Parse AppID for display
  209. app_id_display = "Unknown"
  210. filename_base = os.path.basename(name)
  211. if filename_base.startswith("mappings_"):
  212. # format: mappings_{safe_app_id}.csv
  213. # safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
  214. app_id_display = filename_base.replace("mappings_", "").replace(".csv", "")
  215. mapping_files.append({
  216. "filename": name,
  217. "app_id": app_id_display
  218. })
  219. # Use the first one found to get headers
  220. if not target_file:
  221. target_file = name
  222. else:
  223. if csv_filename in zipf.namelist():
  224. target_file = csv_filename
  225. if not target_file:
  226. # It's possible the backup doesn't have this file (e.g. empty mappings)
  227. return {"csv_headers": [], "db_columns": db_columns, "mapping_files": []}
  228. with zipf.open(target_file, 'r') as f:
  229. # zipf.open returns bytes, need text wrapper
  230. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  231. reader = csv.reader(wrapper)
  232. try:
  233. csv_headers = next(reader)
  234. except StopIteration:
  235. csv_headers = []
  236. except zipfile.BadZipFile:
  237. raise HTTPException(status_code=400, detail="Invalid backup file format")
  238. return {"csv_headers": csv_headers, "db_columns": db_columns, "mapping_files": mapping_files}
  239. @staticmethod
  240. def send_restore_sms(captcha_id: str, captcha_code: str, user: User):
  241. # 1. Verify Captcha
  242. if not CaptchaService.verify_captcha(captcha_id, captcha_code):
  243. raise HTTPException(status_code=400, detail="图形验证码错误")
  244. # 2. Send SMS to current user
  245. try:
  246. SmsService.send_code(user.mobile)
  247. except Exception as e:
  248. if hasattr(e, "detail"):
  249. raise e
  250. raise HTTPException(status_code=400, detail="发送短信失败")
  251. @staticmethod
  252. def restore_data(
  253. db: Session,
  254. current_user: User,
  255. backup_id: int,
  256. restore_type: str,
  257. field_mapping: Dict[str, str],
  258. password: str,
  259. sms_code: str,
  260. selected_files: List[str] = None
  261. ):
  262. # 1. Verification
  263. # Verify Password
  264. if not verify_password(password, current_user.password_hash):
  265. raise HTTPException(status_code=400, detail="密码错误")
  266. # Verify SMS Code
  267. if not SmsService.verify_code(current_user.mobile, sms_code):
  268. raise HTTPException(status_code=400, detail="短信验证码错误或已过期")
  269. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  270. if not backup or not os.path.exists(backup.file_path):
  271. raise HTTPException(status_code=404, detail="Backup file not found")
  272. # 2. Determine Model and Files
  273. model = None
  274. target_files = []
  275. if restore_type == "APPLICATIONS":
  276. target_files = ["applications.csv"]
  277. model = Application
  278. elif restore_type == "USERS":
  279. target_files = ["users.csv"]
  280. model = User
  281. elif restore_type == "MAPPINGS":
  282. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  283. # Find all mapping files (both nested and flat)
  284. all_mapping_files = []
  285. for name in zipf.namelist():
  286. if (name.startswith("mappings/") and name.endswith(".csv")) or \
  287. (name.startswith("mappings_") and name.endswith(".csv") and "/" not in name):
  288. all_mapping_files.append(name)
  289. # Filter if selected_files provided
  290. if selected_files:
  291. target_files = [f for f in all_mapping_files if f in selected_files]
  292. else:
  293. target_files = all_mapping_files
  294. model = AppUserMapping
  295. else:
  296. raise HTTPException(status_code=400, detail="Invalid restore type")
  297. # 3. Process Restore
  298. restored_count = 0
  299. try:
  300. # Get columns and types for type conversion
  301. mapper = inspect(model)
  302. columns = mapper.columns
  303. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  304. for filename in target_files:
  305. # Check if file exists in zip (considering flat structure compatibility for MAPPINGS)
  306. actual_filename = filename
  307. if filename not in zipf.namelist():
  308. # Try flat name if it's a mapping file
  309. if restore_type == "MAPPINGS" and filename.startswith("mappings/") and "/" in filename:
  310. flat_name = filename.split("/")[-1]
  311. if flat_name in zipf.namelist():
  312. actual_filename = flat_name
  313. else:
  314. continue
  315. else:
  316. continue
  317. with zipf.open(actual_filename, 'r') as f:
  318. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  319. # Use DictReader but we need to map headers manually based on field_mapping
  320. # Actually we can just read rows and map values
  321. reader = csv.DictReader(wrapper)
  322. for row in reader:
  323. # Construct data dict based on mapping
  324. # field_mapping: { "csv_col": "db_col" }
  325. data = {}
  326. for csv_col, db_col in field_mapping.items():
  327. if csv_col in row and db_col: # if db_col is not empty/none
  328. val = row[csv_col]
  329. # Type Conversion
  330. if db_col in columns:
  331. col_type = columns[db_col].type
  332. # Handle Boolean
  333. if isinstance(col_type, Boolean):
  334. if str(val).lower() in ('true', '1', 't', 'yes'):
  335. val = True
  336. elif str(val).lower() in ('false', '0', 'f', 'no'):
  337. val = False
  338. else:
  339. if val == "": val = None
  340. # Handle Integer
  341. elif isinstance(col_type, Integer):
  342. if val == "":
  343. val = None
  344. else:
  345. try:
  346. val = int(val)
  347. except ValueError:
  348. pass # Keep as is or ignore
  349. # Handle Empty Strings for others
  350. elif val == "":
  351. val = None
  352. data[db_col] = val
  353. # Upsert Logic
  354. # We assume 'id' is present if mapped.
  355. # If id exists, merge. Else add.
  356. if 'id' in data and data['id']:
  357. existing = db.query(model).filter(model.id == data['id']).first()
  358. if existing:
  359. for k, v in data.items():
  360. setattr(existing, k, v)
  361. else:
  362. obj = model(**data)
  363. db.add(obj)
  364. else:
  365. # No ID, just add? Might create duplicates.
  366. # Ideally we should map unique keys.
  367. # For now, let's assume ID is required for restore to work correctly with relationships
  368. obj = model(**data)
  369. db.add(obj)
  370. restored_count += 1
  371. db.commit()
  372. # Log Operation
  373. LogService.create_log(
  374. db=db,
  375. operator_id=current_user.id,
  376. action_type=ActionType.UPDATE, # Using UPDATE generic for Restore
  377. details={
  378. "event": "restore_data",
  379. "type": restore_type,
  380. "backup_id": backup_id,
  381. "count": restored_count
  382. }
  383. )
  384. return {"message": f"Successfully restored {restored_count} records", "count": restored_count}
  385. except Exception as e:
  386. db.rollback()
  387. raise HTTPException(status_code=500, detail=f"Restore failed: {str(e)}")