from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends from app.core.websocket_manager import manager from app.core import security from app.core.config import settings from jose import jwt, JWTError from app.api.v1 import deps from sqlalchemy.orm import Session from app.core.database import get_db, SessionLocal from app.models.user import User router = APIRouter() async def get_user_from_token(token: str, db: Session): try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) user_id = payload.get("sub") if user_id is None: return None return db.query(User).filter(User.id == int(user_id)).first() except JWTError: return None @router.websocket("/messages") async def websocket_endpoint( websocket: WebSocket, token: str = Query(...) ): """ 全平台通用的 WebSocket 连接端点 连接 URL: ws://host/api/v1/ws/messages?token=YOUR_JWT_TOKEN """ db = SessionLocal() try: user = await get_user_from_token(token, db) finally: db.close() if not user: await websocket.close(code=4001, reason="Authentication failed") return await manager.connect(websocket, user.id) try: while True: # 接收客户端心跳 data = await websocket.receive_text() if data == "ping": await websocket.send_text("pong") except WebSocketDisconnect: manager.disconnect(websocket, user.id)