Skip to content

Commit fad4574

Browse files
authored
fix: copy StringMap on insert and query in dbmem (#11206)
Addresses the issue in #11185 for the StringMap datatype. There are other slice data types in our database package that also need to be fixed, but that'll be a different PR
1 parent 32c93a8 commit fad4574

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

coderd/database/dbmem/dbmem.go

+36-4
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,10 @@ func (q *FakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUI
580580
if provisionerJob.ID != id {
581581
continue
582582
}
583+
// clone the Tags before returning, since maps are reference types and
584+
// we don't want the caller to be able to mutate the map we have inside
585+
// dbmem!
586+
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
583587
return provisionerJob, nil
584588
}
585589
return database.ProvisionerJob{}, sql.ErrNoRows
@@ -779,6 +783,10 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
779783
provisionerJob.WorkerID = arg.WorkerID
780784
provisionerJob.JobStatus = provisonerJobStatus(provisionerJob)
781785
q.provisionerJobs[index] = provisionerJob
786+
// clone the Tags before returning, since maps are reference types and
787+
// we don't want the caller to be able to mutate the map we have inside
788+
// dbmem!
789+
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
782790
return provisionerJob, nil
783791
}
784792
return database.ProvisionerJob{}, sql.ErrNoRows
@@ -1884,6 +1892,10 @@ func (q *FakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.T
18841892
hungJobs := []database.ProvisionerJob{}
18851893
for _, provisionerJob := range q.provisionerJobs {
18861894
if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) {
1895+
// clone the Tags before appending, since maps are reference types and
1896+
// we don't want the caller to be able to mutate the map we have inside
1897+
// dbmem!
1898+
provisionerJob.Tags = maps.Clone(provisionerJob.Tags)
18871899
hungJobs = append(hungJobs, provisionerJob)
18881900
}
18891901
}
@@ -2191,7 +2203,15 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi
21912203
if len(q.provisionerDaemons) == 0 {
21922204
return nil, sql.ErrNoRows
21932205
}
2194-
return q.provisionerDaemons, nil
2206+
// copy the data so that the caller can't manipulate any data inside dbmem
2207+
// after returning
2208+
out := make([]database.ProvisionerDaemon, len(q.provisionerDaemons))
2209+
copy(out, q.provisionerDaemons)
2210+
for i := range out {
2211+
// maps are reference types, so we need to clone them
2212+
out[i].Tags = maps.Clone(out[i].Tags)
2213+
}
2214+
return out, nil
21952215
}
21962216

21972217
func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
@@ -2209,6 +2229,10 @@ func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID
22092229
for _, job := range q.provisionerJobs {
22102230
for _, id := range ids {
22112231
if id == job.ID {
2232+
// clone the Tags before appending, since maps are reference types and
2233+
// we don't want the caller to be able to mutate the map we have inside
2234+
// dbmem!
2235+
job.Tags = maps.Clone(job.Tags)
22122236
jobs = append(jobs, job)
22132237
break
22142238
}
@@ -2230,6 +2254,10 @@ func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(_ context.Context
22302254
for _, job := range q.provisionerJobs {
22312255
for _, id := range ids {
22322256
if id == job.ID {
2257+
// clone the Tags before appending, since maps are reference types and
2258+
// we don't want the caller to be able to mutate the map we have inside
2259+
// dbmem!
2260+
job.Tags = maps.Clone(job.Tags)
22332261
job := database.GetProvisionerJobsByIDsWithQueuePositionRow{
22342262
ProvisionerJob: job,
22352263
}
@@ -2260,6 +2288,10 @@ func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti
22602288
jobs := make([]database.ProvisionerJob, 0)
22612289
for _, job := range q.provisionerJobs {
22622290
if job.CreatedAt.After(after) {
2291+
// clone the Tags before appending, since maps are reference types and
2292+
// we don't want the caller to be able to mutate the map we have inside
2293+
// dbmem!
2294+
job.Tags = maps.Clone(job.Tags)
22632295
jobs = append(jobs, job)
22642296
}
22652297
}
@@ -4969,7 +5001,7 @@ func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser
49695001
FileID: arg.FileID,
49705002
Type: arg.Type,
49715003
Input: arg.Input,
4972-
Tags: arg.Tags,
5004+
Tags: maps.Clone(arg.Tags),
49735005
TraceMetadata: arg.TraceMetadata,
49745006
}
49755007
job.JobStatus = provisonerJobStatus(job)
@@ -6993,7 +7025,7 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
69937025
continue
69947026
}
69957027
d.Provisioners = arg.Provisioners
6996-
d.Tags = arg.Tags
7028+
d.Tags = maps.Clone(arg.Tags)
69977029
d.Version = arg.Version
69987030
d.LastSeenAt = arg.LastSeenAt
69997031
return d, nil
@@ -7004,7 +7036,7 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
70047036
CreatedAt: arg.CreatedAt,
70057037
Name: arg.Name,
70067038
Provisioners: arg.Provisioners,
7007-
Tags: arg.Tags,
7039+
Tags: maps.Clone(arg.Tags),
70087040
ReplicaID: uuid.NullUUID{},
70097041
LastSeenAt: arg.LastSeenAt,
70107042
Version: arg.Version,

0 commit comments

Comments
 (0)