ws.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends
  2. from app.core.websocket_manager import manager
  3. from app.core import security
  4. from app.core.config import settings
  5. from jose import jwt, JWTError
  6. from app.api.v1 import deps
  7. from sqlalchemy.orm import Session
  8. from app.core.database import get_db, SessionLocal
  9. from app.models.user import User
  10. router = APIRouter()
  11. async def get_user_from_token(token: str, db: Session):
  12. try:
  13. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  14. user_id = payload.get("sub")
  15. if user_id is None:
  16. return None
  17. return db.query(User).filter(User.id == int(user_id)).first()
  18. except JWTError:
  19. return None
  20. @router.websocket("/messages")
  21. async def websocket_endpoint(
  22. websocket: WebSocket,
  23. token: str = Query(...)
  24. ):
  25. """
  26. 全平台通用的 WebSocket 连接端点
  27. 连接 URL: ws://host/api/v1/ws/messages?token=YOUR_JWT_TOKEN
  28. """
  29. db = SessionLocal()
  30. try:
  31. user = await get_user_from_token(token, db)
  32. finally:
  33. db.close()
  34. if not user:
  35. await websocket.close(code=4001, reason="Authentication failed")
  36. return
  37. await manager.connect(websocket, user.id)
  38. try:
  39. while True:
  40. # 接收客户端心跳
  41. data = await websocket.receive_text()
  42. if data == "ping":
  43. await websocket.send_text("pong")
  44. except WebSocketDisconnect:
  45. manager.disconnect(websocket, user.id)