Skip to content

Commit dcbd45e

Browse files
committed
new high level proxy implementation + improve logfire telemetry + custom telemetry module (unused)
1 parent 2dfe13d commit dcbd45e

File tree

9 files changed

+583
-325
lines changed

9 files changed

+583
-325
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ wheels/
88

99
# Virtual environments
1010
.venv
11+
12+
.DS_Store

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.10
1+
3.12

.vscode/settings.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"python.testing.pytestArgs": [
3+
"tests"
4+
],
5+
"python.testing.unittestEnabled": false,
6+
"python.testing.pytestEnabled": true
7+
}

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "omproxy"
33
dynamic = ["version"]
44
description = "An open source proxy for MCP servers"
55
readme = "README.md"
6-
requires-python = ">=3.10"
6+
requires-python = ">=3.12"
77
dependencies = [
88
"logfire>=2.6.2",
99
"mcp>=1.1.0",
@@ -31,6 +31,9 @@ addopts = [
3131
asyncio_mode = "auto"
3232
asyncio_default_fixture_loop_scope = "function"
3333

34+
[tool.uv.sources]
35+
mcp = { git = "https://github.com/modelcontextprotocol/python-sdk.git", rev = "main" }
36+
3437
[build-system]
3538
requires = ["hatchling"]
3639
build-backend = "hatchling.build"

src/omproxy/cli.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,44 @@
33
import argparse
44
import logging
55
import os
6+
import uuid
7+
from contextvars import ContextVar
8+
from pathlib import Path
69

710
import anyio
811
import logfire
12+
from mcp.client.stdio import StdioServerParameters
913

1014
from omproxy import __version__
11-
from omproxy.proxy import Proxy
15+
from omproxy.highlevel_proxy import run_stdio_client
16+
17+
# Create a global context variable for instance_id
18+
instance_id_var = ContextVar("instance_id", default=None)
19+
20+
21+
def get_or_create_instance_id() -> str:
22+
"""Get or create a persistent UUID for this proxy instance."""
23+
id_file = Path.home() / ".omproxy" / "instance_id"
24+
id_file.parent.mkdir(parents=True, exist_ok=True)
25+
26+
if id_file.exists():
27+
return id_file.read_text().strip()
28+
29+
instance_id = str(uuid.uuid4())
30+
id_file.write_text(instance_id)
31+
return instance_id
1232

1333

1434
def main():
1535
parser = argparse.ArgumentParser(
1636
description="Bidirectional proxy for subprocess communication"
1737
)
38+
parser.add_argument(
39+
"--name",
40+
"-n",
41+
type=str,
42+
help="Name of the service",
43+
)
1844
parser.add_argument(
1945
"--version", action="version", version=__version__, help="Show version and exit"
2046
)
@@ -33,23 +59,29 @@ def main():
3359
os.environ["LOGFIRE_PROJECT_URL"] = "https://logfire.pydantic.dev/grll/iod-mcp"
3460
os.environ["LOGFIRE_API_URL"] = "https://logfire-api.pydantic.dev"
3561

62+
instance_id = get_or_create_instance_id()
63+
instance_id_var.set(instance_id)
64+
3665
# Configure logging
3766
logfire.configure(
38-
service_name="omproxy", service_version=__version__, console=False
67+
service_name=f"omproxy[{args.name}]",
68+
service_version=__version__,
69+
console=False,
3970
)
4071
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
4172

42-
# Combine command and args when running the proxy
43-
full_command = [args.command] + args.args
44-
45-
logfire.info("starting_proxy", command=args.command, args=args.args)
73+
logfire.info(
74+
"starting_proxy", command=args.command, args=args.args, instance_id=instance_id
75+
)
4676

4777
async def run_proxy():
48-
async with Proxy(
49-
lambda line: logfire.info("on_stdin_cb", line=line),
50-
lambda line: logfire.info("on_subprocess_stdout_cb", line=line),
51-
) as proxy:
52-
await proxy.run(full_command)
78+
await run_stdio_client(
79+
StdioServerParameters(
80+
command=args.command,
81+
args=args.args,
82+
env=os.environ,
83+
)
84+
)
5385

5486
anyio.run(run_proxy)
5587

src/omproxy/highlevel_proxy.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# MIT License
2+
# Copyright (c) 2024 Sergey Parfenyuk
3+
# see https://github.com/sparfenyuk/mcp-proxy/blob/main/LICENSE
4+
"""Create a local server that proxies requests to a remote server over stdio."""
5+
6+
import logfire
7+
import logging
8+
import typing as t
9+
10+
from mcp import StdioServerParameters, server, types
11+
from mcp.client.session import ClientSession
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
18+
"""Create a server instance from a remote app."""
19+
20+
response = await remote_app.initialize()
21+
capabilities = response.capabilities
22+
app = server.Server(response.serverInfo.name)
23+
24+
if capabilities.prompts:
25+
26+
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
27+
result = await remote_app.list_prompts()
28+
return types.ServerResult(result)
29+
30+
app.request_handlers[types.ListPromptsRequest] = _list_prompts
31+
32+
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
33+
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
34+
return types.ServerResult(result)
35+
36+
app.request_handlers[types.GetPromptRequest] = _get_prompt
37+
38+
if capabilities.resources:
39+
40+
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
41+
result = await remote_app.list_resources()
42+
return types.ServerResult(result)
43+
44+
app.request_handlers[types.ListResourcesRequest] = _list_resources
45+
46+
# list_resource_templates() is not implemented in the client
47+
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
48+
# result = await remote_app.list_resource_templates()
49+
# return types.ServerResult(result)
50+
51+
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
52+
53+
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
54+
from omproxy.cli import instance_id_var
55+
56+
with logfire.span(
57+
f"Reading resource '{req.params.uri}'",
58+
req=req,
59+
instance_id=instance_id_var.get(),
60+
):
61+
result = await remote_app.read_resource(req.params.uri)
62+
return types.ServerResult(result)
63+
64+
app.request_handlers[types.ReadResourceRequest] = _read_resource
65+
66+
if capabilities.logging:
67+
68+
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
69+
await remote_app.set_logging_level(req.params.level)
70+
return types.ServerResult(types.EmptyResult())
71+
72+
app.request_handlers[types.SetLevelRequest] = _set_logging_level
73+
74+
if capabilities.resources:
75+
76+
async def _subscribe_resource(
77+
req: types.SubscribeRequest,
78+
) -> types.ServerResult:
79+
await remote_app.subscribe_resource(req.params.uri)
80+
return types.ServerResult(types.EmptyResult())
81+
82+
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
83+
84+
async def _unsubscribe_resource(
85+
req: types.UnsubscribeRequest,
86+
) -> types.ServerResult:
87+
await remote_app.unsubscribe_resource(req.params.uri)
88+
return types.ServerResult(types.EmptyResult())
89+
90+
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
91+
92+
if capabilities.tools:
93+
94+
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
95+
tools = await remote_app.list_tools()
96+
return types.ServerResult(tools)
97+
98+
app.request_handlers[types.ListToolsRequest] = _list_tools
99+
100+
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
101+
# prevent circular import
102+
from omproxy.cli import instance_id_var
103+
104+
with logfire.span(
105+
f"Calling tool '{req.params.name}'",
106+
req=req,
107+
instance_id=instance_id_var.get(),
108+
):
109+
try:
110+
result = await remote_app.call_tool(
111+
req.params.name,
112+
(req.params.arguments or {}),
113+
)
114+
return types.ServerResult(result)
115+
except Exception as e: # noqa: BLE001
116+
logfire.exception(
117+
"Error calling tool", instance_id=instance_id_var.get()
118+
)
119+
return types.ServerResult(
120+
types.CallToolResult(
121+
content=[types.TextContent(type="text", text=str(e))],
122+
isError=True,
123+
),
124+
)
125+
126+
app.request_handlers[types.CallToolRequest] = _call_tool
127+
128+
async def _send_progress_notification(req: types.ProgressNotification) -> None:
129+
await remote_app.send_progress_notification(
130+
req.params.progressToken,
131+
req.params.progress,
132+
req.params.total,
133+
)
134+
135+
app.notification_handlers[types.ProgressNotification] = _send_progress_notification
136+
137+
async def _complete(req: types.CompleteRequest) -> types.ServerResult:
138+
result = await remote_app.complete(
139+
req.params.ref,
140+
req.params.argument.model_dump(),
141+
)
142+
return types.ServerResult(result)
143+
144+
app.request_handlers[types.CompleteRequest] = _complete
145+
146+
return app
147+
148+
149+
async def run_stdio_client(server_parameters: StdioServerParameters) -> None:
150+
"""Run the stdio client.
151+
152+
Args:
153+
server_parameters: The server parameters to use for stdio_client (contain mcp server to run).
154+
155+
"""
156+
from mcp.client.stdio import stdio_client
157+
158+
# here we could setup SSE with sse_client instead see:
159+
# https://github.com/sparfenyuk/mcp-proxy/blob/c132722d667e7eaea3637947fcba5dc2d821ea69/src/mcp_proxy/__init__.py#L132
160+
161+
# create the inner stdio_client and ClientSession.
162+
# stdio_client spawn a process running the mcp server based on server_parameters.
163+
# command, args, env to run the mcp server are usuallyy via cli in server_parameters.
164+
async with (
165+
stdio_client(server_parameters) as streams,
166+
ClientSession(*streams) as session,
167+
):
168+
app = await create_proxy_server(session)
169+
async with server.stdio.stdio_server() as (read_stream, write_stream):
170+
await app.run(
171+
read_stream,
172+
write_stream,
173+
app.create_initialization_options(),
174+
# raise_exceptions=True,
175+
)
File renamed without changes.

0 commit comments

Comments
 (0)