| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends
- from app.core.websocket_manager import manager
- from app.core import security
- from app.core.config import settings
- from jose import jwt, JWTError
- from app.api.v1 import deps
- from sqlalchemy.orm import Session
- from app.core.database import get_db, SessionLocal
- from app.models.user import User
- router = APIRouter()
- async def get_user_from_token(token: str, db: Session):
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
- user_id = payload.get("sub")
- if user_id is None:
- return None
- return db.query(User).filter(User.id == int(user_id)).first()
- except JWTError:
- return None
- @router.websocket("/messages")
- async def websocket_endpoint(
- websocket: WebSocket,
- token: str = Query(...)
- ):
- """
- 全平台通用的 WebSocket 连接端点
- 连接 URL: ws://host/api/v1/ws/messages?token=YOUR_JWT_TOKEN
- """
- db = SessionLocal()
- try:
- user = await get_user_from_token(token, db)
- finally:
- db.close()
- if not user:
- await websocket.close(code=4001, reason="Authentication failed")
- return
- await manager.connect(websocket, user.id)
- try:
- while True:
- # 接收客户端心跳
- data = await websocket.receive_text()
- if data == "ping":
- await websocket.send_text("pong")
- except WebSocketDisconnect:
- manager.disconnect(websocket, user.id)
|