deps.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. from typing import Generator, Optional, Union
  2. from fastapi import Depends, HTTPException, status, Response, Header, Request
  3. from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
  4. from jose import jwt, JWTError
  5. from sqlalchemy.orm import Session
  6. from datetime import datetime, timedelta
  7. from app.core import security
  8. from app.core.config import settings
  9. from app.core.database import SessionLocal
  10. from app.core.utils import get_client_ip
  11. from app.models.user import User
  12. from app.models.application import Application
  13. from app.models.admin_api_key import AdminApiKey
  14. from app.schemas.token import TokenPayload
  15. from app.services.signature_service import SignatureService
  16. reusable_oauth2 = OAuth2PasswordBearer(
  17. tokenUrl=f"{settings.API_V1_STR}/auth/login",
  18. auto_error=False # Allow optional token
  19. )
  20. token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
  21. admin_key_header = APIKeyHeader(name="X-Admin-Api-Key", auto_error=False)
  22. def get_db() -> Generator:
  23. try:
  24. db = SessionLocal()
  25. yield db
  26. finally:
  27. db.close()
  28. def get_current_user(
  29. response: Response,
  30. db: Session = Depends(get_db),
  31. token: str = Depends(reusable_oauth2)
  32. ) -> User:
  33. if not token:
  34. raise HTTPException(
  35. status_code=status.HTTP_403_FORBIDDEN,
  36. detail="Not authenticated",
  37. )
  38. try:
  39. payload = jwt.decode(
  40. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  41. )
  42. token_data = TokenPayload(**payload)
  43. # Sliding Expiration Check
  44. # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
  45. exp = payload.get("exp")
  46. is_long_term = payload.get("long_term", False)
  47. if exp:
  48. now = datetime.now().timestamp()
  49. remaining_seconds = exp - now
  50. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  51. if is_long_term:
  52. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  53. # If remaining time is less than half of the configured expiration time
  54. if remaining_seconds < threshold:
  55. expires_delta = None
  56. if is_long_term:
  57. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  58. # Issue new token
  59. new_token = security.create_access_token(
  60. subject=token_data.sub,
  61. expires_delta=expires_delta,
  62. is_long_term=is_long_term
  63. )
  64. # Set in response header
  65. response.headers["X-New-Token"] = new_token
  66. except (JWTError, Exception):
  67. raise HTTPException(
  68. status_code=status.HTTP_403_FORBIDDEN,
  69. detail="Could not validate credentials",
  70. )
  71. # Ensure it's a user token (numeric ID)
  72. if not token_data.sub or not token_data.sub.isdigit():
  73. raise HTTPException(
  74. status_code=status.HTTP_403_FORBIDDEN,
  75. detail="Invalid token type",
  76. )
  77. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  78. if not user:
  79. raise HTTPException(status_code=404, detail="User not found")
  80. return user
  81. def get_current_active_user(
  82. current_user: User = Depends(get_current_user),
  83. ) -> User:
  84. if current_user.status != "ACTIVE":
  85. raise HTTPException(status_code=400, detail="Inactive user")
  86. return current_user
  87. def get_current_user_optional(
  88. response: Response,
  89. db: Session = Depends(get_db),
  90. token: Optional[str] = Depends(reusable_oauth2)
  91. ) -> Optional[User]:
  92. """
  93. Returns the user if the token is valid, otherwise None.
  94. Does NOT raise 403.
  95. """
  96. if not token:
  97. return None
  98. try:
  99. payload = jwt.decode(
  100. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  101. )
  102. token_data = TokenPayload(**payload)
  103. # Sliding Expiration Check for Optional Auth
  104. exp = payload.get("exp")
  105. is_long_term = payload.get("long_term", False)
  106. if exp:
  107. now = datetime.now().timestamp()
  108. remaining_seconds = exp - now
  109. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  110. if is_long_term:
  111. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  112. if remaining_seconds < threshold:
  113. expires_delta = None
  114. if is_long_term:
  115. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  116. new_token = security.create_access_token(
  117. subject=token_data.sub,
  118. expires_delta=expires_delta,
  119. is_long_term=is_long_term
  120. )
  121. response.headers["X-New-Token"] = new_token
  122. except (JWTError, Exception):
  123. return None
  124. if not token_data.sub or not token_data.sub.isdigit():
  125. return None
  126. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  127. return user
  128. def get_current_active_user_optional(
  129. current_user: Optional[User] = Depends(get_current_user_optional),
  130. ) -> Optional[User]:
  131. if current_user and current_user.status == "ACTIVE":
  132. return current_user
  133. return None
  134. def get_current_app(
  135. db: Session = Depends(get_db),
  136. token: Optional[str] = Depends(reusable_oauth2),
  137. access_token: Optional[str] = Depends(token_header_scheme)
  138. ) -> Application:
  139. """
  140. Get application from token (Machine-to-Machine auth).
  141. Supports:
  142. 1. JWT Bearer Token (Subject: "app:{id}")
  143. 2. Permanent Access Token (Header: X-App-Access-Token)
  144. """
  145. # 1. Try Access Token first if present
  146. if access_token:
  147. # Use simple auth with permanent token
  148. app = db.query(Application).filter(Application.access_token == access_token).first()
  149. if not app:
  150. raise HTTPException(status_code=403, detail="Invalid access token")
  151. return app
  152. # 2. Try JWT Bearer Token
  153. if not token:
  154. raise HTTPException(
  155. status_code=status.HTTP_403_FORBIDDEN,
  156. detail="Not authenticated",
  157. )
  158. try:
  159. payload = jwt.decode(
  160. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  161. )
  162. token_data = TokenPayload(**payload)
  163. except (JWTError, Exception):
  164. raise HTTPException(
  165. status_code=status.HTTP_403_FORBIDDEN,
  166. detail="Could not validate credentials",
  167. )
  168. sub = token_data.sub
  169. if not sub or not sub.startswith("app:"):
  170. raise HTTPException(status_code=403, detail="Not an app token")
  171. try:
  172. app_id = int(sub.split(":")[1])
  173. except (ValueError, IndexError):
  174. raise HTTPException(status_code=403, detail="Invalid app token format")
  175. app = db.query(Application).filter(Application.id == app_id).first()
  176. if not app:
  177. raise HTTPException(status_code=404, detail="App not found")
  178. return app
  179. # 定义一个联合类型,表示调用者可能是用户,也可能是应用
  180. AuthSubject = Union[User, Application]
  181. def get_current_user_or_app(
  182. # --- 用户认证参数 ---
  183. token: Optional[str] = Depends(reusable_oauth2),
  184. # --- 应用认证参数 (Header 方式) ---
  185. x_app_id: Optional[str] = Header(None, alias="X-App-Id"),
  186. x_timestamp: Optional[str] = Header(None, alias="X-Timestamp"),
  187. x_sign: Optional[str] = Header(None, alias="X-Sign"),
  188. # --- 数据库会话 ---
  189. db: Session = Depends(get_db)
  190. ) -> AuthSubject:
  191. # 1. 尝试用户认证 (JWT)
  192. if token:
  193. try:
  194. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  195. token_data = TokenPayload(**payload)
  196. if token_data.sub and token_data.sub.isdigit():
  197. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  198. if user and user.status == "ACTIVE":
  199. return user
  200. except:
  201. pass # Token 无效,继续尝试应用认证
  202. # 2. 尝试应用认证 (签名)
  203. if x_app_id and x_timestamp and x_sign:
  204. app = db.query(Application).filter(Application.app_id == x_app_id).first()
  205. if app:
  206. # 验证签名
  207. params = {
  208. "app_id": x_app_id,
  209. "timestamp": x_timestamp,
  210. "sign": x_sign
  211. }
  212. if SignatureService.verify_signature(app.app_secret, params, x_sign):
  213. return app
  214. # 3. 均未通过
  215. raise HTTPException(
  216. status_code=401,
  217. detail="Authentication failed: Invalid Token or Signature"
  218. )
  219. def get_current_user_or_app_by_api_key(
  220. response: Response,
  221. db: Session = Depends(get_db),
  222. token: Optional[str] = Depends(reusable_oauth2),
  223. access_token: Optional[str] = Depends(token_header_scheme)
  224. ) -> AuthSubject:
  225. """
  226. 支持用户 JWT 和应用 API key 认证。
  227. 优先尝试用户认证,如果失败则尝试应用 API key 认证。
  228. """
  229. # 1. 尝试用户认证 (JWT)
  230. if token:
  231. try:
  232. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  233. token_data = TokenPayload(**payload)
  234. if token_data.sub and token_data.sub.isdigit():
  235. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  236. if user and user.status == "ACTIVE":
  237. # 滑动过期检查(复用现有逻辑)
  238. exp = payload.get("exp")
  239. is_long_term = payload.get("long_term", False)
  240. if exp:
  241. now = datetime.now().timestamp()
  242. remaining_seconds = exp - now
  243. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  244. if is_long_term:
  245. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  246. if remaining_seconds < threshold:
  247. expires_delta = None
  248. if is_long_term:
  249. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  250. new_token = security.create_access_token(
  251. subject=token_data.sub,
  252. expires_delta=expires_delta,
  253. is_long_term=is_long_term
  254. )
  255. response.headers["X-New-Token"] = new_token
  256. return user
  257. except:
  258. pass # Token 无效,继续尝试应用认证
  259. # 2. 尝试应用 API key 认证
  260. if access_token:
  261. app = db.query(Application).filter(Application.access_token == access_token).first()
  262. if app:
  263. return app
  264. # 3. 均未通过
  265. raise HTTPException(
  266. status_code=401,
  267. detail="Authentication failed: Invalid Token or API Key"
  268. )
  269. # ---------------------------------------------------------------------------
  270. # Admin API Key 鉴权(仅供 /admin-api/* 一组独立接口使用,与上面 Application
  271. # 永久 Token / 用户 JWT 完全无关,避免误用到原 /users/*、/organizations/* 上)
  272. # ---------------------------------------------------------------------------
  273. def get_admin_api_key(
  274. request: Request,
  275. db: Session = Depends(get_db),
  276. raw_key: Optional[str] = Depends(admin_key_header),
  277. ) -> AdminApiKey:
  278. """根据请求头 X-Admin-Api-Key 解析有效的 AdminApiKey 实体。
  279. - 用 key 前缀粗筛
  280. - bcrypt 校验完整 key
  281. - 校验 is_active 且未吊销
  282. - 命中后更新 last_used_at / last_used_ip / use_count
  283. """
  284. if not raw_key:
  285. raise HTTPException(status_code=401, detail="Missing X-Admin-Api-Key header")
  286. # 容错:去除两侧空白;前缀长度 12(与生成时一致)
  287. raw_key = raw_key.strip()
  288. prefix = raw_key[:12]
  289. if not prefix:
  290. raise HTTPException(status_code=403, detail="Invalid admin api key")
  291. candidates = (
  292. db.query(AdminApiKey)
  293. .filter(
  294. AdminApiKey.key_prefix == prefix,
  295. AdminApiKey.is_active.is_(True),
  296. AdminApiKey.is_revoked.is_(False),
  297. )
  298. .all()
  299. )
  300. if not candidates:
  301. raise HTTPException(status_code=403, detail="Invalid or disabled admin api key")
  302. matched: Optional[AdminApiKey] = None
  303. for k in candidates:
  304. try:
  305. if security.verify_password(raw_key, k.key_hash):
  306. matched = k
  307. break
  308. except Exception:
  309. continue
  310. if not matched:
  311. raise HTTPException(status_code=403, detail="Invalid or disabled admin api key")
  312. # 更新使用信息(失败不影响主流程)
  313. try:
  314. matched.last_used_at = datetime.now()
  315. matched.last_used_ip = get_client_ip(request) if request else None
  316. matched.use_count = (matched.use_count or 0) + 1
  317. db.add(matched)
  318. db.commit()
  319. db.refresh(matched)
  320. except Exception:
  321. db.rollback()
  322. return matched