Skip to content

feat: Add TLS support #556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/coder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ jobs:
gcloud config set compute/zone us-central1-a
gcloud compute scp ./dist/coder_*_linux_amd64.deb coder:/tmp/coder.deb
gcloud compute ssh coder -- sudo dpkg -i /tmp/coder.deb
gcloud compute ssh coder -- sudo systemctl daemon-reload

- name: Start
run: gcloud compute ssh coder -- sudo service coder restart
Expand Down
200 changes: 175 additions & 25 deletions cli/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package cli

import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"time"

"github.com/briandowns/spinner"
Expand All @@ -36,23 +40,23 @@ import (

func start() *cobra.Command {
var (
accessURL string
address string
dev bool
postgresURL string
provisionerDaemonCount uint8
dev bool
tlsCertFile string
tlsClientCAFile string
tlsClientAuth string
tlsEnable bool
tlsKeyFile string
tlsMinVersion string
useTunnel bool
)
root := &cobra.Command{
Use: "start",
RunE: func(cmd *cobra.Command, args []string) error {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), ` ▄█▀ ▀█▄
▄▄ ▀▀▀ █▌ ██▀▀█▄ ▐█
▄▄██▀▀█▄▄▄ ██ ██ █▀▀█ ▐█▀▀██ ▄█▀▀█ █▀▀
█▌ ▄▌ ▐█ █▌ ▀█▄▄▄█▌ █ █ ▐█ ██ ██▀▀ █
██████▀▄█ ▀▀▀▀ ▀▀▀▀ ▀▀▀▀▀ ▀▀▀▀ ▀

`)

printLogo(cmd)
if postgresURL == "" {
// Default to the environment variable!
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
Expand All @@ -63,6 +67,17 @@ func start() *cobra.Command {
return xerrors.Errorf("listen %q: %w", address, err)
}
defer listener.Close()

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if tlsEnable {
listener, err = configureTLS(tlsConfig, listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile)
if err != nil {
return xerrors.Errorf("configure tls: %w", err)
}
}

tcpAddr, valid := listener.Addr().(*net.TCPAddr)
if !valid {
return xerrors.New("must be listening on tcp")
Expand All @@ -76,7 +91,12 @@ func start() *cobra.Command {
Scheme: "http",
Host: tcpAddr.String(),
}
accessURL := localURL
if tlsEnable {
localURL.Scheme = "https"
}
if accessURL == "" {
accessURL = localURL.String()
}
var tunnelErr <-chan error
// If we're attempting to tunnel in dev-mode, the access URL
// needs to be changed to use the tunnel.
Expand All @@ -88,27 +108,25 @@ func start() *cobra.Command {
IsConfirm: true,
})
if err == nil {
var accessURLRaw string
accessURLRaw, tunnelErr, err = tunnel.New(cmd.Context(), localURL.String())
accessURL, tunnelErr, err = tunnel.New(cmd.Context(), localURL.String())
if err != nil {
return xerrors.Errorf("create tunnel: %w", err)
}
accessURL, err = url.Parse(accessURLRaw)
if err != nil {
return xerrors.Errorf("parse: %w", err)
}

_, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(cliui.Styles.Wrap.Render(cliui.Styles.Prompt.String()+`Tunnel started. Your deployment is accessible at:`))+"\n "+cliui.Styles.Field.Render(accessURL.String()))
_, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(cliui.Styles.Wrap.Render(cliui.Styles.Prompt.String()+`Tunnel started. Your deployment is accessible at:`))+"\n "+cliui.Styles.Field.Render(accessURL))
}
}
validator, err := idtoken.NewValidator(cmd.Context(), option.WithoutAuthentication())
if err != nil {
return err
}

accessURLParsed, err := url.Parse(accessURL)
if err != nil {
return xerrors.Errorf("parse access url %q: %w", accessURL, err)
}
logger := slog.Make(sloghuman.Sink(os.Stderr))
options := &coderd.Options{
AccessURL: accessURL,
AccessURL: accessURLParsed,
Logger: logger.Named("coderd"),
Database: databasefake.New(),
Pubsub: database.NewPubsubInMemory(),
Expand Down Expand Up @@ -137,6 +155,13 @@ func start() *cobra.Command {

handler, closeCoderd := coderd.New(options)
client := codersdk.New(localURL)
if tlsEnable {
// Use the TLS config here. This client is used for creating the
// default user, among other things.
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}

provisionerDaemons := make([]*provisionerd.Server, 0)
for i := uint8(0); i < provisionerDaemonCount; i++ {
Expand All @@ -152,10 +177,18 @@ func start() *cobra.Command {
}
}()

errCh := make(chan error)
errCh := make(chan error, 1)
shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context())
defer shutdownConns()
go func() {
defer close(errCh)
errCh <- http.Serve(listener, handler)
server := http.Server{
Handler: handler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool! When cancelling this context, is it a hard cutoff to websockets (ie just severs the connection), or does it send a close frame?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard cutoff right now, but we technically could send a shutdown signal. Maybe I'll do that...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't necessarily think either is right or wrong, just asking for my own curiosity mostly!

},
}
errCh <- server.Serve(listener)
}()

