Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ package cli

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"net/http/pprof"
Expand Down Expand Up @@ -392,12 +396,10 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co

// Redirect from the HTTP listener to the access URL if:
// 1. The redirect flag is enabled.
// 2. HTTP listening is enabled (obviously).
// 3. TLS is enabled (otherwise they're likely using a reverse proxy
// 2. TLS is enabled (otherwise they're likely using a reverse proxy
// which can do this instead).
// 4. The access URL has been set manually (not a tunnel).
// 5. The access URL is HTTPS.
shouldRedirectHTTPToAccessURL := cfg.TLS.RedirectHTTP.Value && cfg.HTTPAddress.Value != "" && cfg.TLS.Enable.Value && tunnel == nil && accessURLParsed.Scheme == "https"
// 3. The access URL is HTTPS.
shouldRedirectHTTPToAccessURL := cfg.TLS.RedirectHTTP.Value && cfg.TLS.Enable.Value && accessURLParsed.Scheme == "https"

// A newline is added before for visibility in terminal output.
cmd.Printf("\nView the Web UI: %s\n", accessURLParsed.String())
Expand Down Expand Up @@ -1162,12 +1164,6 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er
if len(tlsCertFiles) != len(tlsKeyFiles) {
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
}
if len(tlsCertFiles) == 0 {
return nil, xerrors.New("--tls-cert-file is required when tls is enabled")
}
if len(tlsKeyFiles) == 0 {
return nil, xerrors.New("--tls-key-file is required when tls is enabled")
}

certs := make([]tls.Certificate, len(tlsCertFiles))
for i := range tlsCertFiles {
Expand All @@ -1183,6 +1179,36 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er
return certs, nil
}

// generateSelfSignedCertificate creates an unsafe self-signed certificate
// at random that allows users to proceed with setup in the event they
// haven't configured any TLS certificates.
func generateSelfSignedCertificate() (*tls.Certificate, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),

KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}

var cert tls.Certificate
cert.Certificate = append(cert.Certificate, derBytes)
cert.PrivateKey = privateKey
return &cert, nil
}

func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
Expand Down Expand Up @@ -1219,8 +1245,19 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles
if err != nil {
return nil, xerrors.Errorf("load certificates: %w", err)
}
var selfSignedCertificate *tls.Certificate
if len(certs) == 0 {
selfSignedCertificate, err = generateSelfSignedCertificate()
if err != nil {
return nil, xerrors.Errorf("generate self signed certificate: %w", err)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just certs = append(certs, selfSignedCertificate) instead of the extra logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will fix!

}

tlsConfig.Certificates = certs
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
if selfSignedCertificate != nil {
return selfSignedCertificate, nil
}
// If there's only one certificate, return it.
if len(certs) == 1 {
return &certs[0], nil
Expand Down Expand Up @@ -1485,7 +1522,7 @@ func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile stri

func redirectHTTPToAccessURL(handler http.Handler, accessURL *url.URL) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
if r.Host != accessURL.Host {
http.Redirect(w, r, accessURL.String(), http.StatusTemporaryRedirect)
return
}
Expand Down