From cda1a3a5938c060d6ef4213552da653e180c0c63 Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Tue, 12 Aug 2025 11:20:12 +0200 Subject: [PATCH 1/3] chore(coderd/database/dbauthz): migrate TestUser to mocked db (#19305) Related to https://github.com/coder/internal/issues/869 --------- Co-authored-by: Steven Masley --- coderd/database/dbauthz/dbauthz_test.go | 596 ++++++++++-------------- 1 file changed, 257 insertions(+), 339 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 0638d3d007986..5c4f3a7d105fd 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1679,315 +1679,253 @@ func (s *MethodTestSuite) TestTemplate() { } func (s *MethodTestSuite) TestUser() { - s.Run("GetAuthorizedUsers", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - dbgen.User(s.T(), db, database.User{}) + s.Run("GetAuthorizedUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetUsersParams{} + dbm.EXPECT().GetAuthorizedUsers(gomock.Any(), arg, gomock.Any()).Return([]database.GetUsersRow{}, nil).AnyTimes() // No asserts because SQLFilter. - check.Args(database.GetUsersParams{}, emptyPreparedAuthorized{}). - Asserts() - })) - s.Run("DeleteAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(rbac.ResourceApiKey.WithOwner(u.ID.String()), policy.ActionDelete).Returns() - })) - s.Run("GetQuotaAllowanceForUser", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetQuotaAllowanceForUserParams{ - UserID: u.ID, - OrganizationID: uuid.New(), - }).Asserts(u, policy.ActionRead).Returns(int64(0)) + check.Args(arg, emptyPreparedAuthorized{}).Asserts() })) - s.Run("GetQuotaConsumedForUser", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetQuotaConsumedForUserParams{ - OwnerID: u.ID, - OrganizationID: uuid.New(), - }).Asserts(u, policy.ActionRead).Returns(int64(0)) - })) - s.Run("GetUserByEmailOrUsername", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetUserByEmailOrUsernameParams{ - Username: u.Username, - Email: u.Email, - }).Asserts(u, policy.ActionRead).Returns(u) - })) - s.Run("GetUserByID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) + s.Run("DeleteAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.APIKey{}) + dbm.EXPECT().DeleteAPIKeysByUserID(gomock.Any(), key.UserID).Return(nil).AnyTimes() + check.Args(key.UserID).Asserts(rbac.ResourceApiKey.WithOwner(key.UserID.String()), policy.ActionDelete).Returns() + })) + s.Run("GetQuotaAllowanceForUser", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.GetQuotaAllowanceForUserParams{UserID: u.ID, OrganizationID: uuid.New()} + dbm.EXPECT().GetQuotaAllowanceForUser(gomock.Any(), arg).Return(int64(0), nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionRead).Returns(int64(0)) + })) + s.Run("GetQuotaConsumedForUser", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.GetQuotaConsumedForUserParams{OwnerID: u.ID, OrganizationID: uuid.New()} + dbm.EXPECT().GetQuotaConsumedForUser(gomock.Any(), arg).Return(int64(0), nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionRead).Returns(int64(0)) + })) + s.Run("GetUserByEmailOrUsername", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.GetUserByEmailOrUsernameParams{Email: u.Email} + dbm.EXPECT().GetUserByEmailOrUsername(gomock.Any(), arg).Return(u, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionRead).Returns(u) + })) + s.Run("GetUserByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionRead).Returns(u) })) - s.Run("GetUsersByIDs", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{CreatedAt: dbtime.Now().Add(-time.Hour)}) - b := dbgen.User(s.T(), db, database.User{CreatedAt: dbtime.Now()}) - check.Args([]uuid.UUID{a.ID, b.ID}). - Asserts(a, policy.ActionRead, b, policy.ActionRead). - Returns(slice.New(a, b)) - })) - s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) - dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) - check.Args(database.GetUsersParams{}). - // Asserts are done in a SQL filter - Asserts() - })) - s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertUserParams{ - ID: uuid.New(), - LoginType: database.LoginTypePassword, - RBACRoles: []string{}, - }).Asserts(rbac.ResourceAssignRole, policy.ActionAssign, rbac.ResourceUser, policy.ActionCreate) - })) - s.Run("InsertUserLink", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertUserLinkParams{ - UserID: u.ID, - LoginType: database.LoginTypeOIDC, - }).Asserts(u, policy.ActionUpdate) - })) - s.Run("UpdateUserDeletedByID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) + s.Run("GetUsersByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + a := testutil.Fake(s.T(), faker, database.User{CreatedAt: dbtime.Now().Add(-time.Hour)}) + b := testutil.Fake(s.T(), faker, database.User{CreatedAt: dbtime.Now()}) + ids := []uuid.UUID{a.ID, b.ID} + dbm.EXPECT().GetUsersByIDs(gomock.Any(), ids).Return([]database.User{a, b}, nil).AnyTimes() + check.Args(ids).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetUsersParams{} + dbm.EXPECT().GetAuthorizedUsers(gomock.Any(), arg, gomock.Any()).Return([]database.GetUsersRow{}, nil).AnyTimes() + // Asserts are done in a SQL filter + check.Args(arg).Asserts() + })) + s.Run("InsertUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertUserParams{ID: uuid.New(), LoginType: database.LoginTypePassword, RBACRoles: []string{}} + dbm.EXPECT().InsertUser(gomock.Any(), arg).Return(database.User{ID: arg.ID, LoginType: arg.LoginType}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceAssignRole, policy.ActionAssign, rbac.ResourceUser, policy.ActionCreate) + })) + s.Run("InsertUserLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.InsertUserLinkParams{UserID: u.ID, LoginType: database.LoginTypeOIDC} + dbm.EXPECT().InsertUserLink(gomock.Any(), arg).Return(database.UserLink{}, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdate) + })) + s.Run("UpdateUserDeletedByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserDeletedByID(gomock.Any(), u.ID).Return(nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionDelete).Returns() })) - s.Run("UpdateUserGithubComUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserGithubComUserIDParams{ - ID: u.ID, - }).Asserts(u, policy.ActionUpdatePersonal) - })) - s.Run("UpdateUserHashedPassword", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserHashedPasswordParams{ - ID: u.ID, - }).Asserts(u, policy.ActionUpdatePersonal).Returns() - })) - s.Run("UpdateUserHashedOneTimePasscode", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserHashedOneTimePasscodeParams{ - ID: u.ID, - HashedOneTimePasscode: []byte{}, - OneTimePasscodeExpiresAt: sql.NullTime{Time: u.CreatedAt, Valid: true}, - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() - })) - s.Run("UpdateUserQuietHoursSchedule", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserQuietHoursScheduleParams{ - ID: u.ID, - }).Asserts(u, policy.ActionUpdatePersonal) - })) - s.Run("UpdateUserLastSeenAt", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserLastSeenAtParams{ - ID: u.ID, - UpdatedAt: u.UpdatedAt, - LastSeenAt: u.LastSeenAt, - }).Asserts(u, policy.ActionUpdate).Returns(u) - })) - s.Run("UpdateUserProfile", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserProfileParams{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Name: u.Name, - UpdatedAt: u.UpdatedAt, - }).Asserts(u, policy.ActionUpdatePersonal).Returns(u) - })) - s.Run("GetUserWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args( - database.GetUserWorkspaceBuildParametersParams{ - OwnerID: u.ID, - TemplateID: uuid.UUID{}, - }, - ).Asserts(u, policy.ActionReadPersonal).Returns( - []database.GetUserWorkspaceBuildParametersRow{}, - ) - })) - s.Run("GetUserThemePreference", s.Subtest(func(db database.Store, check *expects) { - ctx := context.Background() - u := dbgen.User(s.T(), db, database.User{}) - db.UpdateUserThemePreference(ctx, database.UpdateUserThemePreferenceParams{ - UserID: u.ID, - ThemePreference: "light", - }) + s.Run("UpdateUserGithubComUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserGithubComUserIDParams{ID: u.ID} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserGithubComUserID(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal) + })) + s.Run("UpdateUserHashedPassword", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserHashedPasswordParams{ID: u.ID} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserHashedPassword(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns() + })) + s.Run("UpdateUserHashedOneTimePasscode", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserHashedOneTimePasscodeParams{ID: u.ID} + dbm.EXPECT().UpdateUserHashedOneTimePasscode(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() + })) + s.Run("UpdateUserQuietHoursSchedule", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserQuietHoursScheduleParams{ID: u.ID} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserQuietHoursSchedule(gomock.Any(), arg).Return(database.User{}, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal) + })) + s.Run("UpdateUserLastSeenAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserLastSeenAtParams{ID: u.ID, UpdatedAt: u.UpdatedAt, LastSeenAt: u.LastSeenAt} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserLastSeenAt(gomock.Any(), arg).Return(u, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdate).Returns(u) + })) + s.Run("UpdateUserProfile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserProfileParams{ID: u.ID, Email: u.Email, Username: u.Username, Name: u.Name, UpdatedAt: u.UpdatedAt} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserProfile(gomock.Any(), arg).Return(u, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(u) + })) + s.Run("GetUserWorkspaceBuildParameters", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.GetUserWorkspaceBuildParametersParams{OwnerID: u.ID, TemplateID: uuid.Nil} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserWorkspaceBuildParameters(gomock.Any(), arg).Return([]database.GetUserWorkspaceBuildParametersRow{}, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionReadPersonal).Returns([]database.GetUserWorkspaceBuildParametersRow{}) + })) + s.Run("GetUserThemePreference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserThemePreference(gomock.Any(), u.ID).Return("light", nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("light") })) - s.Run("UpdateUserThemePreference", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - uc := database.UserConfig{ - UserID: u.ID, - Key: "theme_preference", - Value: "dark", - } - check.Args(database.UpdateUserThemePreferenceParams{ - UserID: u.ID, - ThemePreference: uc.Value, - }).Asserts(u, policy.ActionUpdatePersonal).Returns(uc) - })) - s.Run("GetUserTerminalFont", s.Subtest(func(db database.Store, check *expects) { - ctx := context.Background() - u := dbgen.User(s.T(), db, database.User{}) - db.UpdateUserTerminalFont(ctx, database.UpdateUserTerminalFontParams{ - UserID: u.ID, - TerminalFont: "ibm-plex-mono", - }) + s.Run("UpdateUserThemePreference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + uc := database.UserConfig{UserID: u.ID, Key: "theme_preference", Value: "dark"} + arg := database.UpdateUserThemePreferenceParams{UserID: u.ID, ThemePreference: uc.Value} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserThemePreference(gomock.Any(), arg).Return(uc, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc) + })) + s.Run("GetUserTerminalFont", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserTerminalFont(gomock.Any(), u.ID).Return("ibm-plex-mono", nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("ibm-plex-mono") })) - s.Run("UpdateUserTerminalFont", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - uc := database.UserConfig{ - UserID: u.ID, - Key: "terminal_font", - Value: "ibm-plex-mono", - } - check.Args(database.UpdateUserTerminalFontParams{ - UserID: u.ID, - TerminalFont: uc.Value, - }).Asserts(u, policy.ActionUpdatePersonal).Returns(uc) - })) - s.Run("UpdateUserStatus", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserStatusParams{ - ID: u.ID, - Status: u.Status, - UpdatedAt: u.UpdatedAt, - }).Asserts(u, policy.ActionUpdate).Returns(u) - })) - s.Run("DeleteGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + s.Run("UpdateUserTerminalFont", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + uc := database.UserConfig{UserID: u.ID, Key: "terminal_font", Value: "ibm-plex-mono"} + arg := database.UpdateUserTerminalFontParams{UserID: u.ID, TerminalFont: uc.Value} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserTerminalFont(gomock.Any(), arg).Return(uc, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc) + })) + s.Run("UpdateUserStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserStatusParams{ID: u.ID, Status: u.Status, UpdatedAt: u.UpdatedAt} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserStatus(gomock.Any(), arg).Return(u, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdate).Returns(u) + })) + s.Run("DeleteGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.GitSSHKey{}) + dbm.EXPECT().GetGitSSHKey(gomock.Any(), key.UserID).Return(key, nil).AnyTimes() + dbm.EXPECT().DeleteGitSSHKey(gomock.Any(), key.UserID).Return(nil).AnyTimes() check.Args(key.UserID).Asserts(rbac.ResourceUserObject(key.UserID), policy.ActionUpdatePersonal).Returns() })) - s.Run("GetGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + s.Run("GetGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.GitSSHKey{}) + dbm.EXPECT().GetGitSSHKey(gomock.Any(), key.UserID).Return(key, nil).AnyTimes() check.Args(key.UserID).Asserts(rbac.ResourceUserObject(key.UserID), policy.ActionReadPersonal).Returns(key) })) - s.Run("InsertGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertGitSSHKeyParams{ - UserID: u.ID, - }).Asserts(u, policy.ActionUpdatePersonal) - })) - s.Run("UpdateGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) - check.Args(database.UpdateGitSSHKeyParams{ - UserID: key.UserID, - UpdatedAt: key.UpdatedAt, - }).Asserts(rbac.ResourceUserObject(key.UserID), policy.ActionUpdatePersonal).Returns(key) - })) - s.Run("GetExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) - check.Args(database.GetExternalAuthLinkParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionReadPersonal).Returns(link) - })) - s.Run("InsertExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertExternalAuthLinkParams{ - ProviderID: uuid.NewString(), - UserID: u.ID, - }).Asserts(u, policy.ActionUpdatePersonal) - })) - s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) - check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{ - OAuthRefreshToken: "", - OAuthRefreshTokenKeyID: "", - ProviderID: link.ProviderID, - UserID: link.UserID, - UpdatedAt: link.UpdatedAt, - }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal) - })) - s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) - check.Args(database.UpdateExternalAuthLinkParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: link.OAuthExpiry, - UpdatedAt: link.UpdatedAt, - }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal).Returns(link) - })) - s.Run("UpdateUserLink", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - link := dbgen.UserLink(s.T(), db, database.UserLink{}) - check.Args(database.UpdateUserLinkParams{ - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: link.OAuthExpiry, - UserID: link.UserID, - LoginType: link.LoginType, - Claims: database.UserLinkClaims{}, - }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal).Returns(link) - })) - s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{RBACRoles: []string{codersdk.RoleTemplateAdmin}}) + s.Run("InsertGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.InsertGitSSHKeyParams{UserID: u.ID} + dbm.EXPECT().InsertGitSSHKey(gomock.Any(), arg).Return(database.GitSSHKey{UserID: u.ID}, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal) + })) + s.Run("UpdateGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.GitSSHKey{}) + arg := database.UpdateGitSSHKeyParams{UserID: key.UserID, UpdatedAt: key.UpdatedAt} + dbm.EXPECT().GetGitSSHKey(gomock.Any(), key.UserID).Return(key, nil).AnyTimes() + dbm.EXPECT().UpdateGitSSHKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(key, policy.ActionUpdatePersonal).Returns(key) + })) + s.Run("GetExternalAuthLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{}) + arg := database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID} + dbm.EXPECT().GetExternalAuthLink(gomock.Any(), arg).Return(link, nil).AnyTimes() + check.Args(arg).Asserts(link, policy.ActionReadPersonal).Returns(link) + })) + s.Run("InsertExternalAuthLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.InsertExternalAuthLinkParams{ProviderID: uuid.NewString(), UserID: u.ID} + dbm.EXPECT().InsertExternalAuthLink(gomock.Any(), arg).Return(database.ExternalAuthLink{}, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal) + })) + s.Run("UpdateExternalAuthLinkRefreshToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{}) + arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt} + dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes() + dbm.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(link, policy.ActionUpdatePersonal) + })) + s.Run("UpdateExternalAuthLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{}) + arg := database.UpdateExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID, OAuthAccessToken: link.OAuthAccessToken, OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: link.OAuthExpiry, UpdatedAt: link.UpdatedAt} + dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes() + dbm.EXPECT().UpdateExternalAuthLink(gomock.Any(), arg).Return(link, nil).AnyTimes() + check.Args(arg).Asserts(link, policy.ActionUpdatePersonal).Returns(link) + })) + s.Run("UpdateUserLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + link := testutil.Fake(s.T(), faker, database.UserLink{}) + arg := database.UpdateUserLinkParams{OAuthAccessToken: link.OAuthAccessToken, OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: link.OAuthExpiry, UserID: link.UserID, LoginType: link.LoginType, Claims: database.UserLinkClaims{}} + dbm.EXPECT().GetUserLinkByUserIDLoginType(gomock.Any(), database.GetUserLinkByUserIDLoginTypeParams{UserID: link.UserID, LoginType: link.LoginType}).Return(link, nil).AnyTimes() + dbm.EXPECT().UpdateUserLink(gomock.Any(), arg).Return(link, nil).AnyTimes() + check.Args(arg).Asserts(link, policy.ActionUpdatePersonal).Returns(link) + })) + s.Run("UpdateUserRoles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{RBACRoles: []string{codersdk.RoleTemplateAdmin}}) o := u o.RBACRoles = []string{codersdk.RoleUserAdmin} - check.Args(database.UpdateUserRolesParams{ - GrantedRoles: []string{codersdk.RoleUserAdmin}, - ID: u.ID, - }).Asserts( + arg := database.UpdateUserRolesParams{GrantedRoles: []string{codersdk.RoleUserAdmin}, ID: u.ID} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserRoles(gomock.Any(), arg).Return(o, nil).AnyTimes() + check.Args(arg).Asserts( u, policy.ActionRead, rbac.ResourceAssignRole, policy.ActionAssign, rbac.ResourceAssignRole, policy.ActionUnassign, ).Returns(o) })) - s.Run("AllUserIDs", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{}) - b := dbgen.User(s.T(), db, database.User{}) + s.Run("AllUserIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + a := testutil.Fake(s.T(), faker, database.User{}) + b := testutil.Fake(s.T(), faker, database.User{}) + dbm.EXPECT().AllUserIDs(gomock.Any(), false).Return([]uuid.UUID{a.ID, b.ID}, nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(slice.New(a.ID, b.ID)) })) - s.Run("CustomRoles", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.CustomRolesParams{}).Asserts(rbac.ResourceAssignRole, policy.ActionRead).Returns([]database.CustomRole{}) + s.Run("CustomRoles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.CustomRolesParams{} + dbm.EXPECT().CustomRoles(gomock.Any(), arg).Return([]database.CustomRole{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceAssignRole, policy.ActionRead).Returns([]database.CustomRole{}) })) - s.Run("Organization/DeleteCustomRole", s.Subtest(func(db database.Store, check *expects) { - customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{ - OrganizationID: uuid.NullUUID{ - UUID: uuid.New(), - Valid: true, - }, - }) - check.Args(database.DeleteCustomRoleParams{ - Name: customRole.Name, - OrganizationID: customRole.OrganizationID, - }).Asserts( - rbac.ResourceAssignOrgRole.InOrg(customRole.OrganizationID.UUID), policy.ActionDelete) + s.Run("Organization/DeleteCustomRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + orgID := uuid.New() + arg := database.DeleteCustomRoleParams{Name: "role", OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}} + dbm.EXPECT().DeleteCustomRole(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionDelete) })) - s.Run("Site/DeleteCustomRole", s.Subtest(func(db database.Store, check *expects) { - customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{ - OrganizationID: uuid.NullUUID{ - UUID: uuid.Nil, - Valid: false, - }, - }) - check.Args(database.DeleteCustomRoleParams{ - Name: customRole.Name, - }).Asserts( - // fails immediately, missing organization id - ).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) + s.Run("Site/DeleteCustomRole", s.Mocked(func(_ *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.DeleteCustomRoleParams{Name: "role"} + check.Args(arg).Asserts().Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) })) - s.Run("Blank/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) { - dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) - customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{ - OrganizationID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, - }) - // Blank is no perms in the role - check.Args(database.UpdateCustomRoleParams{ - Name: customRole.Name, - DisplayName: "Test Name", - OrganizationID: customRole.OrganizationID, - SitePermissions: nil, - OrgPermissions: nil, - UserPermissions: nil, - }).Asserts(rbac.ResourceAssignOrgRole.InOrg(customRole.OrganizationID.UUID), policy.ActionUpdate) - })) - s.Run("SitePermissions/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpdateCustomRoleParams{ + s.Run("Blank/UpdateCustomRole", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + orgID := uuid.New() + arg := database.UpdateCustomRoleParams{Name: "name", DisplayName: "Test Name", OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}} + dbm.EXPECT().UpdateCustomRole(gomock.Any(), arg).Return(database.CustomRole{}, nil).AnyTimes() + // Blank perms -> no escalation asserts beyond org role update + check.Args(arg).Asserts(rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionUpdate) + })) + s.Run("SitePermissions/UpdateCustomRole", s.Mocked(func(_ *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.UpdateCustomRoleParams{ Name: "", OrganizationID: uuid.NullUUID{UUID: uuid.Nil, Valid: false}, DisplayName: "Test Name", @@ -1998,50 +1936,35 @@ func (s *MethodTestSuite) TestUser() { UserPermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ codersdk.ResourceWorkspace: {codersdk.ActionRead}, }), convertSDKPerm), - }).Asserts( - // fails immediately, missing organization id - ).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) + } + check.Args(arg).Asserts().Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) })) - s.Run("OrgPermissions/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) { + s.Run("OrgPermissions/UpdateCustomRole", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { orgID := uuid.New() - customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{ - OrganizationID: uuid.NullUUID{ - UUID: orgID, - Valid: true, - }, - }) - - check.Args(database.UpdateCustomRoleParams{ - Name: customRole.Name, - DisplayName: "Test Name", - OrganizationID: customRole.OrganizationID, - SitePermissions: nil, + arg := database.UpdateCustomRoleParams{ + Name: "name", + DisplayName: "Test Name", + OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}, OrgPermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ codersdk.ResourceTemplate: {codersdk.ActionCreate, codersdk.ActionRead}, }), convertSDKPerm), - UserPermissions: nil, - }).Asserts( - // First check + } + dbm.EXPECT().UpdateCustomRole(gomock.Any(), arg).Return(database.CustomRole{}, nil).AnyTimes() + check.Args(arg).Asserts( rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionUpdate, // Escalation checks rbac.ResourceTemplate.InOrg(orgID), policy.ActionCreate, rbac.ResourceTemplate.InOrg(orgID), policy.ActionRead, ) })) - s.Run("Blank/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) { - // Blank is no perms in the role + s.Run("Blank/InsertCustomRole", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { orgID := uuid.New() - check.Args(database.InsertCustomRoleParams{ - Name: "test", - DisplayName: "Test Name", - OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}, - SitePermissions: nil, - OrgPermissions: nil, - UserPermissions: nil, - }).Asserts(rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionCreate) - })) - s.Run("SitePermissions/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertCustomRoleParams{ + arg := database.InsertCustomRoleParams{Name: "test", DisplayName: "Test Name", OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}} + dbm.EXPECT().InsertCustomRole(gomock.Any(), arg).Return(database.CustomRole{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionCreate) + })) + s.Run("SitePermissions/InsertCustomRole", s.Mocked(func(_ *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertCustomRoleParams{ Name: "test", DisplayName: "Test Name", SitePermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ @@ -2051,42 +1974,37 @@ func (s *MethodTestSuite) TestUser() { UserPermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ codersdk.ResourceWorkspace: {codersdk.ActionRead}, }), convertSDKPerm), - }).Asserts( - // fails immediately, missing organization id - ).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) + } + check.Args(arg).Asserts().Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}) })) - s.Run("OrgPermissions/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) { + s.Run("OrgPermissions/InsertCustomRole", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { orgID := uuid.New() - check.Args(database.InsertCustomRoleParams{ - Name: "test", - DisplayName: "Test Name", - OrganizationID: uuid.NullUUID{ - UUID: orgID, - Valid: true, - }, - SitePermissions: nil, + arg := database.InsertCustomRoleParams{ + Name: "test", + DisplayName: "Test Name", + OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true}, OrgPermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ codersdk.ResourceTemplate: {codersdk.ActionCreate, codersdk.ActionRead}, }), convertSDKPerm), - UserPermissions: nil, - }).Asserts( - // First check + } + dbm.EXPECT().InsertCustomRole(gomock.Any(), arg).Return(database.CustomRole{}, nil).AnyTimes() + check.Args(arg).Asserts( rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionCreate, // Escalation checks rbac.ResourceTemplate.InOrg(orgID), policy.ActionCreate, rbac.ResourceTemplate.InOrg(orgID), policy.ActionRead, ) })) - s.Run("GetUserStatusCounts", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.GetUserStatusCountsParams{ - StartTime: time.Now().Add(-time.Hour * 24 * 30), - EndTime: time.Now(), - Interval: int32((time.Hour * 24).Seconds()), - }).Asserts(rbac.ResourceUser, policy.ActionRead) + s.Run("GetUserStatusCounts", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetUserStatusCountsParams{StartTime: time.Now().Add(-time.Hour * 24 * 30), EndTime: time.Now(), Interval: int32((time.Hour * 24).Seconds())} + dbm.EXPECT().GetUserStatusCounts(gomock.Any(), arg).Return([]database.GetUserStatusCountsRow{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceUser, policy.ActionRead) })) - s.Run("ValidateUserIDs", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args([]uuid.UUID{u.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead) + s.Run("ValidateUserIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + ids := []uuid.UUID{u.ID} + dbm.EXPECT().ValidateUserIDs(gomock.Any(), ids).Return(database.ValidateUserIDsRow{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead) })) } From f349edcc3cb79bb717b563b11d205c4f630eaabd Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Tue, 12 Aug 2025 11:23:55 +0100 Subject: [PATCH 2/3] refactor: create tasks in coderd instead of frontend (#19280) Instead of creating tasks with a specialized call to `CreateWorkspace` on the frontend, we instead lift this to the backend and allow the frontend to simply call `CreateAITask`. --- coderd/aitasks.go | 110 ++++++++++++++++ coderd/aitasks_test.go | 124 +++++++++++++++++++ coderd/coderd.go | 9 ++ coderd/database/dbauthz/dbauthz.go | 11 ++ coderd/database/dbauthz/dbauthz_test.go | 14 +++ coderd/database/dbmetrics/querymetrics.go | 7 ++ coderd/database/dbmock/dbmock.go | 15 +++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 15 +++ coderd/database/queries/templateversions.sql | 7 ++ codersdk/aitasks.go | 27 ++++ site/src/api/api.ts | 12 ++ site/src/api/typesGenerated.ts | 8 ++ site/src/pages/TasksPage/TasksPage.tsx | 6 +- 14 files changed, 362 insertions(+), 4 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index a982ccc39b26b..e1d72f264a025 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -1,13 +1,20 @@ package coderd import ( + "database/sql" + "errors" "fmt" "net/http" + "slices" "strings" "github.com/google/uuid" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" ) @@ -61,3 +68,106 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) { Prompts: promptsByBuildID, }) } + +// This endpoint is experimental and not guaranteed to be stable, so we're not +// generating public-facing documentation for it. +func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + apiKey = httpmw.APIKey(r) + auditor = api.Auditor.Load() + mems = httpmw.OrganizationMembersParam(r) + ) + + var req codersdk.CreateTaskRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) { + httpapi.ResourceNotFound(rw) + return + } + + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching whether the template version has an AI task.", + Detail: err.Error(), + }) + return + } + if !hasAITask { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName), + }) + return + } + + createReq := codersdk.CreateWorkspaceRequest{ + Name: req.Name, + TemplateVersionID: req.TemplateVersionID, + TemplateVersionPresetID: req.TemplateVersionPresetID, + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: codersdk.AITaskPromptParameterName, Value: req.Prompt}, + }, + } + + var owner workspaceOwner + if mems.User != nil { + // This user fetch is an optimization path for the most common case of creating a + // task for 'Me'. + // + // This is also required to allow `owners` to create workspaces for users + // that are not in an organization. + owner = workspaceOwner{ + ID: mems.User.ID, + Username: mems.User.Username, + AvatarURL: mems.User.AvatarURL, + } + } else { + // A task can still be created if the caller can read the organization + // member. The organization is required, which can be sourced from the + // template. + // + // TODO: This code gets called twice for each workspace build request. + // This is inefficient and costs at most 2 extra RTTs to the DB. + // This can be optimized. It exists as it is now for code simplicity. + // The most common case is to create a workspace for 'Me'. Which does + // not enter this code branch. + template, ok := requestTemplate(ctx, rw, createReq, api.Database) + if !ok { + return + } + + // If the caller can find the organization membership in the same org + // as the template, then they can continue. + orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool { + return mem.OrganizationID == template.OrganizationID + }) + if orgIndex == -1 { + httpapi.ResourceNotFound(rw) + return + } + + member := mems.Memberships[orgIndex] + owner = workspaceOwner{ + ID: member.UserID, + Username: member.Username, + AvatarURL: member.AvatarURL, + } + } + + aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{ + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, + AdditionalFields: audit.AdditionalFields{ + WorkspaceOwner: owner.Username, + }, + }) + + defer commitAudit() + createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, rw, r) +} diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 53f0174d6f03d..8d12dd3a5ec95 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -1,9 +1,11 @@ package coderd_test import ( + "net/http" "testing" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" @@ -139,3 +141,125 @@ func TestAITasksPrompts(t *testing.T) { require.Empty(t, prompts.Prompts) }) } + +func TestTaskCreate(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template with an "AI Prompt" parameter + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: echo.ApplyComplete, + ProvisionPlan: []*proto.Response{ + {Type: &proto.Response_Plan{Plan: &proto.PlanComplete{ + Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}}, + HasAiTasks: true, + }}}, + }, + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task. + workspace, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ + Name: taskName, + TemplateVersionID: template.ActiveVersionID, + Prompt: taskPrompt, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + // Then: We expect a workspace to have been created. + assert.Equal(t, taskName, workspace.Name) + assert.Equal(t, template.ID, workspace.TemplateID) + + // And: We expect it to have the "AI Prompt" parameter correctly set. + parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) + require.NoError(t, err) + require.Len(t, parameters, 1) + assert.Equal(t, codersdk.AITaskPromptParameterName, parameters[0].Name) + assert.Equal(t, taskPrompt, parameters[0].Value) + }) + + t.Run("FailsOnNonTaskTemplate", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template without an "AI Prompt" parameter + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task. + _, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ + Name: taskName, + TemplateVersionID: template.ActiveVersionID, + Prompt: taskPrompt, + }) + + // Then: We expect it to fail. + var sdkErr *codersdk.Error + require.Error(t, err) + require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error") + assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("FailsOnInvalidTemplate", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task with an invalid template version ID. + _, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ + Name: taskName, + TemplateVersionID: uuid.New(), + Prompt: taskPrompt, + }) + + // Then: We expect it to fail. + var sdkErr *codersdk.Error + require.Error(t, err) + require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error") + assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 78ae849fd1894..2aa30c9d7a45c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -995,6 +995,15 @@ func New(options *Options) *API { r.Route("/aitasks", func(r chi.Router) { r.Get("/prompts", api.aiTasksPrompts) }) + r.Route("/tasks", func(r chi.Router) { + r.Use(apiRateLimiter) + + r.Route("/{user}", func(r chi.Router) { + r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize)) + + r.Post("/", api.tasksCreate) + }) + }) r.Route("/mcp", func(r chi.Router) { r.Use( httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP), diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a4759c8636097..2cbcf1ec6f0d4 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2863,6 +2863,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg return tv, nil } +func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { + // If we can successfully call `GetTemplateVersionByID`, then + // we know the actor has sufficient permissions to know if the + // template has an AI task. + if _, err := q.GetTemplateVersionByID(ctx, id); err != nil { + return false, err + } + + return q.db.GetTemplateVersionHasAITask(ctx, id) +} + func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { // An actor can read template version parameters if they can read the related template. tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 5c4f3a7d105fd..a5623fbcbcd36 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() { }) check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead) })) + s.Run("GetTemplateVersionHasAITask", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + t := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + OrganizationID: o.ID, + TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true}, + CreatedBy: u.ID, + }) + check.Args(tv.ID).Asserts(t, policy.ActionRead) + })) s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u := dbgen.User(s.T(), db, database.User{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index cb61df85db007..9bfdbf049ac1a 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1531,6 +1531,13 @@ func (m queryMetricsStore) GetTemplateVersionByTemplateIDAndName(ctx context.Con return version, err } +func (m queryMetricsStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { + start := time.Now() + r0, r1 := m.s.GetTemplateVersionHasAITask(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateVersionHasAITask").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { start := time.Now() parameters, err := m.s.GetTemplateVersionParameters(ctx, templateVersionID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 0966df6acfba4..934cd434426b2 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3256,6 +3256,21 @@ func (mr *MockStoreMockRecorder) GetTemplateVersionByTemplateIDAndName(ctx, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionByTemplateIDAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionByTemplateIDAndName), ctx, arg) } +// GetTemplateVersionHasAITask mocks base method. +func (m *MockStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTemplateVersionHasAITask", ctx, id) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTemplateVersionHasAITask indicates an expected call of GetTemplateVersionHasAITask. +func (mr *MockStoreMockRecorder) GetTemplateVersionHasAITask(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAITask", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAITask), ctx, id) +} + // GetTemplateVersionParameters mocks base method. func (m *MockStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 042bee9d3701b..9c179351b26e3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -354,6 +354,7 @@ type sqlcQuerier interface { GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error) + GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionParameter, error) GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (TemplateVersionTerraformValue, error) GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionVariable, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index eef7b8b6819d0..c039b7f94e8d5 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -12870,6 +12870,21 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, return i, err } +const getTemplateVersionHasAITask = `-- name: GetTemplateVersionHasAITask :one +SELECT EXISTS ( + SELECT 1 + FROM template_versions + WHERE id = $1 AND has_ai_task = TRUE +) +` + +func (q *sqlQuerier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { + row := q.db.QueryRowContext(ctx, getTemplateVersionHasAITask, id) + var exists bool + err := row.Scan(&exists) + return exists, err +} + const getTemplateVersionsByIDs = `-- name: GetTemplateVersionsByIDs :many SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, has_ai_task, created_by_avatar_url, created_by_username, created_by_name diff --git a/coderd/database/queries/templateversions.sql b/coderd/database/queries/templateversions.sql index 5cf59fab30272..97fb6bd9ecc08 100644 --- a/coderd/database/queries/templateversions.sql +++ b/coderd/database/queries/templateversions.sql @@ -234,3 +234,10 @@ FROM WHERE template_versions.id IN (archived_versions.id) RETURNING template_versions.id; + +-- name: GetTemplateVersionHasAITask :one +SELECT EXISTS ( + SELECT 1 + FROM template_versions + WHERE id = $1 AND has_ai_task = TRUE +); diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index 89ca9c948f272..49d89bf5e2656 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -3,6 +3,7 @@ package codersdk import ( "context" "encoding/json" + "fmt" "net/http" "strings" @@ -44,3 +45,29 @@ func (c *ExperimentalClient) AITaskPrompts(ctx context.Context, buildIDs []uuid. var prompts AITasksPromptsResponse return prompts, json.NewDecoder(res.Body).Decode(&prompts) } + +type CreateTaskRequest struct { + Name string `json:"name"` + TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"` + TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"` + Prompt string `json:"prompt"` +} + +func (c *ExperimentalClient) CreateTask(ctx context.Context, user string, request CreateTaskRequest) (Workspace, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/tasks/%s", user), request) + if err != nil { + return Workspace{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusCreated { + return Workspace{}, ReadBodyAsError(res) + } + + var workspace Workspace + if err := json.NewDecoder(res.Body).Decode(&workspace); err != nil { + return Workspace{}, err + } + + return workspace, nil +} diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2b21ddf1e8a08..b9d5f06924519 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2665,6 +2665,18 @@ class ExperimentalApiMethods { return response.data; }; + + createTask = async ( + user: string, + req: TypesGen.CreateTaskRequest, + ): Promise => { + const response = await this.axios.post( + `/api/experimental/tasks/${user}`, + req, + ); + + return response.data; + }; } // This is a hard coded CSRF token/cookie pair for local development. In prod, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 52fdb1d6effc4..6f5ab307a2fa8 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -476,6 +476,14 @@ export interface CreateProvisionerKeyResponse { readonly key: string; } +// From codersdk/aitasks.go +export interface CreateTaskRequest { + readonly name: string; + readonly template_version_id: string; + readonly template_version_preset_id?: string; + readonly prompt: string; +} + // From codersdk/organizations.go export interface CreateTemplateRequest { readonly name: string; diff --git a/site/src/pages/TasksPage/TasksPage.tsx b/site/src/pages/TasksPage/TasksPage.tsx index ce6ddea380046..2f6405e796134 100644 --- a/site/src/pages/TasksPage/TasksPage.tsx +++ b/site/src/pages/TasksPage/TasksPage.tsx @@ -741,13 +741,11 @@ export const data = { } } - const workspace = await API.createWorkspace(userId, { + const workspace = await API.experimental.createTask(userId, { name: `task-${generateWorkspaceName()}`, template_version_id: templateVersionId, template_version_preset_id: preset_id || undefined, - rich_parameter_values: [ - { name: AI_PROMPT_PARAMETER_NAME, value: prompt }, - ], + prompt, }); return { From 64f0aaa2b4abd9baddeadef448028701a00fdf25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=B1=E3=82=A4=E3=83=A9?= Date: Tue, 12 Aug 2025 04:38:20 -0600 Subject: [PATCH 3/3] chore: skip flaking classic parameters test (#19308) Causing test-js failures on main --- site/src/pages/WorkspacePage/WorkspacePage.test.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx index 288e80127af4a..18a9e1a6f7232 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx @@ -306,7 +306,10 @@ describe("WorkspacePage", () => { }); }); - it("updates the parameters when they are missing during update", async () => { + // Started flaking after upgrading react-router. Tests the old parameters path + // and isn't worth spending more time to fix since this code will be removed + // in a few releases when dynamic parameters takes over the world. + it.skip("updates the parameters when they are missing during update", async () => { // Mocks jest .spyOn(API, "getWorkspaceByOwnerAndName")