Skip to content

Commit 6756fa4

Browse files
committed
implement notifs
1 parent 5910f30 commit 6756fa4

File tree

1 file changed

+78
-23
lines changed

1 file changed

+78
-23
lines changed

pkg/transport/stdio.go

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,12 @@ func (t *Stdio) setupCommand() (stdin io.WriteCloser, stdout io.ReadCloser, cmd
203203
// initialize sends the initialization request and waits for response and then sends the initialized
204204
// notification.
205205
func (t *Stdio) initialize(stdin io.WriteCloser, stdout io.ReadCloser) error {
206+
// Create initialization request with current ID
207+
initRequestID := t.nextID
206208
initRequest := Request{
207209
JSONRPC: "2.0",
208210
Method: "initialize",
209-
ID: t.nextID,
211+
ID: initRequestID,
210212
Params: map[string]any{
211213
"clientInfo": map[string]any{
212214
"name": "f/mcptools",
@@ -222,11 +224,13 @@ func (t *Stdio) initialize(stdin io.WriteCloser, stdout io.ReadCloser) error {
222224
return fmt.Errorf("init request failed: %w", err)
223225
}
224226

227+
// readResponse now properly checks for matching response ID
225228
_, err := t.readResponse(stdout)
226229
if err != nil {
227230
return fmt.Errorf("init response failed: %w", err)
228231
}
229232

233+
// Send initialized notification (notifications don't have IDs)
230234
initNotification := Request{
231235
JSONRPC: "2.0",
232236
Method: "notifications/initialized",
@@ -272,34 +276,85 @@ func (t *Stdio) sendRequest(stdin io.WriteCloser, request Request) error {
272276
return nil
273277
}
274278

275-
// readResponse reads and parses a JSON-RPC response.
279+
// readResponse reads and parses a JSON-RPC response matching the given request ID.
276280
func (t *Stdio) readResponse(stdout io.ReadCloser) (*Response, error) {
277281
reader := bufio.NewReader(stdout)
278-
line, err := reader.ReadBytes('\n')
279-
if err != nil {
280-
return nil, fmt.Errorf("error reading from stdout: %w", err)
281-
}
282282

283-
if t.debug {
284-
fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s", string(line))
285-
}
283+
// Keep track of the expected response ID (the last request ID we sent)
284+
expectedID := t.nextID - 1
286285

287-
if len(line) == 0 {
288-
return nil, fmt.Errorf("no response from command")
289-
}
286+
for {
287+
line, err := reader.ReadBytes('\n')
288+
if err != nil {
289+
return nil, fmt.Errorf("error reading from stdout: %w", err)
290+
}
290291

291-
var response Response
292-
if unmarshalErr := json.Unmarshal(line, &response); unmarshalErr != nil {
293-
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, string(line))
294-
}
292+
if t.debug {
293+
fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s", string(line))
294+
}
295295

296-
if response.Error != nil {
297-
return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message)
298-
}
296+
if len(line) == 0 {
297+
return nil, fmt.Errorf("no response from command")
298+
}
299299

300-
if t.debug {
301-
fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n")
302-
}
300+
// First check if this is a notification (no ID field)
301+
var msg map[string]interface{}
302+
if err := json.Unmarshal(line, &msg); err != nil {
303+
return nil, fmt.Errorf("error unmarshaling message: %w, response: %s", err, string(line))
304+
}
303305

304-
return &response, nil
306+
// If it's a notification, display it and continue reading
307+
if methodVal, hasMethod := msg["method"]; hasMethod && msg["id"] == nil {
308+
method, ok := methodVal.(string)
309+
if ok && method == "notifications/message" {
310+
if paramsVal, hasParams := msg["params"].(map[string]interface{}); hasParams {
311+
level, _ := paramsVal["level"].(string)
312+
data, _ := paramsVal["data"].(string)
313+
314+
// Format and print the notification based on level
315+
switch level {
316+
case "error":
317+
fmt.Fprintf(os.Stderr, "\033[31m[ERROR] %s\033[0m\n", data) // Red
318+
case "warning":
319+
fmt.Fprintf(os.Stderr, "\033[33m[WARNING] %s\033[0m\n", data) // Yellow
320+
case "alert":
321+
fmt.Fprintf(os.Stderr, "\033[35m[ALERT] %s\033[0m\n", data) // Magenta
322+
case "info":
323+
fmt.Fprintf(os.Stderr, "\033[36m[INFO] %s\033[0m\n", data) // Cyan
324+
default:
325+
fmt.Fprintf(os.Stderr, "\033[37m[%s] %s\033[0m\n", level, data) // White for unknown levels
326+
}
327+
}
328+
} else {
329+
// For other notification types
330+
fmt.Fprintf(os.Stderr, "[Notification] %s\n", string(line))
331+
}
332+
continue
333+
}
334+
335+
// Parse as a proper response
336+
var response Response
337+
if unmarshalErr := json.Unmarshal(line, &response); unmarshalErr != nil {
338+
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, string(line))
339+
}
340+
341+
// If this response has an ID field and it matches our expected ID, or if it has an error, return it
342+
if response.ID == expectedID || response.Error != nil {
343+
if response.Error != nil {
344+
return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message)
345+
}
346+
347+
if t.debug {
348+
fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response with matching ID: %d\n", response.ID)
349+
}
350+
351+
return &response, nil
352+
}
353+
354+
// Otherwise, this is a response for a different request
355+
if t.debug {
356+
fmt.Fprintf(os.Stderr, "DEBUG: Received response for request ID %d, expecting %d. Continuing to read.\n",
357+
response.ID, expectedID)
358+
}
359+
}
305360
}

0 commit comments

Comments
 (0)