config := createConfig(cmd)
Expand Down Expand Up @@ -271,6 +304,7 @@ func start() *cobra.Command {
}

_, _ = fmt.Fprintf(cmd.OutOrStdout(), cliui.Styles.Prompt.String()+"Waiting for WebSocket connections to close...\n")
shutdownConns()
closeCoderd()
return nil
},
Expand All @@ -279,11 +313,42 @@ func start() *cobra.Command {
if defaultAddress == "" {
defaultAddress = "127.0.0.1:3000"
}
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard.")
root.Flags().BoolVarP(&dev, "dev", "", false, "Serve Coder in dev mode for tinkering.")
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "", "URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
root.Flags().BoolVarP(&useTunnel, "tunnel", "", true, "Serve dev mode through a Cloudflare Tunnel for easy setup.")
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
if defaultTLSClientAuth == "" {
defaultTLSClientAuth = "request"
}
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
`Specifies the policy the server will follow for TLS Client Authentication. `+
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
if defaultTLSMinVersion == "" {
defaultTLSMinVersion = "tls12"
}
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
if defaultTunnelRaw == "" {
defaultTunnelRaw = "true"
}
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
_ = root.Flags().MarkHidden("tunnel")

return root
Expand Down Expand Up @@ -346,3 +411,88 @@ func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger s
WorkDirectory: tempDir,
}), nil
}

func printLogo(cmd *cobra.Command) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), ` ▄█▀ ▀█▄
▄▄ ▀▀▀ █▌ ██▀▀█▄ ▐█
▄▄██▀▀█▄▄▄ ██ ██ █▀▀█ ▐█▀▀██ ▄█▀▀█ █▀▀
█▌ ▄▌ ▐█ █▌ ▀█▄▄▄█▌ █ █ ▐█ ██ ██▀▀ █
██████▀▄█ ▀▀▀▀ ▀▀▀▀ ▀▀▀▀▀ ▀▀▀▀ ▀

`)
}

func configureTLS(tlsConfig *tls.Config, listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile string) (net.Listener, error) {
switch tlsMinVersion {
case "tls10":
tlsConfig.MinVersion = tls.VersionTLS10
case "tls11":
tlsConfig.MinVersion = tls.VersionTLS11
case "tls12":
tlsConfig.MinVersion = tls.VersionTLS12
case "tls13":
tlsConfig.MinVersion = tls.VersionTLS13
default:
return nil, xerrors.Errorf("unrecognized tls version: %q", tlsMinVersion)
}

switch tlsClientAuth {
case "none":
tlsConfig.ClientAuth = tls.NoClientCert
case "request":
tlsConfig.ClientAuth = tls.RequestClientCert
case "require-any":
tlsConfig.ClientAuth = tls.RequireAnyClientCert
case "verify-if-given":
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
case "require-and-verify":
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
default:
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
}

if tlsCertFile == "" {
return nil, xerrors.New("tls-cert-file is required when tls is enabled")
}
if tlsKeyFile == "" {
return nil, xerrors.New("tls-key-file is required when tls is enabled")
}

certPEMBlock, err := os.ReadFile(tlsCertFile)
if err != nil {
return nil, xerrors.Errorf("read file %q: %w", tlsCertFile, err)
}
keyPEMBlock, err := os.ReadFile(tlsKeyFile)
if err != nil {
return nil, xerrors.Errorf("read file %q: %w", tlsKeyFile, err)
}
keyBlock, _ := pem.Decode(keyPEMBlock)
if keyBlock == nil {
return nil, xerrors.New("decoded pem is blank")
}
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return nil, xerrors.Errorf("create key pair: %w", err)
}
tlsConfig.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &cert, nil
}

certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(certPEMBlock)
tlsConfig.RootCAs = certPool

if tlsClientCAFile != "" {
caPool := x509.NewCertPool()
data, err := ioutil.ReadFile(tlsClientCAFile)
if err != nil {
return nil, xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
}
if !caPool.AppendCertsFromPEM(data) {
return nil, xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
}
tlsConfig.ClientCAs = caPool
}

return tls.NewListener(listener, tlsConfig), nil
}
Loading