| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- from typing import Generator, Optional, Union
- from fastapi import Depends, HTTPException, status, Response, Header
- from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
- from jose import jwt, JWTError
- from sqlalchemy.orm import Session
- from datetime import datetime, timedelta
- from app.core import security
- from app.core.config import settings
- from app.core.database import SessionLocal
- from app.models.user import User
- from app.models.application import Application
- from app.schemas.token import TokenPayload
- from app.services.signature_service import SignatureService
- reusable_oauth2 = OAuth2PasswordBearer(
- tokenUrl=f"{settings.API_V1_STR}/auth/login",
- auto_error=False # Allow optional token
- )
- token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
- def get_db() -> Generator:
- try:
- db = SessionLocal()
- yield db
- finally:
- db.close()
- def get_current_user(
- response: Response,
- db: Session = Depends(get_db),
- token: str = Depends(reusable_oauth2)
- ) -> User:
- if not token:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authenticated",
- )
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
-
- # Sliding Expiration Check
- # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
-
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
-
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
-
- # If remaining time is less than half of the configured expiration time
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
-
- # Issue new token
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- # Set in response header
- response.headers["X-New-Token"] = new_token
-
- except (JWTError, Exception):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials",
- )
-
- # Ensure it's a user token (numeric ID)
- if not token_data.sub or not token_data.sub.isdigit():
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Invalid token type",
- )
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if not user:
- raise HTTPException(status_code=404, detail="User not found")
- return user
- def get_current_active_user(
- current_user: User = Depends(get_current_user),
- ) -> User:
- if current_user.status != "ACTIVE":
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- def get_current_user_optional(
- response: Response,
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2)
- ) -> Optional[User]:
- """
- Returns the user if the token is valid, otherwise None.
- Does NOT raise 403.
- """
- if not token:
- return None
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
- # Sliding Expiration Check for Optional Auth
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
-
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
-
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
-
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- response.headers["X-New-Token"] = new_token
- except (JWTError, Exception):
- return None
-
- if not token_data.sub or not token_data.sub.isdigit():
- return None
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- return user
- def get_current_active_user_optional(
- current_user: Optional[User] = Depends(get_current_user_optional),
- ) -> Optional[User]:
- if current_user and current_user.status == "ACTIVE":
- return current_user
- return None
- def get_current_app(
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2),
- access_token: Optional[str] = Depends(token_header_scheme)
- ) -> Application:
- """
- Get application from token (Machine-to-Machine auth).
- Supports:
- 1. JWT Bearer Token (Subject: "app:{id}")
- 2. Permanent Access Token (Header: X-App-Access-Token)
- """
- # 1. Try Access Token first if present
- if access_token:
- # Use simple auth with permanent token
- app = db.query(Application).filter(Application.access_token == access_token).first()
- if not app:
- raise HTTPException(status_code=403, detail="Invalid access token")
- return app
- # 2. Try JWT Bearer Token
- if not token:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authenticated",
- )
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
- except (JWTError, Exception):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials",
- )
-
- sub = token_data.sub
- if not sub or not sub.startswith("app:"):
- raise HTTPException(status_code=403, detail="Not an app token")
-
- try:
- app_id = int(sub.split(":")[1])
- except (ValueError, IndexError):
- raise HTTPException(status_code=403, detail="Invalid app token format")
-
- app = db.query(Application).filter(Application.id == app_id).first()
- if not app:
- raise HTTPException(status_code=404, detail="App not found")
- return app
- # 定义一个联合类型,表示调用者可能是用户,也可能是应用
- AuthSubject = Union[User, Application]
- def get_current_user_or_app(
- # --- 用户认证参数 ---
- token: Optional[str] = Depends(reusable_oauth2),
-
- # --- 应用认证参数 (Header 方式) ---
- x_app_id: Optional[str] = Header(None, alias="X-App-Id"),
- x_timestamp: Optional[str] = Header(None, alias="X-Timestamp"),
- x_sign: Optional[str] = Header(None, alias="X-Sign"),
-
- # --- 数据库会话 ---
- db: Session = Depends(get_db)
- ) -> AuthSubject:
-
- # 1. 尝试用户认证 (JWT)
- if token:
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
- token_data = TokenPayload(**payload)
- if token_data.sub and token_data.sub.isdigit():
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if user and user.status == "ACTIVE":
- return user
- except:
- pass # Token 无效,继续尝试应用认证
- # 2. 尝试应用认证 (签名)
- if x_app_id and x_timestamp and x_sign:
- app = db.query(Application).filter(Application.app_id == x_app_id).first()
- if app:
- # 验证签名
- params = {
- "app_id": x_app_id,
- "timestamp": x_timestamp,
- "sign": x_sign
- }
- if SignatureService.verify_signature(app.app_secret, params, x_sign):
- return app
-
- # 3. 均未通过
- raise HTTPException(
- status_code=401,
- detail="Authentication failed: Invalid Token or Signature"
- )
- def get_current_user_or_app_by_api_key(
- response: Response,
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2),
- access_token: Optional[str] = Depends(token_header_scheme)
- ) -> AuthSubject:
- """
- 支持用户 JWT 和应用 API key 认证。
- 优先尝试用户认证,如果失败则尝试应用 API key 认证。
- """
- # 1. 尝试用户认证 (JWT)
- if token:
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
- token_data = TokenPayload(**payload)
- if token_data.sub and token_data.sub.isdigit():
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if user and user.status == "ACTIVE":
- # 滑动过期检查(复用现有逻辑)
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- response.headers["X-New-Token"] = new_token
- return user
- except:
- pass # Token 无效,继续尝试应用认证
- # 2. 尝试应用 API key 认证
- if access_token:
- app = db.query(Application).filter(Application.access_token == access_token).first()
- if app:
- return app
- # 3. 均未通过
- raise HTTPException(
- status_code=401,
- detail="Authentication failed: Invalid Token or API Key"
- )
|