From 04a2cae0b598f064d10e3f3990c3716512077bdf Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 11 Jul 2023 18:56:12 -0400 Subject: [PATCH 1/5] chore: add owner to resourceUser rbac object --- coderd/coderdtest/authorize.go | 2 +- coderd/database/dbauthz/dbauthz.go | 8 ++++---- coderd/database/dbauthz/dbauthz_test.go | 2 +- coderd/database/modelmethods.go | 6 +++--- coderd/database/modelqueries.go | 5 ++++- coderd/rbac/regosql/configs.go | 16 ++++++++++++++++ coderd/rbac/roles.go | 14 ++++++++++---- coderd/rbac/roles_test.go | 2 +- 8 files changed, 40 insertions(+), 15 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 5e3918e7e6f02..83a3cd40d649e 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -116,7 +116,7 @@ func (RBACAsserter) convertObjects(t *testing.T, objs ...interface{}) []rbac.Obj case codersdk.TemplateVersion: robj = rbac.ResourceTemplate.InOrg(obj.OrganizationID) case codersdk.User: - robj = rbac.ResourceUser.WithID(obj.ID) + robj = rbac.ResourceUser.WithID(obj.ID).WithOwner(obj.ID.String()) case codersdk.Workspace: robj = rbac.ResourceWorkspace.WithID(obj.ID).InOrg(obj.OrganizationID).WithOwner(obj.OwnerID.String()) default: diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 41fa20392fadf..fab229e2acfc1 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1181,7 +1181,7 @@ func (q *querier) GetProvisionerLogsAfterID(ctx context.Context, arg database.Ge } func (q *querier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID) (int64, error) { - err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID).WithOwner(userID.String())) if err != nil { return -1, err } @@ -1189,7 +1189,7 @@ func (q *querier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID } func (q *querier) GetQuotaConsumedForUser(ctx context.Context, userID uuid.UUID) (int64, error) { - err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID).WithOwner(userID.String())) if err != nil { return -1, err } @@ -1436,7 +1436,7 @@ func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([] // itself. func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]database.User, error) { for _, uid := range ids { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(uid)); err != nil { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(uid).WithOwner(uid.String())); err != nil { return nil, err } } @@ -1942,7 +1942,7 @@ func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.Inser // TODO: Should this be in system.go? func (q *querier) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceUser.WithID(arg.UserID)); err != nil { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceUser.WithID(arg.UserID).WithOwner(arg.UserID.String())); err != nil { return database.UserLink{}, err } return q.db.InsertUserLink(ctx, arg) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index bde4a1dfd5ef4..e9079491b4736 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -521,7 +521,7 @@ func (s *MethodTestSuite) TestOrganization() { ma := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: oa.ID}) mb := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: ob.ID}) check.Args([]uuid.UUID{ma.UserID, mb.UserID}). - Asserts(rbac.ResourceUser.WithID(ma.UserID), rbac.ActionRead, rbac.ResourceUser.WithID(mb.UserID), rbac.ActionRead) + Asserts(rbac.ResourceUser.WithID(ma.UserID).WithOwner(ma.UserID.String()), rbac.ActionRead, rbac.ResourceUser.WithID(mb.UserID).WithOwner(mb.UserID.String()), rbac.ActionRead) })) s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) { mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{}) diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index bb7dfdd1bb818..d8b7731c82aab 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -201,7 +201,7 @@ func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object { // TODO: This feels incorrect as we are really returning a list of orgmembers. // This return type should be refactored to return a list of orgmembers, not this // special type. - return rbac.ResourceUser.WithID(m.UserID) + return rbac.ResourceUser.WithID(m.UserID).WithOwner(m.UserID.String()) } func (o Organization) RBACObject() rbac.Object { @@ -233,7 +233,7 @@ func (f File) RBACObject() rbac.Object { // If you are trying to get the RBAC object for the UserData, use // u.UserDataRBACObject() instead. func (u User) RBACObject() rbac.Object { - return rbac.ResourceUser.WithID(u.ID) + return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String()) } func (u User) UserDataRBACObject() rbac.Object { @@ -241,7 +241,7 @@ func (u User) UserDataRBACObject() rbac.Object { } func (u GetUsersRow) RBACObject() rbac.Object { - return rbac.ResourceUser.WithID(u.ID) + return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String()) } func (u GitSSHKey) RBACObject() rbac.Object { diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 4ff6e6e2d154a..d0a7142a41bb6 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -256,7 +256,10 @@ type userQuerier interface { } func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + if err != nil { return -1, xerrors.Errorf("compile authorized filter: %w", err) } diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 475d317cd53ab..6c33eadb4c97b 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -22,6 +22,22 @@ func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher { return ACLGroupMatcher(m, "user_acl", []string{"input", "object", "acl_user_list"}) } +func UserConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + resourceIDMatcher(), + // Users are never owned by an organization. + sqltypes.AlwaysFalse(organizationOwnerMatcher()), + // Users are always owned by themselves. + sqltypes.StringVarMatcher("id :: text", []string{"input", "object", "owner"}), + ) + matcher.RegisterMatcher( + // No ACLs on the user type + sqltypes.AlwaysFalse(groupACLMatcher(matcher)), + sqltypes.AlwaysFalse(userACLMatcher(matcher)), + ) + return matcher +} + func TemplateConverter() *sqltypes.VariableConverter { matcher := sqltypes.NewVariableConverter().RegisterMatcher( resourceIDMatcher(), diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index ee3805b716402..8e6ed66dce546 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -145,14 +145,18 @@ func ReloadBuiltinRoles(opts *RoleOptions) { Name: member, DisplayName: "", Site: Permissions(map[string][]Action{ - // All users can read all other users and know they exist. - ResourceUser.Type: {ActionRead}, ResourceRoleAssignment.Type: {ActionRead}, // All users can see the provisioner daemons. ResourceProvisionerDaemon.Type: {ActionRead}, }), - Org: map[string][]Permission{}, - User: allPermsExcept(ResourceWorkspaceLocked), + Org: map[string][]Permission{}, + User: append(allPermsExcept(ResourceWorkspaceLocked, ResourceUser), + Permissions(map[string][]Action{ + // Users cannot do create/update/delete on themselves, but they + // can read their own details. + ResourceUser.Type: {ActionRead}, + })..., + ), }.withCachedRegoValue() auditorRole := Role{ @@ -163,6 +167,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // are not in. ResourceTemplate.Type: {ActionRead}, ResourceAuditLog.Type: {ActionRead}, + ResourceUser.Type: {ActionRead}, }), Org: map[string][]Permission{}, User: []Permission{}, @@ -172,6 +177,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { Name: templateAdmin, DisplayName: "Template Admin", Site: Permissions(map[string][]Action{ + ResourceUser.Type: {ActionRead}, ResourceTemplate.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, // CRUD all files, even those they did not upload. ResourceFile.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, diff --git a/coderd/rbac/roles_test.go b/coderd/rbac/roles_test.go index 4c8b90bdfdb67..9d68dce3b92d2 100644 --- a/coderd/rbac/roles_test.go +++ b/coderd/rbac/roles_test.go @@ -106,7 +106,7 @@ func TestRolePermissions(t *testing.T) { { Name: "MyUser", Actions: []rbac.Action{rbac.ActionRead}, - Resource: rbac.ResourceUser.WithID(currentUser), + Resource: rbac.ResourceUser.WithID(currentUser).WithOwner(currentUser.String()), AuthorizeMap: map[bool][]authSubject{ true: {owner, memberMe, orgMemberMe, orgAdmin, otherOrgMember, otherOrgAdmin, templateAdmin, userAdmin}, false: {}, From 76e724c31ac9c2d1ffd03ed28bfdee299675ddd4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 12 Jul 2023 08:53:08 -0400 Subject: [PATCH 2/5] chore: templates conditionally return created by If omitted, the caller does not have permission to view said data --- coderd/templates.go | 16 ++++++++- coderd/templateversions.go | 34 ++++++++++++------- .../TemplateStats/TemplateStats.tsx | 11 +++--- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/coderd/templates.go b/coderd/templates.go index b2cfb4bf3c229..4841b418a763a 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -679,11 +679,19 @@ func (api *API) templateExamples(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, ex) } +// getCreatedByNamesByTemplateIDs returns a map of template IDs to the +// usernames of the users who created them. If the caller does not have +// permission to view the given creator, then the username will be the empty +// string. func getCreatedByNamesByTemplateIDs(ctx context.Context, db database.Store, templates []database.Template) (map[string]string, error) { creators := make(map[string]string, len(templates)) for _, template := range templates { creator, err := db.GetUserByID(ctx, template.CreatedBy) if err != nil { + if errors.Is(err, sql.ErrNoRows) || dbauthz.IsNotAuthorizedError(err) { + // Users might be omitted if the caller does not have access. + continue + } return map[string]string{}, err } creators[template.ID.String()] = creator.Username @@ -713,6 +721,12 @@ func (api *API) convertTemplate( buildTimeStats := api.metricsCache.TemplateBuildTimeStats(template.ID) + // Only include this uuid if the user has permission to view the user. + // We know this if the username is not empty. + createdBy := uuid.Nil + if createdByName != "" { + createdBy = template.CreatedBy + } return codersdk.Template{ ID: template.ID, CreatedAt: template.CreatedAt, @@ -728,7 +742,7 @@ func (api *API) convertTemplate( Icon: template.Icon, DefaultTTLMillis: time.Duration(template.DefaultTTL).Milliseconds(), MaxTTLMillis: time.Duration(template.MaxTTL).Milliseconds(), - CreatedByID: template.CreatedBy, + CreatedByID: createdBy, CreatedByName: createdByName, AllowUserAutostart: template.AllowUserAutostart, AllowUserAutostop: template.AllowUserAutostop, diff --git a/coderd/templateversions.go b/coderd/templateversions.go index 37a7bba98b2be..b68ab888343eb 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" @@ -53,8 +54,9 @@ func (api *API) templateVersion(rw http.ResponseWriter, r *http.Request) { return } + // User can be the empty user if the caller does not have permission. user, err := api.Database.GetUserByID(ctx, templateVersion.CreatedBy) - if err != nil { + if err != nil && !dbauthz.IsNotAuthorizedError(err) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error on fetching user.", Detail: err.Error(), @@ -165,7 +167,7 @@ func (api *API) patchTemplateVersion(rw http.ResponseWriter, r *http.Request) { } user, err := api.Database.GetUserByID(ctx, templateVersion.CreatedBy) - if err != nil { + if err != nil && !dbauthz.IsNotAuthorizedError(err) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error on fetching user.", Detail: err.Error(), @@ -843,7 +845,7 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { } user, err := api.Database.GetUserByID(ctx, templateVersion.CreatedBy) - if err != nil { + if err != nil && !dbauthz.IsNotAuthorizedError(err) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error on fetching user.", Detail: err.Error(), @@ -1012,7 +1014,7 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res } user, err := api.Database.GetUserByID(ctx, templateVersion.CreatedBy) - if err != nil { + if err != nil && !dbauthz.IsNotAuthorizedError(err) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error on fetching user.", Detail: err.Error(), @@ -1325,7 +1327,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht aReq.New = templateVersion user, err := api.Database.GetUserByID(ctx, templateVersion.CreatedBy) - if err != nil { + if err != nil && !dbauthz.IsNotAuthorizedError(err) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error on fetching user.", Detail: err.Error(), @@ -1404,14 +1406,20 @@ func (api *API) templateVersionLogs(rw http.ResponseWriter, r *http.Request) { } func convertTemplateVersion(version database.TemplateVersion, job codersdk.ProvisionerJob, user database.User, warnings []codersdk.TemplateVersionWarning) codersdk.TemplateVersion { - createdBy := codersdk.User{ - ID: user.ID, - Username: user.Username, - Email: user.Email, - CreatedAt: user.CreatedAt, - Status: codersdk.UserStatus(user.Status), - Roles: []codersdk.Role{}, - AvatarURL: user.AvatarURL.String, + // Only populate these fields if the user is not nil. + // It is usually nil because the caller cannot access the user + // resource in question. + var createdBy codersdk.User + if user.ID != uuid.Nil { + createdBy = codersdk.User{ + ID: user.ID, + Username: user.Username, + Email: user.Email, + CreatedAt: user.CreatedAt, + Status: codersdk.UserStatus(user.Status), + Roles: []codersdk.Role{}, + AvatarURL: user.AvatarURL.String, + } } return codersdk.TemplateVersion{ diff --git a/site/src/components/TemplateStats/TemplateStats.tsx b/site/src/components/TemplateStats/TemplateStats.tsx index 7b646a7c40342..76fc9abdaedae 100644 --- a/site/src/components/TemplateStats/TemplateStats.tsx +++ b/site/src/components/TemplateStats/TemplateStats.tsx @@ -7,6 +7,7 @@ import { formatTemplateActiveDevelopers, } from "utils/templates" import { Template, TemplateVersion } from "../../api/typesGenerated" +import { Maybe } from "components/Conditionals/Maybe" const Language = { usedByLabel: "Used by", @@ -56,10 +57,12 @@ export const TemplateStats: FC = ({ label={Language.lastUpdateLabel} value={createDayString(template.updated_at)} /> - + + + ) } From 34a0e42009b87db4fe673a33f3f871e29d857580 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 12 Jul 2023 11:04:39 -0400 Subject: [PATCH 3/5] feat: make GetUsers query use sql filter --- coderd/database/dbauthz/dbauthz.go | 33 +++++----- coderd/database/dbmock/dbmock.go | 15 +++++ coderd/database/modelqueries.go | 61 +++++++++++++++++++ coderd/database/queries.sql.go | 1 + coderd/database/queries/users.sql | 1 + site/src/components/Navbar/Navbar.tsx | 2 + .../src/components/Navbar/NavbarView.test.tsx | 8 +++ site/src/components/Navbar/NavbarView.tsx | 10 ++- 8 files changed, 115 insertions(+), 16 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index fab229e2acfc1..717792b6bb48b 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -620,8 +620,12 @@ func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFi } func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - // TODO Implement this with a SQL filter. The count is incorrect without it. - rowUsers, err := q.db.GetUsers(ctx, arg) + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return nil, -1, xerrors.Errorf("failed to prepare sql filter: %w", err) + } + + rowUsers, err := q.db.GetAuthorizedUsers(ctx, arg, prep) if err != nil { return nil, -1, err } @@ -630,18 +634,8 @@ func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersPa return []database.User{}, 0, nil } - act, ok := ActorFromContext(ctx) - if !ok { - return nil, -1, NoActorError - } - // TODO: Is this correct? Should we return a restricted user? users := database.ConvertUserRows(rowUsers) - users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) - if err != nil { - return nil, -1, err - } - return users, rowUsers[0].Count, nil } @@ -699,6 +693,13 @@ func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job datab } } +// GetAuthorizedUsers is not required for dbauthz since GetUsers is already +// authenticated. +func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + // GetUsers is authenticated. + return q.GetUsers(ctx, arg) +} + func (q *querier) AcquireLock(ctx context.Context, id int64) error { return q.db.AcquireLock(ctx, id) } @@ -1427,8 +1428,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database } func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { - // TODO: We should use GetUsersWithCount with a better method signature. - return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) + // This does the filtering in SQL. + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedUsers(ctx, arg, prep) } // GetUsersByIDs is only used for usernames on workspace return data. diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index deab31927154f..9d635bd77e0e4 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -446,6 +446,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2) } +// GetAuthorizedUsers mocks base method. +func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizedUsers", arg0, arg1, arg2) + ret0, _ := ret[0].([]database.GetUsersRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAuthorizedUsers indicates an expected call of GetAuthorizedUsers. +func (mr *MockStoreMockRecorder) GetAuthorizedUsers(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), arg0, arg1, arg2) +} + // GetAuthorizedWorkspaces mocks base method. func (m *MockStore) GetAuthorizedWorkspaces(arg0 context.Context, arg1 database.GetWorkspacesParams, arg2 rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index d0a7142a41bb6..08ccb31e2fd5f 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -252,9 +252,70 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa } type userQuerier interface { + GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) } +func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + + filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.AfterID, + arg.Search, + pq.Array(arg.Status), + pq.Array(arg.RbacRole), + arg.LastSeenBefore, + arg.LastSeenAfter, + arg.OffsetOpt, + arg.LimitOpt, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetUsersRow + for rows.Next() { + var i GetUsersRow + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Status, + &i.RBACRoles, + &i.LoginType, + &i.AvatarURL, + &i.Deleted, + &i.LastSeenAt, + &i.Count, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ VariableConverter: regosql.UserConverter(), diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 32d85aa9d6516..08c0777b2a307 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5301,6 +5301,7 @@ WHERE ELSE true END -- End of filters + ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET $7 diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 75cc85cdf90de..28f7a5ca6ba0b 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -208,6 +208,7 @@ WHERE ELSE true END -- End of filters + ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET @offset_opt diff --git a/site/src/components/Navbar/Navbar.tsx b/site/src/components/Navbar/Navbar.tsx index 88042b5bceb8d..0e31f49aa6f99 100644 --- a/site/src/components/Navbar/Navbar.tsx +++ b/site/src/components/Navbar/Navbar.tsx @@ -16,6 +16,7 @@ export const Navbar: FC = () => { const canViewAuditLog = featureVisibility["audit_log"] && Boolean(permissions.viewAuditLog) const canViewDeployment = Boolean(permissions.viewDeploymentValues) + const canViewUsers = Boolean(permissions.readAllUsers) const onSignOut = () => authSend("SIGN_OUT") const proxyContextValue = useProxy() const dashboard = useDashboard() @@ -29,6 +30,7 @@ export const Navbar: FC = () => { onSignOut={onSignOut} canViewAuditLog={canViewAuditLog} canViewDeployment={canViewDeployment} + canViewUsers={canViewUsers} proxyContextValue={ dashboard.experiments.includes("moons") ? proxyContextValue : undefined } diff --git a/site/src/components/Navbar/NavbarView.test.tsx b/site/src/components/Navbar/NavbarView.test.tsx index 55f5dd35901c3..63dc3bcb067ea 100644 --- a/site/src/components/Navbar/NavbarView.test.tsx +++ b/site/src/components/Navbar/NavbarView.test.tsx @@ -48,6 +48,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) const workspacesLink = await screen.findByText(navLanguage.workspaces) @@ -62,6 +63,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) const templatesLink = await screen.findByText(navLanguage.templates) @@ -76,6 +78,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) const userLink = await screen.findByText(navLanguage.users) @@ -98,6 +101,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) @@ -115,6 +119,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) const auditLink = await screen.findByText(navLanguage.audit) @@ -129,6 +134,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog={false} canViewDeployment + canViewUsers />, ) const auditLink = screen.queryByText(navLanguage.audit) @@ -143,6 +149,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog canViewDeployment + canViewUsers />, ) const auditLink = await screen.findByText(navLanguage.deployment) @@ -159,6 +166,7 @@ describe("NavbarView", () => { onSignOut={noop} canViewAuditLog={false} canViewDeployment={false} + canViewUsers={false} />, ) const auditLink = screen.queryByText(navLanguage.deployment) diff --git a/site/src/components/Navbar/NavbarView.tsx b/site/src/components/Navbar/NavbarView.tsx index a2ae924fbb039..513b46f896b3a 100644 --- a/site/src/components/Navbar/NavbarView.tsx +++ b/site/src/components/Navbar/NavbarView.tsx @@ -36,6 +36,7 @@ export interface NavbarViewProps { onSignOut: () => void canViewAuditLog: boolean canViewDeployment: boolean + canViewUsers: boolean proxyContextValue?: ProxyContextValue } @@ -43,6 +44,7 @@ export const Language = { workspaces: "Workspaces", templates: "Templates", users: "Users", + groups: "Groups", audit: "Audit", deployment: "Deployment", } @@ -52,8 +54,9 @@ const NavItems: React.FC< className?: string canViewAuditLog: boolean canViewDeployment: boolean + canViewUsers: boolean }> -> = ({ className, canViewAuditLog, canViewDeployment }) => { +> = ({ className, canViewAuditLog, canViewUsers, canViewDeployment }) => { const styles = useStyles() const location = useLocation() @@ -77,7 +80,7 @@ const NavItems: React.FC< - {Language.users} + {canViewUsers ? Language.users : Language.groups} {canViewAuditLog && ( @@ -105,6 +108,7 @@ export const NavbarView: FC = ({ onSignOut, canViewAuditLog, canViewDeployment, + canViewUsers, proxyContextValue, }) => { const styles = useStyles() @@ -142,6 +146,7 @@ export const NavbarView: FC = ({ @@ -158,6 +163,7 @@ export const NavbarView: FC = ({ className={styles.desktopNavItems} canViewAuditLog={canViewAuditLog} canViewDeployment={canViewDeployment} + canViewUsers={canViewUsers} /> Date: Wed, 12 Jul 2023 11:16:35 -0400 Subject: [PATCH 4/5] Include custom db funcs in gen --- coderd/database/dbauthz/dbauthz.go | 90 +- coderd/database/dbfake/dbfake.go | 6344 ++++++++++++------------ coderd/database/dbmetrics/dbmetrics.go | 78 +- scripts/dbgen/main.go | 54 +- 4 files changed, 3329 insertions(+), 3237 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 717792b6bb48b..3124bad7cd38d 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -575,11 +575,6 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r return nil } -func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { - // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. - return q.GetTemplatesWithFilter(ctx, arg) -} - func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { deleteF := func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ @@ -591,34 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) } -func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - // An actor is authorized to read template group roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateGroupRoles(ctx, id) -} - -func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - // An actor is authorized to query template user roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateUserRoles(ctx, id) -} - -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) -} - func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) if err != nil { @@ -649,11 +616,6 @@ func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) } -func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. - return q.GetWorkspaces(ctx, arg) -} - func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ @@ -693,13 +655,6 @@ func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job datab } } -// GetAuthorizedUsers is not required for dbauthz since GetUsers is already -// authenticated. -func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { - // GetUsers is authenticated. - return q.GetUsers(ctx, arg) -} - func (q *querier) AcquireLock(ctx context.Context, id int64) error { return q.db.AcquireLock(ctx, id) } @@ -2647,3 +2602,48 @@ func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (d } return q.db.UpsertTailnetCoordinator(ctx, id) } + +func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { + // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. + return q.GetTemplatesWithFilter(ctx, arg) +} + +func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + // An actor is authorized to read template group roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateGroupRoles(ctx, id) +} + +func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + // An actor is authorized to query template user roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateUserRoles(ctx, id) +} + +func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. + return q.GetWorkspaces(ctx, arg) +} + +// GetAuthorizedUsers is not required for dbauthz since GetUsers is already +// authenticated. +func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + // GetUsers is authenticated. + return q.GetUsers(ctx, arg) +} + +func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 810f8e0b929d9..21011fbd489b9 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/coderd/database/db2sdk" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/coderd/util/slice" "github.com/coder/coder/codersdk" ) @@ -265,3575 +266,3441 @@ func (q *fakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { return database.User{}, sql.ErrNoRows } -func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err +func convertUsers(users []database.User, count int64) []database.GetUsersRow { + rows := make([]database.GetUsersRow, len(users)) + for i, u := range users { + rows[i] = database.GetUsersRow{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + HashedPassword: u.HashedPassword, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + Status: u.Status, + RBACRoles: u.RBACRoles, + LoginType: u.LoginType, + AvatarURL: u.AvatarURL, + Deleted: u.Deleted, + LastSeenAt: u.LastSeenAt, + Count: count, + } } - q.mutex.RLock() - defer q.mutex.RUnlock() + return rows +} - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return -1, err +// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. +// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. +func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { + var status string + connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second + switch { + case !dbAgent.FirstConnectedAt.Valid: + switch { + case connectionTimeout > 0 && database.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: + // If the agent took too long to connect the first time, + // mark it as timed out. + status = "timeout" + default: + // If the agent never connected, it's waiting for the compute + // to start up. + status = "connecting" } + case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): + // If we've disconnected after our last connection, we know the + // agent is no longer connected. + status = "disconnected" + case database.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: + // The connection died without updating the last connected. + status = "disconnected" + case dbAgent.LastConnectedAt.Valid: + // The agent should be assumed connected if it's under inactivity timeouts + // and last connected at has been properly set. + status = "connected" + default: + panic("unknown agent status: " + status) } + return status +} - users := make([]database.User, 0, len(q.users)) - - for _, user := range q.users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue +func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []database.GetWorkspacesRow { + rows := make([]database.GetWorkspacesRow, len(workspaces)) + for i, w := range workspaces { + rows[i] = database.GetWorkspacesRow{ + ID: w.ID, + CreatedAt: w.CreatedAt, + UpdatedAt: w.UpdatedAt, + OwnerID: w.OwnerID, + OrganizationID: w.OrganizationID, + TemplateID: w.TemplateID, + Deleted: w.Deleted, + Name: w.Name, + AutostartSchedule: w.AutostartSchedule, + Ttl: w.Ttl, + LastUsedAt: w.LastUsedAt, + Count: count, } - - users = append(users, user) } + return rows +} - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) +func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { + for _, workspace := range q.workspaces { + if workspace.ID == id { + return workspace, nil } } - users = tmp + return database.Workspace{}, sql.ErrNoRows +} - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } +func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { + var agent database.WorkspaceAgent + for _, _agent := range q.workspaceAgents { + if _agent.ID == agentID { + agent = _agent + break } - users = tmp + } + if agent.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows } - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } + var resource database.WorkspaceResource + for _, _resource := range q.workspaceResources { + if _resource.ID == agent.ResourceID { + resource = _resource + break } - users = usersFilteredByStatus + } + if resource.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows } - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } + var build database.WorkspaceBuild + for _, _build := range q.workspaceBuilds { + if _build.JobID == resource.JobID { + build = _build + break } + } + if build.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } - users = usersFilteredByRole + for _, workspace := range q.workspaces { + if workspace.ID == build.WorkspaceID { + return workspace, nil + } } - return int64(len(users)), nil + return database.Workspace{}, sql.ErrNoRows } -func convertUsers(users []database.User, count int64) []database.GetUsersRow { - rows := make([]database.GetUsersRow, len(users)) - for i, u := range users { - rows[i] = database.GetUsersRow{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - HashedPassword: u.HashedPassword, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, - Status: u.Status, - RBACRoles: u.RBACRoles, - LoginType: u.LoginType, - AvatarURL: u.AvatarURL, - Deleted: u.Deleted, - LastSeenAt: u.LastSeenAt, - Count: count, +func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { + for _, history := range q.workspaceBuilds { + if history.ID == id { + return history, nil } } - - return rows + return database.WorkspaceBuild{}, sql.ErrNoRows } -//nolint:gocyclo -func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err +func (q *fakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { + var row database.WorkspaceBuild + var buildNum int32 = -1 + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { + row = workspaceBuild + buildNum = workspaceBuild.BuildNumber + } } + if buildNum == -1 { + return database.WorkspaceBuild{}, sql.ErrNoRows + } + return row, nil +} - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err +func (q *fakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { + for _, template := range q.templates { + if template.ID == id { + return template.DeepCopy(), nil } } + return database.Template{}, sql.ErrNoRows +} - workspaces := make([]database.Workspace, 0) - for _, workspace := range q.workspaces { - if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { +func (q *fakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { + for _, templateVersion := range q.templateVersions { + if templateVersion.ID != templateVersionID { continue } + return templateVersion, nil + } + return database.TemplateVersion{}, sql.ErrNoRows +} - if arg.OwnerUsername != "" { - owner, err := q.getUserByIDNoLock(workspace.OwnerID) - if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { - continue - } +func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { + // The schema sorts this by created at, so we iterate the array backwards. + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] + if agent.ID == id { + return agent, nil } + } + return database.WorkspaceAgent{}, sql.ErrNoRows +} - if arg.TemplateName != "" { - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { +func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { + workspaceAgents := make([]database.WorkspaceAgent, 0) + for _, agent := range q.workspaceAgents { + for _, resourceID := range resourceIDs { + if agent.ResourceID != resourceID { continue } + workspaceAgents = append(workspaceAgents, agent) } + } + return workspaceAgents, nil +} - if !arg.Deleted && workspace.Deleted { +func (q *fakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + for _, provisionerJob := range q.provisionerJobs { + if provisionerJob.ID != id { continue } + return provisionerJob, nil + } + return database.ProvisionerJob{}, sql.ErrNoRows +} - if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { +func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { + resources := make([]database.WorkspaceResource, 0) + for _, resource := range q.workspaceResources { + if resource.JobID != jobID { continue } + resources = append(resources, resource) + } + return resources, nil +} - if arg.Status != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - // This logic should match the logic in the workspace.sql file. - var statusMatch bool - switch database.WorkspaceStatus(arg.Status) { - case database.WorkspaceStatusPending: - statusMatch = isNull(job.StartedAt) - case database.WorkspaceStatusStarting: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStart +func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { + for _, group := range q.groups { + if group.ID == id { + return group, nil + } + } - case database.WorkspaceStatusRunning: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStart + return database.Group{}, sql.ErrNoRows +} - case database.WorkspaceStatusStopping: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStop +// isNull is only used in dbfake, so reflect is ok. Use this to make the logic +// look more similar to the postgres. +func isNull(v interface{}) bool { + return !isNotNull(v) +} - case database.WorkspaceStatusStopped: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusFailed: - statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || - (isNotNull(job.CompletedAt) && isNotNull(job.Error)) +func isNotNull(v interface{}) bool { + return reflect.ValueOf(v).FieldByName("Valid").Bool() +} - case database.WorkspaceStatusCanceling: - statusMatch = isNotNull(job.CanceledAt) && - isNull(job.CompletedAt) +// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly +// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake +// database would strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little +// sense to directly test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to +// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, +// these methods remain unimplemented in the fakeQuerier. +var ErrUnimplemented = xerrors.New("unimplemented") - case database.WorkspaceStatusCanceled: - statusMatch = isNotNull(job.CanceledAt) && - isNotNull(job.CompletedAt) +func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error { + return xerrors.New("AcquireLock must only be called within a transaction") +} - case database.WorkspaceStatusDeleted: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNotNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionDelete && - isNull(job.Error) +func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { + if err := validateDatabaseType(arg); err != nil { + return database.ProvisionerJob{}, err + } - case database.WorkspaceStatusDeleting: - statusMatch = isNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionDelete + q.mutex.Lock() + defer q.mutex.Unlock() - default: - return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) - } - if !statusMatch { + for index, provisionerJob := range q.provisionerJobs { + if provisionerJob.StartedAt.Valid { + continue + } + found := false + for _, provisionerType := range arg.Types { + if provisionerJob.Provisioner != provisionerType { continue } + found = true + break } - - if arg.HasAgent != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - - var workspaceResourceIDs []uuid.UUID - for _, wr := range workspaceResources { - workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) - } - - workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) + if !found { + continue + } + tags := map[string]string{} + if arg.Tags != nil { + err := json.Unmarshal(arg.Tags, &tags) if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - var hasAgentMatched bool - for _, wa := range workspaceAgents { - if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { - hasAgentMatched = true - } - } - - if !hasAgentMatched { - continue + return provisionerJob, xerrors.Errorf("unmarshal: %w", err) } } - if len(arg.TemplateIds) > 0 { - match := false - for _, id := range arg.TemplateIds { - if workspace.TemplateID == id { - match = true - break - } + missing := false + for key, value := range provisionerJob.Tags { + provided, found := tags[key] + if !found { + missing = true + break } - if !match { - continue + if provided != value { + missing = true + break } } - - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { + if missing { continue } - workspaces = append(workspaces, workspace) + provisionerJob.StartedAt = arg.StartedAt + provisionerJob.UpdatedAt = arg.StartedAt.Time + provisionerJob.WorkerID = arg.WorkerID + q.provisionerJobs[index] = provisionerJob + return provisionerJob, nil } + return database.ProvisionerJob{}, sql.ErrNoRows +} - // Sort workspaces (ORDER BY) - isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { - return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart - } +func (*fakeQuerier) CleanTailnetCoordinators(_ context.Context) error { + return ErrUnimplemented +} - preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} - preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} - preloadedUsers := map[uuid.UUID]database.User{} +func (q *fakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() - for _, w := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err == nil { - preloadedWorkspaceBuilds[w.ID] = build - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get latest build: %w", err) + for index, apiKey := range q.apiKeys { + if apiKey.ID != id { + continue } + q.apiKeys[index] = q.apiKeys[len(q.apiKeys)-1] + q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] + return nil + } + return sql.ErrNoRows +} - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err == nil { - preloadedProvisionerJobs[w.ID] = job - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } +func (q *fakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() - user, err := q.getUserByIDNoLock(w.OwnerID) - if err == nil { - preloadedUsers[w.ID] = user - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user: %w", err) + for i := len(q.apiKeys) - 1; i >= 0; i-- { + if q.apiKeys[i].UserID == userID { + q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) } } - sort.Slice(workspaces, func(i, j int) bool { - w1 := workspaces[i] - w2 := workspaces[j] - - // Order by: running first - w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) - w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) + return nil +} - if w1IsRunning && !w2IsRunning { - return true - } +func (q *fakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() - if !w1IsRunning && w2IsRunning { - return false + for i := len(q.apiKeys) - 1; i >= 0; i-- { + if q.apiKeys[i].UserID == userID && q.apiKeys[i].Scope == database.APIKeyScopeApplicationConnect { + q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) } + } - // Order by: usernames - if w1.ID != w2.ID { - return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) - } + return nil +} - // Order by: workspace names - return sort.StringsAreSorted([]string{w1.Name, w2.Name}) - }) +func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { + return ErrUnimplemented +} - beforePageCount := len(workspaces) +func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() - if arg.Offset > 0 { - if int(arg.Offset) > len(workspaces) { - return []database.GetWorkspacesRow{}, nil + for index, key := range q.gitSSHKey { + if key.UserID != userID { + continue } - workspaces = workspaces[arg.Offset:] + q.gitSSHKey[index] = q.gitSSHKey[len(q.gitSSHKey)-1] + q.gitSSHKey = q.gitSSHKey[:len(q.gitSSHKey)-1] + return nil } - if arg.Limit > 0 { - if int(arg.Limit) > len(workspaces) { - return convertToWorkspaceRows(workspaces, int64(beforePageCount)), nil + return sql.ErrNoRows +} + +func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, group := range q.groups { + if group.ID == id { + q.groups = append(q.groups[:i], q.groups[i+1:]...) + return nil } - workspaces = workspaces[:arg.Limit] } - return convertToWorkspaceRows(workspaces, int64(beforePageCount)), nil + return sql.ErrNoRows } -// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. -// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. -func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { - var status string - connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second - switch { - case !dbAgent.FirstConnectedAt.Valid: - switch { - case connectionTimeout > 0 && database.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: - // If the agent took too long to connect the first time, - // mark it as timed out. - status = "timeout" - default: - // If the agent never connected, it's waiting for the compute - // to start up. - status = "connecting" +func (q *fakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database.DeleteGroupMemberFromGroupParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, member := range q.groupMembers { + if member.UserID == arg.UserID && member.GroupID == arg.GroupID { + q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...) } - case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): - // If we've disconnected after our last connection, we know the - // agent is no longer connected. - status = "disconnected" - case database.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: - // The connection died without updating the last connected. - status = "disconnected" - case dbAgent.LastConnectedAt.Valid: - // The agent should be assumed connected if it's under inactivity timeouts - // and last connected at has been properly set. - status = "connected" - default: - panic("unknown agent status: " + status) } - return status + return nil } -func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []database.GetWorkspacesRow { - rows := make([]database.GetWorkspacesRow, len(workspaces)) - for i, w := range workspaces { - rows[i] = database.GetWorkspacesRow{ - ID: w.ID, - CreatedAt: w.CreatedAt, - UpdatedAt: w.UpdatedAt, - OwnerID: w.OwnerID, - OrganizationID: w.OrganizationID, - TemplateID: w.TemplateID, - Deleted: w.Deleted, - Name: w.Name, - AutostartSchedule: w.AutostartSchedule, - Ttl: w.Ttl, - LastUsedAt: w.LastUsedAt, - Count: count, +func (q *fakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + newMembers := q.groupMembers[:0] + for _, member := range q.groupMembers { + if member.UserID != arg.UserID { + // Do not delete the other members + newMembers = append(newMembers, member) + } else if member.UserID == arg.UserID { + // We only want to delete from groups in the organization in the args. + for _, group := range q.groups { + // Find the group that the member is apartof. + if group.ID == member.GroupID { + // Only add back the member if the organization ID does not match + // the arg organization ID. Since the arg is saying which + // org to delete. + if group.OrganizationID != arg.OrganizationID { + newMembers = append(newMembers, member) + } + break + } + } } } - return rows + q.groupMembers = newMembers + + return nil } -func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { - for _, workspace := range q.workspaces { - if workspace.ID == id { - return workspace, nil +func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, l := range q.licenses { + if l.ID == id { + q.licenses[index] = q.licenses[len(q.licenses)-1] + q.licenses = q.licenses[:len(q.licenses)-1] + return id, nil } } - return database.Workspace{}, sql.ErrNoRows + return 0, sql.ErrNoRows } -func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { - var agent database.WorkspaceAgent - for _, _agent := range q.workspaceAgents { - if _agent.ID == agentID { - agent = _agent - break - } - } - if agent.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } +func (*fakeQuerier) DeleteOldWorkspaceAgentStartupLogs(_ context.Context) error { + // noop + return nil +} - var resource database.WorkspaceResource - for _, _resource := range q.workspaceResources { - if _resource.ID == agent.ResourceID { - resource = _resource - break - } - } - if resource.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } +func (*fakeQuerier) DeleteOldWorkspaceAgentStats(_ context.Context) error { + // no-op + return nil +} - var build database.WorkspaceBuild - for _, _build := range q.workspaceBuilds { - if _build.JobID == resource.JobID { - build = _build - break - } - } - if build.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } +func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { + q.mutex.Lock() + defer q.mutex.Unlock() - for _, workspace := range q.workspaces { - if workspace.ID == build.WorkspaceID { - return workspace, nil + for i, replica := range q.replicas { + if replica.UpdatedAt.Before(before) { + q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) } } - return database.Workspace{}, sql.ErrNoRows + return nil } -func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - for _, history := range q.workspaceBuilds { - if history.ID == id { - return history, nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows +func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + return database.DeleteTailnetAgentRow{}, ErrUnimplemented } -func (q *fakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - var row database.WorkspaceBuild - var buildNum int32 = -1 - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { - row = workspaceBuild - buildNum = workspaceBuild.BuildNumber +func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + return database.DeleteTailnetClientRow{}, ErrUnimplemented +} + +func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, apiKey := range q.apiKeys { + if apiKey.ID == id { + return apiKey, nil } } - if buildNum == -1 { - return database.WorkspaceBuild{}, sql.ErrNoRows - } - return row, nil + return database.APIKey{}, sql.ErrNoRows } -func (q *fakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { - for _, template := range q.templates { - if template.ID == id { - return template.DeepCopy(), nil +func (q *fakeQuerier) GetAPIKeyByName(_ context.Context, params database.GetAPIKeyByNameParams) (database.APIKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + if params.TokenName == "" { + return database.APIKey{}, sql.ErrNoRows + } + for _, apiKey := range q.apiKeys { + if params.UserID == apiKey.UserID && params.TokenName == apiKey.TokenName { + return apiKey, nil } } - return database.Template{}, sql.ErrNoRows + return database.APIKey{}, sql.ErrNoRows } -func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { +func (q *fakeQuerier) GetAPIKeysByLoginType(_ context.Context, t database.LoginType) ([]database.APIKey, error) { + if err := validateDatabaseType(t); err != nil { return nil, err } q.mutex.RLock() defer q.mutex.RUnlock() - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) - if err != nil { - return nil, err + apiKeys := make([]database.APIKey, 0) + for _, key := range q.apiKeys { + if key.LoginType == t { + apiKeys = append(apiKeys, key) } } + return apiKeys, nil +} - var templates []database.Template - for _, template := range q.templates { - if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { - continue - } +func (q *fakeQuerier) GetAPIKeysByUserID(_ context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - if template.Deleted != arg.Deleted { - continue - } - if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { - continue + apiKeys := make([]database.APIKey, 0) + for _, key := range q.apiKeys { + if key.UserID == params.UserID && key.LoginType == params.LoginType { + apiKeys = append(apiKeys, key) } + } + return apiKeys, nil +} - if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { - continue - } +func (q *fakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - if len(arg.IDs) > 0 { - match := false - for _, id := range arg.IDs { - if template.ID == id { - match = true - break - } - } - if !match { - continue - } + apiKeys := make([]database.APIKey, 0) + for _, key := range q.apiKeys { + if key.LastUsed.After(after) { + apiKeys = append(apiKeys, key) } - templates = append(templates, template.DeepCopy()) } - if len(templates) > 0 { - slices.SortFunc(templates, func(i, j database.Template) bool { - if i.Name != j.Name { - return i.Name < j.Name - } - return i.ID.String() < j.ID.String() - }) - return templates, nil - } - - return nil, sql.ErrNoRows + return apiKeys, nil } -func (q *fakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - for _, templateVersion := range q.templateVersions { - if templateVersion.ID != templateVersionID { - continue +func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + active := int64(0) + for _, u := range q.users { + if u.Status == database.UserStatusActive && !u.Deleted { + active++ } - return templateVersion, nil } - return database.TemplateVersion{}, sql.ErrNoRows + return active, nil } -func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { +func (q *fakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } + return q.appSecurityKey, nil +} - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows +func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err } - users := make([]database.TemplateUser, 0, len(template.UserACL)) - for k, v := range template.UserACL { - user, err := q.getUserByIDNoLock(uuid.MustParse(k)) - if err != nil && xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user by ID: %w", err) + q.mutex.RLock() + defer q.mutex.RUnlock() + + logs := make([]database.GetAuditLogsOffsetRow, 0, arg.Limit) + + // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. + for _, alog := range q.auditLogs { + if arg.Offset > 0 { + arg.Offset-- + continue } - // We don't delete users from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { + if arg.Action != "" && !strings.Contains(string(alog.Action), arg.Action) { continue } - - if user.Deleted || user.Status == database.UserStatusSuspended { + if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { + continue + } + if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { continue } + if arg.Username != "" { + user, err := q.getUserByIDNoLock(alog.UserID) + if err == nil && !strings.EqualFold(arg.Username, user.Username) { + continue + } + } + if arg.Email != "" { + user, err := q.getUserByIDNoLock(alog.UserID) + if err == nil && !strings.EqualFold(arg.Email, user.Email) { + continue + } + } + if !arg.DateFrom.IsZero() { + if alog.Time.Before(arg.DateFrom) { + continue + } + } + if !arg.DateTo.IsZero() { + if alog.Time.After(arg.DateTo) { + continue + } + } + if arg.BuildReason != "" { + workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) + if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { + continue + } + } - users = append(users, database.TemplateUser{ - User: user, - Actions: v, + user, err := q.getUserByIDNoLock(alog.UserID) + userValid := err == nil + + logs = append(logs, database.GetAuditLogsOffsetRow{ + ID: alog.ID, + RequestID: alog.RequestID, + OrganizationID: alog.OrganizationID, + Ip: alog.Ip, + UserAgent: alog.UserAgent, + ResourceType: alog.ResourceType, + ResourceID: alog.ResourceID, + ResourceTarget: alog.ResourceTarget, + ResourceIcon: alog.ResourceIcon, + Action: alog.Action, + Diff: alog.Diff, + StatusCode: alog.StatusCode, + AdditionalFields: alog.AdditionalFields, + UserID: alog.UserID, + UserUsername: sql.NullString{String: user.Username, Valid: userValid}, + UserEmail: sql.NullString{String: user.Email, Valid: userValid}, + UserCreatedAt: sql.NullTime{Time: user.CreatedAt, Valid: userValid}, + UserStatus: database.NullUserStatus{UserStatus: user.Status, Valid: userValid}, + UserRoles: user.RBACRoles, + Count: 0, }) + + if len(logs) >= int(arg.Limit) { + break + } } - return users, nil + count := int64(len(logs)) + for i := range logs { + logs[i].Count = count + } + + return logs, nil } -func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { +func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t + var user *database.User + roles := make([]string, 0) + for _, u := range q.users { + if u.ID == userID { + u := u + roles = append(roles, u.RBACRoles...) + roles = append(roles, "member") + user = &u break } } - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows + for _, mem := range q.organizationMembers { + if mem.UserID == userID { + roles = append(roles, mem.Roles...) + roles = append(roles, "organization-member:"+mem.OrganizationID.String()) + } } - groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) - for k, v := range template.GroupACL { - group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get group by ID: %w", err) - } - // We don't delete groups from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue + var groups []string + for _, member := range q.groupMembers { + if member.UserID == userID { + groups = append(groups, member.GroupID.String()) } + } - groups = append(groups, database.TemplateGroup{ - Group: group, - Actions: v, - }) + if user == nil { + return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows } - return groups, nil + return database.GetAuthorizationUserRolesRow{ + ID: userID, + Username: user.Username, + Status: user.Status, + Roles: roles, + Groups: groups, + }, nil } -func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if agent.ID == id { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows -} +func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() -func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - for _, resourceID := range resourceIDs { - if agent.ResourceID != resourceID { - continue - } - workspaceAgents = append(workspaceAgents, agent) - } - } - return workspaceAgents, nil + return q.derpMeshKey, nil } -func (q *fakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - for _, provisionerJob := range q.provisionerJobs { - if provisionerJob.ID != id { - continue - } - return provisionerJob, nil - } - return database.ProvisionerJob{}, sql.ErrNoRows +func (q *fakeQuerier) GetDefaultProxyConfig(_ context.Context) (database.GetDefaultProxyConfigRow, error) { + return database.GetDefaultProxyConfigRow{ + DisplayName: q.defaultProxyDisplayName, + IconUrl: q.defaultProxyIconURL, + }, nil } -func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.JobID != jobID { +func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + seens := make(map[time.Time]map[uuid.UUID]struct{}) + + for _, as := range q.workspaceAgentStats { + if as.ConnectionCount == 0 { continue } - resources = append(resources, resource) + date := as.CreatedAt.UTC().Add(time.Duration(tzOffset) * -1 * time.Hour).Truncate(time.Hour * 24) + + dateEntry := seens[date] + if dateEntry == nil { + dateEntry = make(map[uuid.UUID]struct{}) + } + dateEntry[as.UserID] = struct{}{} + seens[date] = dateEntry } - return resources, nil -} -func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { - for _, group := range q.groups { - if group.ID == id { - return group, nil + seenKeys := maps.Keys(seens) + sort.Slice(seenKeys, func(i, j int) bool { + return seenKeys[i].Before(seenKeys[j]) + }) + + var rs []database.GetDeploymentDAUsRow + for _, key := range seenKeys { + ids := seens[key] + for id := range ids { + rs = append(rs, database.GetDeploymentDAUsRow{ + Date: key, + UserID: id, + }) } } - return database.Group{}, sql.ErrNoRows + return rs, nil } -// isNull is only used in dbfake, so reflect is ok. Use this to make the logic -// look more similar to the postgres. -func isNull(v interface{}) bool { - return !isNotNull(v) -} +func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() -func isNotNull(v interface{}) bool { - return reflect.ValueOf(v).FieldByName("Valid").Bool() + return q.deploymentID, nil } -// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly -// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake -// database would strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little -// sense to directly test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to -// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, -// these methods remain unimplemented in the fakeQuerier. -var ErrUnimplemented = xerrors.New("unimplemented") +func (q *fakeQuerier) GetDeploymentWorkspaceAgentStats(_ context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() -func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error { - return xerrors.New("AcquireLock must only be called within a transaction") -} + agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) + for _, agentStat := range q.workspaceAgentStats { + if agentStat.CreatedAt.After(createdAfter) { + agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) + } + } -func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err + latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} + for _, agentStat := range q.workspaceAgentStats { + if agentStat.CreatedAt.After(createdAfter) { + latestAgentStats[agentStat.AgentID] = agentStat + } } - q.mutex.Lock() - defer q.mutex.Unlock() + stat := database.GetDeploymentWorkspaceAgentStatsRow{} + for _, agentStat := range latestAgentStats { + stat.SessionCountVSCode += agentStat.SessionCountVSCode + stat.SessionCountJetBrains += agentStat.SessionCountJetBrains + stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY + stat.SessionCountSSH += agentStat.SessionCountSSH + } - for index, provisionerJob := range q.provisionerJobs { - if provisionerJob.StartedAt.Valid { + latencies := make([]float64, 0) + for _, agentStat := range agentStatsCreatedAfter { + if agentStat.ConnectionMedianLatencyMS <= 0 { continue } - found := false - for _, provisionerType := range arg.Types { - if provisionerJob.Provisioner != provisionerType { - continue - } - found = true - break + stat.WorkspaceRxBytes += agentStat.RxBytes + stat.WorkspaceTxBytes += agentStat.TxBytes + latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) + } + + tryPercentile := func(fs []float64, p float64) float64 { + if len(fs) == 0 { + return -1 } - if !found { + sort.Float64s(fs) + return fs[int(float64(len(fs))*p/100)] + } + + stat.WorkspaceConnectionLatency50 = tryPercentile(latencies, 50) + stat.WorkspaceConnectionLatency95 = tryPercentile(latencies, 95) + + return stat, nil +} + +func (q *fakeQuerier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + stat := database.GetDeploymentWorkspaceStatsRow{} + for _, workspace := range q.workspaces { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return stat, err + } + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return stat, err + } + if !job.StartedAt.Valid { + stat.PendingWorkspaces++ continue } - tags := map[string]string{} - if arg.Tags != nil { - err := json.Unmarshal(arg.Tags, &tags) - if err != nil { - return provisionerJob, xerrors.Errorf("unmarshal: %w", err) - } + if job.StartedAt.Valid && + !job.CanceledAt.Valid && + time.Since(job.UpdatedAt) <= 30*time.Second && + !job.CompletedAt.Valid { + stat.BuildingWorkspaces++ + continue } - - missing := false - for key, value := range provisionerJob.Tags { - provided, found := tags[key] - if !found { - missing = true - break + if job.CompletedAt.Valid && + !job.CanceledAt.Valid && + !job.Error.Valid { + if build.Transition == database.WorkspaceTransitionStart { + stat.RunningWorkspaces++ } - if provided != value { - missing = true - break + if build.Transition == database.WorkspaceTransitionStop { + stat.StoppedWorkspaces++ } + continue } - if missing { + if job.CanceledAt.Valid || job.Error.Valid { + stat.FailedWorkspaces++ continue } - provisionerJob.StartedAt = arg.StartedAt - provisionerJob.UpdatedAt = arg.StartedAt.Time - provisionerJob.WorkerID = arg.WorkerID - q.provisionerJobs[index] = provisionerJob - return provisionerJob, nil } - return database.ProvisionerJob{}, sql.ErrNoRows + return stat, nil } -func (*fakeQuerier) CleanTailnetCoordinators(_ context.Context) error { - return ErrUnimplemented -} +func (q *fakeQuerier) GetFileByHashAndCreator(_ context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { + if err := validateDatabaseType(arg); err != nil { + return database.File{}, err + } -func (q *fakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() - for index, apiKey := range q.apiKeys { - if apiKey.ID != id { - continue + for _, file := range q.files { + if file.Hash == arg.Hash && file.CreatedBy == arg.CreatedBy { + return file, nil } - q.apiKeys[index] = q.apiKeys[len(q.apiKeys)-1] - q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] - return nil } - return sql.ErrNoRows + return database.File{}, sql.ErrNoRows } -func (q *fakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetFileByID(_ context.Context, id uuid.UUID) (database.File, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) + for _, file := range q.files { + if file.ID == id { + return file, nil } } - - return nil + return database.File{}, sql.ErrNoRows } -func (q *fakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]database.GetFileTemplatesRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID && q.apiKeys[i].Scope == database.APIKeyScopeApplicationConnect { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) + rows := make([]database.GetFileTemplatesRow, 0) + var file database.File + for _, f := range q.files { + if f.ID == id { + file = f + break } } + if file.Hash == "" { + return rows, nil + } - return nil -} + for _, job := range q.provisionerJobs { + if job.FileID == id { + for _, version := range q.templateVersions { + if version.JobID == job.ID { + for _, template := range q.templates { + if template.ID == version.TemplateID.UUID { + rows = append(rows, database.GetFileTemplatesRow{ + FileID: file.ID, + FileCreatedBy: file.CreatedBy, + TemplateID: template.ID, + TemplateOrganizationID: template.OrganizationID, + TemplateCreatedBy: template.CreatedBy, + UserACL: template.UserACL, + GroupACL: template.GroupACL, + }) + } + } + } + } + } + } -func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { - return ErrUnimplemented + return rows, nil } -func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { + if err := validateDatabaseType(arg); err != nil { + return 0, err + } + count, err := q.GetAuthorizedUserCount(ctx, arg, nil) + return count, err +} - for index, key := range q.gitSSHKey { - if key.UserID != userID { +func (q *fakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitAuthLink{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + for _, gitAuthLink := range q.gitAuthLinks { + if arg.UserID != gitAuthLink.UserID { continue } - q.gitSSHKey[index] = q.gitSSHKey[len(q.gitSSHKey)-1] - q.gitSSHKey = q.gitSSHKey[:len(q.gitSSHKey)-1] - return nil + if arg.ProviderID != gitAuthLink.ProviderID { + continue + } + return gitAuthLink, nil } - return sql.ErrNoRows + return database.GitAuthLink{}, sql.ErrNoRows } -func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - for i, group := range q.groups { - if group.ID == id { - q.groups = append(q.groups[:i], q.groups[i+1:]...) - return nil + for _, key := range q.gitSSHKey { + if key.UserID == userID { + return key, nil } } + return database.GitSSHKey{}, sql.ErrNoRows +} - return sql.ErrNoRows +func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.getGroupByIDNoLock(ctx, id) } -func (q *fakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database.DeleteGroupMemberFromGroupParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Group{}, err + } - for i, member := range q.groupMembers { - if member.UserID == arg.UserID && member.GroupID == arg.GroupID { - q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...) + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return group, nil } } - return nil + + return database.Group{}, sql.ErrNoRows } -func (q *fakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]database.User, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - newMembers := q.groupMembers[:0] + var members []database.GroupMember for _, member := range q.groupMembers { - if member.UserID != arg.UserID { - // Do not delete the other members - newMembers = append(newMembers, member) - } else if member.UserID == arg.UserID { - // We only want to delete from groups in the organization in the args. - for _, group := range q.groups { - // Find the group that the member is apartof. - if group.ID == member.GroupID { - // Only add back the member if the organization ID does not match - // the arg organization ID. Since the arg is saying which - // org to delete. - if group.OrganizationID != arg.OrganizationID { - newMembers = append(newMembers, member) - } - break - } - } + if member.GroupID == groupID { + members = append(members, member) } } - q.groupMembers = newMembers - - return nil -} -func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + users := make([]database.User, 0, len(members)) - for index, l := range q.licenses { - if l.ID == id { - q.licenses[index] = q.licenses[len(q.licenses)-1] - q.licenses = q.licenses[:len(q.licenses)-1] - return id, nil + for _, member := range members { + for _, user := range q.users { + if user.ID == member.UserID && user.Status == database.UserStatusActive && !user.Deleted { + users = append(users, user) + break + } } } - return 0, sql.ErrNoRows -} - -func (*fakeQuerier) DeleteOldWorkspaceAgentStartupLogs(_ context.Context) error { - // noop - return nil -} -func (*fakeQuerier) DeleteOldWorkspaceAgentStats(_ context.Context) error { - // no-op - return nil + return users, nil } -func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { - q.mutex.Lock() - defer q.mutex.Unlock() +func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationID uuid.UUID) ([]database.Group, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - for i, replica := range q.replicas { - if replica.UpdatedAt.Before(before) { - q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) + var groups []database.Group + for _, group := range q.groups { + // Omit the allUsers group. + if group.OrganizationID == organizationID && group.ID != organizationID { + groups = append(groups, group) } } - return nil -} - -func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - return database.DeleteTailnetAgentRow{}, ErrUnimplemented -} - -func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - return database.DeleteTailnetClientRow{}, ErrUnimplemented + return groups, nil } -func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { +func (q *fakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, apiKey := range q.apiKeys { - if apiKey.ID == id { - return apiKey, nil + hungJobs := []database.ProvisionerJob{} + for _, provisionerJob := range q.provisionerJobs { + if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) { + hungJobs = append(hungJobs, provisionerJob) } } - return database.APIKey{}, sql.ErrNoRows + return hungJobs, nil } -func (q *fakeQuerier) GetAPIKeyByName(_ context.Context, params database.GetAPIKeyByNameParams) (database.APIKey, error) { +func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() - if params.TokenName == "" { - return database.APIKey{}, sql.ErrNoRows - } - for _, apiKey := range q.apiKeys { - if params.UserID == apiKey.UserID && params.TokenName == apiKey.TokenName { - return apiKey, nil - } + if q.lastUpdateCheck == nil { + return "", sql.ErrNoRows } - return database.APIKey{}, sql.ErrNoRows + return string(q.lastUpdateCheck), nil } -func (q *fakeQuerier) GetAPIKeysByLoginType(_ context.Context, t database.LoginType) ([]database.APIKey, error) { - if err := validateDatabaseType(t); err != nil { - return nil, err - } - +func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LoginType == t { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil + return q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) } -func (q *fakeQuerier) GetAPIKeysByUserID(_ context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { +func (q *fakeQuerier) GetLatestWorkspaceBuilds(_ context.Context) ([]database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.UserID == params.UserID && key.LoginType == params.LoginType { - apiKeys = append(apiKeys, key) + builds := make(map[uuid.UUID]database.WorkspaceBuild) + buildNumbers := make(map[uuid.UUID]int32) + for _, workspaceBuild := range q.workspaceBuilds { + id := workspaceBuild.WorkspaceID + if workspaceBuild.BuildNumber > buildNumbers[id] { + builds[id] = workspaceBuild + buildNumbers[id] = workspaceBuild.BuildNumber } } - return apiKeys, nil + var returnBuilds []database.WorkspaceBuild + for i, n := range buildNumbers { + if n > 0 { + b := builds[i] + returnBuilds = append(returnBuilds, b) + } + } + if len(returnBuilds) == 0 { + return nil, sql.ErrNoRows + } + return returnBuilds, nil } -func (q *fakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { +func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LastUsed.After(after) { - apiKeys = append(apiKeys, key) + builds := make(map[uuid.UUID]database.WorkspaceBuild) + buildNumbers := make(map[uuid.UUID]int32) + for _, workspaceBuild := range q.workspaceBuilds { + for _, id := range ids { + if id == workspaceBuild.WorkspaceID && workspaceBuild.BuildNumber > buildNumbers[id] { + builds[id] = workspaceBuild + buildNumbers[id] = workspaceBuild.BuildNumber + } } } - return apiKeys, nil + var returnBuilds []database.WorkspaceBuild + for i, n := range buildNumbers { + if n > 0 { + b := builds[i] + returnBuilds = append(returnBuilds, b) + } + } + if len(returnBuilds) == 0 { + return nil, sql.ErrNoRows + } + return returnBuilds, nil } -func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { +func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) { q.mutex.RLock() defer q.mutex.RUnlock() - active := int64(0) - for _, u := range q.users { - if u.Status == database.UserStatusActive && !u.Deleted { - active++ + for _, license := range q.licenses { + if license.ID == id { + return license, nil } } - return active, nil + return database.License{}, sql.ErrNoRows } -func (q *fakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) { +func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.appSecurityKey, nil -} - -func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } + results := append([]database.License{}, q.licenses...) + sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) + return results, nil +} +func (q *fakeQuerier) GetLogoURL(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() - logs := make([]database.GetAuditLogsOffsetRow, 0, arg.Limit) - - // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. - for _, alog := range q.auditLogs { - if arg.Offset > 0 { - arg.Offset-- - continue - } - if arg.Action != "" && !strings.Contains(string(alog.Action), arg.Action) { - continue - } - if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { - continue - } - if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { - continue - } - if arg.Username != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Username, user.Username) { - continue - } - } - if arg.Email != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Email, user.Email) { - continue - } - } - if !arg.DateFrom.IsZero() { - if alog.Time.Before(arg.DateFrom) { - continue - } - } - if !arg.DateTo.IsZero() { - if alog.Time.After(arg.DateTo) { - continue - } - } - if arg.BuildReason != "" { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) - if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { - continue - } - } - - user, err := q.getUserByIDNoLock(alog.UserID) - userValid := err == nil - - logs = append(logs, database.GetAuditLogsOffsetRow{ - ID: alog.ID, - RequestID: alog.RequestID, - OrganizationID: alog.OrganizationID, - Ip: alog.Ip, - UserAgent: alog.UserAgent, - ResourceType: alog.ResourceType, - ResourceID: alog.ResourceID, - ResourceTarget: alog.ResourceTarget, - ResourceIcon: alog.ResourceIcon, - Action: alog.Action, - Diff: alog.Diff, - StatusCode: alog.StatusCode, - AdditionalFields: alog.AdditionalFields, - UserID: alog.UserID, - UserUsername: sql.NullString{String: user.Username, Valid: userValid}, - UserEmail: sql.NullString{String: user.Email, Valid: userValid}, - UserCreatedAt: sql.NullTime{Time: user.CreatedAt, Valid: userValid}, - UserStatus: database.NullUserStatus{UserStatus: user.Status, Valid: userValid}, - UserRoles: user.RBACRoles, - Count: 0, - }) - - if len(logs) >= int(arg.Limit) { - break - } + if q.logoURL == "" { + return "", sql.ErrNoRows } - count := int64(len(logs)) - for i := range logs { - logs[i].Count = count - } + return q.logoURL, nil +} - return logs, nil +func (q *fakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.oauthSigningKey, nil } -func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { +func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var user *database.User - roles := make([]string, 0) - for _, u := range q.users { - if u.ID == userID { - u := u - roles = append(roles, u.RBACRoles...) - roles = append(roles, "member") - user = &u - break + for _, organization := range q.organizations { + if organization.ID == id { + return organization, nil } } + return database.Organization{}, sql.ErrNoRows +} - for _, mem := range q.organizationMembers { - if mem.UserID == userID { - roles = append(roles, mem.Roles...) - roles = append(roles, "organization-member:"+mem.OrganizationID.String()) - } - } +func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - var groups []string - for _, member := range q.groupMembers { - if member.UserID == userID { - groups = append(groups, member.GroupID.String()) + for _, organization := range q.organizations { + if organization.Name == name { + return organization, nil } } - - if user == nil { - return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows - } - - return database.GetAuthorizationUserRolesRow{ - ID: userID, - Username: user.Username, - Status: user.Status, - Roles: roles, - Groups: groups, - }, nil + return database.Organization{}, sql.ErrNoRows } -func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { +func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.derpMeshKey, nil + getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) + for _, userID := range ids { + userOrganizationIDs := make([]uuid.UUID, 0) + for _, membership := range q.organizationMembers { + if membership.UserID == userID { + userOrganizationIDs = append(userOrganizationIDs, membership.OrganizationID) + } + } + getOrganizationIDsByMemberIDRows = append(getOrganizationIDsByMemberIDRows, database.GetOrganizationIDsByMemberIDsRow{ + UserID: userID, + OrganizationIDs: userOrganizationIDs, + }) + } + if len(getOrganizationIDsByMemberIDRows) == 0 { + return nil, sql.ErrNoRows + } + return getOrganizationIDsByMemberIDRows, nil } -func (q *fakeQuerier) GetDefaultProxyConfig(_ context.Context) (database.GetDefaultProxyConfigRow, error) { - return database.GetDefaultProxyConfigRow{ - DisplayName: q.defaultProxyDisplayName, - IconUrl: q.defaultProxyIconURL, - }, nil -} +func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { + if err := validateDatabaseType(arg); err != nil { + return database.OrganizationMember{}, err + } -func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.ConnectionCount == 0 { + for _, organizationMember := range q.organizationMembers { + if organizationMember.OrganizationID != arg.OrganizationID { continue } - date := as.CreatedAt.UTC().Add(time.Duration(tzOffset) * -1 * time.Hour).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) + if organizationMember.UserID != arg.UserID { + continue } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry + return organizationMember, nil } + return database.OrganizationMember{}, sql.ErrNoRows +} - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) +func (q *fakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - var rs []database.GetDeploymentDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetDeploymentDAUsRow{ - Date: key, - UserID: id, - }) + var memberships []database.OrganizationMember + for _, organizationMember := range q.organizationMembers { + mem := organizationMember + if mem.UserID != userID { + continue } + memberships = append(memberships, mem) } - - return rs, nil + return memberships, nil } -func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { +func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.deploymentID, nil + if len(q.organizations) == 0 { + return nil, sql.ErrNoRows + } + return q.organizations, nil } -func (q *fakeQuerier) GetDeploymentWorkspaceAgentStats(_ context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { +func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UUID) ([]database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) + organizations := make([]database.Organization, 0) + for _, organizationMember := range q.organizationMembers { + if organizationMember.UserID != userID { + continue } - } - - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat + for _, organization := range q.organizations { + if organization.ID != organizationMember.OrganizationID { + continue + } + organizations = append(organizations, organization) } } - - stat := database.GetDeploymentWorkspaceAgentStatsRow{} - for _, agentStat := range latestAgentStats { - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH + if len(organizations) == 0 { + return nil, sql.ErrNoRows } + return organizations, nil +} - latencies := make([]float64, 0) - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { +func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + parameters := make([]database.ParameterSchema, 0) + for _, parameterSchema := range q.parameterSchemas { + if parameterSchema.JobID != jobID { continue } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) + parameters = append(parameters, parameterSchema) } - - tryPercentile := func(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - return fs[int(float64(len(fs))*p/100)] + if len(parameters) == 0 { + return nil, sql.ErrNoRows } - - stat.WorkspaceConnectionLatency50 = tryPercentile(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentile(latencies, 95) - - return stat, nil + sort.Slice(parameters, func(i, j int) bool { + return parameters[i].Index < parameters[j].Index + }) + return parameters, nil } -func (q *fakeQuerier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { +func (q *fakeQuerier) GetPreviousTemplateVersion(_ context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { + if err := validateDatabaseType(arg); err != nil { + return database.TemplateVersion{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - stat := database.GetDeploymentWorkspaceStatsRow{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return stat, err + var currentTemplateVersion database.TemplateVersion + for _, templateVersion := range q.templateVersions { + if templateVersion.TemplateID != arg.TemplateID { + continue } - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return stat, err + if templateVersion.Name != arg.Name { + continue } - if !job.StartedAt.Valid { - stat.PendingWorkspaces++ + if templateVersion.OrganizationID != arg.OrganizationID { continue } - if job.StartedAt.Valid && - !job.CanceledAt.Valid && - time.Since(job.UpdatedAt) <= 30*time.Second && - !job.CompletedAt.Valid { - stat.BuildingWorkspaces++ + currentTemplateVersion = templateVersion + break + } + + previousTemplateVersions := make([]database.TemplateVersion, 0) + for _, templateVersion := range q.templateVersions { + if templateVersion.ID == currentTemplateVersion.ID { continue } - if job.CompletedAt.Valid && - !job.CanceledAt.Valid && - !job.Error.Valid { - if build.Transition == database.WorkspaceTransitionStart { - stat.RunningWorkspaces++ - } - if build.Transition == database.WorkspaceTransitionStop { - stat.StoppedWorkspaces++ - } + if templateVersion.OrganizationID != arg.OrganizationID { continue } - if job.CanceledAt.Valid || job.Error.Valid { - stat.FailedWorkspaces++ + if templateVersion.TemplateID != currentTemplateVersion.TemplateID { continue } + + if templateVersion.CreatedAt.Before(currentTemplateVersion.CreatedAt) { + previousTemplateVersions = append(previousTemplateVersions, templateVersion) + } } - return stat, nil -} -func (q *fakeQuerier) GetFileByHashAndCreator(_ context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err + if len(previousTemplateVersions) == 0 { + return database.TemplateVersion{}, sql.ErrNoRows } + sort.Slice(previousTemplateVersions, func(i, j int) bool { + return previousTemplateVersions[i].CreatedAt.After(previousTemplateVersions[j].CreatedAt) + }) + + return previousTemplateVersions[0], nil +} + +func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, file := range q.files { - if file.Hash == arg.Hash && file.CreatedBy == arg.CreatedBy { - return file, nil - } + if len(q.provisionerDaemons) == 0 { + return nil, sql.ErrNoRows } - return database.File{}, sql.ErrNoRows + return q.provisionerDaemons, nil } -func (q *fakeQuerier) GetFileByID(_ context.Context, id uuid.UUID) (database.File, error) { +func (q *fakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, file := range q.files { - if file.ID == id { - return file, nil - } - } - return database.File{}, sql.ErrNoRows + return q.getProvisionerJobByIDNoLock(ctx, id) } -func (q *fakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]database.GetFileTemplatesRow, error) { +func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { q.mutex.RLock() defer q.mutex.RUnlock() - rows := make([]database.GetFileTemplatesRow, 0) - var file database.File - for _, f := range q.files { - if f.ID == id { - file = f - break + jobs := make([]database.ProvisionerJob, 0) + for _, job := range q.provisionerJobs { + for _, id := range ids { + if id == job.ID { + jobs = append(jobs, job) + break + } } } - if file.Hash == "" { - return rows, nil + if len(jobs) == 0 { + return nil, sql.ErrNoRows } + return jobs, nil +} + +func (q *fakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(_ context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + jobs := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0) + queuePosition := int64(1) for _, job := range q.provisionerJobs { - if job.FileID == id { - for _, version := range q.templateVersions { - if version.JobID == job.ID { - for _, template := range q.templates { - if template.ID == version.TemplateID.UUID { - rows = append(rows, database.GetFileTemplatesRow{ - FileID: file.ID, - FileCreatedBy: file.CreatedBy, - TemplateID: template.ID, - TemplateOrganizationID: template.OrganizationID, - TemplateCreatedBy: template.CreatedBy, - UserACL: template.UserACL, - GroupACL: template.GroupACL, - }) - } - } + for _, id := range ids { + if id == job.ID { + job := database.GetProvisionerJobsByIDsWithQueuePositionRow{ + ProvisionerJob: job, } + if !job.ProvisionerJob.StartedAt.Valid { + job.QueuePosition = queuePosition + } + jobs = append(jobs, job) + break } } + if !job.StartedAt.Valid { + queuePosition++ + } } - - return rows, nil + for _, job := range jobs { + if !job.ProvisionerJob.StartedAt.Valid { + // Set it to the max position! + job.QueueSize = queuePosition + } + } + return jobs, nil } -func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - if err := validateDatabaseType(arg); err != nil { - return 0, err +func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + jobs := make([]database.ProvisionerJob, 0) + for _, job := range q.provisionerJobs { + if job.CreatedAt.After(after) { + jobs = append(jobs, job) + } } - count, err := q.GetAuthorizedUserCount(ctx, arg, nil) - return count, err + return jobs, nil } -func (q *fakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { +func (q *fakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitAuthLink{}, err + return nil, err } q.mutex.RLock() defer q.mutex.RUnlock() - for _, gitAuthLink := range q.gitAuthLinks { - if arg.UserID != gitAuthLink.UserID { + + logs := make([]database.ProvisionerJobLog, 0) + for _, jobLog := range q.provisionerJobLogs { + if jobLog.JobID != arg.JobID { continue } - if arg.ProviderID != gitAuthLink.ProviderID { + if jobLog.ID <= arg.CreatedAfter { continue } - return gitAuthLink, nil + logs = append(logs, jobLog) } - return database.GitAuthLink{}, sql.ErrNoRows + return logs, nil } -func (q *fakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { +func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, key := range q.gitSSHKey { - if key.UserID == userID { - return key, nil + var sum int64 + for _, member := range q.groupMembers { + if member.UserID != userID { + continue + } + for _, group := range q.groups { + if group.ID == member.GroupID { + sum += int64(group.QuotaAllowance) + } } } - return database.GitSSHKey{}, sql.ErrNoRows + return sum, nil } -func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { +func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getGroupByIDNoLock(ctx, id) -} + var sum int64 + for _, workspace := range q.workspaces { + if workspace.OwnerID != userID { + continue + } + if workspace.Deleted { + continue + } -func (q *fakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err + var lastBuild database.WorkspaceBuild + for _, build := range q.workspaceBuilds { + if build.WorkspaceID != workspace.ID { + continue + } + if build.CreatedAt.After(lastBuild.CreatedAt) { + lastBuild = build + } + } + sum += int64(lastBuild.DailyCost) } + return sum, nil +} +func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { q.mutex.RLock() defer q.mutex.RUnlock() - - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil + replicas := make([]database.Replica, 0) + for _, replica := range q.replicas { + if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { + replicas = append(replicas, replica) } } - - return database.Group{}, sql.ErrNoRows + return replicas, nil } -func (q *fakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]database.User, error) { +func (q *fakeQuerier) GetServiceBanner(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var members []database.GroupMember - for _, member := range q.groupMembers { - if member.GroupID == groupID { - members = append(members, member) - } + if q.serviceBanner == nil { + return "", sql.ErrNoRows } - users := make([]database.User, 0, len(members)) - - for _, member := range members { - for _, user := range q.users { - if user.ID == member.UserID && user.Status == database.UserStatusActive && !user.Deleted { - users = append(users, user) - break - } - } - } + return string(q.serviceBanner), nil +} - return users, nil +func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { + return nil, ErrUnimplemented } -func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationID uuid.UUID) ([]database.Group, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { + return nil, ErrUnimplemented +} - var groups []database.Group - for _, group := range q.groups { - // Omit the allUsers group. - if group.OrganizationID == organizationID && group.ID != organizationID { - groups = append(groups, group) - } +func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GetTemplateAverageBuildTimeRow{}, err } - return groups, nil -} - -func (q *fakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) { + var emptyRow database.GetTemplateAverageBuildTimeRow + var ( + startTimes []float64 + stopTimes []float64 + deleteTimes []float64 + ) q.mutex.RLock() defer q.mutex.RUnlock() + for _, wb := range q.workspaceBuilds { + version, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) + if err != nil { + return emptyRow, err + } + if version.TemplateID != arg.TemplateID { + continue + } - hungJobs := []database.ProvisionerJob{} - for _, provisionerJob := range q.provisionerJobs { - if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) { - hungJobs = append(hungJobs, provisionerJob) + job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) + if err != nil { + return emptyRow, err + } + if job.CompletedAt.Valid { + took := job.CompletedAt.Time.Sub(job.StartedAt.Time).Seconds() + switch wb.Transition { + case database.WorkspaceTransitionStart: + startTimes = append(startTimes, took) + case database.WorkspaceTransitionStop: + stopTimes = append(stopTimes, took) + case database.WorkspaceTransitionDelete: + deleteTimes = append(deleteTimes, took) + } } } - return hungJobs, nil -} - -func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - if q.lastUpdateCheck == nil { - return "", sql.ErrNoRows + tryPercentile := func(fs []float64, p float64) float64 { + if len(fs) == 0 { + return -1 + } + sort.Float64s(fs) + return fs[int(float64(len(fs))*p/100)] } - return string(q.lastUpdateCheck), nil + + var row database.GetTemplateAverageBuildTimeRow + row.Delete50, row.Delete95 = tryPercentile(deleteTimes, 50), tryPercentile(deleteTimes, 95) + row.Stop50, row.Stop95 = tryPercentile(stopTimes, 50), tryPercentile(stopTimes, 95) + row.Start50, row.Start95 = tryPercentile(startTimes, 50), tryPercentile(startTimes, 95) + return row, nil } -func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { +func (q *fakeQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) + return q.getTemplateByIDNoLock(ctx, id) } -func (q *fakeQuerier) GetLatestWorkspaceBuilds(_ context.Context) ([]database.WorkspaceBuild, error) { +func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Template{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - id := workspaceBuild.WorkspaceID - if workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = workspaceBuild - buildNumbers[id] = workspaceBuild.BuildNumber + for _, template := range q.templates { + if template.OrganizationID != arg.OrganizationID { + continue } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) + if !strings.EqualFold(template.Name, arg.Name) { + continue } + if template.Deleted != arg.Deleted { + continue + } + return template.DeepCopy(), nil } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil + return database.Template{}, sql.ErrNoRows } -func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { +func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - for _, id := range ids { - if id == workspaceBuild.WorkspaceID && workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = workspaceBuild - buildNumbers[id] = workspaceBuild.BuildNumber - } + seens := make(map[time.Time]map[uuid.UUID]struct{}) + + for _, as := range q.workspaceAgentStats { + if as.TemplateID != arg.TemplateID { + continue } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) + if as.ConnectionCount == 0 { + continue } - } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil -} -func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + date := as.CreatedAt.UTC().Add(time.Duration(arg.TzOffset) * time.Hour * -1).Truncate(time.Hour * 24) - for _, license := range q.licenses { - if license.ID == id { - return license, nil + dateEntry := seens[date] + if dateEntry == nil { + dateEntry = make(map[uuid.UUID]struct{}) } + dateEntry[as.UserID] = struct{}{} + seens[date] = dateEntry } - return database.License{}, sql.ErrNoRows -} - -func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - results := append([]database.License{}, q.licenses...) - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil -} - -func (q *fakeQuerier) GetLogoURL(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + seenKeys := maps.Keys(seens) + sort.Slice(seenKeys, func(i, j int) bool { + return seenKeys[i].Before(seenKeys[j]) + }) - if q.logoURL == "" { - return "", sql.ErrNoRows + var rs []database.GetTemplateDAUsRow + for _, key := range seenKeys { + ids := seens[key] + for id := range ids { + rs = append(rs, database.GetTemplateDAUsRow{ + Date: key, + UserID: id, + }) + } } - return q.logoURL, nil + return rs, nil } -func (q *fakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) { +func (q *fakeQuerier) GetTemplateVersionByID(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.oauthSigningKey, nil + return q.getTemplateVersionByIDNoLock(ctx, templateVersionID) } -func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { +func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, organization := range q.organizations { - if organization.ID == id { - return organization, nil + for _, templateVersion := range q.templateVersions { + if templateVersion.JobID != jobID { + continue } + return templateVersion, nil } - return database.Organization{}, sql.ErrNoRows + return database.TemplateVersion{}, sql.ErrNoRows } -func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { +func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { + if err := validateDatabaseType(arg); err != nil { + return database.TemplateVersion{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - for _, organization := range q.organizations { - if organization.Name == name { - return organization, nil + for _, templateVersion := range q.templateVersions { + if templateVersion.TemplateID != arg.TemplateID { + continue } + if !strings.EqualFold(templateVersion.Name, arg.Name) { + continue + } + return templateVersion, nil } - return database.Organization{}, sql.ErrNoRows + return database.TemplateVersion{}, sql.ErrNoRows } -func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { +func (q *fakeQuerier) GetTemplateVersionParameters(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { q.mutex.RLock() defer q.mutex.RUnlock() - getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) - for _, userID := range ids { - userOrganizationIDs := make([]uuid.UUID, 0) - for _, membership := range q.organizationMembers { - if membership.UserID == userID { - userOrganizationIDs = append(userOrganizationIDs, membership.OrganizationID) - } + parameters := make([]database.TemplateVersionParameter, 0) + for _, param := range q.templateVersionParameters { + if param.TemplateVersionID != templateVersionID { + continue } - getOrganizationIDsByMemberIDRows = append(getOrganizationIDsByMemberIDRows, database.GetOrganizationIDsByMemberIDsRow{ - UserID: userID, - OrganizationIDs: userOrganizationIDs, - }) - } - if len(getOrganizationIDsByMemberIDRows) == 0 { - return nil, sql.ErrNoRows + parameters = append(parameters, param) } - return getOrganizationIDsByMemberIDRows, nil + sort.Slice(parameters, func(i, j int) bool { + if parameters[i].DisplayOrder != parameters[j].DisplayOrder { + return parameters[i].DisplayOrder < parameters[j].DisplayOrder + } + return strings.ToLower(parameters[i].Name) < strings.ToLower(parameters[j].Name) + }) + return parameters, nil } -func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - +func (q *fakeQuerier) GetTemplateVersionVariables(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, organizationMember := range q.organizationMembers { - if organizationMember.OrganizationID != arg.OrganizationID { - continue - } - if organizationMember.UserID != arg.UserID { + variables := make([]database.TemplateVersionVariable, 0) + for _, variable := range q.templateVersionVariables { + if variable.TemplateVersionID != templateVersionID { continue } - return organizationMember, nil + variables = append(variables, variable) } - return database.OrganizationMember{}, sql.ErrNoRows + return variables, nil } -func (q *fakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { +func (q *fakeQuerier) GetTemplateVersionsByIDs(_ context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var memberships []database.OrganizationMember - for _, organizationMember := range q.organizationMembers { - mem := organizationMember - if mem.UserID != userID { - continue + versions := make([]database.TemplateVersion, 0) + for _, version := range q.templateVersions { + for _, id := range ids { + if id == version.ID { + versions = append(versions, version) + break + } } - memberships = append(memberships, mem) } - return memberships, nil -} - -func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if len(q.organizations) == 0 { + if len(versions) == 0 { return nil, sql.ErrNoRows } - return q.organizations, nil + + return versions, nil } -func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UUID) ([]database.Organization, error) { +func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { + if err := validateDatabaseType(arg); err != nil { + return version, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - organizations := make([]database.Organization, 0) - for _, organizationMember := range q.organizationMembers { - if organizationMember.UserID != userID { + for _, templateVersion := range q.templateVersions { + if templateVersion.TemplateID.UUID != arg.TemplateID { continue } - for _, organization := range q.organizations { - if organization.ID != organizationMember.OrganizationID { - continue + version = append(version, templateVersion) + } + + // Database orders by created_at + slices.SortFunc(version, func(a, b database.TemplateVersion) bool { + if a.CreatedAt.Equal(b.CreatedAt) { + // Technically the postgres database also orders by uuid. So match + // that behavior + return a.ID.String() < b.ID.String() + } + return a.CreatedAt.Before(b.CreatedAt) + }) + + if arg.AfterID != uuid.Nil { + found := false + for i, v := range version { + if v.ID == arg.AfterID { + // We want to return all users after index i. + version = version[i+1:] + found = true + break } - organizations = append(organizations, organization) + } + + // If no users after the time, then we return an empty list. + if !found { + return nil, sql.ErrNoRows } } - if len(organizations) == 0 { + + if arg.OffsetOpt > 0 { + if int(arg.OffsetOpt) > len(version)-1 { + return nil, sql.ErrNoRows + } + version = version[arg.OffsetOpt:] + } + + if arg.LimitOpt > 0 { + if int(arg.LimitOpt) > len(version) { + arg.LimitOpt = int32(len(version)) + } + version = version[:arg.LimitOpt] + } + + if len(version) == 0 { return nil, sql.ErrNoRows } - return organizations, nil + + return version, nil } -func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { +func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { q.mutex.RLock() defer q.mutex.RUnlock() - parameters := make([]database.ParameterSchema, 0) - for _, parameterSchema := range q.parameterSchemas { - if parameterSchema.JobID != jobID { - continue + versions := make([]database.TemplateVersion, 0) + for _, version := range q.templateVersions { + if version.CreatedAt.After(after) { + versions = append(versions, version) } - parameters = append(parameters, parameterSchema) } - if len(parameters) == 0 { - return nil, sql.ErrNoRows + return versions, nil +} + +func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + templates := slices.Clone(q.templates) + for i := range templates { + templates[i] = templates[i].DeepCopy() } - sort.Slice(parameters, func(i, j int) bool { - return parameters[i].Index < parameters[j].Index + slices.SortFunc(templates, func(i, j database.Template) bool { + if i.Name != j.Name { + return i.Name < j.Name + } + return i.ID.String() < j.ID.String() }) - return parameters, nil + + return templates, nil } -func (q *fakeQuerier) GetPreviousTemplateVersion(_ context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { +func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err + return nil, err } + return q.GetAuthorizedTemplates(ctx, arg, nil) +} + +func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var currentTemplateVersion database.TemplateVersion - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if templateVersion.Name != arg.Name { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - currentTemplateVersion = templateVersion - break - } - - previousTemplateVersions := make([]database.TemplateVersion, 0) - for _, templateVersion := range q.templateVersions { - if templateVersion.ID == currentTemplateVersion.ID { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - if templateVersion.TemplateID != currentTemplateVersion.TemplateID { - continue - } - - if templateVersion.CreatedAt.Before(currentTemplateVersion.CreatedAt) { - previousTemplateVersions = append(previousTemplateVersions, templateVersion) + now := time.Now() + var results []database.License + for _, l := range q.licenses { + if l.Exp.After(now) { + results = append(results, l) } } + sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) + return results, nil +} - if len(previousTemplateVersions) == 0 { - return database.TemplateVersion{}, sql.ErrNoRows +func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { + if err := validateDatabaseType(arg); err != nil { + return database.User{}, err } - sort.Slice(previousTemplateVersions, func(i, j int) bool { - return previousTemplateVersions[i].CreatedAt.After(previousTemplateVersions[j].CreatedAt) - }) - - return previousTemplateVersions[0], nil -} - -func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { q.mutex.RLock() defer q.mutex.RUnlock() - if len(q.provisionerDaemons) == 0 { - return nil, sql.ErrNoRows + for _, user := range q.users { + if !user.Deleted && (strings.EqualFold(user.Email, arg.Email) || strings.EqualFold(user.Username, arg.Username)) { + return user, nil + } } - return q.provisionerDaemons, nil + return database.User{}, sql.ErrNoRows } -func (q *fakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { +func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getProvisionerJobByIDNoLock(ctx, id) + return q.getUserByIDNoLock(id) } -func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { +func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { q.mutex.RLock() defer q.mutex.RUnlock() - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - for _, id := range ids { - if id == job.ID { - jobs = append(jobs, job) - break - } + existing := int64(0) + for _, u := range q.users { + if !u.Deleted { + existing++ } } - if len(jobs) == 0 { - return nil, sql.ErrNoRows - } - - return jobs, nil + return existing, nil } -func (q *fakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(_ context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { +func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { q.mutex.RLock() defer q.mutex.RUnlock() - jobs := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0) - queuePosition := int64(1) - for _, job := range q.provisionerJobs { - for _, id := range ids { - if id == job.ID { - job := database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ProvisionerJob: job, - } - if !job.ProvisionerJob.StartedAt.Valid { - job.QueuePosition = queuePosition - } - jobs = append(jobs, job) - break - } - } - if !job.StartedAt.Valid { - queuePosition++ - } - } - for _, job := range jobs { - if !job.ProvisionerJob.StartedAt.Valid { - // Set it to the max position! - job.QueueSize = queuePosition + for _, link := range q.userLinks { + if link.LinkedID == id { + return link, nil } } - return jobs, nil + return database.UserLink{}, sql.ErrNoRows } -func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { +func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + if err := validateDatabaseType(params); err != nil { + return database.UserLink{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - if job.CreatedAt.After(after) { - jobs = append(jobs, job) + for _, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + return link, nil } } - return jobs, nil + return database.UserLink{}, sql.ErrNoRows } -func (q *fakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { - if err := validateDatabaseType(arg); err != nil { +func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { + if err := validateDatabaseType(params); err != nil { return nil, err } q.mutex.RLock() defer q.mutex.RUnlock() - logs := make([]database.ProvisionerJobLog, 0) - for _, jobLog := range q.provisionerJobLogs { - if jobLog.JobID != arg.JobID { - continue - } - if jobLog.ID <= arg.CreatedAfter { - continue + // Avoid side-effect of sorting. + users := make([]database.User, len(q.users)) + copy(users, q.users) + + // Database orders by username + slices.SortFunc(users, func(a, b database.User) bool { + return strings.ToLower(a.Username) < strings.ToLower(b.Username) + }) + + // Filter out deleted since they should never be returned.. + tmp := make([]database.User, 0, len(users)) + for _, user := range users { + if !user.Deleted { + tmp = append(tmp, user) } - logs = append(logs, jobLog) } - return logs, nil -} + users = tmp -func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + if params.AfterID != uuid.Nil { + found := false + for i, v := range users { + if v.ID == params.AfterID { + // We want to return all users after index i. + users = users[i+1:] + found = true + break + } + } - var sum int64 - for _, member := range q.groupMembers { - if member.UserID != userID { - continue + // If no users after the time, then we return an empty list. + if !found { + return []database.GetUsersRow{}, nil } - for _, group := range q.groups { - if group.ID == member.GroupID { - sum += int64(group.QuotaAllowance) + } + + if params.Search != "" { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) } } + users = tmp } - return sum, nil -} -func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var sum int64 - for _, workspace := range q.workspaces { - if workspace.OwnerID != userID { - continue - } - if workspace.Deleted { - continue + if len(params.Status) > 0 { + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { + return strings.EqualFold(string(a), string(b)) + }) { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) + } } + users = usersFilteredByStatus + } - var lastBuild database.WorkspaceBuild - for _, build := range q.workspaceBuilds { - if build.WorkspaceID != workspace.ID { - continue + if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { + usersFilteredByRole := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { + usersFilteredByRole = append(usersFilteredByRole, users[i]) } - if build.CreatedAt.After(lastBuild.CreatedAt) { - lastBuild = build + } + users = usersFilteredByRole + } + + if !params.LastSeenBefore.IsZero() { + usersFilteredByLastSeen := make([]database.User, 0, len(users)) + for i, user := range users { + if user.LastSeenAt.Before(params.LastSeenBefore) { + usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) } } - sum += int64(lastBuild.DailyCost) + users = usersFilteredByLastSeen } - return sum, nil -} -func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - replicas := make([]database.Replica, 0) - for _, replica := range q.replicas { - if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { - replicas = append(replicas, replica) + if !params.LastSeenAfter.IsZero() { + usersFilteredByLastSeen := make([]database.User, 0, len(users)) + for i, user := range users { + if user.LastSeenAt.After(params.LastSeenAfter) { + usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) + } } + users = usersFilteredByLastSeen } - return replicas, nil -} -func (q *fakeQuerier) GetServiceBanner(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + beforePageCount := len(users) - if q.serviceBanner == nil { - return "", sql.ErrNoRows + if params.OffsetOpt > 0 { + if int(params.OffsetOpt) > len(users)-1 { + return []database.GetUsersRow{}, nil + } + users = users[params.OffsetOpt:] } - return string(q.serviceBanner), nil -} - -func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} + if params.LimitOpt > 0 { + if int(params.LimitOpt) > len(users) { + params.LimitOpt = int32(len(users)) + } + users = users[:params.LimitOpt] + } -func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { - return nil, ErrUnimplemented + return convertUsers(users, int64(beforePageCount)), nil } -func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GetTemplateAverageBuildTimeRow{}, err - } - - var emptyRow database.GetTemplateAverageBuildTimeRow - var ( - startTimes []float64 - stopTimes []float64 - deleteTimes []float64 - ) +func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, wb := range q.workspaceBuilds { - version, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return emptyRow, err - } - if version.TemplateID != arg.TemplateID { - continue - } - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return emptyRow, err - } - if job.CompletedAt.Valid { - took := job.CompletedAt.Time.Sub(job.StartedAt.Time).Seconds() - switch wb.Transition { - case database.WorkspaceTransitionStart: - startTimes = append(startTimes, took) - case database.WorkspaceTransitionStop: - stopTimes = append(stopTimes, took) - case database.WorkspaceTransitionDelete: - deleteTimes = append(deleteTimes, took) + users := make([]database.User, 0) + for _, user := range q.users { + for _, id := range ids { + if user.ID != id { + continue } + users = append(users, user) } } + return users, nil +} - tryPercentile := func(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 +func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + // The schema sorts this by created at, so we iterate the array backwards. + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] + if agent.AuthToken == authToken { + return agent, nil } - sort.Float64s(fs) - return fs[int(float64(len(fs))*p/100)] } - - var row database.GetTemplateAverageBuildTimeRow - row.Delete50, row.Delete95 = tryPercentile(deleteTimes, 50), tryPercentile(deleteTimes, 95) - row.Stop50, row.Stop95 = tryPercentile(stopTimes, 50), tryPercentile(stopTimes, 95) - row.Start50, row.Start95 = tryPercentile(startTimes, 50), tryPercentile(startTimes, 95) - return row, nil + return database.WorkspaceAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { +func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getTemplateByIDNoLock(ctx, id) + return q.getWorkspaceAgentByIDNoLock(ctx, id) } -func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err - } - +func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, template := range q.templates { - if template.OrganizationID != arg.OrganizationID { - continue - } - if !strings.EqualFold(template.Name, arg.Name) { - continue - } - if template.Deleted != arg.Deleted { - continue + // The schema sorts this by created at, so we iterate the array backwards. + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] + if agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID { + return agent, nil } - return template.DeepCopy(), nil } - return database.Template{}, sql.ErrNoRows + return database.WorkspaceAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { +func (q *fakeQuerier) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceAgentLifecycleStateByIDRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.TemplateID != arg.TemplateID { - continue - } - if as.ConnectionCount == 0 { - continue - } - - date := as.CreatedAt.UTC().Add(time.Duration(arg.TzOffset) * time.Hour * -1).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) - } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry - } - - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) - - var rs []database.GetTemplateDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetTemplateDAUsRow{ - Date: key, - UserID: id, - }) - } + agent, err := q.getWorkspaceAgentByIDNoLock(ctx, id) + if err != nil { + return database.GetWorkspaceAgentLifecycleStateByIDRow{}, err } - - return rs, nil -} - -func (q *fakeQuerier) GetTemplateVersionByID(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getTemplateVersionByIDNoLock(ctx, templateVersionID) + return database.GetWorkspaceAgentLifecycleStateByIDRow{ + LifecycleState: agent.LifecycleState, + StartedAt: agent.StartedAt, + ReadyAt: agent.ReadyAt, + }, nil } -func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { +func (q *fakeQuerier) GetWorkspaceAgentMetadata(_ context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentMetadatum, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, templateVersion := range q.templateVersions { - if templateVersion.JobID != jobID { - continue + metadata := make([]database.WorkspaceAgentMetadatum, 0) + for _, m := range q.workspaceAgentMetadata { + if m.WorkspaceAgentID == workspaceAgentID { + metadata = append(metadata, m) } - return templateVersion, nil } - return database.TemplateVersion{}, sql.ErrNoRows + return metadata, nil } -func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { +func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg database.GetWorkspaceAgentStartupLogsAfterParams) ([]database.WorkspaceAgentStartupLog, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err + return nil, err } q.mutex.RLock() defer q.mutex.RUnlock() - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if !strings.EqualFold(templateVersion.Name, arg.Name) { + logs := []database.WorkspaceAgentStartupLog{} + for _, log := range q.workspaceAgentLogs { + if log.AgentID != arg.AgentID { continue } - return templateVersion, nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *fakeQuerier) GetTemplateVersionParameters(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.TemplateVersionParameter, 0) - for _, param := range q.templateVersionParameters { - if param.TemplateVersionID != templateVersionID { + if arg.CreatedAfter != 0 && log.ID <= arg.CreatedAfter { continue } - parameters = append(parameters, param) + logs = append(logs, log) } - sort.Slice(parameters, func(i, j int) bool { - if parameters[i].DisplayOrder != parameters[j].DisplayOrder { - return parameters[i].DisplayOrder < parameters[j].DisplayOrder - } - return strings.ToLower(parameters[i].Name) < strings.ToLower(parameters[j].Name) - }) - return parameters, nil + return logs, nil } -func (q *fakeQuerier) GetTemplateVersionVariables(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { +func (q *fakeQuerier) GetWorkspaceAgentStats(_ context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - variables := make([]database.TemplateVersionVariable, 0) - for _, variable := range q.templateVersionVariables { - if variable.TemplateVersionID != templateVersionID { - continue + agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) + for _, agentStat := range q.workspaceAgentStats { + if agentStat.CreatedAt.After(createdAfter) { + agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) } - variables = append(variables, variable) } - return variables, nil -} - -func (q *fakeQuerier) GetTemplateVersionsByIDs(_ context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - for _, id := range ids { - if id == version.ID { - versions = append(versions, version) - break - } + latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} + for _, agentStat := range q.workspaceAgentStats { + if agentStat.CreatedAt.After(createdAfter) { + latestAgentStats[agentStat.AgentID] = agentStat } } - if len(versions) == 0 { - return nil, sql.ErrNoRows - } - - return versions, nil -} -func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { - if err := validateDatabaseType(arg); err != nil { - return version, err + statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsRow{} + for _, agentStat := range latestAgentStats { + stat := statByAgent[agentStat.AgentID] + stat.SessionCountVSCode += agentStat.SessionCountVSCode + stat.SessionCountJetBrains += agentStat.SessionCountJetBrains + stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY + stat.SessionCountSSH += agentStat.SessionCountSSH + statByAgent[stat.AgentID] = stat } - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID.UUID != arg.TemplateID { + latenciesByAgent := map[uuid.UUID][]float64{} + minimumDateByAgent := map[uuid.UUID]time.Time{} + for _, agentStat := range agentStatsCreatedAfter { + if agentStat.ConnectionMedianLatencyMS <= 0 { continue } - version = append(version, templateVersion) - } - - // Database orders by created_at - slices.SortFunc(version, func(a, b database.TemplateVersion) bool { - if a.CreatedAt.Equal(b.CreatedAt) { - // Technically the postgres database also orders by uuid. So match - // that behavior - return a.ID.String() < b.ID.String() - } - return a.CreatedAt.Before(b.CreatedAt) - }) - - if arg.AfterID != uuid.Nil { - found := false - for i, v := range version { - if v.ID == arg.AfterID { - // We want to return all users after index i. - version = version[i+1:] - found = true - break - } - } - - // If no users after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows + stat := statByAgent[agentStat.AgentID] + minimumDate := minimumDateByAgent[agentStat.AgentID] + if agentStat.CreatedAt.Before(minimumDate) || minimumDate.IsZero() { + minimumDateByAgent[agentStat.AgentID] = agentStat.CreatedAt } + stat.WorkspaceRxBytes += agentStat.RxBytes + stat.WorkspaceTxBytes += agentStat.TxBytes + statByAgent[agentStat.AgentID] = stat + latenciesByAgent[agentStat.AgentID] = append(latenciesByAgent[agentStat.AgentID], agentStat.ConnectionMedianLatencyMS) } - if arg.OffsetOpt > 0 { - if int(arg.OffsetOpt) > len(version)-1 { - return nil, sql.ErrNoRows + tryPercentile := func(fs []float64, p float64) float64 { + if len(fs) == 0 { + return -1 } - version = version[arg.OffsetOpt:] + sort.Float64s(fs) + return fs[int(float64(len(fs))*p/100)] } - if arg.LimitOpt > 0 { - if int(arg.LimitOpt) > len(version) { - arg.LimitOpt = int32(len(version)) + for _, stat := range statByAgent { + stat.AggregatedFrom = minimumDateByAgent[stat.AgentID] + statByAgent[stat.AgentID] = stat + + latencies, ok := latenciesByAgent[stat.AgentID] + if !ok { + continue } - version = version[:arg.LimitOpt] + stat.WorkspaceConnectionLatency50 = tryPercentile(latencies, 50) + stat.WorkspaceConnectionLatency95 = tryPercentile(latencies, 95) + statByAgent[stat.AgentID] = stat } - if len(version) == 0 { - return nil, sql.ErrNoRows + stats := make([]database.GetWorkspaceAgentStatsRow, 0, len(statByAgent)) + for _, agent := range statByAgent { + stats = append(stats, agent) } - - return version, nil + return stats, nil } -func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { +func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - if version.CreatedAt.After(after) { - versions = append(versions, version) - } - } - return versions, nil -} + agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) + latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} -func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + for _, agentStat := range q.workspaceAgentStats { + if agentStat.CreatedAt.After(createdAfter) { + agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) + latestAgentStats[agentStat.AgentID] = agentStat + } + } - templates := slices.Clone(q.templates) - for i := range templates { - templates[i] = templates[i].DeepCopy() + statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsAndLabelsRow{} + + // Session and connection metrics + for _, agentStat := range latestAgentStats { + stat := statByAgent[agentStat.AgentID] + stat.SessionCountVSCode += agentStat.SessionCountVSCode + stat.SessionCountJetBrains += agentStat.SessionCountJetBrains + stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY + stat.SessionCountSSH += agentStat.SessionCountSSH + stat.ConnectionCount += agentStat.ConnectionCount + if agentStat.ConnectionMedianLatencyMS >= 0 && stat.ConnectionMedianLatencyMS < agentStat.ConnectionMedianLatencyMS { + stat.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS + } + statByAgent[agentStat.AgentID] = stat } - slices.SortFunc(templates, func(i, j database.Template) bool { - if i.Name != j.Name { - return i.Name < j.Name + + // Tx, Rx metrics + for _, agentStat := range agentStatsCreatedAfter { + stat := statByAgent[agentStat.AgentID] + stat.RxBytes += agentStat.RxBytes + stat.TxBytes += agentStat.TxBytes + statByAgent[agentStat.AgentID] = stat + } + + // Labels + for _, agentStat := range agentStatsCreatedAfter { + stat := statByAgent[agentStat.AgentID] + + user, err := q.getUserByIDNoLock(agentStat.UserID) + if err != nil { + return nil, err } - return i.ID.String() < j.ID.String() - }) - return templates, nil -} + stat.Username = user.Username -func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err + workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID) + if err != nil { + return nil, err + } + stat.WorkspaceName = workspace.Name + + agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID) + if err != nil { + return nil, err + } + stat.AgentName = agent.Name + + statByAgent[agentStat.AgentID] = stat } - return q.GetAuthorizedTemplates(ctx, arg, nil) + stats := make([]database.GetWorkspaceAgentStatsAndLabelsRow, 0, len(statByAgent)) + for _, agent := range statByAgent { + stats = append(stats, agent) + } + return stats, nil } -func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { +func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { q.mutex.RLock() defer q.mutex.RUnlock() - now := time.Now() - var results []database.License - for _, l := range q.licenses { - if l.Exp.After(now) { - results = append(results, l) + return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) +} + +func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + workspaceAgents := make([]database.WorkspaceAgent, 0) + for _, agent := range q.workspaceAgents { + if agent.CreatedAt.After(after) { + workspaceAgents = append(workspaceAgents, agent) } } - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil + return workspaceAgents, nil } -func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { +func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Get latest build for workspace. + workspaceBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) + if err != nil { + return nil, xerrors.Errorf("get latest workspace build: %w", err) + } + + // Get resources for build. + resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID) + if err != nil { + return nil, xerrors.Errorf("get workspace resources: %w", err) + } + if len(resources) == 0 { + return []database.WorkspaceAgent{}, nil + } + + resourceIDs := make([]uuid.UUID, len(resources)) + for i, resource := range resources { + resourceIDs[i] = resource.ID + } + + agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) + if err != nil { + return nil, xerrors.Errorf("get workspace agents: %w", err) + } + + return agents, nil +} + +func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndSlug(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { if err := validateDatabaseType(arg); err != nil { - return database.User{}, err + return database.WorkspaceApp{}, err } q.mutex.RLock() defer q.mutex.RUnlock() - for _, user := range q.users { - if !user.Deleted && (strings.EqualFold(user.Email, arg.Email) || strings.EqualFold(user.Username, arg.Username)) { - return user, nil + for _, app := range q.workspaceApps { + if app.AgentID != arg.AgentID { + continue + } + if app.Slug != arg.Slug { + continue } + return app, nil } - return database.User{}, sql.ErrNoRows + return database.WorkspaceApp{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { +func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getUserByIDNoLock(id) + apps := make([]database.WorkspaceApp, 0) + for _, app := range q.workspaceApps { + if app.AgentID == id { + apps = append(apps, app) + } + } + if len(apps) == 0 { + return nil, sql.ErrNoRows + } + return apps, nil } -func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { +func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { q.mutex.RLock() defer q.mutex.RUnlock() - existing := int64(0) - for _, u := range q.users { - if !u.Deleted { - existing++ + apps := make([]database.WorkspaceApp, 0) + for _, app := range q.workspaceApps { + for _, id := range ids { + if app.AgentID == id { + apps = append(apps, app) + break + } } } - return existing, nil + return apps, nil } -func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { +func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, link := range q.userLinks { - if link.LinkedID == id { - return link, nil + apps := make([]database.WorkspaceApp, 0) + for _, app := range q.workspaceApps { + if app.CreatedAt.After(after) { + apps = append(apps, app) } } - return database.UserLink{}, sql.ErrNoRows + return apps, nil } -func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } +func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.getWorkspaceBuildByIDNoLock(ctx, id) +} +func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - return link, nil + for _, build := range q.workspaceBuilds { + if build.JobID == jobID { + return build, nil } } - return database.UserLink{}, sql.ErrNoRows + return database.WorkspaceBuild{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err +func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { + if err := validateDatabaseType(arg); err != nil { + return database.WorkspaceBuild{}, err } q.mutex.RLock() defer q.mutex.RUnlock() - // Avoid side-effect of sorting. - users := make([]database.User, len(q.users)) - copy(users, q.users) - - // Database orders by username - slices.SortFunc(users, func(a, b database.User) bool { - return strings.ToLower(a.Username) < strings.ToLower(b.Username) - }) - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.WorkspaceID != arg.WorkspaceID { + continue + } + if workspaceBuild.BuildNumber != arg.BuildNumber { + continue } + return workspaceBuild, nil } - users = tmp + return database.WorkspaceBuild{}, sql.ErrNoRows +} - if params.AfterID != uuid.Nil { - found := false - for i, v := range users { - if v.ID == params.AfterID { - // We want to return all users after index i. - users = users[i+1:] - found = true - break - } - } +func (q *fakeQuerier) GetWorkspaceBuildParameters(_ context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - // If no users after the time, then we return an empty list. - if !found { - return []database.GetUsersRow{}, nil + params := make([]database.WorkspaceBuildParameter, 0) + for _, param := range q.workspaceBuildParameters { + if param.WorkspaceBuildID != workspaceBuildID { + continue } + params = append(params, param) } + return params, nil +} - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp +func (q *fakeQuerier) GetWorkspaceBuildsByWorkspaceID(_ context.Context, + params database.GetWorkspaceBuildsByWorkspaceIDParams, +) ([]database.WorkspaceBuild, error) { + if err := validateDatabaseType(params); err != nil { + return nil, err } - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } + q.mutex.RLock() + defer q.mutex.RUnlock() - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } + history := make([]database.WorkspaceBuild, 0) + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.CreatedAt.Before(params.Since) { + continue + } + if workspaceBuild.WorkspaceID == params.WorkspaceID { + history = append(history, workspaceBuild) } - users = usersFilteredByRole } - if !params.LastSeenBefore.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.Before(params.LastSeenBefore) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) + // Order by build_number + slices.SortFunc(history, func(a, b database.WorkspaceBuild) bool { + // use greater than since we want descending order + return a.BuildNumber > b.BuildNumber + }) + + if params.AfterID != uuid.Nil { + found := false + for i, v := range history { + if v.ID == params.AfterID { + // We want to return all builds after index i. + history = history[i+1:] + found = true + break } } - users = usersFilteredByLastSeen - } - if !params.LastSeenAfter.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.After(params.LastSeenAfter) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) - } + // If no builds after the time, then we return an empty list. + if !found { + return nil, sql.ErrNoRows } - users = usersFilteredByLastSeen } - beforePageCount := len(users) - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(users)-1 { - return []database.GetUsersRow{}, nil + if int(params.OffsetOpt) > len(history)-1 { + return nil, sql.ErrNoRows } - users = users[params.OffsetOpt:] + history = history[params.OffsetOpt:] } if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(users) { - params.LimitOpt = int32(len(users)) + if int(params.LimitOpt) > len(history) { + params.LimitOpt = int32(len(history)) } - users = users[:params.LimitOpt] + history = history[:params.LimitOpt] } - return convertUsers(users, int64(beforePageCount)), nil + if len(history) == 0 { + return nil, sql.ErrNoRows + } + return history, nil } -func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { +func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() - users := make([]database.User, 0) - for _, user := range q.users { - for _, id := range ids { - if user.ID != id { - continue - } - users = append(users, user) + workspaceBuilds := make([]database.WorkspaceBuild, 0) + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.CreatedAt.After(after) { + workspaceBuilds = append(workspaceBuilds, workspaceBuild) } } - return users, nil + return workspaceBuilds, nil } -func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if agent.AuthToken == authToken { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows + return q.getWorkspaceByAgentIDNoLock(ctx, agentID) } -func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getWorkspaceAgentByIDNoLock(ctx, id) + return q.getWorkspaceByIDNoLock(ctx, id) } -func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Workspace{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID { - return agent, nil + var found *database.Workspace + for _, workspace := range q.workspaces { + workspace := workspace + if workspace.OwnerID != arg.OwnerID { + continue + } + if !strings.EqualFold(workspace.Name, arg.Name) { + continue + } + if workspace.Deleted != arg.Deleted { + continue + } + + // Return the most recent workspace with the given name + if found == nil || workspace.CreatedAt.After(found.CreatedAt) { + found = &workspace } } - return database.WorkspaceAgent{}, sql.ErrNoRows + if found != nil { + return *found, nil + } + return database.Workspace{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceAgentLifecycleStateByIDRow, error) { +func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { + if err := validateDatabaseType(workspaceAppID); err != nil { + return database.Workspace{}, err + } + q.mutex.RLock() defer q.mutex.RUnlock() - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, id) - if err != nil { - return database.GetWorkspaceAgentLifecycleStateByIDRow{}, err + for _, workspaceApp := range q.workspaceApps { + workspaceApp := workspaceApp + if workspaceApp.ID == workspaceAppID { + return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) + } } - return database.GetWorkspaceAgentLifecycleStateByIDRow{ - LifecycleState: agent.LifecycleState, - StartedAt: agent.StartedAt, - ReadyAt: agent.ReadyAt, - }, nil + return database.Workspace{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentMetadata(_ context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentMetadatum, error) { +func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { q.mutex.RLock() defer q.mutex.RUnlock() - metadata := make([]database.WorkspaceAgentMetadatum, 0) - for _, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID == workspaceAgentID { - metadata = append(metadata, m) + cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) + + for _, p := range q.workspaceProxies { + if !p.Deleted { + cpy = append(cpy, p) } } - return metadata, nil + return cpy, nil } -func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg database.GetWorkspaceAgentStartupLogsAfterParams) ([]database.WorkspaceAgentStartupLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - +func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { q.mutex.RLock() defer q.mutex.RUnlock() - logs := []database.WorkspaceAgentStartupLog{} - for _, log := range q.workspaceAgentLogs { - if log.AgentID != arg.AgentID { + // Return zero rows if this is called with a non-sanitized hostname. The SQL + // version of this query does the same thing. + if !validProxyByHostnameRegex.MatchString(params.Hostname) { + return database.WorkspaceProxy{}, sql.ErrNoRows + } + + // This regex matches the SQL version. + accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) + + for _, proxy := range q.workspaceProxies { + if proxy.Deleted { continue } - if arg.CreatedAfter != 0 && log.ID <= arg.CreatedAfter { - continue + if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { + return proxy, nil + } + + // Compile the app hostname regex. This is slow sadly. + if params.AllowWildcardHostname { + wildcardRegexp, err := httpapi.CompileHostnamePattern(proxy.WildcardHostname) + if err != nil { + return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) + } + if _, ok := httpapi.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { + return proxy, nil + } } - logs = append(logs, log) } - return logs, nil + + return database.WorkspaceProxy{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentStats(_ context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { +func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { q.mutex.RLock() defer q.mutex.RUnlock() - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) + for _, proxy := range q.workspaceProxies { + if proxy.ID == id { + return proxy, nil } } + return database.WorkspaceProxy{}, sql.ErrNoRows +} - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat +func (q *fakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, proxy := range q.workspaceProxies { + if proxy.Deleted { + continue + } + if proxy.Name == name { + return proxy, nil } } + return database.WorkspaceProxy{}, sql.ErrNoRows +} - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsRow{} - for _, agentStat := range latestAgentStats { - stat := statByAgent[agentStat.AgentID] - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - statByAgent[stat.AgentID] = stat - } - - latenciesByAgent := map[uuid.UUID][]float64{} - minimumDateByAgent := map[uuid.UUID]time.Time{} - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat := statByAgent[agentStat.AgentID] - minimumDate := minimumDateByAgent[agentStat.AgentID] - if agentStat.CreatedAt.Before(minimumDate) || minimumDate.IsZero() { - minimumDateByAgent[agentStat.AgentID] = agentStat.CreatedAt - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - latenciesByAgent[agentStat.AgentID] = append(latenciesByAgent[agentStat.AgentID], agentStat.ConnectionMedianLatencyMS) - } - - tryPercentile := func(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - return fs[int(float64(len(fs))*p/100)] - } - - for _, stat := range statByAgent { - stat.AggregatedFrom = minimumDateByAgent[stat.AgentID] - statByAgent[stat.AgentID] = stat - - latencies, ok := latenciesByAgent[stat.AgentID] - if !ok { - continue - } - stat.WorkspaceConnectionLatency50 = tryPercentile(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentile(latencies, 95) - statByAgent[stat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) - } - return stats, nil -} - -func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { +func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsAndLabelsRow{} - - // Session and connection metrics - for _, agentStat := range latestAgentStats { - stat := statByAgent[agentStat.AgentID] - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - stat.ConnectionCount += agentStat.ConnectionCount - if agentStat.ConnectionMedianLatencyMS >= 0 && stat.ConnectionMedianLatencyMS < agentStat.ConnectionMedianLatencyMS { - stat.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS - } - statByAgent[agentStat.AgentID] = stat - } - - // Tx, Rx metrics - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - stat.RxBytes += agentStat.RxBytes - stat.TxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - } - - // Labels - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - - user, err := q.getUserByIDNoLock(agentStat.UserID) - if err != nil { - return nil, err - } - - stat.Username = user.Username - - workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID) - if err != nil { - return nil, err - } - stat.WorkspaceName = workspace.Name - - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID) - if err != nil { - return nil, err + for _, resource := range q.workspaceResources { + if resource.ID == id { + return resource, nil } - stat.AgentName = agent.Name - - statByAgent[agentStat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsAndLabelsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) } - return stats, nil -} - -func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) + return database.WorkspaceResource{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { q.mutex.RLock() defer q.mutex.RUnlock() - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if agent.CreatedAt.After(after) { - workspaceAgents = append(workspaceAgents, agent) + metadata := make([]database.WorkspaceResourceMetadatum, 0) + for _, metadatum := range q.workspaceResourceMetadata { + for _, id := range ids { + if metadatum.WorkspaceResourceID == id { + metadata = append(metadata, metadatum) + } } } - return workspaceAgents, nil + return metadata, nil } -func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Get latest build for workspace. - workspaceBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) - if err != nil { - return nil, xerrors.Errorf("get latest workspace build: %w", err) - } - - // Get resources for build. - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - if len(resources) == 0 { - return []database.WorkspaceAgent{}, nil - } - - resourceIDs := make([]uuid.UUID, len(resources)) - for i, resource := range resources { - resourceIDs[i] = resource.ID - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) +func (q *fakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { + resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) + return nil, err } - - return agents, nil -} - -func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndSlug(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceApp{}, err + resourceIDs := map[uuid.UUID]struct{}{} + for _, resource := range resources { + resourceIDs[resource.ID] = struct{}{} } q.mutex.RLock() defer q.mutex.RUnlock() - for _, app := range q.workspaceApps { - if app.AgentID != arg.AgentID { - continue - } - if app.Slug != arg.Slug { + metadata := make([]database.WorkspaceResourceMetadatum, 0) + for _, m := range q.workspaceResourceMetadata { + _, ok := resourceIDs[m.WorkspaceResourceID] + if !ok { continue } - return app, nil + metadata = append(metadata, m) } - return database.WorkspaceApp{}, sql.ErrNoRows + return metadata, nil } -func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { +func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.AgentID == id { - apps = append(apps, app) - } - } - if len(apps) == 0 { - return nil, sql.ErrNoRows - } - return apps, nil + return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) } -func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { +func (q *fakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - for _, id := range ids { - if app.AgentID == id { - apps = append(apps, app) - break + resources := make([]database.WorkspaceResource, 0) + for _, resource := range q.workspaceResources { + for _, jobID := range jobIDs { + if resource.JobID != jobID { + continue } + resources = append(resources, resource) } } - return apps, nil + return resources, nil } -func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { +func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.CreatedAt.After(after) { - apps = append(apps, app) + resources := make([]database.WorkspaceResource, 0) + for _, resource := range q.workspaceResources { + if resource.CreatedAt.After(after) { + resources = append(resources, resource) } } - return apps, nil + return resources, nil } -func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } - return q.getWorkspaceBuildByIDNoLock(ctx, id) + // A nil auth filter means no auth filter. + workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) + return workspaceRows, err } -func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { +func (q *fakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() - for _, build := range q.workspaceBuilds { - if build.JobID == jobID { - return build, nil + workspaces := []database.Workspace{} + for _, workspace := range q.workspaces { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, err } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} -func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID != arg.WorkspaceID { + if build.Transition == database.WorkspaceTransitionStart && + !build.Deadline.IsZero() && + build.Deadline.Before(now) && + !workspace.LockedAt.Valid { + workspaces = append(workspaces, workspace) continue } - if workspaceBuild.BuildNumber != arg.BuildNumber { + + if build.Transition == database.WorkspaceTransitionStop && + workspace.AutostartSchedule.Valid && + !workspace.LockedAt.Valid { + workspaces = append(workspaces, workspace) continue } - return workspaceBuild, nil - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} -func (q *fakeQuerier) GetWorkspaceBuildParameters(_ context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - params := make([]database.WorkspaceBuildParameter, 0) - for _, param := range q.workspaceBuildParameters { - if param.WorkspaceBuildID != workspaceBuildID { + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job by ID: %w", err) + } + if db2sdk.ProvisionerJobStatus(job) == codersdk.ProvisionerJobFailed { + workspaces = append(workspaces, workspace) continue } - params = append(params, param) - } - return params, nil -} - -func (q *fakeQuerier) GetWorkspaceBuildsByWorkspaceID(_ context.Context, - params database.GetWorkspaceBuildsByWorkspaceIDParams, -) ([]database.WorkspaceBuild, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err - } - q.mutex.RLock() - defer q.mutex.RUnlock() - - history := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.Before(params.Since) { + template, err := q.GetTemplateByID(ctx, workspace.TemplateID) + if err != nil { + return nil, xerrors.Errorf("get template by ID: %w", err) + } + if !workspace.LockedAt.Valid && template.InactivityTTL > 0 { + workspaces = append(workspaces, workspace) continue } - if workspaceBuild.WorkspaceID == params.WorkspaceID { - history = append(history, workspaceBuild) + if workspace.LockedAt.Valid && template.LockedTTL > 0 { + workspaces = append(workspaces, workspace) + continue } } - // Order by build_number - slices.SortFunc(history, func(a, b database.WorkspaceBuild) bool { - // use greater than since we want descending order - return a.BuildNumber > b.BuildNumber - }) - - if params.AfterID != uuid.Nil { - found := false - for i, v := range history { - if v.ID == params.AfterID { - // We want to return all builds after index i. - history = history[i+1:] - found = true - break - } - } + return workspaces, nil +} - // If no builds after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows - } +func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + if err := validateDatabaseType(arg); err != nil { + return database.APIKey{}, err } - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(history)-1 { - return nil, sql.ErrNoRows - } - history = history[params.OffsetOpt:] + q.mutex.Lock() + defer q.mutex.Unlock() + + if arg.LifetimeSeconds == 0 { + arg.LifetimeSeconds = 86400 } - if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(history) { - params.LimitOpt = int32(len(history)) + for _, u := range q.users { + if u.ID == arg.UserID && u.Deleted { + return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") } - history = history[:params.LimitOpt] } - if len(history) == 0 { - return nil, sql.ErrNoRows + //nolint:gosimple + key := database.APIKey{ + ID: arg.ID, + LifetimeSeconds: arg.LifetimeSeconds, + HashedSecret: arg.HashedSecret, + IPAddress: arg.IPAddress, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, + Scope: arg.Scope, + TokenName: arg.TokenName, } - return history, nil + q.apiKeys = append(q.apiKeys, key) + return key, nil } -func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { + return q.InsertGroup(ctx, database.InsertGroupParams{ + ID: orgID, + Name: database.AllUsersGroup, + OrganizationID: orgID, + }) +} - workspaceBuilds := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.After(after) { - workspaceBuilds = append(workspaceBuilds, workspaceBuild) - } +func (q *fakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { + if err := validateDatabaseType(arg); err != nil { + return database.AuditLog{}, err } - return workspaceBuilds, nil + + q.mutex.Lock() + defer q.mutex.Unlock() + + alog := database.AuditLog(arg) + + q.auditLogs = append(q.auditLogs, alog) + slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) bool { + return a.Time.Before(b.Time) + }) + + return alog, nil } -func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() - return q.getWorkspaceByAgentIDNoLock(ctx, agentID) + q.derpMeshKey = id + return nil } -func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) InsertDeploymentID(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() - return q.getWorkspaceByIDNoLock(ctx, id) + q.deploymentID = id + return nil } -func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { +func (q *fakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err + return database.File{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - var found *database.Workspace - for _, workspace := range q.workspaces { - workspace := workspace - if workspace.OwnerID != arg.OwnerID { - continue - } - if !strings.EqualFold(workspace.Name, arg.Name) { - continue - } - if workspace.Deleted != arg.Deleted { - continue - } + //nolint:gosimple + file := database.File{ + ID: arg.ID, + Hash: arg.Hash, + CreatedAt: arg.CreatedAt, + CreatedBy: arg.CreatedBy, + Mimetype: arg.Mimetype, + Data: arg.Data, + } + q.files = append(q.files, file) + return file, nil +} - // Return the most recent workspace with the given name - if found == nil || workspace.CreatedAt.After(found.CreatedAt) { - found = &workspace - } +func (q *fakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitAuthLink{}, err } - if found != nil { - return *found, nil + + q.mutex.Lock() + defer q.mutex.Unlock() + // nolint:gosimple + gitAuthLink := database.GitAuthLink{ + ProviderID: arg.ProviderID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OAuthAccessToken: arg.OAuthAccessToken, + OAuthRefreshToken: arg.OAuthRefreshToken, + OAuthExpiry: arg.OAuthExpiry, } - return database.Workspace{}, sql.ErrNoRows + q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink) + return gitAuthLink, nil } -func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { - if err := validateDatabaseType(workspaceAppID); err != nil { - return database.Workspace{}, err +func (q *fakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitSSHKey{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - for _, workspaceApp := range q.workspaceApps { - workspaceApp := workspaceApp - if workspaceApp.ID == workspaceAppID { - return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) - } + //nolint:gosimple + gitSSHKey := database.GitSSHKey{ + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + PrivateKey: arg.PrivateKey, + PublicKey: arg.PublicKey, } - return database.Workspace{}, sql.ErrNoRows + q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) + return gitSSHKey, nil } -func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Group{}, err + } - cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) + q.mutex.Lock() + defer q.mutex.Unlock() - for _, p := range q.workspaceProxies { - if !p.Deleted { - cpy = append(cpy, p) + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return database.Group{}, errDuplicateKey } } - return cpy, nil -} - -func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - // Return zero rows if this is called with a non-sanitized hostname. The SQL - // version of this query does the same thing. - if !validProxyByHostnameRegex.MatchString(params.Hostname) { - return database.WorkspaceProxy{}, sql.ErrNoRows + //nolint:gosimple + group := database.Group{ + ID: arg.ID, + Name: arg.Name, + OrganizationID: arg.OrganizationID, + AvatarURL: arg.AvatarURL, + QuotaAllowance: arg.QuotaAllowance, } - // This regex matches the SQL version. - accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) + q.groups = append(q.groups, group) - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { - return proxy, nil - } + return group, nil +} - // Compile the app hostname regex. This is slow sadly. - if params.AllowWildcardHostname { - wildcardRegexp, err := httpapi.CompileHostnamePattern(proxy.WildcardHostname) - if err != nil { - return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) - } - if _, ok := httpapi.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { - return proxy, nil - } - } +func (q *fakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { + if err := validateDatabaseType(arg); err != nil { + return err } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - for _, proxy := range q.workspaceProxies { - if proxy.ID == id { - return proxy, nil + for _, member := range q.groupMembers { + if member.GroupID == arg.GroupID && + member.UserID == arg.UserID { + return errDuplicateKey } } - return database.WorkspaceProxy{}, sql.ErrNoRows + + //nolint:gosimple + q.groupMembers = append(q.groupMembers, database.GroupMember{ + GroupID: arg.GroupID, + UserID: arg.UserID, + }) + + return nil } -func (q *fakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { +func (q *fakeQuerier) InsertLicense( + _ context.Context, arg database.InsertLicenseParams, +) (database.License, error) { + if err := validateDatabaseType(arg); err != nil { + return database.License{}, err + } + q.mutex.Lock() defer q.mutex.Unlock() - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if proxy.Name == name { - return proxy, nil - } + l := database.License{ + ID: q.lastLicenseID + 1, + UploadedAt: arg.UploadedAt, + JWT: arg.JWT, + Exp: arg.Exp, } - return database.WorkspaceProxy{}, sql.ErrNoRows + q.lastLicenseID = l.ID + q.licenses = append(q.licenses, l) + return l, nil } -func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, resource := range q.workspaceResources { - if resource.ID == id { - return resource, nil - } +func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Organization{}, err } - return database.WorkspaceResource{}, sql.ErrNoRows -} -func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.workspaceResourceMetadata { - for _, id := range ids { - if metadatum.WorkspaceResourceID == id { - metadata = append(metadata, metadatum) - } - } + organization := database.Organization{ + ID: arg.ID, + Name: arg.Name, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, } - return metadata, nil + q.organizations = append(q.organizations, organization) + return organization, nil } -func (q *fakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { - resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) - if err != nil { - return nil, err - } - resourceIDs := map[uuid.UUID]struct{}{} - for _, resource := range resources { - resourceIDs[resource.ID] = struct{}{} +func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + if err := validateDatabaseType(arg); err != nil { + return database.OrganizationMember{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, m := range q.workspaceResourceMetadata { - _, ok := resourceIDs[m.WorkspaceResourceID] - if !ok { - continue - } - metadata = append(metadata, m) + //nolint:gosimple + organizationMember := database.OrganizationMember{ + OrganizationID: arg.OrganizationID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Roles: arg.Roles, } - return metadata, nil + q.organizationMembers = append(q.organizationMembers, organizationMember) + return organizationMember, nil } -func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) -} +func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { + if err := validateDatabaseType(arg); err != nil { + return database.ProvisionerDaemon{}, err + } -func (q *fakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - for _, jobID := range jobIDs { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } + daemon := database.ProvisionerDaemon{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + Name: arg.Name, + Provisioners: arg.Provisioners, + Tags: arg.Tags, } - return resources, nil + q.provisionerDaemons = append(q.provisionerDaemons, daemon) + return daemon, nil } -func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { + if err := validateDatabaseType(arg); err != nil { + return database.ProvisionerJob{}, err + } - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.CreatedAt.After(after) { - resources = append(resources, resource) - } + q.mutex.Lock() + defer q.mutex.Unlock() + + job := database.ProvisionerJob{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OrganizationID: arg.OrganizationID, + InitiatorID: arg.InitiatorID, + Provisioner: arg.Provisioner, + StorageMethod: arg.StorageMethod, + FileID: arg.FileID, + Type: arg.Type, + Input: arg.Input, + Tags: arg.Tags, } - return resources, nil + q.provisionerJobs = append(q.provisionerJobs, job) + return job, nil } -func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { +func (q *fakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { if err := validateDatabaseType(arg); err != nil { return nil, err } - // A nil auth filter means no auth filter. - workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) - return workspaceRows, err -} - -func (q *fakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := []database.Workspace{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, err - } + q.mutex.Lock() + defer q.mutex.Unlock() - if build.Transition == database.WorkspaceTransitionStart && - !build.Deadline.IsZero() && - build.Deadline.Before(now) && - !workspace.LockedAt.Valid { - workspaces = append(workspaces, workspace) - continue - } - - if build.Transition == database.WorkspaceTransitionStop && - workspace.AutostartSchedule.Valid && - !workspace.LockedAt.Valid { - workspaces = append(workspaces, workspace) - continue - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - if db2sdk.ProvisionerJobStatus(job) == codersdk.ProvisionerJobFailed { - workspaces = append(workspaces, workspace) - continue - } - - template, err := q.GetTemplateByID(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - if !workspace.LockedAt.Valid && template.InactivityTTL > 0 { - workspaces = append(workspaces, workspace) - continue - } - if workspace.LockedAt.Valid && template.LockedTTL > 0 { - workspaces = append(workspaces, workspace) - continue - } + logs := make([]database.ProvisionerJobLog, 0) + id := int64(1) + if len(q.provisionerJobLogs) > 0 { + id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID } - - return workspaces, nil + for index, output := range arg.Output { + id++ + logs = append(logs, database.ProvisionerJobLog{ + ID: id, + JobID: arg.JobID, + CreatedAt: arg.CreatedAt[index], + Source: arg.Source[index], + Level: arg.Level[index], + Stage: arg.Stage[index], + Output: output, + }) + } + q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) + return logs, nil } -func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { +func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { if err := validateDatabaseType(arg); err != nil { - return database.APIKey{}, err + return database.Replica{}, err } q.mutex.Lock() defer q.mutex.Unlock() - if arg.LifetimeSeconds == 0 { - arg.LifetimeSeconds = 86400 - } - - for _, u := range q.users { - if u.ID == arg.UserID && u.Deleted { - return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") - } - } - - //nolint:gosimple - key := database.APIKey{ + replica := database.Replica{ ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, CreatedAt: arg.CreatedAt, + StartedAt: arg.StartedAt, UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - Scope: arg.Scope, - TokenName: arg.TokenName, + Hostname: arg.Hostname, + RegionID: arg.RegionID, + RelayAddress: arg.RelayAddress, + Version: arg.Version, + DatabaseLatency: arg.DatabaseLatency, } - q.apiKeys = append(q.apiKeys, key) - return key, nil -} - -func (q *fakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { - return q.InsertGroup(ctx, database.InsertGroupParams{ - ID: orgID, - Name: database.AllUsersGroup, - OrganizationID: orgID, - }) + q.replicas = append(q.replicas, replica) + return replica, nil } -func (q *fakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { +func (q *fakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) (database.Template, error) { if err := validateDatabaseType(arg); err != nil { - return database.AuditLog{}, err + return database.Template{}, err } q.mutex.Lock() defer q.mutex.Unlock() - alog := database.AuditLog(arg) - - q.auditLogs = append(q.auditLogs, alog) - slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) bool { - return a.Time.Before(b.Time) - }) - - return alog, nil -} - -func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.derpMeshKey = id - return nil -} - -func (q *fakeQuerier) InsertDeploymentID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.deploymentID = id - return nil + //nolint:gosimple + template := database.Template{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OrganizationID: arg.OrganizationID, + Name: arg.Name, + Provisioner: arg.Provisioner, + ActiveVersionID: arg.ActiveVersionID, + Description: arg.Description, + CreatedBy: arg.CreatedBy, + UserACL: arg.UserACL, + GroupACL: arg.GroupACL, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, + AllowUserAutostart: true, + AllowUserAutostop: true, + } + q.templates = append(q.templates, template) + return template.DeepCopy(), nil } -func (q *fakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { +func (q *fakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { if err := validateDatabaseType(arg); err != nil { - return database.File{}, err + return database.TemplateVersion{}, err + } + + if len(arg.Message) > 1048576 { + return database.TemplateVersion{}, xerrors.New("message too long") } q.mutex.Lock() defer q.mutex.Unlock() //nolint:gosimple - file := database.File{ - ID: arg.ID, - Hash: arg.Hash, - CreatedAt: arg.CreatedAt, - CreatedBy: arg.CreatedBy, - Mimetype: arg.Mimetype, - Data: arg.Data, + version := database.TemplateVersion{ + ID: arg.ID, + TemplateID: arg.TemplateID, + OrganizationID: arg.OrganizationID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Name: arg.Name, + Message: arg.Message, + Readme: arg.Readme, + JobID: arg.JobID, + CreatedBy: arg.CreatedBy, } - q.files = append(q.files, file) - return file, nil + q.templateVersions = append(q.templateVersions, version) + return version, nil } -func (q *fakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { +func (q *fakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitAuthLink{}, err + return database.TemplateVersionParameter{}, err } q.mutex.Lock() defer q.mutex.Unlock() - // nolint:gosimple - gitAuthLink := database.GitAuthLink{ - ProviderID: arg.ProviderID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthExpiry: arg.OAuthExpiry, + + //nolint:gosimple + param := database.TemplateVersionParameter{ + TemplateVersionID: arg.TemplateVersionID, + Name: arg.Name, + DisplayName: arg.DisplayName, + Description: arg.Description, + Type: arg.Type, + Mutable: arg.Mutable, + DefaultValue: arg.DefaultValue, + Icon: arg.Icon, + Options: arg.Options, + ValidationError: arg.ValidationError, + ValidationRegex: arg.ValidationRegex, + ValidationMin: arg.ValidationMin, + ValidationMax: arg.ValidationMax, + ValidationMonotonic: arg.ValidationMonotonic, + Required: arg.Required, + DisplayOrder: arg.DisplayOrder, + Ephemeral: arg.Ephemeral, } - q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink) - return gitAuthLink, nil + q.templateVersionParameters = append(q.templateVersionParameters, param) + return param, nil } -func (q *fakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { +func (q *fakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err + return database.TemplateVersionVariable{}, err } q.mutex.Lock() defer q.mutex.Unlock() //nolint:gosimple - gitSSHKey := database.GitSSHKey{ - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - PrivateKey: arg.PrivateKey, - PublicKey: arg.PublicKey, + variable := database.TemplateVersionVariable{ + TemplateVersionID: arg.TemplateVersionID, + Name: arg.Name, + Description: arg.Description, + Type: arg.Type, + Value: arg.Value, + DefaultValue: arg.DefaultValue, + Required: arg.Required, + Sensitive: arg.Sensitive, } - q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) - return gitSSHKey, nil + q.templateVersionVariables = append(q.templateVersionVariables, variable) + return variable, nil } -func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { +func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err + return database.User{}, err + } + + // There is a common bug when using dbfake that 2 inserted users have the + // same created_at time. This causes user order to not be deterministic, + // which breaks some unit tests. + // To fix this, we make sure that the created_at time is always greater + // than the last user's created_at time. + allUsers, _ := q.GetUsers(context.Background(), database.GetUsersParams{}) + if len(allUsers) > 0 { + lastUser := allUsers[len(allUsers)-1] + if arg.CreatedAt.Before(lastUser.CreatedAt) || + arg.CreatedAt.Equal(lastUser.CreatedAt) { + // 1 ms is a good enough buffer. + arg.CreatedAt = lastUser.CreatedAt.Add(time.Millisecond) + } } q.mutex.Lock() defer q.mutex.Unlock() - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return database.Group{}, errDuplicateKey + for _, user := range q.users { + if user.Username == arg.Username && !user.Deleted { + return database.User{}, errDuplicateKey } } - //nolint:gosimple - group := database.Group{ + user := database.User{ ID: arg.ID, - Name: arg.Name, - OrganizationID: arg.OrganizationID, - AvatarURL: arg.AvatarURL, - QuotaAllowance: arg.QuotaAllowance, + Email: arg.Email, + HashedPassword: arg.HashedPassword, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Username: arg.Username, + Status: database.UserStatusActive, + RBACRoles: arg.RBACRoles, + LoginType: arg.LoginType, } - - q.groups = append(q.groups, group) - - return group, nil + q.users = append(q.users, user) + return user, nil } -func (q *fakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - +func (q *fakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { q.mutex.Lock() defer q.mutex.Unlock() - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID && - member.UserID == arg.UserID { - return errDuplicateKey + var groupIDs []uuid.UUID + for _, group := range q.groups { + for _, groupName := range arg.GroupNames { + if group.Name == groupName { + groupIDs = append(groupIDs, group.ID) + } } } - //nolint:gosimple - q.groupMembers = append(q.groupMembers, database.GroupMember{ - GroupID: arg.GroupID, - UserID: arg.UserID, - }) + for _, groupID := range groupIDs { + q.groupMembers = append(q.groupMembers, database.GroupMember{ + UserID: arg.UserID, + GroupID: groupID, + }) + } return nil } -func (q *fakeQuerier) InsertLicense( - _ context.Context, arg database.InsertLicenseParams, -) (database.License, error) { - if err := validateDatabaseType(arg); err != nil { - return database.License{}, err - } - +func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { q.mutex.Lock() defer q.mutex.Unlock() - l := database.License{ - ID: q.lastLicenseID + 1, - UploadedAt: arg.UploadedAt, - JWT: arg.JWT, - Exp: arg.Exp, - } - q.lastLicenseID = l.ID - q.licenses = append(q.licenses, l) - return l, nil -} - -func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Organization{}, err + //nolint:gosimple + link := database.UserLink{ + UserID: args.UserID, + LoginType: args.LoginType, + LinkedID: args.LinkedID, + OAuthAccessToken: args.OAuthAccessToken, + OAuthRefreshToken: args.OAuthRefreshToken, + OAuthExpiry: args.OAuthExpiry, } - q.mutex.Lock() - defer q.mutex.Unlock() + q.userLinks = append(q.userLinks, link) - organization := database.Organization{ - ID: arg.ID, - Name: arg.Name, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - } - q.organizations = append(q.organizations, organization) - return organization, nil + return link, nil } -func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { +func (q *fakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err + return database.Workspace{}, err } q.mutex.Lock() defer q.mutex.Unlock() //nolint:gosimple - organizationMember := database.OrganizationMember{ - OrganizationID: arg.OrganizationID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Roles: arg.Roles, + workspace := database.Workspace{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OwnerID: arg.OwnerID, + OrganizationID: arg.OrganizationID, + TemplateID: arg.TemplateID, + Name: arg.Name, + AutostartSchedule: arg.AutostartSchedule, + Ttl: arg.Ttl, + LastUsedAt: arg.LastUsedAt, } - q.organizationMembers = append(q.organizationMembers, organizationMember) - return organizationMember, nil + q.workspaces = append(q.workspaces, workspace) + return workspace, nil } -func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { +func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerDaemon{}, err + return database.WorkspaceAgent{}, err } q.mutex.Lock() defer q.mutex.Unlock() - daemon := database.ProvisionerDaemon{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - Name: arg.Name, - Provisioners: arg.Provisioners, - Tags: arg.Tags, + agent := database.WorkspaceAgent{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + ResourceID: arg.ResourceID, + AuthToken: arg.AuthToken, + AuthInstanceID: arg.AuthInstanceID, + EnvironmentVariables: arg.EnvironmentVariables, + Name: arg.Name, + Architecture: arg.Architecture, + OperatingSystem: arg.OperatingSystem, + Directory: arg.Directory, + StartupScriptBehavior: arg.StartupScriptBehavior, + StartupScript: arg.StartupScript, + InstanceMetadata: arg.InstanceMetadata, + ResourceMetadata: arg.ResourceMetadata, + ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, + TroubleshootingURL: arg.TroubleshootingURL, + MOTDFile: arg.MOTDFile, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + ShutdownScript: arg.ShutdownScript, } - q.provisionerDaemons = append(q.provisionerDaemons, daemon) - return daemon, nil -} -func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err - } + q.workspaceAgents = append(q.workspaceAgents, agent) + return agent, nil +} +func (q *fakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { q.mutex.Lock() defer q.mutex.Unlock() - job := database.ProvisionerJob{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - InitiatorID: arg.InitiatorID, - Provisioner: arg.Provisioner, - StorageMethod: arg.StorageMethod, - FileID: arg.FileID, - Type: arg.Type, - Input: arg.Input, - Tags: arg.Tags, + //nolint:gosimple + metadatum := database.WorkspaceAgentMetadatum{ + WorkspaceAgentID: arg.WorkspaceAgentID, + Script: arg.Script, + DisplayName: arg.DisplayName, + Key: arg.Key, + Timeout: arg.Timeout, + Interval: arg.Interval, } - q.provisionerJobs = append(q.provisionerJobs, job) - return job, nil + + q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) + return nil } -func (q *fakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { +func (q *fakeQuerier) InsertWorkspaceAgentStartupLogs(_ context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { if err := validateDatabaseType(arg); err != nil { return nil, err } @@ -3841,699 +3708,368 @@ func (q *fakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.I q.mutex.Lock() defer q.mutex.Unlock() - logs := make([]database.ProvisionerJobLog, 0) - id := int64(1) - if len(q.provisionerJobLogs) > 0 { - id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID + logs := []database.WorkspaceAgentStartupLog{} + id := int64(0) + if len(q.workspaceAgentLogs) > 0 { + id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID } + outputLength := int32(0) for index, output := range arg.Output { id++ - logs = append(logs, database.ProvisionerJobLog{ + logs = append(logs, database.WorkspaceAgentStartupLog{ ID: id, - JobID: arg.JobID, + AgentID: arg.AgentID, CreatedAt: arg.CreatedAt[index], - Source: arg.Source[index], Level: arg.Level[index], - Stage: arg.Stage[index], Output: output, }) + outputLength += int32(len(output)) } - q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) - return logs, nil -} - -func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - replica := database.Replica{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - Hostname: arg.Hostname, - RegionID: arg.RegionID, - RelayAddress: arg.RelayAddress, - Version: arg.Version, - DatabaseLatency: arg.DatabaseLatency, + for index, agent := range q.workspaceAgents { + if agent.ID != arg.AgentID { + continue + } + // Greater than 1MB, same as the PostgreSQL constraint! + if agent.StartupLogsLength+outputLength > (1 << 20) { + return nil, &pq.Error{ + Constraint: "max_startup_logs_length", + Table: "workspace_agents", + } + } + agent.StartupLogsLength += outputLength + q.workspaceAgents[index] = agent + break } - q.replicas = append(q.replicas, replica) - return replica, nil + q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) + return logs, nil } -func (q *fakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) (database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err +func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { + if err := validateDatabaseType(p); err != nil { + return database.WorkspaceAgentStat{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - template := database.Template{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - Name: arg.Name, - Provisioner: arg.Provisioner, - ActiveVersionID: arg.ActiveVersionID, - Description: arg.Description, - CreatedBy: arg.CreatedBy, - UserACL: arg.UserACL, - GroupACL: arg.GroupACL, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, - AllowUserAutostart: true, - AllowUserAutostop: true, + stat := database.WorkspaceAgentStat{ + ID: p.ID, + CreatedAt: p.CreatedAt, + WorkspaceID: p.WorkspaceID, + AgentID: p.AgentID, + UserID: p.UserID, + ConnectionsByProto: p.ConnectionsByProto, + ConnectionCount: p.ConnectionCount, + RxPackets: p.RxPackets, + RxBytes: p.RxBytes, + TxPackets: p.TxPackets, + TxBytes: p.TxBytes, + TemplateID: p.TemplateID, + SessionCountVSCode: p.SessionCountVSCode, + SessionCountJetBrains: p.SessionCountJetBrains, + SessionCountReconnectingPTY: p.SessionCountReconnectingPTY, + SessionCountSSH: p.SessionCountSSH, + ConnectionMedianLatencyMS: p.ConnectionMedianLatencyMS, } - q.templates = append(q.templates, template) - return template.DeepCopy(), nil + q.workspaceAgentStats = append(q.workspaceAgentStats, stat) + return stat, nil } -func (q *fakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { +func (q *fakeQuerier) InsertWorkspaceApp(_ context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - if len(arg.Message) > 1048576 { - return database.TemplateVersion{}, xerrors.New("message too long") + return database.WorkspaceApp{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - version := database.TemplateVersion{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, - } - q.templateVersions = append(q.templateVersions, version) - return version, nil -} - -func (q *fakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionParameter{}, err + if arg.SharingLevel == "" { + arg.SharingLevel = database.AppSharingLevelOwner } - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - param := database.TemplateVersionParameter{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Type: arg.Type, - Mutable: arg.Mutable, - DefaultValue: arg.DefaultValue, - Icon: arg.Icon, - Options: arg.Options, - ValidationError: arg.ValidationError, - ValidationRegex: arg.ValidationRegex, - ValidationMin: arg.ValidationMin, - ValidationMax: arg.ValidationMax, - ValidationMonotonic: arg.ValidationMonotonic, - Required: arg.Required, - DisplayOrder: arg.DisplayOrder, - Ephemeral: arg.Ephemeral, + // nolint:gosimple + workspaceApp := database.WorkspaceApp{ + ID: arg.ID, + AgentID: arg.AgentID, + CreatedAt: arg.CreatedAt, + Slug: arg.Slug, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + Command: arg.Command, + Url: arg.Url, + External: arg.External, + Subdomain: arg.Subdomain, + SharingLevel: arg.SharingLevel, + HealthcheckUrl: arg.HealthcheckUrl, + HealthcheckInterval: arg.HealthcheckInterval, + HealthcheckThreshold: arg.HealthcheckThreshold, + Health: arg.Health, } - q.templateVersionParameters = append(q.templateVersionParameters, param) - return param, nil + q.workspaceApps = append(q.workspaceApps, workspaceApp) + return workspaceApp, nil } -func (q *fakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { +func (q *fakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionVariable{}, err + return database.WorkspaceBuild{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - variable := database.TemplateVersionVariable{ + workspaceBuild := database.WorkspaceBuild{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + WorkspaceID: arg.WorkspaceID, TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - Description: arg.Description, - Type: arg.Type, - Value: arg.Value, - DefaultValue: arg.DefaultValue, - Required: arg.Required, - Sensitive: arg.Sensitive, + BuildNumber: arg.BuildNumber, + Transition: arg.Transition, + InitiatorID: arg.InitiatorID, + JobID: arg.JobID, + ProvisionerState: arg.ProvisionerState, + Deadline: arg.Deadline, + Reason: arg.Reason, } - q.templateVersionVariables = append(q.templateVersionVariables, variable) - return variable, nil + q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) + return workspaceBuild, nil } -func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { +func (q *fakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - // There is a common bug when using dbfake that 2 inserted users have the - // same created_at time. This causes user order to not be deterministic, - // which breaks some unit tests. - // To fix this, we make sure that the created_at time is always greater - // than the last user's created_at time. - allUsers, _ := q.GetUsers(context.Background(), database.GetUsersParams{}) - if len(allUsers) > 0 { - lastUser := allUsers[len(allUsers)-1] - if arg.CreatedAt.Before(lastUser.CreatedAt) || - arg.CreatedAt.Equal(lastUser.CreatedAt) { - // 1 ms is a good enough buffer. - arg.CreatedAt = lastUser.CreatedAt.Add(time.Millisecond) - } + return err } q.mutex.Lock() defer q.mutex.Unlock() - for _, user := range q.users { - if user.Username == arg.Username && !user.Deleted { - return database.User{}, errDuplicateKey - } - } - - user := database.User{ - ID: arg.ID, - Email: arg.Email, - HashedPassword: arg.HashedPassword, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Username: arg.Username, - Status: database.UserStatusActive, - RBACRoles: arg.RBACRoles, - LoginType: arg.LoginType, + for index, name := range arg.Name { + q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ + WorkspaceBuildID: arg.WorkspaceBuildID, + Name: name, + Value: arg.Value[index], + }) } - q.users = append(q.users, user) - return user, nil + return nil } -func (q *fakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { +func (q *fakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { q.mutex.Lock() defer q.mutex.Unlock() - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupName := range arg.GroupNames { - if group.Name == groupName { - groupIDs = append(groupIDs, group.ID) - } + for _, p := range q.workspaceProxies { + if !p.Deleted && p.Name == arg.Name { + return database.WorkspaceProxy{}, errDuplicateKey } } - for _, groupID := range groupIDs { - q.groupMembers = append(q.groupMembers, database.GroupMember{ - UserID: arg.UserID, - GroupID: groupID, - }) + p := database.WorkspaceProxy{ + ID: arg.ID, + Name: arg.Name, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + TokenHashedSecret: arg.TokenHashedSecret, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Deleted: false, } - - return nil + q.workspaceProxies = append(q.workspaceProxies, p) + return p, nil } -func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { +func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { + if err := validateDatabaseType(arg); err != nil { + return database.WorkspaceResource{}, err + } + q.mutex.Lock() defer q.mutex.Unlock() //nolint:gosimple - link := database.UserLink{ - UserID: args.UserID, - LoginType: args.LoginType, - LinkedID: args.LinkedID, - OAuthAccessToken: args.OAuthAccessToken, - OAuthRefreshToken: args.OAuthRefreshToken, - OAuthExpiry: args.OAuthExpiry, + resource := database.WorkspaceResource{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + JobID: arg.JobID, + Transition: arg.Transition, + Type: arg.Type, + Name: arg.Name, + Hide: arg.Hide, + Icon: arg.Icon, + DailyCost: arg.DailyCost, } - - q.userLinks = append(q.userLinks, link) - - return link, nil + q.workspaceResources = append(q.workspaceResources, resource) + return resource, nil } -func (q *fakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { +func (q *fakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err + return nil, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - workspace := database.Workspace{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OwnerID: arg.OwnerID, - OrganizationID: arg.OrganizationID, - TemplateID: arg.TemplateID, - Name: arg.Name, - AutostartSchedule: arg.AutostartSchedule, - Ttl: arg.Ttl, - LastUsedAt: arg.LastUsedAt, + metadata := make([]database.WorkspaceResourceMetadatum, 0) + id := int64(1) + if len(q.workspaceResourceMetadata) > 0 { + id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID } - q.workspaces = append(q.workspaces, workspace) - return workspace, nil -} - -func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceAgent{}, err + for index, key := range arg.Key { + id++ + value := arg.Value[index] + metadata = append(metadata, database.WorkspaceResourceMetadatum{ + ID: id, + WorkspaceResourceID: arg.WorkspaceResourceID, + Key: key, + Value: sql.NullString{ + String: value, + Valid: value != "", + }, + Sensitive: arg.Sensitive[index], + }) } + q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) + return metadata, nil +} +func (q *fakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { q.mutex.Lock() defer q.mutex.Unlock() - agent := database.WorkspaceAgent{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - StartupScriptBehavior: arg.StartupScriptBehavior, - StartupScript: arg.StartupScript, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, - TroubleshootingURL: arg.TroubleshootingURL, - MOTDFile: arg.MOTDFile, - LifecycleState: database.WorkspaceAgentLifecycleStateCreated, - ShutdownScript: arg.ShutdownScript, + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Url = arg.Url + p.WildcardHostname = arg.WildcardHostname + p.UpdatedAt = database.Now() + q.workspaceProxies[i] = p + return p, nil + } } + return database.WorkspaceProxy{}, sql.ErrNoRows +} - q.workspaceAgents = append(q.workspaceAgents, agent) - return agent, nil +func (*fakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { + return false, xerrors.New("TryAcquireLock must only be called within a transaction") } -func (q *fakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { +func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - metadatum := database.WorkspaceAgentMetadatum{ - WorkspaceAgentID: arg.WorkspaceAgentID, - Script: arg.Script, - DisplayName: arg.DisplayName, - Key: arg.Key, - Timeout: arg.Timeout, - Interval: arg.Interval, + for index, apiKey := range q.apiKeys { + if apiKey.ID != arg.ID { + continue + } + apiKey.LastUsed = arg.LastUsed + apiKey.ExpiresAt = arg.ExpiresAt + apiKey.IPAddress = arg.IPAddress + q.apiKeys[index] = apiKey + return nil } - - q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) - return nil + return sql.ErrNoRows } -func (q *fakeQuerier) InsertWorkspaceAgentStartupLogs(_ context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { +func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { if err := validateDatabaseType(arg); err != nil { - return nil, err + return database.GitAuthLink{}, err } q.mutex.Lock() defer q.mutex.Unlock() - - logs := []database.WorkspaceAgentStartupLog{} - id := int64(0) - if len(q.workspaceAgentLogs) > 0 { - id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID - } - outputLength := int32(0) - for index, output := range arg.Output { - id++ - logs = append(logs, database.WorkspaceAgentStartupLog{ - ID: id, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt[index], - Level: arg.Level[index], - Output: output, - }) - outputLength += int32(len(output)) - } - for index, agent := range q.workspaceAgents { - if agent.ID != arg.AgentID { + for index, gitAuthLink := range q.gitAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { continue } - // Greater than 1MB, same as the PostgreSQL constraint! - if agent.StartupLogsLength+outputLength > (1 << 20) { - return nil, &pq.Error{ - Constraint: "max_startup_logs_length", - Table: "workspace_agents", - } + if gitAuthLink.UserID != arg.UserID { + continue } - agent.StartupLogsLength += outputLength - q.workspaceAgents[index] = agent - break + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken + gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken + gitAuthLink.OAuthExpiry = arg.OAuthExpiry + q.gitAuthLinks[index] = gitAuthLink + + return gitAuthLink, nil } - q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) - return logs, nil + return database.GitAuthLink{}, sql.ErrNoRows } -func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { - if err := validateDatabaseType(p); err != nil { - return database.WorkspaceAgentStat{}, err +func (q *fakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitSSHKey{}, err } q.mutex.Lock() defer q.mutex.Unlock() - stat := database.WorkspaceAgentStat{ - ID: p.ID, - CreatedAt: p.CreatedAt, - WorkspaceID: p.WorkspaceID, - AgentID: p.AgentID, - UserID: p.UserID, - ConnectionsByProto: p.ConnectionsByProto, - ConnectionCount: p.ConnectionCount, - RxPackets: p.RxPackets, - RxBytes: p.RxBytes, - TxPackets: p.TxPackets, - TxBytes: p.TxBytes, - TemplateID: p.TemplateID, - SessionCountVSCode: p.SessionCountVSCode, - SessionCountJetBrains: p.SessionCountJetBrains, - SessionCountReconnectingPTY: p.SessionCountReconnectingPTY, - SessionCountSSH: p.SessionCountSSH, - ConnectionMedianLatencyMS: p.ConnectionMedianLatencyMS, + for index, key := range q.gitSSHKey { + if key.UserID != arg.UserID { + continue + } + key.UpdatedAt = arg.UpdatedAt + key.PrivateKey = arg.PrivateKey + key.PublicKey = arg.PublicKey + q.gitSSHKey[index] = key + return key, nil } - q.workspaceAgentStats = append(q.workspaceAgentStats, stat) - return stat, nil + return database.GitSSHKey{}, sql.ErrNoRows } -func (q *fakeQuerier) InsertWorkspaceApp(_ context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { +func (q *fakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceApp{}, err + return database.Group{}, err } q.mutex.Lock() defer q.mutex.Unlock() - if arg.SharingLevel == "" { - arg.SharingLevel = database.AppSharingLevelOwner - } - - // nolint:gosimple - workspaceApp := database.WorkspaceApp{ - ID: arg.ID, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt, - Slug: arg.Slug, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - Command: arg.Command, - Url: arg.Url, - External: arg.External, - Subdomain: arg.Subdomain, - SharingLevel: arg.SharingLevel, - HealthcheckUrl: arg.HealthcheckUrl, - HealthcheckInterval: arg.HealthcheckInterval, - HealthcheckThreshold: arg.HealthcheckThreshold, - Health: arg.Health, + for i, group := range q.groups { + if group.ID == arg.ID { + group.Name = arg.Name + group.AvatarURL = arg.AvatarURL + group.QuotaAllowance = arg.QuotaAllowance + q.groups[i] = group + return group, nil + } } - q.workspaceApps = append(q.workspaceApps, workspaceApp) - return workspaceApp, nil + return database.Group{}, sql.ErrNoRows } -func (q *fakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { +func (q *fakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err + return database.OrganizationMember{}, err } q.mutex.Lock() defer q.mutex.Unlock() - workspaceBuild := database.WorkspaceBuild{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - WorkspaceID: arg.WorkspaceID, - TemplateVersionID: arg.TemplateVersionID, - BuildNumber: arg.BuildNumber, - Transition: arg.Transition, - InitiatorID: arg.InitiatorID, - JobID: arg.JobID, - ProvisionerState: arg.ProvisionerState, - Deadline: arg.Deadline, - Reason: arg.Reason, + for i, mem := range q.organizationMembers { + if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { + uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) + exist := make(map[string]struct{}) + for _, r := range arg.GrantedRoles { + if _, ok := exist[r]; ok { + continue + } + exist[r] = struct{}{} + uniqueRoles = append(uniqueRoles, r) + } + sort.Strings(uniqueRoles) + + mem.Roles = uniqueRoles + q.organizationMembers[i] = mem + return mem, nil + } } - q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) - return workspaceBuild, nil + + return database.OrganizationMember{}, sql.ErrNoRows } -func (q *fakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, name := range arg.Name { - q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ - WorkspaceBuildID: arg.WorkspaceBuildID, - Name: name, - Value: arg.Value[index], - }) - } - return nil -} - -func (q *fakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, p := range q.workspaceProxies { - if !p.Deleted && p.Name == arg.Name { - return database.WorkspaceProxy{}, errDuplicateKey - } - } - - p := database.WorkspaceProxy{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - TokenHashedSecret: arg.TokenHashedSecret, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Deleted: false, - } - q.workspaceProxies = append(q.workspaceProxies, p) - return p, nil -} - -func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceResource{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - resource := database.WorkspaceResource{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - JobID: arg.JobID, - Transition: arg.Transition, - Type: arg.Type, - Name: arg.Name, - Hide: arg.Hide, - Icon: arg.Icon, - DailyCost: arg.DailyCost, - } - q.workspaceResources = append(q.workspaceResources, resource) - return resource, nil -} - -func (q *fakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - id := int64(1) - if len(q.workspaceResourceMetadata) > 0 { - id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID - } - for index, key := range arg.Key { - id++ - value := arg.Value[index] - metadata = append(metadata, database.WorkspaceResourceMetadatum{ - ID: id, - WorkspaceResourceID: arg.WorkspaceResourceID, - Key: key, - Value: sql.NullString{ - String: value, - Valid: value != "", - }, - Sensitive: arg.Sensitive[index], - }) - } - q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) - return metadata, nil -} - -func (q *fakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Url = arg.Url - p.WildcardHostname = arg.WildcardHostname - p.UpdatedAt = database.Now() - q.workspaceProxies[i] = p - return p, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (*fakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { - return false, xerrors.New("TryAcquireLock must only be called within a transaction") -} - -func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, apiKey := range q.apiKeys { - if apiKey.ID != arg.ID { - continue - } - apiKey.LastUsed = arg.LastUsed - apiKey.ExpiresAt = arg.ExpiresAt - apiKey.IPAddress = arg.IPAddress - q.apiKeys[index] = apiKey - return nil - } - return sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitAuthLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.gitAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - gitAuthLink.OAuthExpiry = arg.OAuthExpiry - q.gitAuthLinks[index] = gitAuthLink - - return gitAuthLink, nil - } - return database.GitAuthLink{}, sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.gitSSHKey { - if key.UserID != arg.UserID { - continue - } - key.UpdatedAt = arg.UpdatedAt - key.PrivateKey = arg.PrivateKey - key.PublicKey = arg.PublicKey - q.gitSSHKey[index] = key - return key, nil - } - return database.GitSSHKey{}, sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, group := range q.groups { - if group.ID == arg.ID { - group.Name = arg.Name - group.AvatarURL = arg.AvatarURL - group.QuotaAllowance = arg.QuotaAllowance - q.groups[i] = group - return group, nil - } - } - return database.Group{}, sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, mem := range q.organizationMembers { - if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { - uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) - exist := make(map[string]struct{}) - for _, r := range arg.GrantedRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - - mem.Roles = uniqueRoles - q.organizationMembers[i] = mem - return mem, nil - } - } - - return database.OrganizationMember{}, sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { +func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -5248,169 +4784,669 @@ func (q *fakeQuerier) UpdateWorkspaceLastUsedAt(_ context.Context, arg database. if workspace.ID != arg.ID { continue } - workspace.LastUsedAt = arg.LastUsedAt - q.workspaces[index] = workspace - return nil + workspace.LastUsedAt = arg.LastUsedAt + q.workspaces[index] = workspace + return nil + } + + return sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateWorkspaceLockedAt(_ context.Context, arg database.UpdateWorkspaceLockedAtParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { + continue + } + workspace.LockedAt = arg.LockedAt + workspace.LastUsedAt = database.Now() + q.workspaces[index] = workspace + return nil + } + + return sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, p := range q.workspaceProxies { + if p.Name == arg.Name && p.ID != arg.ID { + return database.WorkspaceProxy{}, errDuplicateKey + } + } + + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Name = arg.Name + p.DisplayName = arg.DisplayName + p.Icon = arg.Icon + if len(p.TokenHashedSecret) > 0 { + p.TokenHashedSecret = arg.TokenHashedSecret + } + q.workspaceProxies[i] = p + return p, nil + } + } + return database.WorkspaceProxy{}, sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Deleted = arg.Deleted + p.UpdatedAt = database.Now() + q.workspaceProxies[i] = p + return nil + } + } + return sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { + continue + } + workspace.Ttl = arg.Ttl + q.workspaces[index] = workspace + return nil + } + + return sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateWorkspaceTTLToBeWithinTemplateMax(_ context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, workspace := range q.workspaces { + if workspace.TemplateID != arg.TemplateID || !workspace.Ttl.Valid || workspace.Ttl.Int64 < arg.TemplateMaxTTL { + continue + } + + workspace.Ttl = sql.NullInt64{Int64: arg.TemplateMaxTTL, Valid: true} + q.workspaces[index] = workspace + } + + return nil +} + +func (q *fakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.appSecurityKey = data + return nil +} + +func (q *fakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { + q.defaultProxyDisplayName = arg.DisplayName + q.defaultProxyIconURL = arg.IconUrl + return nil +} + +func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.lastUpdateCheck = []byte(data) + return nil +} + +func (q *fakeQuerier) UpsertLogoURL(_ context.Context, data string) error { + q.mutex.RLock() + defer q.mutex.RUnlock() + + q.logoURL = data + return nil +} + +func (q *fakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.oauthSigningKey = value + return nil +} + +func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error { + q.mutex.RLock() + defer q.mutex.RUnlock() + + q.serviceBanner = []byte(data) + return nil +} + +func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + return database.TailnetAgent{}, ErrUnimplemented +} + +func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { + return database.TailnetClient{}, ErrUnimplemented +} + +func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { + return database.TailnetCoordinator{}, ErrUnimplemented +} + +func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) + if err != nil { + return nil, err + } + } + + var templates []database.Template + for _, template := range q.templates { + if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { + continue + } + + if template.Deleted != arg.Deleted { + continue + } + if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { + continue + } + + if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { + continue + } + + if len(arg.IDs) > 0 { + match := false + for _, id := range arg.IDs { + if template.ID == id { + match = true + break + } + } + if !match { + continue + } + } + templates = append(templates, template.DeepCopy()) + } + if len(templates) > 0 { + slices.SortFunc(templates, func(i, j database.Template) bool { + if i.Name != j.Name { + return i.Name < j.Name + } + return i.ID.String() < j.ID.String() + }) + return templates, nil + } + + return nil, sql.ErrNoRows +} + +func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break + } + } + + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows + } + + groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) + for k, v := range template.GroupACL { + group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get group by ID: %w", err) + } + // We don't delete groups from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { + continue + } + + groups = append(groups, database.TemplateGroup{ + Group: group, + Actions: v, + }) + } + + return groups, nil +} + +func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break + } + } + + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows + } + + users := make([]database.TemplateUser, 0, len(template.UserACL)) + for k, v := range template.UserACL { + user, err := q.getUserByIDNoLock(uuid.MustParse(k)) + if err != nil && xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user by ID: %w", err) + } + // We don't delete users from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { + continue + } + + if user.Deleted || user.Status == database.UserStatusSuspended { + continue + } + + users = append(users, database.TemplateUser{ + User: user, + Actions: v, + }) + } + + return users, nil +} + +//nolint:gocyclo +func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + if prepared != nil { + // Call this to match the same function calls as the SQL implementation. + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return nil, err + } + } + + workspaces := make([]database.Workspace, 0) + for _, workspace := range q.workspaces { + if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { + continue + } + + if arg.OwnerUsername != "" { + owner, err := q.getUserByIDNoLock(workspace.OwnerID) + if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { + continue + } + } + + if arg.TemplateName != "" { + template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) + if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { + continue + } + } + + if !arg.Deleted && workspace.Deleted { + continue + } + + if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { + continue + } + + if arg.Status != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + + // This logic should match the logic in the workspace.sql file. + var statusMatch bool + switch database.WorkspaceStatus(arg.Status) { + case database.WorkspaceStatusPending: + statusMatch = isNull(job.StartedAt) + case database.WorkspaceStatusStarting: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStart + + case database.WorkspaceStatusRunning: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStart + + case database.WorkspaceStatusStopping: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStop + + case database.WorkspaceStatusStopped: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStop + case database.WorkspaceStatusFailed: + statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || + (isNotNull(job.CompletedAt) && isNotNull(job.Error)) + + case database.WorkspaceStatusCanceling: + statusMatch = isNotNull(job.CanceledAt) && + isNull(job.CompletedAt) + + case database.WorkspaceStatusCanceled: + statusMatch = isNotNull(job.CanceledAt) && + isNotNull(job.CompletedAt) + + case database.WorkspaceStatusDeleted: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNotNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionDelete && + isNull(job.Error) + + case database.WorkspaceStatusDeleting: + statusMatch = isNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionDelete + + default: + return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) + } + if !statusMatch { + continue + } + } + + if arg.HasAgent != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + + workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) + if err != nil { + return nil, xerrors.Errorf("get workspace resources: %w", err) + } + + var workspaceResourceIDs []uuid.UUID + for _, wr := range workspaceResources { + workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) + } + + workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) + if err != nil { + return nil, xerrors.Errorf("get workspace agents: %w", err) + } + + var hasAgentMatched bool + for _, wa := range workspaceAgents { + if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { + hasAgentMatched = true + } + } + + if !hasAgentMatched { + continue + } + } + + if len(arg.TemplateIds) > 0 { + match := false + for _, id := range arg.TemplateIds { + if workspace.TemplateID == id { + match = true + break + } + } + if !match { + continue + } + } + + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { + continue + } + workspaces = append(workspaces, workspace) } - return sql.ErrNoRows -} - -func (q *fakeQuerier) UpdateWorkspaceLockedAt(_ context.Context, arg database.UpdateWorkspaceLockedAtParams) error { - if err := validateDatabaseType(arg); err != nil { - return err + // Sort workspaces (ORDER BY) + isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { + return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart } - q.mutex.Lock() - defer q.mutex.Unlock() + preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} + preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} + preloadedUsers := map[uuid.UUID]database.User{} - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue + for _, w := range workspaces { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) + if err == nil { + preloadedWorkspaceBuilds[w.ID] = build + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get latest build: %w", err) } - workspace.LockedAt = arg.LockedAt - workspace.LastUsedAt = database.Now() - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} -func (q *fakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err == nil { + preloadedProvisionerJobs[w.ID] = job + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } - for _, p := range q.workspaceProxies { - if p.Name == arg.Name && p.ID != arg.ID { - return database.WorkspaceProxy{}, errDuplicateKey + user, err := q.getUserByIDNoLock(w.OwnerID) + if err == nil { + preloadedUsers[w.ID] = user + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user: %w", err) } } - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Name = arg.Name - p.DisplayName = arg.DisplayName - p.Icon = arg.Icon - if len(p.TokenHashedSecret) > 0 { - p.TokenHashedSecret = arg.TokenHashedSecret - } - q.workspaceProxies[i] = p - return p, nil + sort.Slice(workspaces, func(i, j int) bool { + w1 := workspaces[i] + w2 := workspaces[j] + + // Order by: running first + w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) + w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) + + if w1IsRunning && !w2IsRunning { + return true } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} -func (q *fakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() + if !w1IsRunning && w2IsRunning { + return false + } - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Deleted = arg.Deleted - p.UpdatedAt = database.Now() - q.workspaceProxies[i] = p - return nil + // Order by: usernames + if w1.ID != w2.ID { + return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) } - } - return sql.ErrNoRows -} -func (q *fakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + // Order by: workspace names + return sort.StringsAreSorted([]string{w1.Name, w2.Name}) + }) - q.mutex.Lock() - defer q.mutex.Unlock() + beforePageCount := len(workspaces) - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue + if arg.Offset > 0 { + if int(arg.Offset) > len(workspaces) { + return []database.GetWorkspacesRow{}, nil } - workspace.Ttl = arg.Ttl - q.workspaces[index] = workspace - return nil + workspaces = workspaces[arg.Offset:] + } + if arg.Limit > 0 { + if int(arg.Limit) > len(workspaces) { + return convertToWorkspaceRows(workspaces, int64(beforePageCount)), nil + } + workspaces = workspaces[:arg.Limit] } - return sql.ErrNoRows + return convertToWorkspaceRows(workspaces, int64(beforePageCount)), nil } -func (q *fakeQuerier) UpdateWorkspaceTTLToBeWithinTemplateMax(_ context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { +func (q *fakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { if err := validateDatabaseType(arg); err != nil { - return err + return nil, err } - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.TemplateID != arg.TemplateID || !workspace.Ttl.Valid || workspace.Ttl.Int64 < arg.TemplateMaxTTL { - continue + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + if err != nil { + return nil, err } - - workspace.Ttl = sql.NullInt64{Int64: arg.TemplateMaxTTL, Valid: true} - q.workspaces[index] = workspace } - return nil -} + users, err := q.GetUsers(ctx, arg) + if err != nil { + return nil, err + } -func (q *fakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() - q.appSecurityKey = data - return nil -} + filteredUsers := make([]database.GetUsersRow, 0, len(users)) + for _, user := range users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } -func (q *fakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { - q.defaultProxyDisplayName = arg.DisplayName - q.defaultProxyIconURL = arg.IconUrl - return nil + filteredUsers = append(filteredUsers, user) + } + return filteredUsers, nil } -func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.lastUpdateCheck = []byte(data) - return nil -} +func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + if err := validateDatabaseType(params); err != nil { + return 0, err + } -func (q *fakeQuerier) UpsertLogoURL(_ context.Context, data string) error { q.mutex.RLock() defer q.mutex.RUnlock() - q.logoURL = data - return nil -} + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return -1, err + } + } -func (q *fakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() + users := make([]database.User, 0, len(q.users)) - q.oauthSigningKey = value - return nil -} + for _, user := range q.users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } -func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + users = append(users, user) + } - q.serviceBanner = []byte(data) - return nil -} + // Filter out deleted since they should never be returned.. + tmp := make([]database.User, 0, len(users)) + for _, user := range users { + if !user.Deleted { + tmp = append(tmp, user) + } + } + users = tmp -func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - return database.TailnetAgent{}, ErrUnimplemented -} + if params.Search != "" { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } + } + users = tmp + } -func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { - return database.TailnetClient{}, ErrUnimplemented -} + if len(params.Status) > 0 { + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { + return strings.EqualFold(string(a), string(b)) + }) { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) + } + } + users = usersFilteredByStatus + } -func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { - return database.TailnetCoordinator{}, ErrUnimplemented + if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { + usersFilteredByRole := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { + usersFilteredByRole = append(usersFilteredByRole, users[i]) + } + } + + users = usersFilteredByRole + } + + return int64(len(users)), nil } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index ec28fd428a102..d685f37ce3510 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "golang.org/x/exp/slices" @@ -73,41 +74,6 @@ func (m metricsStore) InTx(f func(database.Store) error, options *sql.TxOptions) return err } -func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - start := time.Now() - templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) - return templates, err -} - -func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - start := time.Now() - roles, err := m.s.GetTemplateGroupRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - start := time.Now() - roles, err := m.s.GetTemplateUserRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - start := time.Now() - workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) - return workspaces, err -} - -func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - start := time.Now() - count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { start := time.Now() err := m.s.AcquireLock(ctx, pgAdvisoryXactLock) @@ -1639,3 +1605,45 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetCoordinator(ctx, id) } + +func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { + start := time.Now() + templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) + return templates, err +} + +func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + start := time.Now() + roles, err := m.s.GetTemplateGroupRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + start := time.Now() + roles, err := m.s.GetTemplateUserRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + start := time.Now() + workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) + return workspaces, err +} + +func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + start := time.Now() + r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) + return count, err +} diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go index 8f87892510103..c6ddb7858e7f9 100644 --- a/scripts/dbgen/main.go +++ b/scripts/dbgen/main.go @@ -32,6 +32,11 @@ func init() { if err != nil { panic(err) } + customFuncs, err := readCustomQuerierFunctions() + if err != nil { + panic(err) + } + funcs = append(funcs, customFuncs...) funcByName = map[string]struct{}{} for _, f := range funcs { funcByName[f.Name] = struct{}{} @@ -423,11 +428,25 @@ func readQuerierFunctions() ([]querierFunction, error) { return nil, err } querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go") + return loadQuerierFunctions(querierPath, "sqlcQuerier") +} + +// readCustomQuerierFunctions reads the functions from coderd/database/modelqueries.go +func readCustomQuerierFunctions() ([]querierFunction, error) { + localPath, err := localFilePath() + if err != nil { + return nil, err + } + querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "modelqueries.go") + return loadQuerierFunctions(querierPath, "customQuerier") +} - querierData, err := os.ReadFile(querierPath) +func loadQuerierFunctions(filename string, interfaceName string) ([]querierFunction, error) { + querierData, err := os.ReadFile(filename) if err != nil { return nil, xerrors.Errorf("read querier: %w", err) } + f, err := decorator.Parse(querierData) if err != nil { return nil, err @@ -447,7 +466,7 @@ func readQuerierFunctions() ([]querierFunction, error) { } // This is the name of the interface. If that ever changes, // this will need to be updated. - if typeSpec.Name.Name != "sqlcQuerier" { + if typeSpec.Name.Name != interfaceName { continue } querier, ok = typeSpec.Type.(*dst.InterfaceType) @@ -461,7 +480,9 @@ func readQuerierFunctions() ([]querierFunction, error) { return nil, xerrors.Errorf("querier not found") } funcs := []querierFunction{} - for _, method := range querier.Methods.List { + allMethods := interfaceMethods(querier) + + for _, method := range allMethods { funcType, ok := method.Type.(*dst.FuncType) if !ok { continue @@ -540,3 +561,30 @@ func nameFromSnakeCase(s string) string { } return ret } + +// interfaceMethods returns all embedded methods of an interface. +func interfaceMethods(i *dst.InterfaceType) []*dst.Field { + var allMethods []*dst.Field + for _, field := range i.Methods.List { + switch fieldType := field.Type.(type) { + case *dst.FuncType: + allMethods = append(allMethods, field) + case *dst.InterfaceType: + allMethods = append(allMethods, interfaceMethods(fieldType)...) + case *dst.Ident: + // Embedded interfaces are Idents -> TypeSpec -> InterfaceType + // If the embedded interface is not in the parsed file, then + // the Obj will be nil. + if fieldType.Obj != nil { + objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec) + if ok { + isInterface, ok := objDecl.Type.(*dst.InterfaceType) + if ok { + allMethods = append(allMethods, interfaceMethods(isInterface)...) + } + } + } + } + } + return allMethods +} From a9b7704a194996130fe21ab2b6a43f83bc07e0d4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 12 Jul 2023 11:55:41 -0400 Subject: [PATCH 5/5] feat: implement proper sql filter for users query --- coderd/database/dbmetrics/dbmetrics.go | 1 - coderd/database/queries.sql.go | 2 + coderd/database/queries/users.sql | 2 + coderd/rbac/input.json | 26 +++++------ coderd/rbac/regosql/compile_test.go | 20 ++++++++ coderd/rbac/regosql/configs.go | 5 +- coderd/rbac/regosql/sqltypes/always_false.go | 49 +++++++++++++------- 7 files changed, 71 insertions(+), 34 deletions(-) diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index d685f37ce3510..1bbbf42b3ab16 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -9,7 +9,6 @@ import ( "time" "github.com/google/uuid" - "github.com/prometheus/client_golang/prometheus" "golang.org/x/exp/slices" diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 08c0777b2a307..c30194fb0bdf9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5302,6 +5302,8 @@ WHERE END -- End of filters + -- Authorize Filter clause will be injected below in GetAuthorizedUserCount + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET $7 diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 28f7a5ca6ba0b..c96cdbd0aa7e1 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -209,6 +209,8 @@ WHERE END -- End of filters + -- Authorize Filter clause will be injected below in GetAuthorizedUserCount + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET @offset_opt diff --git a/coderd/rbac/input.json b/coderd/rbac/input.json index 5e464168ac5ac..71a81ec81de15 100644 --- a/coderd/rbac/input.json +++ b/coderd/rbac/input.json @@ -1,12 +1,11 @@ { - "action": "never-match-action", + "action": "read", "object": { "id": "9046b041-58ed-47a3-9c3a-de302577875a", - "owner": "00000000-0000-0000-0000-000000000000", - "org_owner": "bf7b72bd-a2b1-4ef2-962c-1d698e0483f6", - "type": "workspace", + "owner": "9046b041-58ed-47a3-9c3a-de302577875a", + "org_owner": "00000000-0000-0000-0000-000000000000", + "type": "user", "acl_user_list": { - "f041847d-711b-40da-a89a-ede39f70dc7f": ["create"] }, "acl_group_list": {} }, @@ -14,20 +13,21 @@ "id": "10d03e62-7703-4df5-a358-4f76577d4e2f", "roles": [ { - "name": "owner", - "display_name": "Owner", + "name": "member", + "display_name": "Member", "site": [ + ], + "org": {}, + "user": [ { "negate": false, - "resource_type": "*", - "action": "*" + "resource_type": "user", + "action": "read" } - ], - "org": {}, - "user": [] + ] } ], - "groups": ["b617a647-b5d0-4cbe-9e40-26f89710bf18"], + "groups": [], "scope": { "name": "Scope_all", "display_name": "All operations", diff --git a/coderd/rbac/regosql/compile_test.go b/coderd/rbac/regosql/compile_test.go index 6c350b7834639..5673b8621c2c7 100644 --- a/coderd/rbac/regosql/compile_test.go +++ b/coderd/rbac/regosql/compile_test.go @@ -242,6 +242,26 @@ neq(input.object.owner, ""); p("false")), VariableConverter: regosql.TemplateConverter(), }, + { + Name: "UserNoOrgOwner", + Queries: []string{ + `input.object.org_owner != ""`, + }, + ExpectedSQL: p("'' != ''"), + VariableConverter: regosql.UserConverter(), + }, + { + Name: "UserOwnsSelf", + Queries: []string{ + `"10d03e62-7703-4df5-a358-4f76577d4e2f" = input.object.owner; + input.object.owner != ""; + input.object.org_owner = ""`, + }, + VariableConverter: regosql.UserConverter(), + ExpectedSQL: p( + p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = id :: text") + " AND " + p("id :: text != ''") + " AND " + p("'' = ''"), + ), + }, } for _, tc := range testCases { diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 6c33eadb4c97b..a2f1db4a0cba9 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -25,8 +25,9 @@ func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher { func UserConverter() *sqltypes.VariableConverter { matcher := sqltypes.NewVariableConverter().RegisterMatcher( resourceIDMatcher(), - // Users are never owned by an organization. - sqltypes.AlwaysFalse(organizationOwnerMatcher()), + // Users are never owned by an organization, so always return the empty string + // for the org owner. + sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}), // Users are always owned by themselves. sqltypes.StringVarMatcher("id :: text", []string{"input", "object", "owner"}), ) diff --git a/coderd/rbac/regosql/sqltypes/always_false.go b/coderd/rbac/regosql/sqltypes/always_false.go index 93831d844c8b1..da2c1891dae2b 100644 --- a/coderd/rbac/regosql/sqltypes/always_false.go +++ b/coderd/rbac/regosql/sqltypes/always_false.go @@ -1,45 +1,58 @@ package sqltypes import ( + "strconv" + "github.com/open-policy-agent/opa/ast" ) var ( - _ Node = alwaysFalse{} - _ VariableMatcher = alwaysFalse{} + _ Node = constBoolean{} + _ VariableMatcher = constBoolean{} ) -type alwaysFalse struct { - Matcher VariableMatcher +type constBoolean struct { + Matcher VariableMatcher + constant bool InnerNode Node } // AlwaysFalse overrides the inner node with a constant "false". func AlwaysFalse(m VariableMatcher) VariableMatcher { - return alwaysFalse{ - Matcher: m, + return constBoolean{ + Matcher: m, + constant: false, + } +} + +func AlwaysTrue(m VariableMatcher) VariableMatcher { + return constBoolean{ + Matcher: m, + constant: true, } } // AlwaysFalseNode is mainly used for unit testing to make a Node immediately. func AlwaysFalseNode(n Node) Node { - return alwaysFalse{ + return constBoolean{ InnerNode: n, Matcher: nil, + constant: false, } } // UseAs uses a type no one supports to always override with false. -func (alwaysFalse) UseAs() Node { return alwaysFalse{} } +func (constBoolean) UseAs() Node { return constBoolean{} } -func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) { +func (f constBoolean) ConvertVariable(rego ast.Ref) (Node, bool) { if f.Matcher != nil { n, ok := f.Matcher.ConvertVariable(rego) if ok { - return alwaysFalse{ + return constBoolean{ Matcher: f.Matcher, InnerNode: n, + constant: f.constant, }, true } } @@ -47,18 +60,18 @@ func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) { return nil, false } -func (alwaysFalse) SQLString(_ *SQLGenerator) string { - return "false" +func (c constBoolean) SQLString(_ *SQLGenerator) string { + return strconv.FormatBool(c.constant) } -func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) { - return "false", nil +func (c constBoolean) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) { + return strconv.FormatBool(c.constant), nil } -func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) { - return "false", nil +func (c constBoolean) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) { + return strconv.FormatBool(c.constant), nil } -func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) { - return "false", nil +func (c constBoolean) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) { + return strconv.FormatBool(c.constant), nil }