diff --git a/authclient/authclient.go b/authclient/authclient.go index e4cecf77..67d747d6 100644 --- a/authclient/authclient.go +++ b/authclient/authclient.go @@ -3,15 +3,19 @@ package authclient import ( "context" + "encoding/base64" + "encoding/json" "fmt" "io" "net" "net/http" + "net/http/cookiejar" "net/url" "os" "strings" "time" + "github.com/pomerium/pomerium/pkg/identity/oidc" "golang.org/x/sync/errgroup" ) @@ -41,6 +45,10 @@ func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL, onOpen return strings.TrimSpace(string(rawJWTBytes)), nil } + if client.cfg.deviceCodeFlow { + return client.runDeviceCodeFlow(ctx, serverURL) + } + li, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return "", fmt.Errorf("failed to start listener: %w", err) @@ -160,3 +168,110 @@ func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, s _, _ = fmt.Fprintf(os.Stderr, "Your browser has been opened to visit:\n\n%s\n\n", string(bs)) return nil } + +type DeviceAuthTokenResponse struct { + Token string `json:"token"` +} + +func (client *AuthClient) runDeviceCodeFlow(ctx context.Context, requestURL *url.URL) (string, error) { + apiUrl := requestURL.ResolveReference(&url.URL{ + Path: "/.pomerium/api/v1/device_auth", + }) + + req, err := http.NewRequestWithContext(ctx, "GET", apiUrl.String(), nil) + if err != nil { + return "", err + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = client.cfg.tlsConfig + + jar, err := cookiejar.New(nil) + if err != nil { + return "", err + } + hc := &http.Client{ + Timeout: 10 * time.Minute, + Transport: transport, + Jar: jar, + } + + res, err := hc.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return "", fmt.Errorf("authentication failed: %s", res.Status) + } + + if res.Header.Get("Content-Type") != "application/json" { + return "", fmt.Errorf("unexpected content type: %s", res.Header.Get("Content-Type")) + } + + bytes, err := io.ReadAll(res.Body) + if err != nil { + return "", err + } + + var response oidc.UserDeviceAuthResponse + if err := json.Unmarshal(bytes, &response); err != nil { + return "", err + } + + fmt.Fprintf(os.Stderr, "Authenticate with your browser at %s\n", response.VerificationURIComplete) + + delay := time.Duration(response.InitialRetryDelay) * time.Second + numRetries := 10 + for i := 0; i < numRetries; i++ { + select { + case <-time.After(delay): + case <-ctx.Done(): + return "", ctx.Err() + } + req, err = http.NewRequestWithContext(ctx, "POST", apiUrl.String(), strings.NewReader(url.Values{ + "pomerium_device_auth_retry_token": {base64.StdEncoding.EncodeToString(response.RetryToken)}, + }.Encode())) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + res, err = hc.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusTooManyRequests: + if retryAfter := res.Header.Get("Retry-After"); retryAfter != "" { + if d, err := time.ParseDuration(retryAfter); err == nil { + delay = d + } + } + continue + default: + return "", fmt.Errorf("authentication failed: %s", res.Status) + } + + if res.Header.Get("Content-Type") != "application/json" { + return "", fmt.Errorf("unexpected content type: %s", res.Header.Get("Content-Type")) + } + + tokenBytes, err := io.ReadAll(res.Body) + if err != nil { + return "", err + } + + var tokenResponse DeviceAuthTokenResponse + if err := json.Unmarshal(tokenBytes, &tokenResponse); err != nil { + return "", err + } + + return tokenResponse.Token, nil + } + + return "", fmt.Errorf("authentication timed out after %d retries", numRetries) +} diff --git a/authclient/config.go b/authclient/config.go index e349c3af..48d306d7 100644 --- a/authclient/config.go +++ b/authclient/config.go @@ -8,6 +8,7 @@ import ( type config struct { open func(rawURL string) error + deviceCodeFlow bool serviceAccount string serviceAccountFile string tlsConfig *tls.Config @@ -58,3 +59,9 @@ func WithTLSConfig(tlsConfig *tls.Config) Option { cfg.tlsConfig = tlsConfig.Clone() } } + +func WithUseDeviceCodeFlow(enabled bool) Option { + return func(cfg *config) { + cfg.deviceCodeFlow = enabled + } +} diff --git a/cmd/pomerium-cli/kubernetes.go b/cmd/pomerium-cli/kubernetes.go index d7c7edbd..c1d829db 100644 --- a/cmd/pomerium-cli/kubernetes.go +++ b/cmd/pomerium-cli/kubernetes.go @@ -69,6 +69,7 @@ var kubernetesExecCredentialCmd = &cobra.Command{ } ac := authclient.New( + authclient.WithUseDeviceCodeFlow(browserOptions.useDeviceCodeFlow), authclient.WithBrowserCommand(browserOptions.command), authclient.WithServiceAccount(serviceAccountOptions.serviceAccount), authclient.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile), diff --git a/cmd/pomerium-cli/main.go b/cmd/pomerium-cli/main.go index 94d46bb2..c7f0e52a 100644 --- a/cmd/pomerium-cli/main.go +++ b/cmd/pomerium-cli/main.go @@ -123,13 +123,16 @@ func getTLSConfig() (*tls.Config, error) { } var browserOptions struct { - command string + command string + useDeviceCodeFlow bool } func addBrowserFlags(cmd *cobra.Command) { flags := cmd.Flags() flags.StringVar(&browserOptions.command, "browser-cmd", "", "custom browser command to run when opening a URL") + flags.BoolVar(&browserOptions.useDeviceCodeFlow, "use-device-code-flow", false, + "use device code flow for authentication instead of opening a browser") } var serviceAccountOptions struct {