@@ -5,18 +5,24 @@ import (
5
5
"database/sql"
6
6
"encoding/json"
7
7
"fmt"
8
+ "sync/atomic"
8
9
"testing"
9
10
"time"
10
11
11
12
"github.com/google/uuid"
13
+ "github.com/prometheus/client_golang/prometheus"
12
14
"github.com/stretchr/testify/assert"
13
15
"github.com/stretchr/testify/require"
14
16
"go.uber.org/goleak"
15
17
18
+ "cdr.dev/slog"
16
19
"cdr.dev/slog/sloggers/slogtest"
17
20
"github.com/coder/coder/v2/coderd/database"
21
+ "github.com/coder/coder/v2/coderd/database/dbauthz"
18
22
"github.com/coder/coder/v2/coderd/database/dbgen"
19
23
"github.com/coder/coder/v2/coderd/database/dbtestutil"
24
+ "github.com/coder/coder/v2/coderd/provisionerdserver"
25
+ "github.com/coder/coder/v2/coderd/rbac"
20
26
"github.com/coder/coder/v2/coderd/unhanger"
21
27
"github.com/coder/coder/v2/provisionersdk"
22
28
"github.com/coder/coder/v2/testutil"
@@ -37,7 +43,7 @@ func TestDetectorNoJobs(t *testing.T) {
37
43
statsCh = make (chan unhanger.Stats )
38
44
)
39
45
40
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
46
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
41
47
detector .Start ()
42
48
tickCh <- time .Now ()
43
49
@@ -84,7 +90,7 @@ func TestDetectorNoHungJobs(t *testing.T) {
84
90
})
85
91
}
86
92
87
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
93
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
88
94
detector .Start ()
89
95
tickCh <- now
90
96
@@ -190,7 +196,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
190
196
t .Log ("previous job ID: " , previousWorkspaceBuildJob .ID )
191
197
t .Log ("current job ID: " , currentWorkspaceBuildJob .ID )
192
198
193
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
199
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
194
200
detector .Start ()
195
201
tickCh <- now
196
202
@@ -313,7 +319,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
313
319
t .Log ("previous job ID: " , previousWorkspaceBuildJob .ID )
314
320
t .Log ("current job ID: " , currentWorkspaceBuildJob .ID )
315
321
316
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
322
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
317
323
detector .Start ()
318
324
tickCh <- now
319
325
@@ -406,7 +412,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
406
412
407
413
t .Log ("current job ID: " , currentWorkspaceBuildJob .ID )
408
414
409
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
415
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
410
416
detector .Start ()
411
417
tickCh <- now
412
418
@@ -469,29 +475,40 @@ func TestDetectorHungOtherJobTypes(t *testing.T) {
469
475
Type : database .ProvisionerJobTypeTemplateVersionImport ,
470
476
Input : []byte ("{}" ),
471
477
})
472
-
473
- // Template dry-run job.
474
- templateDryRunJob = dbgen .ProvisionerJob (t , db , pubsub , database.ProvisionerJob {
475
- CreatedAt : tenMinAgo ,
476
- UpdatedAt : sixMinAgo ,
477
- StartedAt : sql.NullTime {
478
- Time : tenMinAgo ,
479
- Valid : true ,
480
- },
478
+ _ = dbgen .TemplateVersion (t , db , database.TemplateVersion {
481
479
OrganizationID : org .ID ,
482
- InitiatorID : user .ID ,
483
- Provisioner : database .ProvisionerTypeEcho ,
484
- StorageMethod : database .ProvisionerStorageMethodFile ,
485
- FileID : file .ID ,
486
- Type : database .ProvisionerJobTypeTemplateVersionDryRun ,
487
- Input : []byte ("{}" ),
480
+ JobID : templateImportJob .ID ,
488
481
})
489
482
)
490
483
484
+ // Template dry-run job.
485
+ dryRunVersion := dbgen .TemplateVersion (t , db , database.TemplateVersion {
486
+ OrganizationID : org .ID ,
487
+ })
488
+ input , err := json .Marshal (provisionerdserver.TemplateVersionDryRunJob {
489
+ TemplateVersionID : dryRunVersion .ID ,
490
+ })
491
+ require .NoError (t , err )
492
+ templateDryRunJob := dbgen .ProvisionerJob (t , db , pubsub , database.ProvisionerJob {
493
+ CreatedAt : tenMinAgo ,
494
+ UpdatedAt : sixMinAgo ,
495
+ StartedAt : sql.NullTime {
496
+ Time : tenMinAgo ,
497
+ Valid : true ,
498
+ },
499
+ OrganizationID : org .ID ,
500
+ InitiatorID : user .ID ,
501
+ Provisioner : database .ProvisionerTypeEcho ,
502
+ StorageMethod : database .ProvisionerStorageMethodFile ,
503
+ FileID : file .ID ,
504
+ Type : database .ProvisionerJobTypeTemplateVersionDryRun ,
505
+ Input : input ,
506
+ })
507
+
491
508
t .Log ("template import job ID: " , templateImportJob .ID )
492
509
t .Log ("template dry-run job ID: " , templateDryRunJob .ID )
493
510
494
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
511
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
495
512
detector .Start ()
496
513
tickCh <- now
497
514
@@ -564,11 +581,15 @@ func TestDetectorHungCanceledJob(t *testing.T) {
564
581
Type : database .ProvisionerJobTypeTemplateVersionImport ,
565
582
Input : []byte ("{}" ),
566
583
})
584
+ _ = dbgen .TemplateVersion (t , db , database.TemplateVersion {
585
+ OrganizationID : org .ID ,
586
+ JobID : templateImportJob .ID ,
587
+ })
567
588
)
568
589
569
590
t .Log ("template import job ID: " , templateImportJob .ID )
570
591
571
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
592
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
572
593
detector .Start ()
573
594
tickCh <- now
574
595
@@ -657,6 +678,10 @@ func TestDetectorPushesLogs(t *testing.T) {
657
678
Type : database .ProvisionerJobTypeTemplateVersionImport ,
658
679
Input : []byte ("{}" ),
659
680
})
681
+ _ = dbgen .TemplateVersion (t , db , database.TemplateVersion {
682
+ OrganizationID : org .ID ,
683
+ JobID : templateImportJob .ID ,
684
+ })
660
685
)
661
686
662
687
t .Log ("template import job ID: " , templateImportJob .ID )
@@ -678,7 +703,7 @@ func TestDetectorPushesLogs(t *testing.T) {
678
703
require .Len (t , logs , 10 )
679
704
}
680
705
681
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
706
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
682
707
detector .Start ()
683
708
684
709
// Create pubsub subscription to listen for new log events.
@@ -752,7 +777,7 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
752
777
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
753
778
now := time .Now ()
754
779
for i := 0 ; i < unhanger .MaxJobsPerRun + 1 ; i ++ {
755
- dbgen .ProvisionerJob (t , db , pubsub , database.ProvisionerJob {
780
+ pj := dbgen .ProvisionerJob (t , db , pubsub , database.ProvisionerJob {
756
781
CreatedAt : now .Add (- time .Hour ),
757
782
UpdatedAt : now .Add (- time .Hour ),
758
783
StartedAt : sql.NullTime {
@@ -767,9 +792,13 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
767
792
Type : database .ProvisionerJobTypeTemplateVersionImport ,
768
793
Input : []byte ("{}" ),
769
794
})
795
+ _ = dbgen .TemplateVersion (t , db , database.TemplateVersion {
796
+ OrganizationID : org .ID ,
797
+ JobID : pj .ID ,
798
+ })
770
799
}
771
800
772
- detector := unhanger .New (ctx , db , pubsub , log , tickCh ).WithStatsChannel (statsCh )
801
+ detector := unhanger .New (ctx , wrapDBAuthz ( db , log ) , pubsub , log , tickCh ).WithStatsChannel (statsCh )
773
802
detector .Start ()
774
803
tickCh <- now
775
804
@@ -788,3 +817,17 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
788
817
detector .Close ()
789
818
detector .Wait ()
790
819
}
820
+
821
+ // wrapDBAuthz adds our Authorization/RBAC around the given database store, to
822
+ // ensure the unhanger has the right permissions to do its work.
823
+ func wrapDBAuthz (db database.Store , logger slog.Logger ) database.Store {
824
+ accessControlStore := & atomic.Pointer [dbauthz.AccessControlStore ]{}
825
+ var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore {}
826
+ accessControlStore .Store (& acs )
827
+ return dbauthz .New (
828
+ db ,
829
+ rbac .NewStrictCachingAuthorizer (prometheus .NewRegistry ()),
830
+ logger ,
831
+ accessControlStore ,
832
+ )
833
+ }
0 commit comments