@@ -21,6 +21,7 @@ import (
21
21
"runtime"
22
22
"strconv"
23
23
"strings"
24
+ "sync/atomic"
24
25
"testing"
25
26
"time"
26
27
@@ -240,20 +241,64 @@ func TestServer(t *testing.T) {
240
241
err := root .ExecuteContext (ctx )
241
242
require .Error (t , err )
242
243
})
243
- t .Run ("TLSNoCertFile " , func (t * testing.T ) {
244
+ t .Run ("TLSInvalid " , func (t * testing.T ) {
244
245
t .Parallel ()
245
- ctx , cancelFunc := context .WithCancel (context .Background ())
246
- defer cancelFunc ()
247
246
248
- root , _ := clitest .New (t ,
249
- "server" ,
250
- "--in-memory" ,
251
- "--address" , ":0" ,
252
- "--tls-enable" ,
253
- "--cache-dir" , t .TempDir (),
254
- )
255
- err := root .ExecuteContext (ctx )
256
- require .Error (t , err )
247
+ cert1Path , key1Path := generateTLSCertificate (t )
248
+ cert2Path , key2Path := generateTLSCertificate (t )
249
+
250
+ cases := []struct {
251
+ name string
252
+ args []string
253
+ errContains string
254
+ }{
255
+ {
256
+ name : "NoCertAndKey" ,
257
+ args : []string {"--tls-enable" },
258
+ errContains : "--tls-cert-file is required when tls is enabled" ,
259
+ },
260
+ {
261
+ name : "NoCert" ,
262
+ args : []string {"--tls-enable" , "--tls-key-file" , key1Path },
263
+ errContains : "--tls-cert-file and --tls-key-file must be used the same amount of times" ,
264
+ },
265
+ {
266
+ name : "NoKey" ,
267
+ args : []string {"--tls-enable" , "--tls-cert-file" , cert1Path },
268
+ errContains : "--tls-cert-file and --tls-key-file must be used the same amount of times" ,
269
+ },
270
+ {
271
+ name : "MismatchedCount" ,
272
+ args : []string {"--tls-enable" , "--tls-cert-file" , cert1Path , "--tls-key-file" , key1Path , "--tls-cert-file" , cert2Path },
273
+ errContains : "--tls-cert-file and --tls-key-file must be used the same amount of times" ,
274
+ },
275
+ {
276
+ name : "MismatchedCertAndKey" ,
277
+ args : []string {"--tls-enable" , "--tls-cert-file" , cert1Path , "--tls-key-file" , key2Path },
278
+ errContains : "load TLS key pair" ,
279
+ },
280
+ }
281
+
282
+ for _ , c := range cases {
283
+ c := c
284
+ t .Run (c .name , func (t * testing.T ) {
285
+ t .Parallel ()
286
+ ctx , cancelFunc := context .WithCancel (context .Background ())
287
+ defer cancelFunc ()
288
+
289
+ args := []string {
290
+ "server" ,
291
+ "--in-memory" ,
292
+ "--address" , ":0" ,
293
+ "--cache-dir" , t .TempDir (),
294
+ }
295
+ args = append (args , c .args ... )
296
+ root , _ := clitest .New (t , args ... )
297
+ err := root .ExecuteContext (ctx )
298
+ require .Error (t , err )
299
+ require .ErrorContains (t , err , c .errContains )
300
+ })
301
+ }
257
302
})
258
303
t .Run ("TLSValid" , func (t * testing.T ) {
259
304
t .Parallel ()
@@ -293,6 +338,86 @@ func TestServer(t *testing.T) {
293
338
cancelFunc ()
294
339
require .NoError (t , <- errC )
295
340
})
341
+ t .Run ("TLSValidMultiple" , func (t * testing.T ) {
342
+ t .Parallel ()
343
+ ctx , cancelFunc := context .WithCancel (context .Background ())
344
+ defer cancelFunc ()
345
+
346
+ cert1Path , key1Path := generateTLSCertificate (t , "alpaca.com" )
347
+ cert2Path , key2Path := generateTLSCertificate (t , "*.llama.com" )
348
+ root , cfg := clitest .New (t ,
349
+ "server" ,
350
+ "--in-memory" ,
351
+ "--address" , ":0" ,
352
+ "--tls-enable" ,
353
+ "--tls-cert-file" , cert1Path ,
354
+ "--tls-key-file" , key1Path ,
355
+ "--tls-cert-file" , cert2Path ,
356
+ "--tls-key-file" , key2Path ,
357
+ "--cache-dir" , t .TempDir (),
358
+ )
359
+ errC := make (chan error , 1 )
360
+ go func () {
361
+ errC <- root .ExecuteContext (ctx )
362
+ }()
363
+ accessURL := waitAccessURL (t , cfg )
364
+ require .Equal (t , "https" , accessURL .Scheme )
365
+ originalHost := accessURL .Host
366
+
367
+ var (
368
+ expectAddr string
369
+ dials int64
370
+ )
371
+ client := codersdk .New (accessURL )
372
+ client .HTTPClient = & http.Client {
373
+ Transport : & http.Transport {
374
+ DialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
375
+ atomic .AddInt64 (& dials , 1 )
376
+ assert .Equal (t , expectAddr , addr )
377
+
378
+ host , _ , err := net .SplitHostPort (addr )
379
+ require .NoError (t , err )
380
+
381
+ // Always connect to the accessURL ip:port regardless of
382
+ // hostname.
383
+ conn , err := tls .Dial (network , originalHost , & tls.Config {
384
+ MinVersion : tls .VersionTLS12 ,
385
+ //nolint:gosec
386
+ InsecureSkipVerify : true ,
387
+ ServerName : host ,
388
+ })
389
+ if err != nil {
390
+ return nil , err
391
+ }
392
+
393
+ // We can't call conn.VerifyHostname because it requires
394
+ // that the certificates are valid, so we call
395
+ // VerifyHostname on the first certificate instead.
396
+ require .Len (t , conn .ConnectionState ().PeerCertificates , 1 )
397
+ err = conn .ConnectionState ().PeerCertificates [0 ].VerifyHostname (host )
398
+ assert .NoError (t , err , "invalid cert common name" )
399
+ return conn , nil
400
+ },
401
+ },
402
+ }
403
+
404
+ // Use the first certificate and hostname.
405
+ client .URL .Host = "alpaca.com:443"
406
+ expectAddr = "alpaca.com:443"
407
+ _ , err := client .HasFirstUser (ctx )
408
+ require .NoError (t , err )
409
+ require .EqualValues (t , 1 , atomic .LoadInt64 (& dials ))
410
+
411
+ // Use the second certificate (wildcard) and hostname.
412
+ client .URL .Host = "hi.llama.com:443"
413
+ expectAddr = "hi.llama.com:443"
414
+ _ , err = client .HasFirstUser (ctx )
415
+ require .NoError (t , err )
416
+ require .EqualValues (t , 2 , atomic .LoadInt64 (& dials ))
417
+
418
+ cancelFunc ()
419
+ require .NoError (t , <- errC )
420
+ })
296
421
// This cannot be ran in parallel because it uses a signal.
297
422
//nolint:paralleltest
298
423
t .Run ("Shutdown" , func (t * testing.T ) {
@@ -480,16 +605,22 @@ func TestServer(t *testing.T) {
480
605
})
481
606
}
482
607
483
- func generateTLSCertificate (t testing.TB ) (certPath , keyPath string ) {
608
+ func generateTLSCertificate (t testing.TB , commonName ... string ) (certPath , keyPath string ) {
484
609
dir := t .TempDir ()
485
610
611
+ commonNameStr := "localhost"
612
+ if len (commonName ) > 0 {
613
+ commonNameStr = commonName [0 ]
614
+ }
486
615
privateKey , err := ecdsa .GenerateKey (elliptic .P256 (), rand .Reader )
487
616
require .NoError (t , err )
488
617
template := x509.Certificate {
489
618
SerialNumber : big .NewInt (1 ),
490
619
Subject : pkix.Name {
491
620
Organization : []string {"Acme Co" },
621
+ CommonName : commonNameStr ,
492
622
},
623
+ DNSNames : []string {commonNameStr },
493
624
NotBefore : time .Now (),
494
625
NotAfter : time .Now ().Add (time .Hour * 24 * 180 ),
495
626
@@ -498,6 +629,7 @@ func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
498
629
BasicConstraintsValid : true ,
499
630
IPAddresses : []net.IP {net .ParseIP ("127.0.0.1" )},
500
631
}
632
+
501
633
derBytes , err := x509 .CreateCertificate (rand .Reader , & template , & template , & privateKey .PublicKey , privateKey )
502
634
require .NoError (t , err )
503
635
certFile , err := os .CreateTemp (dir , "" )
0 commit comments