Skip to content

fix: copy StringMap on insert and query in dbmem #11206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ func (q *FakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUI
if provisionerJob.ID != id {
continue
}
// clone the Tags before returning, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
return provisionerJob, nil
}
return database.ProvisionerJob{}, sql.ErrNoRows
Expand Down Expand Up @@ -779,6 +783,10 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
provisionerJob.WorkerID = arg.WorkerID
provisionerJob.JobStatus = provisonerJobStatus(provisionerJob)
q.provisionerJobs[index] = provisionerJob
// clone the Tags before returning, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
return provisionerJob, nil
}
return database.ProvisionerJob{}, sql.ErrNoRows
Expand Down Expand Up @@ -1884,6 +1892,10 @@ func (q *FakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.T
hungJobs := []database.ProvisionerJob{}
for _, provisionerJob := range q.provisionerJobs {
if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) {
// clone the Tags before appending, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
hungJobs = append(hungJobs, provisionerJob)
}
}
Expand Down Expand Up @@ -2191,7 +2203,15 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi
if len(q.provisionerDaemons) == 0 {
return nil, sql.ErrNoRows
}
return q.provisionerDaemons, nil
// copy the data so that the caller can't manipulate any data inside dbmem
// after returning
out := make([]database.ProvisionerDaemon, len(q.provisionerDaemons))
copy(out, q.provisionerDaemons)
for i := range out {
// maps are reference types, so we need to clone them
out[i].Tags = maps.Clone(out[i].Tags)
}
return out, nil
}

func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
Expand All @@ -2209,6 +2229,10 @@ func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID
for _, job := range q.provisionerJobs {
for _, id := range ids {
if id == job.ID {
// clone the Tags before appending, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
job.Tags = maps.Clone(job.Tags)
jobs = append(jobs, job)
break
}
Expand All @@ -2230,6 +2254,10 @@ func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(_ context.Context
for _, job := range q.provisionerJobs {
for _, id := range ids {
if id == job.ID {
// clone the Tags before appending, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
job.Tags = maps.Clone(job.Tags)
job := database.GetProvisionerJobsByIDsWithQueuePositionRow{
ProvisionerJob: job,
}
Expand Down Expand Up @@ -2260,6 +2288,10 @@ func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti
jobs := make([]database.ProvisionerJob, 0)
for _, job := range q.provisionerJobs {
if job.CreatedAt.After(after) {
// clone the Tags before appending, since maps are reference types and
// we don't want the caller to be able to mutate the map we have inside
// dbmem!
job.Tags = maps.Clone(job.Tags)
jobs = append(jobs, job)
}
}
Expand Down Expand Up @@ -4969,7 +5001,7 @@ func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser
FileID: arg.FileID,
Type: arg.Type,
Input: arg.Input,
Tags: arg.Tags,
Tags: maps.Clone(arg.Tags),
TraceMetadata: arg.TraceMetadata,
}
job.JobStatus = provisonerJobStatus(job)
Expand Down Expand Up @@ -6973,7 +7005,7 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
continue
}
d.Provisioners = arg.Provisioners
d.Tags = arg.Tags
d.Tags = maps.Clone(arg.Tags)
d.Version = arg.Version
d.LastSeenAt = arg.LastSeenAt
return d, nil
Expand All @@ -6984,7 +7016,7 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: arg.Tags,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
Expand Down