diff --git a/backend/cmd/headlamp.go b/backend/cmd/headlamp.go index 0907aab85bf..821733b96aa 100644 --- a/backend/cmd/headlamp.go +++ b/backend/cmd/headlamp.go @@ -46,15 +46,15 @@ import ( "github.com/kubernetes-sigs/headlamp/backend/pkg/auth" "github.com/kubernetes-sigs/headlamp/backend/pkg/cache" cfg "github.com/kubernetes-sigs/headlamp/backend/pkg/config" + "github.com/kubernetes-sigs/headlamp/backend/pkg/portforward" + "github.com/prometheus/client_golang/prometheus/promhttp" headlampcfg "github.com/kubernetes-sigs/headlamp/backend/pkg/headlampconfig" "github.com/kubernetes-sigs/headlamp/backend/pkg/helm" "github.com/kubernetes-sigs/headlamp/backend/pkg/kubeconfig" "github.com/kubernetes-sigs/headlamp/backend/pkg/logger" "github.com/kubernetes-sigs/headlamp/backend/pkg/plugins" - "github.com/kubernetes-sigs/headlamp/backend/pkg/portforward" "github.com/kubernetes-sigs/headlamp/backend/pkg/telemetry" - "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" @@ -390,427 +390,592 @@ func addPluginListRoute(config *HeadlampConfig, r *mux.Router) { }).Methods("GET") } -//nolint:gocognit,funlen,gocyclo func createHeadlampHandler(config *HeadlampConfig) http.Handler { - kubeConfigPath := config.KubeConfigPath + setupLogger(config) + setupPluginHandlers(config) + setupInClusterContext(config) + setupStaticFiles(config) - config.StaticPluginDir = os.Getenv("HEADLAMP_STATIC_PLUGINS_DIR") + router := createRouter(config) + setupRoutes(config, router) + return finalizeHandler(config, router) +} + +func setupLogger(config *HeadlampConfig) { logger.Log(logger.LevelInfo, nil, nil, "Creating Headlamp handler") logger.Log(logger.LevelInfo, nil, nil, "Listen address: "+fmt.Sprintf("%s:%d", config.ListenAddr, config.Port)) - logger.Log(logger.LevelInfo, nil, nil, "Kubeconfig path: "+kubeConfigPath) + logger.Log(logger.LevelInfo, nil, nil, "Kubeconfig path: "+config.KubeConfigPath) logger.Log(logger.LevelInfo, nil, nil, "Static plugin dir: "+config.StaticPluginDir) logger.Log(logger.LevelInfo, nil, nil, "Plugins dir: "+config.PluginDir) logger.Log(logger.LevelInfo, nil, nil, "Dynamic clusters support: "+fmt.Sprint(config.EnableDynamicClusters)) logger.Log(logger.LevelInfo, nil, nil, "Helm support: "+fmt.Sprint(config.EnableHelm)) logger.Log(logger.LevelInfo, nil, nil, "Proxy URLs: "+fmt.Sprint(config.ProxyURLs)) +} +func setupPluginHandlers(config *HeadlampConfig) { plugins.PopulatePluginsCache(config.StaticPluginDir, config.PluginDir, config.cache) - skipFunc := kubeconfig.SkipKubeContextInCommaSeparatedString(config.SkippedKubeContexts) - if !config.UseInCluster || config.WatchPluginsChanges { - // in-cluster mode is unlikely to want reloading plugins. pluginEventChan := make(chan string) go plugins.Watch(config.PluginDir, pluginEventChan) go plugins.HandlePluginEvents(config.StaticPluginDir, config.PluginDir, pluginEventChan, config.cache) - // in-cluster mode is unlikely to want reloading kubeconfig. - go kubeconfig.LoadAndWatchFiles(config.KubeConfigStore, kubeConfigPath, kubeconfig.KubeConfig, skipFunc) + + skipFunc := kubeconfig.SkipKubeContextInCommaSeparatedString(config.SkippedKubeContexts) + go kubeconfig.LoadAndWatchFiles(config.KubeConfigStore, config.KubeConfigPath, kubeconfig.KubeConfig, skipFunc) } +} - // In-cluster +func setupInClusterContext(config *HeadlampConfig) { if config.UseInCluster { - context, err := kubeconfig.GetInClusterContext(config.oidcIdpIssuerURL, - config.oidcClientID, config.oidcClientSecret, - strings.Join(config.oidcScopes, ",")) + context, err := kubeconfig.GetInClusterContext( + config.oidcIdpIssuerURL, + config.oidcClientID, + config.oidcClientSecret, + strings.Join(config.oidcScopes, ","), + ) if err != nil { logger.Log(logger.LevelError, nil, err, "Failed to get in-cluster context") + return } context.Source = kubeconfig.InCluster - - err = context.SetupProxy() - if err != nil { + if err := context.SetupProxy(); err != nil { logger.Log(logger.LevelError, nil, err, "Failed to setup proxy for in-cluster context") + return } - err = config.KubeConfigStore.AddContext(context) - if err != nil { + if err := config.KubeConfigStore.AddContext(context); err != nil { logger.Log(logger.LevelError, nil, err, "Failed to add in-cluster context") } } +} +func setupStaticFiles(config *HeadlampConfig) { if config.StaticDir != "" { baseURLReplace(config.StaticDir, config.BaseURL) } +} - // For when using a base-url, like "/headlamp" with a reverse proxy. - var r *mux.Router +func createRouter(config *HeadlampConfig) *mux.Router { if config.BaseURL == "" { - r = mux.NewRouter() - } else { - baseRoute := mux.NewRouter() - r = baseRoute.PathPrefix(config.BaseURL).Subrouter() + return mux.NewRouter() } - fmt.Println("*** Headlamp Server ***") - fmt.Println(" API Routers:") + baseRoute := mux.NewRouter() - // load kubeConfig clusters - err := kubeconfig.LoadAndStoreKubeConfigs(config.KubeConfigStore, kubeConfigPath, kubeconfig.KubeConfig, skipFunc) - if err != nil { - logger.Log(logger.LevelError, nil, err, "loading kubeconfig") - } + return baseRoute.PathPrefix(config.BaseURL).Subrouter() +} + +func setupRoutes(config *HeadlampConfig, router *mux.Router) { + setupMetricsRoute(config, router) + setupKubeconfigRoutes(config) + addPluginRoutes(config, router) + config.handleClusterRequests(router) + setupExternalProxyRoute(config, router) + setupConfigRoute(router, config) + setupWebsocketRoute(router, config) + setupClusterSetupRoute(router, config) + setupOIDCRoutes(router, config) + setupPortForwardRoutes(router, config) + setupNodeDrainRoutes(router, config) +} - // Prometheus metrics endpoint - // to enable this endpoint, run command run-backend-with-metrics - // or set the environment variable HEADLAMP_CONFIG_METRICS_ENABLED=true +func setupMetricsRoute(config *HeadlampConfig, router *mux.Router) { if config.Metrics != nil && config.telemetryConfig.MetricsEnabled != nil && *config.telemetryConfig.MetricsEnabled { - r.Handle("/metrics", promhttp.Handler()) + router.Handle("/metrics", promhttp.Handler()) logger.Log(logger.LevelInfo, nil, nil, "prometheus metrics endpoint: /metrics") } +} + +func setupKubeconfigRoutes(config *HeadlampConfig) { + skipFunc := kubeconfig.SkipKubeContextInCommaSeparatedString(config.SkippedKubeContexts) + if err := kubeconfig.LoadAndStoreKubeConfigs( + config.KubeConfigStore, + config.KubeConfigPath, + kubeconfig.KubeConfig, + skipFunc, + ); err != nil { + logger.Log(logger.LevelError, nil, err, "loading kubeconfig") + } - // load dynamic clusters kubeConfigPersistenceFile, err := defaultHeadlampKubeConfigFile() if err != nil { logger.Log(logger.LevelError, nil, err, "getting default kubeconfig persistence file") } - err = kubeconfig.LoadAndStoreKubeConfigs(config.KubeConfigStore, kubeConfigPersistenceFile, - kubeconfig.DynamicCluster, skipFunc) - if err != nil { + if err := kubeconfig.LoadAndStoreKubeConfigs( + config.KubeConfigStore, + kubeConfigPersistenceFile, + kubeconfig.DynamicCluster, + skipFunc, + ); err != nil { logger.Log(logger.LevelError, nil, err, "loading dynamic kubeconfig") } +} - addPluginRoutes(config, r) +func setupExternalProxyRoute(config *HeadlampConfig, router *mux.Router) { + router.HandleFunc("/externalproxy", func(w http.ResponseWriter, r *http.Request) { + handleExternalProxy(config, w, r) + }) +} - config.handleClusterRequests(r) +func setupConfigRoute(router *mux.Router, config *HeadlampConfig) { + router.HandleFunc("/config", config.getConfig).Methods("GET") +} - r.HandleFunc("/externalproxy", func(w http.ResponseWriter, r *http.Request) { - proxyURL := r.Header.Get("proxy-to") - if proxyURL == "" && r.Header.Get("Forward-to") != "" { - proxyURL = r.Header.Get("Forward-to") - } +func setupWebsocketRoute(router *mux.Router, config *HeadlampConfig) { + router.HandleFunc("/wsMultiplexer", config.multiplexer.HandleClientWebSocket) +} - if proxyURL == "" { - logger.Log(logger.LevelError, map[string]string{"proxyURL": proxyURL}, - errors.New("proxy URL is empty"), "proxy URL is empty") - http.Error(w, "proxy URL is empty", http.StatusBadRequest) +func setupClusterSetupRoute(router *mux.Router, config *HeadlampConfig) { + config.addClusterSetupRoute(router) +} - return - } +func setupOIDCRoutes(router *mux.Router, config *HeadlampConfig) { + oauthRequestMap := make(map[string]*OauthConfig) - url, err := url.Parse(proxyURL) - if err != nil { - logger.Log(logger.LevelError, map[string]string{"proxyURL": proxyURL}, - err, "The provided proxy URL is invalid") - http.Error(w, fmt.Sprintf("The provided proxy URL is invalid: %v", err), http.StatusBadRequest) + router.HandleFunc("/oidc", func(w http.ResponseWriter, r *http.Request) { + handleOIDC(config, w, r, oauthRequestMap) + }).Queries("cluster", "{cluster}") - return - } + router.HandleFunc("/oidc-callback", func(w http.ResponseWriter, r *http.Request) { + handleOIDCCallback(config, w, r, oauthRequestMap) + }) +} - isURLContainedInProxyURLs := false +func handleOIDCCallback( + config *HeadlampConfig, + w http.ResponseWriter, + r *http.Request, + oauthRequestMap map[string]*OauthConfig, +) { + state := r.URL.Query().Get("state") - for _, proxyURL := range config.ProxyURLs { - g := glob.MustCompile(proxyURL) - if g.Match(url.String()) { - isURLContainedInProxyURLs = true - break - } - } + if err := validateState(state); err != nil { + writeOIDCError(w, err, "invalid state") + return + } - if !isURLContainedInProxyURLs { - logger.Log(logger.LevelError, nil, err, "no allowed proxy url match, request denied") - http.Error(w, "no allowed proxy url match, request denied ", http.StatusBadRequest) + oauthConfig := getOAuthConfigFromState(state, oauthRequestMap) + if oauthConfig == nil { + writeOIDCError(w, nil, "invalid request") + return + } - return - } + processOIDCCallback(config, w, r, oauthConfig, state) +} - ctx := context.Background() +func writeOIDCError(w http.ResponseWriter, err error, msg string) { + if err != nil { + msg = fmt.Sprintf("%s: %v", msg, err) + } - proxyReq, err := http.NewRequestWithContext(ctx, r.Method, proxyURL, r.Body) - if err != nil { - logger.Log(logger.LevelError, nil, err, "creating request") - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, msg, http.StatusBadRequest) +} - return - } +func validateState(state string) error { + if _, err := base64.StdEncoding.DecodeString(state); err != nil { + return fmt.Errorf("failed to decode state: %w", err) + } - // We may want to filter some headers, otherwise we could just use a shallow copy - proxyReq.Header = make(http.Header) - for h, val := range r.Header { - proxyReq.Header[h] = val - } + return nil +} - // Disable caching - w.Header().Set("Cache-Control", "no-cache, private, max-age=0") - w.Header().Set("Expires", time.Unix(0, 0).Format(http.TimeFormat)) - w.Header().Set("Pragma", "no-cache") - w.Header().Set("X-Accel-Expires", "0") +func getOAuthConfigFromState( + state string, + oauthRequestMap map[string]*OauthConfig, +) *OauthConfig { + return oauthRequestMap[state] +} - client := http.Client{} +func processOIDCCallback( + config *HeadlampConfig, + w http.ResponseWriter, + r *http.Request, + oauthConfig *OauthConfig, + state string, +) { + token, err := exchangeCode(oauthConfig, r) + if err != nil { + writeOIDCError(w, err, "failed to exchange token") + return + } - resp, err := client.Do(proxyReq) - if err != nil { - logger.Log(logger.LevelError, nil, err, "making request") - http.Error(w, err.Error(), http.StatusBadGateway) + tokenType := "id_token" + if config.oidcUseAccessToken { + tokenType = "access_token" + } - return - } + rawToken := getRawToken(token, tokenType) + if rawToken == "" { + writeOIDCError(w, nil, fmt.Sprintf("no %s field", tokenType)) + return + } - defer resp.Body.Close() + cacheToken(config.cache, rawToken, token.RefreshToken) + verifyToken(oauthConfig, rawToken) + redirectUser(config, w, r, state, rawToken) +} - // Check that the server actually sent compressed data - var reader io.ReadCloser +func exchangeCode(oauthConfig *OauthConfig, r *http.Request) (*oauth2.Token, error) { + return oauthConfig.Config.Exchange(oauthConfig.Ctx, r.URL.Query().Get("code")) +} - switch resp.Header.Get("Content-Encoding") { - case "gzip": - reader, err = gzip.NewReader(resp.Body) - if err != nil { - logger.Log(logger.LevelError, nil, err, "reading gzip response") - http.Error(w, err.Error(), http.StatusInternalServerError) +func getRawToken(token *oauth2.Token, tokenType string) string { + rawToken, _ := token.Extra(tokenType).(string) + return rawToken +} - return - } - defer reader.Close() - default: - reader = resp.Body - } +func cacheToken(cache cache.Cache[interface{}], rawToken, refreshToken string) { + _ = cache.Set(context.Background(), + fmt.Sprintf("oidc-token-%s", rawToken), + refreshToken) +} - respBody, err := io.ReadAll(reader) - if err != nil { - logger.Log(logger.LevelError, nil, err, "reading response") - http.Error(w, err.Error(), http.StatusBadGateway) +func verifyToken(oauthConfig *OauthConfig, rawToken string) { + _, _ = oauthConfig.Verifier.Verify(oauthConfig.Ctx, rawToken) +} - return - } +func redirectUser( + config *HeadlampConfig, + w http.ResponseWriter, + r *http.Request, + state, + rawToken string, +) { + redirectURL := buildRedirectURL(config, state, rawToken) + http.Redirect(w, r, redirectURL, http.StatusSeeOther) +} - _, err = w.Write(respBody) - if err != nil { - logger.Log(logger.LevelError, nil, err, "writing response") - http.Error(w, err.Error(), http.StatusInternalServerError) +func buildRedirectURL(config *HeadlampConfig, state, rawToken string) string { + base := "http://localhost:3000/" + if !config.DevMode { + base = "/" + } - return - } + if trimmed := strings.Trim(config.BaseURL, "/"); trimmed != "" { + base += trimmed + "/" + } - defer resp.Body.Close() - }) + decodedState, _ := base64.StdEncoding.DecodeString(state) - // Configuration - r.HandleFunc("/config", config.getConfig).Methods("GET") + return fmt.Sprintf("%sauth?cluster=%s&token=%s", + base, decodedState, rawToken) +} - // Websocket connections - r.HandleFunc("/wsMultiplexer", config.multiplexer.HandleClientWebSocket) +func handleOIDC( + config *HeadlampConfig, + w http.ResponseWriter, + r *http.Request, + oauthRequestMap map[string]*OauthConfig, +) { + ctx := createOIDCContext(config) + cluster := r.URL.Query().Get("cluster") - config.addClusterSetupRoute(r) + kContext, err := getKubeContext(config, cluster) + if err != nil { + handleOIDCContextError(w, r, cluster, err) + return + } - oauthRequestMap := make(map[string]*OauthConfig) + oidcConfig, err := getOIDCConfig(config, r, kContext, ctx) + if err != nil { + handleOIDCConfigError(w, cluster, err) + return + } - r.HandleFunc("/oidc", func(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() - cluster := r.URL.Query().Get("cluster") - if config.Insecure { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec - } - InsecureClient := &http.Client{Transport: tr} - ctx = oidc.ClientContext(ctx, InsecureClient) - } + state := createOIDCState(cluster) + oauthRequestMap[state] = oidcConfig + http.Redirect(w, r, oidcConfig.Config.AuthCodeURL(state), http.StatusFound) +} - kContext, err := config.KubeConfigStore.GetContext(cluster) - if err != nil { - logger.Log(logger.LevelError, map[string]string{"cluster": cluster}, - err, "failed to get context") +func createOIDCContext(config *HeadlampConfig) context.Context { + ctx := context.Background() - http.NotFound(w, r) - return + if config.Insecure { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec } + ctx = oidc.ClientContext(ctx, &http.Client{Transport: tr}) + } - oidcAuthConfig, err := kContext.OidcConfig() - if err != nil { - logger.Log(logger.LevelError, map[string]string{"cluster": cluster}, - err, "failed to get oidc config") + return ctx +} - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } +func getKubeContext(config *HeadlampConfig, cluster string) (*kubeconfig.Context, error) { + return config.KubeConfigStore.GetContext(cluster) +} - if config.oidcValidatorIdpIssuerURL != "" { - ctx = oidc.InsecureIssuerURLContext(ctx, config.oidcValidatorIdpIssuerURL) - } +func handleOIDCContextError(w http.ResponseWriter, r *http.Request, cluster string, err error) { + logger.Log(logger.LevelError, map[string]string{"cluster": cluster}, err, "failed to get context") + http.NotFound(w, r) +} - provider, err := oidc.NewProvider(ctx, oidcAuthConfig.IdpIssuerURL) - if err != nil { - logger.Log(logger.LevelError, map[string]string{"idpIssuerURL": oidcAuthConfig.IdpIssuerURL}, - err, "failed to get provider") +func getOIDCConfig( + config *HeadlampConfig, + r *http.Request, + kContext *kubeconfig.Context, + ctx context.Context, +) (*OauthConfig, error) { + oidcAuthConfig, err := kContext.OidcConfig() + if err != nil { + return nil, fmt.Errorf("failed to get oidc config: %w", err) + } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + if config.oidcValidatorIdpIssuerURL != "" { + ctx = oidc.InsecureIssuerURLContext(ctx, config.oidcValidatorIdpIssuerURL) + } - validatorClientID := oidcAuthConfig.ClientID - if config.oidcValidatorClientID != "" { - validatorClientID = config.oidcValidatorClientID - } - oidcConfig := &oidc.Config{ - ClientID: validatorClientID, - } + provider, err := oidc.NewProvider(ctx, oidcAuthConfig.IdpIssuerURL) + if err != nil { + return nil, fmt.Errorf("failed to create provider: %w", err) + } - verifier := provider.Verifier(oidcConfig) - oauthConfig := &oauth2.Config{ + validatorClientID := getValidatorClientID(config, oidcAuthConfig) + verifier := provider.Verifier(&oidc.Config{ClientID: validatorClientID}) + + return &OauthConfig{ + Config: &oauth2.Config{ ClientID: oidcAuthConfig.ClientID, ClientSecret: oidcAuthConfig.ClientSecret, Endpoint: provider.Endpoint(), RedirectURL: getOidcCallbackURL(r, config), Scopes: append([]string{oidc.ScopeOpenID}, oidcAuthConfig.Scopes...), - } - /* we encode the cluster to base64 and set it as state so that when getting redirected - by oidc we can use this state value to get cluster name - */ - state := base64.StdEncoding.EncodeToString([]byte(cluster)) - oauthRequestMap[state] = &OauthConfig{Config: oauthConfig, Verifier: verifier, Ctx: ctx} - http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound) - }).Queries("cluster", "{cluster}") + }, + Verifier: verifier, + Ctx: ctx, + }, nil +} - r.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { - portforward.StartPortForward(config.KubeConfigStore, config.cache, w, r) - }).Methods("POST") +func handleOIDCConfigError(w http.ResponseWriter, cluster string, err error) { + logger.Log(logger.LevelError, map[string]string{"cluster": cluster}, err, "failed to get oidc config") + http.Error(w, err.Error(), http.StatusInternalServerError) +} - r.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { - portforward.StopOrDeletePortForward(config.cache, w, r) - }).Methods("DELETE") +func createOIDCState(cluster string) string { + return base64.StdEncoding.EncodeToString([]byte(cluster)) +} - r.HandleFunc("/portforward/list", func(w http.ResponseWriter, r *http.Request) { - portforward.GetPortForwards(config.cache, w, r) - }) +func getValidatorClientID(config *HeadlampConfig, oidcAuthConfig *kubeconfig.OidcConfig) string { + if config.oidcValidatorClientID != "" { + return config.oidcValidatorClientID + } - r.HandleFunc("/drain-node", config.handleNodeDrain).Methods("POST") - r.HandleFunc("/drain-node-status", - config.handleNodeDrainStatus).Methods("GET").Queries("cluster", "{cluster}", "nodeName", "{node}") - r.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { - portforward.GetPortForwardByID(config.cache, w, r) - }).Methods("GET") + return oidcAuthConfig.ClientID +} - r.HandleFunc("/oidc-callback", func(w http.ResponseWriter, r *http.Request) { - state := r.URL.Query().Get("state") +func handleExternalProxy(config *HeadlampConfig, w http.ResponseWriter, r *http.Request) { + proxyURL := getProxyURLFromRequest(r) + if proxyURL == "" { + handleProxyURLError(w, "proxy URL is empty") + return + } - decodedState, err := base64.StdEncoding.DecodeString(state) - if err != nil { - logger.Log(logger.LevelError, nil, err, "failed to decode state") - http.Error(w, "wrong state set, invalid request "+err.Error(), http.StatusBadRequest) + parsedURL, err := parseProxyURL(proxyURL) + if err != nil { + handleProxyURLError(w, fmt.Sprintf("The provided proxy URL is invalid: %v", err)) + return + } - return - } + if !isURLAllowed(config.ProxyURLs, parsedURL.String()) { + handleProxyURLError(w, "no allowed proxy url match, request denied") + return + } - if state == "" { - logger.Log(logger.LevelError, nil, err, "invalid request state is empty") - http.Error(w, "invalid request state is empty", http.StatusBadRequest) + proxyReq, err := createProxyRequest(r, proxyURL) + if err != nil { + handleProxyError(w, err, "creating request", http.StatusInternalServerError) + return + } - return - } + copyHeaders(r.Header, proxyReq.Header) + setNoCacheHeaders(w) - //nolint:nestif - if oauthConfig, ok := oauthRequestMap[state]; ok { - oauth2Token, err := oauthConfig.Config.Exchange(oauthConfig.Ctx, r.URL.Query().Get("code")) - if err != nil { - logger.Log(logger.LevelError, nil, err, "failed to exchange token") - http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + resp, err := executeProxyRequest(proxyReq) + if err != nil { + handleProxyError(w, err, "making request", http.StatusBadGateway) + return + } + defer resp.Body.Close() - return - } + if err := processAndWriteProxyResponse(w, resp); err != nil { + handleProxyError(w, err, "processing response", http.StatusInternalServerError) + } +} - tokenType := "id_token" - if config.oidcUseAccessToken { - tokenType = "access_token" - } +func getProxyURLFromRequest(r *http.Request) string { + proxyURL := r.Header.Get("proxy-to") + if proxyURL == "" { + proxyURL = r.Header.Get("Forward-to") + } - rawUserToken, ok := oauth2Token.Extra(tokenType).(string) - if !ok { - logger.Log(logger.LevelError, nil, err, fmt.Sprintf("no %s field in oauth2 token", tokenType)) - http.Error(w, fmt.Sprintf("No %s field in oauth2 token.", tokenType), http.StatusInternalServerError) + return proxyURL +} - return - } +func handleProxyURLError(w http.ResponseWriter, message string) { + logger.Log(logger.LevelError, map[string]string{"proxyURL": message}, + errors.New(message), message) + http.Error(w, message, http.StatusBadRequest) +} - if err := config.cache.Set(context.Background(), - fmt.Sprintf("oidc-token-%s", rawUserToken), oauth2Token.RefreshToken); err != nil { - logger.Log(logger.LevelError, nil, err, "failed to cache refresh token") - http.Error(w, "Failed to cache refresh token: "+err.Error(), http.StatusInternalServerError) +func parseProxyURL(proxyURL string) (*url.URL, error) { + parsedURL, err := url.Parse(proxyURL) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"proxyURL": proxyURL}, + err, "The provided proxy URL is invalid") - return - } + return nil, err + } - idToken, err := oauthConfig.Verifier.Verify(oauthConfig.Ctx, rawUserToken) - if err != nil { - logger.Log(logger.LevelError, nil, err, "failed to verify ID Token") - http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) + return parsedURL, nil +} - return - } +func isURLAllowed(allowedURLs []string, targetURL string) bool { + for _, allowedURL := range allowedURLs { + g := glob.MustCompile(allowedURL) + if g.Match(targetURL) { + return true + } + } - resp := struct { - OAuth2Token *oauth2.Token - IDTokenClaims *json.RawMessage // ID Token payload is just JSON. - }{oauth2Token, new(json.RawMessage)} + logger.Log(logger.LevelError, nil, nil, "no allowed proxy url match, request denied") - if err := idToken.Claims(&resp.IDTokenClaims); err != nil { - logger.Log(logger.LevelError, nil, err, "failed to get id token claims") - http.Error(w, err.Error(), http.StatusInternalServerError) + return false +} - return - } +func createProxyRequest(r *http.Request, proxyURL string) (*http.Request, error) { + ctx := context.Background() - var redirectURL string - if config.DevMode { - redirectURL = "http://localhost:3000/" - } else { - redirectURL = "/" - } + proxyReq, err := http.NewRequestWithContext(ctx, r.Method, proxyURL, r.Body) + if err != nil { + logger.Log(logger.LevelError, nil, err, "creating request") + return nil, err + } - baseURL := strings.Trim(config.BaseURL, "/") - if baseURL != "" { - redirectURL += baseURL + "/" - } + return proxyReq, nil +} - redirectURL += fmt.Sprintf("auth?cluster=%1s&token=%2s", decodedState, rawUserToken) - http.Redirect(w, r, redirectURL, http.StatusSeeOther) - } else { - http.Error(w, "invalid request", http.StatusBadRequest) - return +func handleProxyError(w http.ResponseWriter, err error, context string, statusCode int) { + logger.Log(logger.LevelError, nil, err, context) + http.Error(w, err.Error(), statusCode) +} + +func copyHeaders(source http.Header, target http.Header) { + for h, val := range source { + target[h] = val + } +} + +func setNoCacheHeaders(w http.ResponseWriter) { + w.Header().Set("Cache-Control", "no-cache, private, max-age=0") + w.Header().Set("Expires", time.Unix(0, 0).Format(http.TimeFormat)) + w.Header().Set("Pragma", "no-cache") + w.Header().Set("X-Accel-Expires", "0") +} + +func executeProxyRequest(proxyReq *http.Request) (*http.Response, error) { + client := http.Client{} + return client.Do(proxyReq) +} + +func processAndWriteProxyResponse(w http.ResponseWriter, resp *http.Response) error { + reader, err := getResponseReader(resp) + if err != nil { + return err + } + + defer func() { + if reader != resp.Body { + reader.Close() } - }) + }() + + respBody, err := io.ReadAll(reader) + if err != nil { + logger.Log(logger.LevelError, nil, err, "reading response") + return err + } + + if _, err = w.Write(respBody); err != nil { + logger.Log(logger.LevelError, nil, err, "writing response") + return err + } - // Serve the frontend if needed + return nil +} + +func finalizeHandler(config *HeadlampConfig, router *mux.Router) http.Handler { if config.StaticDir != "" { staticPath := config.StaticDir - - if isWindows { - // We supPort unix paths on windows. So "frontend/static" works. - if strings.Contains(config.StaticDir, "/") { - staticPath = filepath.FromSlash(config.StaticDir) - } + if isWindows && strings.Contains(config.StaticDir, "/") { + staticPath = filepath.FromSlash(config.StaticDir) } - spa := spaHandler{staticPath: staticPath, indexPath: "index.html", baseURL: config.BaseURL} - r.PathPrefix("/").Handler(spa) - - http.Handle("/", r) + spa := spaHandler{ + staticPath: staticPath, + indexPath: "index.html", + baseURL: config.BaseURL, + } + router.PathPrefix("/").Handler(spa) } - // On dev mode we're loose about where connections come from if config.DevMode { - headers := handlers.AllowedHeaders([]string{ - "X-HEADLAMP_BACKEND-TOKEN", "X-Requested-With", "Content-Type", - "Authorization", "Forward-To", - "KUBECONFIG", "X-HEADLAMP-USER-ID", - }) - methods := handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "DELETE", "PATCH", "OPTIONS"}) - origins := handlers.AllowedOrigins([]string{"*"}) + return setupCORS(router) + } - return handlers.CORS(headers, methods, origins)(r) + return router +} + +func setupCORS(router *mux.Router) http.Handler { + headers := handlers.AllowedHeaders([]string{ + "X-HEADLAMP_BACKEND-TOKEN", "X-Requested-With", "Content-Type", + "Authorization", "Forward-To", + "KUBECONFIG", "X-HEADLAMP-USER-ID", + }) + methods := handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "DELETE", "PATCH", "OPTIONS"}) + origins := handlers.AllowedOrigins([]string{"*"}) + + return handlers.CORS(headers, methods, origins)(router) +} + +func getResponseReader(resp *http.Response) (io.ReadCloser, error) { + switch resp.Header.Get("Content-Encoding") { + case "gzip": + reader, err := gzip.NewReader(resp.Body) + if err != nil { + logger.Log(logger.LevelError, nil, err, "reading gzip response") + return nil, err + } + + return reader, nil + default: + return resp.Body, nil } +} + +func setupPortForwardRoutes(router *mux.Router, config *HeadlampConfig) { + router.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { + portforward.StartPortForward(config.KubeConfigStore, config.cache, w, r) + }).Methods("POST") + + router.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { + portforward.StopOrDeletePortForward(config.cache, w, r) + }).Methods("DELETE") - return r + router.HandleFunc("/portforward/list", func(w http.ResponseWriter, r *http.Request) { + portforward.GetPortForwards(config.cache, w, r) + }) + + router.HandleFunc("/portforward", func(w http.ResponseWriter, r *http.Request) { + portforward.GetPortForwardByID(config.cache, w, r) + }).Methods("GET") +} + +func setupNodeDrainRoutes(router *mux.Router, config *HeadlampConfig) { + router.HandleFunc("/drain-node", config.handleNodeDrain).Methods("POST") + router.HandleFunc("/drain-node-status", + config.handleNodeDrainStatus).Methods("GET").Queries("cluster", "{cluster}", "nodeName", "{node}") } func parseClusterAndToken(r *http.Request) (string, string) {