from typing import Generator, Optional, Union from fastapi import Depends, HTTPException, status, Response, Header from fastapi.security import OAuth2PasswordBearer, APIKeyHeader from jose import jwt, JWTError from sqlalchemy.orm import Session from datetime import datetime, timedelta from app.core import security from app.core.config import settings from app.core.database import SessionLocal from app.models.user import User from app.models.application import Application from app.schemas.token import TokenPayload from app.services.signature_service import SignatureService reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False # Allow optional token ) token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False) def get_db() -> Generator: try: db = SessionLocal() yield db finally: db.close() def get_current_user( response: Response, db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) ) -> User: if not token: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authenticated", ) try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) token_data = TokenPayload(**payload) # Sliding Expiration Check # If token is valid but expires soon (e.g. less than half of total lifetime), renew it exp = payload.get("exp") is_long_term = payload.get("long_term", False) if exp: now = datetime.now().timestamp() remaining_seconds = exp - now threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2 if is_long_term: threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2 # If remaining time is less than half of the configured expiration time if remaining_seconds < threshold: expires_delta = None if is_long_term: expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG) # Issue new token new_token = security.create_access_token( subject=token_data.sub, expires_delta=expires_delta, is_long_term=is_long_term ) # Set in response header response.headers["X-New-Token"] = new_token except (JWTError, Exception): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials", ) # Ensure it's a user token (numeric ID) if not token_data.sub or not token_data.sub.isdigit(): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token type", ) user = db.query(User).filter(User.id == int(token_data.sub)).first() if not user: raise HTTPException(status_code=404, detail="User not found") return user def get_current_active_user( current_user: User = Depends(get_current_user), ) -> User: if current_user.status != "ACTIVE": raise HTTPException(status_code=400, detail="Inactive user") return current_user def get_current_user_optional( response: Response, db: Session = Depends(get_db), token: Optional[str] = Depends(reusable_oauth2) ) -> Optional[User]: """ Returns the user if the token is valid, otherwise None. Does NOT raise 403. """ if not token: return None try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) token_data = TokenPayload(**payload) # Sliding Expiration Check for Optional Auth exp = payload.get("exp") is_long_term = payload.get("long_term", False) if exp: now = datetime.now().timestamp() remaining_seconds = exp - now threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2 if is_long_term: threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2 if remaining_seconds < threshold: expires_delta = None if is_long_term: expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG) new_token = security.create_access_token( subject=token_data.sub, expires_delta=expires_delta, is_long_term=is_long_term ) response.headers["X-New-Token"] = new_token except (JWTError, Exception): return None if not token_data.sub or not token_data.sub.isdigit(): return None user = db.query(User).filter(User.id == int(token_data.sub)).first() return user def get_current_active_user_optional( current_user: Optional[User] = Depends(get_current_user_optional), ) -> Optional[User]: if current_user and current_user.status == "ACTIVE": return current_user return None def get_current_app( db: Session = Depends(get_db), token: Optional[str] = Depends(reusable_oauth2), access_token: Optional[str] = Depends(token_header_scheme) ) -> Application: """ Get application from token (Machine-to-Machine auth). Supports: 1. JWT Bearer Token (Subject: "app:{id}") 2. Permanent Access Token (Header: X-App-Access-Token) """ # 1. Try Access Token first if present if access_token: # Use simple auth with permanent token app = db.query(Application).filter(Application.access_token == access_token).first() if not app: raise HTTPException(status_code=403, detail="Invalid access token") return app # 2. Try JWT Bearer Token if not token: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authenticated", ) try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) token_data = TokenPayload(**payload) except (JWTError, Exception): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials", ) sub = token_data.sub if not sub or not sub.startswith("app:"): raise HTTPException(status_code=403, detail="Not an app token") try: app_id = int(sub.split(":")[1]) except (ValueError, IndexError): raise HTTPException(status_code=403, detail="Invalid app token format") app = db.query(Application).filter(Application.id == app_id).first() if not app: raise HTTPException(status_code=404, detail="App not found") return app # 定义一个联合类型,表示调用者可能是用户,也可能是应用 AuthSubject = Union[User, Application] def get_current_user_or_app( # --- 用户认证参数 --- token: Optional[str] = Depends(reusable_oauth2), # --- 应用认证参数 (Header 方式) --- x_app_id: Optional[str] = Header(None, alias="X-App-Id"), x_timestamp: Optional[str] = Header(None, alias="X-Timestamp"), x_sign: Optional[str] = Header(None, alias="X-Sign"), # --- 数据库会话 --- db: Session = Depends(get_db) ) -> AuthSubject: # 1. 尝试用户认证 (JWT) if token: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) token_data = TokenPayload(**payload) if token_data.sub and token_data.sub.isdigit(): user = db.query(User).filter(User.id == int(token_data.sub)).first() if user and user.status == "ACTIVE": return user except: pass # Token 无效,继续尝试应用认证 # 2. 尝试应用认证 (签名) if x_app_id and x_timestamp and x_sign: app = db.query(Application).filter(Application.app_id == x_app_id).first() if app: # 验证签名 params = { "app_id": x_app_id, "timestamp": x_timestamp, "sign": x_sign } if SignatureService.verify_signature(app.app_secret, params, x_sign): return app # 3. 均未通过 raise HTTPException( status_code=401, detail="Authentication failed: Invalid Token or Signature" )