Skip to content

Commit 72afbc9

Browse files
authored
auth: change OAuthHandler to take http Request and Response (#603)
Change OAuthHandler signature from func(context.Context, OAuthHandlerArgs) to func(req *http.Request, res *http.Response). - Remove OAuthHandlerArgs struct - Update HTTPTransport to pass req and resp to handler - Update tests to use new signature - Handler can now call oauthex.GetProtectedResourceMetadataFromHeader with proper validation against request URL This change fixes an impedance mismatch between OAuthHandler and the protected resource metadata functions of the oauthex package. The new signature allows handlers to properly validate resource metadata against the request URL, as required by RFC 9728. Fixes #600
1 parent d256a9c commit 72afbc9

File tree

2 files changed

+13
-30
lines changed

2 files changed

+13
-30
lines changed

auth/client.go

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,19 @@ package auth
88

99
import (
1010
"bytes"
11-
"context"
1211
"errors"
1312
"io"
1413
"net/http"
1514
"sync"
1615

17-
"github.com/modelcontextprotocol/go-sdk/oauthex"
1816
"golang.org/x/oauth2"
1917
)
2018

2119
// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization
2220
// is approved, or an error if not.
23-
type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error)
24-
25-
// OAuthHandlerArgs are arguments to an [OAuthHandler].
26-
type OAuthHandlerArgs struct {
27-
// The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header.
28-
// Empty if not present or there was an error obtaining it.
29-
ResourceMetadataURL string
30-
}
21+
// The handler receives the HTTP request and response that triggered the authentication flow.
22+
// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader].
23+
type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error)
3124

3225
// HTTPTransport is an [http.RoundTripper] that follows the MCP
3326
// OAuth protocol when it encounters a 401 Unauthorized response.
@@ -112,10 +105,7 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
112105
// TODO: We hold the lock for the entire OAuth flow. This could be a long
113106
// time. Is there a better way?
114107
if _, ok := t.opts.Base.(*oauth2.Transport); !ok {
115-
authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]
116-
ts, err := t.handler(req.Context(), OAuthHandlerArgs{
117-
ResourceMetadataURL: extractResourceMetadataURL(authHeaders),
118-
})
108+
ts, err := t.handler(req, resp)
119109
if err != nil {
120110
return nil, err
121111
}
@@ -131,11 +121,3 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
131121

132122
return t.opts.Base.RoundTrip(req)
133123
}
134-
135-
func extractResourceMetadataURL(authHeaders []string) string {
136-
cs, err := oauthex.ParseWWWAuthenticate(authHeaders)
137-
if err != nil {
138-
return ""
139-
}
140-
return oauthex.ResourceMetadataURL(cs)
141-
}

auth/client_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
package auth
88

99
import (
10-
"context"
1110
"errors"
1211
"fmt"
1312
"io"
@@ -65,10 +64,11 @@ func TestHTTPTransport(t *testing.T) {
6564

6665
t.Run("successful auth flow", func(t *testing.T) {
6766
var handlerCalls int
68-
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
67+
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
6968
handlerCalls++
70-
if args.ResourceMetadataURL != "http://metadata.example.com" {
71-
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
69+
// Verify that the response has the expected WWW-Authenticate header
70+
if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` {
71+
t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate"))
7272
}
7373
return fakeTokenSource, nil
7474
}
@@ -108,9 +108,10 @@ func TestHTTPTransport(t *testing.T) {
108108
})
109109

110110
t.Run("request body is cloned", func(t *testing.T) {
111-
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
112-
if args.ResourceMetadataURL != "http://metadata.example.com" {
113-
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
111+
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
112+
// Verify that the response has the expected WWW-Authenticate header
113+
if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` {
114+
t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate"))
114115
}
115116
return fakeTokenSource, nil
116117
}
@@ -134,7 +135,7 @@ func TestHTTPTransport(t *testing.T) {
134135

135136
t.Run("handler returns error", func(t *testing.T) {
136137
handlerErr := errors.New("user rejected auth")
137-
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
138+
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
138139
return nil, handlerErr
139140
}
140141

0 commit comments

Comments
 (0)