Skip to content

Commit 37f8310

Browse files
authored
Fix Graphql tool discovery #131)
1 parent 468830b commit 37f8310

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

src/transports/graphql/graphql_transport.go

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -271,52 +271,46 @@ func (t *GraphQLClientTransport) RegisterToolProvider(
271271
return nil, fmt.Errorf("introspection failed: %w", err)
272272
}
273273

274-
// Build tool list
274+
// Build tool list with optional filtering by operation type/name
275275
var toolsList []Tool
276276

277-
// Register query fields
278-
for _, f := range resp.Schema.QueryType.Fields {
277+
opType := strings.ToLower(prov.OperationType)
278+
279+
// Helper to register a field if it matches the optional OperationName
280+
addTool := func(fieldName string, descPtr *string) {
281+
if prov.OperationName != nil && *prov.OperationName != fieldName {
282+
return
283+
}
279284
desc := ""
280-
if f.Description != nil {
281-
desc = *f.Description
285+
if descPtr != nil {
286+
desc = *descPtr
282287
}
283288
toolsList = append(toolsList, Tool{
284-
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
289+
Name: fmt.Sprintf("%s.%s", prov.Name, fieldName),
285290
Description: desc,
286291
Inputs: ToolInputOutputSchema{Required: nil},
287292
Provider: prov,
288293
})
289294
}
290295

296+
// Register query fields
297+
if opType == "" || opType == "query" {
298+
for _, f := range resp.Schema.QueryType.Fields {
299+
addTool(f.Name, f.Description)
300+
}
301+
}
302+
291303
// Register mutation fields
292-
if resp.Schema.MutationType != nil {
304+
if (opType == "" || opType == "mutation") && resp.Schema.MutationType != nil {
293305
for _, f := range resp.Schema.MutationType.Fields {
294-
desc := ""
295-
if f.Description != nil {
296-
desc = *f.Description
297-
}
298-
toolsList = append(toolsList, Tool{
299-
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
300-
Description: desc,
301-
Inputs: ToolInputOutputSchema{Required: nil},
302-
Provider: prov,
303-
})
306+
addTool(f.Name, f.Description)
304307
}
305308
}
306309

307310
// Register subscription fields
308-
if resp.Schema.SubscriptionType != nil {
311+
if (opType == "" || opType == "subscription") && resp.Schema.SubscriptionType != nil {
309312
for _, f := range resp.Schema.SubscriptionType.Fields {
310-
desc := ""
311-
if f.Description != nil {
312-
desc = *f.Description
313-
}
314-
toolsList = append(toolsList, Tool{
315-
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
316-
Description: desc,
317-
Inputs: ToolInputOutputSchema{Required: nil},
318-
Provider: prov,
319-
})
313+
addTool(f.Name, f.Description)
320314
}
321315
}
322316
return toolsList, nil

src/transports/graphql/graphql_transport_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,59 @@ func TestGraphQLClientTransport_RegisterAndCall(t *testing.T) {
6060
t.Fatalf("unexpected result: %#v", res)
6161
}
6262
}
63+
64+
// TestGraphQLClientTransport_RegisterToolFiltering ensures that tools are
65+
// filtered by provider OperationType and OperationName to avoid duplicates.
66+
func TestGraphQLClientTransport_RegisterToolFiltering(t *testing.T) {
67+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68+
var req struct {
69+
Query string `json:"query"`
70+
}
71+
json.NewDecoder(r.Body).Decode(&req)
72+
if strings.Contains(req.Query, "__schema") {
73+
// Return one query field and one subscription field
74+
resp := map[string]any{"data": map[string]any{"__schema": map[string]any{
75+
"queryType": map[string]any{"fields": []map[string]any{{"name": "echo"}, {"name": "ping"}}},
76+
"subscriptionType": map[string]any{"fields": []map[string]any{{"name": "updates"}}},
77+
}}}
78+
w.Header().Set("Content-Type", "application/json")
79+
json.NewEncoder(w).Encode(resp)
80+
return
81+
}
82+
http.Error(w, "bad request", http.StatusBadRequest)
83+
}))
84+
defer server.Close()
85+
86+
tr := NewGraphQLClientTransport(nil)
87+
ctx := context.Background()
88+
89+
// Query provider should only register query fields and respect OperationName
90+
qName := "echo"
91+
provQuery := &GraphQLProvider{
92+
BaseProvider: BaseProvider{Name: "gql", ProviderType: ProviderGraphQL},
93+
URL: server.URL,
94+
OperationType: "query",
95+
OperationName: &qName,
96+
}
97+
tools, err := tr.RegisterToolProvider(ctx, provQuery)
98+
if err != nil {
99+
t.Fatalf("register error: %v", err)
100+
}
101+
if len(tools) != 1 || tools[0].Name != "gql.echo" {
102+
t.Fatalf("unexpected tools: %#v", tools)
103+
}
104+
105+
// Subscription provider should only register subscription field
106+
provSub := &GraphQLProvider{
107+
BaseProvider: BaseProvider{Name: "gqlsub", ProviderType: ProviderGraphQL},
108+
URL: server.URL,
109+
OperationType: "subscription",
110+
}
111+
tools, err = tr.RegisterToolProvider(ctx, provSub)
112+
if err != nil {
113+
t.Fatalf("register error: %v", err)
114+
}
115+
if len(tools) != 1 || tools[0].Name != "gqlsub.updates" {
116+
t.Fatalf("unexpected tools for subscription: %#v", tools)
117+
}
118+
}

0 commit comments

Comments
 (0)