diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 93adedf9e87b3..49ca69ddd1571 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -18,6 +18,8 @@ import ( // New returns an in-memory fake of the database. func New() database.Store { return &fakeQuerier{ + mutex: newReentrantLock(), + apiKeys: make([]database.APIKey, 0), organizationMembers: make([]database.OrganizationMember, 0), organizations: make([]database.Organization, 0), @@ -43,7 +45,7 @@ func New() database.Store { // fakeQuerier replicates database functionality to enable quick testing. type fakeQuerier struct { - mutex sync.RWMutex + mutex sync.Locker // Legacy tables apiKeys []database.APIKey @@ -73,6 +75,8 @@ type fakeQuerier struct { // InTx doesn't rollback data properly for in-memory yet. func (q *fakeQuerier) InTx(fn func(database.Store) error) error { + q.mutex.Lock() + defer q.mutex.Unlock() return fn(q) } @@ -133,8 +137,8 @@ func (q *fakeQuerier) DeleteParameterValueByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, apiKey := range q.apiKeys { if apiKey.ID == id { @@ -145,8 +149,8 @@ func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIK } func (q *fakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apiKeys := make([]database.APIKey, 0) for _, key := range q.apiKeys { @@ -173,8 +177,8 @@ func (q *fakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { } func (q *fakeQuerier) GetFileByHash(_ context.Context, hash string) (database.File, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, file := range q.files { if file.Hash == hash { @@ -185,8 +189,8 @@ func (q *fakeQuerier) GetFileByHash(_ context.Context, hash string) (database.Fi } func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, user := range q.users { if user.Email == arg.Email || user.Username == arg.Username { @@ -197,8 +201,8 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.G } func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, user := range q.users { if user.ID == id { @@ -209,15 +213,15 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use } func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return int64(len(q.users)), nil } func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // Avoid side-effect of sorting. users := make([]database.User, len(q.users)) @@ -294,8 +298,8 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams } func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() users := make([]database.User, 0) for _, user := range q.users { @@ -310,8 +314,8 @@ func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]datab } func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var user *database.User roles := make([]string, 0) @@ -345,8 +349,8 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U } func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspaces { @@ -395,8 +399,8 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace } func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspace := range q.workspaces { if workspace.ID.String() == id.String() { @@ -407,8 +411,8 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas } func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var found *database.Workspace for _, workspace := range q.workspaces { @@ -435,8 +439,8 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa } func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -451,8 +455,8 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -464,8 +468,8 @@ func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time } func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -480,8 +484,8 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U } func (q *fakeQuerier) GetWorkspacesAutostart(_ context.Context) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaces := make([]database.Workspace, 0) for _, ws := range q.workspaces { if ws.AutostartSchedule.String != "" { @@ -494,8 +498,8 @@ func (q *fakeQuerier) GetWorkspacesAutostart(_ context.Context) ([]database.Work } func (q *fakeQuerier) GetWorkspaceOwnerCountsByTemplateIDs(_ context.Context, templateIDs []uuid.UUID) ([]database.GetWorkspaceOwnerCountsByTemplateIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() counts := map[uuid.UUID]map[uuid.UUID]struct{}{} for _, templateID := range templateIDs { @@ -534,8 +538,8 @@ func (q *fakeQuerier) GetWorkspaceOwnerCountsByTemplateIDs(_ context.Context, te } func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, history := range q.workspaceBuilds { if history.ID.String() == id.String() { @@ -546,8 +550,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, build := range q.workspaceBuilds { if build.JobID.String() == jobID.String() { @@ -558,8 +562,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUI } func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var row database.WorkspaceBuild var buildNum int32 @@ -576,8 +580,8 @@ func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(_ context.Context, wo } func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() builds := make(map[uuid.UUID]database.WorkspaceBuild) buildNumbers := make(map[uuid.UUID]int32) @@ -604,8 +608,8 @@ func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceID(_ context.Context, params database.GetWorkspaceBuildByWorkspaceIDParams) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() history := make([]database.WorkspaceBuild, 0) for _, workspaceBuild := range q.workspaceBuilds { @@ -658,8 +662,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceID(_ context.Context, } func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndName(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndNameParams) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspaceBuild := range q.workspaceBuilds { if workspaceBuild.WorkspaceID.String() != arg.WorkspaceID.String() { @@ -674,8 +678,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndName(_ context.Context, a } func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspaceBuild := range q.workspaceBuilds { if workspaceBuild.WorkspaceID.String() != arg.WorkspaceID.String() { @@ -690,8 +694,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Con } func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceBuilds := make([]database.WorkspaceBuild, 0) for _, workspaceBuild := range q.workspaceBuilds { @@ -703,8 +707,8 @@ func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() if len(q.organizations) == 0 { return nil, sql.ErrNoRows @@ -713,8 +717,8 @@ func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organizati } func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organization := range q.organizations { if organization.ID == id { @@ -725,8 +729,8 @@ func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (data } func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organization := range q.organizations { if organization.Name == name { @@ -737,8 +741,8 @@ func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (dat } func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UUID) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() organizations := make([]database.Organization, 0) for _, organizationMember := range q.organizationMembers { @@ -759,8 +763,8 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UU } func (q *fakeQuerier) ParameterValues(_ context.Context, arg database.ParameterValuesParams) ([]database.ParameterValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameterValues := make([]database.ParameterValue, 0) for _, parameterValue := range q.parameterValues { @@ -789,8 +793,8 @@ func (q *fakeQuerier) ParameterValues(_ context.Context, arg database.ParameterV } func (q *fakeQuerier) GetTemplateByID(_ context.Context, id uuid.UUID) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, template := range q.templates { if template.ID.String() == id.String() { @@ -801,8 +805,8 @@ func (q *fakeQuerier) GetTemplateByID(_ context.Context, id uuid.UUID) (database } func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, template := range q.templates { if template.OrganizationID != arg.OrganizationID { @@ -820,8 +824,8 @@ func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg da } func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for idx, tpl := range q.templates { if tpl.ID != arg.ID { @@ -839,8 +843,8 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd } func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var templates []database.Template for _, template := range q.templates { @@ -877,8 +881,8 @@ func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.Get } func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.TemplateID.UUID.String() != arg.TemplateID.String() { @@ -936,8 +940,8 @@ func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg dat } func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() versions := make([]database.TemplateVersion, 0) for _, version := range q.templateVersions { @@ -949,8 +953,8 @@ func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after t } func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.TemplateID != arg.TemplateID { @@ -965,8 +969,8 @@ func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, a } func (q *fakeQuerier) GetTemplateVersionByID(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.ID.String() != templateVersionID.String() { @@ -978,8 +982,8 @@ func (q *fakeQuerier) GetTemplateVersionByID(_ context.Context, templateVersionI } func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.JobID.String() != jobID.String() { @@ -991,8 +995,8 @@ func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UU } func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameters := make([]database.ParameterSchema, 0) for _, parameterSchema := range q.parameterSchemas { @@ -1008,8 +1012,8 @@ func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.U } func (q *fakeQuerier) GetParameterSchemasCreatedAfter(_ context.Context, after time.Time) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameters := make([]database.ParameterSchema, 0) for _, parameterSchema := range q.parameterSchemas { @@ -1021,8 +1025,8 @@ func (q *fakeQuerier) GetParameterSchemasCreatedAfter(_ context.Context, after t } func (q *fakeQuerier) GetParameterValueByScopeAndName(_ context.Context, arg database.GetParameterValueByScopeAndNameParams) (database.ParameterValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, parameterValue := range q.parameterValues { if parameterValue.Scope != arg.Scope { @@ -1040,15 +1044,15 @@ func (q *fakeQuerier) GetParameterValueByScopeAndName(_ context.Context, arg dat } func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return q.templates[:], nil } func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organizationMember := range q.organizationMembers { if organizationMember.OrganizationID != arg.OrganizationID { @@ -1063,8 +1067,8 @@ func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg datab } func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) for _, userID := range ids { @@ -1086,8 +1090,8 @@ func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui } func (q *fakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var memberships []database.OrganizationMember for _, organizationMember := range q.organizationMembers { @@ -1123,8 +1127,8 @@ func (q *fakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe } func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() if len(q.provisionerDaemons) == 0 { return nil, sql.ErrNoRows @@ -1133,8 +1137,8 @@ func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi } func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1147,8 +1151,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken } func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1161,8 +1165,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1175,8 +1179,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI } func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceAgents := make([]database.WorkspaceAgent, 0) for _, agent := range q.provisionerJobAgents { @@ -1194,8 +1198,8 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc } func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceAgents := make([]database.WorkspaceAgent, 0) for _, agent := range q.provisionerJobAgents { @@ -1207,8 +1211,8 @@ func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, app := range q.workspaceApps { if app.AgentID != arg.AgentID { @@ -1223,8 +1227,8 @@ func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg dat } func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, provisionerDaemon := range q.provisionerDaemons { if provisionerDaemon.ID.String() != id.String() { @@ -1236,8 +1240,8 @@ func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, provisionerJob := range q.provisionerJobs { if provisionerJob.ID.String() != id.String() { @@ -1249,8 +1253,8 @@ func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, resource := range q.provisionerJobResources { if resource.ID.String() == id.String() { @@ -1261,8 +1265,8 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() resources := make([]database.WorkspaceResource, 0) for _, resource := range q.provisionerJobResources { @@ -1278,8 +1282,8 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid } func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() resources := make([]database.WorkspaceResource, 0) for _, resource := range q.provisionerJobResources { @@ -1291,8 +1295,8 @@ func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after } func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() jobs := make([]database.ProvisionerJob, 0) for _, job := range q.provisionerJobs { @@ -1311,8 +1315,8 @@ func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID } func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() jobs := make([]database.ProvisionerJob, 0) for _, job := range q.provisionerJobs { @@ -1324,8 +1328,8 @@ func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetProvisionerLogsByIDBetween(_ context.Context, arg database.GetProvisionerLogsByIDBetweenParams) ([]database.ProvisionerJobLog, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() logs := make([]database.ProvisionerJobLog, 0) for _, jobLog := range q.provisionerJobLogs { @@ -2011,8 +2015,8 @@ func (q *fakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitS } func (q *fakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, key := range q.gitSSHKey { if key.UserID == userID { @@ -2055,8 +2059,8 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error } func (q *fakeQuerier) GetAuditLogsBefore(_ context.Context, arg database.GetAuditLogsBeforeParams) ([]database.AuditLog, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() logs := make([]database.AuditLog, 0) start := database.AuditLog{} @@ -2114,8 +2118,8 @@ func (q *fakeQuerier) InsertDeploymentID(_ context.Context, id string) error { } func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return q.deploymentID, nil } diff --git a/coderd/database/databasefake/reentrantlock.go b/coderd/database/databasefake/reentrantlock.go new file mode 100644 index 0000000000000..97f31754df87f --- /dev/null +++ b/coderd/database/databasefake/reentrantlock.go @@ -0,0 +1,98 @@ +package databasefake + +import ( + "bytes" + "runtime" + "strconv" + "sync" +) + +// reentrantLock is a lock that can be locked multiple times from the same goroutine. +// +// Go maintainers insist that this is a Bad Idea and refuse to implement it in the standard library, so let's talk about +// why we're doing it here. +// +// We want to support locking the fake database for the duration of a transaction, so that other goroutines cannot see +// uncommitted transactions. However, we also need to lock the database during queries that are not explicitly in a +// transaction. When a goroutine executing a transaction calls a query, it is already holding the lock, so attempting +// to lock a standard mutex again will deadlock. A reentrant lock neatly solves this problem. +// +// The argument I've heard around why reentrant locks are a Bad Idea points out that it indicates a problem with your +// interface, because some methods must leave the database in an inconsistent state. That criticism applies here +// because the whole reason we need transactions is sometimes individual queries leave the database in an inconsistent +// state. However valid the criticism, the assumption that it becomes the most important factor in all cases is flawed +// and, frankly, patronizing. +// +// Here we do not have the luxury of reinventing the interface, because this fake database is attempting to +// emulate another piece of software which does have this interface: postgres. Basically, the logic that enforces the +// consistency of the database resides at a higher layer than we are emulating. +// +// Some alternatives considered, but rejected: +// +// 1. create an explicit transaction type, which are serialized to a channel and then processed in order. +// * requires implementing each query function twice, once wrapping it in a transaction, and once doing the real +// work +// * cannot support recursive transactions +// 2. store whether we're in a transaction in the Context passed to the queries. +// * changes InTx(func(store) error) -> InTx(ctx, func(ctx2, store) error). Inside the transaction function, +// callers **must use** ctx2 to query. Use of other contexts, like ctx, will deadlock. Adding this tripmine +// to every use of transactions seems like a recipe for bugs that are hard to diagnose (and only show up with the +// fake database). +type reentrantLock struct { + c *sync.Cond + holder uint64 + n uint64 +} + +// getGID returns the goroutine ID. +// +// From https://blog.sgmansfield.com/2015/12/goroutine-ids/ +func getGID() uint64 { + b := make([]byte, 64) + b = b[:runtime.Stack(b, false)] + b = bytes.TrimPrefix(b, []byte("goroutine ")) + b = b[:bytes.IndexByte(b, ' ')] + n, _ := strconv.ParseUint(string(b), 10, 64) + return n +} + +func newReentrantLock() sync.Locker { + return &reentrantLock{ + c: sync.NewCond(&sync.Mutex{}), + holder: 0, + n: 0, + } +} + +func (r *reentrantLock) Lock() { + gID := getGID() + r.c.L.Lock() + defer r.c.L.Unlock() + for { + if r.holder == 0 { + // not held by any goroutine + r.holder = gID + break + } + if r.holder == gID { + // held by us + break + } + r.c.Wait() + } + r.n++ +} + +func (r *reentrantLock) Unlock() { + gID := getGID() + r.c.L.Lock() + defer r.c.L.Unlock() + if r.holder != gID { + panic("unlocked without holding lock") + } + r.n-- + if r.n == 0 { + r.holder = 0 + r.c.Signal() + } +} diff --git a/go.sum b/go.sum index c05c3275c9b28..f3fb10a21b0db 100644 --- a/go.sum +++ b/go.sum @@ -1958,6 +1958,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200107162124-548cf772de50/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=