Skip to content

Commit 29f814d

Browse files
authored
Add automatic authentication monitoring for remote workloads with unauthenticated state detection (#2421)
* add authenticated token source with background monitoring * use single goroutine with exponentil backoff time * fix linting issue * fix workload status unintentional mapping * change the fn comment * flip the state on any error * fix linting
1 parent 8989d46 commit 29f814d

File tree

15 files changed

+647
-34
lines changed

15 files changed

+647
-34
lines changed

cmd/thv/app/list.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/spf13/cobra"
1010

11+
rt "github.com/stacklok/toolhive/pkg/container/runtime"
1112
"github.com/stacklok/toolhive/pkg/core"
1213
"github.com/stacklok/toolhive/pkg/logger"
1314
"github.com/stacklok/toolhive/pkg/workloads"
@@ -139,11 +140,17 @@ func printTextOutput(workloadList []core.Workload) {
139140

140141
// Print workload information
141142
for _, c := range workloadList {
143+
// Highlight unauthenticated workloads with a warning indicator
144+
status := string(c.Status)
145+
if c.Status == rt.WorkloadStatusUnauthenticated {
146+
status = "⚠️ " + status
147+
}
148+
142149
// Print workload information
143150
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%d\t%s\t%s\t%s\n",
144151
c.Name,
145152
c.Package,
146-
c.Status,
153+
status,
147154
c.URL,
148155
c.Port,
149156
c.ToolType,

cmd/thv/app/proxy.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package app
33
import (
44
"context"
55
"fmt"
6-
"net/http"
76
"net/url"
87
"os/signal"
98
"syscall"
@@ -19,6 +18,7 @@ import (
1918
"github.com/stacklok/toolhive/pkg/logger"
2019
"github.com/stacklok/toolhive/pkg/networking"
2120
"github.com/stacklok/toolhive/pkg/transport"
21+
"github.com/stacklok/toolhive/pkg/transport/middleware"
2222
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
2323
"github.com/stacklok/toolhive/pkg/transport/types"
2424
)
@@ -375,18 +375,7 @@ func resolveClientSecret() (string, error) {
375375

376376
// createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests
377377
func createTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.MiddlewareFunction {
378-
return func(next http.Handler) http.Handler {
379-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
380-
token, err := tokenSource.Token()
381-
if err != nil {
382-
http.Error(w, "Unable to retrieve OAuth token", http.StatusUnauthorized)
383-
return
384-
}
385-
386-
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
387-
next.ServeHTTP(w, r)
388-
})
389-
}
378+
return middleware.CreateTokenInjectionMiddleware(tokenSource)
390379
}
391380

392381
// addExternalTokenMiddleware adds token exchange or token injection middleware to the middleware chain

docs/arch/02-core-concepts.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ A **workload** is the fundamental deployment unit in ToolHive. It represents eve
3737
- `removing` - Workload is being deleted
3838
- `error` - Workload encountered an error
3939
- `unhealthy` - Workload is running but unhealthy
40+
- `unauthenticated` - Remote workload cannot authenticate (expired tokens)
4041

4142
**Implementation:**
4243
- Interface: `pkg/workloads/manager.go`

docs/arch/08-workloads-lifecycle.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,24 @@ stateDiagram-v2
2121
2222
Running --> Stopping: Stop
2323
Running --> Unhealthy: Health Failed
24+
Running --> Unauthenticated: Auth Failed
2425
Running --> Stopped: Container Exit
2526
2627
Stopping --> Stopped: Success
2728
Stopped --> Starting: Restart
2829
Stopped --> Removing: Delete
2930
31+
Unauthenticated --> Starting: Re-authenticate
32+
Unauthenticated --> Removing: Delete
33+
3034
Removing --> [*]: Success
3135
Error --> Starting: Restart
3236
Error --> Removing: Delete
3337
```
3438

3539
**States**: `pkg/container/runtime/types.go`
3640
- `starting`, `running`, `stopping`, `stopped`
37-
- `removing`, `error`, `unhealthy`
41+
- `removing`, `error`, `unhealthy`, `unauthenticated`
3842

3943
## Core Operations
4044

docs/server/docs.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.yaml

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/auth/monitored_token_source.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"golang.org/x/oauth2"
10+
11+
"github.com/stacklok/toolhive/pkg/container/runtime"
12+
)
13+
14+
// StatusUpdater is an interface for updating workload authentication status.
15+
// This abstraction allows the monitored token source to work with any status management system
16+
// without creating import cycles.
17+
type StatusUpdater interface {
18+
SetWorkloadStatus(ctx context.Context, workloadName string, status runtime.WorkloadStatus, reason string) error
19+
}
20+
21+
// MonitoredTokenSource is a wrapper around an oauth2.TokenSource that monitors authentication
22+
// failures and automatically marks workloads as unauthenticated when tokens expire or fail.
23+
// It provides both per-request token retrieval and background monitoring.
24+
type MonitoredTokenSource struct {
25+
tokenSource oauth2.TokenSource
26+
workloadName string
27+
statusUpdater StatusUpdater
28+
monitoringCtx context.Context
29+
stopMonitoring chan struct{}
30+
stopOnce sync.Once
31+
32+
timer *time.Timer
33+
}
34+
35+
// NewMonitoredTokenSource creates a new MonitoredTokenSource that wraps the provided
36+
// oauth2.TokenSource and monitors it for authentication failures.
37+
func NewMonitoredTokenSource(
38+
ctx context.Context,
39+
tokenSource oauth2.TokenSource,
40+
workloadName string,
41+
statusUpdater StatusUpdater,
42+
) *MonitoredTokenSource {
43+
return &MonitoredTokenSource{
44+
tokenSource: tokenSource,
45+
workloadName: workloadName,
46+
statusUpdater: statusUpdater,
47+
monitoringCtx: ctx,
48+
stopMonitoring: make(chan struct{}),
49+
}
50+
}
51+
52+
// Token retrieves a token from the token source and will mark the workload as unauthenticated
53+
// if the token retrieval fails.
54+
func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) {
55+
tok, err := mts.tokenSource.Token()
56+
if err != nil {
57+
mts.markAsUnauthenticated(fmt.Sprintf("Token retrieval failed: %v", err))
58+
return nil, err
59+
}
60+
return tok, nil
61+
}
62+
63+
// StartBackgroundMonitoring starts the background monitoring goroutine that checks
64+
// token validity at expiry time and marks the workload as unauthenticated on the failure.
65+
func (mts *MonitoredTokenSource) StartBackgroundMonitoring() {
66+
if mts.timer == nil {
67+
mts.timer = time.NewTimer(time.Millisecond) // kick immediately
68+
}
69+
go mts.monitorLoop()
70+
}
71+
72+
func (mts *MonitoredTokenSource) monitorLoop() {
73+
for {
74+
select {
75+
case <-mts.monitoringCtx.Done():
76+
mts.stopTimer()
77+
return
78+
case <-mts.stopMonitoring:
79+
mts.stopTimer()
80+
return
81+
case <-mts.timer.C:
82+
shouldStop, next := mts.onTick()
83+
if shouldStop {
84+
mts.stopTimer()
85+
return
86+
}
87+
mts.resetTimer(next)
88+
}
89+
}
90+
}
91+
92+
func (mts *MonitoredTokenSource) stopTimer() {
93+
if mts.timer != nil && !mts.timer.Stop() {
94+
select {
95+
case <-mts.timer.C:
96+
default:
97+
}
98+
}
99+
}
100+
101+
func (mts *MonitoredTokenSource) resetTimer(d time.Duration) {
102+
mts.stopTimer()
103+
mts.timer.Reset(d)
104+
}
105+
106+
// onTick returns (shouldStop bool, nextDelay time.Duration)
107+
func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
108+
tok, err := mts.tokenSource.Token()
109+
if err != nil {
110+
// Any error → mark as unauthenticated and stop
111+
mts.markAsUnauthenticated(fmt.Sprintf("No valid token: %v", err))
112+
return true, 0
113+
}
114+
115+
// Success → schedule next check
116+
if tok.Expiry.IsZero() {
117+
// no expiry → nothing to monitor
118+
return true, 0
119+
}
120+
wait := time.Until(tok.Expiry)
121+
if wait < time.Second {
122+
wait = time.Second
123+
}
124+
return false, wait
125+
}
126+
127+
// markAsUnauthenticated marks the workload as unauthenticated.
128+
func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string) {
129+
_ = mts.statusUpdater.SetWorkloadStatus(
130+
context.Background(),
131+
mts.workloadName,
132+
runtime.WorkloadStatusUnauthenticated,
133+
reason,
134+
)
135+
mts.stopOnce.Do(func() { close(mts.stopMonitoring) })
136+
}

0 commit comments

Comments
 (0)