瀏覽代碼

修改为redis发送消息和通知

liuq 1 月之前
父節點
當前提交
5d627899ea
共有 3 個文件被更改,包括 225 次插入18 次删除
  1. 3 2
      backend/Dockerfile
  2. 216 15
      backend/app/core/websocket_manager.py
  3. 6 1
      backend/app/main.py

+ 3 - 2
backend/Dockerfile

@@ -27,5 +27,6 @@ RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple \
 
 COPY . .
 
-# Run application
-CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
+# Run application with Gunicorn + Uvicorn Workers (multi-process)
+# This enables cross-process WebSocket message broadcasting via Redis Pub/Sub
+CMD ["gunicorn", "app.main:app", "-k", "uvicorn.workers.UvicornWorker", "--workers", "4", "--bind", "0.0.0.0:8000", "--timeout", "120", "--keep-alive", "5"]

+ 216 - 15
backend/app/core/websocket_manager.py

@@ -2,6 +2,10 @@ from typing import Dict, List
 from fastapi import WebSocket
 import logging
 import json
+import asyncio
+import os
+from app.core.cache import redis_client
+from app.core.config import settings
 
 logger = logging.getLogger(__name__)
 
@@ -9,39 +13,236 @@ class ConnectionManager:
     def __init__(self):
         # 存储格式: {user_id: [WebSocket1, WebSocket2, ...]}
         self.active_connections: Dict[int, List[WebSocket]] = {}
+        # 进程唯一标识(用于避免重复发送)
+        self.instance_id = f"{os.getpid()}_{id(self)}"
+        # Redis 订阅者任务
+        self._subscriber_task = None
+        # Redis Pub/Sub 对象
+        self._pubsub = None
+        # 异步 Redis 客户端(用于订阅)
+        self._async_redis = None
+        
+    async def _start_subscriber(self):
+        """启动 Redis 订阅者,监听消息频道"""
+        try:
+            # 使用 redis.asyncio 创建异步客户端(redis 4.6+ 支持)
+            import redis.asyncio as aioredis
+            
+            # 构建 Redis URL,处理密码为 None 的情况
+            if settings.REDIS_PASSWORD:
+                redis_url = f"redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
+            else:
+                redis_url = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
+            
+            self._async_redis = aioredis.from_url(
+                redis_url,
+                decode_responses=True
+            )
+            
+            self._pubsub = self._async_redis.pubsub()
+            # 订阅个人消息频道和广播频道
+            await self._pubsub.subscribe("ws:personal", "ws:broadcast")
+            
+            logger.info(f"WebSocket Redis subscriber started (instance: {self.instance_id})")
+            
+            # 持续监听消息
+            async for message in self._pubsub.listen():
+                if message['type'] == 'message':
+                    try:
+                        data = json.loads(message['data'])
+                        channel = message['channel']
+                        
+                        if channel == 'ws:personal':
+                            # 个人消息
+                            user_id = data.get('user_id')
+                            msg_data = data.get('message')
+                            sender_instance = data.get('instance_id')
+                            
+                            # 如果不是自己发送的,且本地有该用户的连接,则发送
+                            if sender_instance != self.instance_id and user_id in self.active_connections:
+                                worker_pid = os.getpid()
+                                local_connections_count = len(self.active_connections.get(user_id, []))
+                                logger.debug(
+                                    f"Received Redis message for user {user_id}. "
+                                    f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
+                                    f"Sender Instance: {sender_instance}, "
+                                    f"Local connections: {local_connections_count}"
+                                )
+                                await self._send_to_local_connections(user_id, msg_data)
+                            elif sender_instance == self.instance_id:
+                                logger.debug(
+                                    f"Ignored own message for user {user_id} "
+                                    f"(Worker PID: {os.getpid()}, Instance ID: {self.instance_id})"
+                                )
+                                
+                        elif channel == 'ws:broadcast':
+                            # 广播消息
+                            msg_data = data.get('message')
+                            sender_instance = data.get('instance_id')
+                            
+                            # 如果不是自己发送的,则广播给所有本地连接
+                            if sender_instance != self.instance_id:
+                                worker_pid = os.getpid()
+                                total_connections = sum(len(conns) for conns in self.active_connections.values())
+                                logger.debug(
+                                    f"Received Redis broadcast message. "
+                                    f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
+                                    f"Sender Instance: {sender_instance}, "
+                                    f"Total local connections: {total_connections}"
+                                )
+                                await self._broadcast_to_local_connections(msg_data)
+                            else:
+                                logger.debug(
+                                    f"Ignored own broadcast message "
+                                    f"(Worker PID: {os.getpid()}, Instance ID: {self.instance_id})"
+                                )
+                                
+                    except Exception as e:
+                        logger.error(f"Error processing Redis message: {e}")
+                        
+        except Exception as e:
+            logger.error(f"Redis subscriber error: {e}")
+            # 如果订阅失败,尝试重新连接
+            await asyncio.sleep(5)
+            if self._pubsub:
+                await self._start_subscriber()
+    
+    async def _send_to_local_connections(self, user_id: int, message: dict):
+        """向本地连接发送消息"""
+        if user_id in self.active_connections:
+            connections = self.active_connections[user_id][:]
+            for connection in connections:
+                try:
+                    await connection.send_json(message)
+                except Exception as e:
+                    logger.error(f"Error sending message to user {user_id}: {e}")
+                    # 连接可能已断开,从列表中移除
+                    if connection in self.active_connections[user_id]:
+                        self.active_connections[user_id].remove(connection)
+    
+    async def _broadcast_to_local_connections(self, message: str):
+        """广播给所有本地连接"""
+        all_connections = []
+        for connection_list in self.active_connections.values():
+            all_connections.extend(connection_list[:])
+        
+        for connection in all_connections:
+            try:
+                await connection.send_text(message)
+            except Exception as e:
+                logger.error(f"Error broadcasting message: {e}")
 
     async def connect(self, websocket: WebSocket, user_id: int):
         await websocket.accept()
         if user_id not in self.active_connections:
             self.active_connections[user_id] = []
         self.active_connections[user_id].append(websocket)
