mcp_pipe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. """
  2. Simple MCP stdio <-> WebSocket pipe with optional unified config.
  3. Version: 0.2.0
  4. Usage (env):
  5. export MCP_ENDPOINT=<ws_endpoint>
  6. # Windows (PowerShell): $env:MCP_ENDPOINT = "<ws_endpoint>"
  7. Start server process(es) from config:
  8. Run all configured servers (default)
  9. python mcp_pipe.py
  10. Run a single local server script (back-compat)
  11. python mcp_pipe.py path/to/server.py
  12. Config discovery order:
  13. $MCP_CONFIG, then ./mcp_config.json
  14. Env overrides:
  15. (none for proxy; uses current Python: python -m mcp_proxy)
  16. """
  17. import asyncio
  18. import websockets
  19. import subprocess
  20. import logging
  21. import os
  22. import signal
  23. import sys
  24. import json
  25. from dotenv import load_dotenv
  26. # Auto-load environment variables from a .env file if present
  27. load_dotenv()
  28. # Configure logging
  29. logging.basicConfig(
  30. level=logging.INFO,
  31. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  32. )
  33. logger = logging.getLogger('MCP_PIPE')
  34. # Reconnection settings
  35. INITIAL_BACKOFF = 1 # Initial wait time in seconds
  36. MAX_BACKOFF = 600 # Maximum wait time in seconds
  37. async def connect_with_retry(uri, target):
  38. """Connect to WebSocket server with retry mechanism for a given server target."""
  39. reconnect_attempt = 0
  40. backoff = INITIAL_BACKOFF
  41. while True: # Infinite reconnection
  42. try:
  43. if reconnect_attempt > 0:
  44. logger.info(f"[{target}] Waiting {backoff}s before reconnection attempt {reconnect_attempt}...")
  45. await asyncio.sleep(backoff)
  46. # Attempt to connect
  47. await connect_to_server(uri, target)
  48. except Exception as e:
  49. reconnect_attempt += 1
  50. logger.warning(f"[{target}] Connection closed (attempt {reconnect_attempt}): {e}")
  51. # Calculate wait time for next reconnection (exponential backoff)
  52. backoff = min(backoff * 2, MAX_BACKOFF)
  53. async def connect_to_server(uri, target):
  54. """Connect to WebSocket server and pipe stdio for the given server target."""
  55. try:
  56. logger.info(f"[{target}] Connecting to WebSocket server...")
  57. async with websockets.connect(uri) as websocket:
  58. logger.info(f"[{target}] Successfully connected to WebSocket server")
  59. # Start server process (built from CLI arg or config)
  60. cmd, env = build_server_command(target)
  61. process = subprocess.Popen(
  62. cmd,
  63. stdin=subprocess.PIPE,
  64. stdout=subprocess.PIPE,
  65. stderr=subprocess.PIPE,
  66. encoding='utf-8',
  67. text=True,
  68. env=env
  69. )
  70. logger.info(f"[{target}] Started server process: {' '.join(cmd)}")
  71. # Create two tasks: read from WebSocket and write to process, read from process and write to WebSocket
  72. await asyncio.gather(
  73. pipe_websocket_to_process(websocket, process, target),
  74. pipe_process_to_websocket(process, websocket, target),
  75. pipe_process_stderr_to_terminal(process, target)
  76. )
  77. except websockets.exceptions.ConnectionClosed as e:
  78. logger.error(f"[{target}] WebSocket connection closed: {e}")
  79. raise # Re-throw exception to trigger reconnection
  80. except Exception as e:
  81. logger.error(f"[{target}] Connection error: {e}")
  82. raise # Re-throw exception
  83. finally:
  84. # Ensure the child process is properly terminated
  85. if 'process' in locals():
  86. logger.info(f"[{target}] Terminating server process")
  87. try:
  88. process.terminate()
  89. process.wait(timeout=5)
  90. except subprocess.TimeoutExpired:
  91. process.kill()
  92. logger.info(f"[{target}] Server process terminated")
  93. async def pipe_websocket_to_process(websocket, process, target):
  94. """Read data from WebSocket and write to process stdin"""
  95. try:
  96. while True:
  97. # Read message from WebSocket
  98. message = await websocket.recv()
  99. logger.debug(f"[{target}] << {message[:120]}...")
  100. # Write to process stdin (in text mode)
  101. if isinstance(message, bytes):
  102. message = message.decode('utf-8')
  103. process.stdin.write(message + '\n')
  104. process.stdin.flush()
  105. except Exception as e:
  106. logger.error(f"[{target}] Error in WebSocket to process pipe: {e}")
  107. raise # Re-throw exception to trigger reconnection
  108. finally:
  109. # Close process stdin
  110. if not process.stdin.closed:
  111. process.stdin.close()
  112. async def pipe_process_to_websocket(process, websocket, target):
  113. """Read data from process stdout and send to WebSocket"""
  114. try:
  115. while True:
  116. # Read data from process stdout
  117. data = await asyncio.to_thread(process.stdout.readline)
  118. if not data: # If no data, the process may have ended
  119. logger.info(f"[{target}] Process has ended output")
  120. break
  121. # Send data to WebSocket
  122. logger.debug(f"[{target}] >> {data[:120]}...")
  123. # In text mode, data is already a string, no need to decode
  124. await websocket.send(data)
  125. except Exception as e:
  126. logger.error(f"[{target}] Error in process to WebSocket pipe: {e}")
  127. raise # Re-throw exception to trigger reconnection
  128. async def pipe_process_stderr_to_terminal(process, target):
  129. """Read data from process stderr and print to terminal"""
  130. try:
  131. while True:
  132. # Read data from process stderr
  133. data = await asyncio.to_thread(process.stderr.readline)
  134. if not data: # If no data, the process may have ended
  135. logger.info(f"[{target}] Process has ended stderr output")
  136. break
  137. # Print stderr data to terminal (in text mode, data is already a string)
  138. sys.stderr.write(data)
  139. sys.stderr.flush()
  140. except Exception as e:
  141. logger.error(f"[{target}] Error in process stderr pipe: {e}")
  142. raise # Re-throw exception to trigger reconnection
  143. def signal_handler(sig, frame):
  144. """Handle interrupt signals"""
  145. logger.info("Received interrupt signal, shutting down...")
  146. sys.exit(0)
  147. def load_config():
  148. """Load JSON config from $MCP_CONFIG or ./mcp_config.json. Return dict or {}."""
  149. path = os.environ.get("MCP_CONFIG") or os.path.join(os.getcwd(), "mcp_config.json")
  150. if not os.path.exists(path):
  151. return {}
  152. try:
  153. with open(path, "r", encoding="utf-8") as f:
  154. return json.load(f)
  155. except Exception as e:
  156. logger.warning(f"Failed to load config {path}: {e}")
  157. return {}
  158. def build_server_command(target=None):
  159. """Build [cmd,...] and env for the server process for a given target.
  160. Priority:
  161. - If target matches a server in config.mcpServers: use its definition
  162. - Else: treat target as a Python script path (back-compat)
  163. If target is None, read from sys.argv[1].
  164. """
  165. if target is None:
  166. assert len(sys.argv) >= 2, "missing server name or script path"
  167. target = sys.argv[1]
  168. cfg = load_config()
  169. servers = cfg.get("mcpServers", {}) if isinstance(cfg, dict) else {}
  170. if target in servers:
  171. entry = servers[target] or {}
  172. if entry.get("disabled"):
  173. raise RuntimeError(f"Server '{target}' is disabled in config")
  174. typ = (entry.get("type") or entry.get("transportType") or "stdio").lower()
  175. # environment for child process
  176. child_env = os.environ.copy()
  177. for k, v in (entry.get("env") or {}).items():
  178. child_env[str(k)] = str(v)
  179. if typ == "stdio":
  180. command = entry.get("command")
  181. args = entry.get("args") or []
  182. if not command:
  183. raise RuntimeError(f"Server '{target}' is missing 'command'")
  184. return [command, *args], child_env
  185. if typ in ("sse", "http", "streamablehttp"):
  186. url = entry.get("url")
  187. if not url:
  188. raise RuntimeError(f"Server '{target}' (type {typ}) is missing 'url'")
  189. # Unified approach: always use current Python to run mcp-proxy module
  190. cmd = [sys.executable, "-m", "mcp_proxy"]
  191. if typ in ("http", "streamablehttp"):
  192. cmd += ["--transport", "streamablehttp"]
  193. # optional headers: {"Authorization": "Bearer xxx"}
  194. headers = entry.get("headers") or {}
  195. for hk, hv in headers.items():
  196. cmd += ["-H", hk, str(hv)]
  197. cmd.append(url)
  198. return cmd, child_env
  199. raise RuntimeError(f"Unsupported server type: {typ}")
  200. # Fallback to script path (back-compat)
  201. script_path = target
  202. if not os.path.exists(script_path):
  203. raise RuntimeError(
  204. f"'{target}' is neither a configured server nor an existing script"
  205. )
  206. return [sys.executable, script_path], os.environ.copy()
  207. if __name__ == "__main__":
  208. # Register signal handler
  209. signal.signal(signal.SIGINT, signal_handler)
  210. # Get token from environment variable or command line arguments
  211. endpoint_url = os.environ.get('MCP_ENDPOINT')
  212. if not endpoint_url:
  213. logger.error("Please set the `MCP_ENDPOINT` environment variable")
  214. sys.exit(1)
  215. # Determine target: default to all if no arg; single target otherwise
  216. target_arg = sys.argv[1] if len(sys.argv) >= 2 else None
  217. async def _main():
  218. if not target_arg:
  219. cfg = load_config()
  220. servers_cfg = (cfg.get("mcpServers") or {})
  221. all_servers = list(servers_cfg.keys())
  222. enabled = [name for name, entry in servers_cfg.items() if not (entry or {}).get("disabled")]
  223. skipped = [name for name in all_servers if name not in enabled]
  224. if skipped:
  225. logger.info(f"Skipping disabled servers: {', '.join(skipped)}")
  226. if not enabled:
  227. raise RuntimeError("No enabled mcpServers found in config")
  228. logger.info(f"Starting servers: {', '.join(enabled)}")
  229. tasks = [asyncio.create_task(connect_with_retry(endpoint_url, t)) for t in enabled]
  230. # Run all forever; if any crashes it will auto-retry inside
  231. await asyncio.gather(*tasks)
  232. else:
  233. if os.path.exists(target_arg):
  234. await connect_with_retry(endpoint_url, target_arg)
  235. else:
  236. logger.error("Argument must be a local Python script path. To run configured servers, run without arguments.")
  237. sys.exit(1)
  238. try:
  239. asyncio.run(_main())
  240. except KeyboardInterrupt:
  241. logger.info("Program interrupted by user")
  242. except Exception as e:
  243. logger.error(f"Program execution error: {e}")