deps.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import Generator, Optional
  2. from fastapi import Depends, HTTPException, status, Response
  3. from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
  4. from jose import jwt, JWTError
  5. from sqlalchemy.orm import Session
  6. from datetime import datetime
  7. from app.core import security
  8. from app.core.config import settings
  9. from app.core.database import SessionLocal
  10. from app.models.user import User
  11. from app.models.application import Application
  12. from app.schemas.token import TokenPayload
  13. reusable_oauth2 = OAuth2PasswordBearer(
  14. tokenUrl=f"{settings.API_V1_STR}/auth/login",
  15. auto_error=False # Allow optional token
  16. )
  17. token_header_scheme = APIKeyHeader(name="X-App-Access-Token", auto_error=False)
  18. def get_db() -> Generator:
  19. try:
  20. db = SessionLocal()
  21. yield db
  22. finally:
  23. db.close()
  24. def get_current_user(
  25. response: Response,
  26. db: Session = Depends(get_db),
  27. token: str = Depends(reusable_oauth2)
  28. ) -> User:
  29. if not token:
  30. raise HTTPException(
  31. status_code=status.HTTP_403_FORBIDDEN,
  32. detail="Not authenticated",
  33. )
  34. try:
  35. payload = jwt.decode(
  36. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  37. )
  38. token_data = TokenPayload(**payload)
  39. # Sliding Expiration Check
  40. # If token is valid but expires soon (e.g. less than half of total lifetime), renew it
  41. exp = payload.get("exp")
  42. if exp:
  43. now = datetime.now().timestamp()
  44. remaining_seconds = exp - now
  45. # If remaining time is less than half of the configured expiration time
  46. if remaining_seconds < (settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2):
  47. # Issue new token
  48. new_token = security.create_access_token(subject=token_data.sub)
  49. # Set in response header
  50. response.headers["X-New-Token"] = new_token
  51. except (JWTError, Exception):
  52. raise HTTPException(
  53. status_code=status.HTTP_403_FORBIDDEN,
  54. detail="Could not validate credentials",
  55. )
  56. # Ensure it's a user token (numeric ID)
  57. if not token_data.sub or not token_data.sub.isdigit():
  58. raise HTTPException(
  59. status_code=status.HTTP_403_FORBIDDEN,
  60. detail="Invalid token type",
  61. )
  62. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  63. if not user:
  64. raise HTTPException(status_code=404, detail="User not found")
  65. return user
  66. def get_current_active_user(
  67. current_user: User = Depends(get_current_user),
  68. ) -> User:
  69. if current_user.status != "ACTIVE":
  70. raise HTTPException(status_code=400, detail="Inactive user")
  71. return current_user
  72. def get_current_user_optional(
  73. response: Response,
  74. db: Session = Depends(get_db),
  75. token: Optional[str] = Depends(reusable_oauth2)
  76. ) -> Optional[User]:
  77. """
  78. Returns the user if the token is valid, otherwise None.
  79. Does NOT raise 403.
  80. """
  81. if not token:
  82. return None
  83. try:
  84. payload = jwt.decode(
  85. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  86. )
  87. token_data = TokenPayload(**payload)
  88. # Sliding Expiration Check for Optional Auth
  89. exp = payload.get("exp")
  90. if exp:
  91. now = datetime.now().timestamp()
  92. remaining_seconds = exp - now
  93. if remaining_seconds < (settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 / 2):
  94. new_token = security.create_access_token(subject=token_data.sub)
  95. response.headers["X-New-Token"] = new_token
  96. except (JWTError, Exception):
  97. return None
  98. if not token_data.sub or not token_data.sub.isdigit():
  99. return None
  100. user = db.query(User).filter(User.id == int(token_data.sub)).first()
  101. return user
  102. def get_current_active_user_optional(
  103. current_user: Optional[User] = Depends(get_current_user_optional),
  104. ) -> Optional[User]:
  105. if current_user and current_user.status == "ACTIVE":
  106. return current_user
  107. return None
  108. def get_current_app(
  109. db: Session = Depends(get_db),
  110. token: Optional[str] = Depends(reusable_oauth2),
  111. access_token: Optional[str] = Depends(token_header_scheme)
  112. ) -> Application:
  113. """
  114. Get application from token (Machine-to-Machine auth).
  115. Supports:
  116. 1. JWT Bearer Token (Subject: "app:{id}")
  117. 2. Permanent Access Token (Header: X-App-Access-Token)
  118. """
  119. # 1. Try Access Token first if present
  120. if access_token:
  121. # Use simple auth with permanent token
  122. app = db.query(Application).filter(Application.access_token == access_token).first()
  123. if not app:
  124. raise HTTPException(status_code=403, detail="Invalid access token")
  125. return app
  126. # 2. Try JWT Bearer Token
  127. if not token:
  128. raise HTTPException(
  129. status_code=status.HTTP_403_FORBIDDEN,
  130. detail="Not authenticated",
  131. )
  132. try:
  133. payload = jwt.decode(
  134. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  135. )
  136. token_data = TokenPayload(**payload)
  137. except (JWTError, Exception):
  138. raise HTTPException(
  139. status_code=status.HTTP_403_FORBIDDEN,
  140. detail="Could not validate credentials",
  141. )
  142. sub = token_data.sub
  143. if not sub or not sub.startswith("app:"):
  144. raise HTTPException(status_code=403, detail="Not an app token")
  145. try:
  146. app_id = int(sub.split(":")[1])
  147. except (ValueError, IndexError):
  148. raise HTTPException(status_code=403, detail="Invalid app token format")
  149. app = db.query(Application).filter(Application.id == app_id).first()
  150. if not app:
  151. raise HTTPException(status_code=404, detail="App not found")
  152. return app