From 9e830f3b722fcc6ecae12439079a854ac9631a3d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 17 Jun 2022 10:47:31 -0700 Subject: [PATCH] fake database locks during transactions Signed-off-by: Spike Curtis --- coderd/database/databasefake/databasefake.go | 262 +++++++++--------- coderd/database/databasefake/reentrantlock.go | 98 +++++++ go.sum | 1 + 3 files changed, 232 insertions(+), 129 deletions(-) create mode 100644 coderd/database/databasefake/reentrantlock.go diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index e41c70b874ed8..443bf89259c10 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -17,6 +17,8 @@ import ( // New returns an in-memory fake of the database. func New() database.Store { return &fakeQuerier{ + mutex: newReentrantLock(), + apiKeys: make([]database.APIKey, 0), organizationMembers: make([]database.OrganizationMember, 0), organizations: make([]database.Organization, 0), @@ -42,7 +44,7 @@ func New() database.Store { // fakeQuerier replicates database functionality to enable quick testing. type fakeQuerier struct { - mutex sync.RWMutex + mutex sync.Locker // Legacy tables apiKeys []database.APIKey @@ -72,6 +74,8 @@ type fakeQuerier struct { // InTx doesn't rollback data properly for in-memory yet. func (q *fakeQuerier) InTx(fn func(database.Store) error) error { + q.mutex.Lock() + defer q.mutex.Unlock() return fn(q) } @@ -119,8 +123,8 @@ func (q *fakeQuerier) DeleteParameterValueByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, apiKey := range q.apiKeys { if apiKey.ID == id { @@ -131,8 +135,8 @@ func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIK } func (q *fakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apiKeys := make([]database.APIKey, 0) for _, key := range q.apiKeys { @@ -159,8 +163,8 @@ func (q *fakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { } func (q *fakeQuerier) GetFileByHash(_ context.Context, hash string) (database.File, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, file := range q.files { if file.Hash == hash { @@ -171,8 +175,8 @@ func (q *fakeQuerier) GetFileByHash(_ context.Context, hash string) (database.Fi } func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, user := range q.users { if user.Email == arg.Email || user.Username == arg.Username { @@ -183,8 +187,8 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.G } func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, user := range q.users { if user.ID == id { @@ -195,15 +199,15 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use } func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return int64(len(q.users)), nil } func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // Avoid side-effect of sorting. users := make([]database.User, len(q.users)) @@ -280,8 +284,8 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams } func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() users := make([]database.User, 0) for _, user := range q.users { @@ -296,8 +300,8 @@ func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]datab } func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var user *database.User roles := make([]string, 0) @@ -331,8 +335,8 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U } func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspaces { @@ -381,8 +385,8 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace } func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspace := range q.workspaces { if workspace.ID.String() == id.String() { @@ -393,8 +397,8 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas } func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var found *database.Workspace for _, workspace := range q.workspaces { @@ -421,8 +425,8 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa } func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -437,8 +441,8 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -450,8 +454,8 @@ func (q *fakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time } func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() apps := make([]database.WorkspaceApp, 0) for _, app := range q.workspaceApps { @@ -466,8 +470,8 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U } func (q *fakeQuerier) GetWorkspacesAutostart(_ context.Context) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaces := make([]database.Workspace, 0) for _, ws := range q.workspaces { if ws.AutostartSchedule.String != "" { @@ -480,8 +484,8 @@ func (q *fakeQuerier) GetWorkspacesAutostart(_ context.Context) ([]database.Work } func (q *fakeQuerier) GetWorkspaceOwnerCountsByTemplateIDs(_ context.Context, templateIDs []uuid.UUID) ([]database.GetWorkspaceOwnerCountsByTemplateIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() counts := map[uuid.UUID]map[uuid.UUID]struct{}{} for _, templateID := range templateIDs { @@ -520,8 +524,8 @@ func (q *fakeQuerier) GetWorkspaceOwnerCountsByTemplateIDs(_ context.Context, te } func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, history := range q.workspaceBuilds { if history.ID.String() == id.String() { @@ -532,8 +536,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, build := range q.workspaceBuilds { if build.JobID.String() == jobID.String() { @@ -544,8 +548,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUI } func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var row database.WorkspaceBuild var buildNum int32 @@ -562,8 +566,8 @@ func (q *fakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(_ context.Context, wo } func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() builds := make(map[uuid.UUID]database.WorkspaceBuild) buildNumbers := make(map[uuid.UUID]int32) @@ -590,8 +594,8 @@ func (q *fakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceID(_ context.Context, params database.GetWorkspaceBuildByWorkspaceIDParams) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() history := make([]database.WorkspaceBuild, 0) for _, workspaceBuild := range q.workspaceBuilds { @@ -644,8 +648,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceID(_ context.Context, } func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndName(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndNameParams) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspaceBuild := range q.workspaceBuilds { if workspaceBuild.WorkspaceID.String() != arg.WorkspaceID.String() { @@ -660,8 +664,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndName(_ context.Context, a } func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, workspaceBuild := range q.workspaceBuilds { if workspaceBuild.WorkspaceID.String() != arg.WorkspaceID.String() { @@ -676,8 +680,8 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Con } func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceBuilds := make([]database.WorkspaceBuild, 0) for _, workspaceBuild := range q.workspaceBuilds { @@ -689,8 +693,8 @@ func (q *fakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() if len(q.organizations) == 0 { return nil, sql.ErrNoRows @@ -699,8 +703,8 @@ func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organizati } func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organization := range q.organizations { if organization.ID == id { @@ -711,8 +715,8 @@ func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (data } func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organization := range q.organizations { if organization.Name == name { @@ -723,8 +727,8 @@ func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (dat } func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UUID) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() organizations := make([]database.Organization, 0) for _, organizationMember := range q.organizationMembers { @@ -745,8 +749,8 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UU } func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database.GetParameterValuesByScopeParams) ([]database.ParameterValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameterValues := make([]database.ParameterValue, 0) for _, parameterValue := range q.parameterValues { @@ -765,8 +769,8 @@ func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database. } func (q *fakeQuerier) GetTemplateByID(_ context.Context, id uuid.UUID) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, template := range q.templates { if template.ID.String() == id.String() { @@ -777,8 +781,8 @@ func (q *fakeQuerier) GetTemplateByID(_ context.Context, id uuid.UUID) (database } func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, template := range q.templates { if template.OrganizationID != arg.OrganizationID { @@ -796,8 +800,8 @@ func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg da } func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for idx, tpl := range q.templates { if tpl.ID != arg.ID { @@ -815,8 +819,8 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd } func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var templates []database.Template for _, template := range q.templates { @@ -853,8 +857,8 @@ func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.Get } func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.TemplateID.UUID.String() != arg.TemplateID.String() { @@ -912,8 +916,8 @@ func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg dat } func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() versions := make([]database.TemplateVersion, 0) for _, version := range q.templateVersions { @@ -925,8 +929,8 @@ func (q *fakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after t } func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.TemplateID != arg.TemplateID { @@ -941,8 +945,8 @@ func (q *fakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, a } func (q *fakeQuerier) GetTemplateVersionByID(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.ID.String() != templateVersionID.String() { @@ -954,8 +958,8 @@ func (q *fakeQuerier) GetTemplateVersionByID(_ context.Context, templateVersionI } func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, templateVersion := range q.templateVersions { if templateVersion.JobID.String() != jobID.String() { @@ -967,8 +971,8 @@ func (q *fakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UU } func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameters := make([]database.ParameterSchema, 0) for _, parameterSchema := range q.parameterSchemas { @@ -984,8 +988,8 @@ func (q *fakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.U } func (q *fakeQuerier) GetParameterSchemasCreatedAfter(_ context.Context, after time.Time) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() parameters := make([]database.ParameterSchema, 0) for _, parameterSchema := range q.parameterSchemas { @@ -997,8 +1001,8 @@ func (q *fakeQuerier) GetParameterSchemasCreatedAfter(_ context.Context, after t } func (q *fakeQuerier) GetParameterValueByScopeAndName(_ context.Context, arg database.GetParameterValueByScopeAndNameParams) (database.ParameterValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, parameterValue := range q.parameterValues { if parameterValue.Scope != arg.Scope { @@ -1016,15 +1020,15 @@ func (q *fakeQuerier) GetParameterValueByScopeAndName(_ context.Context, arg dat } func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return q.templates[:], nil } func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, organizationMember := range q.organizationMembers { if organizationMember.OrganizationID != arg.OrganizationID { @@ -1039,8 +1043,8 @@ func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg datab } func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) for _, userID := range ids { @@ -1062,8 +1066,8 @@ func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui } func (q *fakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() var memberships []database.OrganizationMember for _, organizationMember := range q.organizationMembers { @@ -1099,8 +1103,8 @@ func (q *fakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe } func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() if len(q.provisionerDaemons) == 0 { return nil, sql.ErrNoRows @@ -1109,8 +1113,8 @@ func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi } func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1123,8 +1127,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken } func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1137,8 +1141,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { @@ -1151,8 +1155,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI } func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceAgents := make([]database.WorkspaceAgent, 0) for _, agent := range q.provisionerJobAgents { @@ -1170,8 +1174,8 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc } func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() workspaceAgents := make([]database.WorkspaceAgent, 0) for _, agent := range q.provisionerJobAgents { @@ -1183,8 +1187,8 @@ func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, app := range q.workspaceApps { if app.AgentID != arg.AgentID { @@ -1199,8 +1203,8 @@ func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg dat } func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, provisionerDaemon := range q.provisionerDaemons { if provisionerDaemon.ID.String() != id.String() { @@ -1212,8 +1216,8 @@ func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, provisionerJob := range q.provisionerJobs { if provisionerJob.ID.String() != id.String() { @@ -1225,8 +1229,8 @@ func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, resource := range q.provisionerJobResources { if resource.ID.String() == id.String() { @@ -1237,8 +1241,8 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() resources := make([]database.WorkspaceResource, 0) for _, resource := range q.provisionerJobResources { @@ -1254,8 +1258,8 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid } func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() resources := make([]database.WorkspaceResource, 0) for _, resource := range q.provisionerJobResources { @@ -1267,8 +1271,8 @@ func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after } func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() jobs := make([]database.ProvisionerJob, 0) for _, job := range q.provisionerJobs { @@ -1287,8 +1291,8 @@ func (q *fakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID } func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() jobs := make([]database.ProvisionerJob, 0) for _, job := range q.provisionerJobs { @@ -1300,8 +1304,8 @@ func (q *fakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti } func (q *fakeQuerier) GetProvisionerLogsByIDBetween(_ context.Context, arg database.GetProvisionerLogsByIDBetweenParams) ([]database.ProvisionerJobLog, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() logs := make([]database.ProvisionerJobLog, 0) for _, jobLog := range q.provisionerJobLogs { @@ -1986,8 +1990,8 @@ func (q *fakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitS } func (q *fakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, key := range q.gitSSHKey { if key.UserID == userID { @@ -2030,8 +2034,8 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error } func (q *fakeQuerier) GetAuditLogsBefore(_ context.Context, arg database.GetAuditLogsBeforeParams) ([]database.AuditLog, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() logs := make([]database.AuditLog, 0) start := database.AuditLog{} @@ -2089,8 +2093,8 @@ func (q *fakeQuerier) InsertDeploymentID(_ context.Context, id string) error { } func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() return q.deploymentID, nil } diff --git a/coderd/database/databasefake/reentrantlock.go b/coderd/database/databasefake/reentrantlock.go new file mode 100644 index 0000000000000..97f31754df87f --- /dev/null +++ b/coderd/database/databasefake/reentrantlock.go @@ -0,0 +1,98 @@ +package databasefake + +import ( + "bytes" + "runtime" + "strconv" + "sync" +) + +// reentrantLock is a lock that can be locked multiple times from the same goroutine. +// +// Go maintainers insist that this is a Bad Idea and refuse to implement it in the standard library, so let's talk about +// why we're doing it here. +// +// We want to support locking the fake database for the duration of a transaction, so that other goroutines cannot see +// uncommitted transactions. However, we also need to lock the database during queries that are not explicitly in a +// transaction. When a goroutine executing a transaction calls a query, it is already holding the lock, so attempting +// to lock a standard mutex again will deadlock. A reentrant lock neatly solves this problem. +// +// The argument I've heard around why reentrant locks are a Bad Idea points out that it indicates a problem with your +// interface, because some methods must leave the database in an inconsistent state. That criticism applies here +// because the whole reason we need transactions is sometimes individual queries leave the database in an inconsistent +// state. However valid the criticism, the assumption that it becomes the most important factor in all cases is flawed +// and, frankly, patronizing. +// +// Here we do not have the luxury of reinventing the interface, because this fake database is attempting to +// emulate another piece of software which does have this interface: postgres. Basically, the logic that enforces the +// consistency of the database resides at a higher layer than we are emulating. +// +// Some alternatives considered, but rejected: +// +// 1. create an explicit transaction type, which are serialized to a channel and then processed in order. +// * requires implementing each query function twice, once wrapping it in a transaction, and once doing the real +// work +// * cannot support recursive transactions +// 2. store whether we're in a transaction in the Context passed to the queries. +// * changes InTx(func(store) error) -> InTx(ctx, func(ctx2, store) error). Inside the transaction function, +// callers **must use** ctx2 to query. Use of other contexts, like ctx, will deadlock. Adding this tripmine +// to every use of transactions seems like a recipe for bugs that are hard to diagnose (and only show up with the +// fake database). +type reentrantLock struct { + c *sync.Cond + holder uint64 + n uint64 +} + +// getGID returns the goroutine ID. +// +// From https://blog.sgmansfield.com/2015/12/goroutine-ids/ +func getGID() uint64 { + b := make([]byte, 64) + b = b[:runtime.Stack(b, false)] + b = bytes.TrimPrefix(b, []byte("goroutine ")) + b = b[:bytes.IndexByte(b, ' ')] + n, _ := strconv.ParseUint(string(b), 10, 64) + return n +} + +func newReentrantLock() sync.Locker { + return &reentrantLock{ + c: sync.NewCond(&sync.Mutex{}), + holder: 0, + n: 0, + } +} + +func (r *reentrantLock) Lock() { + gID := getGID() + r.c.L.Lock() + defer r.c.L.Unlock() + for { + if r.holder == 0 { + // not held by any goroutine + r.holder = gID + break + } + if r.holder == gID { + // held by us + break + } + r.c.Wait() + } + r.n++ +} + +func (r *reentrantLock) Unlock() { + gID := getGID() + r.c.L.Lock() + defer r.c.L.Unlock() + if r.holder != gID { + panic("unlocked without holding lock") + } + r.n-- + if r.n == 0 { + r.holder = 0 + r.c.Signal() + } +} diff --git a/go.sum b/go.sum index e30eaae0b8da9..a0daa30502eff 100644 --- a/go.sum +++ b/go.sum @@ -1956,6 +1956,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200107162124-548cf772de50/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=