| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- from typing import Generator, Optional, Union
- from fastapi import Depends, HTTPException, status, Response, Header, Request
- from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
- from jose import jwt, JWTError
- from sqlalchemy.orm import Session
- from datetime import datetime, timedelta
- from app.core import security
- from app.core.config import settings
- from app.core.database import SessionLocal
- from app.core.utils import get_client_ip
- from app.models.user import User
- from app.models.application import Application
- from app.models.admin_api_key import AdminApiKey
- from app.schemas.token import TokenPayload
- from app.services.signature_service import SignatureService
- reusable_oauth2 = OAuth2PasswordBearer(
- tokenUrl=f"{settings.API_V1_STR}/auth/login",
- auto_error=False # Allow optional token
- )
- token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
- admin_key_header = APIKeyHeader(name="X-Admin-Api-Key", auto_error=False)
- def get_db() -> Generator:
- try:
- db = SessionLocal()
- yield db
- finally:
- db.close()
- def get_current_user(
- response: Response,
- db: Session = Depends(get_db),
- token: str = Depends(reusable_oauth2)
- ) -> User:
- if not token:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authenticated",
- )
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
-
- # Sliding Expiration Check
- # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
-
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
-
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
-
- # If remaining time is less than half of the configured expiration time
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
-
- # Issue new token
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- # Set in response header
- response.headers["X-New-Token"] = new_token
-
- except (JWTError, Exception):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials",
- )
-
- # Ensure it's a user token (numeric ID)
- if not token_data.sub or not token_data.sub.isdigit():
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Invalid token type",
- )
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if not user:
- raise HTTPException(status_code=404, detail="User not found")
- return user
- def get_current_active_user(
- current_user: User = Depends(get_current_user),
- ) -> User:
- if current_user.status != "ACTIVE":
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- def get_current_user_optional(
- response: Response,
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2)
- ) -> Optional[User]:
- """
- Returns the user if the token is valid, otherwise None.
- Does NOT raise 403.
- """
- if not token:
- return None
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
- # Sliding Expiration Check for Optional Auth
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
-
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
-
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
-
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- response.headers["X-New-Token"] = new_token
- except (JWTError, Exception):
- return None
-
- if not token_data.sub or not token_data.sub.isdigit():
- return None
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- return user
- def get_current_active_user_optional(
- current_user: Optional[User] = Depends(get_current_user_optional),
- ) -> Optional[User]:
- if current_user and current_user.status == "ACTIVE":
- return current_user
- return None
- def get_current_app(
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2),
- access_token: Optional[str] = Depends(token_header_scheme)
- ) -> Application:
- """
- Get application from token (Machine-to-Machine auth).
- Supports:
- 1. JWT Bearer Token (Subject: "app:{id}")
- 2. Permanent Access Token (Header: X-App-Access-Token)
- """
- # 1. Try Access Token first if present
- if access_token:
- # Use simple auth with permanent token
- app = db.query(Application).filter(Application.access_token == access_token).first()
- if not app:
- raise HTTPException(status_code=403, detail="Invalid access token")
- return app
- # 2. Try JWT Bearer Token
- if not token:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authenticated",
- )
- try:
- payload = jwt.decode(
- token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
- )
- token_data = TokenPayload(**payload)
- except (JWTError, Exception):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials",
- )
-
- sub = token_data.sub
- if not sub or not sub.startswith("app:"):
- raise HTTPException(status_code=403, detail="Not an app token")
-
- try:
- app_id = int(sub.split(":")[1])
- except (ValueError, IndexError):
- raise HTTPException(status_code=403, detail="Invalid app token format")
-
- app = db.query(Application).filter(Application.id == app_id).first()
- if not app:
- raise HTTPException(status_code=404, detail="App not found")
- return app
- # 定义一个联合类型,表示调用者可能是用户,也可能是应用
- AuthSubject = Union[User, Application]
- def get_current_user_or_app(
- # --- 用户认证参数 ---
- token: Optional[str] = Depends(reusable_oauth2),
-
- # --- 应用认证参数 (Header 方式) ---
- x_app_id: Optional[str] = Header(None, alias="X-App-Id"),
- x_timestamp: Optional[str] = Header(None, alias="X-Timestamp"),
- x_sign: Optional[str] = Header(None, alias="X-Sign"),
-
- # --- 数据库会话 ---
- db: Session = Depends(get_db)
- ) -> AuthSubject:
-
- # 1. 尝试用户认证 (JWT)
- if token:
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
- token_data = TokenPayload(**payload)
- if token_data.sub and token_data.sub.isdigit():
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if user and user.status == "ACTIVE":
- return user
- except:
- pass # Token 无效,继续尝试应用认证
- # 2. 尝试应用认证 (签名)
- if x_app_id and x_timestamp and x_sign:
- app = db.query(Application).filter(Application.app_id == x_app_id).first()
- if app:
- # 验证签名
- params = {
- "app_id": x_app_id,
- "timestamp": x_timestamp,
- "sign": x_sign
- }
- if SignatureService.verify_signature(app.app_secret, params, x_sign):
- return app
-
- # 3. 均未通过
- raise HTTPException(
- status_code=401,
- detail="Authentication failed: Invalid Token or Signature"
- )
- def get_current_user_or_app_by_api_key(
- response: Response,
- db: Session = Depends(get_db),
- token: Optional[str] = Depends(reusable_oauth2),
- access_token: Optional[str] = Depends(token_header_scheme)
- ) -> AuthSubject:
- """
- 支持用户 JWT 和应用 API key 认证。
- 优先尝试用户认证,如果失败则尝试应用 API key 认证。
- """
- # 1. 尝试用户认证 (JWT)
- if token:
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
- token_data = TokenPayload(**payload)
- if token_data.sub and token_data.sub.isdigit():
- user = db.query(User).filter(User.id == int(token_data.sub)).first()
- if user and user.status == "ACTIVE":
- # 滑动过期检查(复用现有逻辑)
- exp = payload.get("exp")
- is_long_term = payload.get("long_term", False)
- if exp:
- now = datetime.now().timestamp()
- remaining_seconds = exp - now
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
- if is_long_term:
- threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
- if remaining_seconds < threshold:
- expires_delta = None
- if is_long_term:
- expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
- new_token = security.create_access_token(
- subject=token_data.sub,
- expires_delta=expires_delta,
- is_long_term=is_long_term
- )
- response.headers["X-New-Token"] = new_token
- return user
- except:
- pass # Token 无效,继续尝试应用认证
- # 2. 尝试应用 API key 认证
- if access_token:
- app = db.query(Application).filter(Application.access_token == access_token).first()
- if app:
- return app
- # 3. 均未通过
- raise HTTPException(
- status_code=401,
- detail="Authentication failed: Invalid Token or API Key"
- )
- # ---------------------------------------------------------------------------
- # Admin API Key 鉴权(仅供 /admin-api/* 一组独立接口使用,与上面 Application
- # 永久 Token / 用户 JWT 完全无关,避免误用到原 /users/*、/organizations/* 上)
- # ---------------------------------------------------------------------------
- def get_admin_api_key(
- request: Request,
- db: Session = Depends(get_db),
- raw_key: Optional[str] = Depends(admin_key_header),
- ) -> AdminApiKey:
- """根据请求头 X-Admin-Api-Key 解析有效的 AdminApiKey 实体。
- - 用 key 前缀粗筛
- - bcrypt 校验完整 key
- - 校验 is_active 且未吊销
- - 命中后更新 last_used_at / last_used_ip / use_count
- """
- if not raw_key:
- raise HTTPException(status_code=401, detail="Missing X-Admin-Api-Key header")
- # 容错:去除两侧空白;前缀长度 12(与生成时一致)
- raw_key = raw_key.strip()
- prefix = raw_key[:12]
- if not prefix:
- raise HTTPException(status_code=403, detail="Invalid admin api key")
- candidates = (
- db.query(AdminApiKey)
- .filter(
- AdminApiKey.key_prefix == prefix,
- AdminApiKey.is_active.is_(True),
- AdminApiKey.is_revoked.is_(False),
- )
- .all()
- )
- if not candidates:
- raise HTTPException(status_code=403, detail="Invalid or disabled admin api key")
- matched: Optional[AdminApiKey] = None
- for k in candidates:
- try:
- if security.verify_password(raw_key, k.key_hash):
- matched = k
- break
- except Exception:
- continue
- if not matched:
- raise HTTPException(status_code=403, detail="Invalid or disabled admin api key")
- # 更新使用信息(失败不影响主流程)
- try:
- matched.last_used_at = datetime.now()
- matched.last_used_ip = get_client_ip(request) if request else None
- matched.use_count = (matched.use_count or 0) + 1
- db.add(matched)
- db.commit()
- db.refresh(matched)
- except Exception:
- db.rollback()
- return matched
|