From 8ee7ad1dc5642b23c3d43344c78141997354d4e2 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 24 Mar 2022 18:28:51 +0000 Subject: [PATCH 1/3] feat: Add TLS support This adds numerous flags with inspiration taken from Vault for configuring TLS inside Coder. This enables secure deployments without a proxy, like Cloudflare. --- cli/start.go | 196 ++++++++++++++++++++++++++---- cli/start_test.go | 150 +++++++++++++++++++++++ coder.env | 5 +- codersdk/client.go | 9 +- codersdk/provisionerdaemons.go | 2 +- codersdk/workspaceresourceauth.go | 2 +- 6 files changed, 333 insertions(+), 31 deletions(-) diff --git a/cli/start.go b/cli/start.go index dc0205e71c733..6444d07313f0b 100644 --- a/cli/start.go +++ b/cli/start.go @@ -2,7 +2,10 @@ package cli import ( "context" + "crypto/tls" + "crypto/x509" "database/sql" + "encoding/pem" "fmt" "io/ioutil" "net" @@ -10,6 +13,7 @@ import ( "net/url" "os" "os/signal" + "strconv" "time" "github.com/briandowns/spinner" @@ -36,23 +40,23 @@ import ( func start() *cobra.Command { var ( + accessURL string address string + dev bool postgresURL string provisionerDaemonCount uint8 - dev bool + tlsCertFile string + tlsClientCAFile string + tlsClientAuth string + tlsEnable bool + tlsKeyFile string + tlsMinVersion string useTunnel bool ) root := &cobra.Command{ Use: "start", RunE: func(cmd *cobra.Command, args []string) error { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), ` ▄█▀ ▀█▄ - ▄▄ ▀▀▀ █▌ ██▀▀█▄ ▐█ - ▄▄██▀▀█▄▄▄ ██ ██ █▀▀█ ▐█▀▀██ ▄█▀▀█ █▀▀ -█▌ ▄▌ ▐█ █▌ ▀█▄▄▄█▌ █ █ ▐█ ██ ██▀▀ █ - ██████▀▄█ ▀▀▀▀ ▀▀▀▀ ▀▀▀▀▀ ▀▀▀▀ ▀ - -`) - + printLogo(cmd) if postgresURL == "" { // Default to the environment variable! postgresURL = os.Getenv("CODER_PG_CONNECTION_URL") @@ -63,6 +67,17 @@ func start() *cobra.Command { return xerrors.Errorf("listen %q: %w", address, err) } defer listener.Close() + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + if tlsEnable { + listener, err = configureTLS(tlsConfig, listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile) + if err != nil { + return xerrors.Errorf("configure tls: %w", err) + } + } + tcpAddr, valid := listener.Addr().(*net.TCPAddr) if !valid { return xerrors.New("must be listening on tcp") @@ -76,7 +91,12 @@ func start() *cobra.Command { Scheme: "http", Host: tcpAddr.String(), } - accessURL := localURL + if tlsEnable { + localURL.Scheme = "https" + } + if accessURL == "" { + accessURL = localURL.String() + } var tunnelErr <-chan error // If we're attempting to tunnel in dev-mode, the access URL // needs to be changed to use the tunnel. @@ -88,17 +108,11 @@ func start() *cobra.Command { IsConfirm: true, }) if err == nil { - var accessURLRaw string - accessURLRaw, tunnelErr, err = tunnel.New(cmd.Context(), localURL.String()) + accessURL, tunnelErr, err = tunnel.New(cmd.Context(), localURL.String()) if err != nil { return xerrors.Errorf("create tunnel: %w", err) } - accessURL, err = url.Parse(accessURLRaw) - if err != nil { - return xerrors.Errorf("parse: %w", err) - } - - _, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(cliui.Styles.Wrap.Render(cliui.Styles.Prompt.String()+`Tunnel started. Your deployment is accessible at:`))+"\n "+cliui.Styles.Field.Render(accessURL.String())) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(cliui.Styles.Wrap.Render(cliui.Styles.Prompt.String()+`Tunnel started. Your deployment is accessible at:`))+"\n "+cliui.Styles.Field.Render(accessURL)) } } validator, err := idtoken.NewValidator(cmd.Context(), option.WithoutAuthentication()) @@ -106,9 +120,13 @@ func start() *cobra.Command { return err } + accessURLParsed, err := url.Parse(accessURL) + if err != nil { + return xerrors.Errorf("parse access url %q: %w", accessURL, err) + } logger := slog.Make(sloghuman.Sink(os.Stderr)) options := &coderd.Options{ - AccessURL: accessURL, + AccessURL: accessURLParsed, Logger: logger.Named("coderd"), Database: databasefake.New(), Pubsub: database.NewPubsubInMemory(), @@ -137,6 +155,13 @@ func start() *cobra.Command { handler, closeCoderd := coderd.New(options) client := codersdk.New(localURL) + if tlsEnable { + // Use the TLS config here. This client is used for creating the + // default user, among other things. + client.HTTPClient.Transport = &http.Transport{ + TLSClientConfig: tlsConfig, + } + } provisionerDaemons := make([]*provisionerd.Server, 0) for i := uint8(0); i < provisionerDaemonCount; i++ { @@ -153,9 +178,17 @@ func start() *cobra.Command { }() errCh := make(chan error) + shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context()) + defer shutdownConns() go func() { defer close(errCh) - errCh <- http.Serve(listener, handler) + server := http.Server{ + Handler: handler, + BaseContext: func(_ net.Listener) context.Context { + return shutdownConnsCtx + }, + } + errCh <- server.Serve(listener) }() config := createConfig(cmd) @@ -271,6 +304,7 @@ func start() *cobra.Command { } _, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Prompt.String()+"Waiting for WebSocket connections to close...\n") + shutdownConns() closeCoderd() return nil }, @@ -279,11 +313,42 @@ func start() *cobra.Command { if defaultAddress == "" { defaultAddress = "127.0.0.1:3000" } + root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder.") root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard.") - root.Flags().BoolVarP(&dev, "dev", "", false, "Serve Coder in dev mode for tinkering.") - root.Flags().StringVarP(&postgresURL, "postgres-url", "", "", "URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).") + defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE")) + root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering.") + root.Flags().StringVarP(&postgresURL, "postgres-url", "", "", + "URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).") root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.") - root.Flags().BoolVarP(&useTunnel, "tunnel", "", true, "Serve dev mode through a Cloudflare Tunnel for easy setup.") + defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE")) + root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled.") + root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"), + "Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+ + "To configure the listener to use a CA certificate, concatenate the primary certificate "+ + "and the CA certificate together. The primary certificate should appear first in the combined file.") + root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"), + "PEM-encoded Certificate Authority file used for checking the authenticity of client.") + defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH") + if defaultTLSClientAuth == "" { + defaultTLSClientAuth = "request" + } + root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth, + `Specifies the policy the server will follow for TLS Client Authentication. `+ + `Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify".`) + root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"), + "Specifies the path to the private key for the certificate. It requires a PEM-encoded file.") + defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION") + if defaultTLSMinVersion == "" { + defaultTLSMinVersion = "tls12" + } + root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion, + `Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13".`) + defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL") + if defaultTunnelRaw == "" { + defaultTunnelRaw = "true" + } + defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw) + root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup.") _ = root.Flags().MarkHidden("tunnel") return root @@ -346,3 +411,88 @@ func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger s WorkDirectory: tempDir, }), nil } + +func printLogo(cmd *cobra.Command) { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), ` ▄█▀ ▀█▄ + ▄▄ ▀▀▀ █▌ ██▀▀█▄ ▐█ + ▄▄██▀▀█▄▄▄ ██ ██ █▀▀█ ▐█▀▀██ ▄█▀▀█ █▀▀ +█▌ ▄▌ ▐█ █▌ ▀█▄▄▄█▌ █ █ ▐█ ██ ██▀▀ █ + ██████▀▄█ ▀▀▀▀ ▀▀▀▀ ▀▀▀▀▀ ▀▀▀▀ ▀ + +`) +} + +func configureTLS(tlsConfig *tls.Config, listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile string) (net.Listener, error) { + switch tlsMinVersion { + case "tls10": + tlsConfig.MinVersion = tls.VersionTLS10 + case "tls11": + tlsConfig.MinVersion = tls.VersionTLS11 + case "tls12": + tlsConfig.MinVersion = tls.VersionTLS12 + case "tls13": + tlsConfig.MinVersion = tls.VersionTLS13 + default: + return nil, xerrors.Errorf("unrecognized tls version: %q", tlsMinVersion) + } + + switch tlsClientAuth { + case "none": + tlsConfig.ClientAuth = tls.NoClientCert + case "request": + tlsConfig.ClientAuth = tls.RequestClientCert + case "require-any": + tlsConfig.ClientAuth = tls.RequireAnyClientCert + case "verify-if-given": + tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven + case "require-and-verify": + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + default: + return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth) + } + + if tlsCertFile == "" { + return nil, xerrors.New("tls-cert-file is required when tls is enabled") + } + if tlsKeyFile == "" { + return nil, xerrors.New("tls-key-file is required when tls is enabled") + } + + certPEMBlock, err := os.ReadFile(tlsCertFile) + if err != nil { + return nil, xerrors.Errorf("read file %q: %w", tlsCertFile, err) + } + keyPEMBlock, err := os.ReadFile(tlsKeyFile) + if err != nil { + return nil, xerrors.Errorf("read file %q: %w", tlsKeyFile, err) + } + keyBlock, _ := pem.Decode(keyPEMBlock) + if keyBlock == nil { + return nil, xerrors.New("decoded pem is blank") + } + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return nil, xerrors.Errorf("create key pair: %w", err) + } + tlsConfig.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return &cert, nil + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certPEMBlock) + tlsConfig.RootCAs = certPool + + if tlsClientCAFile != "" { + caPool := x509.NewCertPool() + data, err := ioutil.ReadFile(tlsClientCAFile) + if err != nil { + return nil, xerrors.Errorf("read %q: %w", tlsClientCAFile, err) + } + if !caPool.AppendCertsFromPEM(data) { + return nil, xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file") + } + tlsConfig.ClientCAs = caPool + } + + return tls.NewListener(listener, tlsConfig), nil +} diff --git a/cli/start_test.go b/cli/start_test.go index 465d392647f3a..7839541dce8ae 100644 --- a/cli/start_test.go +++ b/cli/start_test.go @@ -2,7 +2,18 @@ package cli_test import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" "net/url" + "os" "runtime" "testing" "time" @@ -10,6 +21,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/database/postgres" ) @@ -79,4 +91,142 @@ func TestStart(t *testing.T) { _, err = client.User(ctx, "") require.NoError(t, err) }) + t.Run("TLSBadVersion", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + root, _ := clitest.New(t, "start", "--dev", "--tunnel=false", "--address", ":0", + "--tls-enable", "--tls-min-version", "tls9") + err := root.ExecuteContext(ctx) + require.Error(t, err) + }) + t.Run("TLSBadClientAuth", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + root, _ := clitest.New(t, "start", "--dev", "--tunnel=false", "--address", ":0", + "--tls-enable", "--tls-client-auth", "something") + err := root.ExecuteContext(ctx) + require.Error(t, err) + }) + t.Run("TLSNoCertFile", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + root, _ := clitest.New(t, "start", "--dev", "--tunnel=false", "--address", ":0", + "--tls-enable") + err := root.ExecuteContext(ctx) + require.Error(t, err) + }) + t.Run("TLSValid", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + certPath, keyPath := generateTLSCertificate(t) + root, cfg := clitest.New(t, "start", "--dev", "--tunnel=false", "--address", ":0", + "--tls-enable", "--tls-cert-file", certPath, "--tls-key-file", keyPath) + go func() { + err := root.ExecuteContext(ctx) + require.ErrorIs(t, err, context.Canceled) + }() + var accessURLRaw string + require.Eventually(t, func() bool { + var err error + accessURLRaw, err = cfg.URL().Read() + return err == nil + }, 15*time.Second, 25*time.Millisecond) + accessURL, err := url.Parse(accessURLRaw) + require.NoError(t, err) + require.Equal(t, "https", accessURL.Scheme) + client := codersdk.New(accessURL) + client.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + }, + }, + } + _, err = client.HasFirstUser(ctx) + require.NoError(t, err) + }) + // This cannot be ran in parallel because it uses a signal. + //nolint:paralleltest + t.Run("Shutdown", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + root, cfg := clitest.New(t, "start", "--dev", "--tunnel=false", "--address", ":0", "--provisioner-daemons", "0") + done := make(chan struct{}) + go func() { + defer close(done) + err := root.ExecuteContext(ctx) + require.NoError(t, err) + }() + var token string + require.Eventually(t, func() bool { + var err error + token, err = cfg.Session().Read() + return err == nil + }, 15*time.Second, 25*time.Millisecond) + // Verify that authentication was properly set in dev-mode. + accessURL, err := cfg.URL().Read() + require.NoError(t, err) + parsed, err := url.Parse(accessURL) + require.NoError(t, err) + client := codersdk.New(parsed) + client.SessionToken = token + orgs, err := client.OrganizationsByUser(ctx, "") + require.NoError(t, err) + coderdtest.NewProvisionerDaemon(t, client) + + // Create a workspace so the cleanup occurs! + version := coderdtest.CreateProjectVersion(t, client, orgs[0].ID, nil) + coderdtest.AwaitProjectVersionJob(t, client, version.ID) + project := coderdtest.CreateProject(t, client, orgs[0].ID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, "", project.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + require.NoError(t, err) + currentProcess, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + err = currentProcess.Signal(os.Interrupt) + require.NoError(t, err) + <-done + }) +} + +func generateTLSCertificate(t testing.TB) (certPath, keyPath string) { + dir := t.TempDir() + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + certFile, err := os.CreateTemp(dir, "") + require.NoError(t, err) + defer certFile.Close() + _, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + require.NoError(t, err) + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + keyFile, err := os.CreateTemp(dir, "") + require.NoError(t, err) + defer keyFile.Close() + err = pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}) + require.NoError(t, err) + return certFile.Name(), keyFile.Name() } diff --git a/coder.env b/coder.env index e26c7bad88029..dedef38a3782b 100644 --- a/coder.env +++ b/coder.env @@ -1,3 +1,6 @@ -# Runtime variables for "coder start". +# Run "coder start --help" to vie. CODER_ADDRESS= CODER_PG_CONNECTION_URL= +CODER_TLS_CERT_FILE= +CODER_TLS_ENABLE= +CODER_TLS_KEY_FILE= diff --git a/codersdk/client.go b/codersdk/client.go index d47cd5fc341ea..61d9e20512e3e 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -21,16 +21,15 @@ import ( func New(serverURL *url.URL) *Client { return &Client{ URL: serverURL, - httpClient: &http.Client{}, + HTTPClient: &http.Client{}, } } // Client is an HTTP caller for methods to the Coder API. type Client struct { - URL *url.URL + HTTPClient *http.Client SessionToken string - - httpClient *http.Client + URL *url.URL } // request performs an HTTP request with the body provided. @@ -71,7 +70,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac opt(req) } - resp, err := c.httpClient.Do(req) + resp, err := c.HTTPClient.Do(req) if err != nil { return nil, xerrors.Errorf("do: %w", err) } diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 0a987cdbb23d5..f8080dcc8a9ee 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -59,7 +59,7 @@ func (c *Client) ListenProvisionerDaemon(ctx context.Context) (proto.DRPCProvisi return nil, xerrors.Errorf("parse url: %w", err) } conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: c.httpClient, + HTTPClient: c.HTTPClient, // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, }) diff --git a/codersdk/workspaceresourceauth.go b/codersdk/workspaceresourceauth.go index 26f075e70650c..9c4673f3a6237 100644 --- a/codersdk/workspaceresourceauth.go +++ b/codersdk/workspaceresourceauth.go @@ -30,7 +30,7 @@ func (c *Client) AuthWorkspaceGoogleInstanceIdentity(ctx context.Context, servic serviceAccount = "default" } if gcpClient == nil { - gcpClient = metadata.NewClient(c.httpClient) + gcpClient = metadata.NewClient(c.HTTPClient) } // "format=full" is required, otherwise the responding payload will be missing "instance_id". jwt, err := gcpClient.Get(fmt.Sprintf("instance/service-accounts/%s/identity?audience=coder&format=full", serviceAccount)) From 83165488b98aefe6f7af644641e3506f1bcfe879 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 24 Mar 2022 14:09:28 -0500 Subject: [PATCH 2/3] Update cli/start.go Co-authored-by: Colin Adler --- cli/start.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/start.go b/cli/start.go index 8f334ba337d4e..bd1e0f04ed418 100644 --- a/cli/start.go +++ b/cli/start.go @@ -177,7 +177,7 @@ func start() *cobra.Command { } }() - errCh := make(chan error) + errCh := make(chan error, 1) shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context()) defer shutdownConns() go func() { From aa38753968391dfd908db709b2e86dca8675b4af Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 24 Mar 2022 19:10:41 +0000 Subject: [PATCH 3/3] Fix flag help in coder.env --- coder.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coder.env b/coder.env index dedef38a3782b..dc341d542aeed 100644 --- a/coder.env +++ b/coder.env @@ -1,4 +1,4 @@ -# Run "coder start --help" to vie. +# Run "coder start --help" for flag information. CODER_ADDRESS= CODER_PG_CONNECTION_URL= CODER_TLS_CERT_FILE=