@@ -18,6 +18,8 @@ import (
18
18
"github.com/coder/coder/v2/buildinfo"
19
19
"github.com/coder/coder/v2/coderd/coderdtest"
20
20
"github.com/coder/coder/v2/coderd/database"
21
+ "github.com/coder/coder/v2/coderd/database/dbauthz"
22
+ "github.com/coder/coder/v2/coderd/provisionerkey"
21
23
"github.com/coder/coder/v2/coderd/rbac"
22
24
"github.com/coder/coder/v2/coderd/util/ptr"
23
25
"github.com/coder/coder/v2/codersdk"
@@ -552,6 +554,152 @@ func TestProvisionerDaemonServe(t *testing.T) {
552
554
require .NoError (t , err )
553
555
require .Len (t , daemons , 0 )
554
556
})
557
+
558
+ t .Run ("ProvisionerKeyAuth" , func (t * testing.T ) {
559
+ t .Parallel ()
560
+
561
+ insertParams , token , err := provisionerkey .New (uuid .Nil , "dont-TEST-me" )
562
+ require .NoError (t , err )
563
+
564
+ tcs := []struct {
565
+ name string
566
+ psk string
567
+ multiOrgFeatureEnabled bool
568
+ multiOrgExperimentEnabled bool
569
+ insertParams database.InsertProvisionerKeyParams
570
+ requestProvisionerKey string
571
+ requestPSK string
572
+ errStatusCode int
573
+ }{
574
+ {
575
+ name : "MultiOrgDisabledPSKAuthOK" ,
576
+ psk : "provisionersftw" ,
577
+ requestPSK : "provisionersftw" ,
578
+ },
579
+ {
580
+ name : "MultiOrgExperimentDisabledPSKAuthOK" ,
581
+ multiOrgFeatureEnabled : true ,
582
+ psk : "provisionersftw" ,
583
+ requestPSK : "provisionersftw" ,
584
+ },
585
+ {
586
+ name : "MultiOrgFeatureDisabledPSKAuthOK" ,
587
+ multiOrgExperimentEnabled : true ,
588
+ psk : "provisionersftw" ,
589
+ requestPSK : "provisionersftw" ,
590
+ },
591
+ {
592
+ name : "MultiOrgEnabledPSKAuthOK" ,
593
+ psk : "provisionersftw" ,
594
+ multiOrgFeatureEnabled : true ,
595
+ multiOrgExperimentEnabled : true ,
596
+ requestPSK : "provisionersftw" ,
597
+ },
598
+ {
599
+ name : "MultiOrgEnabledKeyAuthOK" ,
600
+ psk : "provisionersftw" ,
601
+ multiOrgFeatureEnabled : true ,
602
+ multiOrgExperimentEnabled : true ,
603
+ insertParams : insertParams ,
604
+ requestProvisionerKey : token ,
605
+ },
606
+ {
607
+ name : "MultiOrgEnabledPSKAuthDisabled" ,
608
+ multiOrgFeatureEnabled : true ,
609
+ multiOrgExperimentEnabled : true ,
610
+ requestPSK : "provisionersftw" ,
611
+ errStatusCode : http .StatusUnauthorized ,
612
+ },
613
+ {
614
+ name : "WrongKey" ,
615
+ multiOrgFeatureEnabled : true ,
616
+ multiOrgExperimentEnabled : true ,
617
+ requestProvisionerKey : "provisionersftw" ,
618
+ errStatusCode : http .StatusUnauthorized ,
619
+ },
620
+ {
621
+ name : "IdOKKeyWrong" ,
622
+ multiOrgFeatureEnabled : true ,
623
+ multiOrgExperimentEnabled : true ,
624
+ requestProvisionerKey : insertParams .ID .String () + ":" + "wrong" ,
625
+ errStatusCode : http .StatusUnauthorized ,
626
+ },
627
+ {
628
+ name : "IdWrongKeyOK" ,
629
+ multiOrgFeatureEnabled : true ,
630
+ multiOrgExperimentEnabled : true ,
631
+ requestProvisionerKey : uuid .NewString () + ":" + token ,
632
+ errStatusCode : http .StatusUnauthorized ,
633
+ },
634
+ {
635
+ name : "TokenOnly" ,
636
+ multiOrgFeatureEnabled : true ,
637
+ multiOrgExperimentEnabled : true ,
638
+ requestProvisionerKey : token ,
639
+ errStatusCode : http .StatusUnauthorized ,
640
+ },
641
+ }
642
+
643
+ for _ , tc := range tcs {
644
+ t .Run (tc .name , func (t * testing.T ) {
645
+ t .Parallel ()
646
+ features := license.Features {
647
+ codersdk .FeatureExternalProvisionerDaemons : 1 ,
648
+ }
649
+ if tc .multiOrgFeatureEnabled {
650
+ features [codersdk .FeatureMultipleOrganizations ] = 1
651
+ }
652
+ dv := coderdtest .DeploymentValues (t )
653
+ if tc .multiOrgExperimentEnabled {
654
+ dv .Experiments .Append (string (codersdk .ExperimentMultiOrganization ))
655
+ }
656
+ client , db , user := coderdenttest .NewWithDatabase (t , & coderdenttest.Options {
657
+ LicenseOptions : & coderdenttest.LicenseOptions {
658
+ Features : features ,
659
+ },
660
+ ProvisionerDaemonPSK : tc .psk ,
661
+ Options : & coderdtest.Options {
662
+ DeploymentValues : dv ,
663
+ },
664
+ })
665
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
666
+ defer cancel ()
667
+
668
+ if tc .insertParams .Name != "" {
669
+ tc .insertParams .OrganizationID = user .OrganizationID
670
+ // nolint:gocritic // test
671
+ _ , err := db .InsertProvisionerKey (dbauthz .AsSystemRestricted (ctx ), tc .insertParams )
672
+ require .NoError (t , err )
673
+ }
674
+
675
+ another := codersdk .New (client .URL )
676
+ srv , err := another .ServeProvisionerDaemon (ctx , codersdk.ServeProvisionerDaemonRequest {
677
+ ID : uuid .New (),
678
+ Name : testutil .MustRandString (t , 63 ),
679
+ Organization : user .OrganizationID ,
680
+ Provisioners : []codersdk.ProvisionerType {
681
+ codersdk .ProvisionerTypeEcho ,
682
+ },
683
+ Tags : map [string ]string {
684
+ provisionersdk .TagScope : provisionersdk .ScopeOrganization ,
685
+ },
686
+ PreSharedKey : tc .requestPSK ,
687
+ ProvisionerKey : tc .requestProvisionerKey ,
688
+ })
689
+ if tc .errStatusCode != 0 {
690
+ require .Error (t , err )
691
+ var apiError * codersdk.Error
692
+ require .ErrorAs (t , err , & apiError )
693
+ require .Equal (t , http .StatusUnauthorized , apiError .StatusCode ())
694
+ return
695
+ }
696
+
697
+ require .NoError (t , err )
698
+ err = srv .DRPCConn ().Close ()
699
+ require .NoError (t , err )
700
+ })
701
+ }
702
+ })
555
703
}
556
704
557
705
func TestGetProvisionerDaemons (t * testing.T ) {
0 commit comments