| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- """
- Simple MCP stdio <-> WebSocket pipe with optional unified config.
- Version: 0.2.0
- Usage (env):
- export MCP_ENDPOINT=<ws_endpoint>
- # Windows (PowerShell): $env:MCP_ENDPOINT = "<ws_endpoint>"
- Start server process(es) from config:
- Run all configured servers (default)
- python mcp_pipe.py
- Run a single local server script (back-compat)
- python mcp_pipe.py path/to/server.py
- Config discovery order:
- $MCP_CONFIG, then ./mcp_config.json
- Env overrides:
- (none for proxy; uses current Python: python -m mcp_proxy)
- """
- import asyncio
- import websockets
- import subprocess
- import logging
- import os
- import signal
- import sys
- import json
- from dotenv import load_dotenv
- # Auto-load environment variables from a .env file if present
- load_dotenv()
- # Configure logging
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger('MCP_PIPE')
- # Reconnection settings
- INITIAL_BACKOFF = 1 # Initial wait time in seconds
- MAX_BACKOFF = 600 # Maximum wait time in seconds
- async def connect_with_retry(uri, target):
- """Connect to WebSocket server with retry mechanism for a given server target."""
- reconnect_attempt = 0
- backoff = INITIAL_BACKOFF
- while True: # Infinite reconnection
- try:
- if reconnect_attempt > 0:
- logger.info(f"[{target}] Waiting {backoff}s before reconnection attempt {reconnect_attempt}...")
- await asyncio.sleep(backoff)
- # Attempt to connect
- await connect_to_server(uri, target)
- except Exception as e:
- reconnect_attempt += 1
- logger.warning(f"[{target}] Connection closed (attempt {reconnect_attempt}): {e}")
- # Calculate wait time for next reconnection (exponential backoff)
- backoff = min(backoff * 2, MAX_BACKOFF)
- async def connect_to_server(uri, target):
- """Connect to WebSocket server and pipe stdio for the given server target."""
- try:
- logger.info(f"[{target}] Connecting to WebSocket server...")
- async with websockets.connect(uri) as websocket:
- logger.info(f"[{target}] Successfully connected to WebSocket server")
- # Start server process (built from CLI arg or config)
- cmd, env = build_server_command(target)
- process = subprocess.Popen(
- cmd,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- encoding='utf-8',
- text=True,
- env=env
- )
- logger.info(f"[{target}] Started server process: {' '.join(cmd)}")
-
- # Create two tasks: read from WebSocket and write to process, read from process and write to WebSocket
- await asyncio.gather(
- pipe_websocket_to_process(websocket, process, target),
- pipe_process_to_websocket(process, websocket, target),
- pipe_process_stderr_to_terminal(process, target)
- )
- except websockets.exceptions.ConnectionClosed as e:
- logger.error(f"[{target}] WebSocket connection closed: {e}")
- raise # Re-throw exception to trigger reconnection
- except Exception as e:
- logger.error(f"[{target}] Connection error: {e}")
- raise # Re-throw exception
- finally:
- # Ensure the child process is properly terminated
- if 'process' in locals():
- logger.info(f"[{target}] Terminating server process")
- try:
- process.terminate()
- process.wait(timeout=5)
- except subprocess.TimeoutExpired:
- process.kill()
- logger.info(f"[{target}] Server process terminated")
- async def pipe_websocket_to_process(websocket, process, target):
- """Read data from WebSocket and write to process stdin"""
- try:
- while True:
- # Read message from WebSocket
- message = await websocket.recv()
- logger.debug(f"[{target}] << {message[:120]}...")
-
- # Write to process stdin (in text mode)
- if isinstance(message, bytes):
- message = message.decode('utf-8')
- process.stdin.write(message + '\n')
- process.stdin.flush()
- except Exception as e:
- logger.error(f"[{target}] Error in WebSocket to process pipe: {e}")
- raise # Re-throw exception to trigger reconnection
- finally:
- # Close process stdin
- if not process.stdin.closed:
- process.stdin.close()
- async def pipe_process_to_websocket(process, websocket, target):
- """Read data from process stdout and send to WebSocket"""
- try:
- while True:
- # Read data from process stdout
- data = await asyncio.to_thread(process.stdout.readline)
-
- if not data: # If no data, the process may have ended
- logger.info(f"[{target}] Process has ended output")
- break
-
- # Send data to WebSocket
- logger.debug(f"[{target}] >> {data[:120]}...")
- # In text mode, data is already a string, no need to decode
- await websocket.send(data)
- except Exception as e:
- logger.error(f"[{target}] Error in process to WebSocket pipe: {e}")
- raise # Re-throw exception to trigger reconnection
- async def pipe_process_stderr_to_terminal(process, target):
- """Read data from process stderr and print to terminal"""
- try:
- while True:
- # Read data from process stderr
- data = await asyncio.to_thread(process.stderr.readline)
-
- if not data: # If no data, the process may have ended
- logger.info(f"[{target}] Process has ended stderr output")
- break
-
- # Print stderr data to terminal (in text mode, data is already a string)
- sys.stderr.write(data)
- sys.stderr.flush()
- except Exception as e:
- logger.error(f"[{target}] Error in process stderr pipe: {e}")
- raise # Re-throw exception to trigger reconnection
- def signal_handler(sig, frame):
- """Handle interrupt signals"""
- logger.info("Received interrupt signal, shutting down...")
- sys.exit(0)
- def load_config():
- """Load JSON config from $MCP_CONFIG or ./mcp_config.json. Return dict or {}."""
- path = os.environ.get("MCP_CONFIG") or os.path.join(os.getcwd(), "mcp_config.json")
- if not os.path.exists(path):
- return {}
- try:
- with open(path, "r", encoding="utf-8") as f:
- return json.load(f)
- except Exception as e:
- logger.warning(f"Failed to load config {path}: {e}")
- return {}
- def build_server_command(target=None):
- """Build [cmd,...] and env for the server process for a given target.
- Priority:
- - If target matches a server in config.mcpServers: use its definition
- - Else: treat target as a Python script path (back-compat)
- If target is None, read from sys.argv[1].
- """
- if target is None:
- assert len(sys.argv) >= 2, "missing server name or script path"
- target = sys.argv[1]
- cfg = load_config()
- servers = cfg.get("mcpServers", {}) if isinstance(cfg, dict) else {}
- if target in servers:
- entry = servers[target] or {}
- if entry.get("disabled"):
- raise RuntimeError(f"Server '{target}' is disabled in config")
- typ = (entry.get("type") or entry.get("transportType") or "stdio").lower()
- # environment for child process
- child_env = os.environ.copy()
- for k, v in (entry.get("env") or {}).items():
- child_env[str(k)] = str(v)
- if typ == "stdio":
- command = entry.get("command")
- args = entry.get("args") or []
- if not command:
- raise RuntimeError(f"Server '{target}' is missing 'command'")
- return [command, *args], child_env
- if typ in ("sse", "http", "streamablehttp"):
- url = entry.get("url")
- if not url:
- raise RuntimeError(f"Server '{target}' (type {typ}) is missing 'url'")
- # Unified approach: always use current Python to run mcp-proxy module
- cmd = [sys.executable, "-m", "mcp_proxy"]
- if typ in ("http", "streamablehttp"):
- cmd += ["--transport", "streamablehttp"]
- # optional headers: {"Authorization": "Bearer xxx"}
- headers = entry.get("headers") or {}
- for hk, hv in headers.items():
- cmd += ["-H", hk, str(hv)]
- cmd.append(url)
- return cmd, child_env
- raise RuntimeError(f"Unsupported server type: {typ}")
- # Fallback to script path (back-compat)
- script_path = target
- if not os.path.exists(script_path):
- raise RuntimeError(
- f"'{target}' is neither a configured server nor an existing script"
- )
- return [sys.executable, script_path], os.environ.copy()
- if __name__ == "__main__":
- # Register signal handler
- signal.signal(signal.SIGINT, signal_handler)
-
- # Get token from environment variable or command line arguments
- endpoint_url = os.environ.get('MCP_ENDPOINT')
- if not endpoint_url:
- logger.error("Please set the `MCP_ENDPOINT` environment variable")
- sys.exit(1)
-
- # Determine target: default to all if no arg; single target otherwise
- target_arg = sys.argv[1] if len(sys.argv) >= 2 else None
- async def _main():
- if not target_arg:
- cfg = load_config()
- servers_cfg = (cfg.get("mcpServers") or {})
- all_servers = list(servers_cfg.keys())
- enabled = [name for name, entry in servers_cfg.items() if not (entry or {}).get("disabled")]
- skipped = [name for name in all_servers if name not in enabled]
- if skipped:
- logger.info(f"Skipping disabled servers: {', '.join(skipped)}")
- if not enabled:
- raise RuntimeError("No enabled mcpServers found in config")
- logger.info(f"Starting servers: {', '.join(enabled)}")
- tasks = [asyncio.create_task(connect_with_retry(endpoint_url, t)) for t in enabled]
- # Run all forever; if any crashes it will auto-retry inside
- await asyncio.gather(*tasks)
- else:
- if os.path.exists(target_arg):
- await connect_with_retry(endpoint_url, target_arg)
- else:
- logger.error("Argument must be a local Python script path. To run configured servers, run without arguments.")
- sys.exit(1)
- try:
- asyncio.run(_main())
- except KeyboardInterrupt:
- logger.info("Program interrupted by user")
- except Exception as e:
- logger.error(f"Program execution error: {e}")
|