|
|
@@ -2,6 +2,10 @@ from typing import Dict, List
|
|
|
from fastapi import WebSocket
|
|
|
import logging
|
|
|
import json
|
|
|
+import asyncio
|
|
|
+import os
|
|
|
+from app.core.cache import redis_client
|
|
|
+from app.core.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -9,39 +13,236 @@ class ConnectionManager:
|
|
|
def __init__(self):
|
|
|
# 存储格式: {user_id: [WebSocket1, WebSocket2, ...]}
|
|
|
self.active_connections: Dict[int, List[WebSocket]] = {}
|
|
|
+ # 进程唯一标识(用于避免重复发送)
|
|
|
+ self.instance_id = f"{os.getpid()}_{id(self)}"
|
|
|
+ # Redis 订阅者任务
|
|
|
+ self._subscriber_task = None
|
|
|
+ # Redis Pub/Sub 对象
|
|
|
+ self._pubsub = None
|
|
|
+ # 异步 Redis 客户端(用于订阅)
|
|
|
+ self._async_redis = None
|
|
|
+
|
|
|
+ async def _start_subscriber(self):
|
|
|
+ """启动 Redis 订阅者,监听消息频道"""
|
|
|
+ try:
|
|
|
+ # 使用 redis.asyncio 创建异步客户端(redis 4.6+ 支持)
|
|
|
+ import redis.asyncio as aioredis
|
|
|
+
|
|
|
+ # 构建 Redis URL,处理密码为 None 的情况
|
|
|
+ if settings.REDIS_PASSWORD:
|
|
|
+ redis_url = f"redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
|
|
|
+ else:
|
|
|
+ redis_url = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
|
|
|
+
|
|
|
+ self._async_redis = aioredis.from_url(
|
|
|
+ redis_url,
|
|
|
+ decode_responses=True
|
|
|
+ )
|
|
|
+
|
|
|
+ self._pubsub = self._async_redis.pubsub()
|
|
|
+ # 订阅个人消息频道和广播频道
|
|
|
+ await self._pubsub.subscribe("ws:personal", "ws:broadcast")
|
|
|
+
|
|
|
+ logger.info(f"WebSocket Redis subscriber started (instance: {self.instance_id})")
|
|
|
+
|
|
|
+ # 持续监听消息
|
|
|
+ async for message in self._pubsub.listen():
|
|
|
+ if message['type'] == 'message':
|
|
|
+ try:
|
|
|
+ data = json.loads(message['data'])
|
|
|
+ channel = message['channel']
|
|
|
+
|
|
|
+ if channel == 'ws:personal':
|
|
|
+ # 个人消息
|
|
|
+ user_id = data.get('user_id')
|
|
|
+ msg_data = data.get('message')
|
|
|
+ sender_instance = data.get('instance_id')
|
|
|
+
|
|
|
+ # 如果不是自己发送的,且本地有该用户的连接,则发送
|
|
|
+ if sender_instance != self.instance_id and user_id in self.active_connections:
|
|
|
+ worker_pid = os.getpid()
|
|
|
+ local_connections_count = len(self.active_connections.get(user_id, []))
|
|
|
+ logger.debug(
|
|
|
+ f"Received Redis message for user {user_id}. "
|
|
|
+ f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
|
|
|
+ f"Sender Instance: {sender_instance}, "
|
|
|
+ f"Local connections: {local_connections_count}"
|
|
|
+ )
|
|
|
+ await self._send_to_local_connections(user_id, msg_data)
|
|
|
+ elif sender_instance == self.instance_id:
|
|
|
+ logger.debug(
|
|
|
+ f"Ignored own message for user {user_id} "
|
|
|
+ f"(Worker PID: {os.getpid()}, Instance ID: {self.instance_id})"
|
|
|
+ )
|
|
|
+
|
|
|
+ elif channel == 'ws:broadcast':
|
|
|
+ # 广播消息
|
|
|
+ msg_data = data.get('message')
|
|
|
+ sender_instance = data.get('instance_id')
|
|
|
+
|
|
|
+ # 如果不是自己发送的,则广播给所有本地连接
|
|
|
+ if sender_instance != self.instance_id:
|
|
|
+ worker_pid = os.getpid()
|
|
|
+ total_connections = sum(len(conns) for conns in self.active_connections.values())
|
|
|
+ logger.debug(
|
|
|
+ f"Received Redis broadcast message. "
|
|
|
+ f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
|
|
|
+ f"Sender Instance: {sender_instance}, "
|
|
|
+ f"Total local connections: {total_connections}"
|
|
|
+ )
|
|
|
+ await self._broadcast_to_local_connections(msg_data)
|
|
|
+ else:
|
|
|
+ logger.debug(
|
|
|
+ f"Ignored own broadcast message "
|
|
|
+ f"(Worker PID: {os.getpid()}, Instance ID: {self.instance_id})"
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing Redis message: {e}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Redis subscriber error: {e}")
|
|
|
+ # 如果订阅失败,尝试重新连接
|
|
|
+ await asyncio.sleep(5)
|
|
|
+ if self._pubsub:
|
|
|
+ await self._start_subscriber()
|
|
|
+
|
|
|
+ async def _send_to_local_connections(self, user_id: int, message: dict):
|
|
|
+ """向本地连接发送消息"""
|
|
|
+ if user_id in self.active_connections:
|
|
|
+ connections = self.active_connections[user_id][:]
|
|
|
+ for connection in connections:
|
|
|
+ try:
|
|
|
+ await connection.send_json(message)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error sending message to user {user_id}: {e}")
|
|
|
+ # 连接可能已断开,从列表中移除
|
|
|
+ if connection in self.active_connections[user_id]:
|
|
|
+ self.active_connections[user_id].remove(connection)
|
|
|
+
|
|
|
+ async def _broadcast_to_local_connections(self, message: str):
|
|
|
+ """广播给所有本地连接"""
|
|
|
+ all_connections = []
|
|
|
+ for connection_list in self.active_connections.values():
|
|
|
+ all_connections.extend(connection_list[:])
|
|
|
+
|
|
|
+ for connection in all_connections:
|
|
|
+ try:
|
|
|
+ await connection.send_text(message)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error broadcasting message: {e}")
|
|
|
|
|
|
async def connect(self, websocket: WebSocket, user_id: int):
|
|
|
await websocket.accept()
|
|
|
if user_id not in self.active_connections:
|
|
|
self.active_connections[user_id] = []
|
|
|
self.active_connections[user_id].append(websocket)
|
|
|
- logger.info(f"User {user_id} connected via WebSocket. Total connections: {len(self.active_connections[user_id])}")
|
|
|
+
|
|
|
+ # 增强日志:记录 worker 信息
|
|
|
+ worker_pid = os.getpid()
|
|
|
+ logger.info(
|
|
|
+ f"User {user_id} connected via WebSocket. "
|
|
|
+ f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
|
|
|
+ f"Total connections for this user: {len(self.active_connections[user_id])}, "
|
|
|
+ f"Total users on this worker: {len(self.active_connections)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 首次连接时启动订阅者(如果还没启动)
|
|
|
+ if self._subscriber_task is None or self._subscriber_task.done():
|
|
|
+ self._subscriber_task = asyncio.create_task(self._start_subscriber())
|
|
|
|
|
|
def disconnect(self, websocket: WebSocket, user_id: int):
|
|
|
+ worker_pid = os.getpid()
|
|
|
+ was_connected = user_id in self.active_connections
|
|
|
+
|
|
|
if user_id in self.active_connections:
|
|
|
if websocket in self.active_connections[user_id]:
|
|
|
self.active_connections[user_id].remove(websocket)
|
|
|
if not self.active_connections[user_id]:
|
|
|
del self.active_connections[user_id]
|
|
|
- logger.info(f"User {user_id} disconnected")
|
|
|
+
|
|
|
+ # 增强日志:记录 worker 信息和断开后的状态
|
|
|
+ remaining_connections = len(self.active_connections.get(user_id, []))
|
|
|
+ logger.info(
|
|
|
+ f"User {user_id} disconnected from WebSocket. "
|
|
|
+ f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
|
|
|
+ f"Remaining connections for this user: {remaining_connections}, "
|
|
|
+ f"Total users on this worker: {len(self.active_connections)}"
|
|
|
+ )
|
|
|
|
|
|
async def broadcast(self, message: str):
|
|
|
- for connection_list in self.active_connections.values():
|
|
|
- for connection in connection_list:
|
|
|
- await connection.send_text(message)
|
|
|
+ """广播消息给所有连接的客户端(跨进程)"""
|
|
|
+ # 先发送给本地连接
|
|
|
+ await self._broadcast_to_local_connections(message)
|
|
|
+
|
|
|
+ # 再发布到 Redis,让其他进程也能收到
|
|
|
+ try:
|
|
|
+ payload = {
|
|
|
+ "message": message,
|
|
|
+ "instance_id": self.instance_id
|
|
|
+ }
|
|
|
+ redis_client.publish("ws:broadcast", json.dumps(payload))
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error publishing broadcast to Redis: {e}")
|
|
|
|
|
|
async def send_personal_message(self, message: dict, user_id: int):
|
|
|
"""
|
|
|
- 向特定用户的所有在线设备推送消息
|
|
|
+ 向特定用户的所有在线设备推送消息(跨进程)
|
|
|
"""
|
|
|
- if user_id in self.active_connections:
|
|
|
- # 复制一份列表进行遍历,防止发送过程中连接断开导致列表变化
|
|
|
- connections = self.active_connections[user_id][:]
|
|
|
- for connection in connections:
|
|
|
- try:
|
|
|
- await connection.send_json(message)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Error sending message to user {user_id}: {e}")
|
|
|
- # 可以在此处处理无效连接的清理
|
|
|
+ worker_pid = os.getpid()
|
|
|
+ local_has_connection = user_id in self.active_connections
|
|
|
+ local_connections_count = len(self.active_connections.get(user_id, []))
|
|
|
+
|
|
|
+ # 先发送给本地连接(如果本地有该用户的连接)
|
|
|
+ await self._send_to_local_connections(user_id, message)
|
|
|
+
|
|
|
+ # 再发布到 Redis,让其他进程也能收到
|
|
|
+ try:
|
|
|
+ payload = {
|
|
|
+ "user_id": user_id,
|
|
|
+ "message": message,
|
|
|
+ "instance_id": self.instance_id
|
|
|
+ }
|
|
|
+ redis_client.publish("ws:personal", json.dumps(payload))
|
|
|
+
|
|
|
+ # 记录发送日志
|
|
|
+ logger.debug(
|
|
|
+ f"Sending message to user {user_id}. "
|
|
|
+ f"Worker PID: {worker_pid}, Instance ID: {self.instance_id}, "
|
|
|
+ f"Local connection: {local_has_connection}, "
|
|
|
+ f"Local connections count: {local_connections_count}, "
|
|
|
+ f"Published to Redis: ws:personal"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error publishing message to Redis for user {user_id}: {e}")
|
|
|
+
|
|
|
+ async def shutdown(self):
|
|
|
+ """关闭订阅者和连接"""
|
|
|
+ logger.info(f"Shutting down WebSocket manager (instance: {self.instance_id})...")
|
|
|
+
|
|
|
+ # 取消订阅任务
|
|
|
+ if self._subscriber_task and not self._subscriber_task.done():
|
|
|
+ self._subscriber_task.cancel()
|
|
|
+ try:
|
|
|
+ await self._subscriber_task
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 关闭 Redis 订阅
|
|
|
+ if self._pubsub:
|
|
|
+ try:
|
|
|
+ await self._pubsub.unsubscribe()
|
|
|
+ await self._pubsub.close()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error closing pubsub: {e}")
|
|
|
+
|
|
|
+ # 关闭异步 Redis 客户端
|
|
|
+ if self._async_redis:
|
|
|
+ try:
|
|
|
+ await self._async_redis.close()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error closing async redis: {e}")
|
|
|
+
|
|
|
+ logger.info("WebSocket manager shutdown complete")
|
|
|
|
|
|
manager = ConnectionManager()
|