diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index f2e141dc74d27..30b8c5d33569c 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3,7 +3,10 @@ package dbauthz import ( "context" "database/sql" + "encoding/json" + "errors" "fmt" + "time" "github.com/google/uuid" "golang.org/x/exp/slices" @@ -14,6 +17,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/util/slice" ) var _ database.Store = (*querier)(nil) @@ -456,3 +460,2050 @@ func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rb return authorizer.Prepare(ctx, act, action, resourceType) } + +func (q *querier) Ping(ctx context.Context) (time.Duration, error) { + return q.db.Ping(ctx) +} + +// InTx runs the given function in a transaction. +func (q *querier) InTx(function func(querier database.Store) error, txOpts *sql.TxOptions) error { + return q.db.InTx(func(tx database.Store) error { + // Wrap the transaction store in a querier. + wrapped := New(tx, q.auth, q.log) + return function(wrapped) + }, txOpts) +} + +// authorizeReadFile is a hotfix for the fact that file permissions are +// independent of template permissions. This function checks if the user has +// update access to any of the file's templates. +func (q *querier) authorizeUpdateFileTemplate(ctx context.Context, file database.File) error { + tpls, err := q.db.GetFileTemplates(ctx, file.ID) + if err != nil { + return err + } + // There __should__ only be 1 template per file, but there can be more than + // 1, so check them all. + for _, tpl := range tpls { + // If the user has update access to any template, they have read access to the file. + if err := q.authorizeContext(ctx, rbac.ActionUpdate, tpl); err == nil { + return nil + } + } + + return NotAuthorizedError{ + Err: xerrors.Errorf("not authorized to read file %s", file.ID), + } +} + +func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, removed []string) error { + actor, ok := ActorFromContext(ctx) + if !ok { + return NoActorError + } + + roleAssign := rbac.ResourceRoleAssignment + shouldBeOrgRoles := false + if orgID != nil { + roleAssign = roleAssign.InOrg(*orgID) + shouldBeOrgRoles = true + } + + grantedRoles := append(added, removed...) + // Validate that the roles being assigned are valid. + for _, r := range grantedRoles { + _, isOrgRole := rbac.IsOrgRole(r) + if shouldBeOrgRoles && !isOrgRole { + return xerrors.Errorf("Must only update org roles") + } + if !shouldBeOrgRoles && isOrgRole { + return xerrors.Errorf("Must only update site wide roles") + } + + // All roles should be valid roles + if _, err := rbac.RoleByName(r); err != nil { + return xerrors.Errorf("%q is not a supported role", r) + } + } + + if len(added) > 0 { + if err := q.authorizeContext(ctx, rbac.ActionCreate, roleAssign); err != nil { + return err + } + } + + if len(removed) > 0 { + if err := q.authorizeContext(ctx, rbac.ActionDelete, roleAssign); err != nil { + return err + } + } + + for _, roleName := range grantedRoles { + if !rbac.CanAssignRole(actor.Roles, roleName) { + return xerrors.Errorf("not authorized to assign role %q", roleName) + } + } + + return nil +} + +func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { + // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. + return q.GetTemplatesWithFilter(ctx, arg) +} + +func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { + deleteF := func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ + ID: id, + Deleted: true, + UpdatedAt: database.Now(), + }) + } + return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) +} + +func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + // An actor is authorized to read template group roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateGroupRoles(ctx, id) +} + +func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + // An actor is authorized to query template user roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateUserRoles(ctx, id) +} + +func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +} + +func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { + // TODO Implement this with a SQL filter. The count is incorrect without it. + rowUsers, err := q.db.GetUsers(ctx, arg) + if err != nil { + return nil, -1, err + } + + if len(rowUsers) == 0 { + return []database.User{}, 0, nil + } + + act, ok := ActorFromContext(ctx) + if !ok { + return nil, -1, NoActorError + } + + // TODO: Is this correct? Should we return a restricted user? + users := database.ConvertUserRows(rowUsers) + users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) + if err != nil { + return nil, -1, err + } + + return users, rowUsers[0].Count, nil +} + +func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { + deleteF := func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{ + ID: id, + Deleted: true, + }) + } + return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) +} + +func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. + return q.GetWorkspaces(ctx, arg) +} + +func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ + ID: id, + Deleted: true, + }) + })(ctx, id) +} + +func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job database.ProvisionerJob) (database.TemplateVersion, error) { + switch job.Type { + case database.ProvisionerJobTypeTemplateVersionDryRun: + // TODO: This is really unfortunate that we need to inspect the json + // payload. We should fix this. + tmp := struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{} + err := json.Unmarshal(job.Input, &tmp) + if err != nil { + return database.TemplateVersion{}, xerrors.Errorf("dry-run unmarshal: %w", err) + } + // Authorized call to get template version. + tv, err := q.GetTemplateVersionByID(ctx, tmp.TemplateVersionID) + if err != nil { + return database.TemplateVersion{}, err + } + return tv, nil + case database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + tv, err := q.GetTemplateVersionByJobID(ctx, job.ID) + if err != nil { + return database.TemplateVersion{}, err + } + return tv, nil + default: + return database.TemplateVersion{}, xerrors.Errorf("unknown job type: %q", job.Type) + } +} + +func (q *querier) AcquireLock(ctx context.Context, id int64) error { + return q.db.AcquireLock(ctx, id) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { + // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + // return database.ProvisionerJob{}, err + // } + return q.db.AcquireProvisionerJob(ctx, arg) +} + +func (q *querier) DeleteAPIKeyByID(ctx context.Context, id string) error { + return deleteQ(q.log, q.auth, q.db.GetAPIKeyByID, q.db.DeleteAPIKeyByID)(ctx, id) +} + +func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { + // TODO: This is not 100% correct because it omits apikey IDs. + err := q.authorizeContext(ctx, rbac.ActionDelete, + rbac.ResourceAPIKey.WithOwner(userID.String())) + if err != nil { + return err + } + return q.db.DeleteAPIKeysByUserID(ctx, userID) +} + +func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { + // TODO: This is not 100% correct because it omits apikey IDs. + err := q.authorizeContext(ctx, rbac.ActionDelete, + rbac.ResourceAPIKey.WithOwner(userID.String())) + if err != nil { + return err + } + return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) +} + +func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID) +} + +func (q *querier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetGroupByID, q.db.DeleteGroupByID)(ctx, id) +} + +func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) error { + // Deleting a group member counts as updating a group. + fetch := func(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.GroupID) + } + return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) +} + +func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { + // This will remove the user from all groups in the org. This counts as updating a group. + // NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead + // check if the caller has permission to update any group in the org. + fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) { + return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil + } + return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg) +} + +func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { + err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { + _, err := q.db.DeleteLicense(ctx, id) + return err + })(ctx, id) + if err != nil { + return -1, err + } + return id, nil +} + +func (q *querier) DeleteOldWorkspaceAgentStartupLogs(ctx context.Context) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { + return err + } + return q.db.DeleteOldWorkspaceAgentStartupLogs(ctx) +} + +func (q *querier) DeleteOldWorkspaceAgentStats(ctx context.Context) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { + return err + } + return q.db.DeleteOldWorkspaceAgentStats(ctx) +} + +func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { + return err + } + return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt) +} + +func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { + return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) +} + +func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByNameParams) (database.APIKey, error) { + return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg) +} + +func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) { + return fetchWithPostFilter(q.auth, q.db.GetAPIKeysByLoginType)(ctx, loginType) +} + +func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { + return fetchWithPostFilter(q.auth, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID}) +} + +func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) { + return fetchWithPostFilter(q.auth, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed) +} + +func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return 0, err + } + return q.db.GetActiveUserCount(ctx) +} + +func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetAppSecurityKey(ctx) +} + +func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { + // To optimize audit logs, we only check the global audit log permission once. + // This is because we expect a large unbounded set of audit logs, and applying a SQL + // filter would slow down the query for no benefit. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceAuditLog); err != nil { + return nil, err + } + return q.db.GetAuditLogsOffset(ctx, arg) +} + +func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.GetAuthorizationUserRolesRow{}, err + } + return q.db.GetAuthorizationUserRoles(ctx, userID) +} + +func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return "", err + } + return q.db.GetDERPMeshKey(ctx) +} + +func (q *querier) GetDefaultProxyConfig(ctx context.Context) (database.GetDefaultProxyConfigRow, error) { + // No authz checks + return q.db.GetDefaultProxyConfig(ctx) +} + +// Only used by metrics cache. +func (q *querier) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetDeploymentDAUs(ctx, tzOffset) +} + +func (q *querier) GetDeploymentID(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetDeploymentID(ctx) +} + +func (q *querier) GetDeploymentWorkspaceAgentStats(ctx context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { + return q.db.GetDeploymentWorkspaceAgentStats(ctx, createdAfter) +} + +func (q *querier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { + return q.db.GetDeploymentWorkspaceStats(ctx) +} + +func (q *querier) GetFileByHashAndCreator(ctx context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { + file, err := q.db.GetFileByHashAndCreator(ctx, arg) + if err != nil { + return database.File{}, err + } + err = q.authorizeContext(ctx, rbac.ActionRead, file) + if err != nil { + // Check the user's access to the file's templates. + if q.authorizeUpdateFileTemplate(ctx, file) != nil { + return database.File{}, err + } + } + + return file, nil +} + +func (q *querier) GetFileByID(ctx context.Context, id uuid.UUID) (database.File, error) { + file, err := q.db.GetFileByID(ctx, id) + if err != nil { + return database.File{}, err + } + err = q.authorizeContext(ctx, rbac.ActionRead, file) + if err != nil { + // Check the user's access to the file's templates. + if q.authorizeUpdateFileTemplate(ctx, file) != nil { + return database.File{}, err + } + } + + return file, nil +} + +func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetFileTemplates(ctx, fileID) +} + +func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + // TODO: This should be the only implementation. + return q.GetAuthorizedUserCount(ctx, arg, prep) +} + +func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { + return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) +} + +func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { + return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID) +} + +func (q *querier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { + return fetch(q.log, q.auth, q.db.GetGroupByID)(ctx, id) +} + +func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { + return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg) +} + +func (q *querier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) { + if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check + return nil, err + } + return q.db.GetGroupMembers(ctx, groupID) +} + +func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) { + return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID) +} + +func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return "", err + } + return q.db.GetLastUpdateCheck(ctx) +} + +func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return q.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) +} + +func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) { + // This function is a system function until we implement a join for workspace builds. + // This is because we need to query for all related workspaces to the returned builds. + // This is a very inefficient method of fetching the latest workspace builds. + // We should just join the rbac properties. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetLatestWorkspaceBuilds(ctx) +} + +func (q *querier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { + // This is not ideal as not all builds will be returned if the workspace cannot be read. + // This should probably be handled differently? Maybe join workspace builds with workspace + // ownership properties and filter on that. + for _, id := range ids { + _, err := q.GetWorkspaceByID(ctx, id) + if err != nil { + return nil, err + } + } + + return q.db.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, ids) +} + +func (q *querier) GetLicenseByID(ctx context.Context, id int32) (database.License, error) { + return fetch(q.log, q.auth, q.db.GetLicenseByID)(ctx, id) +} + +func (q *querier) GetLicenses(ctx context.Context) ([]database.License, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.License, error) { + return q.db.GetLicenses(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) GetLogoURL(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetLogoURL(ctx) +} + +func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id) +} + +func (q *querier) GetOrganizationByName(ctx context.Context, name string) (database.Organization, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationByName)(ctx, name) +} + +func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { + // TODO: This should be rewritten to return a list of database.OrganizationMember for consistent RBAC objects. + // Currently this row returns a list of org ids per user, which is challenging to check against the RBAC system. + return fetchWithPostFilter(q.auth, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids) +} + +func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg) +} + +func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { + return fetchWithPostFilter(q.auth, q.db.GetOrganizationMembershipsByUserID)(ctx, userID) +} + +func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) { + return q.db.GetOrganizations(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]database.Organization, error) { + return fetchWithPostFilter(q.auth, q.db.GetOrganizationsByUserID)(ctx, userID) +} + +func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { + version, err := q.db.GetTemplateVersionByJobID(ctx, jobID) + if err != nil { + return nil, err + } + object := version.RBACObjectNoTemplate() + if version.TemplateID.Valid { + tpl, err := q.db.GetTemplateByID(ctx, version.TemplateID.UUID) + if err != nil { + return nil, err + } + object = version.RBACObject(tpl) + } + + err = q.authorizeContext(ctx, rbac.ActionRead, object) + if err != nil { + return nil, err + } + return q.db.GetParameterSchemasByJobID(ctx, jobID) +} + +func (q *querier) GetPreviousTemplateVersion(ctx context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { + // An actor can read the previous template version if they can read the related template. + // If no linked template exists, we check if the actor can read *a* template. + if !arg.TemplateID.Valid { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(arg.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } + if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil { + return database.TemplateVersion{}, err + } + return q.db.GetPreviousTemplateVersion(ctx, arg) +} + +func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.ProvisionerDaemon, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.ProvisionerDaemon, error) { + return q.db.GetProvisionerDaemons(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + job, err := q.db.GetProvisionerJobByID(ctx, id) + if err != nil { + return database.ProvisionerJob{}, err + } + + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + // Authorized call to get workspace build. If we can read the build, we + // can read the job. + _, err := q.GetWorkspaceBuildByJobID(ctx, id) + if err != nil { + return database.ProvisionerJob{}, err + } + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + _, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return database.ProvisionerJob{}, err + } + default: + return database.ProvisionerJob{}, xerrors.Errorf("unknown job type: %q", job.Type) + } + + return job, nil +} + +// TODO: we need to add a provisioner job resource +func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { + // if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + // return nil, err + // } + return q.db.GetProvisionerJobsByIDs(ctx, ids) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) { + // if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + // return nil, err + // } + return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetProvisionerLogsAfterID(ctx context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { + // Authorized read on job lets the actor also read the logs. + _, err := q.GetProvisionerJobByID(ctx, arg.JobID) + if err != nil { + return nil, err + } + return q.db.GetProvisionerLogsAfterID(ctx, arg) +} + +func (q *querier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID) (int64, error) { + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + if err != nil { + return -1, err + } + return q.db.GetQuotaAllowanceForUser(ctx, userID) +} + +func (q *querier) GetQuotaConsumedForUser(ctx context.Context, userID uuid.UUID) (int64, error) { + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + if err != nil { + return -1, err + } + return q.db.GetQuotaConsumedForUser(ctx, userID) +} + +func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetReplicasUpdatedAfter(ctx, updatedAt) +} + +func (q *querier) GetServiceBanner(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetServiceBanner(ctx) +} + +// Only used by metrics cache. +func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.GetTemplateAverageBuildTimeRow{}, err + } + return q.db.GetTemplateAverageBuildTime(ctx, arg) +} + +func (q *querier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { + return fetch(q.log, q.auth, q.db.GetTemplateByID)(ctx, id) +} + +func (q *querier) GetTemplateByOrganizationAndName(ctx context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { + return fetch(q.log, q.auth, q.db.GetTemplateByOrganizationAndName)(ctx, arg) +} + +// Only used by metrics cache. +func (q *querier) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetTemplateDAUs(ctx, arg) +} + +func (q *querier) GetTemplateVersionByID(ctx context.Context, tvid uuid.UUID) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByID(ctx, tvid) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByJobID(ctx, jobID) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByTemplateIDAndName(ctx, arg) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { + // An actor can read template version parameters if they can read the related template. + tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) + if err != nil { + return nil, err + } + + var object rbac.Objecter + template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + object = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + object = tv.RBACObject(template) + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, object); err != nil { + return nil, err + } + return q.db.GetTemplateVersionParameters(ctx, templateVersionID) +} + +func (q *querier) GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { + tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) + if err != nil { + return nil, err + } + + var object rbac.Objecter + template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + object = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + object = tv.RBACObject(template) + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, object); err != nil { + return nil, err + } + return q.db.GetTemplateVersionVariables(ctx, templateVersionID) +} + +// GetTemplateVersionsByIDs is only used for workspace build data. +// The workspace is already fetched. +func (q *querier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetTemplateVersionsByIDs(ctx, ids) +} + +func (q *querier) GetTemplateVersionsByTemplateID(ctx context.Context, arg database.GetTemplateVersionsByTemplateIDParams) ([]database.TemplateVersion, error) { + // An actor can read template versions if they can read the related template. + template, err := q.db.GetTemplateByID(ctx, arg.TemplateID) + if err != nil { + return nil, err + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + + return q.db.GetTemplateVersionsByTemplateID(ctx, arg) +} + +func (q *querier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.TemplateVersion, error) { + // An actor can read execute this query if they can read all templates. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.All()); err != nil { + return nil, err + } + return q.db.GetTemplateVersionsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetTemplates(ctx) +} + +func (q *querier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceTemplate.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedTemplates(ctx, arg, prep) +} + +func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetUnexpiredLicenses(ctx) +} + +func (q *querier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { + return fetch(q.log, q.auth, q.db.GetUserByEmailOrUsername)(ctx, arg) +} + +func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User, error) { + return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id) +} + +func (q *querier) GetUserCount(ctx context.Context) (int64, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return 0, err + } + return q.db.GetUserCount(ctx) +} + +func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.UserLink{}, err + } + return q.db.GetUserLinkByLinkedID(ctx, linkedID) +} + +func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.UserLink{}, err + } + return q.db.GetUserLinkByUserIDLoginType(ctx, arg) +} + +func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { + // TODO: We should use GetUsersWithCount with a better method signature. + return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) +} + +// GetUsersByIDs is only used for usernames on workspace return data. +// This function should be replaced by joining this data to the workspace query +// itself. +func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]database.User, error) { + for _, uid := range ids { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(uid)); err != nil { + return nil, err + } + } + return q.db.GetUsersByIDs(ctx, ids) +} + +// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent. +// This should only be used by a system user in that middleware. +func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.WorkspaceAgent{}, err + } + return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken) +} + +func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { + if _, err := q.GetWorkspaceByAgentID(ctx, id); err != nil { + return database.WorkspaceAgent{}, err + } + return q.db.GetWorkspaceAgentByID(ctx, id) +} + +// GetWorkspaceAgentByInstanceID might want to be a system call? Unsure exactly, +// but this will fail. Need to figure out what AuthInstanceID is, and if it +// is essentially an auth token. But the caller using this function is not +// an authenticated user. So this authz check will fail. +func (q *querier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { + agent, err := q.db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) + if err != nil { + return database.WorkspaceAgent{}, err + } + _, err = q.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return database.WorkspaceAgent{}, err + } + return agent, nil +} + +func (q *querier) GetWorkspaceAgentMetadata(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentMetadatum, error) { + workspace, err := q.db.GetWorkspaceByAgentID(ctx, workspaceAgentID) + if err != nil { + return nil, err + } + + err = q.authorizeContext(ctx, rbac.ActionRead, workspace) + if err != nil { + return nil, err + } + + return q.db.GetWorkspaceAgentMetadata(ctx, workspaceAgentID) +} + +func (q *querier) GetWorkspaceAgentStartupLogsAfter(ctx context.Context, arg database.GetWorkspaceAgentStartupLogsAfterParams) ([]database.WorkspaceAgentStartupLog, error) { + _, err := q.GetWorkspaceAgentByID(ctx, arg.AgentID) + if err != nil { + return nil, err + } + return q.db.GetWorkspaceAgentStartupLogsAfter(ctx, arg) +} + +func (q *querier) GetWorkspaceAgentStats(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { + return q.db.GetWorkspaceAgentStats(ctx, createdAfter) +} + +func (q *querier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { + return q.db.GetWorkspaceAgentStatsAndLabels(ctx, createdAfter) +} + +// GetWorkspaceAgentsByResourceIDs +// The workspace/job is already fetched. +func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { + workspace, err := q.GetWorkspaceByID(ctx, workspaceID) + if err != nil { + return nil, err + } + + return q.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID) +} + +func (q *querier) GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { + // If we can fetch the workspace, we can fetch the apps. Use the authorized call. + if _, err := q.GetWorkspaceByAgentID(ctx, arg.AgentID); err != nil { + return database.WorkspaceApp{}, err + } + + return q.db.GetWorkspaceAppByAgentIDAndSlug(ctx, arg) +} + +func (q *querier) GetWorkspaceAppsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.WorkspaceApp, error) { + if _, err := q.GetWorkspaceByAgentID(ctx, agentID); err != nil { + return nil, err + } + return q.db.GetWorkspaceAppsByAgentID(ctx, agentID) +} + +// GetWorkspaceAppsByAgentIDs +// The workspace/job is already fetched. +func (q *querier) GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceAppsByAgentIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceBuildByID(ctx context.Context, buildID uuid.UUID) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByID(ctx, buildID) + if err != nil { + return database.WorkspaceBuild{}, err + } + if _, err := q.GetWorkspaceByID(ctx, build.WorkspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return build, nil +} + +func (q *querier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) + if err != nil { + return database.WorkspaceBuild{}, err + } + // Authorized fetch + _, err = q.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + return build, nil +} + +func (q *querier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return q.db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, arg) +} + +func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { + // Authorized call to get the workspace build. If we can read the build, + // we can read the params. + _, err := q.GetWorkspaceBuildByID(ctx, workspaceBuildID) + if err != nil { + return nil, err + } + + return q.db.GetWorkspaceBuildParameters(ctx, workspaceBuildID) +} + +func (q *querier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg database.GetWorkspaceBuildsByWorkspaceIDParams) ([]database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { + return nil, err + } + return q.db.GetWorkspaceBuildsByWorkspaceID(ctx, arg) +} + +// Telemetry related functions. These functions are system functions for returning +// telemetry data. Never called by a user. + +func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByAgentID)(ctx, agentID) +} + +func (q *querier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByID)(ctx, id) +} + +func (q *querier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByOwnerIDAndName)(ctx, arg) +} + +func (q *querier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByWorkspaceAppID)(ctx, workspaceAppID) +} + +func (q *querier) GetWorkspaceProxies(ctx context.Context) ([]database.WorkspaceProxy, error) { + return fetchWithPostFilter(q.auth, func(ctx context.Context, _ interface{}) ([]database.WorkspaceProxy, error) { + return q.db.GetWorkspaceProxies(ctx) + })(ctx, nil) +} + +func (q *querier) GetWorkspaceProxyByHostname(ctx context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return database.WorkspaceProxy{}, err + } + return q.db.GetWorkspaceProxyByHostname(ctx, params) +} + +func (q *querier) GetWorkspaceProxyByID(ctx context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceProxyByID)(ctx, id) +} + +func (q *querier) GetWorkspaceProxyByName(ctx context.Context, name string) (database.WorkspaceProxy, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceProxyByName)(ctx, name) +} + +func (q *querier) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (database.WorkspaceResource, error) { + // TODO: Optimize this + resource, err := q.db.GetWorkspaceResourceByID(ctx, id) + if err != nil { + return database.WorkspaceResource{}, err + } + + _, err = q.GetProvisionerJobByID(ctx, resource.JobID) + if err != nil { + return database.WorkspaceResource{}, err + } + + return resource, nil +} + +// GetWorkspaceResourceMetadataByResourceIDs is only used for build data. +// The workspace/job is already fetched. +func (q *querier) GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourceMetadataByResourceIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { + job, err := q.db.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, err + } + var obj rbac.Objecter + switch job.Type { + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // We don't need to do an authorized check, but this helper function + // handles the job type for us. + // TODO: Do not duplicate auth checks. + tv, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return nil, err + } + if !tv.TemplateID.Valid { + // Orphaned template version + obj = tv.RBACObjectNoTemplate() + } else { + template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return nil, err + } + obj = template.RBACObject() + } + case database.ProvisionerJobTypeWorkspaceBuild: + build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) + if err != nil { + return nil, err + } + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return nil, err + } + obj = workspace + default: + return nil, xerrors.Errorf("unknown job type: %s", job.Type) + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, obj); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourcesByJobID(ctx, jobID) +} + +// GetWorkspaceResourcesByJobIDs is only used for workspace build data. +// The workspace is already fetched. +// TODO: Find a way to replace this with proper authz. +func (q *querier) GetWorkspaceResourcesByJobIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResource, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourcesByJobIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceWorkspace.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedWorkspaces(ctx, arg, prep) +} + +func (q *querier) GetWorkspacesEligibleForAutoStartStop(ctx context.Context, now time.Time) ([]database.Workspace, error) { + return q.db.GetWorkspacesEligibleForAutoStartStop(ctx, now) +} + +func (q *querier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + return insert(q.log, q.auth, + rbac.ResourceAPIKey.WithOwner(arg.UserID.String()), + q.db.InsertAPIKey)(ctx, arg) +} + +func (q *querier) InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (database.Group, error) { + // This method creates a new group. + return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(organizationID), q.db.InsertAllUsersGroup)(ctx, organizationID) +} + +func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { + return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) +} + +func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.InsertDERPMeshKey(ctx, value) +} + +func (q *querier) InsertDeploymentID(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.InsertDeploymentID(ctx, value) +} + +func (q *querier) InsertFile(ctx context.Context, arg database.InsertFileParams) (database.File, error) { + return insert(q.log, q.auth, rbac.ResourceFile.WithOwner(arg.CreatedBy.String()), q.db.InsertFile)(ctx, arg) +} + +func (q *querier) InsertGitAuthLink(ctx context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitAuthLink)(ctx, arg) +} + +func (q *querier) InsertGitSSHKey(ctx context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { + return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitSSHKey)(ctx, arg) +} + +func (q *querier) InsertGroup(ctx context.Context, arg database.InsertGroupParams) (database.Group, error) { + return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(arg.OrganizationID), q.db.InsertGroup)(ctx, arg) +} + +func (q *querier) InsertGroupMember(ctx context.Context, arg database.InsertGroupMemberParams) error { + fetch := func(ctx context.Context, arg database.InsertGroupMemberParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.GroupID) + } + return update(q.log, q.auth, fetch, q.db.InsertGroupMember)(ctx, arg) +} + +func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceLicense); err != nil { + return database.License{}, err + } + return q.db.InsertLicense(ctx, arg) +} + +func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg) +} + +func (q *querier) InsertOrganizationMember(ctx context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + // All roles are added roles. Org member is always implied. + addedRoles := append(arg.Roles, rbac.RoleOrgMember(arg.OrganizationID)) + err := q.canAssignRoles(ctx, &arg.OrganizationID, addedRoles, []string{}) + if err != nil { + return database.OrganizationMember{}, err + } + + obj := rbac.ResourceOrganizationMember.InOrg(arg.OrganizationID).WithID(arg.UserID) + return insert(q.log, q.auth, obj, q.db.InsertOrganizationMember)(ctx, arg) +} + +// TODO: We need to create a ProvisionerDaemon resource type +func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { + // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + // return database.ProvisionerDaemon{}, err + // } + return q.db.InsertProvisionerDaemon(ctx, arg) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { + // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + // return database.ProvisionerJob{}, err + // } + return q.db.InsertProvisionerJob(ctx, arg) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { + // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + // return nil, err + // } + return q.db.InsertProvisionerJobLogs(ctx, arg) +} + +func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.Replica{}, err + } + return q.db.InsertReplica(ctx, arg) +} + +func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) (database.Template, error) { + obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID) + return insert(q.log, q.auth, obj, q.db.InsertTemplate)(ctx, arg) +} + +func (q *querier) InsertTemplateVersion(ctx context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { + if !arg.TemplateID.Valid { + // Making a new template version is the same permission as creating a new template. + err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(arg.OrganizationID)) + if err != nil { + return database.TemplateVersion{}, err + } + } else { + // Must do an authorized fetch to prevent leaking template ids this way. + tpl, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID) + if err != nil { + return database.TemplateVersion{}, err + } + // Check the create permission on the template. + err = q.authorizeContext(ctx, rbac.ActionCreate, tpl) + if err != nil { + return database.TemplateVersion{}, err + } + } + + return q.db.InsertTemplateVersion(ctx, arg) +} + +func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.TemplateVersionParameter{}, err + } + return q.db.InsertTemplateVersionParameter(ctx, arg) +} + +func (q *querier) InsertTemplateVersionVariable(ctx context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.TemplateVersionVariable{}, err + } + return q.db.InsertTemplateVersionVariable(ctx, arg) +} + +func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) { + // Always check if the assigned roles can actually be assigned by this actor. + impliedRoles := append([]string{rbac.RoleMember()}, arg.RBACRoles...) + err := q.canAssignRoles(ctx, nil, impliedRoles, []string{}) + if err != nil { + return database.User{}, err + } + obj := rbac.ResourceUser + return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) +} + +func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { + // This will add the user to all named groups. This counts as updating a group. + // NOTE: instead of checking if the user has permission to update each group, we instead + // check if the user has permission to update *a* group in the org. + fetch := func(ctx context.Context, arg database.InsertUserGroupsByNameParams) (rbac.Objecter, error) { + return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil + } + return update(q.log, q.auth, fetch, q.db.InsertUserGroupsByName)(ctx, arg) +} + +// TODO: Should this be in system.go? +func (q *querier) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceUser.WithID(arg.UserID)); err != nil { + return database.UserLink{}, err + } + return q.db.InsertUserLink(ctx, arg) +} + +func (q *querier) InsertWorkspace(ctx context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { + obj := rbac.ResourceWorkspace.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID) + return insert(q.log, q.auth, obj, q.db.InsertWorkspace)(ctx, arg) +} + +// Provisionerd server functions + +func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.WorkspaceAgent{}, err + } + return q.db.InsertWorkspaceAgent(ctx, arg) +} + +func (q *querier) InsertWorkspaceAgentMetadata(ctx context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { + // We don't check for workspace ownership here since the agent metadata may + // be associated with an orphaned agent used by a dry run build. + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return err + } + + return q.db.InsertWorkspaceAgentMetadata(ctx, arg) +} + +func (q *querier) InsertWorkspaceAgentStartupLogs(ctx context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { + return q.db.InsertWorkspaceAgentStartupLogs(ctx, arg) +} + +func (q *querier) InsertWorkspaceAgentStat(ctx context.Context, arg database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { + // TODO: This is a workspace agent operation. Should users be able to query this? + // Not really sure what this is for. + workspace, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) + if err != nil { + return database.WorkspaceAgentStat{}, err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return database.WorkspaceAgentStat{}, err + } + return q.db.InsertWorkspaceAgentStat(ctx, arg) +} + +func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.WorkspaceApp{}, err + } + return q.db.InsertWorkspaceApp(ctx, arg) +} + +func (q *querier) InsertWorkspaceBuild(ctx context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { + w, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + + var action rbac.Action = rbac.ActionUpdate + if arg.Transition == database.WorkspaceTransitionDelete { + action = rbac.ActionDelete + } + + if err = q.authorizeContext(ctx, action, w); err != nil { + return database.WorkspaceBuild{}, err + } + + return q.db.InsertWorkspaceBuild(ctx, arg) +} + +func (q *querier) InsertWorkspaceBuildParameters(ctx context.Context, arg database.InsertWorkspaceBuildParametersParams) error { + // TODO: Optimize this. We always have the workspace and build already fetched. + build, err := q.db.GetWorkspaceBuildByID(ctx, arg.WorkspaceBuildID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return err + } + + return q.db.InsertWorkspaceBuildParameters(ctx, arg) +} + +func (q *querier) InsertWorkspaceProxy(ctx context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { + return insert(q.log, q.auth, rbac.ResourceWorkspaceProxy, q.db.InsertWorkspaceProxy)(ctx, arg) +} + +func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return database.WorkspaceResource{}, err + } + return q.db.InsertWorkspaceResource(ctx, arg) +} + +func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.InsertWorkspaceResourceMetadata(ctx, arg) +} + +func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { + fetch := func(ctx context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { + return q.db.GetWorkspaceProxyByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) +} + +func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { + return q.db.TryAcquireLock(ctx, id) +} + +func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) { + return q.db.GetAPIKeyByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) +} + +func (q *querier) UpdateGitAuthLink(ctx context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { + fetch := func(ctx context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { + return q.db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGitAuthLink)(ctx, arg) +} + +func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { + fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { + return q.db.GetGitSSHKey(ctx, arg.UserID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGitSSHKey)(ctx, arg) +} + +func (q *querier) UpdateGroupByID(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { + fetch := func(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGroupByID)(ctx, arg) +} + +func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { + // Authorized fetch will check that the actor has read access to the org member since the org member is returned. + member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ + OrganizationID: arg.OrgID, + UserID: arg.UserID, + }) + if err != nil { + return database.OrganizationMember{}, err + } + + // The org member role is always implied. + impliedTypes := append(arg.GrantedRoles, rbac.RoleOrgMember(arg.OrgID)) + added, removed := rbac.ChangeRoleSet(member.Roles, impliedTypes) + err = q.canAssignRoles(ctx, &arg.OrgID, added, removed) + if err != nil { + return database.OrganizationMember{}, err + } + + return q.db.UpdateMemberRoles(ctx, arg) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { + // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + // return err + // } + return q.db.UpdateProvisionerJobByID(ctx, arg) +} + +func (q *querier) UpdateProvisionerJobWithCancelByID(ctx context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { + job, err := q.db.GetProvisionerJobByID(ctx, arg.ID) + if err != nil { + return err + } + + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + build, err := q.db.GetWorkspaceBuildByJobID(ctx, arg.ID) + if err != nil { + return err + } + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return err + } + + template, err := q.db.GetTemplateByID(ctx, workspace.TemplateID) + if err != nil { + return err + } + + // Template can specify if cancels are allowed. + // Would be nice to have a way in the rbac rego to do this. + if !template.AllowUserCancelWorkspaceJobs { + // Only owners can cancel workspace builds + actor, ok := ActorFromContext(ctx) + if !ok { + return NoActorError + } + if !slice.Contains(actor.Roles.Names(), rbac.RoleOwner()) { + return xerrors.Errorf("only owners can cancel workspace builds") + } + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return err + } + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + templateVersion, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return err + } + + if templateVersion.TemplateID.Valid { + template, err := q.db.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) + if err != nil { + return err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObject(template)) + if err != nil { + return err + } + } else { + err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObjectNoTemplate()) + if err != nil { + return err + } + } + default: + return xerrors.Errorf("unknown job type: %q", job.Type) + } + return q.db.UpdateProvisionerJobWithCancelByID(ctx, arg) +} + +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { + // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + // return err + // } + return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg) +} + +func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return database.Replica{}, err + } + return q.db.UpdateReplica(ctx, arg) +} + +func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { + // UpdateTemplateACL uses the ActionCreate action. Only users that can create the template + // may update the ACL. + fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return fetchAndQuery(q.log, q.auth, rbac.ActionCreate, fetch, q.db.UpdateTemplateACLByID)(ctx, arg) +} + +func (q *querier) UpdateTemplateActiveVersionByID(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateTemplateActiveVersionByID)(ctx, arg) +} + +// Deprecated: use SoftDeleteTemplateByID instead. +func (q *querier) UpdateTemplateDeletedByID(ctx context.Context, arg database.UpdateTemplateDeletedByIDParams) error { + return q.SoftDeleteTemplateByID(ctx, arg.ID) +} + +func (q *querier) UpdateTemplateMetaByID(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { + fetch := func(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateMetaByID)(ctx, arg) +} + +func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { + fetch := func(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateScheduleByID)(ctx, arg) +} + +func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) { + // An actor is allowed to update the template version if they are authorized to update the template. + tv, err := q.db.GetTemplateVersionByID(ctx, arg.ID) + if err != nil { + return database.TemplateVersion{}, err + } + var obj rbac.Objecter + if !tv.TemplateID.Valid { + obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return database.TemplateVersion{}, err + } + obj = tpl + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { + return database.TemplateVersion{}, err + } + return q.db.UpdateTemplateVersionByID(ctx, arg) +} + +func (q *querier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { + // An actor is allowed to update the template version description if they are authorized to update the template. + tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID) + if err != nil { + return err + } + var obj rbac.Objecter + if !tv.TemplateID.Valid { + obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return err + } + obj = tpl + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { + return err + } + return q.db.UpdateTemplateVersionDescriptionByJobID(ctx, arg) +} + +func (q *querier) UpdateTemplateVersionGitAuthProvidersByJobID(ctx context.Context, arg database.UpdateTemplateVersionGitAuthProvidersByJobIDParams) error { + // An actor is allowed to update the template version git auth providers if they are authorized to update the template. + tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID) + if err != nil { + return err + } + var obj rbac.Objecter + if !tv.TemplateID.Valid { + obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return err + } + obj = tpl + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { + return err + } + return q.db.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, arg) +} + +// UpdateUserDeletedByID +// Deprecated: Delete this function in favor of 'SoftDeleteUserByID'. Deletes are +// irreversible. +func (q *querier) UpdateUserDeletedByID(ctx context.Context, arg database.UpdateUserDeletedByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateUserDeletedByIDParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + // This uses the rbac.ActionDelete action always as this function should always delete. + // We should delete this function in favor of 'SoftDeleteUserByID'. + return deleteQ(q.log, q.auth, fetch, q.db.UpdateUserDeletedByID)(ctx, arg) +} + +func (q *querier) UpdateUserHashedPassword(ctx context.Context, arg database.UpdateUserHashedPasswordParams) error { + user, err := q.db.GetUserByID(ctx, arg.ID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, user.UserDataRBACObject()) + if err != nil { + // Admins can update passwords for other users. + err = q.authorizeContext(ctx, rbac.ActionUpdate, user.RBACObject()) + if err != nil { + return err + } + } + + return q.db.UpdateUserHashedPassword(ctx, arg) +} + +func (q *querier) UpdateUserLastSeenAt(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { + fetch := func(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLastSeenAt)(ctx, arg) +} + +func (q *querier) UpdateUserLink(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { + fetch := func(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { + return q.db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: arg.UserID, + LoginType: arg.LoginType, + }) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLink)(ctx, arg) +} + +func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return database.UserLink{}, err + } + return q.db.UpdateUserLinkedID(ctx, arg) +} + +func (q *querier) UpdateUserProfile(ctx context.Context, arg database.UpdateUserProfileParams) (database.User, error) { + u, err := q.db.GetUserByID(ctx, arg.ID) + if err != nil { + return database.User{}, err + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, u.UserDataRBACObject()); err != nil { + return database.User{}, err + } + return q.db.UpdateUserProfile(ctx, arg) +} + +// UpdateUserRoles updates the site roles of a user. The validation for this function include more than +// just a basic RBAC check. +func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRolesParams) (database.User, error) { + // We need to fetch the user being updated to identify the change in roles. + // This requires read access on the user in question, since the user is + // returned from this function. + user, err := fetch(q.log, q.auth, q.db.GetUserByID)(ctx, arg.ID) + if err != nil { + return database.User{}, err + } + + // The member role is always implied. + impliedTypes := append(arg.GrantedRoles, rbac.RoleMember()) + // If the changeset is nothing, less rbac checks need to be done. + added, removed := rbac.ChangeRoleSet(user.RBACRoles, impliedTypes) + err = q.canAssignRoles(ctx, nil, added, removed) + if err != nil { + return database.User{}, err + } + + return q.db.UpdateUserRoles(ctx, arg) +} + +func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { + fetch := func(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserStatus)(ctx, arg) +} + +func (q *querier) UpdateWorkspace(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateWorkspace)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { + agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return err + } + + if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentLifecycleStateByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentMetadata(ctx context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { + workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.WorkspaceAgentID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentMetadata(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentStartupByID(ctx context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { + agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return err + } + + if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentStartupByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentStartupLogOverflowByID(ctx context.Context, arg database.UpdateWorkspaceAgentStartupLogOverflowByIDParams) error { + agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return err + } + + if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentStartupLogOverflowByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAppHealthByID(ctx context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { + // TODO: This is a workspace agent operation. Should users be able to query this? + workspace, err := q.db.GetWorkspaceByWorkspaceAppID(ctx, arg.ID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) + if err != nil { + return err + } + return q.db.UpdateWorkspaceAppHealthByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAutostart(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceAutostart)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceBuildByID(ctx context.Context, arg database.UpdateWorkspaceBuildByIDParams) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByID(ctx, arg.ID) + if err != nil { + return database.WorkspaceBuild{}, err + } + + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) + if err != nil { + return database.WorkspaceBuild{}, err + } + + return q.db.UpdateWorkspaceBuildByID(ctx, arg) +} + +// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build. +func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return database.WorkspaceBuild{}, err + } + return q.db.UpdateWorkspaceBuildCostByID(ctx, arg) +} + +// Deprecated: Use SoftDeleteWorkspaceByID +func (q *querier) UpdateWorkspaceDeletedByID(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { + // TODO deleteQ me, placeholder for database.Store + fetch := func(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + // This function is always used to deleteQ. + return deleteQ(q.log, q.auth, fetch, q.db.UpdateWorkspaceDeletedByID)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceLastUsedAt(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceLastUsedAt)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceProxy(ctx context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { + return q.db.GetWorkspaceProxyByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateWorkspaceProxy)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceProxyDeleted(ctx context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceProxyDeletedParams) (database.WorkspaceProxy, error) { + return q.db.GetWorkspaceProxyByID(ctx, arg.ID) + } + return deleteQ(q.log, q.auth, fetch, q.db.UpdateWorkspaceProxyDeleted)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceTTL(ctx context.Context, arg database.UpdateWorkspaceTTLParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceTTLParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceTTL)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceTTLToBeWithinTemplateMax(ctx context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.TemplateID) + } + return fetchAndExec(q.log, q.auth, rbac.ActionUpdate, fetch, q.db.UpdateWorkspaceTTLToBeWithinTemplateMax)(ctx, arg) +} + +func (q *querier) UpsertAppSecurityKey(ctx context.Context, data string) error { + // No authz checks as this is done during startup + return q.db.UpsertAppSecurityKey(ctx, data) +} + +func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.UpsertDefaultProxy(ctx, arg) +} + +func (q *querier) UpsertLastUpdateCheck(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.UpsertLastUpdateCheck(ctx, value) +} + +func (q *querier) UpsertLogoURL(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentValues); err != nil { + return err + } + return q.db.UpsertLogoURL(ctx, value) +} + +func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentValues); err != nil { + return err + } + return q.db.UpsertServiceBanner(ctx, value) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c89829b34542a..eaa99482a812e 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2,8 +2,11 @@ package dbauthz_test import ( "context" + "database/sql" + "encoding/json" "reflect" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -16,6 +19,7 @@ import ( "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/util/slice" ) func TestAsNoActor(t *testing.T) { @@ -155,3 +159,1428 @@ func must[T any](value T, err error) T { } return value } + +func (s *MethodTestSuite) TestAPIKey() { + s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(key.ID).Asserts(key, rbac.ActionDelete).Returns() + })) + s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(key.ID).Asserts(key, rbac.ActionRead).Returns(key) + })) + s.Run("GetAPIKeyByName", s.Subtest(func(db database.Store, check *expects) { + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + TokenName: "marge-cat", + LoginType: database.LoginTypeToken, + }) + check.Args(database.GetAPIKeyByNameParams{ + TokenName: key.TokenName, + UserID: key.UserID, + }).Asserts(key, rbac.ActionRead).Returns(key) + })) + s.Run("GetAPIKeysByLoginType", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) + b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) + _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypeGithub}) + check.Args(database.LoginTypePassword). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("GetAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { + idAB := uuid.New() + idC := uuid.New() + + keyA, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: idAB, LoginType: database.LoginTypeToken}) + keyB, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: idAB, LoginType: database.LoginTypeToken}) + _, _ = dbgen.APIKey(s.T(), db, database.APIKey{UserID: idC, LoginType: database.LoginTypeToken}) + + check.Args(database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: idAB}). + Asserts(keyA, rbac.ActionRead, keyB, rbac.ActionRead). + Returns(slice.New(keyA, keyB)) + })) + s.Run("GetAPIKeysLastUsedAfter", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) + b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) + _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(-time.Hour)}) + check.Args(time.Now()). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("InsertAPIKey", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertAPIKeyParams{ + UserID: u.ID, + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + }).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionCreate) + })) + s.Run("UpdateAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(database.UpdateAPIKeyByIDParams{ + ID: a.ID, + }).Asserts(a, rbac.ActionUpdate).Returns() + })) +} + +func (s *MethodTestSuite) TestAuditLogs() { + s.Run("InsertAuditLog", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertAuditLogParams{ + ResourceType: database.ResourceTypeOrganization, + Action: database.AuditActionCreate, + }).Asserts(rbac.ResourceAuditLog, rbac.ActionCreate) + })) + s.Run("GetAuditLogsOffset", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) + _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) + check.Args(database.GetAuditLogsOffsetParams{ + Limit: 10, + }).Asserts(rbac.ResourceAuditLog, rbac.ActionRead) + })) +} + +func (s *MethodTestSuite) TestFile() { + s.Run("GetFileByHashAndCreator", s.Subtest(func(db database.Store, check *expects) { + f := dbgen.File(s.T(), db, database.File{}) + check.Args(database.GetFileByHashAndCreatorParams{ + Hash: f.Hash, + CreatedBy: f.CreatedBy, + }).Asserts(f, rbac.ActionRead).Returns(f) + })) + s.Run("GetFileByID", s.Subtest(func(db database.Store, check *expects) { + f := dbgen.File(s.T(), db, database.File{}) + check.Args(f.ID).Asserts(f, rbac.ActionRead).Returns(f) + })) + s.Run("InsertFile", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertFileParams{ + CreatedBy: u.ID, + }).Asserts(rbac.ResourceFile.WithOwner(u.ID.String()), rbac.ActionCreate) + })) +} + +func (s *MethodTestSuite) TestGroup() { + s.Run("DeleteGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(g.ID).Asserts(g, rbac.ActionDelete).Returns() + })) + s.Run("DeleteGroupMemberFromGroup", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + m := dbgen.GroupMember(s.T(), db, database.GroupMember{ + GroupID: g.ID, + }) + check.Args(database.DeleteGroupMemberFromGroupParams{ + UserID: m.UserID, + GroupID: g.ID, + }).Asserts(g, rbac.ActionUpdate).Returns() + })) + s.Run("GetGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(g.ID).Asserts(g, rbac.ActionRead).Returns(g) + })) + s.Run("GetGroupByOrgAndName", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.GetGroupByOrgAndNameParams{ + OrganizationID: g.OrganizationID, + Name: g.Name, + }).Asserts(g, rbac.ActionRead).Returns(g) + })) + s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{}) + check.Args(g.ID).Asserts(g, rbac.ActionRead) + })) + s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.ID).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("InsertGroup", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(database.InsertGroupParams{ + OrganizationID: o.ID, + Name: "test", + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("InsertGroupMember", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.InsertGroupMemberParams{ + UserID: uuid.New(), + GroupID: g.ID, + }).Asserts(g, rbac.ActionUpdate).Returns() + })) + s.Run("InsertUserGroupsByName", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) + check.Args(database.InsertUserGroupsByNameParams{ + OrganizationID: o.ID, + UserID: u1.ID, + GroupNames: slice.New(g1.Name, g2.Name), + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() + })) + s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID}) + check.Args(database.DeleteGroupMembersByOrgAndUserParams{ + OrganizationID: o.ID, + UserID: u1.ID, + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() + })) + s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.UpdateGroupByIDParams{ + ID: g.ID, + }).Asserts(g, rbac.ActionUpdate) + })) +} + +func (s *MethodTestSuite) TestProvsionerJob() { + s.Run("Build/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(j.ID).Asserts(w, rbac.ActionRead).Returns(j) + })) + s.Run("TemplateVersion/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) + })) + s.Run("TemplateVersionDryRun/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: v.ID})), + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) + })) + s.Run("Build/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: true}) + w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() + })) + s.Run("BuildFalseCancel/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: false}) + w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() + })) + s.Run("TemplateVersion/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("TemplateVersionNoTemplate/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: uuid.Nil, Valid: false}, + JobID: j.ID, + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObjectNoTemplate(), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("TemplateVersionDryRun/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: v.ID})), + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args([]uuid.UUID{a.ID, b.ID}).Asserts().Returns(slice.New(a, b)) + })) + s.Run("GetProvisionerLogsAfterID", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.GetProvisionerLogsAfterIDParams{ + JobID: j.ID, + }).Asserts(w, rbac.ActionRead).Returns([]database.ProvisionerJobLog{}) + })) +} + +func (s *MethodTestSuite) TestLicense() { + s.Run("GetLicenses", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args().Asserts(l, rbac.ActionRead). + Returns([]database.License{l}) + })) + s.Run("InsertLicense", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertLicenseParams{}). + Asserts(rbac.ResourceLicense, rbac.ActionCreate) + })) + s.Run("UpsertLogoURL", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) + })) + s.Run("UpsertServiceBanner", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) + })) + s.Run("GetLicenseByID", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args(l.ID).Asserts(l, rbac.ActionRead).Returns(l) + })) + s.Run("DeleteLicense", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args(l.ID).Asserts(l, rbac.ActionDelete) + })) + s.Run("GetDeploymentID", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts().Returns("") + })) + s.Run("GetDefaultProxyConfig", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{ + DisplayName: "Default", + IconUrl: "/emojis/1f3e1.png", + }) + })) + s.Run("GetLogoURL", s.Subtest(func(db database.Store, check *expects) { + err := db.UpsertLogoURL(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts().Returns("value") + })) + s.Run("GetServiceBanner", s.Subtest(func(db database.Store, check *expects) { + err := db.UpsertServiceBanner(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts().Returns("value") + })) +} + +func (s *MethodTestSuite) TestOrganization() { + s.Run("GetGroupsByOrganizationID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + b := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + check.Args(o.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns([]database.Group{a, b}) + })) + s.Run("GetOrganizationByID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.ID).Asserts(o, rbac.ActionRead).Returns(o) + })) + s.Run("GetOrganizationByName", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.Name).Asserts(o, rbac.ActionRead).Returns(o) + })) + s.Run("GetOrganizationIDsByMemberIDs", s.Subtest(func(db database.Store, check *expects) { + oa := dbgen.Organization(s.T(), db, database.Organization{}) + ob := dbgen.Organization(s.T(), db, database.Organization{}) + ma := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: oa.ID}) + mb := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: ob.ID}) + check.Args([]uuid.UUID{ma.UserID, mb.UserID}). + Asserts(rbac.ResourceUser.WithID(ma.UserID), rbac.ActionRead, rbac.ResourceUser.WithID(mb.UserID), rbac.ActionRead) + })) + s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) { + mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{}) + check.Args(database.GetOrganizationMemberByUserIDParams{ + OrganizationID: mem.OrganizationID, + UserID: mem.UserID, + }).Asserts(mem, rbac.ActionRead).Returns(mem) + })) + s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) + b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) + check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Organization(s.T(), db, database.Organization{}) + b := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args().Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetOrganizationsByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + a := dbgen.Organization(s.T(), db, database.Organization{}) + _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: a.ID}) + b := dbgen.Organization(s.T(), db, database.Organization{}) + _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: b.ID}) + check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("InsertOrganization", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertOrganizationParams{ + ID: uuid.New(), + Name: "random", + }).Asserts(rbac.ResourceOrganization, rbac.ActionCreate) + })) + s.Run("InsertOrganizationMember", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + + check.Args(database.InsertOrganizationMemberParams{ + OrganizationID: o.ID, + UserID: u.ID, + Roles: []string{rbac.RoleOrgAdmin(o.ID)}, + }).Asserts( + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, + rbac.ResourceOrganizationMember.InOrg(o.ID).WithID(u.ID), rbac.ActionCreate) + })) + s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{ + OrganizationID: o.ID, + UserID: u.ID, + Roles: []string{rbac.RoleOrgAdmin(o.ID)}, + }) + out := mem + out.Roles = []string{} + + check.Args(database.UpdateMemberRolesParams{ + GrantedRoles: []string{}, + UserID: u.ID, + OrgID: o.ID, + }).Asserts( + mem, rbac.ActionRead, + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, // org-mem + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionDelete, // org-admin + ).Returns(out) + })) +} + +func (s *MethodTestSuite) TestWorkspaceProxy() { + s.Run("InsertWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceProxyParams{ + ID: uuid.New(), + }).Asserts(rbac.ResourceWorkspaceProxy, rbac.ActionCreate) + })) + s.Run("RegisterWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args(database.RegisterWorkspaceProxyParams{ + ID: p.ID, + }).Asserts(p, rbac.ActionUpdate) + })) + s.Run("GetWorkspaceProxyByID", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args(p.ID).Asserts(p, rbac.ActionRead).Returns(p) + })) + s.Run("UpdateWorkspaceProxyDeleted", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args(database.UpdateWorkspaceProxyDeletedParams{ + ID: p.ID, + Deleted: true, + }).Asserts(p, rbac.ActionDelete) + })) + s.Run("GetWorkspaceProxies", s.Subtest(func(db database.Store, check *expects) { + p1, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + p2, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args().Asserts(p1, rbac.ActionRead, p2, rbac.ActionRead).Returns(slice.New(p1, p2)) + })) +} + +func (s *MethodTestSuite) TestTemplate() { + s.Run("GetPreviousTemplateVersion", s.Subtest(func(db database.Store, check *expects) { + tvid := uuid.New() + now := time.Now() + o1 := dbgen.Organization(s.T(), db, database.Organization{}) + t1 := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o1.ID, + ActiveVersionID: tvid, + }) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + CreatedAt: now.Add(-time.Hour), + ID: tvid, + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + CreatedAt: now.Add(-2 * time.Hour), + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.GetPreviousTemplateVersionParams{ + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionRead).Returns(b) + })) + s.Run("GetTemplateByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead).Returns(t1) + })) + s.Run("GetTemplateByOrganizationAndName", s.Subtest(func(db database.Store, check *expects) { + o1 := dbgen.Organization(s.T(), db, database.Organization{}) + t1 := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o1.ID, + }) + check.Args(database.GetTemplateByOrganizationAndNameParams{ + Name: t1.Name, + OrganizationID: o1.ID, + }).Asserts(t1, rbac.ActionRead).Returns(t1) + })) + s.Run("GetTemplateVersionByJobID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.JobID).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionByTemplateIDAndName", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.GetTemplateVersionByTemplateIDAndNameParams{ + Name: tv.Name, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionParameters", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns([]database.TemplateVersionParameter{}) + })) + s.Run("GetTemplateVersionVariables", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + tvv1 := dbgen.TemplateVersionVariable(s.T(), db, database.TemplateVersionVariable{ + TemplateVersionID: tv.ID, + }) + check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns([]database.TemplateVersionVariable{tvv1}) + })) + s.Run("GetTemplateGroupRoles", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateUserRoles", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionsByTemplateID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + a := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.GetTemplateVersionsByTemplateIDParams{ + TemplateID: t1.ID, + }).Asserts(t1, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("GetTemplateVersionsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + now := time.Now() + t1 := dbgen.Template(s.T(), db, database.Template{}) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + CreatedAt: now.Add(-time.Hour), + }) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + CreatedAt: now.Add(-2 * time.Hour), + }) + check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), rbac.ActionRead) + })) + s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Template(s.T(), db, database.Template{}) + // No asserts because SQLFilter. + check.Args(database.GetTemplatesWithFilterParams{}). + Asserts().Returns(slice.New(a)) + })) + s.Run("GetAuthorizedTemplates", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Template(s.T(), db, database.Template{}) + // No asserts because SQLFilter. + check.Args(database.GetTemplatesWithFilterParams{}, emptyPreparedAuthorized{}). + Asserts(). + Returns(slice.New(a)) + })) + s.Run("InsertTemplate", s.Subtest(func(db database.Store, check *expects) { + orgID := uuid.New() + check.Args(database.InsertTemplateParams{ + Provisioner: "echo", + OrganizationID: orgID, + }).Asserts(rbac.ResourceTemplate.InOrg(orgID), rbac.ActionCreate) + })) + s.Run("InsertTemplateVersion", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.InsertTemplateVersionParams{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + OrganizationID: t1.OrganizationID, + }).Asserts(t1, rbac.ActionRead, t1, rbac.ActionCreate) + })) + s.Run("SoftDeleteTemplateByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionDelete) + })) + s.Run("UpdateTemplateACLByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateACLByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionCreate).Returns(t1) + })) + s.Run("UpdateTemplateActiveVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{ + ActiveVersionID: uuid.New(), + }) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + ID: t1.ActiveVersionID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.UpdateTemplateActiveVersionByIDParams{ + ID: t1.ID, + ActiveVersionID: tv.ID, + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateTemplateDeletedByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateDeletedByIDParams{ + ID: t1.ID, + Deleted: true, + }).Asserts(t1, rbac.ActionDelete).Returns() + })) + s.Run("UpdateTemplateMetaByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateMetaByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) + s.Run("UpdateTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.UpdateTemplateVersionByIDParams{ + ID: tv.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + Name: tv.Name, + UpdatedAt: tv.UpdatedAt, + }).Asserts(t1, rbac.ActionUpdate).Returns(tv) + })) + s.Run("UpdateTemplateVersionDescriptionByJobID", s.Subtest(func(db database.Store, check *expects) { + jobID := uuid.New() + t1 := dbgen.Template(s.T(), db, database.Template{}) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + JobID: jobID, + }) + check.Args(database.UpdateTemplateVersionDescriptionByJobIDParams{ + JobID: jobID, + Readme: "foo", + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateTemplateVersionGitAuthProvidersByJobID", s.Subtest(func(db database.Store, check *expects) { + jobID := uuid.New() + t1 := dbgen.Template(s.T(), db, database.Template{}) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + JobID: jobID, + }) + check.Args(database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ + JobID: jobID, + GitAuthProviders: []string{}, + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) +} + +func (s *MethodTestSuite) TestUser() { + s.Run("DeleteAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionDelete).Returns() + })) + s.Run("GetQuotaAllowanceForUser", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetQuotaConsumedForUser", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetUserByEmailOrUsername", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetUserByEmailOrUsernameParams{ + Username: u.Username, + Email: u.Email, + }).Asserts(u, rbac.ActionRead).Returns(u) + })) + s.Run("GetUserByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(u) + })) + s.Run("GetUsersByIDs", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now().Add(-time.Hour)}) + b := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now()}) + check.Args([]uuid.UUID{a.ID, b.ID}). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1)) + })) + s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1)) + })) + s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) + b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) + check.Args(database.GetUsersParams{}). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + })) + s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) + b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) + check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + })) + s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertUserParams{ + ID: uuid.New(), + LoginType: database.LoginTypePassword, + }).Asserts(rbac.ResourceRoleAssignment, rbac.ActionCreate, rbac.ResourceUser, rbac.ActionCreate) + })) + s.Run("InsertUserLink", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertUserLinkParams{ + UserID: u.ID, + LoginType: database.LoginTypeOIDC, + }).Asserts(u, rbac.ActionUpdate) + })) + s.Run("SoftDeleteUserByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionDelete).Returns() + })) + s.Run("UpdateUserDeletedByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{Deleted: true}) + check.Args(database.UpdateUserDeletedByIDParams{ + ID: u.ID, + Deleted: true, + }).Asserts(u, rbac.ActionDelete).Returns() + })) + s.Run("UpdateUserHashedPassword", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserHashedPasswordParams{ + ID: u.ID, + }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns() + })) + s.Run("UpdateUserLastSeenAt", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserLastSeenAtParams{ + ID: u.ID, + UpdatedAt: u.UpdatedAt, + LastSeenAt: u.LastSeenAt, + }).Asserts(u, rbac.ActionUpdate).Returns(u) + })) + s.Run("UpdateUserProfile", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserProfileParams{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + UpdatedAt: u.UpdatedAt, + }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns(u) + })) + s.Run("UpdateUserStatus", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserStatusParams{ + ID: u.ID, + Status: u.Status, + UpdatedAt: u.UpdatedAt, + }).Asserts(u, rbac.ActionUpdate).Returns(u) + })) + s.Run("DeleteGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(key.UserID).Asserts(key, rbac.ActionDelete).Returns() + })) + s.Run("GetGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(key.UserID).Asserts(key, rbac.ActionRead).Returns(key) + })) + s.Run("InsertGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertGitSSHKeyParams{ + UserID: u.ID, + }).Asserts(rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String()), rbac.ActionCreate) + })) + s.Run("UpdateGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(database.UpdateGitSSHKeyParams{ + UserID: key.UserID, + UpdatedAt: key.UpdatedAt, + }).Asserts(key, rbac.ActionUpdate).Returns(key) + })) + s.Run("GetGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) + check.Args(database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }).Asserts(link, rbac.ActionRead).Returns(link) + })) + s.Run("InsertGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertGitAuthLinkParams{ + ProviderID: uuid.NewString(), + UserID: u.ID, + }).Asserts(rbac.ResourceUserData.WithOwner(u.ID.String()).WithID(u.ID), rbac.ActionCreate) + })) + s.Run("UpdateGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) + check.Args(database.UpdateGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + UpdatedAt: link.UpdatedAt, + }).Asserts(link, rbac.ActionUpdate).Returns(link) + })) + s.Run("UpdateUserLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + UserID: link.UserID, + LoginType: link.LoginType, + }).Asserts(link, rbac.ActionUpdate).Returns(link) + })) + s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{RBACRoles: []string{rbac.RoleTemplateAdmin()}}) + o := u + o.RBACRoles = []string{rbac.RoleUserAdmin()} + check.Args(database.UpdateUserRolesParams{ + GrantedRoles: []string{rbac.RoleUserAdmin()}, + ID: u.ID, + }).Asserts( + u, rbac.ActionRead, + rbac.ResourceRoleAssignment, rbac.ActionCreate, + rbac.ResourceRoleAssignment, rbac.ActionDelete, + ).Returns(o) + })) +} + +func (s *MethodTestSuite) TestWorkspace() { + s.Run("GetWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(ws.ID).Asserts(ws, rbac.ActionRead) + })) + s.Run("GetWorkspaces", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + // No asserts here because SQLFilter. + check.Args(database.GetWorkspacesParams{}).Asserts() + })) + s.Run("GetAuthorizedWorkspaces", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + // No asserts here because SQLFilter. + check.Args(database.GetWorkspacesParams{}, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(ws.ID).Asserts(ws, rbac.ActionRead).Returns(b) + })) + s.Run("GetLatestWorkspaceBuildsByWorkspaceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args([]uuid.UUID{ws.ID}).Asserts(ws, rbac.ActionRead).Returns(slice.New(b)) + })) + s.Run("GetWorkspaceAgentByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(agt) + })) + s.Run("GetWorkspaceAgentByInstanceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.AuthInstanceID.String).Asserts(ws, rbac.ActionRead).Returns(agt) + })) + s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agt.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceAgentStartupLogOverflowByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentStartupLogOverflowByIDParams{ + ID: agt.ID, + StartupLogsOverflowed: true, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceAgentStartupByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentStartupByIDParams{ + ID: agt.ID, + Subsystem: database.WorkspaceAgentSubsystemNone, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("GetWorkspaceAgentStartupLogsAfter", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.GetWorkspaceAgentStartupLogsAfterParams{ + AgentID: agt.ID, + }).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceAgentStartupLog{}) + })) + s.Run("GetWorkspaceAppByAgentIDAndSlug", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + + check.Args(database.GetWorkspaceAppByAgentIDAndSlugParams{ + AgentID: agt.ID, + Slug: app.Slug, + }).Asserts(ws, rbac.ActionRead).Returns(app) + })) + s.Run("GetWorkspaceAppsByAgentID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.ID).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildByJobID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.JobID).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildByWorkspaceIDAndBuildNumber", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 10}) + check.Args(database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + WorkspaceID: ws.ID, + BuildNumber: build.BuildNumber, + }).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.ID).Asserts(ws, rbac.ActionRead). + Returns([]database.WorkspaceBuildParameter{}) + })) + s.Run("GetWorkspaceBuildsByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 1}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 2}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 3}) + check.Args(database.GetWorkspaceBuildsByWorkspaceIDParams{WorkspaceID: ws.ID}).Asserts(ws, rbac.ActionRead) // ordering + })) + s.Run("GetWorkspaceByAgentID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(ws) + })) + s.Run("GetWorkspaceByOwnerIDAndName", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.GetWorkspaceByOwnerIDAndNameParams{ + OwnerID: ws.OwnerID, + Deleted: ws.Deleted, + Name: ws.Name, + }).Asserts(ws, rbac.ActionRead).Returns(ws) + })) + s.Run("GetWorkspaceResourceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + check.Args(res.ID).Asserts(ws, rbac.ActionRead).Returns(res) + })) + s.Run("Build/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + check.Args(job.ID).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceResource{}) + })) + s.Run("Template/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) + job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) + check.Args(job.ID).Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionRead}).Returns([]database.WorkspaceResource{}) + })) + s.Run("InsertWorkspace", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(database.InsertWorkspaceParams{ + ID: uuid.New(), + OwnerID: u.ID, + OrganizationID: o.ID, + }).Asserts(rbac.ResourceWorkspace.WithOwner(u.ID.String()).InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("Start/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertWorkspaceBuildParams{ + WorkspaceID: w.ID, + Transition: database.WorkspaceTransitionStart, + Reason: database.BuildReasonInitiator, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("Delete/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertWorkspaceBuildParams{ + WorkspaceID: w.ID, + Transition: database.WorkspaceTransitionDelete, + Reason: database.BuildReasonInitiator, + }).Asserts(w, rbac.ActionDelete) + })) + s.Run("InsertWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: w.ID}) + check.Args(database.InsertWorkspaceBuildParametersParams{ + WorkspaceBuildID: b.ID, + Name: []string{"foo", "bar"}, + Value: []string{"baz", "qux"}, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspace", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + expected := w + expected.Name = "" + check.Args(database.UpdateWorkspaceParams{ + ID: w.ID, + }).Asserts(w, rbac.ActionUpdate).Returns(expected) + })) + s.Run("InsertWorkspaceAgentStat", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertWorkspaceAgentStatParams{ + WorkspaceID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspaceAppHealthByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + check.Args(database.UpdateWorkspaceAppHealthByIDParams{ + ID: app.ID, + Health: database.WorkspaceAppHealthDisabled, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceAutostart", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceAutostartParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + check.Args(database.UpdateWorkspaceBuildByIDParams{ + ID: build.ID, + UpdatedAt: build.UpdatedAt, + Deadline: build.Deadline, + ProvisionerState: []byte{}, + }).Asserts(ws, rbac.ActionUpdate).Returns(build) + })) + s.Run("SoftDeleteWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + ws.Deleted = true + check.Args(ws.ID).Asserts(ws, rbac.ActionDelete).Returns() + })) + s.Run("UpdateWorkspaceDeletedByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{Deleted: true}) + check.Args(database.UpdateWorkspaceDeletedByIDParams{ + ID: ws.ID, + Deleted: true, + }).Asserts(ws, rbac.ActionDelete).Returns() + })) + s.Run("UpdateWorkspaceLastUsedAt", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceLastUsedAtParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceTTL", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceTTLParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("GetWorkspaceByWorkspaceAppID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + check.Args(app.ID).Asserts(ws, rbac.ActionRead).Returns(ws) + })) +} + +func (s *MethodTestSuite) TestExtraMethods() { + s.Run("GetProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) { + d, err := db.InsertProvisionerDaemon(context.Background(), database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + }) + s.NoError(err, "insert provisioner daemon") + check.Args().Asserts(d, rbac.ActionRead) + })) +} + +func (s *MethodTestSuite) TestSystemFunctions() { + s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID}) + check.Args(database.UpdateUserLinkedIDParams{ + UserID: u.ID, + LinkedID: l.LinkedID, + LoginType: database.LoginTypeGithub, + }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns(l) + })) + s.Run("UpsertDefaultProxy", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertDefaultProxyParams{}).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns() + })) + s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) { + l := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(l.LinkedID).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(l) + })) + s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) { + l := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(database.GetUserLinkByUserIDLoginTypeParams{ + UserID: l.UserID, + LoginType: l.LoginType, + }).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(l) + })) + s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) { + dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) { + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{}) + check.Args(agt.AuthToken).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(agt) + })) + s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionCreate).Returns() + })) + s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionCreate).Returns() + })) + s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertReplicaParams{ + ID: uuid.New(), + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) { + replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()}) + require.NoError(s.T(), err) + check.Args(database.UpdateReplicaParams{ + ID: replica.ID, + DatabaseLatency: 100, + }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate) + })) + s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) { + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) + require.NoError(s.T(), err) + check.Args(time.Now().Add(time.Hour)).Asserts(rbac.ResourceSystem, rbac.ActionDelete) + })) + s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) { + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) + require.NoError(s.T(), err) + check.Args(time.Now().Add(time.Hour*-1)).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Template(s.T(), db, database.Template{}) + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) { + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + o := b + o.DailyCost = 10 + check.Args(database.UpdateWorkspaceBuildCostByIDParams{ + ID: b.ID, + DailyCost: 10, + }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns(o) + })) + s.Run("UpsertLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionUpdate) + })) + s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { + err := db.UpsertLastUpdateCheck(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{}) + check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("DeleteOldWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) + })) + s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + // TODO: add provisioner job resource type + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ ) + })) + s.Run("GetTemplateVersionsByIDs", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + t2 := dbgen.Template(s.T(), db, database.Template{}) + tv1 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + tv2 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, + }) + tv3 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, + }) + check.Args([]uuid.UUID{tv1.ID, tv2.ID, tv3.ID}). + Asserts(rbac.ResourceSystem, rbac.ActionRead). + Returns(slice.New(tv1, tv2, tv3)) + })) + s.Run("GetWorkspaceAppsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { + aWs := dbgen.Workspace(s.T(), db, database.Workspace{}) + aBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: aWs.ID, JobID: uuid.New()}) + aRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: aBuild.JobID}) + aAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: aRes.ID}) + a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: aAgt.ID}) + + bWs := dbgen.Workspace(s.T(), db, database.Workspace{}) + bBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: bWs.ID, JobID: uuid.New()}) + bRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: bBuild.JobID}) + bAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: bRes.ID}) + b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: bAgt.ID}) + + check.Args([]uuid.UUID{a.AgentID, b.AgentID}). + Asserts(rbac.ResourceSystem, rbac.ActionRead). + Returns([]database.WorkspaceApp{a, b}) + })) + s.Run("GetWorkspaceResourcesByJobIDs", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) + tJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) + + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + wJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + check.Args([]uuid.UUID{tJob.ID, wJob.ID}). + Asserts(rbac.ResourceSystem, rbac.ActionRead). + Returns([]database.WorkspaceResource{}) + })) + s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + a := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + b := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + check.Args([]uuid.UUID{a.ID, b.ID}). + Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentsByResourceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args([]uuid.UUID{res.ID}). + Asserts(rbac.ResourceSystem, rbac.ActionRead). + Returns([]database.WorkspaceAgent{agt}) + })) + s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { + // TODO: add a ProvisionerJob resource type + a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args([]uuid.UUID{a.ID, b.ID}). + Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ ). + Returns(slice.New(a, b)) + })) + s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentParams{ + ID: uuid.New(), + StartupScriptBehavior: database.StartupScriptBehaviorNonBlocking, + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAppParams{ + ID: uuid.New(), + Health: database.WorkspaceAppHealthDisabled, + SharingLevel: database.AppSharingLevelOwner, + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceResourceMetadataParams{ + WorkspaceResourceID: uuid.New(), + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("UpdateWorkspaceAgentConnectionByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: agt.ID, + }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns() + })) + s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerJob resource + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + StartedAt: sql.NullTime{Valid: false}, + }) + check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}). + Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) + })) + s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerJob resource + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: j.ID, + }).Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) + })) + s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerJob resource + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.UpdateProvisionerJobByIDParams{ + ID: j.ID, + UpdatedAt: time.Now(), + }).Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) + })) + s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerJob resource + check.Args(database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) + })) + s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerJob resource + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.InsertProvisionerJobLogsParams{ + JobID: j.ID, + }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) + })) + s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) { + // TODO: we need to create a ProvisionerDaemon resource + check.Args(database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) + })) + s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) { + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{}) + check.Args(database.InsertTemplateVersionParameterParams{ + TemplateVersionID: v.ID, + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) { + r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{}) + check.Args(database.InsertWorkspaceResourceParams{ + ID: r.ID, + Transition: database.WorkspaceTransitionStart, + }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) +} diff --git a/coderd/database/dbauthz/querier.go b/coderd/database/dbauthz/querier.go deleted file mode 100644 index 30d1dfdff3647..0000000000000 --- a/coderd/database/dbauthz/querier.go +++ /dev/null @@ -1,1634 +0,0 @@ -package dbauthz - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "time" - - "github.com/google/uuid" - "golang.org/x/xerrors" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/coderd/util/slice" -) - -func (q *querier) Ping(ctx context.Context) (time.Duration, error) { - return q.db.Ping(ctx) -} - -func (q *querier) AcquireLock(ctx context.Context, id int64) error { - return q.db.AcquireLock(ctx, id) -} - -func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { - return q.db.TryAcquireLock(ctx, id) -} - -// InTx runs the given function in a transaction. -func (q *querier) InTx(function func(querier database.Store) error, txOpts *sql.TxOptions) error { - return q.db.InTx(func(tx database.Store) error { - // Wrap the transaction store in a querier. - wrapped := New(tx, q.auth, q.log) - return function(wrapped) - }, txOpts) -} - -func (q *querier) DeleteAPIKeyByID(ctx context.Context, id string) error { - return deleteQ(q.log, q.auth, q.db.GetAPIKeyByID, q.db.DeleteAPIKeyByID)(ctx, id) -} - -func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { - return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) -} - -func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByNameParams) (database.APIKey, error) { - return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg) -} - -func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) { - return fetchWithPostFilter(q.auth, q.db.GetAPIKeysByLoginType)(ctx, loginType) -} - -func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { - return fetchWithPostFilter(q.auth, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID}) -} - -func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) { - return fetchWithPostFilter(q.auth, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed) -} - -func (q *querier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { - return insert(q.log, q.auth, - rbac.ResourceAPIKey.WithOwner(arg.UserID.String()), - q.db.InsertAPIKey)(ctx, arg) -} - -func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error { - fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) { - return q.db.GetAPIKeyByID(ctx, arg.ID) - } - return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) -} - -func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { - return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) -} - -func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { - // To optimize audit logs, we only check the global audit log permission once. - // This is because we expect a large unbounded set of audit logs, and applying a SQL - // filter would slow down the query for no benefit. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceAuditLog); err != nil { - return nil, err - } - return q.db.GetAuditLogsOffset(ctx, arg) -} - -func (q *querier) GetFileByHashAndCreator(ctx context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { - file, err := q.db.GetFileByHashAndCreator(ctx, arg) - if err != nil { - return database.File{}, err - } - err = q.authorizeContext(ctx, rbac.ActionRead, file) - if err != nil { - // Check the user's access to the file's templates. - if q.authorizeUpdateFileTemplate(ctx, file) != nil { - return database.File{}, err - } - } - - return file, nil -} - -func (q *querier) GetFileByID(ctx context.Context, id uuid.UUID) (database.File, error) { - file, err := q.db.GetFileByID(ctx, id) - if err != nil { - return database.File{}, err - } - err = q.authorizeContext(ctx, rbac.ActionRead, file) - if err != nil { - // Check the user's access to the file's templates. - if q.authorizeUpdateFileTemplate(ctx, file) != nil { - return database.File{}, err - } - } - - return file, nil -} - -// authorizeReadFile is a hotfix for the fact that file permissions are -// independent of template permissions. This function checks if the user has -// update access to any of the file's templates. -func (q *querier) authorizeUpdateFileTemplate(ctx context.Context, file database.File) error { - tpls, err := q.db.GetFileTemplates(ctx, file.ID) - if err != nil { - return err - } - // There __should__ only be 1 template per file, but there can be more than - // 1, so check them all. - for _, tpl := range tpls { - // If the user has update access to any template, they have read access to the file. - if err := q.authorizeContext(ctx, rbac.ActionUpdate, tpl); err == nil { - return nil - } - } - - return NotAuthorizedError{ - Err: xerrors.Errorf("not authorized to read file %s", file.ID), - } -} - -func (q *querier) InsertFile(ctx context.Context, arg database.InsertFileParams) (database.File, error) { - return insert(q.log, q.auth, rbac.ResourceFile.WithOwner(arg.CreatedBy.String()), q.db.InsertFile)(ctx, arg) -} - -func (q *querier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error { - return deleteQ(q.log, q.auth, q.db.GetGroupByID, q.db.DeleteGroupByID)(ctx, id) -} - -func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) error { - // Deleting a group member counts as updating a group. - fetch := func(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) (database.Group, error) { - return q.db.GetGroupByID(ctx, arg.GroupID) - } - return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) -} - -func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { - // This will add the user to all named groups. This counts as updating a group. - // NOTE: instead of checking if the user has permission to update each group, we instead - // check if the user has permission to update *a* group in the org. - fetch := func(ctx context.Context, arg database.InsertUserGroupsByNameParams) (rbac.Objecter, error) { - return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil - } - return update(q.log, q.auth, fetch, q.db.InsertUserGroupsByName)(ctx, arg) -} - -func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { - // This will remove the user from all groups in the org. This counts as updating a group. - // NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead - // check if the caller has permission to update any group in the org. - fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) { - return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil - } - return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg) -} - -func (q *querier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { - return fetch(q.log, q.auth, q.db.GetGroupByID)(ctx, id) -} - -func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { - return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg) -} - -func (q *querier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) { - if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check - return nil, err - } - return q.db.GetGroupMembers(ctx, groupID) -} - -func (q *querier) InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (database.Group, error) { - // This method creates a new group. - return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(organizationID), q.db.InsertAllUsersGroup)(ctx, organizationID) -} - -func (q *querier) InsertGroup(ctx context.Context, arg database.InsertGroupParams) (database.Group, error) { - return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(arg.OrganizationID), q.db.InsertGroup)(ctx, arg) -} - -func (q *querier) InsertGroupMember(ctx context.Context, arg database.InsertGroupMemberParams) error { - fetch := func(ctx context.Context, arg database.InsertGroupMemberParams) (database.Group, error) { - return q.db.GetGroupByID(ctx, arg.GroupID) - } - return update(q.log, q.auth, fetch, q.db.InsertGroupMember)(ctx, arg) -} - -func (q *querier) UpdateGroupByID(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { - fetch := func(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { - return q.db.GetGroupByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGroupByID)(ctx, arg) -} - -func (q *querier) UpdateProvisionerJobWithCancelByID(ctx context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { - job, err := q.db.GetProvisionerJobByID(ctx, arg.ID) - if err != nil { - return err - } - - switch job.Type { - case database.ProvisionerJobTypeWorkspaceBuild: - build, err := q.db.GetWorkspaceBuildByJobID(ctx, arg.ID) - if err != nil { - return err - } - workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) - if err != nil { - return err - } - - template, err := q.db.GetTemplateByID(ctx, workspace.TemplateID) - if err != nil { - return err - } - - // Template can specify if cancels are allowed. - // Would be nice to have a way in the rbac rego to do this. - if !template.AllowUserCancelWorkspaceJobs { - // Only owners can cancel workspace builds - actor, ok := ActorFromContext(ctx) - if !ok { - return NoActorError - } - if !slice.Contains(actor.Roles.Names(), rbac.RoleOwner()) { - return xerrors.Errorf("only owners can cancel workspace builds") - } - } - - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) - if err != nil { - return err - } - case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: - // Authorized call to get template version. - templateVersion, err := authorizedTemplateVersionFromJob(ctx, q, job) - if err != nil { - return err - } - - if templateVersion.TemplateID.Valid { - template, err := q.db.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) - if err != nil { - return err - } - err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObject(template)) - if err != nil { - return err - } - } else { - err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObjectNoTemplate()) - if err != nil { - return err - } - } - default: - return xerrors.Errorf("unknown job type: %q", job.Type) - } - return q.db.UpdateProvisionerJobWithCancelByID(ctx, arg) -} - -func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - job, err := q.db.GetProvisionerJobByID(ctx, id) - if err != nil { - return database.ProvisionerJob{}, err - } - - switch job.Type { - case database.ProvisionerJobTypeWorkspaceBuild: - // Authorized call to get workspace build. If we can read the build, we - // can read the job. - _, err := q.GetWorkspaceBuildByJobID(ctx, id) - if err != nil { - return database.ProvisionerJob{}, err - } - case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: - // Authorized call to get template version. - _, err := authorizedTemplateVersionFromJob(ctx, q, job) - if err != nil { - return database.ProvisionerJob{}, err - } - default: - return database.ProvisionerJob{}, xerrors.Errorf("unknown job type: %q", job.Type) - } - - return job, nil -} - -func (q *querier) GetProvisionerLogsAfterID(ctx context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { - // Authorized read on job lets the actor also read the logs. - _, err := q.GetProvisionerJobByID(ctx, arg.JobID) - if err != nil { - return nil, err - } - return q.db.GetProvisionerLogsAfterID(ctx, arg) -} - -func (q *querier) GetWorkspaceAgentStartupLogsAfter(ctx context.Context, arg database.GetWorkspaceAgentStartupLogsAfterParams) ([]database.WorkspaceAgentStartupLog, error) { - _, err := q.GetWorkspaceAgentByID(ctx, arg.AgentID) - if err != nil { - return nil, err - } - return q.db.GetWorkspaceAgentStartupLogsAfter(ctx, arg) -} - -func (q *querier) GetLicenses(ctx context.Context) ([]database.License, error) { - fetch := func(ctx context.Context, _ interface{}) ([]database.License, error) { - return q.db.GetLicenses(ctx) - } - return fetchWithPostFilter(q.auth, fetch)(ctx, nil) -} - -func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceLicense); err != nil { - return database.License{}, err - } - return q.db.InsertLicense(ctx, arg) -} - -func (q *querier) UpsertLogoURL(ctx context.Context, value string) error { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentValues); err != nil { - return err - } - return q.db.UpsertLogoURL(ctx, value) -} - -func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentValues); err != nil { - return err - } - return q.db.UpsertServiceBanner(ctx, value) -} - -func (q *querier) GetLicenseByID(ctx context.Context, id int32) (database.License, error) { - return fetch(q.log, q.auth, q.db.GetLicenseByID)(ctx, id) -} - -func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { - err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { - _, err := q.db.DeleteLicense(ctx, id) - return err - })(ctx, id) - if err != nil { - return -1, err - } - return id, nil -} - -func (q *querier) GetDefaultProxyConfig(ctx context.Context) (database.GetDefaultProxyConfigRow, error) { - // No authz checks - return q.db.GetDefaultProxyConfig(ctx) -} - -func (q *querier) GetDeploymentID(ctx context.Context) (string, error) { - // No authz checks - return q.db.GetDeploymentID(ctx) -} - -func (q *querier) GetLogoURL(ctx context.Context) (string, error) { - // No authz checks - return q.db.GetLogoURL(ctx) -} - -func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) { - // No authz checks - return q.db.GetAppSecurityKey(ctx) -} - -func (q *querier) UpsertAppSecurityKey(ctx context.Context, data string) error { - // No authz checks as this is done during startup - return q.db.UpsertAppSecurityKey(ctx, data) -} - -func (q *querier) GetServiceBanner(ctx context.Context) (string, error) { - // No authz checks - return q.db.GetServiceBanner(ctx) -} - -func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.ProvisionerDaemon, error) { - fetch := func(ctx context.Context, _ interface{}) ([]database.ProvisionerDaemon, error) { - return q.db.GetProvisionerDaemons(ctx) - } - return fetchWithPostFilter(q.auth, fetch)(ctx, nil) -} - -func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) { - return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID) -} - -func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) { - return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id) -} - -func (q *querier) GetOrganizationByName(ctx context.Context, name string) (database.Organization, error) { - return fetch(q.log, q.auth, q.db.GetOrganizationByName)(ctx, name) -} - -func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { - // TODO: This should be rewritten to return a list of database.OrganizationMember for consistent RBAC objects. - // Currently this row returns a list of org ids per user, which is challenging to check against the RBAC system. - return fetchWithPostFilter(q.auth, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids) -} - -func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg) -} - -func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - return fetchWithPostFilter(q.auth, q.db.GetOrganizationMembershipsByUserID)(ctx, userID) -} - -func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) { - fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) { - return q.db.GetOrganizations(ctx) - } - return fetchWithPostFilter(q.auth, fetch)(ctx, nil) -} - -func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]database.Organization, error) { - return fetchWithPostFilter(q.auth, q.db.GetOrganizationsByUserID)(ctx, userID) -} - -func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { - return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg) -} - -func (q *querier) InsertOrganizationMember(ctx context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { - // All roles are added roles. Org member is always implied. - addedRoles := append(arg.Roles, rbac.RoleOrgMember(arg.OrganizationID)) - err := q.canAssignRoles(ctx, &arg.OrganizationID, addedRoles, []string{}) - if err != nil { - return database.OrganizationMember{}, err - } - - obj := rbac.ResourceOrganizationMember.InOrg(arg.OrganizationID).WithID(arg.UserID) - return insert(q.log, q.auth, obj, q.db.InsertOrganizationMember)(ctx, arg) -} - -func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { - // Authorized fetch will check that the actor has read access to the org member since the org member is returned. - member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ - OrganizationID: arg.OrgID, - UserID: arg.UserID, - }) - if err != nil { - return database.OrganizationMember{}, err - } - - // The org member role is always implied. - impliedTypes := append(arg.GrantedRoles, rbac.RoleOrgMember(arg.OrgID)) - added, removed := rbac.ChangeRoleSet(member.Roles, impliedTypes) - err = q.canAssignRoles(ctx, &arg.OrgID, added, removed) - if err != nil { - return database.OrganizationMember{}, err - } - - return q.db.UpdateMemberRoles(ctx, arg) -} - -func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, removed []string) error { - actor, ok := ActorFromContext(ctx) - if !ok { - return NoActorError - } - - roleAssign := rbac.ResourceRoleAssignment - shouldBeOrgRoles := false - if orgID != nil { - roleAssign = roleAssign.InOrg(*orgID) - shouldBeOrgRoles = true - } - - grantedRoles := append(added, removed...) - // Validate that the roles being assigned are valid. - for _, r := range grantedRoles { - _, isOrgRole := rbac.IsOrgRole(r) - if shouldBeOrgRoles && !isOrgRole { - return xerrors.Errorf("Must only update org roles") - } - if !shouldBeOrgRoles && isOrgRole { - return xerrors.Errorf("Must only update site wide roles") - } - - // All roles should be valid roles - if _, err := rbac.RoleByName(r); err != nil { - return xerrors.Errorf("%q is not a supported role", r) - } - } - - if len(added) > 0 { - if err := q.authorizeContext(ctx, rbac.ActionCreate, roleAssign); err != nil { - return err - } - } - - if len(removed) > 0 { - if err := q.authorizeContext(ctx, rbac.ActionDelete, roleAssign); err != nil { - return err - } - } - - for _, roleName := range grantedRoles { - if !rbac.CanAssignRole(actor.Roles, roleName) { - return xerrors.Errorf("not authorized to assign role %q", roleName) - } - } - - return nil -} - -func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { - version, err := q.db.GetTemplateVersionByJobID(ctx, jobID) - if err != nil { - return nil, err - } - object := version.RBACObjectNoTemplate() - if version.TemplateID.Valid { - tpl, err := q.db.GetTemplateByID(ctx, version.TemplateID.UUID) - if err != nil { - return nil, err - } - object = version.RBACObject(tpl) - } - - err = q.authorizeContext(ctx, rbac.ActionRead, object) - if err != nil { - return nil, err - } - return q.db.GetParameterSchemasByJobID(ctx, jobID) -} - -func (q *querier) GetPreviousTemplateVersion(ctx context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { - // An actor can read the previous template version if they can read the related template. - // If no linked template exists, we check if the actor can read *a* template. - if !arg.TemplateID.Valid { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(arg.OrganizationID)); err != nil { - return database.TemplateVersion{}, err - } - } - if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil { - return database.TemplateVersion{}, err - } - return q.db.GetPreviousTemplateVersion(ctx, arg) -} - -func (q *querier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { - return fetch(q.log, q.auth, q.db.GetTemplateByID)(ctx, id) -} - -func (q *querier) GetTemplateByOrganizationAndName(ctx context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - return fetch(q.log, q.auth, q.db.GetTemplateByOrganizationAndName)(ctx, arg) -} - -func (q *querier) GetTemplateVersionByID(ctx context.Context, tvid uuid.UUID) (database.TemplateVersion, error) { - tv, err := q.db.GetTemplateVersionByID(ctx, tvid) - if err != nil { - return database.TemplateVersion{}, err - } - if !tv.TemplateID.Valid { - // If no linked template exists, check if the actor can read a template in the organization. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { - return database.TemplateVersion{}, err - } - } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { - // An actor can read the template version if they can read the related template. - return database.TemplateVersion{}, err - } - return tv, nil -} - -func (q *querier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { - tv, err := q.db.GetTemplateVersionByJobID(ctx, jobID) - if err != nil { - return database.TemplateVersion{}, err - } - if !tv.TemplateID.Valid { - // If no linked template exists, check if the actor can read a template in the organization. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { - return database.TemplateVersion{}, err - } - } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { - // An actor can read the template version if they can read the related template. - return database.TemplateVersion{}, err - } - return tv, nil -} - -func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { - tv, err := q.db.GetTemplateVersionByTemplateIDAndName(ctx, arg) - if err != nil { - return database.TemplateVersion{}, err - } - if !tv.TemplateID.Valid { - // If no linked template exists, check if the actor can read a template in the organization. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { - return database.TemplateVersion{}, err - } - } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { - // An actor can read the template version if they can read the related template. - return database.TemplateVersion{}, err - } - return tv, nil -} - -func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { - // An actor can read template version parameters if they can read the related template. - tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) - if err != nil { - return nil, err - } - - var object rbac.Objecter - template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - object = rbac.ResourceTemplate.InOrg(tv.OrganizationID) - } else { - object = tv.RBACObject(template) - } - - if err := q.authorizeContext(ctx, rbac.ActionRead, object); err != nil { - return nil, err - } - return q.db.GetTemplateVersionParameters(ctx, templateVersionID) -} - -func (q *querier) GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { - tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) - if err != nil { - return nil, err - } - - var object rbac.Objecter - template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - object = rbac.ResourceTemplate.InOrg(tv.OrganizationID) - } else { - object = tv.RBACObject(template) - } - - if err := q.authorizeContext(ctx, rbac.ActionRead, object); err != nil { - return nil, err - } - return q.db.GetTemplateVersionVariables(ctx, templateVersionID) -} - -func (q *querier) GetTemplateVersionsByTemplateID(ctx context.Context, arg database.GetTemplateVersionsByTemplateIDParams) ([]database.TemplateVersion, error) { - // An actor can read template versions if they can read the related template. - template, err := q.db.GetTemplateByID(ctx, arg.TemplateID) - if err != nil { - return nil, err - } - - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - - return q.db.GetTemplateVersionsByTemplateID(ctx, arg) -} - -func (q *querier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.TemplateVersion, error) { - // An actor can read execute this query if they can read all templates. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.All()); err != nil { - return nil, err - } - return q.db.GetTemplateVersionsCreatedAfter(ctx, createdAt) -} - -func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { - // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. - return q.GetTemplatesWithFilter(ctx, arg) -} - -func (q *querier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceTemplate.Type) - if err != nil { - return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) - } - return q.db.GetAuthorizedTemplates(ctx, arg, prep) -} - -func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) (database.Template, error) { - obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID) - return insert(q.log, q.auth, obj, q.db.InsertTemplate)(ctx, arg) -} - -func (q *querier) InsertTemplateVersion(ctx context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { - if !arg.TemplateID.Valid { - // Making a new template version is the same permission as creating a new template. - err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(arg.OrganizationID)) - if err != nil { - return database.TemplateVersion{}, err - } - } else { - // Must do an authorized fetch to prevent leaking template ids this way. - tpl, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID) - if err != nil { - return database.TemplateVersion{}, err - } - // Check the create permission on the template. - err = q.authorizeContext(ctx, rbac.ActionCreate, tpl) - if err != nil { - return database.TemplateVersion{}, err - } - } - - return q.db.InsertTemplateVersion(ctx, arg) -} - -func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { - // UpdateTemplateACL uses the ActionCreate action. Only users that can create the template - // may update the ACL. - fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { - return q.db.GetTemplateByID(ctx, arg.ID) - } - return fetchAndQuery(q.log, q.auth, rbac.ActionCreate, fetch, q.db.UpdateTemplateACLByID)(ctx, arg) -} - -func (q *querier) UpdateTemplateActiveVersionByID(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { - fetch := func(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) (database.Template, error) { - return q.db.GetTemplateByID(ctx, arg.ID) - } - return update(q.log, q.auth, fetch, q.db.UpdateTemplateActiveVersionByID)(ctx, arg) -} - -func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { - deleteF := func(ctx context.Context, id uuid.UUID) error { - return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ - ID: id, - Deleted: true, - UpdatedAt: database.Now(), - }) - } - return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) -} - -// Deprecated: use SoftDeleteTemplateByID instead. -func (q *querier) UpdateTemplateDeletedByID(ctx context.Context, arg database.UpdateTemplateDeletedByIDParams) error { - return q.SoftDeleteTemplateByID(ctx, arg.ID) -} - -func (q *querier) UpdateTemplateMetaByID(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { - fetch := func(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { - return q.db.GetTemplateByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateMetaByID)(ctx, arg) -} - -func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { - fetch := func(ctx context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { - return q.db.GetTemplateByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateScheduleByID)(ctx, arg) -} - -func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) { - // An actor is allowed to update the template version if they are authorized to update the template. - tv, err := q.db.GetTemplateVersionByID(ctx, arg.ID) - if err != nil { - return database.TemplateVersion{}, err - } - var obj rbac.Objecter - if !tv.TemplateID.Valid { - obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) - } else { - tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - return database.TemplateVersion{}, err - } - obj = tpl - } - if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { - return database.TemplateVersion{}, err - } - return q.db.UpdateTemplateVersionByID(ctx, arg) -} - -func (q *querier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { - // An actor is allowed to update the template version description if they are authorized to update the template. - tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID) - if err != nil { - return err - } - var obj rbac.Objecter - if !tv.TemplateID.Valid { - obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) - } else { - tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - return err - } - obj = tpl - } - if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { - return err - } - return q.db.UpdateTemplateVersionDescriptionByJobID(ctx, arg) -} - -func (q *querier) UpdateTemplateVersionGitAuthProvidersByJobID(ctx context.Context, arg database.UpdateTemplateVersionGitAuthProvidersByJobIDParams) error { - // An actor is allowed to update the template version git auth providers if they are authorized to update the template. - tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID) - if err != nil { - return err - } - var obj rbac.Objecter - if !tv.TemplateID.Valid { - obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) - } else { - tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - return err - } - obj = tpl - } - if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { - return err - } - return q.db.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, arg) -} - -func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - // An actor is authorized to read template group roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateGroupRoles(ctx, id) -} - -func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - // An actor is authorized to query template user roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateUserRoles(ctx, id) -} - -func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { - // TODO: This is not 100% correct because it omits apikey IDs. - err := q.authorizeContext(ctx, rbac.ActionDelete, - rbac.ResourceAPIKey.WithOwner(userID.String())) - if err != nil { - return err - } - return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) -} - -func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { - // TODO: This is not 100% correct because it omits apikey IDs. - err := q.authorizeContext(ctx, rbac.ActionDelete, - rbac.ResourceAPIKey.WithOwner(userID.String())) - if err != nil { - return err - } - return q.db.DeleteAPIKeysByUserID(ctx, userID) -} - -func (q *querier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID) (int64, error) { - err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) - if err != nil { - return -1, err - } - return q.db.GetQuotaAllowanceForUser(ctx, userID) -} - -func (q *querier) GetQuotaConsumedForUser(ctx context.Context, userID uuid.UUID) (int64, error) { - err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) - if err != nil { - return -1, err - } - return q.db.GetQuotaConsumedForUser(ctx, userID) -} - -func (q *querier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { - return fetch(q.log, q.auth, q.db.GetUserByEmailOrUsername)(ctx, arg) -} - -func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User, error) { - return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id) -} - -// GetUsersByIDs is only used for usernames on workspace return data. -// This function should be replaced by joining this data to the workspace query -// itself. -func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]database.User, error) { - for _, uid := range ids { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(uid)); err != nil { - return nil, err - } - } - return q.db.GetUsersByIDs(ctx, ids) -} - -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) -} - -func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) - if err != nil { - return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err) - } - // TODO: This should be the only implementation. - return q.GetAuthorizedUserCount(ctx, arg, prep) -} - -func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { - // TODO: We should use GetUsersWithCount with a better method signature. - return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) -} - -func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - // TODO Implement this with a SQL filter. The count is incorrect without it. - rowUsers, err := q.db.GetUsers(ctx, arg) - if err != nil { - return nil, -1, err - } - - if len(rowUsers) == 0 { - return []database.User{}, 0, nil - } - - act, ok := ActorFromContext(ctx) - if !ok { - return nil, -1, NoActorError - } - - // TODO: Is this correct? Should we return a restricted user? - users := database.ConvertUserRows(rowUsers) - users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) - if err != nil { - return nil, -1, err - } - - return users, rowUsers[0].Count, nil -} - -func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) { - // Always check if the assigned roles can actually be assigned by this actor. - impliedRoles := append([]string{rbac.RoleMember()}, arg.RBACRoles...) - err := q.canAssignRoles(ctx, nil, impliedRoles, []string{}) - if err != nil { - return database.User{}, err - } - obj := rbac.ResourceUser - return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) -} - -// TODO: Should this be in system.go? -func (q *querier) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceUser.WithID(arg.UserID)); err != nil { - return database.UserLink{}, err - } - return q.db.InsertUserLink(ctx, arg) -} - -func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { - deleteF := func(ctx context.Context, id uuid.UUID) error { - return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{ - ID: id, - Deleted: true, - }) - } - return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) -} - -// UpdateUserDeletedByID -// Deprecated: Delete this function in favor of 'SoftDeleteUserByID'. Deletes are -// irreversible. -func (q *querier) UpdateUserDeletedByID(ctx context.Context, arg database.UpdateUserDeletedByIDParams) error { - fetch := func(ctx context.Context, arg database.UpdateUserDeletedByIDParams) (database.User, error) { - return q.db.GetUserByID(ctx, arg.ID) - } - // This uses the rbac.ActionDelete action always as this function should always delete. - // We should delete this function in favor of 'SoftDeleteUserByID'. - return deleteQ(q.log, q.auth, fetch, q.db.UpdateUserDeletedByID)(ctx, arg) -} - -func (q *querier) UpdateUserHashedPassword(ctx context.Context, arg database.UpdateUserHashedPasswordParams) error { - user, err := q.db.GetUserByID(ctx, arg.ID) - if err != nil { - return err - } - - err = q.authorizeContext(ctx, rbac.ActionUpdate, user.UserDataRBACObject()) - if err != nil { - // Admins can update passwords for other users. - err = q.authorizeContext(ctx, rbac.ActionUpdate, user.RBACObject()) - if err != nil { - return err - } - } - - return q.db.UpdateUserHashedPassword(ctx, arg) -} - -func (q *querier) UpdateUserLastSeenAt(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { - fetch := func(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { - return q.db.GetUserByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLastSeenAt)(ctx, arg) -} - -func (q *querier) UpdateUserProfile(ctx context.Context, arg database.UpdateUserProfileParams) (database.User, error) { - u, err := q.db.GetUserByID(ctx, arg.ID) - if err != nil { - return database.User{}, err - } - if err := q.authorizeContext(ctx, rbac.ActionUpdate, u.UserDataRBACObject()); err != nil { - return database.User{}, err - } - return q.db.UpdateUserProfile(ctx, arg) -} - -func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { - fetch := func(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { - return q.db.GetUserByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserStatus)(ctx, arg) -} - -func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { - return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID) -} - -func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { - return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID) -} - -func (q *querier) InsertGitSSHKey(ctx context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { - return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitSSHKey)(ctx, arg) -} - -func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { - fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { - return q.db.GetGitSSHKey(ctx, arg.UserID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGitSSHKey)(ctx, arg) -} - -func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { - return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) -} - -func (q *querier) InsertGitAuthLink(ctx context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { - return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitAuthLink)(ctx, arg) -} - -func (q *querier) UpdateGitAuthLink(ctx context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { - fetch := func(ctx context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { - return q.db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGitAuthLink)(ctx, arg) -} - -func (q *querier) UpdateUserLink(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { - fetch := func(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { - return q.db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ - UserID: arg.UserID, - LoginType: arg.LoginType, - }) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLink)(ctx, arg) -} - -// UpdateUserRoles updates the site roles of a user. The validation for this function include more than -// just a basic RBAC check. -func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRolesParams) (database.User, error) { - // We need to fetch the user being updated to identify the change in roles. - // This requires read access on the user in question, since the user is - // returned from this function. - user, err := fetch(q.log, q.auth, q.db.GetUserByID)(ctx, arg.ID) - if err != nil { - return database.User{}, err - } - - // The member role is always implied. - impliedTypes := append(arg.GrantedRoles, rbac.RoleMember()) - // If the changeset is nothing, less rbac checks need to be done. - added, removed := rbac.ChangeRoleSet(user.RBACRoles, impliedTypes) - err = q.canAssignRoles(ctx, nil, added, removed) - if err != nil { - return database.User{}, err - } - - return q.db.UpdateUserRoles(ctx, arg) -} - -func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. - return q.GetWorkspaces(ctx, arg) -} - -func (q *querier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { - prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceWorkspace.Type) - if err != nil { - return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) - } - return q.db.GetAuthorizedWorkspaces(ctx, arg, prep) -} - -func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { - return database.WorkspaceBuild{}, err - } - return q.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) -} - -func (q *querier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { - // This is not ideal as not all builds will be returned if the workspace cannot be read. - // This should probably be handled differently? Maybe join workspace builds with workspace - // ownership properties and filter on that. - for _, id := range ids { - _, err := q.GetWorkspaceByID(ctx, id) - if err != nil { - return nil, err - } - } - - return q.db.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, ids) -} - -func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - if _, err := q.GetWorkspaceByAgentID(ctx, id); err != nil { - return database.WorkspaceAgent{}, err - } - return q.db.GetWorkspaceAgentByID(ctx, id) -} - -// GetWorkspaceAgentByInstanceID might want to be a system call? Unsure exactly, -// but this will fail. Need to figure out what AuthInstanceID is, and if it -// is essentially an auth token. But the caller using this function is not -// an authenticated user. So this authz check will fail. -func (q *querier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { - agent, err := q.db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) - if err != nil { - return database.WorkspaceAgent{}, err - } - _, err = q.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return database.WorkspaceAgent{}, err - } - return agent, nil -} - -func (q *querier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { - workspace, err := q.GetWorkspaceByID(ctx, workspaceID) - if err != nil { - return nil, err - } - - return q.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID) -} - -func (q *querier) UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { - agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) - if err != nil { - return err - } - - workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return err - } - - if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { - return err - } - - return q.db.UpdateWorkspaceAgentLifecycleStateByID(ctx, arg) -} - -func (q *querier) UpdateWorkspaceAgentStartupLogOverflowByID(ctx context.Context, arg database.UpdateWorkspaceAgentStartupLogOverflowByIDParams) error { - agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) - if err != nil { - return err - } - - workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return err - } - - if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { - return err - } - - return q.db.UpdateWorkspaceAgentStartupLogOverflowByID(ctx, arg) -} - -func (q *querier) UpdateWorkspaceAgentStartupByID(ctx context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { - agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) - if err != nil { - return err - } - - workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return err - } - - if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { - return err - } - - return q.db.UpdateWorkspaceAgentStartupByID(ctx, arg) -} - -func (q *querier) GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - // If we can fetch the workspace, we can fetch the apps. Use the authorized call. - if _, err := q.GetWorkspaceByAgentID(ctx, arg.AgentID); err != nil { - return database.WorkspaceApp{}, err - } - - return q.db.GetWorkspaceAppByAgentIDAndSlug(ctx, arg) -} - -func (q *querier) GetWorkspaceAppsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.WorkspaceApp, error) { - if _, err := q.GetWorkspaceByAgentID(ctx, agentID); err != nil { - return nil, err - } - return q.db.GetWorkspaceAppsByAgentID(ctx, agentID) -} - -func (q *querier) GetWorkspaceBuildByID(ctx context.Context, buildID uuid.UUID) (database.WorkspaceBuild, error) { - build, err := q.db.GetWorkspaceBuildByID(ctx, buildID) - if err != nil { - return database.WorkspaceBuild{}, err - } - if _, err := q.GetWorkspaceByID(ctx, build.WorkspaceID); err != nil { - return database.WorkspaceBuild{}, err - } - return build, nil -} - -func (q *querier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { - build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) - if err != nil { - return database.WorkspaceBuild{}, err - } - // Authorized fetch - _, err = q.GetWorkspaceByID(ctx, build.WorkspaceID) - if err != nil { - return database.WorkspaceBuild{}, err - } - return build, nil -} - -func (q *querier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { - return database.WorkspaceBuild{}, err - } - return q.db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, arg) -} - -func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - // Authorized call to get the workspace build. If we can read the build, - // we can read the params. - _, err := q.GetWorkspaceBuildByID(ctx, workspaceBuildID) - if err != nil { - return nil, err - } - - return q.db.GetWorkspaceBuildParameters(ctx, workspaceBuildID) -} - -func (q *querier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg database.GetWorkspaceBuildsByWorkspaceIDParams) ([]database.WorkspaceBuild, error) { - if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { - return nil, err - } - return q.db.GetWorkspaceBuildsByWorkspaceID(ctx, arg) -} - -func (q *querier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceByAgentID)(ctx, agentID) -} - -func (q *querier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceByID)(ctx, id) -} - -func (q *querier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceByOwnerIDAndName)(ctx, arg) -} - -func (q *querier) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - // TODO: Optimize this - resource, err := q.db.GetWorkspaceResourceByID(ctx, id) - if err != nil { - return database.WorkspaceResource{}, err - } - - _, err = q.GetProvisionerJobByID(ctx, resource.JobID) - if err != nil { - return database.WorkspaceResource{}, err - } - - return resource, nil -} - -func (q *querier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - job, err := q.db.GetProvisionerJobByID(ctx, jobID) - if err != nil { - return nil, err - } - var obj rbac.Objecter - switch job.Type { - case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: - // We don't need to do an authorized check, but this helper function - // handles the job type for us. - // TODO: Do not duplicate auth checks. - tv, err := authorizedTemplateVersionFromJob(ctx, q, job) - if err != nil { - return nil, err - } - if !tv.TemplateID.Valid { - // Orphaned template version - obj = tv.RBACObjectNoTemplate() - } else { - template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) - if err != nil { - return nil, err - } - obj = template.RBACObject() - } - case database.ProvisionerJobTypeWorkspaceBuild: - build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) - if err != nil { - return nil, err - } - workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) - if err != nil { - return nil, err - } - obj = workspace - default: - return nil, xerrors.Errorf("unknown job type: %s", job.Type) - } - - if err := q.authorizeContext(ctx, rbac.ActionRead, obj); err != nil { - return nil, err - } - return q.db.GetWorkspaceResourcesByJobID(ctx, jobID) -} - -func (q *querier) InsertWorkspace(ctx context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { - obj := rbac.ResourceWorkspace.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID) - return insert(q.log, q.auth, obj, q.db.InsertWorkspace)(ctx, arg) -} - -func (q *querier) InsertWorkspaceBuild(ctx context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { - w, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) - if err != nil { - return database.WorkspaceBuild{}, err - } - - var action rbac.Action = rbac.ActionUpdate - if arg.Transition == database.WorkspaceTransitionDelete { - action = rbac.ActionDelete - } - - if err = q.authorizeContext(ctx, action, w); err != nil { - return database.WorkspaceBuild{}, err - } - - return q.db.InsertWorkspaceBuild(ctx, arg) -} - -func (q *querier) InsertWorkspaceBuildParameters(ctx context.Context, arg database.InsertWorkspaceBuildParametersParams) error { - // TODO: Optimize this. We always have the workspace and build already fetched. - build, err := q.db.GetWorkspaceBuildByID(ctx, arg.WorkspaceBuildID) - if err != nil { - return err - } - - workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) - if err != nil { - return err - } - - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) - if err != nil { - return err - } - - return q.db.InsertWorkspaceBuildParameters(ctx, arg) -} - -func (q *querier) UpdateWorkspace(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { - return q.db.GetWorkspaceByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateWorkspace)(ctx, arg) -} - -func (q *querier) InsertWorkspaceAgentStat(ctx context.Context, arg database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { - // TODO: This is a workspace agent operation. Should users be able to query this? - // Not really sure what this is for. - workspace, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) - if err != nil { - return database.WorkspaceAgentStat{}, err - } - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) - if err != nil { - return database.WorkspaceAgentStat{}, err - } - return q.db.InsertWorkspaceAgentStat(ctx, arg) -} - -func (q *querier) InsertWorkspaceAgentMetadata(ctx context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { - // We don't check for workspace ownership here since the agent metadata may - // be associated with an orphaned agent used by a dry run build. - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return err - } - - return q.db.InsertWorkspaceAgentMetadata(ctx, arg) -} - -func (q *querier) UpdateWorkspaceAgentMetadata(ctx context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { - workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.WorkspaceAgentID) - if err != nil { - return err - } - - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) - if err != nil { - return err - } - - return q.db.UpdateWorkspaceAgentMetadata(ctx, arg) -} - -func (q *querier) GetWorkspaceAgentMetadata(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentMetadatum, error) { - workspace, err := q.db.GetWorkspaceByAgentID(ctx, workspaceAgentID) - if err != nil { - return nil, err - } - - err = q.authorizeContext(ctx, rbac.ActionRead, workspace) - if err != nil { - return nil, err - } - - return q.db.GetWorkspaceAgentMetadata(ctx, workspaceAgentID) -} - -func (q *querier) UpdateWorkspaceAppHealthByID(ctx context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { - // TODO: This is a workspace agent operation. Should users be able to query this? - workspace, err := q.db.GetWorkspaceByWorkspaceAppID(ctx, arg.ID) - if err != nil { - return err - } - - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) - if err != nil { - return err - } - return q.db.UpdateWorkspaceAppHealthByID(ctx, arg) -} - -func (q *querier) UpdateWorkspaceAutostart(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) error { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) (database.Workspace, error) { - return q.db.GetWorkspaceByID(ctx, arg.ID) - } - return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceAutostart)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceBuildByID(ctx context.Context, arg database.UpdateWorkspaceBuildByIDParams) (database.WorkspaceBuild, error) { - build, err := q.db.GetWorkspaceBuildByID(ctx, arg.ID) - if err != nil { - return database.WorkspaceBuild{}, err - } - - workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) - if err != nil { - return database.WorkspaceBuild{}, err - } - err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) - if err != nil { - return database.WorkspaceBuild{}, err - } - - return q.db.UpdateWorkspaceBuildByID(ctx, arg) -} - -func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { - return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { - return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ - ID: id, - Deleted: true, - }) - })(ctx, id) -} - -// Deprecated: Use SoftDeleteWorkspaceByID -func (q *querier) UpdateWorkspaceDeletedByID(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { - // TODO deleteQ me, placeholder for database.Store - fetch := func(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) (database.Workspace, error) { - return q.db.GetWorkspaceByID(ctx, arg.ID) - } - // This function is always used to deleteQ. - return deleteQ(q.log, q.auth, fetch, q.db.UpdateWorkspaceDeletedByID)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceLastUsedAt(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) (database.Workspace, error) { - return q.db.GetWorkspaceByID(ctx, arg.ID) - } - return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceLastUsedAt)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceTTLToBeWithinTemplateMax(ctx context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) (database.Template, error) { - return q.db.GetTemplateByID(ctx, arg.TemplateID) - } - return fetchAndExec(q.log, q.auth, rbac.ActionUpdate, fetch, q.db.UpdateWorkspaceTTLToBeWithinTemplateMax)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceTTL(ctx context.Context, arg database.UpdateWorkspaceTTLParams) error { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceTTLParams) (database.Workspace, error) { - return q.db.GetWorkspaceByID(ctx, arg.ID) - } - return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceTTL)(ctx, arg) -} - -func (q *querier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceByWorkspaceAppID)(ctx, workspaceAppID) -} - -func (q *querier) GetWorkspaceProxies(ctx context.Context) ([]database.WorkspaceProxy, error) { - return fetchWithPostFilter(q.auth, func(ctx context.Context, _ interface{}) ([]database.WorkspaceProxy, error) { - return q.db.GetWorkspaceProxies(ctx) - })(ctx, nil) -} - -func (q *querier) GetWorkspaceProxyByID(ctx context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceProxyByID)(ctx, id) -} - -func (q *querier) GetWorkspaceProxyByName(ctx context.Context, name string) (database.WorkspaceProxy, error) { - return fetch(q.log, q.auth, q.db.GetWorkspaceProxyByName)(ctx, name) -} - -func (q *querier) InsertWorkspaceProxy(ctx context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { - return insert(q.log, q.auth, rbac.ResourceWorkspaceProxy, q.db.InsertWorkspaceProxy)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceProxy(ctx context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - return q.db.GetWorkspaceProxyByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateWorkspaceProxy)(ctx, arg) -} - -func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { - fetch := func(ctx context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { - return q.db.GetWorkspaceProxyByID(ctx, arg.ID) - } - return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) -} - -func (q *querier) UpdateWorkspaceProxyDeleted(ctx context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { - fetch := func(ctx context.Context, arg database.UpdateWorkspaceProxyDeletedParams) (database.WorkspaceProxy, error) { - return q.db.GetWorkspaceProxyByID(ctx, arg.ID) - } - return deleteQ(q.log, q.auth, fetch, q.db.UpdateWorkspaceProxyDeleted)(ctx, arg) -} - -func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job database.ProvisionerJob) (database.TemplateVersion, error) { - switch job.Type { - case database.ProvisionerJobTypeTemplateVersionDryRun: - // TODO: This is really unfortunate that we need to inspect the json - // payload. We should fix this. - tmp := struct { - TemplateVersionID uuid.UUID `json:"template_version_id"` - }{} - err := json.Unmarshal(job.Input, &tmp) - if err != nil { - return database.TemplateVersion{}, xerrors.Errorf("dry-run unmarshal: %w", err) - } - // Authorized call to get template version. - tv, err := q.GetTemplateVersionByID(ctx, tmp.TemplateVersionID) - if err != nil { - return database.TemplateVersion{}, err - } - return tv, nil - case database.ProvisionerJobTypeTemplateVersionImport: - // Authorized call to get template version. - tv, err := q.GetTemplateVersionByJobID(ctx, job.ID) - if err != nil { - return database.TemplateVersion{}, err - } - return tv, nil - default: - return database.TemplateVersion{}, xerrors.Errorf("unknown job type: %q", job.Type) - } -} diff --git a/coderd/database/dbauthz/querier_test.go b/coderd/database/dbauthz/querier_test.go deleted file mode 100644 index 00355d6f98e0a..0000000000000 --- a/coderd/database/dbauthz/querier_test.go +++ /dev/null @@ -1,1155 +0,0 @@ -package dbauthz_test - -import ( - "context" - "encoding/json" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/dbgen" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/coderd/util/slice" -) - -func (s *MethodTestSuite) TestAPIKey() { - s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { - key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) - check.Args(key.ID).Asserts(key, rbac.ActionDelete).Returns() - })) - s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { - key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) - check.Args(key.ID).Asserts(key, rbac.ActionRead).Returns(key) - })) - s.Run("GetAPIKeyByName", s.Subtest(func(db database.Store, check *expects) { - key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ - TokenName: "marge-cat", - LoginType: database.LoginTypeToken, - }) - check.Args(database.GetAPIKeyByNameParams{ - TokenName: key.TokenName, - UserID: key.UserID, - }).Asserts(key, rbac.ActionRead).Returns(key) - })) - s.Run("GetAPIKeysByLoginType", s.Subtest(func(db database.Store, check *expects) { - a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) - b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) - _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypeGithub}) - check.Args(database.LoginTypePassword). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead). - Returns(slice.New(a, b)) - })) - s.Run("GetAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { - idAB := uuid.New() - idC := uuid.New() - - keyA, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: idAB, LoginType: database.LoginTypeToken}) - keyB, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: idAB, LoginType: database.LoginTypeToken}) - _, _ = dbgen.APIKey(s.T(), db, database.APIKey{UserID: idC, LoginType: database.LoginTypeToken}) - - check.Args(database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: idAB}). - Asserts(keyA, rbac.ActionRead, keyB, rbac.ActionRead). - Returns(slice.New(keyA, keyB)) - })) - s.Run("GetAPIKeysLastUsedAfter", s.Subtest(func(db database.Store, check *expects) { - a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) - b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) - _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(-time.Hour)}) - check.Args(time.Now()). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead). - Returns(slice.New(a, b)) - })) - s.Run("InsertAPIKey", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertAPIKeyParams{ - UserID: u.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionCreate) - })) - s.Run("UpdateAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { - a, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) - check.Args(database.UpdateAPIKeyByIDParams{ - ID: a.ID, - }).Asserts(a, rbac.ActionUpdate).Returns() - })) -} - -func (s *MethodTestSuite) TestAuditLogs() { - s.Run("InsertAuditLog", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertAuditLogParams{ - ResourceType: database.ResourceTypeOrganization, - Action: database.AuditActionCreate, - }).Asserts(rbac.ResourceAuditLog, rbac.ActionCreate) - })) - s.Run("GetAuditLogsOffset", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) - _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) - check.Args(database.GetAuditLogsOffsetParams{ - Limit: 10, - }).Asserts(rbac.ResourceAuditLog, rbac.ActionRead) - })) -} - -func (s *MethodTestSuite) TestFile() { - s.Run("GetFileByHashAndCreator", s.Subtest(func(db database.Store, check *expects) { - f := dbgen.File(s.T(), db, database.File{}) - check.Args(database.GetFileByHashAndCreatorParams{ - Hash: f.Hash, - CreatedBy: f.CreatedBy, - }).Asserts(f, rbac.ActionRead).Returns(f) - })) - s.Run("GetFileByID", s.Subtest(func(db database.Store, check *expects) { - f := dbgen.File(s.T(), db, database.File{}) - check.Args(f.ID).Asserts(f, rbac.ActionRead).Returns(f) - })) - s.Run("InsertFile", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertFileParams{ - CreatedBy: u.ID, - }).Asserts(rbac.ResourceFile.WithOwner(u.ID.String()), rbac.ActionCreate) - })) -} - -func (s *MethodTestSuite) TestGroup() { - s.Run("DeleteGroupByID", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - check.Args(g.ID).Asserts(g, rbac.ActionDelete).Returns() - })) - s.Run("DeleteGroupMemberFromGroup", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - m := dbgen.GroupMember(s.T(), db, database.GroupMember{ - GroupID: g.ID, - }) - check.Args(database.DeleteGroupMemberFromGroupParams{ - UserID: m.UserID, - GroupID: g.ID, - }).Asserts(g, rbac.ActionUpdate).Returns() - })) - s.Run("GetGroupByID", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - check.Args(g.ID).Asserts(g, rbac.ActionRead).Returns(g) - })) - s.Run("GetGroupByOrgAndName", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - check.Args(database.GetGroupByOrgAndNameParams{ - OrganizationID: g.OrganizationID, - Name: g.Name, - }).Asserts(g, rbac.ActionRead).Returns(g) - })) - s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - _ = dbgen.GroupMember(s.T(), db, database.GroupMember{}) - check.Args(g.ID).Asserts(g, rbac.ActionRead) - })) - s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args(o.ID).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) - })) - s.Run("InsertGroup", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args(database.InsertGroupParams{ - OrganizationID: o.ID, - Name: "test", - }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) - })) - s.Run("InsertGroupMember", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - check.Args(database.InsertGroupMemberParams{ - UserID: uuid.New(), - GroupID: g.ID, - }).Asserts(g, rbac.ActionUpdate).Returns() - })) - s.Run("InsertUserGroupsByName", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - u1 := dbgen.User(s.T(), db, database.User{}) - g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) - check.Args(database.InsertUserGroupsByNameParams{ - OrganizationID: o.ID, - UserID: u1.ID, - GroupNames: slice.New(g1.Name, g2.Name), - }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() - })) - s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - u1 := dbgen.User(s.T(), db, database.User{}) - g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) - _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID}) - check.Args(database.DeleteGroupMembersByOrgAndUserParams{ - OrganizationID: o.ID, - UserID: u1.ID, - }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() - })) - s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { - g := dbgen.Group(s.T(), db, database.Group{}) - check.Args(database.UpdateGroupByIDParams{ - ID: g.ID, - }).Asserts(g, rbac.ActionUpdate) - })) -} - -func (s *MethodTestSuite) TestProvsionerJob() { - s.Run("Build/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) - check.Args(j.ID).Asserts(w, rbac.ActionRead).Returns(j) - })) - s.Run("TemplateVersion/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionImport, - }) - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - }) - check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) - })) - s.Run("TemplateVersionDryRun/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - }) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: must(json.Marshal(struct { - TemplateVersionID uuid.UUID `json:"template_version_id"` - }{TemplateVersionID: v.ID})), - }) - check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) - })) - s.Run("Build/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: true}) - w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) - check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() - })) - s.Run("BuildFalseCancel/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: false}) - w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) - check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() - })) - s.Run("TemplateVersion/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionImport, - }) - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - JobID: j.ID, - }) - check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). - Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() - })) - s.Run("TemplateVersionNoTemplate/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionImport, - }) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: uuid.Nil, Valid: false}, - JobID: j.ID, - }) - check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). - Asserts(v.RBACObjectNoTemplate(), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() - })) - s.Run("TemplateVersionDryRun/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, - }) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: must(json.Marshal(struct { - TemplateVersionID uuid.UUID `json:"template_version_id"` - }{TemplateVersionID: v.ID})), - }) - check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). - Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() - })) - s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - check.Args([]uuid.UUID{a.ID, b.ID}).Asserts().Returns(slice.New(a, b)) - })) - s.Run("GetProvisionerLogsAfterID", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) - check.Args(database.GetProvisionerLogsAfterIDParams{ - JobID: j.ID, - }).Asserts(w, rbac.ActionRead).Returns([]database.ProvisionerJobLog{}) - })) -} - -func (s *MethodTestSuite) TestLicense() { - s.Run("GetLicenses", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) - check.Args().Asserts(l, rbac.ActionRead). - Returns([]database.License{l}) - })) - s.Run("InsertLicense", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertLicenseParams{}). - Asserts(rbac.ResourceLicense, rbac.ActionCreate) - })) - s.Run("UpsertLogoURL", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) - })) - s.Run("UpsertServiceBanner", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) - })) - s.Run("GetLicenseByID", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) - check.Args(l.ID).Asserts(l, rbac.ActionRead).Returns(l) - })) - s.Run("DeleteLicense", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) - check.Args(l.ID).Asserts(l, rbac.ActionDelete) - })) - s.Run("GetDeploymentID", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts().Returns("") - })) - s.Run("GetDefaultProxyConfig", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{ - DisplayName: "Default", - IconUrl: "/emojis/1f3e1.png", - }) - })) - s.Run("GetLogoURL", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertLogoURL(context.Background(), "value") - require.NoError(s.T(), err) - check.Args().Asserts().Returns("value") - })) - s.Run("GetServiceBanner", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertServiceBanner(context.Background(), "value") - require.NoError(s.T(), err) - check.Args().Asserts().Returns("value") - })) -} - -func (s *MethodTestSuite) TestOrganization() { - s.Run("GetGroupsByOrganizationID", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - b := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) - check.Args(o.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead). - Returns([]database.Group{a, b}) - })) - s.Run("GetOrganizationByID", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args(o.ID).Asserts(o, rbac.ActionRead).Returns(o) - })) - s.Run("GetOrganizationByName", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args(o.Name).Asserts(o, rbac.ActionRead).Returns(o) - })) - s.Run("GetOrganizationIDsByMemberIDs", s.Subtest(func(db database.Store, check *expects) { - oa := dbgen.Organization(s.T(), db, database.Organization{}) - ob := dbgen.Organization(s.T(), db, database.Organization{}) - ma := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: oa.ID}) - mb := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: ob.ID}) - check.Args([]uuid.UUID{ma.UserID, mb.UserID}). - Asserts(rbac.ResourceUser.WithID(ma.UserID), rbac.ActionRead, rbac.ResourceUser.WithID(mb.UserID), rbac.ActionRead) - })) - s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) { - mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{}) - check.Args(database.GetOrganizationMemberByUserIDParams{ - OrganizationID: mem.OrganizationID, - UserID: mem.UserID, - }).Asserts(mem, rbac.ActionRead).Returns(mem) - })) - s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) - b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) - check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) - })) - s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.Organization(s.T(), db, database.Organization{}) - b := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args().Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) - })) - s.Run("GetOrganizationsByUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - a := dbgen.Organization(s.T(), db, database.Organization{}) - _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: a.ID}) - b := dbgen.Organization(s.T(), db, database.Organization{}) - _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: b.ID}) - check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) - })) - s.Run("InsertOrganization", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertOrganizationParams{ - ID: uuid.New(), - Name: "random", - }).Asserts(rbac.ResourceOrganization, rbac.ActionCreate) - })) - s.Run("InsertOrganizationMember", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - u := dbgen.User(s.T(), db, database.User{}) - - check.Args(database.InsertOrganizationMemberParams{ - OrganizationID: o.ID, - UserID: u.ID, - Roles: []string{rbac.RoleOrgAdmin(o.ID)}, - }).Asserts( - rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, - rbac.ResourceOrganizationMember.InOrg(o.ID).WithID(u.ID), rbac.ActionCreate) - })) - s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) { - o := dbgen.Organization(s.T(), db, database.Organization{}) - u := dbgen.User(s.T(), db, database.User{}) - mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{ - OrganizationID: o.ID, - UserID: u.ID, - Roles: []string{rbac.RoleOrgAdmin(o.ID)}, - }) - out := mem - out.Roles = []string{} - - check.Args(database.UpdateMemberRolesParams{ - GrantedRoles: []string{}, - UserID: u.ID, - OrgID: o.ID, - }).Asserts( - mem, rbac.ActionRead, - rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, // org-mem - rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionDelete, // org-admin - ).Returns(out) - })) -} - -func (s *MethodTestSuite) TestWorkspaceProxy() { - s.Run("InsertWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceProxyParams{ - ID: uuid.New(), - }).Asserts(rbac.ResourceWorkspaceProxy, rbac.ActionCreate) - })) - s.Run("RegisterWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(database.RegisterWorkspaceProxyParams{ - ID: p.ID, - }).Asserts(p, rbac.ActionUpdate) - })) - s.Run("GetWorkspaceProxyByID", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(p.ID).Asserts(p, rbac.ActionRead).Returns(p) - })) - s.Run("UpdateWorkspaceProxyDeleted", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(database.UpdateWorkspaceProxyDeletedParams{ - ID: p.ID, - Deleted: true, - }).Asserts(p, rbac.ActionDelete) - })) - s.Run("GetWorkspaceProxies", s.Subtest(func(db database.Store, check *expects) { - p1, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - p2, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args().Asserts(p1, rbac.ActionRead, p2, rbac.ActionRead).Returns(slice.New(p1, p2)) - })) -} - -func (s *MethodTestSuite) TestTemplate() { - s.Run("GetPreviousTemplateVersion", s.Subtest(func(db database.Store, check *expects) { - tvid := uuid.New() - now := time.Now() - o1 := dbgen.Organization(s.T(), db, database.Organization{}) - t1 := dbgen.Template(s.T(), db, database.Template{ - OrganizationID: o1.ID, - ActiveVersionID: tvid, - }) - _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - CreatedAt: now.Add(-time.Hour), - ID: tvid, - Name: t1.Name, - OrganizationID: o1.ID, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - CreatedAt: now.Add(-2 * time.Hour), - Name: t1.Name, - OrganizationID: o1.ID, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(database.GetPreviousTemplateVersionParams{ - Name: t1.Name, - OrganizationID: o1.ID, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }).Asserts(t1, rbac.ActionRead).Returns(b) - })) - s.Run("GetTemplateByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(t1.ID).Asserts(t1, rbac.ActionRead).Returns(t1) - })) - s.Run("GetTemplateByOrganizationAndName", s.Subtest(func(db database.Store, check *expects) { - o1 := dbgen.Organization(s.T(), db, database.Organization{}) - t1 := dbgen.Template(s.T(), db, database.Template{ - OrganizationID: o1.ID, - }) - check.Args(database.GetTemplateByOrganizationAndNameParams{ - Name: t1.Name, - OrganizationID: o1.ID, - }).Asserts(t1, rbac.ActionRead).Returns(t1) - })) - s.Run("GetTemplateVersionByJobID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(tv.JobID).Asserts(t1, rbac.ActionRead).Returns(tv) - })) - s.Run("GetTemplateVersionByTemplateIDAndName", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(database.GetTemplateVersionByTemplateIDAndNameParams{ - Name: tv.Name, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }).Asserts(t1, rbac.ActionRead).Returns(tv) - })) - s.Run("GetTemplateVersionParameters", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns([]database.TemplateVersionParameter{}) - })) - s.Run("GetTemplateVersionVariables", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - tvv1 := dbgen.TemplateVersionVariable(s.T(), db, database.TemplateVersionVariable{ - TemplateVersionID: tv.ID, - }) - check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns([]database.TemplateVersionVariable{tvv1}) - })) - s.Run("GetTemplateGroupRoles", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(t1.ID).Asserts(t1, rbac.ActionRead) - })) - s.Run("GetTemplateUserRoles", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(t1.ID).Asserts(t1, rbac.ActionRead) - })) - s.Run("GetTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns(tv) - })) - s.Run("GetTemplateVersionsByTemplateID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - a := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(database.GetTemplateVersionsByTemplateIDParams{ - TemplateID: t1.ID, - }).Asserts(t1, rbac.ActionRead). - Returns(slice.New(a, b)) - })) - s.Run("GetTemplateVersionsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - now := time.Now() - t1 := dbgen.Template(s.T(), db, database.Template{}) - _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - CreatedAt: now.Add(-time.Hour), - }) - _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - CreatedAt: now.Add(-2 * time.Hour), - }) - check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), rbac.ActionRead) - })) - s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.Template(s.T(), db, database.Template{}) - // No asserts because SQLFilter. - check.Args(database.GetTemplatesWithFilterParams{}). - Asserts().Returns(slice.New(a)) - })) - s.Run("GetAuthorizedTemplates", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.Template(s.T(), db, database.Template{}) - // No asserts because SQLFilter. - check.Args(database.GetTemplatesWithFilterParams{}, emptyPreparedAuthorized{}). - Asserts(). - Returns(slice.New(a)) - })) - s.Run("InsertTemplate", s.Subtest(func(db database.Store, check *expects) { - orgID := uuid.New() - check.Args(database.InsertTemplateParams{ - Provisioner: "echo", - OrganizationID: orgID, - }).Asserts(rbac.ResourceTemplate.InOrg(orgID), rbac.ActionCreate) - })) - s.Run("InsertTemplateVersion", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(database.InsertTemplateVersionParams{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - OrganizationID: t1.OrganizationID, - }).Asserts(t1, rbac.ActionRead, t1, rbac.ActionCreate) - })) - s.Run("SoftDeleteTemplateByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(t1.ID).Asserts(t1, rbac.ActionDelete) - })) - s.Run("UpdateTemplateACLByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(database.UpdateTemplateACLByIDParams{ - ID: t1.ID, - }).Asserts(t1, rbac.ActionCreate).Returns(t1) - })) - s.Run("UpdateTemplateActiveVersionByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{ - ActiveVersionID: uuid.New(), - }) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - ID: t1.ActiveVersionID, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(database.UpdateTemplateActiveVersionByIDParams{ - ID: t1.ID, - ActiveVersionID: tv.ID, - }).Asserts(t1, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateTemplateDeletedByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(database.UpdateTemplateDeletedByIDParams{ - ID: t1.ID, - Deleted: true, - }).Asserts(t1, rbac.ActionDelete).Returns() - })) - s.Run("UpdateTemplateMetaByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - check.Args(database.UpdateTemplateMetaByIDParams{ - ID: t1.ID, - }).Asserts(t1, rbac.ActionUpdate) - })) - s.Run("UpdateTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - check.Args(database.UpdateTemplateVersionByIDParams{ - ID: tv.ID, - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - Name: tv.Name, - UpdatedAt: tv.UpdatedAt, - }).Asserts(t1, rbac.ActionUpdate).Returns(tv) - })) - s.Run("UpdateTemplateVersionDescriptionByJobID", s.Subtest(func(db database.Store, check *expects) { - jobID := uuid.New() - t1 := dbgen.Template(s.T(), db, database.Template{}) - _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - JobID: jobID, - }) - check.Args(database.UpdateTemplateVersionDescriptionByJobIDParams{ - JobID: jobID, - Readme: "foo", - }).Asserts(t1, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateTemplateVersionGitAuthProvidersByJobID", s.Subtest(func(db database.Store, check *expects) { - jobID := uuid.New() - t1 := dbgen.Template(s.T(), db, database.Template{}) - _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - JobID: jobID, - }) - check.Args(database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ - JobID: jobID, - GitAuthProviders: []string{}, - }).Asserts(t1, rbac.ActionUpdate).Returns() - })) -} - -func (s *MethodTestSuite) TestUser() { - s.Run("DeleteAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionDelete).Returns() - })) - s.Run("GetQuotaAllowanceForUser", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) - })) - s.Run("GetQuotaConsumedForUser", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) - })) - s.Run("GetUserByEmailOrUsername", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetUserByEmailOrUsernameParams{ - Username: u.Username, - Email: u.Email, - }).Asserts(u, rbac.ActionRead).Returns(u) - })) - s.Run("GetUserByID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(u) - })) - s.Run("GetUsersByIDs", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now().Add(-time.Hour)}) - b := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now()}) - check.Args([]uuid.UUID{a.ID, b.ID}). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead). - Returns(slice.New(a, b)) - })) - s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1)) - })) - s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1)) - })) - s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) - check.Args(database.GetUsersParams{}). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead) - })) - s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) - check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) - })) - s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertUserParams{ - ID: uuid.New(), - LoginType: database.LoginTypePassword, - }).Asserts(rbac.ResourceRoleAssignment, rbac.ActionCreate, rbac.ResourceUser, rbac.ActionCreate) - })) - s.Run("InsertUserLink", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertUserLinkParams{ - UserID: u.ID, - LoginType: database.LoginTypeOIDC, - }).Asserts(u, rbac.ActionUpdate) - })) - s.Run("SoftDeleteUserByID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(u, rbac.ActionDelete).Returns() - })) - s.Run("UpdateUserDeletedByID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{Deleted: true}) - check.Args(database.UpdateUserDeletedByIDParams{ - ID: u.ID, - Deleted: true, - }).Asserts(u, rbac.ActionDelete).Returns() - })) - s.Run("UpdateUserHashedPassword", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserHashedPasswordParams{ - ID: u.ID, - }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns() - })) - s.Run("UpdateUserLastSeenAt", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserLastSeenAtParams{ - ID: u.ID, - UpdatedAt: u.UpdatedAt, - LastSeenAt: u.LastSeenAt, - }).Asserts(u, rbac.ActionUpdate).Returns(u) - })) - s.Run("UpdateUserProfile", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserProfileParams{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - UpdatedAt: u.UpdatedAt, - }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns(u) - })) - s.Run("UpdateUserStatus", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.UpdateUserStatusParams{ - ID: u.ID, - Status: u.Status, - UpdatedAt: u.UpdatedAt, - }).Asserts(u, rbac.ActionUpdate).Returns(u) - })) - s.Run("DeleteGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) - check.Args(key.UserID).Asserts(key, rbac.ActionDelete).Returns() - })) - s.Run("GetGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) - check.Args(key.UserID).Asserts(key, rbac.ActionRead).Returns(key) - })) - s.Run("InsertGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertGitSSHKeyParams{ - UserID: u.ID, - }).Asserts(rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String()), rbac.ActionCreate) - })) - s.Run("UpdateGitSSHKey", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) - check.Args(database.UpdateGitSSHKeyParams{ - UserID: key.UserID, - UpdatedAt: key.UpdatedAt, - }).Asserts(key, rbac.ActionUpdate).Returns(key) - })) - s.Run("GetGitAuthLink", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) - check.Args(database.GetGitAuthLinkParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - }).Asserts(link, rbac.ActionRead).Returns(link) - })) - s.Run("InsertGitAuthLink", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(database.InsertGitAuthLinkParams{ - ProviderID: uuid.NewString(), - UserID: u.ID, - }).Asserts(rbac.ResourceUserData.WithOwner(u.ID.String()).WithID(u.ID), rbac.ActionCreate) - })) - s.Run("UpdateGitAuthLink", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) - check.Args(database.UpdateGitAuthLinkParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: link.OAuthExpiry, - UpdatedAt: link.UpdatedAt, - }).Asserts(link, rbac.ActionUpdate).Returns(link) - })) - s.Run("UpdateUserLink", s.Subtest(func(db database.Store, check *expects) { - link := dbgen.UserLink(s.T(), db, database.UserLink{}) - check.Args(database.UpdateUserLinkParams{ - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: link.OAuthExpiry, - UserID: link.UserID, - LoginType: link.LoginType, - }).Asserts(link, rbac.ActionUpdate).Returns(link) - })) - s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{RBACRoles: []string{rbac.RoleTemplateAdmin()}}) - o := u - o.RBACRoles = []string{rbac.RoleUserAdmin()} - check.Args(database.UpdateUserRolesParams{ - GrantedRoles: []string{rbac.RoleUserAdmin()}, - ID: u.ID, - }).Asserts( - u, rbac.ActionRead, - rbac.ResourceRoleAssignment, rbac.ActionCreate, - rbac.ResourceRoleAssignment, rbac.ActionDelete, - ).Returns(o) - })) -} - -func (s *MethodTestSuite) TestWorkspace() { - s.Run("GetWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(ws.ID).Asserts(ws, rbac.ActionRead) - })) - s.Run("GetWorkspaces", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.Workspace(s.T(), db, database.Workspace{}) - _ = dbgen.Workspace(s.T(), db, database.Workspace{}) - // No asserts here because SQLFilter. - check.Args(database.GetWorkspacesParams{}).Asserts() - })) - s.Run("GetAuthorizedWorkspaces", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.Workspace(s.T(), db, database.Workspace{}) - _ = dbgen.Workspace(s.T(), db, database.Workspace{}) - // No asserts here because SQLFilter. - check.Args(database.GetWorkspacesParams{}, emptyPreparedAuthorized{}).Asserts() - })) - s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args(ws.ID).Asserts(ws, rbac.ActionRead).Returns(b) - })) - s.Run("GetLatestWorkspaceBuildsByWorkspaceIDs", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args([]uuid.UUID{ws.ID}).Asserts(ws, rbac.ActionRead).Returns(slice.New(b)) - })) - s.Run("GetWorkspaceAgentByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(agt) - })) - s.Run("GetWorkspaceAgentByInstanceID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(agt.AuthInstanceID.String).Asserts(ws, rbac.ActionRead).Returns(agt) - })) - s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpdateWorkspaceAgentLifecycleStateByIDParams{ - ID: agt.ID, - LifecycleState: database.WorkspaceAgentLifecycleStateCreated, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateWorkspaceAgentStartupLogOverflowByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpdateWorkspaceAgentStartupLogOverflowByIDParams{ - ID: agt.ID, - StartupLogsOverflowed: true, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateWorkspaceAgentStartupByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpdateWorkspaceAgentStartupByIDParams{ - ID: agt.ID, - Subsystem: database.WorkspaceAgentSubsystemNone, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("GetWorkspaceAgentStartupLogsAfter", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.GetWorkspaceAgentStartupLogsAfterParams{ - AgentID: agt.ID, - }).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceAgentStartupLog{}) - })) - s.Run("GetWorkspaceAppByAgentIDAndSlug", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) - - check.Args(database.GetWorkspaceAppByAgentIDAndSlugParams{ - AgentID: agt.ID, - Slug: app.Slug, - }).Asserts(ws, rbac.ActionRead).Returns(app) - })) - s.Run("GetWorkspaceAppsByAgentID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) - b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) - - check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(slice.New(a, b)) - })) - s.Run("GetWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args(build.ID).Asserts(ws, rbac.ActionRead).Returns(build) - })) - s.Run("GetWorkspaceBuildByJobID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args(build.JobID).Asserts(ws, rbac.ActionRead).Returns(build) - })) - s.Run("GetWorkspaceBuildByWorkspaceIDAndBuildNumber", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 10}) - check.Args(database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ - WorkspaceID: ws.ID, - BuildNumber: build.BuildNumber, - }).Asserts(ws, rbac.ActionRead).Returns(build) - })) - s.Run("GetWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) - check.Args(build.ID).Asserts(ws, rbac.ActionRead). - Returns([]database.WorkspaceBuildParameter{}) - })) - s.Run("GetWorkspaceBuildsByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 1}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 2}) - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 3}) - check.Args(database.GetWorkspaceBuildsByWorkspaceIDParams{WorkspaceID: ws.ID}).Asserts(ws, rbac.ActionRead) // ordering - })) - s.Run("GetWorkspaceByAgentID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(ws) - })) - s.Run("GetWorkspaceByOwnerIDAndName", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.GetWorkspaceByOwnerIDAndNameParams{ - OwnerID: ws.OwnerID, - Deleted: ws.Deleted, - Name: ws.Name, - }).Asserts(ws, rbac.ActionRead).Returns(ws) - })) - s.Run("GetWorkspaceResourceByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - check.Args(res.ID).Asserts(ws, rbac.ActionRead).Returns(res) - })) - s.Run("Build/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - check.Args(job.ID).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceResource{}) - })) - s.Run("Template/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) - job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) - check.Args(job.ID).Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionRead}).Returns([]database.WorkspaceResource{}) - })) - s.Run("InsertWorkspace", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - o := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args(database.InsertWorkspaceParams{ - ID: uuid.New(), - OwnerID: u.ID, - OrganizationID: o.ID, - }).Asserts(rbac.ResourceWorkspace.WithOwner(u.ID.String()).InOrg(o.ID), rbac.ActionCreate) - })) - s.Run("Start/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.InsertWorkspaceBuildParams{ - WorkspaceID: w.ID, - Transition: database.WorkspaceTransitionStart, - Reason: database.BuildReasonInitiator, - }).Asserts(w, rbac.ActionUpdate) - })) - s.Run("Delete/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.InsertWorkspaceBuildParams{ - WorkspaceID: w.ID, - Transition: database.WorkspaceTransitionDelete, - Reason: database.BuildReasonInitiator, - }).Asserts(w, rbac.ActionDelete) - })) - s.Run("InsertWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: w.ID}) - check.Args(database.InsertWorkspaceBuildParametersParams{ - WorkspaceBuildID: b.ID, - Name: []string{"foo", "bar"}, - Value: []string{"baz", "qux"}, - }).Asserts(w, rbac.ActionUpdate) - })) - s.Run("UpdateWorkspace", s.Subtest(func(db database.Store, check *expects) { - w := dbgen.Workspace(s.T(), db, database.Workspace{}) - expected := w - expected.Name = "" - check.Args(database.UpdateWorkspaceParams{ - ID: w.ID, - }).Asserts(w, rbac.ActionUpdate).Returns(expected) - })) - s.Run("InsertWorkspaceAgentStat", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.InsertWorkspaceAgentStatParams{ - WorkspaceID: ws.ID, - }).Asserts(ws, rbac.ActionUpdate) - })) - s.Run("UpdateWorkspaceAppHealthByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) - check.Args(database.UpdateWorkspaceAppHealthByIDParams{ - ID: app.ID, - Health: database.WorkspaceAppHealthDisabled, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateWorkspaceAutostart", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.UpdateWorkspaceAutostartParams{ - ID: ws.ID, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - check.Args(database.UpdateWorkspaceBuildByIDParams{ - ID: build.ID, - UpdatedAt: build.UpdatedAt, - Deadline: build.Deadline, - ProvisionerState: []byte{}, - }).Asserts(ws, rbac.ActionUpdate).Returns(build) - })) - s.Run("SoftDeleteWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - ws.Deleted = true - check.Args(ws.ID).Asserts(ws, rbac.ActionDelete).Returns() - })) - s.Run("UpdateWorkspaceDeletedByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{Deleted: true}) - check.Args(database.UpdateWorkspaceDeletedByIDParams{ - ID: ws.ID, - Deleted: true, - }).Asserts(ws, rbac.ActionDelete).Returns() - })) - s.Run("UpdateWorkspaceLastUsedAt", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.UpdateWorkspaceLastUsedAtParams{ - ID: ws.ID, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("UpdateWorkspaceTTL", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - check.Args(database.UpdateWorkspaceTTLParams{ - ID: ws.ID, - }).Asserts(ws, rbac.ActionUpdate).Returns() - })) - s.Run("GetWorkspaceByWorkspaceAppID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) - check.Args(app.ID).Asserts(ws, rbac.ActionRead).Returns(ws) - })) -} - -func (s *MethodTestSuite) TestExtraMethods() { - s.Run("GetProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) { - d, err := db.InsertProvisionerDaemon(context.Background(), database.InsertProvisionerDaemonParams{ - ID: uuid.New(), - }) - s.NoError(err, "insert provisioner daemon") - check.Args().Asserts(d, rbac.ActionRead) - })) -} diff --git a/coderd/database/dbauthz/system.go b/coderd/database/dbauthz/system.go deleted file mode 100644 index f1ca0a686e29a..0000000000000 --- a/coderd/database/dbauthz/system.go +++ /dev/null @@ -1,440 +0,0 @@ -package dbauthz - -import ( - "context" - "time" - - "github.com/google/uuid" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/rbac" -) - -func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetFileTemplates(ctx, fileID) -} - -// GetWorkspaceAppsByAgentIDs -// The workspace/job is already fetched. -func (q *querier) GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceAppsByAgentIDs(ctx, ids) -} - -// GetWorkspaceAgentsByResourceIDs -// The workspace/job is already fetched. -func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids) -} - -// GetWorkspaceResourceMetadataByResourceIDs is only used for build data. -// The workspace/job is already fetched. -func (q *querier) GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceResourceMetadataByResourceIDs(ctx, ids) -} - -// TODO: we need to add a provisioner job resource -func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { - // if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - // return nil, err - // } - return q.db.GetProvisionerJobsByIDs(ctx, ids) -} - -// GetTemplateVersionsByIDs is only used for workspace build data. -// The workspace is already fetched. -func (q *querier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetTemplateVersionsByIDs(ctx, ids) -} - -// GetWorkspaceResourcesByJobIDs is only used for workspace build data. -// The workspace is already fetched. -// TODO: Find a way to replace this with proper authz. -func (q *querier) GetWorkspaceResourcesByJobIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResource, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceResourcesByJobIDs(ctx, ids) -} - -func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return database.UserLink{}, err - } - return q.db.UpdateUserLinkedID(ctx, arg) -} - -func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.UserLink{}, err - } - return q.db.GetUserLinkByLinkedID(ctx, linkedID) -} - -func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.UserLink{}, err - } - return q.db.GetUserLinkByUserIDLoginType(ctx, arg) -} - -func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) { - // This function is a system function until we implement a join for workspace builds. - // This is because we need to query for all related workspaces to the returned builds. - // This is a very inefficient method of fetching the latest workspace builds. - // We should just join the rbac properties. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetLatestWorkspaceBuilds(ctx) -} - -// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent. -// This should only be used by a system user in that middleware. -func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.WorkspaceAgent{}, err - } - return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken) -} - -func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return 0, err - } - return q.db.GetActiveUserCount(ctx) -} - -func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetUnexpiredLicenses(ctx) -} - -func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.GetAuthorizationUserRolesRow{}, err - } - return q.db.GetAuthorizationUserRoles(ctx, userID) -} - -func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return "", err - } - return q.db.GetDERPMeshKey(ctx) -} - -func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return err - } - return q.db.InsertDERPMeshKey(ctx, value) -} - -func (q *querier) InsertDeploymentID(ctx context.Context, value string) error { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return err - } - return q.db.InsertDeploymentID(ctx, value) -} - -func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.Replica{}, err - } - return q.db.InsertReplica(ctx, arg) -} - -func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return database.Replica{}, err - } - return q.db.UpdateReplica(ctx, arg) -} - -func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { - return err - } - return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt) -} - -func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetReplicasUpdatedAfter(ctx, updatedAt) -} - -func (q *querier) GetUserCount(ctx context.Context) (int64, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return 0, err - } - return q.db.GetUserCount(ctx) -} - -func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetTemplates(ctx) -} - -// Only used by metrics cache. -func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.GetTemplateAverageBuildTimeRow{}, err - } - return q.db.GetTemplateAverageBuildTime(ctx, arg) -} - -// Only used by metrics cache. -func (q *querier) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetTemplateDAUs(ctx, arg) -} - -// Only used by metrics cache. -func (q *querier) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetDeploymentDAUs(ctx, tzOffset) -} - -// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build. -func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return database.WorkspaceBuild{}, err - } - return q.db.UpdateWorkspaceBuildCostByID(ctx, arg) -} - -func (q *querier) UpsertLastUpdateCheck(ctx context.Context, value string) error { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return err - } - return q.db.UpsertLastUpdateCheck(ctx, value) -} - -func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return "", err - } - return q.db.GetLastUpdateCheck(ctx) -} - -// Telemetry related functions. These functions are system functions for returning -// telemetry data. Never called by a user. - -func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt) -} - -func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt) -} - -func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt) -} - -func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt) -} - -func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt) -} - -func (q *querier) DeleteOldWorkspaceAgentStats(ctx context.Context) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { - return err - } - return q.db.DeleteOldWorkspaceAgentStats(ctx) -} - -func (q *querier) DeleteOldWorkspaceAgentStartupLogs(ctx context.Context) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { - return err - } - return q.db.DeleteOldWorkspaceAgentStartupLogs(ctx) -} - -func (q *querier) GetDeploymentWorkspaceAgentStats(ctx context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { - return q.db.GetDeploymentWorkspaceAgentStats(ctx, createdAfter) -} - -func (q *querier) GetWorkspaceAgentStats(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { - return q.db.GetWorkspaceAgentStats(ctx, createdAfter) -} - -func (q *querier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { - return q.db.GetWorkspaceAgentStatsAndLabels(ctx, createdAfter) -} - -func (q *querier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { - return q.db.GetDeploymentWorkspaceStats(ctx) -} - -func (q *querier) GetWorkspacesEligibleForAutoStartStop(ctx context.Context, now time.Time) ([]database.Workspace, error) { - return q.db.GetWorkspacesEligibleForAutoStartStop(ctx, now) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) { - // if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - // return nil, err - // } - return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt) -} - -// Provisionerd server functions - -func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.WorkspaceAgent{}, err - } - return q.db.InsertWorkspaceAgent(ctx, arg) -} - -func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.WorkspaceApp{}, err - } - return q.db.InsertWorkspaceApp(ctx, arg) -} - -func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return nil, err - } - return q.db.InsertWorkspaceResourceMetadata(ctx, arg) -} - -func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return err - } - return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { - // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - // return database.ProvisionerJob{}, err - // } - return q.db.AcquireProvisionerJob(ctx, arg) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { - // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - // return err - // } - return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { - // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - // return err - // } - return q.db.UpdateProvisionerJobByID(ctx, arg) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { - // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - // return database.ProvisionerJob{}, err - // } - return q.db.InsertProvisionerJob(ctx, arg) -} - -// TODO: We need to create a ProvisionerJob resource type -func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { - // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - // return nil, err - // } - return q.db.InsertProvisionerJobLogs(ctx, arg) -} - -func (q *querier) InsertWorkspaceAgentStartupLogs(ctx context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { - return q.db.InsertWorkspaceAgentStartupLogs(ctx, arg) -} - -// TODO: We need to create a ProvisionerDaemon resource type -func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { - // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - // return database.ProvisionerDaemon{}, err - // } - return q.db.InsertProvisionerDaemon(ctx, arg) -} - -func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.TemplateVersionParameter{}, err - } - return q.db.InsertTemplateVersionParameter(ctx, arg) -} - -func (q *querier) InsertTemplateVersionVariable(ctx context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.TemplateVersionVariable{}, err - } - return q.db.InsertTemplateVersionVariable(ctx, arg) -} - -func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { - if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { - return database.WorkspaceResource{}, err - } - return q.db.InsertWorkspaceResource(ctx, arg) -} - -func (q *querier) GetWorkspaceProxyByHostname(ctx context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.WorkspaceProxy{}, err - } - return q.db.GetWorkspaceProxyByHostname(ctx, params) -} - -func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { - return err - } - return q.db.UpsertDefaultProxy(ctx, arg) -} diff --git a/coderd/database/dbauthz/system_test.go b/coderd/database/dbauthz/system_test.go deleted file mode 100644 index 98f9e493e2177..0000000000000 --- a/coderd/database/dbauthz/system_test.go +++ /dev/null @@ -1,301 +0,0 @@ -package dbauthz_test - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/dbgen" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/coderd/util/slice" -) - -func (s *MethodTestSuite) TestSystemFunctions() { - s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID}) - check.Args(database.UpdateUserLinkedIDParams{ - UserID: u.ID, - LinkedID: l.LinkedID, - LoginType: database.LoginTypeGithub, - }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns(l) - })) - s.Run("UpsertDefaultProxy", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.UpsertDefaultProxyParams{}).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns() - })) - s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) { - l := dbgen.UserLink(s.T(), db, database.UserLink{}) - check.Args(l.LinkedID).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(l) - })) - s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) { - l := dbgen.UserLink(s.T(), db, database.UserLink{}) - check.Args(database.GetUserLinkByUserIDLoginTypeParams{ - UserID: l.UserID, - LoginType: l.LoginType, - }).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(l) - })) - s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) { - dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) - dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) { - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{}) - check.Args(agt.AuthToken).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(agt) - })) - s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0)) - })) - s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - check.Args(u.ID).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionCreate).Returns() - })) - s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionCreate).Returns() - })) - s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertReplicaParams{ - ID: uuid.New(), - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) - s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) { - replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()}) - require.NoError(s.T(), err) - check.Args(database.UpdateReplicaParams{ - ID: replica.ID, - DatabaseLatency: 100, - }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate) - })) - s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) { - _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) - require.NoError(s.T(), err) - check.Args(time.Now().Add(time.Hour)).Asserts(rbac.ResourceSystem, rbac.ActionDelete) - })) - s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) { - _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) - require.NoError(s.T(), err) - check.Args(time.Now().Add(time.Hour*-1)).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0)) - })) - s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.Template(s.T(), db, database.Template{}) - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) { - b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) - o := b - o.DailyCost = 10 - check.Args(database.UpdateWorkspaceBuildCostByIDParams{ - ID: b.ID, - DailyCost: 10, - }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns(o) - })) - s.Run("UpsertLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { - check.Args("value").Asserts(rbac.ResourceSystem, rbac.ActionUpdate) - })) - s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertLastUpdateCheck(context.Background(), "value") - require.NoError(s.T(), err) - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{}) - check.Args(time.Now()).Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("DeleteOldWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) - })) - s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { - // TODO: add provisioner job resource type - _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)}) - check.Args(time.Now()).Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ ) - })) - s.Run("GetTemplateVersionsByIDs", s.Subtest(func(db database.Store, check *expects) { - t1 := dbgen.Template(s.T(), db, database.Template{}) - t2 := dbgen.Template(s.T(), db, database.Template{}) - tv1 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, - }) - tv2 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, - }) - tv3 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, - }) - check.Args([]uuid.UUID{tv1.ID, tv2.ID, tv3.ID}). - Asserts(rbac.ResourceSystem, rbac.ActionRead). - Returns(slice.New(tv1, tv2, tv3)) - })) - s.Run("GetWorkspaceAppsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { - aWs := dbgen.Workspace(s.T(), db, database.Workspace{}) - aBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: aWs.ID, JobID: uuid.New()}) - aRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: aBuild.JobID}) - aAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: aRes.ID}) - a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: aAgt.ID}) - - bWs := dbgen.Workspace(s.T(), db, database.Workspace{}) - bBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: bWs.ID, JobID: uuid.New()}) - bRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: bBuild.JobID}) - bAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: bRes.ID}) - b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: bAgt.ID}) - - check.Args([]uuid.UUID{a.AgentID, b.AgentID}). - Asserts(rbac.ResourceSystem, rbac.ActionRead). - Returns([]database.WorkspaceApp{a, b}) - })) - s.Run("GetWorkspaceResourcesByJobIDs", s.Subtest(func(db database.Store, check *expects) { - tpl := dbgen.Template(s.T(), db, database.Template{}) - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) - tJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) - - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - wJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - check.Args([]uuid.UUID{tJob.ID, wJob.ID}). - Asserts(rbac.ResourceSystem, rbac.ActionRead). - Returns([]database.WorkspaceResource{}) - })) - s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) - a := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - b := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - check.Args([]uuid.UUID{a.ID, b.ID}). - Asserts(rbac.ResourceSystem, rbac.ActionRead) - })) - s.Run("GetWorkspaceAgentsByResourceIDs", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args([]uuid.UUID{res.ID}). - Asserts(rbac.ResourceSystem, rbac.ActionRead). - Returns([]database.WorkspaceAgent{agt}) - })) - s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { - // TODO: add a ProvisionerJob resource type - a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - check.Args([]uuid.UUID{a.ID, b.ID}). - Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ ). - Returns(slice.New(a, b)) - })) - s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAgentParams{ - ID: uuid.New(), - StartupScriptBehavior: database.StartupScriptBehaviorNonBlocking, - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) - s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceAppParams{ - ID: uuid.New(), - Health: database.WorkspaceAppHealthDisabled, - SharingLevel: database.AppSharingLevelOwner, - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) - s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceResourceMetadataParams{ - WorkspaceResourceID: uuid.New(), - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) - s.Run("UpdateWorkspaceAgentConnectionByID", s.Subtest(func(db database.Store, check *expects) { - ws := dbgen.Workspace(s.T(), db, database.Workspace{}) - build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) - res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) - check.Args(database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: agt.ID, - }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns() - })) - s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerJob resource - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ - StartedAt: sql.NullTime{Valid: false}, - }) - check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}). - Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) - })) - s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerJob resource - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: j.ID, - }).Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) - })) - s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerJob resource - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - check.Args(database.UpdateProvisionerJobByIDParams{ - ID: j.ID, - UpdatedAt: time.Now(), - }).Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ ) - })) - s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerJob resource - check.Args(database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeWorkspaceBuild, - }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) - })) - s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerJob resource - j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) - check.Args(database.InsertProvisionerJobLogsParams{ - JobID: j.ID, - }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) - })) - s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) { - // TODO: we need to create a ProvisionerDaemon resource - check.Args(database.InsertProvisionerDaemonParams{ - ID: uuid.New(), - }).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ ) - })) - s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) { - v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{}) - check.Args(database.InsertTemplateVersionParameterParams{ - TemplateVersionID: v.ID, - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) - s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) { - r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{}) - check.Args(database.InsertWorkspaceResourceParams{ - ID: r.ID, - Transition: database.WorkspaceTransitionStart, - }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) - })) -} diff --git a/coderd/database/gen/authz/main.go b/coderd/database/gen/authz/main.go new file mode 100644 index 0000000000000..2c781faedfeab --- /dev/null +++ b/coderd/database/gen/authz/main.go @@ -0,0 +1,199 @@ +package main + +import ( + "go/format" + "go/token" + "log" + "os" + + "github.com/dave/dst" + "github.com/dave/dst/decorator" + "github.com/dave/dst/decorator/resolver/goast" + "github.com/dave/dst/decorator/resolver/guess" + "golang.org/x/xerrors" +) + +func main() { + err := run() + if err != nil { + log.Fatal(err) + } +} + +func run() error { + funcs, err := readStoreInterface() + if err != nil { + return err + } + funcByName := map[string]struct{}{} + for _, f := range funcs { + funcByName[f.Name] = struct{}{} + } + declByName := map[string]*dst.FuncDecl{} + + dbauthz, err := os.ReadFile("./dbauthz/dbauthz.go") + if err != nil { + return xerrors.Errorf("read dbauthz: %w", err) + } + + // Required to preserve imports! + f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbauthz", goast.New()).Parse(dbauthz) + if err != nil { + return xerrors.Errorf("parse dbauthz: %w", err) + } + + for i := 0; i < len(f.Decls); i++ { + funcDecl, ok := f.Decls[i].(*dst.FuncDecl) + if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { + continue + } + // Check if the receiver is the struct we're interested in + starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr) + if !ok { + continue + } + ident, ok := starExpr.X.(*dst.Ident) + if !ok || ident.Name != "querier" { + continue + } + if _, ok := funcByName[funcDecl.Name.Name]; !ok { + continue + } + declByName[funcDecl.Name.Name] = funcDecl + f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) + i-- + } + + for _, fn := range funcs { + decl, ok := declByName[fn.Name] + if !ok { + // Not implemented! + decl = &dst.FuncDecl{ + Name: dst.NewIdent(fn.Name), + Type: &dst.FuncType{ + Func: true, + TypeParams: fn.Func.TypeParams, + Params: fn.Func.Params, + Results: fn.Func.Results, + Decs: fn.Func.Decs, + }, + Recv: &dst.FieldList{ + List: []*dst.Field{{ + Names: []*dst.Ident{dst.NewIdent("q")}, + Type: dst.NewIdent("*querier"), + }}, + }, + Decs: dst.FuncDeclDecorations{ + NodeDecs: dst.NodeDecs{ + Before: dst.EmptyLine, + After: dst.EmptyLine, + }, + }, + Body: &dst.BlockStmt{ + List: []dst.Stmt{ + &dst.ExprStmt{ + X: &dst.CallExpr{ + Fun: &dst.Ident{ + Name: "panic", + }, + Args: []dst.Expr{ + &dst.BasicLit{ + Kind: token.STRING, + Value: "\"Not implemented\"", + }, + }, + }, + }, + }, + }, + } + } + f.Decls = append(f.Decls, decl) + } + + file, err := os.OpenFile("./dbauthz/dbauthz.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + return xerrors.Errorf("open dbauthz: %w", err) + } + defer file.Close() + + // Required to preserve imports! + restorer := decorator.NewRestorerWithImports("dbauthz", guess.New()) + restored, err := restorer.RestoreFile(f) + if err != nil { + return xerrors.Errorf("restore dbauthz: %w", err) + } + err = format.Node(file, restorer.Fset, restored) + return err +} + +type storeMethod struct { + Name string + Func *dst.FuncType +} + +func readStoreInterface() ([]storeMethod, error) { + querier, err := os.ReadFile("./querier.go") + if err != nil { + return nil, xerrors.Errorf("read querier: %w", err) + } + f, err := decorator.Parse(querier) + if err != nil { + return nil, err + } + + var sqlcQuerier *dst.InterfaceType + for _, decl := range f.Decls { + genDecl, ok := decl.(*dst.GenDecl) + if !ok { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*dst.TypeSpec) + if !ok { + continue + } + if typeSpec.Name.Name != "sqlcQuerier" { + continue + } + sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType) + if !ok { + return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type) + } + break + } + } + if sqlcQuerier == nil { + return nil, xerrors.Errorf("sqlcQuerier not found") + } + funcs := []storeMethod{} + for _, method := range sqlcQuerier.Methods.List { + funcType, ok := method.Type.(*dst.FuncType) + if !ok { + continue + } + + for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} { + if t == nil { + continue + } + for _, f := range t.List { + ident, ok := f.Type.(*dst.Ident) + if !ok { + continue + } + if !ident.IsExported() { + continue + } + ident.Path = "github.com/coder/coder/coderd/database" + } + } + + funcs = append(funcs, storeMethod{ + Name: method.Names[0].Name, + Func: funcType, + }) + } + return funcs, nil +} diff --git a/coderd/database/generate.sh b/coderd/database/generate.sh index c17bf695461bf..f94ba151c07a9 100755 --- a/coderd/database/generate.sh +++ b/coderd/database/generate.sh @@ -63,6 +63,9 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") go run gen/fake/main.go go run golang.org/x/tools/cmd/goimports@latest -w ./dbfake/dbfake.go + go run gen/authz/main.go + go run golang.org/x/tools/cmd/goimports@latest -w ./dbauthz/dbauthz.go + go run gen/metrics/main.go go run golang.org/x/tools/cmd/goimports@latest -w ./dbmetrics/dbmetrics.go )