-        logger.info(f"User {user_id} connected via WebSocket. Total connections: {len(self.active_connections[user_id])}")
+        
+        # 增强日志:记录 worker 信息
+        worker_pid = os.getpid()
+        logger.info(
+            f"User {user_id} connected via WebSocket. "
+            f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
+            f"Total connections for this user: {len(self.active_connections[user_id])}, "
+            f"Total users on this worker: {len(self.active_connections)}"
+        )
+        
+        # 首次连接时启动订阅者(如果还没启动)
+        if self._subscriber_task is None or self._subscriber_task.done():
+            self._subscriber_task = asyncio.create_task(self._start_subscriber())
 
     def disconnect(self, websocket: WebSocket, user_id: int):
+        worker_pid = os.getpid()
+        was_connected = user_id in self.active_connections
+        
         if user_id in self.active_connections:
             if websocket in self.active_connections[user_id]:
                 self.active_connections[user_id].remove(websocket)
             if not self.active_connections[user_id]:
                 del self.active_connections[user_id]
-        logger.info(f"User {user_id} disconnected")
+        
+        # 增强日志:记录 worker 信息和断开后的状态
+        remaining_connections = len(self.active_connections.get(user_id, []))
+        logger.info(
+            f"User {user_id} disconnected from WebSocket. "
+            f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
+            f"Remaining connections for this user: {remaining_connections}, "
+            f"Total users on this worker: {len(self.active_connections)}"
+        )
 
     async def broadcast(self, message: str):
-        for connection_list in self.active_connections.values():
-            for connection in connection_list:
-                await connection.send_text(message)
+        """广播消息给所有连接的客户端(跨进程)"""
+        # 先发送给本地连接
+        await self._broadcast_to_local_connections(message)
+        
+        # 再发布到 Redis,让其他进程也能收到
+        try:
+            payload = {
+                "message": message,
+                "instance_id": self.instance_id
+            }
+            redis_client.publish("ws:broadcast", json.dumps(payload))
+        except Exception as e:
+            logger.error(f"Error publishing broadcast to Redis: {e}")
 
     async def send_personal_message(self, message: dict, user_id: int):
         """
-        向特定用户的所有在线设备推送消息
+        向特定用户的所有在线设备推送消息(跨进程)
         """
-        if user_id in self.active_connections:
-            # 复制一份列表进行遍历,防止发送过程中连接断开导致列表变化
-            connections = self.active_connections[user_id][:]
-            for connection in connections:
-                try:
-                    await connection.send_json(message)
-                except Exception as e:
-                    logger.error(f"Error sending message to user {user_id}: {e}")
-                    # 可以在此处处理无效连接的清理
+        worker_pid = os.getpid()
+        local_has_connection = user_id in self.active_connections
+        local_connections_count = len(self.active_connections.get(user_id, []))
+        
+        # 先发送给本地连接(如果本地有该用户的连接)
+        await self._send_to_local_connections(user_id, message)
+        
+        # 再发布到 Redis,让其他进程也能收到
+        try:
+            payload = {
+                "user_id": user_id,
+                "message": message,
+                "instance_id": self.instance_id
+            }
+            redis_client.publish("ws:personal", json.dumps(payload))
+            
+            # 记录发送日志
+            logger.debug(
+                f"Sending message to user {user_id}. "
+                f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
+                f"Local connection: {local_has_connection}, "
+                f"Local connections count: {local_connections_count}, "
+                f"Published to Redis: ws:personal"
+            )
+        except Exception as e:
+            logger.error(f"Error publishing message to Redis for user {user_id}: {e}")
+    
+    async def shutdown(self):
+        """关闭订阅者和连接"""
+        logger.info(f"Shutting down WebSocket manager (instance: {self.instance_id})...")
+        
+        # 取消订阅任务
+        if self._subscriber_task and not self._subscriber_task.done():
+            self._subscriber_task.cancel()
+            try:
+                await self._subscriber_task
+            except asyncio.CancelledError:
+                pass
+        
+        # 关闭 Redis 订阅
+        if self._pubsub:
+            try:
+                await self._pubsub.unsubscribe()
+                await self._pubsub.close()
+            except Exception as e:
+                logger.error(f"Error closing pubsub: {e}")
+        
+        # 关闭异步 Redis 客户端
+        if self._async_redis:
+            try:
+                await self._async_redis.close()
+            except Exception as e:
+                logger.error(f"Error closing async redis: {e}")
+        
+        logger.info("WebSocket manager shutdown complete")
 
 manager = ConnectionManager()

+ 6 - 1
backend/app/main.py

@@ -74,7 +74,12 @@ async def lifespan(app: FastAPI):
         logger.error(f"Failed to init scheduler or fix user names: {e}")
 
     yield
-    # Shutdown logic if any (e.g. close connections)
+    
+    # Shutdown logic
+    logger.info("Shutting down WebSocket manager...")
+    from app.core.websocket_manager import manager
+    await manager.shutdown()
+    
     scheduler.shutdown()
     logger.info("Shutting down...")