Skip to content

Commit 0f77246

Browse files
authored
feat(core): propagate token clientID on configured claim via interceptor into shared context metadata (#2760)
- adds config for a client ID claim in OIDC access tokens (dot notation capable) - utility functions to set and get the client ID from golang context metadata (which is made available across gRPC service boundaries, unlike golang context keys that are _not_ saved to the context metadata) - reads config into various interceptors - each interceptor propagates the clientID from the parsed token for downstream consumers - improvements to logs alongside `azp` claim - unit and integration tests
1 parent 279bacd commit 0f77246

File tree

10 files changed

+445
-28
lines changed

10 files changed

+445
-28
lines changed

docs/Configuring.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ server:
352352
353353
## Dot notation is used to access the groups claim
354354
group_claim: "realm_access.roles"
355+
356+
# Dot notation is used to access the claim the represents the idP client ID
357+
client_id_claim: # azp
355358
356359
## Deprecated: Use standard casbin policy groupings (g, <user/group>, <role>)
357360
## Maps the external role to the OpenTDF role

opentdf-dev.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ server:
7272
username_claim: # preferred_username
7373
# That claim to access groups (i.e. realm_access.roles)
7474
groups_claim: # realm_access.roles
75+
# Claim the represents the idP client ID
76+
client_id_claim: # azp
7577
## Extends the builtin policy
7678
extension: |
7779
g, opentdf-admin, role:admin

opentdf-ers-mode.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ server:
2828
default: #"role:standard"
2929
## Dot notation is used to access nested claims (i.e. realm_access.roles)
3030
claim: # realm_access.roles
31+
# Claim the represents the idP client ID
32+
client_id_claim: # azp
3133
## Maps the external role to the opentdf role
3234
## Note: left side is used in the policy, right side is the external role
3335
map:

opentdf-example.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ server:
5353
username_claim: # preferred_username
5454
# That claim to access groups (i.e. realm_access.roles)
5555
groups_claim: # realm_access.roles
56+
# Claim the represents the idP client ID
57+
client_id_claim: # azp
5658
## Extends the builtin policy
5759
extension: |
5860
g, opentdf-admin, role:admin

opentdf-kas-mode.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ server:
4545
default: #"role:standard"
4646
## Dot notation is used to access nested claims (i.e. realm_access.roles)
4747
claim: # realm_access.roles
48+
# Claim the represents the idP client ID
49+
client_id_claim: # azp
4850
## Maps the external role to the opentdf role
4951
## Note: left side is used in the policy, right side is the external role
5052
map:

service/internal/auth/authn.go

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/lestrrat-go/jwx/v2/jwk"
2323
"github.com/lestrrat-go/jwx/v2/jws"
2424
"github.com/lestrrat-go/jwx/v2/jwt"
25-
"google.golang.org/grpc/metadata"
2625

2726
sdkAudit "github.com/opentdf/platform/sdk/audit"
2827
"github.com/opentdf/platform/service/logger"
@@ -62,6 +61,11 @@ var (
6261
jwa.PS384: true,
6362
jwa.PS512: true,
6463
}
64+
65+
// Exported error variables for client ID processing
66+
ErrClientIDClaimNotConfigured = errors.New("no client ID claim configured")
67+
ErrClientIDClaimNotFound = errors.New("client ID claim not found")
68+
ErrClientIDClaimNotString = errors.New("client ID claim is not a string")
6569
)
6670

6771
const (
@@ -164,7 +168,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we
164168

165169
// Try an register oidc issuer to wellknown service but don't return an error if it fails
166170
if err := wellknownRegistration("platform_issuer", cfg.Issuer); err != nil {
167-
logger.Warn("failed to register platform issuer", slog.String("error", err.Error()))
171+
logger.Warn("failed to register platform issuer", slog.Any("error", err))
168172
}
169173

170174
var oidcConfigMap map[string]any
@@ -180,7 +184,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we
180184
}
181185

182186
if err := wellknownRegistration("idp", oidcConfigMap); err != nil {
183-
logger.Warn("failed to register platform idp information", slog.String("error", err.Error()))
187+
logger.Warn("failed to register platform idp information", slog.Any("error", err))
184188
}
185189

186190
return a, nil
@@ -212,6 +216,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
212216
}
213217

214218
dp := r.Header.Values("Dpop")
219+
log := a.logger
215220

