diff --git a/cli/server.go b/cli/server.go index 23a5d45c2fc5a..011d972aec265 100644 --- a/cli/server.go +++ b/cli/server.go @@ -207,6 +207,16 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co listener = tls.NewListener(listener, tlsConfig) } + ctx, httpClient, err := configureHTTPClient( + ctx, + cfg.TLS.ClientCertFile.Value, + cfg.TLS.ClientKeyFile.Value, + cfg.TLS.ClientCAFile.Value, + ) + if err != nil { + return xerrors.Errorf("configure http client: %w", err) + } + tcpAddr, valid := listener.Addr().(*net.TCPAddr) if !valid { return xerrors.New("must be listening on tcp") @@ -377,6 +387,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co DeploymentConfig: cfg, PrometheusRegistry: prometheus.NewRegistry(), APIRateLimit: cfg.APIRateLimit.Value, + HTTPClient: httpClient, } if tlsConfig != nil { options.TLSCertificates = tlsConfig.Certificates @@ -424,11 +435,6 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co return xerrors.Errorf("OIDC issuer URL must be set!") } - ctx, err := handleOauth2ClientCertificates(ctx, cfg) - if err != nil { - return xerrors.Errorf("configure oidc client certificates: %w", err) - } - if cfg.OIDC.IgnoreEmailVerified.Value { logger.Warn(ctx, "coder will not check email_verified for OIDC logins") } @@ -1088,19 +1094,27 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles return nil, nil //nolint:nilnil } + err = configureCAPool(tlsClientCAFile, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConfig, nil +} + +func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error { if tlsClientCAFile != "" { caPool := x509.NewCertPool() data, err := os.ReadFile(tlsClientCAFile) if err != nil { - return nil, xerrors.Errorf("read %q: %w", tlsClientCAFile, err) + return xerrors.Errorf("read %q: %w", tlsClientCAFile, err) } if !caPool.AppendCertsFromPEM(data) { - return nil, xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file") + return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file") } tlsConfig.ClientCAs = caPool } - - return tlsConfig, nil + return nil } //nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive) @@ -1319,20 +1333,27 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg return connectionURL, ep.Stop, nil } -func handleOauth2ClientCertificates(ctx context.Context, cfg *codersdk.DeploymentConfig) (context.Context, error) { - if cfg.TLS.ClientCertFile.Value != "" && cfg.TLS.ClientKeyFile.Value != "" { - certificates, err := loadCertificates([]string{cfg.TLS.ClientCertFile.Value}, []string{cfg.TLS.ClientKeyFile.Value}) +func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) { + if clientCertFile != "" && clientKeyFile != "" { + certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile}) if err != nil { - return nil, err + return ctx, nil, err } - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + tlsClientConfig := &tls.Config{ //nolint:gosec + Certificates: certificates, + } + err = configureCAPool(tlsClientCAFile, tlsClientConfig) + if err != nil { + return nil, nil, err + } + + httpClient := &http.Client{ Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - Certificates: certificates, - }, + TLSClientConfig: tlsClientConfig, }, - }), nil + } + return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil } - return ctx, nil + return ctx, &http.Client{}, nil } diff --git a/coderd/coderd.go b/coderd/coderd.go index 675d023e82194..87a136e527d90 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -107,6 +107,7 @@ type Options struct { Experimental bool DeploymentConfig *codersdk.DeploymentConfig UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking. + HTTPClient *http.Client } // New constructs a Coder API handler. @@ -279,7 +280,7 @@ func New(options *Options) *API { for _, gitAuthConfig := range options.GitAuthConfigs { r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) { r.Use( - httpmw.ExtractOAuth2(gitAuthConfig), + httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient), apiKeyMiddleware, ) r.Get("/callback", api.gitAuthCallback(gitAuthConfig)) @@ -428,12 +429,12 @@ func New(options *Options) *API { r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config)) + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient)) r.Get("/callback", api.userOAuth2Github) }) }) r.Route("/oidc/callback", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.OIDCConfig)) + r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient)) r.Get("/", api.userOIDC) }) r.Group(func(r chi.Router) { diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index aeebb544e351e..5463dffd60b0e 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -40,10 +40,14 @@ func OAuth2(r *http.Request) OAuth2State { // ExtractOAuth2 is a middleware for automatically redirecting to OAuth // URLs, and handling the exchange inbound. Any route that does not have // a "code" URL parameter will be redirected. -func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler { +func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + if client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + } + // Interfaces can hold a nil value if config == nil || reflect.ValueOf(config).IsNil() { httpapi.Write(ctx, rw, http.StatusPreconditionRequired, codersdk.Response{ diff --git a/coderd/httpmw/oauth2_test.go b/coderd/httpmw/oauth2_test.go index dd5c7c6bc7b35..17f37bf5d64bd 100644 --- a/coderd/httpmw/oauth2_test.go +++ b/coderd/httpmw/oauth2_test.go @@ -39,14 +39,14 @@ func TestOAuth2(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode) }) t.Run("RedirectWithoutCode", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) location := res.Header().Get("Location") if !assert.NotEmpty(t, location) { return @@ -59,14 +59,14 @@ func TestOAuth2(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?code=something", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusBadRequest, res.Result().StatusCode) }) t.Run("NoStateCookie", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?code=something&state=test", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) }) t.Run("MismatchedState", func(t *testing.T) { @@ -77,7 +77,7 @@ func TestOAuth2(t *testing.T) { Value: "mismatch", }) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) }) t.Run("ExchangeCodeAndState", func(t *testing.T) { @@ -92,7 +92,7 @@ func TestOAuth2(t *testing.T) { Value: "/dashboard", }) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { state := httpmw.OAuth2(r) require.Equal(t, "/dashboard", state.Redirect) })).ServeHTTP(res, req)