deps.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from typing import Generator, Optional
  2. from fastapi import Depends, HTTPException, status, Response
  3. from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
  4. from jose import jwt, JWTError
  5. from sqlalchemy.orm import Session
  6. from datetime import datetime, timedelta
  7. from app.core import security
  8. from app.core.config import settings
  9. from app.core.database import SessionLocal
  10. from app.models.user import User
  11. from app.models.application import Application
  12. from app.schemas.token import TokenPayload
  13. reusable_oauth2 = OAuth2PasswordBearer(
  14. tokenUrl=f"{settings.API_V1_STR}/auth/login",
  15. auto_error=False # Allow optional token
  16. )
  17. token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
  18. def get_db() -> Generator:
  19. try:
  20. db = SessionLocal()
  21. yield db
  22. finally:
  23. db.close()
  24. def get_current_user(
  25. response: Response,
  26. db: Session = Depends(get_db),
  27. token: str = Depends(reusable_oauth2)
  28. ) -> User:
  29. if not token:
  30. raise HTTPException(
  31. status_code=status.HTTP_403_FORBIDDEN,
  32. detail="Not authenticated",
  33. )
  34. try:
  35. payload = jwt.decode(
  36. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  37. )
  38. token_data = TokenPayload(**payload)
  39. # Sliding Expiration Check
  40. # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
  41. exp = payload.get("exp")
  42. is_long_term = payload.get("long_term", False)
  43. if exp:
  44. now = datetime.now().timestamp()
  45. remaining_seconds = exp - now
  46. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  47. if is_long_term:
  48. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  49. # If remaining time is less than half of the configured expiration time
  50. if remaining_seconds < threshold:
  51. expires_delta = None
  52. if is_long_term:
  53. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  54. # Issue new token
  55. new_token = security.create_access_token(
  56. subject=token_data.sub,
  57. expires_delta=expires_delta,
  58. is_long_term=is_long_term
  59. )
  60. # Set in response header
  61. response.headers["X-New-Token"] = new_token
  62. except (JWTError, Exception):
  63. raise HTTPException(
  64. status_code=status.HTTP_403_FORBIDDEN,
  65. detail="Could not validate credentials",
  66. )
  67. # Ensure it's a user token (numeric ID)
  68. if not token_data.sub or not token_data.sub.isdigit():
  69. raise HTTPException(
  70. status_code=status.HTTP_403_FORBIDDEN,
  71. detail="Invalid token type",
  72. )
  73. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  74. if not user:
  75. raise HTTPException(status_code=404, detail="User not found")
  76. return user
  77. def get_current_active_user(
  78. current_user: User = Depends(get_current_user),
  79. ) -> User:
  80. if current_user.status != "ACTIVE":
  81. raise HTTPException(status_code=400, detail="Inactive user")
  82. return current_user
  83. def get_current_user_optional(
  84. response: Response,
  85. db: Session = Depends(get_db),
  86. token: Optional[str] = Depends(reusable_oauth2)
  87. ) -> Optional[User]:
  88. """
  89. Returns the user if the token is valid, otherwise None.
  90. Does NOT raise 403.
  91. """
  92. if not token:
  93. return None
  94. try:
  95. payload = jwt.decode(
  96. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  97. )
  98. token_data = TokenPayload(**payload)
  99. # Sliding Expiration Check for Optional Auth
  100. exp = payload.get("exp")
  101. is_long_term = payload.get("long_term", False)
  102. if exp:
  103. now = datetime.now().timestamp()
  104. remaining_seconds = exp - now
  105. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2
  106. if is_long_term:
  107. threshold = settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG * 60 / 2
  108. if remaining_seconds < threshold:
  109. expires_delta = None
  110. if is_long_term:
  111. expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES_LONG)
  112. new_token = security.create_access_token(
  113. subject=token_data.sub,
  114. expires_delta=expires_delta,
  115. is_long_term=is_long_term
  116. )
  117. response.headers["X-New-Token"] = new_token
  118. except (JWTError, Exception):
  119. return None
  120. if not token_data.sub or not token_data.sub.isdigit():
  121. return None
  122. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  123. return user
  124. def get_current_active_user_optional(
  125. current_user: Optional[User] = Depends(get_current_user_optional),
  126. ) -> Optional[User]:
  127. if current_user and current_user.status == "ACTIVE":
  128. return current_user
  129. return None
  130. def get_current_app(
  131. db: Session = Depends(get_db),
  132. token: Optional[str] = Depends(reusable_oauth2),
  133. access_token: Optional[str] = Depends(token_header_scheme)
  134. ) -> Application:
  135. """
  136. Get application from token (Machine-to-Machine auth).
  137. Supports:
  138. 1. JWT Bearer Token (Subject: "app:{id}")
  139. 2. Permanent Access Token (Header: X-App-Access-Token)
  140. """
  141. # 1. Try Access Token first if present
  142. if access_token:
  143. # Use simple auth with permanent token
  144. app = db.query(Application).filter(Application.access_token == access_token).first()
  145. if not app:
  146. raise HTTPException(status_code=403, detail="Invalid access token")
  147. return app
  148. # 2. Try JWT Bearer Token
  149. if not token:
  150. raise HTTPException(
  151. status_code=status.HTTP_403_FORBIDDEN,
  152. detail="Not authenticated",
  153. )
  154. try:
  155. payload = jwt.decode(
  156. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  157. )
  158. token_data = TokenPayload(**payload)
  159. except (JWTError, Exception):
  160. raise HTTPException(
  161. status_code=status.HTTP_403_FORBIDDEN,
  162. detail="Could not validate credentials",
  163. )
  164. sub = token_data.sub
  165. if not sub or not sub.startswith("app:"):
  166. raise HTTPException(status_code=403, detail="Not an app token")
  167. try:
  168. app_id = int(sub.split(":")[1])
  169. except (ValueError, IndexError):
  170. raise HTTPException(status_code=403, detail="Invalid app token format")
  171. app = db.query(Application).filter(Application.id == app_id).first()
  172. if not app:
  173. raise HTTPException(status_code=404, detail="App not found")
  174. return app