Skip to content

Commit 6325a9e

Browse files
authored
feat: support multiple certificates in coder server and helm (#4150)
1 parent a1056bf commit 6325a9e

File tree

6 files changed

+294
-78
lines changed

6 files changed

+294
-78
lines changed

cli/server.go

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"crypto/tls"
66
"crypto/x509"
77
"database/sql"
8-
"encoding/pem"
98
"errors"
109
"fmt"
1110
"io"
@@ -106,11 +105,11 @@ func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error))
106105
telemetryEnable bool
107106
telemetryTraceEnable bool
108107
telemetryURL string
109-
tlsCertFile string
108+
tlsCertFiles []string
110109
tlsClientCAFile string
111110
tlsClientAuth string
112111
tlsEnable bool
113-
tlsKeyFile string
112+
tlsKeyFiles []string
114113
tlsMinVersion string
115114
tunnel bool
116115
traceEnable bool
@@ -221,7 +220,7 @@ func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error))
221220
defer listener.Close()
222221

223222
if tlsEnable {
224-
listener, err = configureTLS(listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile)
223+
listener, err = configureServerTLS(listener, tlsMinVersion, tlsClientAuth, tlsCertFiles, tlsKeyFiles, tlsClientCAFile)
225224
if err != nil {
226225
return xerrors.Errorf("configure tls: %w", err)
227226
}
@@ -842,17 +841,17 @@ func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error))
842841
_ = root.Flags().MarkHidden("telemetry-url")
843842
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false,
844843
"Whether TLS will be enabled.")
845-
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
846-
"Path to the certificate for TLS. It requires a PEM-encoded file. "+
844+
cliflag.StringArrayVarP(root.Flags(), &tlsCertFiles, "tls-cert-file", "", "CODER_TLS_CERT_FILE", []string{},
845+
"Path to each certificate for TLS. It requires a PEM-encoded file. "+
847846
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
848847
"and the CA certificate together. The primary certificate should appear first in the combined file.")
849848
cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
850849
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
851850
cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
852851
`Policy the server will follow for TLS Client Authentication. `+
853852
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
854-
cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
855-
"Path to the private key for the certificate. It requires a PEM-encoded file")
853+
cliflag.StringArrayVarP(root.Flags(), &tlsKeyFiles, "tls-key-file", "", "CODER_TLS_KEY_FILE", []string{},
854+
"Paths to the private keys for each of the certificates. It requires a PEM-encoded file")
856855
cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
857856
`Minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
858857
cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false,
@@ -1040,7 +1039,32 @@ func printLogo(cmd *cobra.Command, spooky bool) {
10401039
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Remote development on your infrastucture\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version()))
10411040
}
10421041

1043-
func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile string) (net.Listener, error) {
1042+
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
1043+
if len(tlsCertFiles) != len(tlsKeyFiles) {
1044+
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
1045+
}
1046+
if len(tlsCertFiles) == 0 {
1047+
return nil, xerrors.New("--tls-cert-file is required when tls is enabled")
1048+
}
1049+
if len(tlsKeyFiles) == 0 {
1050+
return nil, xerrors.New("--tls-key-file is required when tls is enabled")
1051+
}
1052+
1053+
certs := make([]tls.Certificate, len(tlsCertFiles))
1054+
for i := range tlsCertFiles {
1055+
certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i]
1056+
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
1057+
if err != nil {
1058+
return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err)
1059+
}
1060+
1061+
certs[i] = cert
1062+
}
1063+
1064+
return certs, nil
1065+
}
1066+
1067+
func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (net.Listener, error) {
10441068
tlsConfig := &tls.Config{
10451069
MinVersion: tls.VersionTLS12,
10461070
}
@@ -1072,36 +1096,31 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
10721096
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
10731097
}
10741098

1075-
if tlsCertFile == "" {
1076-
return nil, xerrors.New("tls-cert-file is required when tls is enabled")
1077-
}
1078-
if tlsKeyFile == "" {
1079-
return nil, xerrors.New("tls-key-file is required when tls is enabled")
1080-
}
1081-
1082-
certPEMBlock, err := os.ReadFile(tlsCertFile)
1083-
if err != nil {
1084-
return nil, xerrors.Errorf("read file %q: %w", tlsCertFile, err)
1085-
}
1086-
keyPEMBlock, err := os.ReadFile(tlsKeyFile)
1099+
certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles)
10871100
if err != nil {
1088-
return nil, xerrors.Errorf("read file %q: %w", tlsKeyFile, err)
1089-
}
1090-
keyBlock, _ := pem.Decode(keyPEMBlock)
1091-
if keyBlock == nil {
1092-
return nil, xerrors.New("decoded pem is blank")
1093-
}
1094-
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
1095-
if err != nil {
1096-
return nil, xerrors.Errorf("create key pair: %w", err)
1097-
}
1098-
tlsConfig.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
1099-
return &cert, nil
1101+
return nil, xerrors.Errorf("load certificates: %w", err)
11001102
}
1103+
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
1104+
// If there's only one certificate, return it.
1105+
if len(certs) == 1 {
1106+
return &certs[0], nil
1107+
}
1108+
1109+
// Expensively check which certificate matches the client hello.
1110+
for _, cert := range certs {
1111+
cert := cert
1112+
if err := hi.SupportsCertificate(&cert); err == nil {
1113+
return &cert, nil
1114+
}
1115+
}
11011116

1102-
certPool := x509.NewCertPool()
1103-
certPool.AppendCertsFromPEM(certPEMBlock)
1104-
tlsConfig.RootCAs = certPool
1117+
// Return the first certificate if we have one, or return nil so the
1118+
// server doesn't fail.
1119+
if len(certs) > 0 {
1120+
return &certs[0], nil
1121+
}
1122+
return nil, nil //nolint:nilnil
1123+
}
11051124

11061125
if tlsClientCAFile != "" {
11071126
caPool := x509.NewCertPool()

cli/server_test.go

Lines changed: 145 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"runtime"
2222
"strconv"
2323
"strings"
24+
"sync/atomic"
2425
"testing"
2526
"time"
2627

@@ -240,20 +241,64 @@ func TestServer(t *testing.T) {
240241
err := root.ExecuteContext(ctx)
241242
require.Error(t, err)
242243
})
243-
t.Run("TLSNoCertFile", func(t *testing.T) {
244+
t.Run("TLSInvalid", func(t *testing.T) {
244245
t.Parallel()
245-
ctx, cancelFunc := context.WithCancel(context.Background())
246-
defer cancelFunc()
247246

248-
root, _ := clitest.New(t,
249-
"server",
250-
"--in-memory",
251-
"--address", ":0",
252-
"--tls-enable",
253-
"--cache-dir", t.TempDir(),
254-
)
255-
err := root.ExecuteContext(ctx)
256-
require.Error(t, err)
247+
cert1Path, key1Path := generateTLSCertificate(t)
248+
cert2Path, key2Path := generateTLSCertificate(t)
249+
250+
cases := []struct {
251+
name string
252+
args []string
253+
errContains string
254+
}{
255+
{
256+
name: "NoCertAndKey",
257+
args: []string{"--tls-enable"},
258+
errContains: "--tls-cert-file is required when tls is enabled",
259+
},
260+
{
261+
name: "NoCert",
262+
args: []string{"--tls-enable", "--tls-key-file", key1Path},
263+
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
264+
},
265+
{
266+
name: "NoKey",
267+
args: []string{"--tls-enable", "--tls-cert-file", cert1Path},
268+
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
269+
},
270+
{
271+
name: "MismatchedCount",
272+
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key1Path, "--tls-cert-file", cert2Path},
273+
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
274+
},
275+
{
276+
name: "MismatchedCertAndKey",
277+
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key2Path},
278+
errContains: "load TLS key pair",
279+
},
280+
}
281+
282+
for _, c := range cases {
283+
c := c
284+
t.Run(c.name, func(t *testing.T) {
285+
t.Parallel()
286+
ctx, cancelFunc := context.WithCancel(context.Background())
287+
defer cancelFunc()
288+
289+
args := []string{
290+
"server",
291+
"--in-memory",
292+
"--address", ":0",
293+
"--cache-dir", t.TempDir(),
294+
}
295+
args = append(args, c.args...)
296+
root, _ := clitest.New(t, args...)
297+
err := root.ExecuteContext(ctx)
298+
require.Error(t, err)
299+
require.ErrorContains(t, err, c.errContains)
300+
})
301+
}
257302
})
258303
t.Run("TLSValid", func(t *testing.T) {
259304
t.Parallel()
@@ -293,6 +338,86 @@ func TestServer(t *testing.T) {
293338
cancelFunc()
294339
require.NoError(t, <-errC)
295340
})
341+
t.Run("TLSValidMultiple", func(t *testing.T) {
342+
t.Parallel()
343+
ctx, cancelFunc := context.WithCancel(context.Background())
344+
defer cancelFunc()
345+
346+
cert1Path, key1Path := generateTLSCertificate(t, "alpaca.com")
347+
cert2Path, key2Path := generateTLSCertificate(t, "*.llama.com")
348+
root, cfg := clitest.New(t,
349+
"server",
350+
"--in-memory",
351+
"--address", ":0",
352+
"--tls-enable",
353+
"--tls-cert-file", cert1Path,
354+
"--tls-key-file", key1Path,
355+
"--tls-cert-file", cert2Path,
356+
"--tls-key-file", key2Path,
357+
"--cache-dir", t.TempDir(),
358+
)
359+
errC := make(chan error, 1)
360+
go func() {
361+
errC <- root.ExecuteContext(ctx)
362+
}()
363+
accessURL := waitAccessURL(t, cfg)
364+
require.Equal(t, "https", accessURL.Scheme)
365+
originalHost := accessURL.Host
366+
367+
var (
368+
expectAddr string
369+
dials int64
370+
)
371+
client := codersdk.New(accessURL)
372+
client.HTTPClient = &http.Client{
373+
Transport: &http.Transport{
374+
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
375+
atomic.AddInt64(&dials, 1)
376+
assert.Equal(t, expectAddr, addr)
377+
378+
host, _, err := net.SplitHostPort(addr)
379+
require.NoError(t, err)
380+
381+
// Always connect to the accessURL ip:port regardless of
382+
// hostname.
383+
conn, err := tls.Dial(network, originalHost, &tls.Config{
384+
MinVersion: tls.VersionTLS12,
385+
//nolint:gosec
386+
InsecureSkipVerify: true,
387+
ServerName: host,
388+
})
389+
if err != nil {
390+
return nil, err
391+
}
392+
393+
// We can't call conn.VerifyHostname because it requires
394+
// that the certificates are valid, so we call
395+
// VerifyHostname on the first certificate instead.
396+
require.Len(t, conn.ConnectionState().PeerCertificates, 1)
397+
err = conn.ConnectionState().PeerCertificates[0].VerifyHostname(host)
398+
assert.NoError(t, err, "invalid cert common name")
399+
return conn, nil
400+
},
401+
},
402+
}
403+
404+
// Use the first certificate and hostname.
405+
client.URL.Host = "alpaca.com:443"
406+
expectAddr = "alpaca.com:443"
407+
_, err := client.HasFirstUser(ctx)
408+
require.NoError(t, err)
409+
require.EqualValues(t, 1, atomic.LoadInt64(&dials))
410+
411+
// Use the second certificate (wildcard) and hostname.
412+
client.URL.Host = "hi.llama.com:443"
413+
expectAddr = "hi.llama.com:443"
414+
_, err = client.HasFirstUser(ctx)
415+
require.NoError(t, err)
416+
require.EqualValues(t, 2, atomic.LoadInt64(&dials))
417+
418+
cancelFunc()
419+
require.NoError(t, <-errC)
420+
})
296421
// This cannot be ran in parallel because it uses a signal.
297422
//nolint:paralleltest
298423
t.Run("Shutdown", func(t *testing.T) {
@@ -480,16 +605,22 @@ func TestServer(t *testing.T) {
480605
})
481606
}
482607

483-
func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
608+
func generateTLSCertificate(t testing.TB, commonName ...string) (certPath, keyPath string) {
484609
dir := t.TempDir()
485610

611+
commonNameStr := "localhost"
612+
if len(commonName) > 0 {
613+
commonNameStr = commonName[0]
614+
}
486615
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
487616
require.NoError(t, err)
488617
template := x509.Certificate{
489618
SerialNumber: big.NewInt(1),
490619
Subject: pkix.Name{
491620
Organization: []string{"Acme Co"},
621+
CommonName: commonNameStr,
492622
},
623+
DNSNames: []string{commonNameStr},
493624
NotBefore: time.Now(),
494625
NotAfter: time.Now().Add(time.Hour * 24 * 180),
495626

@@ -498,6 +629,7 @@ func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
498629
BasicConstraintsValid: true,
499630
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
500631
}
632+
501633
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
502634
require.NoError(t, err)
503635
certFile, err := os.CreateTemp(dir, "")

helm/templates/NOTES.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{{- if .Values.coder.tls.secretName }}
2+
3+
WARN: coder.tls.secretName is deprecated and will be removed in a future
4+
release. Please use coder.tls.secretNames instead.
5+
{{- end }}
6+
7+
Enjoy Coder! Please create an issue at https://github.com/coder/coder if you run
8+
into any problems! :)

0 commit comments

Comments
 (0)