Skip to content

Commit 5876edd

Browse files
committed
workspaceusage: improve locking and tests
1 parent c99327c commit 5876edd

File tree

2 files changed

+87
-42
lines changed

2 files changed

+87
-42
lines changed

coderd/workspaceusage/tracker.go

+62-36
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,24 @@ type Store interface {
2828
// It keeps an internal map of workspace IDs that have been used and
2929
// periodically flushes this to its configured Store.
3030
type Tracker struct {
31-
log slog.Logger // you know, for logs
32-
mut sync.Mutex // protects m
33-
m map[uuid.UUID]struct{} // stores workspace ids
34-
s Store // for flushing data
35-
tickCh <-chan time.Time // controls flush interval
36-
stopTick func() // stops flushing
37-
stopCh chan struct{} // signals us to stop
38-
stopOnce sync.Once // because you only stop once
39-
doneCh chan struct{} // signifies that we have stopped
40-
flushCh chan int // used for testing.
31+
log slog.Logger // you know, for logs
32+
flushLock sync.Mutex // protects m
33+
m *uuidSet // stores workspace ids
34+
s Store // for flushing data
35+
tickCh <-chan time.Time // controls flush interval
36+
stopTick func() // stops flushing
37+
stopCh chan struct{} // signals us to stop
38+
stopOnce sync.Once // because you only stop once
39+
doneCh chan struct{} // signifies that we have stopped
40+
flushCh chan int // used for testing.
4141
}
4242

4343
// New returns a new Tracker. It is the caller's responsibility
4444
// to call Close().
4545
func New(s Store, opts ...Option) *Tracker {
4646
hb := &Tracker{
4747
log: slog.Make(sloghuman.Sink(os.Stderr)),
48-
m: make(map[uuid.UUID]struct{}, 0),
48+
m: &uuidSet{},
4949
s: s,
5050
tickCh: nil,
5151
stopTick: nil,
@@ -103,44 +103,40 @@ func WithTickChannel(c chan time.Time) Option {
103103
// Add marks the workspace with the given ID as having been used recently.
104104
// Tracker will periodically flush this to its configured Store.
105105
func (wut *Tracker) Add(workspaceID uuid.UUID) {
106-
wut.mut.Lock()
107-
wut.m[workspaceID] = struct{}{}
108-
wut.mut.Unlock()
106+
wut.m.Add(workspaceID)
109107
}
110108

111-
// flushLocked updates last_used_at of all current workspace IDs.
112-
// MUST HOLD LOCK BEFORE CALLING
113-
func (wut *Tracker) flushLocked(now time.Time) {
114-
if wut.mut.TryLock() {
115-
panic("developer error: must lock before calling flush()")
116-
}
117-
count := len(wut.m)
118-
defer func() { // only used for testing
119-
if wut.flushCh != nil {
109+
// flush updates last_used_at of all current workspace IDs.
110+
// If this is held while a previous flush is in progress, it will
111+
// deadlock until the previous flush has completed.
112+
func (wut *Tracker) flush(now time.Time) {
113+
var count int
114+
if wut.flushCh != nil { // only used for testing
115+
defer func() {
120116
wut.flushCh <- count
121-
}
122-
}()
117+
}()
118+
}
119+
120+
// Copy our current set of IDs
121+
ids := wut.m.UniqueAndClear()
122+
count = len(ids)
123123
if count == 0 {
124124
wut.log.Debug(context.Background(), "nothing to flush")
125125
return
126126
}
127-
// Copy our current set of IDs
128-
ids := make([]uuid.UUID, 0)
129-
for k := range wut.m {
130-
ids = append(ids, k)
131-
}
132-
// Reset our internal map
133-
wut.m = make(map[uuid.UUID]struct{})
127+
134128
// For ease of testing, sort the IDs lexically
135129
sort.Slice(ids, func(i, j int) bool {
136130
// For some unfathomable reason, byte arrays are not comparable?
137131
return strings.Compare(ids[i].String(), ids[j].String()) < 0
138132
})
139133
// Set a short-ish timeout for this. We don't want to hang forever.
140-
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
134+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
141135
defer cancel()
142136
// nolint: gocritic // system function
143137
authCtx := dbauthz.AsSystemRestricted(ctx)
138+
wut.flushLock.Lock()
139+
defer wut.flushLock.Unlock()
144140
if err := wut.s.BatchUpdateWorkspaceLastUsedAt(authCtx, database.BatchUpdateWorkspaceLastUsedAtParams{
145141
LastUsedAt: now,
146142
IDs: ids,
@@ -164,9 +160,7 @@ func (wut *Tracker) Loop() {
164160
if !ok {
165161
return
166162
}
167-
wut.mut.Lock()
168-
wut.flushLocked(now.UTC())
169-
wut.mut.Unlock()
163+
wut.flush(now.UTC())
170164
}
171165
}
172166
}
@@ -179,3 +173,35 @@ func (wut *Tracker) Close() {
179173
<-wut.doneCh
180174
})
181175
}
176+
177+
// uuidSet is a set of UUIDs. Safe for concurrent usage.
178+
// The zero value can be used.
179+
type uuidSet struct {
180+
l sync.Mutex
181+
m map[uuid.UUID]struct{}
182+
}
183+
184+
func (s *uuidSet) Add(id uuid.UUID) {
185+
s.l.Lock()
186+
defer s.l.Unlock()
187+
if s.m == nil {
188+
s.m = make(map[uuid.UUID]struct{})
189+
}
190+
s.m[id] = struct{}{}
191+
}
192+
193+
// UniqueAndClear returns the unique set of entries in s and
194+
// resets the internal map.
195+
func (s *uuidSet) UniqueAndClear() []uuid.UUID {
196+
s.l.Lock()
197+
defer s.l.Unlock()
198+
if s.m == nil {
199+
s.m = make(map[uuid.UUID]struct{})
200+
}
201+
l := make([]uuid.UUID, 0)
202+
for k := range s.m {
203+
l = append(l, k)
204+
}
205+
s.m = make(map[uuid.UUID]struct{})
206+
return l
207+
}

coderd/workspaceusage/tracker_test.go

+25-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package workspaceusage_test
33
import (
44
"sort"
55
"strings"
6+
"sync"
67
"testing"
78
"time"
89

@@ -61,17 +62,35 @@ func TestTracker(t *testing.T) {
6162
return strings.Compare(ids[i].String(), ids[j].String()) < 0
6263
})
6364

64-
for _, id := range ids {
65-
wut.Add(id)
66-
}
67-
6865
now = dbtime.Now()
6966
mDB.EXPECT().BatchUpdateWorkspaceLastUsedAt(gomock.Any(), database.BatchUpdateWorkspaceLastUsedAtParams{
7067
LastUsedAt: now,
7168
IDs: ids,
7269
}).Times(1)
73-
tickCh <- now
74-
count = <-flushCh
70+
// Try to force a race condition.
71+
var wg sync.WaitGroup
72+
numTicks := 10
73+
count = 0
74+
wg.Add(1)
75+
go func() {
76+
defer wg.Done()
77+
for _, id := range ids {
78+
wut.Add(id)
79+
}
80+
}()
81+
for i := 0; i < numTicks; i++ {
82+
wg.Add(1)
83+
go func() {
84+
defer wg.Done()
85+
tickCh <- now
86+
}()
87+
}
88+
89+
for i := 0; i < numTicks; i++ {
90+
count += <-flushCh
91+
}
92+
93+
wg.Wait()
7594
require.Equal(t, 11, count, "expected one flush with eleven ids")
7695
}
7796

0 commit comments

Comments
 (0)