diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index a844a7e8c6258..4580ff3e1bc8a 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -864,6 +864,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { tickInterval time.Duration bytesPerTick int64 ssh bool + disableDirect bool useHostLogin bool app string template string @@ -1023,15 +1024,16 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { // Setup our workspace agent connection. config := workspacetraffic.Config{ - AgentID: agent.ID, - BytesPerTick: bytesPerTick, - Duration: strategy.timeout, - TickInterval: tickInterval, - ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name), - WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name), - SSH: ssh, - Echo: ssh, - App: appConfig, + AgentID: agent.ID, + BytesPerTick: bytesPerTick, + Duration: strategy.timeout, + TickInterval: tickInterval, + ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name), + WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name), + SSH: ssh, + DisableDirect: disableDirect, + Echo: ssh, + App: appConfig, } if webClient != nil { @@ -1117,6 +1119,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { Description: "Send traffic over SSH, cannot be used with --app.", Value: serpent.BoolOf(&ssh), }, + { + Flag: "disable-direct", + Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_DISABLE_DIRECT_CONNECTIONS", + Default: "false", + Description: "Disable direct connections for SSH traffic to workspaces. Does nothing if `--ssh` is not also set.", + Value: serpent.BoolOf(&disableDirect), + }, { Flag: "app", Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_APP", diff --git a/cli/exp_task_status_test.go b/cli/exp_task_status_test.go index 6631980ac1fbd..b520d2728804e 100644 --- a/cli/exp_task_status_test.go +++ b/cli/exp_task_status_test.go @@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE ctx = testutil.Context(t, testutil.WaitShort) now = time.Now().UTC() // TODO: replace with quartz srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now))) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) sb = strings.Builder{} args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()} ) t.Cleanup(srv.Close) - client.URL = testutil.MustURL(t, srv.URL) args = append(args, tc.args...) inv, root := clitest.New(t, args...) inv.Stdout = &sb diff --git a/cli/exp_taskcreate_test.go b/cli/exp_taskcreate_test.go index 121f22eb525f6..f49c2fee1194a 100644 --- a/cli/exp_taskcreate_test.go +++ b/cli/exp_taskcreate_test.go @@ -5,14 +5,12 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "strings" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" @@ -236,7 +234,7 @@ func TestTaskCreate(t *testing.T) { var ( ctx = testutil.Context(t, testutil.WaitShort) srv = httptest.NewServer(tt.handler(t, ctx)) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) args = []string{"exp", "task", "create"} sb strings.Builder err error @@ -244,9 +242,6 @@ func TestTaskCreate(t *testing.T) { t.Cleanup(srv.Close) - client.URL, err = url.Parse(srv.URL) - require.NoError(t, err) - inv, root := clitest.New(t, append(args, tt.args...)...) inv.Environ = serpent.ParseEnviron(tt.env, "") inv.Stdout = &sb diff --git a/cli/root.go b/cli/root.go index b3e67a46ad463..ed6869b6a1c49 100644 --- a/cli/root.go +++ b/cli/root.go @@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod } func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { + if client.SessionTokenProvider == nil { + client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{} + } transport := http.DefaultTransport transport = wrapTransportWithTelemetryHeader(transport, inv) if !r.noVersionCheck { diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index c7f7d35937198..a76f6447dcabd 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken // ExternalLogin does the oauth2 flow for external auth providers. This requires // an authenticated coder client. -func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) { +func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) { coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID)) require.NoError(t, err) f.SetRedirect(t, coderOauthURL.String()) @@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil) require.NoError(t, err) // External auth flow requires the user be authenticated. - headerName := client.SessionTokenHeader - if headerName == "" { - headerName = codersdk.SessionTokenHeader - } - req.Header.Set(headerName, client.SessionToken()) + opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...) if cli.Jar == nil { cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 68bed8f2ef5e9..40caad0818802 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2653,835 +2653,716 @@ func (s *MethodTestSuite) TestCryptoKeys() { } func (s *MethodTestSuite) TestSystemFunctions() { - s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID}) - check.Args(database.UpdateUserLinkedIDParams{ - UserID: u.ID, - LinkedID: l.LinkedID, - LoginType: database.LoginTypeGithub, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(l) - })) - s.Run("GetLatestWorkspaceAppStatusesByWorkspaceIDs", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("UpdateUserLinkedID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + l := testutil.Fake(s.T(), faker, database.UserLink{UserID: u.ID}) + arg := database.UpdateUserLinkedIDParams{UserID: u.ID, LinkedID: l.LinkedID, LoginType: database.LoginTypeGithub} + dbm.EXPECT().UpdateUserLinkedID(gomock.Any(), arg).Return(l, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(l) + })) + s.Run("GetLatestWorkspaceAppStatusesByWorkspaceIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New()} + dbm.EXPECT().GetLatestWorkspaceAppStatusesByWorkspaceIDs(gomock.Any(), ids).Return([]database.WorkspaceAppStatus{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceAppStatusesByAppIDs", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceAppStatusesByAppIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New()} + dbm.EXPECT().GetWorkspaceAppStatusesByAppIDs(gomock.Any(), ids).Return([]database.WorkspaceAppStatus{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetLatestWorkspaceBuildsByWorkspaceIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args([]uuid.UUID{ws.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(slice.New(b)) + s.Run("GetLatestWorkspaceBuildsByWorkspaceIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + wsID := uuid.New() + b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{}) + dbm.EXPECT().GetLatestWorkspaceBuildsByWorkspaceIDs(gomock.Any(), []uuid.UUID{wsID}).Return([]database.WorkspaceBuild{b}, nil).AnyTimes() + check.Args([]uuid.UUID{wsID}).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(slice.New(b)) })) - s.Run("UpsertDefaultProxy", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertDefaultProxyParams{}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() + s.Run("UpsertDefaultProxy", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpsertDefaultProxyParams{} + dbm.EXPECT().UpsertDefaultProxy(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() })) - s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID}) + s.Run("GetUserLinkByLinkedID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + l := testutil.Fake(s.T(), faker, database.UserLink{}) + dbm.EXPECT().GetUserLinkByLinkedID(gomock.Any(), l.LinkedID).Return(l, nil).AnyTimes() check.Args(l.LinkedID).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(l) })) - s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - l := dbgen.UserLink(s.T(), db, database.UserLink{}) - check.Args(database.GetUserLinkByUserIDLoginTypeParams{ - UserID: l.UserID, - LoginType: l.LoginType, - }).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(l) + s.Run("GetUserLinkByUserIDLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + l := testutil.Fake(s.T(), faker, database.UserLink{}) + arg := database.GetUserLinkByUserIDLoginTypeParams{UserID: l.UserID, LoginType: l.LoginType} + dbm.EXPECT().GetUserLinkByUserIDLoginType(gomock.Any(), arg).Return(l, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(l) })) - s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetActiveUserCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetActiveUserCount(gomock.Any(), false).Return(int64(0), nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0)) })) - s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) + s.Run("GetAuthorizationUserRoles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().GetAuthorizationUserRoles(gomock.Any(), u.ID).Return(database.GetAuthorizationUserRolesRow{}, nil).AnyTimes() check.Args(u.ID).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { - db.InsertDERPMeshKey(context.Background(), "testing") + s.Run("GetDERPMeshKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetDERPMeshKey(gomock.Any()).Return("testing", nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("InsertDERPMeshKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().InsertDERPMeshKey(gomock.Any(), "value").Return(nil).AnyTimes() check.Args("value").Asserts(rbac.ResourceSystem, policy.ActionCreate).Returns() })) - s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) { + s.Run("InsertDeploymentID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().InsertDeploymentID(gomock.Any(), "value").Return(nil).AnyTimes() check.Args("value").Asserts(rbac.ResourceSystem, policy.ActionCreate).Returns() })) - s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertReplicaParams{ - ID: uuid.New(), - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertReplica", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertReplicaParams{ID: uuid.New()} + dbm.EXPECT().InsertReplica(gomock.Any(), arg).Return(database.Replica{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) { - replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()}) - require.NoError(s.T(), err) - check.Args(database.UpdateReplicaParams{ - ID: replica.ID, - DatabaseLatency: 100, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) + s.Run("UpdateReplica", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + rep := testutil.Fake(s.T(), faker, database.Replica{}) + arg := database.UpdateReplicaParams{ID: rep.ID, DatabaseLatency: 100} + dbm.EXPECT().UpdateReplica(gomock.Any(), arg).Return(rep, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) { - _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) - require.NoError(s.T(), err) - check.Args(time.Now().Add(time.Hour)).Asserts(rbac.ResourceSystem, policy.ActionDelete) + s.Run("DeleteReplicasUpdatedBefore", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := dbtime.Now().Add(time.Hour) + dbm.EXPECT().DeleteReplicasUpdatedBefore(gomock.Any(), t).Return(nil).AnyTimes() + check.Args(t).Asserts(rbac.ResourceSystem, policy.ActionDelete) })) - s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) { - _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) - require.NoError(s.T(), err) - check.Args(time.Now().Add(time.Hour*-1)).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetReplicasUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := dbtime.Now().Add(-time.Hour) + dbm.EXPECT().GetReplicasUpdatedAfter(gomock.Any(), t).Return([]database.Replica{}, nil).AnyTimes() + check.Args(t).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetUserCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetUserCount(gomock.Any(), false).Return(int64(0), nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0)) })) - s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.Template(s.T(), db, database.Template{}) - check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) - })) - s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) - o := b - o.DailyCost = 10 - check.Args(database.UpdateWorkspaceBuildCostByIDParams{ - ID: b.ID, - DailyCost: 10, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) - })) - s.Run("UpdateWorkspaceBuildProvisionerStateByID", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - check.Args(database.UpdateWorkspaceBuildProvisionerStateByIDParams{ - ID: build.ID, - ProvisionerState: []byte("testing"), - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) - })) - s.Run("UpsertLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceSystem, policy.ActionUpdate) - })) - s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertLastUpdateCheck(context.Background(), "value") - require.NoError(s.T(), err) + s.Run("GetTemplates", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetTemplates(gomock.Any()).Return([]database.Template{}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) - })) - s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("UpdateWorkspaceBuildCostByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{}) + arg := database.UpdateWorkspaceBuildCostByIDParams{ID: b.ID, DailyCost: 10} + dbm.EXPECT().UpdateWorkspaceBuildCostByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour), OpenIn: database.WorkspaceAppOpenInSlimWindow}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("UpdateWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{}) + arg := database.UpdateWorkspaceBuildProvisionerStateByIDParams{ID: b.ID, ProvisionerState: []byte("testing")} + dbm.EXPECT().UpdateWorkspaceBuildProvisionerStateByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("UpsertLastUpdateCheck", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertLastUpdateCheck(gomock.Any(), "value").Return(nil).AnyTimes() + check.Args("value").Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - _ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetLastUpdateCheck", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetLastUpdateCheck(gomock.Any()).Return("value", nil).AnyTimes() + check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("DeleteOldWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetWorkspaceBuildsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetWorkspaceBuildsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceBuild{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspaceAgentsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspaceResourcesCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetWorkspaceResourcesCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceResource{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetWorkspaceResourceMetadataCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceResourceMetadatum{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("DeleteOldWorkspaceAgentStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().DeleteOldWorkspaceAgentStats(gomock.Any()).Return(nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionDelete) })) - s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead) - })) - s.Run("GetTemplateVersionsByIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - t1 := dbgen.Template(s.T(), db, database.Template{}) - t2 := dbgen.Template(s.T(), db, database.Template{}) - tv1 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - tv2 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, - }) - tv3 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, - }) - check.Args([]uuid.UUID{tv1.ID, tv2.ID, tv3.ID}). + s.Run("GetProvisionerJobsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ts := dbtime.Now() + dbm.EXPECT().GetProvisionerJobsCreatedAfter(gomock.Any(), ts).Return([]database.ProvisionerJob{}, nil).AnyTimes() + check.Args(ts).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead) + })) + s.Run("GetTemplateVersionsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + tv1 := testutil.Fake(s.T(), faker, database.TemplateVersion{}) + tv2 := testutil.Fake(s.T(), faker, database.TemplateVersion{}) + tv3 := testutil.Fake(s.T(), faker, database.TemplateVersion{}) + ids := []uuid.UUID{tv1.ID, tv2.ID, tv3.ID} + dbm.EXPECT().GetTemplateVersionsByIDs(gomock.Any(), ids).Return([]database.TemplateVersion{tv1, tv2, tv3}, nil).AnyTimes() + check.Args(ids). Asserts(rbac.ResourceSystem, policy.ActionRead). Returns(slice.New(tv1, tv2, tv3)) })) - s.Run("GetParameterSchemasByJobID", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - tpl := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - }) - job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: tv.JobID}) - check.Args(job.ID). + s.Run("GetParameterSchemasByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + tpl := testutil.Fake(s.T(), faker, database.Template{}) + v := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}}) + jobID := v.JobID + dbm.EXPECT().GetTemplateVersionByJobID(gomock.Any(), jobID).Return(v, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes() + dbm.EXPECT().GetParameterSchemasByJobID(gomock.Any(), jobID).Return([]database.ParameterSchema{}, nil).AnyTimes() + check.Args(jobID). Asserts(tpl, policy.ActionRead). ErrorsWithInMemDB(sql.ErrNoRows). Returns([]database.ParameterSchema{}) })) - s.Run("GetWorkspaceAppsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - aWs := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - aBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: aWs.ID, JobID: uuid.New()}) - aRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: aBuild.JobID}) - aAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: aRes.ID}) - a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: aAgt.ID, OpenIn: database.WorkspaceAppOpenInSlimWindow}) - - bWs := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - bBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: bWs.ID, JobID: uuid.New()}) - bRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: bBuild.JobID}) - bAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: bRes.ID}) - b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: bAgt.ID, OpenIn: database.WorkspaceAppOpenInSlimWindow}) - - check.Args([]uuid.UUID{a.AgentID, b.AgentID}). + s.Run("GetWorkspaceAppsByAgentIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + a := testutil.Fake(s.T(), faker, database.WorkspaceApp{}) + b := testutil.Fake(s.T(), faker, database.WorkspaceApp{}) + ids := []uuid.UUID{a.AgentID, b.AgentID} + dbm.EXPECT().GetWorkspaceAppsByAgentIDs(gomock.Any(), ids).Return([]database.WorkspaceApp{a, b}, nil).AnyTimes() + check.Args(ids). Asserts(rbac.ResourceSystem, policy.ActionRead). Returns([]database.WorkspaceApp{a, b}) })) - s.Run("GetWorkspaceResourcesByJobIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) - tJob := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) - - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - wJob := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - check.Args([]uuid.UUID{tJob.ID, wJob.ID}). + s.Run("GetWorkspaceResourcesByJobIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New(), uuid.New()} + dbm.EXPECT().GetWorkspaceResourcesByJobIDs(gomock.Any(), ids).Return([]database.WorkspaceResource{}, nil).AnyTimes() + check.Args(ids). Asserts(rbac.ResourceSystem, policy.ActionRead). Returns([]database.WorkspaceResource{}) })) - s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - _ = dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - a := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - b := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - check.Args([]uuid.UUID{a.ID, b.ID}). + s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New(), uuid.New()} + dbm.EXPECT().GetWorkspaceResourceMetadataByResourceIDs(gomock.Any(), ids).Return([]database.WorkspaceResourceMetadatum{}, nil).AnyTimes() + check.Args(ids). Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceAgentsByResourceIDs", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args([]uuid.UUID{res.ID}). + s.Run("GetWorkspaceAgentsByResourceIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + resID := uuid.New() + agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + dbm.EXPECT().GetWorkspaceAgentsByResourceIDs(gomock.Any(), []uuid.UUID{resID}).Return([]database.WorkspaceAgent{agt}, nil).AnyTimes() + check.Args([]uuid.UUID{resID}). Asserts(rbac.ResourceSystem, policy.ActionRead). Returns([]database.WorkspaceAgent{agt}) })) - s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - a := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{OrganizationID: o.ID}) - b := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{OrganizationID: o.ID}) - check.Args([]uuid.UUID{a.ID, b.ID}). - Asserts(rbac.ResourceProvisionerJobs.InOrg(o.ID), policy.ActionRead). + s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + a := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org.ID}) + b := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org.ID}) + ids := []uuid.UUID{a.ID, b.ID} + dbm.EXPECT().GetProvisionerJobsByIDs(gomock.Any(), ids).Return([]database.ProvisionerJob{a, b}, nil).AnyTimes() + check.Args(ids). + Asserts(rbac.ResourceProvisionerJobs.InOrg(org.ID), policy.ActionRead). Returns(slice.New(a, b)) })) - s.Run("DeleteWorkspaceSubAgentByID", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) - tpl := dbgen.Template(s.T(), db, database.Template{CreatedBy: u.ID, OrganizationID: o.ID}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{OwnerID: u.ID, TemplateID: tpl.ID, OrganizationID: o.ID}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: j.ID, TemplateVersionID: tv.ID}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: j.ID}) - agent := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID, ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID}}) + s.Run("DeleteWorkspaceSubAgentByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ws := testutil.Fake(s.T(), faker, database.Workspace{}) + agent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(ws, nil).AnyTimes() + dbm.EXPECT().DeleteWorkspaceSubAgentByID(gomock.Any(), agent.ID).Return(nil).AnyTimes() check.Args(agent.ID).Asserts(ws, policy.ActionDeleteAgent) })) - s.Run("GetWorkspaceAgentsByParentID", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) - tpl := dbgen.Template(s.T(), db, database.Template{CreatedBy: u.ID, OrganizationID: o.ID}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{OwnerID: u.ID, TemplateID: tpl.ID, OrganizationID: o.ID}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: j.ID, TemplateVersionID: tv.ID}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: j.ID}) - agent := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID, ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID}}) - check.Args(agent.ID).Asserts(ws, policy.ActionRead) - })) - s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) - tpl := dbgen.Template(s.T(), db, database.Template{CreatedBy: u.ID, OrganizationID: o.ID}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{OwnerID: u.ID, TemplateID: tpl.ID, OrganizationID: o.ID}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: j.ID, TemplateVersionID: tv.ID}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: j.ID}) - check.Args(database.InsertWorkspaceAgentParams{ - ID: uuid.New(), - ResourceID: res.ID, - Name: "dev", - APIKeyScope: database.AgentKeyScopeEnumAll, - }).Asserts(ws, policy.ActionCreateAgent) - })) - s.Run("UpsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) - tpl := dbgen.Template(s.T(), db, database.Template{CreatedBy: u.ID, OrganizationID: o.ID}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{OwnerID: u.ID, TemplateID: tpl.ID, OrganizationID: o.ID}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: j.ID, TemplateVersionID: tv.ID}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: j.ID}) - agent := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpsertWorkspaceAppParams{ - ID: uuid.New(), - AgentID: agent.ID, - Health: database.WorkspaceAppHealthDisabled, - SharingLevel: database.AppSharingLevelOwner, - OpenIn: database.WorkspaceAppOpenInSlimWindow, - }).Asserts(ws, policy.ActionUpdate) - })) - s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceResourceMetadataParams{ - WorkspaceResourceID: uuid.New(), - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) - })) - s.Run("UpdateWorkspaceAgentConnectionByID", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: agt.ID, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() + s.Run("GetWorkspaceAgentsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ws := testutil.Fake(s.T(), faker, database.Workspace{}) + parent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + child := testutil.Fake(s.T(), faker, database.WorkspaceAgent{ParentID: uuid.NullUUID{Valid: true, UUID: parent.ID}}) + dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), parent.ID).Return(ws, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceAgentsByParentID(gomock.Any(), parent.ID).Return([]database.WorkspaceAgent{child}, nil).AnyTimes() + check.Args(parent.ID).Asserts(ws, policy.ActionRead) })) - s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - StartedAt: sql.NullTime{Valid: false}, - UpdatedAt: time.Now(), - }) - check.Args(database.AcquireProvisionerJobParams{ - StartedAt: sql.NullTime{Valid: true, Time: time.Now()}, - OrganizationID: j.OrganizationID, - Types: []database.ProvisionerType{j.Provisioner}, - ProvisionerTags: must(json.Marshal(j.Tags)), - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: j.ID, - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpdateProvisionerJobWithCompleteWithStartedAtByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{ - ID: j.ID, - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobByIDParams{ - ID: j.ID, - UpdatedAt: time.Now(), - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpdateProvisionerJobLogsLength", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobLogsLengthParams{ - ID: j.ID, - LogsLength: 100, - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpdateProvisionerJobLogsOverflowed", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobLogsOverflowedParams{ - ID: j.ID, - LogsOverflowed: true, - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertProvisionerJobParams{ + s.Run("InsertWorkspaceAgent", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ws := testutil.Fake(s.T(), faker, database.Workspace{}) + res := testutil.Fake(s.T(), faker, database.WorkspaceResource{}) + arg := database.InsertWorkspaceAgentParams{ID: uuid.New(), ResourceID: res.ID, Name: "dev", APIKeyScope: database.AgentKeyScopeEnumAll} + dbm.EXPECT().GetWorkspaceByResourceID(gomock.Any(), res.ID).Return(ws, nil).AnyTimes() + dbm.EXPECT().InsertWorkspaceAgent(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.WorkspaceAgent{ResourceID: res.ID}), nil).AnyTimes() + check.Args(arg).Asserts(ws, policy.ActionCreateAgent) + })) + s.Run("UpsertWorkspaceApp", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ws := testutil.Fake(s.T(), faker, database.Workspace{}) + agent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + arg := database.UpsertWorkspaceAppParams{ID: uuid.New(), AgentID: agent.ID, Health: database.WorkspaceAppHealthDisabled, SharingLevel: database.AppSharingLevelOwner, OpenIn: database.WorkspaceAppOpenInSlimWindow} + dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(ws, nil).AnyTimes() + dbm.EXPECT().UpsertWorkspaceApp(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.WorkspaceApp{AgentID: agent.ID}), nil).AnyTimes() + check.Args(arg).Asserts(ws, policy.ActionUpdate) + })) + s.Run("InsertWorkspaceResourceMetadata", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceResourceMetadataParams{WorkspaceResourceID: uuid.New()} + dbm.EXPECT().InsertWorkspaceResourceMetadata(gomock.Any(), arg).Return([]database.WorkspaceResourceMetadatum{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) + })) + s.Run("UpdateWorkspaceAgentConnectionByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + arg := database.UpdateWorkspaceAgentConnectionByIDParams{ID: agt.ID} + dbm.EXPECT().UpdateWorkspaceAgentConnectionByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() + })) + s.Run("AcquireProvisionerJob", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.AcquireProvisionerJobParams{StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, OrganizationID: uuid.New(), Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, ProvisionerTags: json.RawMessage("{}")} + dbm.EXPECT().AcquireProvisionerJob(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.ProvisionerJob{}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpdateProvisionerJobWithCompleteByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.UpdateProvisionerJobWithCompleteByIDParams{ID: j.ID} + dbm.EXPECT().UpdateProvisionerJobWithCompleteByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpdateProvisionerJobWithCompleteWithStartedAtByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{ID: j.ID} + dbm.EXPECT().UpdateProvisionerJobWithCompleteWithStartedAtByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpdateProvisionerJobByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.UpdateProvisionerJobByIDParams{ID: j.ID, UpdatedAt: dbtime.Now()} + dbm.EXPECT().UpdateProvisionerJobByID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpdateProvisionerJobLogsLength", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.UpdateProvisionerJobLogsLengthParams{ID: j.ID, LogsLength: 100} + dbm.EXPECT().UpdateProvisionerJobLogsLength(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpdateProvisionerJobLogsOverflowed", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.UpdateProvisionerJobLogsOverflowedParams{ID: j.ID, LogsOverflowed: true} + dbm.EXPECT().UpdateProvisionerJobLogsOverflowed(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("InsertProvisionerJob", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: json.RawMessage("{}"), - }).Asserts( /* rbac.ResourceProvisionerJobs, policy.ActionCreate */ ) - })) - s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.InsertProvisionerJobLogsParams{ - JobID: j.ID, - }).Asserts( /* rbac.ResourceProvisionerJobs, policy.ActionUpdate */ ) - })) - s.Run("InsertProvisionerJobTimings", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - check.Args(database.InsertProvisionerJobTimingsParams{ - JobID: j.ID, - }).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) - })) - s.Run("UpsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - org := dbgen.Organization(s.T(), db, database.Organization{}) + } + dbm.EXPECT().InsertProvisionerJob(gomock.Any(), arg).Return(testutil.Fake(s.T(), gofakeit.New(0), database.ProvisionerJob{}), nil).AnyTimes() + check.Args(arg).Asserts( /* rbac.ResourceProvisionerJobs, policy.ActionCreate */ ) + })) + s.Run("InsertProvisionerJobLogs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.InsertProvisionerJobLogsParams{JobID: j.ID} + dbm.EXPECT().InsertProvisionerJobLogs(gomock.Any(), arg).Return([]database.ProvisionerJobLog{}, nil).AnyTimes() + check.Args(arg).Asserts( /* rbac.ResourceProvisionerJobs, policy.ActionUpdate */ ) + })) + s.Run("InsertProvisionerJobTimings", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + arg := database.InsertProvisionerJobTimingsParams{JobID: j.ID} + dbm.EXPECT().InsertProvisionerJobTimings(gomock.Any(), arg).Return([]database.ProvisionerJobTiming{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionUpdate) + })) + s.Run("UpsertProvisionerDaemon", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) pd := rbac.ResourceProvisionerDaemon.InOrg(org.ID) - check.Args(database.UpsertProvisionerDaemonParams{ + argOrg := database.UpsertProvisionerDaemonParams{ OrganizationID: org.ID, Provisioners: []database.ProvisionerType{}, - Tags: database.StringMap(map[string]string{ - provisionersdk.TagScope: provisionersdk.ScopeOrganization, - }), - }).Asserts(pd, policy.ActionCreate) - check.Args(database.UpsertProvisionerDaemonParams{ + Tags: database.StringMap(map[string]string{provisionersdk.TagScope: provisionersdk.ScopeOrganization}), + } + dbm.EXPECT().UpsertProvisionerDaemon(gomock.Any(), argOrg).Return(testutil.Fake(s.T(), faker, database.ProvisionerDaemon{OrganizationID: org.ID}), nil).AnyTimes() + check.Args(argOrg).Asserts(pd, policy.ActionCreate) + + argUser := database.UpsertProvisionerDaemonParams{ OrganizationID: org.ID, Provisioners: []database.ProvisionerType{}, - Tags: database.StringMap(map[string]string{ - provisionersdk.TagScope: provisionersdk.ScopeUser, - provisionersdk.TagOwner: "11111111-1111-1111-1111-111111111111", - }), - }).Asserts(pd.WithOwner("11111111-1111-1111-1111-111111111111"), policy.ActionCreate) + Tags: database.StringMap(map[string]string{provisionersdk.TagScope: provisionersdk.ScopeUser, provisionersdk.TagOwner: "11111111-1111-1111-1111-111111111111"}), + } + dbm.EXPECT().UpsertProvisionerDaemon(gomock.Any(), argUser).Return(testutil.Fake(s.T(), faker, database.ProvisionerDaemon{OrganizationID: org.ID}), nil).AnyTimes() + check.Args(argUser).Asserts(pd.WithOwner("11111111-1111-1111-1111-111111111111"), policy.ActionCreate) })) - s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{}) - check.Args(database.InsertTemplateVersionParameterParams{ - TemplateVersionID: v.ID, - Options: json.RawMessage("{}"), - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertTemplateVersionParameter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + v := testutil.Fake(s.T(), faker, database.TemplateVersion{}) + arg := database.InsertTemplateVersionParameterParams{TemplateVersionID: v.ID, Options: json.RawMessage("{}")} + dbm.EXPECT().InsertTemplateVersionParameter(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.TemplateVersionParameter{TemplateVersionID: v.ID}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertWorkspaceAppStatus", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertWorkspaceAppStatusParams{ - ID: uuid.New(), - State: "working", - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceAppStatus", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAppStatusParams{ID: uuid.New(), State: "working"} + dbm.EXPECT().InsertWorkspaceAppStatus(gomock.Any(), arg).Return(testutil.Fake(s.T(), gofakeit.New(0), database.WorkspaceAppStatus{ID: arg.ID, State: arg.State}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertWorkspaceResourceParams{ - ID: uuid.New(), - Transition: database.WorkspaceTransitionStart, - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceResource", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceResourceParams{ID: uuid.New(), Transition: database.WorkspaceTransitionStart} + dbm.EXPECT().InsertWorkspaceResource(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.WorkspaceResource{ID: arg.ID}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("DeleteOldWorkspaceAgentLogs", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete) + s.Run("DeleteOldWorkspaceAgentLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().DeleteOldWorkspaceAgentLogs(gomock.Any(), t).Return(nil).AnyTimes() + check.Args(t).Asserts(rbac.ResourceSystem, policy.ActionDelete) })) - s.Run("InsertWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAgentStatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny) + s.Run("InsertWorkspaceAgentStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentStatsParams{} + dbm.EXPECT().InsertWorkspaceAgentStats(gomock.Any(), arg).Return(xerrors.New("any error")).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny) })) - s.Run("InsertWorkspaceAppStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAppStatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceAppStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAppStatsParams{} + dbm.EXPECT().InsertWorkspaceAppStats(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("UpsertWorkspaceAppAuditSession", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - pj := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: pj.ID}) - agent := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agent.ID}) - check.Args(database.UpsertWorkspaceAppAuditSessionParams{ - AgentID: agent.ID, - AppID: app.ID, - UserID: u.ID, - Ip: "127.0.0.1", - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) - })) - s.Run("InsertWorkspaceAgentScriptTimings", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertWorkspaceAgentScriptTimingsParams{ - ScriptID: uuid.New(), - Stage: database.WorkspaceAgentScriptTimingStageStart, - Status: database.WorkspaceAgentScriptTimingStatusOk, - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("UpsertWorkspaceAppAuditSession", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + agent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + app := testutil.Fake(s.T(), faker, database.WorkspaceApp{}) + arg := database.UpsertWorkspaceAppAuditSessionParams{AgentID: agent.ID, AppID: app.ID, UserID: u.ID, Ip: "127.0.0.1"} + dbm.EXPECT().UpsertWorkspaceAppAuditSession(gomock.Any(), arg).Return(true, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("InsertWorkspaceAgentScripts", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAgentScriptsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceAgentScriptTimings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentScriptTimingsParams{ScriptID: uuid.New(), Stage: database.WorkspaceAgentScriptTimingStageStart, Status: database.WorkspaceAgentScriptTimingStatusOk} + dbm.EXPECT().InsertWorkspaceAgentScriptTimings(gomock.Any(), arg).Return(testutil.Fake(s.T(), gofakeit.New(0), database.WorkspaceAgentScriptTiming{ScriptID: arg.ScriptID}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertWorkspaceAgentMetadata", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertWorkspaceAgentMetadataParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceAgentScripts", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentScriptsParams{} + dbm.EXPECT().InsertWorkspaceAgentScripts(gomock.Any(), arg).Return([]database.WorkspaceAgentScript{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertWorkspaceAgentLogs", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAgentLogsParams{}).Asserts() + s.Run("InsertWorkspaceAgentMetadata", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentMetadataParams{} + dbm.EXPECT().InsertWorkspaceAgentMetadata(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertWorkspaceAgentLogSources", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAgentLogSourcesParams{}).Asserts() + s.Run("InsertWorkspaceAgentLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentLogsParams{} + dbm.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), arg).Return([]database.WorkspaceAgentLog{}, nil).AnyTimes() + check.Args(arg).Asserts() })) - s.Run("GetTemplateDAUs", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetTemplateDAUsParams{}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("InsertWorkspaceAgentLogSources", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceAgentLogSourcesParams{} + dbm.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), arg).Return([]database.WorkspaceAgentLogSource{}, nil).AnyTimes() + check.Args(arg).Asserts() })) - s.Run("GetActiveWorkspaceBuildsByTemplateID", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()). - Asserts(rbac.ResourceSystem, policy.ActionRead). - ErrorsWithInMemDB(sql.ErrNoRows). - Returns([]database.WorkspaceBuild{}) + s.Run("GetTemplateDAUs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetTemplateDAUsParams{} + dbm.EXPECT().GetTemplateDAUs(gomock.Any(), arg).Return([]database.GetTemplateDAUsRow{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetActiveWorkspaceBuildsByTemplateID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().GetActiveWorkspaceBuildsByTemplateID(gomock.Any(), id).Return([]database.WorkspaceBuild{}, nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.WorkspaceBuild{}) })) - s.Run("GetDeploymentDAUs", s.Subtest(func(db database.Store, check *expects) { - check.Args(int32(0)).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetDeploymentDAUs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + tz := int32(0) + dbm.EXPECT().GetDeploymentDAUs(gomock.Any(), tz).Return([]database.GetDeploymentDAUsRow{}, nil).AnyTimes() + check.Args(tz).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetAppSecurityKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetAppSecurityKey(gomock.Any()).Return("", sql.ErrNoRows).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).ErrorsWithPG(sql.ErrNoRows) })) - s.Run("UpsertAppSecurityKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertAppSecurityKey(gomock.Any(), "foo").Return(nil).AnyTimes() check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetApplicationName", s.Subtest(func(db database.Store, check *expects) { - db.UpsertApplicationName(context.Background(), "foo") + s.Run("GetApplicationName", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetApplicationName(gomock.Any()).Return("foo", nil).AnyTimes() check.Args().Asserts() })) - s.Run("UpsertApplicationName", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertApplicationName", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertApplicationName(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetHealthSettings", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetHealthSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetHealthSettings(gomock.Any()).Return("{}", nil).AnyTimes() check.Args().Asserts() })) - s.Run("UpsertHealthSettings", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertHealthSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertHealthSettings(gomock.Any(), "foo").Return(nil).AnyTimes() check.Args("foo").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetNotificationsSettings", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetNotificationsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetNotificationsSettings(gomock.Any()).Return("{}", nil).AnyTimes() check.Args().Asserts() })) - s.Run("UpsertNotificationsSettings", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertNotificationsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertNotificationsSettings(gomock.Any(), "foo").Return(nil).AnyTimes() check.Args("foo").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetDeploymentWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() + s.Run("GetDeploymentWorkspaceAgentStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetDeploymentWorkspaceAgentStats(gomock.Any(), t).Return(database.GetDeploymentWorkspaceAgentStatsRow{}, nil).AnyTimes() + check.Args(t).Asserts() })) - s.Run("GetDeploymentWorkspaceAgentUsageStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() + s.Run("GetDeploymentWorkspaceAgentUsageStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetDeploymentWorkspaceAgentUsageStats(gomock.Any(), t).Return(database.GetDeploymentWorkspaceAgentUsageStatsRow{}, nil).AnyTimes() + check.Args(t).Asserts() })) - s.Run("GetDeploymentWorkspaceStats", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetDeploymentWorkspaceStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetDeploymentWorkspaceStats(gomock.Any()).Return(database.GetDeploymentWorkspaceStatsRow{}, nil).AnyTimes() check.Args().Asserts() })) - s.Run("GetFileTemplates", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetFileTemplates", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().GetFileTemplates(gomock.Any(), id).Return([]database.GetFileTemplatesRow{}, nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetProvisionerJobsToBeReaped", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetProvisionerJobsToBeReapedParams{}).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead) + s.Run("GetProvisionerJobsToBeReaped", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetProvisionerJobsToBeReapedParams{} + dbm.EXPECT().GetProvisionerJobsToBeReaped(gomock.Any(), arg).Return([]database.ProvisionerJob{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead) })) - s.Run("UpsertOAuthSigningKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertOAuthSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertOAuthSigningKey(gomock.Any(), "foo").Return(nil).AnyTimes() check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetOAuthSigningKey", s.Subtest(func(db database.Store, check *expects) { - db.UpsertOAuthSigningKey(context.Background(), "foo") + s.Run("GetOAuthSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetOAuthSigningKey(gomock.Any()).Return("foo", nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("UpsertCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertCoordinatorResumeTokenSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), "foo").Return(nil).AnyTimes() check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) { - db.UpsertCoordinatorResumeTokenSigningKey(context.Background(), "foo") + s.Run("GetCoordinatorResumeTokenSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("foo", nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny) - })) - s.Run("UpdateUserLoginType", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserLoginTypeParams{ - NewLoginType: database.LoginTypePassword, - UserID: u.ID, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) - })) - s.Run("GetWorkspaceAgentStatsAndLabels", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() - })) - s.Run("GetWorkspaceAgentUsageStatsAndLabels", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() - })) - s.Run("GetWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() - })) - s.Run("GetWorkspaceAgentUsageStats", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() - })) - s.Run("GetWorkspaceProxyByHostname", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{ - WildcardHostname: "*.example.com", - }) - check.Args(database.GetWorkspaceProxyByHostnameParams{ - Hostname: "foo.example.com", - AllowWildcardHostname: true, - }).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(p) - })) - s.Run("GetTemplateAverageBuildTime", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetTemplateAverageBuildTimeParams{}).Asserts(rbac.ResourceSystem, policy.ActionRead) - })) - s.Run("GetWorkspacesByTemplateID", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.Nil).Asserts(rbac.ResourceSystem, policy.ActionRead) - })) - s.Run("GetWorkspacesEligibleForTransition", s.Subtest(func(db database.Store, check *expects) { - check.Args(time.Time{}).Asserts() + s.Run("InsertMissingGroups", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertMissingGroupsParams{} + dbm.EXPECT().InsertMissingGroups(gomock.Any(), arg).Return([]database.Group{}, xerrors.New("any error")).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny) })) - s.Run("InsertTemplateVersionVariable", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertTemplateVersionVariableParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("UpdateUserLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserLoginTypeParams{NewLoginType: database.LoginTypePassword, UserID: u.ID} + dbm.EXPECT().UpdateUserLoginType(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.User{}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) + })) + s.Run("GetWorkspaceAgentStatsAndLabels", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetWorkspaceAgentStatsAndLabels(gomock.Any(), t).Return([]database.GetWorkspaceAgentStatsAndLabelsRow{}, nil).AnyTimes() + check.Args(t).Asserts() + })) + s.Run("GetWorkspaceAgentUsageStatsAndLabels", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetWorkspaceAgentUsageStatsAndLabels(gomock.Any(), t).Return([]database.GetWorkspaceAgentUsageStatsAndLabelsRow{}, nil).AnyTimes() + check.Args(t).Asserts() + })) + s.Run("GetWorkspaceAgentStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetWorkspaceAgentStats(gomock.Any(), t).Return([]database.GetWorkspaceAgentStatsRow{}, nil).AnyTimes() + check.Args(t).Asserts() + })) + s.Run("GetWorkspaceAgentUsageStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetWorkspaceAgentUsageStats(gomock.Any(), t).Return([]database.GetWorkspaceAgentUsageStatsRow{}, nil).AnyTimes() + check.Args(t).Asserts() + })) + s.Run("GetWorkspaceProxyByHostname", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{WildcardHostname: "*.example.com"}) + arg := database.GetWorkspaceProxyByHostnameParams{Hostname: "foo.example.com", AllowWildcardHostname: true} + dbm.EXPECT().GetWorkspaceProxyByHostname(gomock.Any(), arg).Return(p, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(p) + })) + s.Run("GetTemplateAverageBuildTime", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetTemplateAverageBuildTimeParams{} + dbm.EXPECT().GetTemplateAverageBuildTime(gomock.Any(), arg).Return(database.GetTemplateAverageBuildTimeRow{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspacesByTemplateID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.Nil + dbm.EXPECT().GetWorkspacesByTemplateID(gomock.Any(), id).Return([]database.WorkspaceTable{}, nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetWorkspacesEligibleForTransition", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + t := time.Time{} + dbm.EXPECT().GetWorkspacesEligibleForTransition(gomock.Any(), t).Return([]database.GetWorkspacesEligibleForTransitionRow{}, nil).AnyTimes() + check.Args(t).Asserts() + })) + s.Run("InsertTemplateVersionVariable", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertTemplateVersionVariableParams{} + dbm.EXPECT().InsertTemplateVersionVariable(gomock.Any(), arg).Return(testutil.Fake(s.T(), gofakeit.New(0), database.TemplateVersionVariable{}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("InsertTemplateVersionWorkspaceTag", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - check.Args(database.InsertTemplateVersionWorkspaceTagParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertTemplateVersionWorkspaceTag", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertTemplateVersionWorkspaceTagParams{} + dbm.EXPECT().InsertTemplateVersionWorkspaceTag(gomock.Any(), arg).Return(testutil.Fake(s.T(), gofakeit.New(0), database.TemplateVersionWorkspaceTag{}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("UpdateInactiveUsersToDormant", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpdateInactiveUsersToDormantParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate). - ErrorsWithInMemDB(sql.ErrNoRows). - Returns([]database.UpdateInactiveUsersToDormantRow{}) + s.Run("UpdateInactiveUsersToDormant", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpdateInactiveUsersToDormantParams{} + dbm.EXPECT().UpdateInactiveUsersToDormant(gomock.Any(), arg).Return([]database.UpdateInactiveUsersToDormantRow{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate).Returns([]database.UpdateInactiveUsersToDormantRow{}) })) - s.Run("GetWorkspaceUniqueOwnerCountByTemplateIDs", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceUniqueOwnerCountByTemplateIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New()} + dbm.EXPECT().GetWorkspaceUniqueOwnerCountByTemplateIDs(gomock.Any(), ids).Return([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceAgentScriptsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceAgentScriptsByAgentIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New()} + dbm.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), ids).Return([]database.WorkspaceAgentScript{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceAgentLogSourcesByAgentIDs", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceAgentLogSourcesByAgentIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{uuid.New()} + dbm.EXPECT().GetWorkspaceAgentLogSourcesByAgentIDs(gomock.Any(), ids).Return([]database.WorkspaceAgentLogSource{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetProvisionerJobsByIDsWithQueuePosition", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetProvisionerJobsByIDsWithQueuePositionParams{}).Asserts() + s.Run("GetProvisionerJobsByIDsWithQueuePosition", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetProvisionerJobsByIDsWithQueuePositionParams{} + dbm.EXPECT().GetProvisionerJobsByIDsWithQueuePosition(gomock.Any(), arg).Return([]database.GetProvisionerJobsByIDsWithQueuePositionRow{}, nil).AnyTimes() + check.Args(arg).Asserts() })) - s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) + s.Run("GetReplicaByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().GetReplicaByID(gomock.Any(), id).Return(database.Replica{}, sql.ErrNoRows).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) })) - s.Run("GetWorkspaceAgentAndLatestBuildByAuthToken", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) + s.Run("GetWorkspaceAgentAndLatestBuildByAuthToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + tok := uuid.New() + dbm.EXPECT().GetWorkspaceAgentAndLatestBuildByAuthToken(gomock.Any(), tok).Return(database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows).AnyTimes() + check.Args(tok).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) })) - s.Run("GetUserLinksByUserID", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetUserLinksByUserID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().GetUserLinksByUserID(gomock.Any(), id).Return([]database.UserLink{}, nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("DeleteRuntimeConfig", s.Subtest(func(db database.Store, check *expects) { + s.Run("DeleteRuntimeConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().DeleteRuntimeConfig(gomock.Any(), "test").Return(nil).AnyTimes() check.Args("test").Asserts(rbac.ResourceSystem, policy.ActionDelete) })) - s.Run("GetRuntimeConfig", s.Subtest(func(db database.Store, check *expects) { - _ = db.UpsertRuntimeConfig(context.Background(), database.UpsertRuntimeConfigParams{ - Key: "test", - Value: "value", - }) + s.Run("GetRuntimeConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetRuntimeConfig(gomock.Any(), "test").Return("value", nil).AnyTimes() check.Args("test").Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("UpsertRuntimeConfig", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertRuntimeConfigParams{ - Key: "test", - Value: "value", - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) - })) - s.Run("GetFailedWorkspaceBuildsByTemplateID", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetFailedWorkspaceBuildsByTemplateIDParams{ - TemplateID: uuid.New(), - Since: dbtime.Now(), - }).Asserts(rbac.ResourceSystem, policy.ActionRead) - })) - s.Run("GetNotificationReportGeneratorLogByTemplate", s.Subtest(func(db database.Store, check *expects) { - _ = db.UpsertNotificationReportGeneratorLog(context.Background(), database.UpsertNotificationReportGeneratorLogParams{ - NotificationTemplateID: notifications.TemplateWorkspaceBuildsFailedReport, - LastGeneratedAt: dbtime.Now(), - }) - check.Args(notifications.TemplateWorkspaceBuildsFailedReport).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("UpsertRuntimeConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpsertRuntimeConfigParams{Key: "test", Value: "value"} + dbm.EXPECT().UpsertRuntimeConfig(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("GetWorkspaceBuildStatsByTemplates", s.Subtest(func(db database.Store, check *expects) { - check.Args(dbtime.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetFailedWorkspaceBuildsByTemplateID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetFailedWorkspaceBuildsByTemplateIDParams{TemplateID: uuid.New(), Since: dbtime.Now()} + dbm.EXPECT().GetFailedWorkspaceBuildsByTemplateID(gomock.Any(), arg).Return([]database.GetFailedWorkspaceBuildsByTemplateIDRow{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("UpsertNotificationReportGeneratorLog", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertNotificationReportGeneratorLogParams{ - NotificationTemplateID: uuid.New(), - LastGeneratedAt: dbtime.Now(), - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("GetNotificationReportGeneratorLogByTemplate", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetNotificationReportGeneratorLogByTemplate(gomock.Any(), notifications.TemplateWorkspaceBuildsFailedReport).Return(database.NotificationReportGeneratorLog{}, nil).AnyTimes() + check.Args(notifications.TemplateWorkspaceBuildsFailedReport).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetProvisionerJobTimingsByJobID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - org := dbgen.Organization(s.T(), db, database.Organization{}) - tpl := dbgen.Template(s.T(), db, database.Template{ - OrganizationID: org.ID, - CreatedBy: u.ID, - }) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - OrganizationID: org.ID, - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - CreatedBy: u.ID, - }) - w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{ - OwnerID: u.ID, - OrganizationID: org.ID, - TemplateID: tpl.ID, - }) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID, TemplateVersionID: tv.ID}) - t := dbgen.ProvisionerJobTimings(s.T(), db, b, 2) - check.Args(j.ID).Asserts(w, policy.ActionRead).Returns(t) + s.Run("GetWorkspaceBuildStatsByTemplates", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + at := dbtime.Now() + dbm.EXPECT().GetWorkspaceBuildStatsByTemplates(gomock.Any(), at).Return([]database.GetWorkspaceBuildStatsByTemplatesRow{}, nil).AnyTimes() + check.Args(at).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceAgentScriptTimingsByBuildID", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - workspace := dbgen.Workspace(s.T(), db, database.WorkspaceTable{}) - job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: job.ID, WorkspaceID: workspace.ID}) - resource := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{ - JobID: build.JobID, - }) - agent := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ - ResourceID: resource.ID, - }) - script := dbgen.WorkspaceAgentScript(s.T(), db, database.WorkspaceAgentScript{ - WorkspaceAgentID: agent.ID, - }) - timing := dbgen.WorkspaceAgentScriptTiming(s.T(), db, database.WorkspaceAgentScriptTiming{ - ScriptID: script.ID, - }) - rows := []database.GetWorkspaceAgentScriptTimingsByBuildIDRow{ - { - StartedAt: timing.StartedAt, - EndedAt: timing.EndedAt, - Stage: timing.Stage, - ScriptID: timing.ScriptID, - ExitCode: timing.ExitCode, - Status: timing.Status, - DisplayName: script.DisplayName, - WorkspaceAgentID: agent.ID, - WorkspaceAgentName: agent.Name, - }, - } - check.Args(build.ID).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(rows) + s.Run("UpsertNotificationReportGeneratorLog", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpsertNotificationReportGeneratorLogParams{NotificationTemplateID: uuid.New(), LastGeneratedAt: dbtime.Now()} + dbm.EXPECT().UpsertNotificationReportGeneratorLog(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("DisableForeignKeysAndTriggers", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetProvisionerJobTimingsByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) + b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{JobID: j.ID}) + ws := testutil.Fake(s.T(), faker, database.Workspace{ID: b.WorkspaceID}) + dbm.EXPECT().GetProvisionerJobByID(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), j.ID).Return(b, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceByID(gomock.Any(), b.WorkspaceID).Return(ws, nil).AnyTimes() + dbm.EXPECT().GetProvisionerJobTimingsByJobID(gomock.Any(), j.ID).Return([]database.ProvisionerJobTiming{}, nil).AnyTimes() + check.Args(j.ID).Asserts(ws, policy.ActionRead) + })) + s.Run("GetWorkspaceAgentScriptTimingsByBuildID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{}) + dbm.EXPECT().GetWorkspaceAgentScriptTimingsByBuildID(gomock.Any(), build.ID).Return([]database.GetWorkspaceAgentScriptTimingsByBuildIDRow{}, nil).AnyTimes() + check.Args(build.ID).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.GetWorkspaceAgentScriptTimingsByBuildIDRow{}) + })) + s.Run("DisableForeignKeysAndTriggers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().DisableForeignKeysAndTriggers(gomock.Any()).Return(nil).AnyTimes() check.Args().Asserts() })) - s.Run("InsertWorkspaceModule", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - check.Args(database.InsertWorkspaceModuleParams{ - JobID: j.ID, - Transition: database.WorkspaceTransitionStart, - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertWorkspaceModule", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) + arg := database.InsertWorkspaceModuleParams{JobID: j.ID, Transition: database.WorkspaceTransitionStart} + dbm.EXPECT().InsertWorkspaceModule(gomock.Any(), arg).Return(testutil.Fake(s.T(), faker, database.WorkspaceModule{JobID: j.ID}), nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("GetWorkspaceModulesByJobID", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceModulesByJobID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().GetWorkspaceModulesByJobID(gomock.Any(), id).Return([]database.WorkspaceModule{}, nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetWorkspaceModulesCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - check.Args(dbtime.Now()).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("GetWorkspaceModulesCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + at := dbtime.Now() + dbm.EXPECT().GetWorkspaceModulesCreatedAfter(gomock.Any(), at).Return([]database.WorkspaceModule{}, nil).AnyTimes() + check.Args(at).Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("GetTelemetryItem", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetTelemetryItem", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetTelemetryItem(gomock.Any(), "test").Return(database.TelemetryItem{}, sql.ErrNoRows).AnyTimes() check.Args("test").Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) })) - s.Run("GetTelemetryItems", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetTelemetryItems", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetTelemetryItems(gomock.Any()).Return([]database.TelemetryItem{}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) })) - s.Run("InsertTelemetryItemIfNotExists", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertTelemetryItemIfNotExistsParams{ - Key: "test", - Value: "value", - }).Asserts(rbac.ResourceSystem, policy.ActionCreate) + s.Run("InsertTelemetryItemIfNotExists", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertTelemetryItemIfNotExistsParams{Key: "test", Value: "value"} + dbm.EXPECT().InsertTelemetryItemIfNotExists(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionCreate) })) - s.Run("UpsertTelemetryItem", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertTelemetryItemParams{ - Key: "test", - Value: "value", - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate) + s.Run("UpsertTelemetryItem", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpsertTelemetryItemParams{Key: "test", Value: "value"} + dbm.EXPECT().UpsertTelemetryItem(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate) })) - s.Run("GetOAuth2GithubDefaultEligible", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetOAuth2GithubDefaultEligible", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetOAuth2GithubDefaultEligible(gomock.Any()).Return(false, sql.ErrNoRows).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Errors(sql.ErrNoRows) })) - s.Run("UpsertOAuth2GithubDefaultEligible", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertOAuth2GithubDefaultEligible", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertOAuth2GithubDefaultEligible(gomock.Any(), true).Return(nil).AnyTimes() check.Args(true).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetWebpushVAPIDKeys", s.Subtest(func(db database.Store, check *expects) { - require.NoError(s.T(), db.UpsertWebpushVAPIDKeys(context.Background(), database.UpsertWebpushVAPIDKeysParams{ - VapidPublicKey: "test", - VapidPrivateKey: "test", - })) - check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(database.GetWebpushVAPIDKeysRow{ - VapidPublicKey: "test", - VapidPrivateKey: "test", - }) + s.Run("GetWebpushVAPIDKeys", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetWebpushVAPIDKeys(gomock.Any()).Return(database.GetWebpushVAPIDKeysRow{VapidPublicKey: "test", VapidPrivateKey: "test"}, nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(database.GetWebpushVAPIDKeysRow{VapidPublicKey: "test", VapidPrivateKey: "test"}) })) - s.Run("UpsertWebpushVAPIDKeys", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertWebpushVAPIDKeysParams{ - VapidPublicKey: "test", - VapidPrivateKey: "test", - }).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + s.Run("UpsertWebpushVAPIDKeys", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpsertWebpushVAPIDKeysParams{VapidPublicKey: "test", VapidPrivateKey: "test"} + dbm.EXPECT().UpsertWebpushVAPIDKeys(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("Build/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - tpl := dbgen.Template(s.T(), db, database.Template{ - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{ - OwnerID: u.ID, - OrganizationID: o.ID, - TemplateID: tpl.ID, - }) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - OrganizationID: o.ID, - CreatedBy: u.ID, - }) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{ - JobID: j.ID, - WorkspaceID: w.ID, - TemplateVersionID: tv.ID, - }) + s.Run("Build/GetProvisionerJobByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) + dbm.EXPECT().GetProvisionerJobByIDForUpdate(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + // Minimal assertion check argument + b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{JobID: j.ID}) + w := testutil.Fake(s.T(), faker, database.Workspace{ID: b.WorkspaceID}) + dbm.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), j.ID).Return(b, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceByID(gomock.Any(), b.WorkspaceID).Return(w, nil).AnyTimes() check.Args(j.ID).Asserts(w, policy.ActionRead).Returns(j) })) - s.Run("TemplateVersion/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionImport, - }) - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - }) - check.Args(j.ID).Asserts(v.RBACObject(tpl), policy.ActionRead).Returns(j) + s.Run("TemplateVersion/GetProvisionerJobByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeTemplateVersionImport}) + tpl := testutil.Fake(s.T(), faker, database.Template{}) + tv := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}}) + dbm.EXPECT().GetProvisionerJobByIDForUpdate(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + dbm.EXPECT().GetTemplateVersionByJobID(gomock.Any(), j.ID).Return(tv, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes() + check.Args(j.ID).Asserts(tv.RBACObject(tpl), policy.ActionRead).Returns(j) })) - s.Run("TemplateVersionDryRun/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - }) - j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: must(json.Marshal(struct { - TemplateVersionID uuid.UUID `json:"template_version_id"` - }{TemplateVersionID: v.ID})), - }) - check.Args(j.ID).Asserts(v.RBACObject(tpl), policy.ActionRead).Returns(j) + s.Run("TemplateVersionDryRun/GetProvisionerJobByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + tpl := testutil.Fake(s.T(), faker, database.Template{}) + tv := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}}) + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{}) + j.Type = database.ProvisionerJobTypeTemplateVersionDryRun + j.Input = must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: tv.ID})) + dbm.EXPECT().GetProvisionerJobByIDForUpdate(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + dbm.EXPECT().GetTemplateVersionByID(gomock.Any(), tv.ID).Return(tv, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes() + check.Args(j.ID).Asserts(tv.RBACObject(tpl), policy.ActionRead).Returns(j) })) } diff --git a/coderd/mcp/mcp_test.go b/coderd/mcp/mcp_test.go index 0c53c899b9830..b7b5a714780d9 100644 --- a/coderd/mcp/mcp_test.go +++ b/coderd/mcp/mcp_test.go @@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) { require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message") // Test registering tools with valid client should succeed - client := &codersdk.Client{} + client := codersdk.New(testutil.MustURL(t, "http://not-used")) err = server.RegisterTools(client) require.NoError(t, err) diff --git a/codersdk/client.go b/codersdk/client.go index 105c8437f841b..b6f10465e3a07 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{ // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ - URL: serverURL, - HTTPClient: &http.Client{}, + URL: serverURL, + HTTPClient: &http.Client{}, + SessionTokenProvider: FixedSessionTokenProvider{}, } } @@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client { type Client struct { // mu protects the fields sessionToken, logger, and logBodies. These // need to be safe for concurrent access. - mu sync.RWMutex - sessionToken string - logger slog.Logger - logBodies bool + mu sync.RWMutex + SessionTokenProvider SessionTokenProvider + logger slog.Logger + logBodies bool HTTPClient *http.Client URL *url.URL - // SessionTokenHeader is an optional custom header to use for setting tokens. By - // default 'Coder-Session-Token' is used. - SessionTokenHeader string - // PlainLogger may be set to log HTTP traffic in a human-readable form. // It uses the LogBodies option. PlainLogger io.Writer @@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) { func (c *Client) SessionToken() string { c.mu.RLock() defer c.mu.RUnlock() - return c.sessionToken + return c.SessionTokenProvider.GetSessionToken() } -// SetSessionToken returns the currently set token for the client. +// SetSessionToken sets a fixed token for the client. +// Deprecated: Create a new client instead of changing the token after creation. func (c *Client) SetSessionToken(token string) { + c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token}) +} + +// SetSessionTokenProvider sets the session token provider for the client. +func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) { c.mu.Lock() defer c.mu.Unlock() - c.sessionToken = token + c.SessionTokenProvider = provider } func prefixLines(prefix, s []byte) []byte { @@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte { // Request performs a HTTP request with the body provided. The caller is // responsible for closing the response body. func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { + opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...) + return c.RequestWithoutSessionToken(ctx, method, path, body, opts...) +} + +// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set +// the session token in the request header, nor does it make a call to the SessionTokenProvider. +// This allows session token providers to call this method without causing reentrancy issues. +func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { if ctx == nil { return nil, xerrors.Errorf("context should not be nil") } @@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac return nil, xerrors.Errorf("create request: %w", err) } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - req.Header.Set(tokenHeader, c.SessionToken()) - if r != nil { req.Header.Set("Content-Type", "application/json") } @@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti return nil, err } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - if opts == nil { opts = &websocket.DialOptions{} } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - if opts.HTTPHeader.Get(tokenHeader) == "" { - opts.HTTPHeader.Set(tokenHeader, c.SessionToken()) - } + c.SessionTokenProvider.SetDialOption(opts) conn, resp, err := websocket.Dial(ctx, u.String(), opts) if resp != nil && resp.Body != nil { diff --git a/codersdk/credentials.go b/codersdk/credentials.go new file mode 100644 index 0000000000000..06dc8cc22a114 --- /dev/null +++ b/codersdk/credentials.go @@ -0,0 +1,55 @@ +package codersdk + +import ( + "net/http" + + "github.com/coder/websocket" +) + +// SessionTokenProvider provides the session token to access the Coder service (coderd). +// @typescript-ignore SessionTokenProvider +type SessionTokenProvider interface { + // AsRequestOption returns a request option that attaches the session token to an HTTP request. + AsRequestOption() RequestOption + // SetDialOption sets the session token on a websocket request via DialOptions + SetDialOption(options *websocket.DialOptions) + // GetSessionToken returns the session token as a string. + GetSessionToken() string +} + +// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable +// at the program start. +// @typescript-ignore FixedSessionTokenProvider +type FixedSessionTokenProvider struct { + SessionToken string + // SessionTokenHeader is an optional custom header to use for setting tokens. By + // default, 'Coder-Session-Token' is used. + SessionTokenHeader string +} + +func (f FixedSessionTokenProvider) AsRequestOption() RequestOption { + return func(req *http.Request) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + req.Header.Set(tokenHeader, f.SessionToken) + } +} + +func (f FixedSessionTokenProvider) GetSessionToken() string { + return f.SessionToken +} + +func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + if opts.HTTPHeader.Get(tokenHeader) == "" { + opts.HTTPHeader.Set(tokenHeader, f.SessionToken) + } +} diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index ddaec06388238..29ddbd1f53094 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * options.BlockEndpoints = true } - headers := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.client.SessionTokenHeader != "" { - tokenHeader = c.client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: c.client.HTTPClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, } - headers.Set(tokenHeader, c.client.SessionToken()) + c.client.SessionTokenProvider.SetDialOption(wsOptions) // New context, separate from dialCtx. We don't want to cancel the // connection if dialCtx is canceled. @@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * return nil, xerrors.Errorf("parse url: %w", err) } - dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{ - HTTPClient: c.client.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) + dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions) clk := quartz.NewReal() controller := tailnet.NewController(options.Logger, dialer) controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk) diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 7024ad2366423..28d46c0137b0d 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) // Register req := wsproxysdk.RegisterWorkspaceProxyRequest{ @@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) req := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://proxy.coder.test", @@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: uuid.New(), @@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) { // Register a replica on proxy 2. This shouldn't be returned by replicas // for proxy 1. - proxyClient2 := wsproxysdk.New(client.URL) - proxyClient2.SetSessionToken(createRes2.ProxyToken) + proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken) _, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://other.proxy.coder.test", WildcardHostname: "*.other.proxy.coder.test", @@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) { require.NoError(t, err) // Register replica 1. - proxyClient1 := wsproxysdk.New(client.URL) - proxyClient1.SetSessionToken(createRes1.ProxyToken) + proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken) req1 := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://one.proxy.coder.test", WildcardHostname: "*.one.proxy.coder.test", @@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) for i := 0; i < 100; i++ { ok := false @@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("BadAppRequest", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{ @@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) { } t.Run("OK", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, goodRequest) @@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("OKHTML", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) rw := httptest.NewRecorder() ctx := testutil.Context(t, testutil.WaitLong) @@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) { Name: testutil.GetRandomName(t), }) - client := wsproxysdk.New(cclient.URL) - client.SetSessionToken(cclient.SessionToken()) + client := wsproxysdk.New(cclient.URL, cclient.SessionToken()) _, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.Error(t, err) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index c2ac1baf2db4e..6e1da2f25853d 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return nil, err } - client := wsproxysdk.New(opts.DashboardURL) - err := client.SetSessionToken(opts.ProxySessionToken) - if err != nil { - return nil, xerrors.Errorf("set client token: %w", err) - } + client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken) // Use the configured client if provided. if opts.HTTPClient != nil { diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 523d429476243..0e8e61af88995 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { t.Cleanup(srv.Close) // Register a proxy. - wsproxyClient := wsproxysdk.New(primaryAccessURL) - wsproxyClient.SetSessionToken(token) + wsproxyClient := wsproxysdk.New(primaryAccessURL, token) hostname, err := cryptorand.String(6) require.NoError(t, err) replicaID := uuid.New() @@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { require.Contains(t, respJSON.Warnings[0], "High availability networking") // Deregister the other replica. - wsproxyClient := wsproxysdk.New(api.AccessURL) - wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken) + wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken) err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: otherReplicaID, }) diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 72f5a4291c40e..15400a2d33c16 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -33,15 +33,20 @@ type Client struct { // New creates a external proxy client for the provided primary coder server // URL. -func New(serverURL *url.URL) *Client { +func New(serverURL *url.URL, sessionToken string) *Client { sdkClient := codersdk.New(serverURL) - sdkClient.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader - + sdkClient.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } sdkClientIgnoreRedirects := codersdk.New(serverURL) sdkClientIgnoreRedirects.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } - sdkClientIgnoreRedirects.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader + sdkClientIgnoreRedirects.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } return &Client{ SDKClient: sdkClient, @@ -49,14 +54,6 @@ func New(serverURL *url.URL) *Client { } } -// SetSessionToken sets the session token for the client. An error is returned -// if the session token is not in the correct format for external proxies. -func (c *Client) SetSessionToken(token string) error { - c.SDKClient.SetSessionToken(token) - c.sdkClientIgnoreRedirects.SetSessionToken(token) - return nil -} - // SessionToken returns the currently set token for the client. func (c *Client) SessionToken() string { return c.SDKClient.SessionToken() @@ -506,17 +503,12 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) { if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } - coordinateHeaders := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.SDKClient.SessionTokenHeader != "" { - tokenHeader = c.SDKClient.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: c.SDKClient.HTTPClient, } - coordinateHeaders.Set(tokenHeader, c.SessionToken()) + c.SDKClient.SessionTokenProvider.SetDialOption(wsOptions) - return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{ - HTTPClient: c.SDKClient.HTTPClient, - HTTPHeader: coordinateHeaders, - }), nil + return workspacesdk.NewWebsocketDialer(logger, coordinateURL, wsOptions), nil } type CryptoKeysResponse struct { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index aada23da9dc12..6b4da6831c9bf 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -60,8 +60,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) @@ -111,8 +110,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - _ = client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) diff --git a/go.mod b/go.mod index f111e6e6260d7..dd8109b35bcf0 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ replace github.com/tcnksm/go-httpstat => github.com/coder/go-httpstat v0.0.0-202 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716 +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e // This is replaced to include // 1. a fix for a data race: c.f. https://github.com/tailscale/wireguard-go/pull/25 @@ -530,7 +530,7 @@ require ( github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/tmaxmax/go-sse v0.10.0 // indirect - github.com/ulikunitz/xz v0.5.12 // indirect + github.com/ulikunitz/xz v0.5.15 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.36.0 // indirect diff --git a/go.sum b/go.sum index ba73e2228f398..b0ec2563d5dbf 100644 --- a/go.sum +++ b/go.sum @@ -928,8 +928,8 @@ github.com/coder/serpent v0.10.0 h1:ofVk9FJXSek+SmL3yVE3GoArP83M+1tX+H7S4t8BSuM= github.com/coder/serpent v0.10.0/go.mod h1:cZFW6/fP+kE9nd/oRkEHJpG6sXCtQ+AX7WMMEHv0Y3Q= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788 h1:YoUSJ19E8AtuUFVYBpXuOD6a/zVP3rcxezNsoDseTUw= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788/go.mod h1:aGQbuCLyhRLMzZF067xc84Lh7JDs1FKwCmF1Crl9dxQ= -github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716 h1:hi7o0sA+RPBq8Rvvz+hNrC/OTL2897OKREMIRIuQeTs= -github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716/go.mod h1:l7ml5uu7lFh5hY28lGYM4b/oFSmuPHYX6uk4RAu23Lc= +github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e h1:9RKGKzGLHtTvVBQublzDGtCtal3cXP13diCHoAIGPeI= +github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e/go.mod h1:jU9T1vEs+DOs8NtGp1F2PT0/TOGVwtg/JCCKYRgvMOs= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e h1:JNLPDi2P73laR1oAclY6jWzAbucf70ASAvf5mh2cME0= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= github.com/coder/terraform-provider-coder/v2 v2.10.0 h1:cGPMfARGHKb80kZsbDX/t/YKwMOwI5zkIyVCQziHR2M= @@ -1828,8 +1828,8 @@ github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1 github.com/u-root/uio v0.0.0-20240209044354-b3d14b93376a h1:BH1SOPEvehD2kVrndDnGJiUF0TrBpNs+iyYocu6h0og= github.com/u-root/uio v0.0.0-20240209044354-b3d14b93376a/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= -github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= +github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/scaletest/workspacetraffic/config.go b/scaletest/workspacetraffic/config.go index 6ef0760ff3013..0948d35ea7dbb 100644 --- a/scaletest/workspacetraffic/config.go +++ b/scaletest/workspacetraffic/config.go @@ -28,6 +28,9 @@ type Config struct { SSH bool `json:"ssh"` + // Ignored unless SSH is true. + DisableDirect bool `json:"ssh_disable_direct"` + // Echo controls whether the agent should echo the data it receives. // If false, the agent will discard the data. Note that setting this // to true will double the amount of data read from the agent for diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go index 7640203e6c224..3b516c6347225 100644 --- a/scaletest/workspacetraffic/conn.go +++ b/scaletest/workspacetraffic/conn.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net" - "net/http" "sync" "time" @@ -144,7 +143,7 @@ func (c *rptyConn) Close() (err error) { } //nolint:revive // Ignore requestPTY control flag. -func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, cmd string, requestPTY bool) (rwc *countReadWriteCloser, err error) { +func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, cmd string, requestPTY bool, blockEndpoints bool) (rwc *countReadWriteCloser, err error) { var closers []func() error defer func() { if err != nil { @@ -156,7 +155,9 @@ func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, } }() - agentConn, err := workspacesdk.New(client).DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{}) + agentConn, err := workspacesdk.New(client).DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{ + BlockEndpoints: blockEndpoints, + }) if err != nil { return nil, xerrors.Errorf("dial workspace agent: %w", err) } @@ -267,18 +268,13 @@ func (w *wrappedSSHConn) Write(p []byte) (n int, err error) { } func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) { - headers := http.Header{} - tokenHeader := codersdk.SessionTokenHeader - if client.SessionTokenHeader != "" { - tokenHeader = client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: client.HTTPClient, } - headers.Set(tokenHeader, client.SessionToken()) + client.SessionTokenProvider.SetDialOption(wsOptions) //nolint:bodyclose // The websocket conn manages the body. - conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ - HTTPClient: client.HTTPClient, - HTTPHeader: headers, - }) + conn, _, err := websocket.Dial(ctx, url, wsOptions) if err != nil { return nil, xerrors.Errorf("websocket dial: %w", err) } diff --git a/scaletest/workspacetraffic/run.go b/scaletest/workspacetraffic/run.go index cad6a9d51c6ce..7dd7cb6803695 100644 --- a/scaletest/workspacetraffic/run.go +++ b/scaletest/workspacetraffic/run.go @@ -111,7 +111,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) // If echo is enabled, disable PTY to avoid double echo and // reduce CPU usage. requestPTY := !r.cfg.Echo - conn, err = connectSSH(ctx, r.client, agentID, command, requestPTY) + conn, err = connectSSH(ctx, r.client, agentID, command, requestPTY, r.cfg.DisableDirect) if err != nil { logger.Error(ctx, "connect to workspace agent via ssh", slog.Error(err)) return xerrors.Errorf("connect to workspace via ssh: %w", err) diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index 59801e68d8f62..dd84747886456 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "runtime" "slices" "strings" @@ -313,9 +314,7 @@ func TestRun(t *testing.T) { readMetrics = &testMetrics{} writeMetrics = &testMetrics{} ) - client := &codersdk.Client{ - HTTPClient: &http.Client{}, - } + client := codersdk.New(&url.URL{}) runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ BytesPerTick: int64(bytesPerTick), TickInterval: tickInterval,