216221
// Verify the token
217222
header := r.Header["Authorization"]
@@ -228,12 +233,12 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
228233
origin = "http://" + strings.TrimSuffix(origin, ":80")
229234
}
230235
}
231-
accessTok, ctxWithJWK, err := a.checkToken(r.Context(), header, receiverInfo{
236+
accessTok, ctx, err := a.checkToken(r.Context(), header, receiverInfo{
232237
u: []string{normalizeURL(origin, r.URL)},
233238
m: []string{r.Method},
234239
}, dp)
235240
if err != nil {
236-
slog.WarnContext(r.Context(),
241+
log.WarnContext(ctx,
237242
"unauthenticated",
238243
slog.Any("error", err),
239244
slog.Any("dpop", dp),
@@ -242,12 +247,19 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
242247
return
243248
}
244249

245-
md, ok := metadata.FromIncomingContext(ctxWithJWK)
246-
if !ok {
247-
md = metadata.New(nil)
250+
clientID, err := a.getClientIDFromToken(ctx, accessTok)
251+
if err != nil {
252+
log.WarnContext(
253+
ctx,
254+
"could not determine client ID from token",
255+
slog.Any("err", err),
256+
)
257+
} else {
258+
log = log.
259+
With("client_id", clientID).
260+
With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim)
261+
ctx = ctxAuth.ContextWithAuthnMetadata(ctx, clientID)
248262
}
249-
md.Append("access_token", ctxAuth.GetRawAccessTokenFromContext(ctxWithJWK, nil))
250-
ctxWithJWK = metadata.NewIncomingContext(ctxWithJWK, md)
251263

252264
// Check if the token is allowed to access the resource
253265
var action string
@@ -263,7 +275,8 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
263275
}
264276
if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil {
265277
if err.Error() == "permission denied" {
266-
a.logger.WarnContext(r.Context(),
278+
log.WarnContext(
279+
ctx,
267280
"permission denied",
268281
slog.String("azp", accessTok.Subject()),
269282
slog.Any("error", err),
@@ -274,12 +287,16 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
274287
http.Error(w, "internal server error", http.StatusInternalServerError)
275288
return
276289
} else if !allow {
277-
a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject()))
290+
log.WarnContext(
291+
ctx,
292+
"permission denied",
293+
slog.String("azp", accessTok.Subject()),
294+
)
278295
http.Error(w, "permission denied", http.StatusForbidden)
279296
return
280297
}
281298

282-
r = r.WithContext(ctxWithJWK)
299+
r = r.WithContext(ctx)
283300
handler.ServeHTTP(w, r)
284301
})
285302
}
@@ -296,6 +313,8 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
296313
return next(ctx, req)
297314
}
298315

316+
log := a.logger
317+
299318
ri := receiverInfo{
300319
u: []string{req.Spec().Procedure},
301320
m: []string{http.MethodPost},
@@ -319,7 +338,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
319338
resource := p[1] + "/" + p[2]
320339
action := getAction(p[2])
321340

322-
token, newCtx, err := a.checkToken(
341+
token, ctxWithJWK, err := a.checkToken(
323342
ctx,
324343
header,
325344
ri,
@@ -329,22 +348,38 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
329348
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated"))
330349
}
331350

351+
clientID, err := a.getClientIDFromToken(ctxWithJWK, token)
352+
if err != nil {
353+
log.WarnContext(
354+
ctxWithJWK,
355+
"could not determine client ID from token",
356+
slog.Any("err", err),
357+
)
358+
} else {
359+
log = log.
360+
With("client_id", clientID).
361+
With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim)
362+
ctxWithJWK = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID)
363+
}
364+
332365
// Check if the token is allowed to access the resource
333366
if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil {
334367
if err.Error() == "permission denied" {
335-
a.logger.Warn("permission denied",
368+
log.WarnContext(
369+
ctxWithJWK,
370+
"permission denied",
336371
slog.String("azp", token.Subject()),
337372
slog.Any("error", err),
338373
)
339374
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
340375
}
341376
return nil, err
342377
} else if !allowed {
343-
a.logger.Warn("permission denied", slog.String("azp", token.Subject()))
378+
log.WarnContext(ctxWithJWK, "permission denied", slog.String("azp", token.Subject()))
344379
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
345380
}
346381

347-
return next(newCtx, req)
382+
return next(ctxWithJWK, req)
348383
})
349384
}
350385
return connect.UnaryInterceptorFunc(interceptor)
@@ -399,7 +434,7 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp
399434
case strings.HasPrefix(authHeader[0], "Bearer "):
400435
tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ")
401436
default:
402-
a.logger.Warn("failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0]))
437+
a.logger.WarnContext(ctx, "failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0]))
403438
return nil, nil, errors.New("not of type bearer or dpop")
404439
}
405440

@@ -431,12 +466,12 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp
431466
ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw)
432467
return accessToken, ctx, nil
433468
}
434-
key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
469+
dpopKey, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
435470
if err != nil {
436471
a.logger.Warn("failed to validate dpop", slog.Any("err", err))
437472
return nil, nil, err
438473
}
439-
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw)
474+
ctx = ctxAuth.ContextWithAuthNInfo(ctx, dpopKey, accessToken, tokenRaw)
440475
return accessToken, ctx, nil
441476
}
442477

@@ -668,7 +703,7 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header
668703
u = append(u, a.lookupGatewayPaths(ctx, path, header)...)
669704

670705
// Validate the token and create a JWT token
671-
_, nextCtx, err := a.checkToken(ctx, authHeader, receiverInfo{
706+
token, ctxWithJWK, err := a.checkToken(ctx, authHeader, receiverInfo{
672707
u: u,
673708
m: []string{http.MethodPost},
674709
}, header["Dpop"])
@@ -677,8 +712,33 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header
677712
}
678713

679714
// Return the next context with the token
680-
return nextCtx, nil
715+
clientID, err := a.getClientIDFromToken(ctxWithJWK, token)
716+
if err != nil {
717+
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated"))
718+
}
719+
return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil
681720
}
682721
}
683722
return ctx, nil
684723
}
724+
725+
// getClientIDFromToken returns the client ID from the token if found (dot notation)
726+
func (a *Authentication) getClientIDFromToken(ctx context.Context, tok jwt.Token) (string, error) {
727+
clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim
728+
if clientIDClaim == "" {
729+
return "", ErrClientIDClaimNotConfigured
730+
}
731+
claimsMap, err := tok.AsMap(ctx)
732+
if err != nil {
733+
return "", fmt.Errorf("failed to parse token as a map and find claim at [%s]: %w", clientIDClaim, err)
734+
}
735+
found := dotNotation(claimsMap, clientIDClaim)
736+
if found == nil {
737+
return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotFound, clientIDClaim)
738+
}
739+
clientID, isString := found.(string)
740+
if !isString {
741+
return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotString, clientIDClaim)
742+
}
743+
return clientID, nil
744+
}

0 commit comments

Comments
 (0)