backup_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import os
  2. import shutil
  3. import zipfile
  4. import pandas as pd
  5. import io
  6. import csv
  7. from datetime import datetime
  8. from typing import List, Dict, Any
  9. from sqlalchemy.orm import Session
  10. from sqlalchemy import inspect, Boolean, Integer
  11. from fastapi import HTTPException
  12. from app.core.database import SessionLocal
  13. from app.models.application import Application
  14. from app.models.mapping import AppUserMapping
  15. from app.models.user import User
  16. from app.models.backup import BackupRecord, BackupType, BackupSettings
  17. from app.core.scheduler import scheduler
  18. from app.core.security import verify_password
  19. from app.services.captcha_service import CaptchaService
  20. from app.services.sms_service import SmsService
  21. from app.services.log_service import LogService
  22. from app.schemas.operation_log import ActionType
  23. from apscheduler.triggers.cron import CronTrigger
  24. # Ensure backup directory exists relative to backend root
  25. BACKUP_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "backups")
  26. class BackupService:
  27. @staticmethod
  28. def ensure_backup_dir():
  29. if not os.path.exists(BACKUP_DIR):
  30. os.makedirs(BACKUP_DIR, exist_ok=True)
  31. @staticmethod
  32. def create_backup(db: Session, backup_type: BackupType = BackupType.MANUAL) -> BackupRecord:
  33. BackupService.ensure_backup_dir()
  34. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  35. base_filename = f"backup_{timestamp}"
  36. zip_filename = f"{base_filename}.zip"
  37. zip_filepath = os.path.join(BACKUP_DIR, zip_filename)
  38. # Temp dir for csvs
  39. temp_dir = os.path.join(BACKUP_DIR, f"temp_{timestamp}")
  40. os.makedirs(temp_dir, exist_ok=True)
  41. try:
  42. # 1. Export Applications
  43. apps = db.query(Application).all()
  44. apps_df = pd.read_sql(db.query(Application).statement, db.bind)
  45. apps_df.to_csv(os.path.join(temp_dir, "applications.csv"), index=False, encoding='utf-8-sig')
  46. # 2. Export Mappings (Separate by Application)
  47. mappings_dir = os.path.join(temp_dir, "mappings")
  48. os.makedirs(mappings_dir, exist_ok=True)
  49. for app in apps:
  50. # Use app_id (string) for filename, sanitized
  51. safe_app_id = "".join([c for c in app.app_id if c.isalnum() or c in ('-', '_')])
  52. if not safe_app_id:
  53. safe_app_id = f"app_{app.id}"
  54. app_mappings_query = db.query(AppUserMapping).filter(AppUserMapping.app_id == app.id)
  55. # Check if there are any mappings
  56. if app_mappings_query.count() > 0:
  57. app_mappings_df = pd.read_sql(app_mappings_query.statement, db.bind)
  58. filename = f"mappings_{safe_app_id}.csv"
  59. app_mappings_df.to_csv(os.path.join(mappings_dir, filename), index=False, encoding='utf-8-sig')
  60. # 3. Export Users
  61. users_df = pd.read_sql(db.query(User).statement, db.bind)
  62. users_df.to_csv(os.path.join(temp_dir, "users.csv"), index=False, encoding='utf-8-sig')
  63. # 4. Zip
  64. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  65. for root, dirs, files in os.walk(temp_dir):
  66. for file in files:
  67. file_path = os.path.join(root, file)
  68. zipf.write(file_path, file.replace(temp_dir, "").lstrip(os.sep))
  69. # Get file size
  70. file_size = os.path.getsize(zip_filepath)
  71. # Record
  72. backup_record = BackupRecord(
  73. filename=zip_filename,
  74. file_path=zip_filepath,
  75. backup_type=backup_type,
  76. content_types="users,applications,mappings",
  77. file_size=file_size
  78. )
  79. db.add(backup_record)
  80. db.commit()
  81. db.refresh(backup_record)
  82. return backup_record
  83. except Exception as e:
  84. # Cleanup zip if failed
  85. if os.path.exists(zip_filepath):
  86. os.remove(zip_filepath)
  87. raise e
  88. finally:
  89. # Cleanup temp
  90. if os.path.exists(temp_dir):
  91. shutil.rmtree(temp_dir)
  92. @staticmethod
  93. def get_settings(db: Session) -> BackupSettings:
  94. settings = db.query(BackupSettings).first()
  95. if not settings:
  96. settings = BackupSettings(auto_backup_enabled=False, backup_time="02:00")
  97. db.add(settings)
  98. db.commit()
  99. db.refresh(settings)
  100. return settings
  101. @staticmethod
  102. def update_settings(db: Session, auto_backup_enabled: bool, backup_time: str):
  103. settings = BackupService.get_settings(db)
  104. settings.auto_backup_enabled = auto_backup_enabled
  105. settings.backup_time = backup_time
  106. settings.updated_at = datetime.now()
  107. db.commit()
  108. db.refresh(settings)
  109. # Update Scheduler
  110. BackupService.configure_scheduler(settings)
  111. return settings
  112. @staticmethod
  113. def configure_scheduler(settings: BackupSettings):
  114. job_id = "auto_backup_job"
  115. if scheduler.get_job(job_id):
  116. scheduler.remove_job(job_id)
  117. if settings.auto_backup_enabled:
  118. try:
  119. hour, minute = settings.backup_time.split(":")
  120. trigger = CronTrigger(hour=int(hour), minute=int(minute))
  121. scheduler.add_job(
  122. BackupService.perform_auto_backup,
  123. trigger=trigger,
  124. id=job_id,
  125. replace_existing=True
  126. )
  127. except ValueError:
  128. # Handle invalid time format if necessary
  129. pass
  130. @staticmethod
  131. def perform_auto_backup():
  132. db = SessionLocal()
  133. try:
  134. BackupService.create_backup(db, BackupType.AUTO)
  135. # Update last_backup_at
  136. settings = BackupService.get_settings(db)
  137. settings.last_backup_at = datetime.now()
  138. db.commit()
  139. finally:
  140. db.close()
  141. @staticmethod
  142. def init_scheduler(db: Session):
  143. settings = BackupService.get_settings(db)
  144. BackupService.configure_scheduler(settings)
  145. # --- Restore Logic ---
  146. @staticmethod
  147. def get_model_columns(model):
  148. return [c.key for c in inspect(model).mapper.column_attrs]
  149. @staticmethod
  150. def preview_restore(db: Session, backup_id: int, restore_type: str):
  151. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  152. if not backup or not os.path.exists(backup.file_path):
  153. raise HTTPException(status_code=404, detail="Backup file not found")
  154. csv_filename = ""
  155. model = None
  156. if restore_type == "APPLICATIONS":
  157. csv_filename = "applications.csv"
  158. model = Application
  159. elif restore_type == "USERS":
  160. csv_filename = "users.csv"
  161. model = User
  162. elif restore_type == "MAPPINGS":
  163. # For mappings, we just check the first file in mappings/ dir to get headers
  164. # Logic: list zip contents, find first file starting with mappings/
  165. model = AppUserMapping
  166. else:
  167. raise HTTPException(status_code=400, detail="Invalid restore type")
  168. db_columns = BackupService.get_model_columns(model)
  169. csv_headers = []
  170. try:
  171. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  172. target_file = None
  173. if restore_type == "MAPPINGS":
  174. for name in zipf.namelist():
  175. if name.startswith("mappings/") and name.endswith(".csv"):
  176. target_file = name
  177. break
  178. else:
  179. if csv_filename in zipf.namelist():
  180. target_file = csv_filename
  181. if not target_file:
  182. # It's possible the backup doesn't have this file (e.g. empty mappings)
  183. return {"csv_headers": [], "db_columns": db_columns}
  184. with zipf.open(target_file, 'r') as f:
  185. # zipf.open returns bytes, need text wrapper
  186. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  187. reader = csv.reader(wrapper)
  188. try:
  189. csv_headers = next(reader)
  190. except StopIteration:
  191. csv_headers = []
  192. except zipfile.BadZipFile:
  193. raise HTTPException(status_code=400, detail="Invalid backup file format")
  194. return {"csv_headers": csv_headers, "db_columns": db_columns}
  195. @staticmethod
  196. def send_restore_sms(captcha_id: str, captcha_code: str, user: User):
  197. # 1. Verify Captcha
  198. if not CaptchaService.verify_captcha(captcha_id, captcha_code):
  199. raise HTTPException(status_code=400, detail="图形验证码错误")
  200. # 2. Send SMS to current user
  201. try:
  202. SmsService.send_code(user.mobile)
  203. except Exception as e:
  204. if hasattr(e, "detail"):
  205. raise e
  206. raise HTTPException(status_code=400, detail="发送短信失败")
  207. @staticmethod
  208. def restore_data(
  209. db: Session,
  210. current_user: User,
  211. backup_id: int,
  212. restore_type: str,
  213. field_mapping: Dict[str, str],
  214. password: str,
  215. sms_code: str
  216. ):
  217. # 1. Verification
  218. # Verify Password
  219. if not verify_password(password, current_user.password_hash):
  220. raise HTTPException(status_code=400, detail="密码错误")
  221. # Verify SMS Code
  222. if not SmsService.verify_code(current_user.mobile, sms_code):
  223. raise HTTPException(status_code=400, detail="短信验证码错误或已过期")
  224. backup = db.query(BackupRecord).filter(BackupRecord.id == backup_id).first()
  225. if not backup or not os.path.exists(backup.file_path):
  226. raise HTTPException(status_code=404, detail="Backup file not found")
  227. # 2. Determine Model and Files
  228. model = None
  229. target_files = []
  230. if restore_type == "APPLICATIONS":
  231. target_files = ["applications.csv"]
  232. model = Application
  233. elif restore_type == "USERS":
  234. target_files = ["users.csv"]
  235. model = User
  236. elif restore_type == "MAPPINGS":
  237. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  238. target_files = [name for name in zipf.namelist() if name.startswith("mappings/") and name.endswith(".csv")]
  239. model = AppUserMapping
  240. else:
  241. raise HTTPException(status_code=400, detail="Invalid restore type")
  242. # 3. Process Restore
  243. restored_count = 0
  244. try:
  245. # Get columns and types for type conversion
  246. mapper = inspect(model)
  247. columns = mapper.columns
  248. with zipfile.ZipFile(backup.file_path, 'r') as zipf:
  249. for filename in target_files:
  250. if filename not in zipf.namelist():
  251. continue
  252. with zipf.open(filename, 'r') as f:
  253. wrapper = io.TextIOWrapper(f, encoding='utf-8-sig')
  254. # Use DictReader but we need to map headers manually based on field_mapping
  255. # Actually we can just read rows and map values
  256. reader = csv.DictReader(wrapper)
  257. for row in reader:
  258. # Construct data dict based on mapping
  259. # field_mapping: { "csv_col": "db_col" }
  260. data = {}
  261. for csv_col, db_col in field_mapping.items():
  262. if csv_col in row and db_col: # if db_col is not empty/none
  263. val = row[csv_col]
  264. # Type Conversion
  265. if db_col in columns:
  266. col_type = columns[db_col].type
  267. # Handle Boolean
  268. if isinstance(col_type, Boolean):
  269. if str(val).lower() in ('true', '1', 't', 'yes'):
  270. val = True
  271. elif str(val).lower() in ('false', '0', 'f', 'no'):
  272. val = False
  273. else:
  274. if val == "": val = None
  275. # Handle Integer
  276. elif isinstance(col_type, Integer):
  277. if val == "":
  278. val = None
  279. else:
  280. try:
  281. val = int(val)
  282. except ValueError:
  283. pass # Keep as is or ignore
  284. # Handle Empty Strings for others
  285. elif val == "":
  286. val = None
  287. data[db_col] = val
  288. # Upsert Logic
  289. # We assume 'id' is present if mapped.
  290. # If id exists, merge. Else add.
  291. if 'id' in data and data['id']:
  292. existing = db.query(model).filter(model.id == data['id']).first()
  293. if existing:
  294. for k, v in data.items():
  295. setattr(existing, k, v)
  296. else:
  297. obj = model(**data)
  298. db.add(obj)
  299. else:
  300. # No ID, just add? Might create duplicates.
  301. # Ideally we should map unique keys.
  302. # For now, let's assume ID is required for restore to work correctly with relationships
  303. obj = model(**data)
  304. db.add(obj)
  305. restored_count += 1
  306. db.commit()
  307. # Log Operation
  308. LogService.create_log(
  309. db=db,
  310. operator_id=current_user.id,
  311. action_type=ActionType.UPDATE, # Using UPDATE generic for Restore
  312. details={
  313. "event": "restore_data",
  314. "type": restore_type,
  315. "backup_id": backup_id,
  316. "count": restored_count
  317. }
  318. )
  319. return {"message": f"Successfully restored {restored_count} records", "count": restored_count}
  320. except Exception as e:
  321. db.rollback()
  322. raise HTTPException(status_code=500, detail=f"Restore failed: {str(e)}")