diff --git a/jesse/__init__.py b/jesse/__init__.py index a74eca941..99a5fdf78 100644 --- a/jesse/__init__.py +++ b/jesse/__init__.py @@ -244,6 +244,49 @@ def run() -> None: uvicorn.run(fastapi_app, host=host, port=port, log_level="info") +@cli.command() +@click.option( + "--transport", + type=click.Choice(["stdio", "sse", "streamable-http", "http"]), + default="stdio", + show_default=True, + help="Transport protocol to expose for the FastMCP server.", +) +@click.option( + "--host", + default="127.0.0.1", + show_default=True, + help="Host address for HTTP or SSE transports.", +) +@click.option( + "--port", + default=8765, + show_default=True, + help="Port for HTTP or SSE transports.", +) +@click.option( + "--hide-banner", + is_flag=True, + default=False, + help="Hide the FastMCP startup banner output.", +) +def fastmcp(transport: str, host: str, port: int, hide_banner: bool) -> None: + """Run the FastMCP server that exposes Jesse management tools.""" + + from jesse.services.mcp import run_fastmcp_server + + transport_kwargs = {} + if transport in {"http", "sse", "streamable-http"}: + transport_kwargs["host"] = host + transport_kwargs["port"] = port + + run_fastmcp_server( + transport=transport, + show_banner=not hide_banner, + **transport_kwargs, + ) + + @fastapi_app.on_event("shutdown") def shutdown_event(): from jesse.services.db import database diff --git a/jesse/services/mcp/__init__.py b/jesse/services/mcp/__init__.py new file mode 100644 index 000000000..138530f9a --- /dev/null +++ b/jesse/services/mcp/__init__.py @@ -0,0 +1,11 @@ +"""Integration with FastMCP for managing Jesse background processes.""" + +from .server import ( + create_fastmcp_server, + run_fastmcp_server, +) + +__all__ = [ + "create_fastmcp_server", + "run_fastmcp_server", +] diff --git a/jesse/services/mcp/server.py b/jesse/services/mcp/server.py new file mode 100644 index 000000000..e614073c6 --- /dev/null +++ b/jesse/services/mcp/server.py @@ -0,0 +1,171 @@ +"""FastMCP server integration for Jesse.""" + +from __future__ import annotations + +import uuid +from typing import Any, Dict, List, Optional + +import jesse.helpers as jh +from fastmcp import Context, FastMCP +from pydantic import BaseModel, Field + +from jesse.services.multiprocessing import process_manager +from jesse.services.web import BacktestRequestJson + +_DEFAULT_INSTRUCTIONS = ( + "Control Jesse's background workers via the Model Context Protocol. " + "Use the available tools to start new backtests, inspect running tasks, " + "and request graceful shutdowns of worker processes." +) + + +class StartBacktestPayload(BaseModel): + """Parameters accepted by the FastMCP backtest tool.""" + + exchange: str + routes: List[Dict[str, str]] + data_routes: List[Dict[str, str]] + config: Dict[str, Any] + start_date: str + finish_date: str + id: Optional[str] = Field( + default=None, + description=( + "Unique client identifier for the spawned worker. If omitted, an ID " + "is generated automatically." + ), + ) + debug_mode: bool = False + export_csv: bool = False + export_json: bool = False + export_chart: bool = False + export_tradingview: bool = False + fast_mode: bool = False + benchmark: bool = False + + def to_backtest_request(self) -> BacktestRequestJson: + """Convert payload to the internal request schema.""" + + payload = self.model_dump() + if not payload.get("id"): + payload["id"] = uuid.uuid4().hex + return BacktestRequestJson.model_validate(payload) + + +class CancelProcessPayload(BaseModel): + """Payload for requesting cancellation of a running worker.""" + + client_id: str = Field(description="Identifier of the worker to cancel.") + + +def create_fastmcp_server( + *, + name: str | None = None, + instructions: str | None = None, +) -> FastMCP: + """Instantiate a FastMCP server with Jesse specific tools.""" + + server = FastMCP( + name=name or "Jesse MCP Server", + instructions=instructions or _DEFAULT_INSTRUCTIONS, + ) + + @server.tool( + name="start_backtest", + description="Start a Jesse backtest in a managed worker process.", + tags=["backtest", "jesse"], + ) + def start_backtest( + payload: StartBacktestPayload, + ctx: Context | None = None, + ) -> Dict[str, str]: + request = payload.to_backtest_request() + jh.validate_cwd() + + from jesse.modes.backtest_mode import run as run_backtest + + process_manager.add_task( + run_backtest, + request.id, + request.debug_mode, + request.config, + request.exchange, + request.routes, + request.data_routes, + request.start_date, + request.finish_date, + None, + request.export_chart, + request.export_tradingview, + request.export_csv, + request.export_json, + request.fast_mode, + request.benchmark, + ) + + message = f"Backtest session {request.id} started" + if ctx is not None: + ctx.info(message) + return {"status": "started", "id": request.id} + + @server.tool( + name="list_processes", + description="List active Jesse worker identifiers managed by the process manager.", + tags=["jesse"], + ) + def list_processes() -> List[str]: + return sorted(process_manager.active_workers) + + @server.tool( + name="cancel_process", + description="Request cancellation of a Jesse worker by its client identifier.", + tags=["jesse"], + ) + def cancel_process( + payload: CancelProcessPayload, + ctx: Context | None = None, + ) -> Dict[str, str]: + process_manager.cancel_process(payload.client_id) + message = f"Cancellation requested for worker {payload.client_id}" + if ctx is not None: + ctx.info(message) + return {"status": "requested", "id": payload.client_id} + + @server.tool( + name="flush_processes", + description="Terminate all running Jesse worker processes.", + tags=["jesse"], + ) + def flush_processes(ctx: Context | None = None) -> Dict[str, str]: + process_manager.flush() + message = "All Jesse workers were terminated" + if ctx is not None: + ctx.info(message) + return {"status": "terminated"} + + return server + + +def run_fastmcp_server( + *, + transport: str = "stdio", + host: str | None = None, + port: int | None = None, + show_banner: bool = True, + **transport_kwargs: Any, +) -> None: + """Run the FastMCP server using the requested transport.""" + + jh.validate_cwd() + server = create_fastmcp_server() + + supported_http_transports = {"http", "sse", "streamable-http"} + kwargs: Dict[str, Any] = dict(transport_kwargs) + + if transport in supported_http_transports: + if host is not None: + kwargs.setdefault("host", host) + if port is not None: + kwargs.setdefault("port", port) + + server.run(transport=transport, show_banner=show_banner, **kwargs) diff --git a/requirements.txt b/requirements.txt index caf3d2bac..66517765b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,9 +19,10 @@ simplejson~=3.16.0 aioredis~=1.3.1 redis~=4.1.4 fastapi~=0.111.1 -uvicorn~=0.29.0 +fastmcp~=2.12.4 +uvicorn~=0.37.0 websockets>=10.0.0 -python-dotenv~=0.19.2 +python-dotenv~=1.1.0 aiofiles~=0.7.0 numba~=0.61.0rc2 PyJWT~=2.8.0 diff --git a/tests/test_fastmcp_integration.py b/tests/test_fastmcp_integration.py new file mode 100644 index 000000000..22cd3efcb --- /dev/null +++ b/tests/test_fastmcp_integration.py @@ -0,0 +1,50 @@ +"""Tests for the FastMCP integration layer.""" + +from pathlib import Path +from types import SimpleNamespace +import sys + + +if "pkg_resources" not in sys.modules: # pragma: no cover - test helper setup + class _PkgResourcesStub: + def resource_filename(self, package: str, resource: str) -> str: + base = Path(__file__).resolve().parent.parent / package + target = base if not resource else base / resource + return str(target) + + def get_distribution(self, name: str) -> SimpleNamespace: + return SimpleNamespace(version="0.0.0") + + sys.modules["pkg_resources"] = _PkgResourcesStub() + +from jesse.services.mcp import create_fastmcp_server +from jesse.services.mcp.server import StartBacktestPayload + + +def test_fastmcp_server_registers_expected_tools(): + server = create_fastmcp_server() + tool_names = set(server._tool_manager._tools.keys()) + assert {"start_backtest", "list_processes", "cancel_process", "flush_processes"} <= tool_names + + +def test_start_backtest_payload_generates_identifier(): + payload = StartBacktestPayload( + exchange="Sandbox", + routes=[ + { + "exchange": "Sandbox", + "symbol": "BTC-USDT", + "timeframe": "1m", + "strategy": "Example", + } + ], + data_routes=[{"exchange": "Sandbox", "symbol": "BTC-USDT", "timeframe": "5m"}], + config={}, + start_date="2020-01-01", + finish_date="2020-01-02", + ) + + request = payload.to_backtest_request() + + assert request.id is not None + assert request.exchange == "Sandbox"