Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cmd/thv/app/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/spf13/cobra"

rt "github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/core"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/workloads"
Expand Down Expand Up @@ -139,11 +140,17 @@ func printTextOutput(workloadList []core.Workload) {

// Print workload information
for _, c := range workloadList {
// Highlight unauthenticated workloads with a warning indicator
status := string(c.Status)
if c.Status == rt.WorkloadStatusUnauthenticated {
status = "⚠️ " + status
}

// Print workload information
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%d\t%s\t%s\t%s\n",
c.Name,
c.Package,
c.Status,
status,
c.URL,
c.Port,
c.ToolType,
Expand Down
15 changes: 2 additions & 13 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package app
import (
"context"
"fmt"
"net/http"
"net/url"
"os/signal"
"syscall"
Expand All @@ -19,6 +18,7 @@ import (
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/transport"
"github.com/stacklok/toolhive/pkg/transport/middleware"
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
"github.com/stacklok/toolhive/pkg/transport/types"
)
Expand Down Expand Up @@ -375,18 +375,7 @@ func resolveClientSecret() (string, error) {

// createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests
func createTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.MiddlewareFunction {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := tokenSource.Token()
if err != nil {
http.Error(w, "Unable to retrieve OAuth token", http.StatusUnauthorized)
return
}

r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
next.ServeHTTP(w, r)
})
}
return middleware.CreateTokenInjectionMiddleware(tokenSource)
}

// addExternalTokenMiddleware adds token exchange or token injection middleware to the middleware chain
Expand Down
1 change: 1 addition & 0 deletions docs/arch/02-core-concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ A **workload** is the fundamental deployment unit in ToolHive. It represents eve
- `removing` - Workload is being deleted
- `error` - Workload encountered an error
- `unhealthy` - Workload is running but unhealthy
- `unauthenticated` - Remote workload cannot authenticate (expired tokens)

**Implementation:**
- Interface: `pkg/workloads/manager.go`
Expand Down
6 changes: 5 additions & 1 deletion docs/arch/08-workloads-lifecycle.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ stateDiagram-v2

Running --> Stopping: Stop
Running --> Unhealthy: Health Failed
Running --> Unauthenticated: Auth Failed
Running --> Stopped: Container Exit

Stopping --> Stopped: Success
Stopped --> Starting: Restart
Stopped --> Removing: Delete

Unauthenticated --> Starting: Re-authenticate
Unauthenticated --> Removing: Delete

Removing --> [*]: Success
Error --> Starting: Restart
Error --> Removing: Delete
```

**States**: `pkg/container/runtime/types.go`
- `starting`, `running`, `stopping`, `stopped`
- `removing`, `error`, `unhealthy`
- `removing`, `error`, `unhealthy`, `unauthenticated`

## Core Operations

Expand Down
2 changes: 1 addition & 1 deletion docs/server/docs.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/server/swagger.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

136 changes: 136 additions & 0 deletions pkg/auth/monitored_token_source.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package auth

import (
"context"
"fmt"
"sync"
"time"

"golang.org/x/oauth2"

"github.com/stacklok/toolhive/pkg/container/runtime"
)

// StatusUpdater is an interface for updating workload authentication status.
// This abstraction allows the monitored token source to work with any status management system
// without creating import cycles.
type StatusUpdater interface {
SetWorkloadStatus(ctx context.Context, workloadName string, status runtime.WorkloadStatus, reason string) error
}

// MonitoredTokenSource is a wrapper around an oauth2.TokenSource that monitors authentication
// failures and automatically marks workloads as unauthenticated when tokens expire or fail.
// It provides both per-request token retrieval and background monitoring.
type MonitoredTokenSource struct {
tokenSource oauth2.TokenSource
workloadName string
statusUpdater StatusUpdater
monitoringCtx context.Context
stopMonitoring chan struct{}
stopOnce sync.Once

timer *time.Timer
}

// NewMonitoredTokenSource creates a new MonitoredTokenSource that wraps the provided
// oauth2.TokenSource and monitors it for authentication failures.
func NewMonitoredTokenSource(
ctx context.Context,
tokenSource oauth2.TokenSource,
workloadName string,
statusUpdater StatusUpdater,
) *MonitoredTokenSource {
return &MonitoredTokenSource{
tokenSource: tokenSource,
workloadName: workloadName,
statusUpdater: statusUpdater,
monitoringCtx: ctx,
stopMonitoring: make(chan struct{}),
}
}

// Token retrieves a token from the token source and will mark the workload as unauthenticated
// if the token retrieval fails.
func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) {
tok, err := mts.tokenSource.Token()
if err != nil {
mts.markAsUnauthenticated(fmt.Sprintf("Token retrieval failed: %v", err))
return nil, err
}
return tok, nil
}

// StartBackgroundMonitoring starts the background monitoring goroutine that checks
// token validity at expiry time and marks the workload as unauthenticated on the failure.
func (mts *MonitoredTokenSource) StartBackgroundMonitoring() {
if mts.timer == nil {
mts.timer = time.NewTimer(time.Millisecond) // kick immediately
}
go mts.monitorLoop()
}

func (mts *MonitoredTokenSource) monitorLoop() {
for {
select {
case <-mts.monitoringCtx.Done():
mts.stopTimer()
return
case <-mts.stopMonitoring:
mts.stopTimer()
return
case <-mts.timer.C:
shouldStop, next := mts.onTick()
if shouldStop {
mts.stopTimer()
return
}
mts.resetTimer(next)
}
}
}

func (mts *MonitoredTokenSource) stopTimer() {
if mts.timer != nil && !mts.timer.Stop() {
select {
case <-mts.timer.C:
default:
}
}
}

func (mts *MonitoredTokenSource) resetTimer(d time.Duration) {
mts.stopTimer()
mts.timer.Reset(d)
}

// onTick returns (shouldStop bool, nextDelay time.Duration)
func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
tok, err := mts.tokenSource.Token()
if err != nil {
// Any error → mark as unauthenticated and stop
mts.markAsUnauthenticated(fmt.Sprintf("No valid token: %v", err))
return true, 0
}

// Success → schedule next check
if tok.Expiry.IsZero() {
// no expiry → nothing to monitor
return true, 0
}
wait := time.Until(tok.Expiry)
if wait < time.Second {
wait = time.Second
}
return false, wait
}

// markAsUnauthenticated marks the workload as unauthenticated.
func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string) {
_ = mts.statusUpdater.SetWorkloadStatus(
context.Background(),
mts.workloadName,
runtime.WorkloadStatusUnauthenticated,
reason,
)
mts.stopOnce.Do(func() { close(mts.stopMonitoring) })
}
Loading
Loading