Skip to content

Commit d743d3b

Browse files
authored
Merge pull request #36 from tt/keep-stdio-process-running-in-shell-mode
Keep standard I/O process running in shell mode
2 parents 9623d18 + d80b0f3 commit d743d3b

File tree

5 files changed

+83
-20
lines changed

5 files changed

+83
-20
lines changed

cmd/mcptools/commands/shell.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo
4444
os.Exit(1)
4545
}
4646

47-
mcpClient, clientErr := CreateClientFunc(parsedArgs)
47+
mcpClient, clientErr := CreateClientFunc(parsedArgs, client.CloseTransportAfterExecute(false))
4848
if clientErr != nil {
4949
fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr)
5050
os.Exit(1)

cmd/mcptools/commands/test_helpers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func setupMockClient(executeFunc func(method string, _ any) (map[string]any, err
3030
mockClient := client.NewWithTransport(mockTransport)
3131

3232
// Override the function that creates clients
33-
CreateClientFunc = func(_ []string) (*client.Client, error) {
33+
CreateClientFunc = func(_ []string, _ ...client.Option) (*client.Client, error) {
3434
return mockClient, nil
3535
}
3636

cmd/mcptools/commands/utils.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ var (
1717

1818
// CreateClientFunc is the function used to create MCP clients.
1919
// This can be replaced in tests to use a mock transport.
20-
var CreateClientFunc = func(args []string) (*client.Client, error) {
20+
var CreateClientFunc = func(args []string, opts ...client.Option) (*client.Client, error) {
2121
if len(args) == 0 {
2222
return nil, ErrCommandRequired
2323
}
@@ -42,7 +42,13 @@ var CreateClientFunc = func(args []string) (*client.Client, error) {
4242
return client.NewHTTP(args[0]), nil
4343
}
4444

45-
return client.NewStdio(args), nil
45+
c := client.NewStdio(args)
46+
47+
for _, opt := range opts {
48+
opt(c)
49+
}
50+
51+
return c, nil
4652
}
4753

4854
// ProcessFlags processes command line flags, sets the format option, and returns the remaining

pkg/client/client.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ type Client struct {
1818
transport transport.Transport
1919
}
2020

21+
// Option provides a way for passing options to the Client to change its
22+
// configuration.
23+
type Option func(*Client)
24+
25+
// CloseTransportAfterExecute allows keeping a transport alive if supported by
26+
// the transport.
27+
func CloseTransportAfterExecute(closeTransport bool) Option {
28+
return func(c *Client) {
29+
t, ok := c.transport.(interface{ SetCloseAfterExecute(bool) })
30+
if ok {
31+
t.SetCloseAfterExecute(closeTransport)
32+
}
33+
}
34+
}
35+
2136
// NewWithTransport creates a new MCP client using the provided transport.
2237
// This allows callers to provide a custom transport implementation.
2338
func NewWithTransport(t transport.Transport) *Client {

pkg/transport/stdio.go

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,20 @@ import (
1414
// Stdio implements the Transport interface by executing a command
1515
// and communicating with it via stdin/stdout using JSON-RPC.
1616
type Stdio struct {
17+
process *stdioProcess
1718
command []string
1819
nextID int
1920
debug bool
2021
}
2122

23+
// stdioProcess reflects the state of a running command.
24+
type stdioProcess struct {
25+
stdin io.WriteCloser
26+
stdout io.ReadCloser
27+
cmd *exec.Cmd
28+
stderrBuf *bytes.Buffer
29+
}
30+
2231
// NewStdio creates a new Stdio transport that will execute the given command.
2332
// It communicates with the command using JSON-RPC over stdin/stdout.
2433
func NewStdio(command []string) *Stdio {
@@ -30,23 +39,41 @@ func NewStdio(command []string) *Stdio {
3039
}
3140
}
3241

42+
// SetCloseAfterExecute toggles whether the underlying process should be closed
43+
// or kept alive after each call to Execute.
44+
func (t *Stdio) SetCloseAfterExecute(v bool) {
45+
if v {
46+
t.process = nil
47+
} else {
48+
t.process = &stdioProcess{}
49+
}
50+
}
51+
3352
// Execute implements the Transport interface by spawning a subprocess
3453
// and communicating with it via JSON-RPC over stdin/stdout.
3554
func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
36-
stdin, stdout, cmd, stderrBuf, err := t.setupCommand()
37-
if err != nil {
38-
return nil, err
55+
process := t.process
56+
if process == nil {
57+
process = &stdioProcess{}
58+
}
59+
60+
if process.cmd == nil {
61+
var err error
62+
process.stdin, process.stdout, process.cmd, process.stderrBuf, err = t.setupCommand()
63+
if err != nil {
64+
return nil, err
65+
}
3966
}
4067

4168
if t.debug {
4269
fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n")
4370
}
4471

45-
if initErr := t.initialize(stdin, stdout); initErr != nil {
72+
if initErr := t.initialize(process.stdin, process.stdout); initErr != nil {
4673
if t.debug {
4774
fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr)
48-
if stderrBuf.Len() > 0 {
49-
fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", stderrBuf.String())
75+
if process.stderrBuf.Len() > 0 {
76+
fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", process.stderrBuf.String())
5077
}
5178
}
5279
return nil, initErr
@@ -64,43 +91,58 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
6491
}
6592
t.nextID++
6693

67-
if sendErr := t.sendRequest(stdin, request); sendErr != nil {
94+
if sendErr := t.sendRequest(process.stdin, request); sendErr != nil {
6895
return nil, sendErr
6996
}
70-
_ = stdin.Close()
7197

72-
response, err := t.readResponse(stdout)
98+
response, err := t.readResponse(process.stdout)
99+
if err != nil {
100+
return nil, err
101+
}
102+
103+
err = t.closeProcess(process)
73104
if err != nil {
74105
return nil, err
75106
}
76107

108+
return response.Result, nil
109+
}
110+
111+
// closeProcess waits for the command to finish, returning any error.
112+
func (t *Stdio) closeProcess(process *stdioProcess) error {
113+
if t.process != nil {
114+
return nil
115+
}
116+
117+
_ = process.stdin.Close()
118+
77119
// Wait for the command to finish with a timeout to prevent zombie processes
78120
done := make(chan error, 1)
79121
go func() {
80-
done <- cmd.Wait()
122+
done <- process.cmd.Wait()
81123
}()
82124

83125
select {
84126
case waitErr := <-done:
85127
if t.debug {
86128
fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr)
87-
if stderrBuf.Len() > 0 {
88-
fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", stderrBuf.String())
129+
if process.stderrBuf.Len() > 0 {
130+
fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", process.stderrBuf.String())
89131
}
90132
}
91133

92-
if waitErr != nil && stderrBuf.Len() > 0 {
93-
return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String())
134+
if waitErr != nil && process.stderrBuf.Len() > 0 {
135+
return fmt.Errorf("command error: %w, stderr: %s", waitErr, process.stderrBuf.String())
94136
}
95137
case <-time.After(1 * time.Second):
96138
if t.debug {
97139
fmt.Fprintf(os.Stderr, "DEBUG: Command timed out after 1 seconds\n")
98140
}
99141
// Kill the process if it times out
100-
_ = cmd.Process.Kill()
142+
_ = process.cmd.Process.Kill()
101143
}
102144

103-
return response.Result, nil
145+
return nil
104146
}
105147

106148
// setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error.

0 commit comments

Comments
 (0)