deps.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from typing import Generator, Optional, Union
  2. from fastapi import Depends, HTTPException, status, Response, Header
  3. from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
  4. from jose import jwt, JWTError
  5. from sqlalchemy.orm import Session
  6. from datetime import datetime, timedelta
  7. from app.core import security
  8. from app.core.config import settings
  9. from app.core.database import SessionLocal
  10. from app.models.user import User
  11. from app.models.application import Application
  12. from app.schemas.token import TokenPayload
  13. from app.services.signature_service import SignatureService
  14. reusable_oauth2 = OAuth2PasswordBearer(
  15. tokenUrl=f"{settings.API_V1_STR}/auth/login",
  16. auto_error=False # Allow optional token
  17. )
  18. token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
  19. def get_db() -> Generator:
  20. try:
  21. db = SessionLocal()
  22. yield db
  23. finally:
  24. db.close()
  25. def get_current_user(
  26. response: Response,
  27. db: Session = Depends(get_db),
  28. token: str = Depends(reusable_oauth2)
  29. ) -> User:
  30. if not token:
  31. raise HTTPException(
  32. status_code=status.HTTP_403_FORBIDDEN,
  33. detail="Not authenticated",
  34. )
  35. try:
  36. payload = jwt.decode(
  37. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  38. )
  39. token_data = TokenPayload(**payload)
  40. # Sliding Expiration Check
  41. # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
  42. exp = payload.get("exp")
  43. is_long_term = payload.get("long_term", False)
  44. if exp:
  45. now = datetime.now().timestamp()
  46. remaining_seconds = exp - now
  47. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  48. if is_long_term:
  49. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  50. # If remaining time is less than half of the configured expiration time
  51. if remaining_seconds < threshold:
  52. expires_delta = None
  53. if is_long_term:
  54. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  55. # Issue new token
  56. new_token = security.create_access_token(
  57. subject=token_data.sub,
  58. expires_delta=expires_delta,
  59. is_long_term=is_long_term
  60. )
  61. # Set in response header
  62. response.headers["X-New-Token"] = new_token
  63. except (JWTError, Exception):
  64. raise HTTPException(
  65. status_code=status.HTTP_403_FORBIDDEN,
  66. detail="Could not validate credentials",
  67. )
  68. # Ensure it's a user token (numeric ID)
  69. if not token_data.sub or not token_data.sub.isdigit():
  70. raise HTTPException(
  71. status_code=status.HTTP_403_FORBIDDEN,
  72. detail="Invalid token type",
  73. )
  74. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  75. if not user:
  76. raise HTTPException(status_code=404, detail="User not found")
  77. return user
  78. def get_current_active_user(
  79. current_user: User = Depends(get_current_user),
  80. ) -> User:
  81. if current_user.status != "ACTIVE":
  82. raise HTTPException(status_code=400, detail="Inactive user")
  83. return current_user
  84. def get_current_user_optional(
  85. response: Response,
  86. db: Session = Depends(get_db),
  87. token: Optional[str] = Depends(reusable_oauth2)
  88. ) -> Optional[User]:
  89. """
  90. Returns the user if the token is valid, otherwise None.
  91. Does NOT raise 403.
  92. """
  93. if not token:
  94. return None
  95. try:
  96. payload = jwt.decode(
  97. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  98. )
  99. token_data = TokenPayload(**payload)
  100. # Sliding Expiration Check for Optional Auth
  101. exp = payload.get("exp")
  102. is_long_term = payload.get("long_term", False)
  103. if exp:
  104. now = datetime.now().timestamp()
  105. remaining_seconds = exp - now
  106. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  107. if is_long_term:
  108. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  109. if remaining_seconds < threshold:
  110. expires_delta = None
  111. if is_long_term:
  112. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  113. new_token = security.create_access_token(
  114. subject=token_data.sub,
  115. expires_delta=expires_delta,
  116. is_long_term=is_long_term
  117. )
  118. response.headers["X-New-Token"] = new_token
  119. except (JWTError, Exception):
  120. return None
  121. if not token_data.sub or not token_data.sub.isdigit():
  122. return None
  123. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  124. return user
  125. def get_current_active_user_optional(
  126. current_user: Optional[User] = Depends(get_current_user_optional),
  127. ) -> Optional[User]:
  128. if current_user and current_user.status == "ACTIVE":
  129. return current_user
  130. return None
  131. def get_current_app(
  132. db: Session = Depends(get_db),
  133. token: Optional[str] = Depends(reusable_oauth2),
  134. access_token: Optional[str] = Depends(token_header_scheme)
  135. ) -> Application:
  136. """
  137. Get application from token (Machine-to-Machine auth).
  138. Supports:
  139. 1. JWT Bearer Token (Subject: "app:{id}")
  140. 2. Permanent Access Token (Header: X-App-Access-Token)
  141. """
  142. # 1. Try Access Token first if present
  143. if access_token:
  144. # Use simple auth with permanent token
  145. app = db.query(Application).filter(Application.access_token == access_token).first()
  146. if not app:
  147. raise HTTPException(status_code=403, detail="Invalid access token")
  148. return app
  149. # 2. Try JWT Bearer Token
  150. if not token:
  151. raise HTTPException(
  152. status_code=status.HTTP_403_FORBIDDEN,
  153. detail="Not authenticated",
  154. )
  155. try:
  156. payload = jwt.decode(
  157. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  158. )
  159. token_data = TokenPayload(**payload)
  160. except (JWTError, Exception):
  161. raise HTTPException(
  162. status_code=status.HTTP_403_FORBIDDEN,
  163. detail="Could not validate credentials",
  164. )
  165. sub = token_data.sub
  166. if not sub or not sub.startswith("app:"):
  167. raise HTTPException(status_code=403, detail="Not an app token")
  168. try:
  169. app_id = int(sub.split(":")[1])
  170. except (ValueError, IndexError):
  171. raise HTTPException(status_code=403, detail="Invalid app token format")
  172. app = db.query(Application).filter(Application.id == app_id).first()
  173. if not app:
  174. raise HTTPException(status_code=404, detail="App not found")
  175. return app
  176. # 定义一个联合类型,表示调用者可能是用户,也可能是应用
  177. AuthSubject = Union[User, Application]
  178. def get_current_user_or_app(
  179. # --- 用户认证参数 ---
  180. token: Optional[str] = Depends(reusable_oauth2),
  181. # --- 应用认证参数 (Header 方式) ---
  182. x_app_id: Optional[str] = Header(None, alias="X-App-Id"),
  183. x_timestamp: Optional[str] = Header(None, alias="X-Timestamp"),
  184. x_sign: Optional[str] = Header(None, alias="X-Sign"),
  185. # --- 数据库会话 ---
  186. db: Session = Depends(get_db)
  187. ) -> AuthSubject:
  188. # 1. 尝试用户认证 (JWT)
  189. if token:
  190. try:
  191. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  192. token_data = TokenPayload(**payload)
  193. if token_data.sub and token_data.sub.isdigit():
  194. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  195. if user and user.status == "ACTIVE":
  196. return user
  197. except:
  198. pass # Token 无效,继续尝试应用认证
  199. # 2. 尝试应用认证 (签名)
  200. if x_app_id and x_timestamp and x_sign:
  201. app = db.query(Application).filter(Application.app_id == x_app_id).first()
  202. if app:
  203. # 验证签名
  204. params = {
  205. "app_id": x_app_id,
  206. "timestamp": x_timestamp,
  207. "sign": x_sign
  208. }
  209. if SignatureService.verify_signature(app.app_secret, params, x_sign):
  210. return app
  211. # 3. 均未通过
  212. raise HTTPException(
  213. status_code=401,
  214. detail="Authentication failed: Invalid Token or Signature"
  215. )
  216. def get_current_user_or_app_by_api_key(
  217. response: Response,
  218. db: Session = Depends(get_db),
  219. token: Optional[str] = Depends(reusable_oauth2),
  220. access_token: Optional[str] = Depends(token_header_scheme)
  221. ) -> AuthSubject:
  222. """
  223. 支持用户 JWT 和应用 API key 认证。
  224. 优先尝试用户认证,如果失败则尝试应用 API key 认证。
  225. """
  226. # 1. 尝试用户认证 (JWT)
  227. if token:
  228. try:
  229. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  230. token_data = TokenPayload(**payload)
  231. if token_data.sub and token_data.sub.isdigit():
  232. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  233. if user and user.status == "ACTIVE":
  234. # 滑动过期检查(复用现有逻辑)
  235. exp = payload.get("exp")
  236. is_long_term = payload.get("long_term", False)
  237. if exp:
  238. now = datetime.now().timestamp()
  239. remaining_seconds = exp - now
  240. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  241. if is_long_term:
  242. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  243. if remaining_seconds < threshold:
  244. expires_delta = None
  245. if is_long_term:
  246. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  247. new_token = security.create_access_token(
  248. subject=token_data.sub,
  249. expires_delta=expires_delta,
  250. is_long_term=is_long_term
  251. )
  252. response.headers["X-New-Token"] = new_token
  253. return user
  254. except:
  255. pass # Token 无效,继续尝试应用认证
  256. # 2. 尝试应用 API key 认证
  257. if access_token:
  258. app = db.query(Application).filter(Application.access_token == access_token).first()
  259. if app:
  260. return app
  261. # 3. 均未通过
  262. raise HTTPException(
  263. status_code=401,
  264. detail="Authentication failed: Invalid Token or API Key"
  265. )