From 05d86e54f3b4e0495651f6808fe139964e98f3c1 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Thu, 17 Nov 2022 21:21:32 -0500 Subject: [PATCH 1/7] Added client certs to Oauth HTTPClient context --- cli/server.go | 47 ++------------------------------ coderd/coderd.go | 6 ++-- coderd/httpmw/oauth2.go | 11 +++++++- coderd/httpmw/oauth2_test.go | 12 ++++---- codersdk/certificateUtils.go | 53 ++++++++++++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 55 deletions(-) create mode 100644 codersdk/certificateUtils.go diff --git a/cli/server.go b/cli/server.go index e8a009a8977c4..8165f8f03c292 100644 --- a/cli/server.go +++ b/cli/server.go @@ -393,7 +393,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co return xerrors.Errorf("OIDC issuer URL must be set!") } - ctx, err := handleOauth2ClientCertificates(ctx, cfg) + ctx, err := codersdk.HandleOauth2ClientCertificates(ctx, cfg.TLS) if err != nil { return xerrors.Errorf("configure oidc client certificates: %w", err) } @@ -964,31 +964,6 @@ func printLogo(cmd *cobra.Command) { _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Software development on your infrastucture\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version())) } -func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { - if len(tlsCertFiles) != len(tlsKeyFiles) { - return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times") - } - if len(tlsCertFiles) == 0 { - return nil, xerrors.New("--tls-cert-file is required when tls is enabled") - } - if len(tlsKeyFiles) == 0 { - return nil, xerrors.New("--tls-key-file is required when tls is enabled") - } - - certs := make([]tls.Certificate, len(tlsCertFiles)) - for i := range tlsCertFiles { - certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i] - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err) - } - - certs[i] = cert - } - - return certs, nil -} - func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, @@ -1021,7 +996,7 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth) } - certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles) + certs, err := codersdk.LoadCertificates(tlsCertFiles, tlsKeyFiles) if err != nil { return nil, xerrors.Errorf("load certificates: %w", err) } @@ -1278,21 +1253,3 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg } return connectionURL, ep.Stop, nil } - -func handleOauth2ClientCertificates(ctx context.Context, cfg *codersdk.DeploymentConfig) (context.Context, error) { - if cfg.TLS.ClientCertFile.Value != "" && cfg.TLS.ClientKeyFile.Value != "" { - certificates, err := loadCertificates([]string{cfg.TLS.ClientCertFile.Value}, []string{cfg.TLS.ClientKeyFile.Value}) - if err != nil { - return nil, err - } - - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - Certificates: certificates, - }, - }, - }), nil - } - return ctx, nil -} diff --git a/coderd/coderd.go b/coderd/coderd.go index 3f7d3d7211321..6746d0608d896 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -275,7 +275,7 @@ func New(options *Options) *API { for _, gitAuthConfig := range options.GitAuthConfigs { r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) { r.Use( - httpmw.ExtractOAuth2(gitAuthConfig), + httpmw.ExtractOAuth2(gitAuthConfig, options.DeploymentConfig.TLS), apiKeyMiddleware, ) r.Get("/callback", api.gitAuthCallback(gitAuthConfig)) @@ -421,12 +421,12 @@ func New(options *Options) *API { r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config)) + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.DeploymentConfig.TLS)) r.Get("/callback", api.userOAuth2Github) }) }) r.Route("/oidc/callback", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.OIDCConfig)) + r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.DeploymentConfig.TLS)) r.Get("/", api.userOIDC) }) r.Group(func(r chi.Router) { diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index aeebb544e351e..4203a4afc0d61 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -40,7 +40,7 @@ func OAuth2(r *http.Request) OAuth2State { // ExtractOAuth2 is a middleware for automatically redirecting to OAuth // URLs, and handling the exchange inbound. Any route that does not have // a "code" URL parameter will be redirected. -func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler { +func ExtractOAuth2(config OAuth2Config, tlsConfig *codersdk.TLSConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -114,6 +114,15 @@ func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler { redirect = stateRedirect.Value } + ctx, err = codersdk.HandleOauth2ClientCertificates(ctx, tlsConfig) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading client certificates.", + Detail: err.Error(), + }) + return + } + oauthToken, err := config.Exchange(ctx, code) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ diff --git a/coderd/httpmw/oauth2_test.go b/coderd/httpmw/oauth2_test.go index dd5c7c6bc7b35..17f37bf5d64bd 100644 --- a/coderd/httpmw/oauth2_test.go +++ b/coderd/httpmw/oauth2_test.go @@ -39,14 +39,14 @@ func TestOAuth2(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode) }) t.Run("RedirectWithoutCode", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) location := res.Header().Get("Location") if !assert.NotEmpty(t, location) { return @@ -59,14 +59,14 @@ func TestOAuth2(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?code=something", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusBadRequest, res.Result().StatusCode) }) t.Run("NoStateCookie", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?code=something&state=test", nil) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) }) t.Run("MismatchedState", func(t *testing.T) { @@ -77,7 +77,7 @@ func TestOAuth2(t *testing.T) { Value: "mismatch", }) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req) require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) }) t.Run("ExchangeCodeAndState", func(t *testing.T) { @@ -92,7 +92,7 @@ func TestOAuth2(t *testing.T) { Value: "/dashboard", }) res := httptest.NewRecorder() - httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { state := httpmw.OAuth2(r) require.Equal(t, "/dashboard", state.Redirect) })).ServeHTTP(res, req) diff --git a/codersdk/certificateUtils.go b/codersdk/certificateUtils.go new file mode 100644 index 0000000000000..b4e6889f23284 --- /dev/null +++ b/codersdk/certificateUtils.go @@ -0,0 +1,53 @@ +package codersdk + +import ( + "context" + "crypto/tls" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/xerrors" +) + +func LoadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { + if len(tlsCertFiles) != len(tlsKeyFiles) { + return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times") + } + if len(tlsCertFiles) == 0 { + return nil, xerrors.New("--tls-cert-file is required when tls is enabled") + } + if len(tlsKeyFiles) == 0 { + return nil, xerrors.New("--tls-key-file is required when tls is enabled") + } + + certs := make([]tls.Certificate, len(tlsCertFiles)) + for i := range tlsCertFiles { + certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i] + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err) + } + + certs[i] = cert + } + + return certs, nil +} + +func HandleOauth2ClientCertificates(ctx context.Context, cfg *TLSConfig) (context.Context, error) { + if cfg != nil && cfg.ClientCertFile.Value != "" && cfg.ClientKeyFile.Value != "" { + certificates, err := LoadCertificates([]string{cfg.ClientCertFile.Value}, []string{cfg.ClientKeyFile.Value}) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ //nolint:gosec + Certificates: certificates, + }, + }, + }), nil + } + return ctx, nil +} From a936cbae66a9eb4c5ebe9d1e13fb3d69a6a53ab1 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Tue, 6 Dec 2022 01:18:46 -0500 Subject: [PATCH 2/7] Created an HttpClient option which can be configured with TLS client certs --- cli/server.go | 57 ++++++++++++++++++++++++++++++++---- coderd/coderd.go | 7 +++-- coderd/httpmw/oauth2.go | 7 +++-- codersdk/certificateUtils.go | 53 --------------------------------- 4 files changed, 60 insertions(+), 64 deletions(-) delete mode 100644 codersdk/certificateUtils.go diff --git a/cli/server.go b/cli/server.go index d92606d30d004..ecaab83d9b115 100644 --- a/cli/server.go +++ b/cli/server.go @@ -347,6 +347,11 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co return xerrors.Errorf("parse real ip config: %w", err) } + clientTLSHTTPClient, err := clientTLSHTTPClient(cfg.TLS) + if err != nil { + return xerrors.Errorf("configure http client certificates: %w", err) + } + options := &coderd.Options{ AccessURL: accessURLParsed, AppHostname: appHostname, @@ -369,6 +374,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co DeploymentConfig: cfg, PrometheusRegistry: prometheus.NewRegistry(), APIRateLimit: cfg.APIRateLimit.Value, + HttpClient: clientTLSHTTPClient, } if tlsConfig != nil { options.TLSCertificates = tlsConfig.Certificates @@ -416,11 +422,6 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co return xerrors.Errorf("OIDC issuer URL must be set!") } - ctx, err := codersdk.HandleOauth2ClientCertificates(ctx, cfg.TLS) - if err != nil { - return xerrors.Errorf("configure oidc client certificates: %w", err) - } - if cfg.OIDC.IgnoreEmailVerified.Value { logger.Warn(ctx, "coder will not check email_verified for OIDC logins") } @@ -994,6 +995,31 @@ func printLogo(cmd *cobra.Command) { _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Software development on your infrastucture\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version())) } +func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { + if len(tlsCertFiles) != len(tlsKeyFiles) { + return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times") + } + if len(tlsCertFiles) == 0 { + return nil, xerrors.New("--tls-cert-file is required when tls is enabled") + } + if len(tlsKeyFiles) == 0 { + return nil, xerrors.New("--tls-key-file is required when tls is enabled") + } + + certs := make([]tls.Certificate, len(tlsCertFiles)) + for i := range tlsCertFiles { + certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i] + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err) + } + + certs[i] = cert + } + + return certs, nil +} + func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, @@ -1026,7 +1052,7 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth) } - certs, err := codersdk.LoadCertificates(tlsCertFiles, tlsKeyFiles) + certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles) if err != nil { return nil, xerrors.Errorf("load certificates: %w", err) } @@ -1283,3 +1309,22 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg } return connectionURL, ep.Stop, nil } + +// clientTLSHTTPClient creates a http client that will use the configured client certs hwne making HTTP calls +func clientTLSHTTPClient(cfg *codersdk.TLSConfig) (*http.Client, error) { + if cfg != nil && cfg.ClientCertFile.Value != "" && cfg.ClientKeyFile.Value != "" { + certificates, err := loadCertificates([]string{cfg.ClientCertFile.Value}, []string{cfg.ClientKeyFile.Value}) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ //nolint:gosec + Certificates: certificates, + }, + }, + }, nil + } + return &http.Client{}, nil +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 634f60dc461cf..4719b8fe6ad97 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -107,6 +107,7 @@ type Options struct { Experimental bool DeploymentConfig *codersdk.DeploymentConfig UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking. + HttpClient *http.Client } // New constructs a Coder API handler. @@ -279,7 +280,7 @@ func New(options *Options) *API { for _, gitAuthConfig := range options.GitAuthConfigs { r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) { r.Use( - httpmw.ExtractOAuth2(gitAuthConfig, options.DeploymentConfig.TLS), + httpmw.ExtractOAuth2(gitAuthConfig, options.HttpClient), apiKeyMiddleware, ) r.Get("/callback", api.gitAuthCallback(gitAuthConfig)) @@ -426,12 +427,12 @@ func New(options *Options) *API { r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.DeploymentConfig.TLS)) + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HttpClient)) r.Get("/callback", api.userOAuth2Github) }) }) r.Route("/oidc/callback", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.DeploymentConfig.TLS)) + r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HttpClient)) r.Get("/", api.userOIDC) }) r.Group(func(r chi.Router) { diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index 4203a4afc0d61..80f9b587a4f78 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -40,10 +40,14 @@ func OAuth2(r *http.Request) OAuth2State { // ExtractOAuth2 is a middleware for automatically redirecting to OAuth // URLs, and handling the exchange inbound. Any route that does not have // a "code" URL parameter will be redirected. -func ExtractOAuth2(config OAuth2Config, tlsConfig *codersdk.TLSConfig) func(http.Handler) http.Handler { +func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + if client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + } + // Interfaces can hold a nil value if config == nil || reflect.ValueOf(config).IsNil() { httpapi.Write(ctx, rw, http.StatusPreconditionRequired, codersdk.Response{ @@ -114,7 +118,6 @@ func ExtractOAuth2(config OAuth2Config, tlsConfig *codersdk.TLSConfig) func(http redirect = stateRedirect.Value } - ctx, err = codersdk.HandleOauth2ClientCertificates(ctx, tlsConfig) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading client certificates.", diff --git a/codersdk/certificateUtils.go b/codersdk/certificateUtils.go deleted file mode 100644 index b4e6889f23284..0000000000000 --- a/codersdk/certificateUtils.go +++ /dev/null @@ -1,53 +0,0 @@ -package codersdk - -import ( - "context" - "crypto/tls" - "net/http" - - "golang.org/x/oauth2" - "golang.org/x/xerrors" -) - -func LoadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { - if len(tlsCertFiles) != len(tlsKeyFiles) { - return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times") - } - if len(tlsCertFiles) == 0 { - return nil, xerrors.New("--tls-cert-file is required when tls is enabled") - } - if len(tlsKeyFiles) == 0 { - return nil, xerrors.New("--tls-key-file is required when tls is enabled") - } - - certs := make([]tls.Certificate, len(tlsCertFiles)) - for i := range tlsCertFiles { - certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i] - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err) - } - - certs[i] = cert - } - - return certs, nil -} - -func HandleOauth2ClientCertificates(ctx context.Context, cfg *TLSConfig) (context.Context, error) { - if cfg != nil && cfg.ClientCertFile.Value != "" && cfg.ClientKeyFile.Value != "" { - certificates, err := LoadCertificates([]string{cfg.ClientCertFile.Value}, []string{cfg.ClientKeyFile.Value}) - if err != nil { - return nil, err - } - - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - Certificates: certificates, - }, - }, - }), nil - } - return ctx, nil -} From 2223349f293ea1dbf2808dc3b4f3755fd6b76fa7 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Tue, 6 Dec 2022 01:20:27 -0500 Subject: [PATCH 3/7] Cleaning --- coderd/httpmw/oauth2.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index 80f9b587a4f78..5463dffd60b0e 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -118,14 +118,6 @@ func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) redirect = stateRedirect.Value } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error reading client certificates.", - Detail: err.Error(), - }) - return - } - oauthToken, err := config.Exchange(ctx, code) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ From 8c73209e7fb44b8573bc9adf0b5b66bd714864ff Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Tue, 6 Dec 2022 01:22:10 -0500 Subject: [PATCH 4/7] Cleaning --- coderd/httpmw/oauth2.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index 5463dffd60b0e..949c7059cfbfe 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -44,9 +44,7 @@ func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - if client != nil { - ctx = context.WithValue(ctx, oauth2.HTTPClient, client) - } + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) // Interfaces can hold a nil value if config == nil || reflect.ValueOf(config).IsNil() { From af2e5648859795eb118ace92c8292bfbbb32d124 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Tue, 6 Dec 2022 01:33:24 -0500 Subject: [PATCH 5/7] Fix tests --- coderd/httpmw/oauth2.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index 949c7059cfbfe..5463dffd60b0e 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -44,7 +44,9 @@ func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + if client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + } // Interfaces can hold a nil value if config == nil || reflect.ValueOf(config).IsNil() { From b98e78d45c12e6aa4710deea591e7ead0b0d9d95 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Tue, 6 Dec 2022 01:37:16 -0500 Subject: [PATCH 6/7] Fix lint --- cli/server.go | 2 +- coderd/coderd.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cli/server.go b/cli/server.go index ecaab83d9b115..fd980d4033e5e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -374,7 +374,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co DeploymentConfig: cfg, PrometheusRegistry: prometheus.NewRegistry(), APIRateLimit: cfg.APIRateLimit.Value, - HttpClient: clientTLSHTTPClient, + HTTPClient: clientTLSHTTPClient, } if tlsConfig != nil { options.TLSCertificates = tlsConfig.Certificates diff --git a/coderd/coderd.go b/coderd/coderd.go index 4719b8fe6ad97..9d19577fb0afb 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -107,7 +107,7 @@ type Options struct { Experimental bool DeploymentConfig *codersdk.DeploymentConfig UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking. - HttpClient *http.Client + HTTPClient *http.Client } // New constructs a Coder API handler. @@ -280,7 +280,7 @@ func New(options *Options) *API { for _, gitAuthConfig := range options.GitAuthConfigs { r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) { r.Use( - httpmw.ExtractOAuth2(gitAuthConfig, options.HttpClient), + httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient), apiKeyMiddleware, ) r.Get("/callback", api.gitAuthCallback(gitAuthConfig)) @@ -427,12 +427,12 @@ func New(options *Options) *API { r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HttpClient)) + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient)) r.Get("/callback", api.userOAuth2Github) }) }) r.Route("/oidc/callback", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HttpClient)) + r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient)) r.Get("/", api.userOIDC) }) r.Group(func(r chi.Router) { From a5e6429aa86375736e4fcdba369c585b886bfd27 Mon Sep 17 00:00:00 2001 From: Arthur Normand Date: Sat, 10 Dec 2022 18:00:55 -0500 Subject: [PATCH 7/7] Cleaning and adding httpClient to context --- cli/server.go | 61 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/cli/server.go b/cli/server.go index fd980d4033e5e..1c233232c807a 100644 --- a/cli/server.go +++ b/cli/server.go @@ -199,6 +199,16 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co listener = tls.NewListener(listener, tlsConfig) } + ctx, httpClient, err := configureHTTPClient( + ctx, + cfg.TLS.ClientCertFile.Value, + cfg.TLS.ClientKeyFile.Value, + cfg.TLS.ClientCAFile.Value, + ) + if err != nil { + return xerrors.Errorf("configure http client: %w", err) + } + tcpAddr, valid := listener.Addr().(*net.TCPAddr) if !valid { return xerrors.New("must be listening on tcp") @@ -347,11 +357,6 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co return xerrors.Errorf("parse real ip config: %w", err) } - clientTLSHTTPClient, err := clientTLSHTTPClient(cfg.TLS) - if err != nil { - return xerrors.Errorf("configure http client certificates: %w", err) - } - options := &coderd.Options{ AccessURL: accessURLParsed, AppHostname: appHostname, @@ -374,7 +379,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co DeploymentConfig: cfg, PrometheusRegistry: prometheus.NewRegistry(), APIRateLimit: cfg.APIRateLimit.Value, - HTTPClient: clientTLSHTTPClient, + HTTPClient: httpClient, } if tlsConfig != nil { options.TLSCertificates = tlsConfig.Certificates @@ -1079,19 +1084,27 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles return nil, nil //nolint:nilnil } + err = configureCAPool(tlsClientCAFile, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConfig, nil +} + +func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error { if tlsClientCAFile != "" { caPool := x509.NewCertPool() data, err := os.ReadFile(tlsClientCAFile) if err != nil { - return nil, xerrors.Errorf("read %q: %w", tlsClientCAFile, err) + return xerrors.Errorf("read %q: %w", tlsClientCAFile, err) } if !caPool.AppendCertsFromPEM(data) { - return nil, xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file") + return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file") } tlsConfig.ClientCAs = caPool } - - return tlsConfig, nil + return nil } //nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive) @@ -1310,21 +1323,27 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg return connectionURL, ep.Stop, nil } -// clientTLSHTTPClient creates a http client that will use the configured client certs hwne making HTTP calls -func clientTLSHTTPClient(cfg *codersdk.TLSConfig) (*http.Client, error) { - if cfg != nil && cfg.ClientCertFile.Value != "" && cfg.ClientKeyFile.Value != "" { - certificates, err := loadCertificates([]string{cfg.ClientCertFile.Value}, []string{cfg.ClientKeyFile.Value}) +func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) { + if clientCertFile != "" && clientKeyFile != "" { + certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile}) if err != nil { - return nil, err + return ctx, nil, err } - return &http.Client{ + tlsClientConfig := &tls.Config{ //nolint:gosec + Certificates: certificates, + } + err = configureCAPool(tlsClientCAFile, tlsClientConfig) + if err != nil { + return nil, nil, err + } + + httpClient := &http.Client{ Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - Certificates: certificates, - }, + TLSClientConfig: tlsClientConfig, }, - }, nil + } + return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil } - return &http.Client{}, nil + return ctx, &http.Client{}, nil }