Skip to content

Commit ced50fc

Browse files
Emyrkstirby
authored andcommitted
chore: fix concurrent CommitQuota transactions for unrelated users/orgs (#15261)
The failure condition being fixed is `w1` and `w2` could belong to different users, organizations, and templates and still cause a serializable failure if run concurrently. This is because the old query did a `seq scan` on the `workspace_builds` table. Since that is the table being updated, we really want to prevent that. So before this would fail for any 2 workspaces. Now it only fails if `w1` and `w2` are owned by the same user and organization. (cherry picked from commit 854044e)
1 parent 0cd5066 commit ced50fc

File tree

15 files changed

+982
-23
lines changed

15 files changed

+982
-23
lines changed

coderd/database/db.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type Store interface {
2828
wrapper
2929

3030
Ping(ctx context.Context) (time.Duration, error)
31+
PGLocks(ctx context.Context) (PGLocks, error)
3132
InTx(func(Store) error, *TxOptions) error
3233
}
3334

@@ -48,13 +49,26 @@ type DBTX interface {
4849
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
4950
}
5051

52+
func WithSerialRetryCount(count int) func(*sqlQuerier) {
53+
return func(q *sqlQuerier) {
54+
q.serialRetryCount = count
55+
}
56+
}
57+
5158
// New creates a new database store using a SQL database connection.
52-
func New(sdb *sql.DB) Store {
59+
func New(sdb *sql.DB, opts ...func(*sqlQuerier)) Store {
5360
dbx := sqlx.NewDb(sdb, "postgres")
54-
return &sqlQuerier{
61+
q := &sqlQuerier{
5562
db: dbx,
5663
sdb: dbx,
64+
// This is an arbitrary number.
65+
serialRetryCount: 3,
66+
}
67+
68+
for _, opt := range opts {
69+
opt(q)
5770
}
71+
return q
5872
}
5973

6074
// TxOptions is used to pass some execution metadata to the callers.
@@ -104,6 +118,10 @@ type querier interface {
104118
type sqlQuerier struct {
105119
sdb *sqlx.DB
106120
db DBTX
121+
122+
// serialRetryCount is the number of times to retry a transaction
123+
// if it fails with a serialization error.
124+
serialRetryCount int
107125
}
108126

109127
func (*sqlQuerier) Wrappers() []string {
@@ -143,11 +161,9 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *TxOptions) error {
143161
// If we are in a transaction already, the parent InTx call will handle the retry.
144162
// We do not want to duplicate those retries.
145163
if !inTx && sqlOpts.Isolation == sql.LevelSerializable {
146-
// This is an arbitrarily chosen number.
147-
const retryAmount = 3
148164
var err error
149165
attempts := 0
150-
for attempts = 0; attempts < retryAmount; attempts++ {
166+
for attempts = 0; attempts < q.serialRetryCount; attempts++ {
151167
txOpts.executionCount++
152168
err = q.runTx(function, sqlOpts)
153169
if err == nil {
@@ -203,3 +219,10 @@ func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) er
203219
}
204220
return nil
205221
}
222+
223+
func safeString(s *string) string {
224+
if s == nil {
225+
return "<nil>"
226+
}
227+
return *s
228+
}

coderd/database/dbauthz/dbauthz.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,10 @@ func (q *querier) Ping(ctx context.Context) (time.Duration, error) {
603603
return q.db.Ping(ctx)
604604
}
605605

606+
func (q *querier) PGLocks(ctx context.Context) (database.PGLocks, error) {
607+
return q.db.PGLocks(ctx)
608+
}
609+
606610
// InTx runs the given function in a transaction.
607611
func (q *querier) InTx(function func(querier database.Store) error, txOpts *database.TxOptions) error {
608612
return q.db.InTx(func(tx database.Store) error {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ func TestDBAuthzRecursive(t *testing.T) {
152152
for i := 2; i < method.Type.NumIn(); i++ {
153153
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
154154
}
155-
if method.Name == "InTx" || method.Name == "Ping" || method.Name == "Wrappers" {
155+
if method.Name == "InTx" ||
156+
method.Name == "Ping" ||
157+
method.Name == "Wrappers" ||
158+
method.Name == "PGLocks" {
156159
continue
157160
}
158161
// Log the name of the last method, so if there is a panic, it is

coderd/database/dbauthz/setup_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var errMatchAny = xerrors.New("match any error")
3434
var skipMethods = map[string]string{
3535
"InTx": "Not relevant",
3636
"Ping": "Not relevant",
37+
"PGLocks": "Not relevant",
3738
"Wrappers": "Not relevant",
3839
"AcquireLock": "Not relevant",
3940
"TryAcquireLock": "Not relevant",

coderd/database/dbfake/builder.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package dbfake
2+
3+
import (
4+
"testing"
5+
6+
"github.com/google/uuid"
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/coderd/database"
10+
"github.com/coder/coder/v2/coderd/database/dbauthz"
11+
"github.com/coder/coder/v2/coderd/database/dbgen"
12+
"github.com/coder/coder/v2/coderd/database/dbtime"
13+
"github.com/coder/coder/v2/testutil"
14+
)
15+
16+
type OrganizationBuilder struct {
17+
t *testing.T
18+
db database.Store
19+
seed database.Organization
20+
allUsersAllowance int32
21+
members []uuid.UUID
22+
groups map[database.Group][]uuid.UUID
23+
}
24+
25+
func Organization(t *testing.T, db database.Store) OrganizationBuilder {
26+
return OrganizationBuilder{
27+
t: t,
28+
db: db,
29+
members: []uuid.UUID{},
30+
groups: make(map[database.Group][]uuid.UUID),
31+
}
32+
}
33+
34+
type OrganizationResponse struct {
35+
Org database.Organization
36+
AllUsersGroup database.Group
37+
Members []database.OrganizationMember
38+
Groups []database.Group
39+
}
40+
41+
func (b OrganizationBuilder) EveryoneAllowance(allowance int) OrganizationBuilder {
42+
//nolint: revive // returns modified struct
43+
b.allUsersAllowance = int32(allowance)
44+
return b
45+
}
46+
47+
func (b OrganizationBuilder) Seed(seed database.Organization) OrganizationBuilder {
48+
//nolint: revive // returns modified struct
49+
b.seed = seed
50+
return b
51+
}
52+
53+
func (b OrganizationBuilder) Members(users ...database.User) OrganizationBuilder {
54+
for _, u := range users {
55+
//nolint: revive // returns modified struct
56+
b.members = append(b.members, u.ID)
57+
}
58+
return b
59+
}
60+
61+
func (b OrganizationBuilder) Group(seed database.Group, members ...database.User) OrganizationBuilder {
62+
//nolint: revive // returns modified struct
63+
b.groups[seed] = []uuid.UUID{}
64+
for _, u := range members {
65+
//nolint: revive // returns modified struct
66+
b.groups[seed] = append(b.groups[seed], u.ID)
67+
}
68+
return b
69+
}
70+
71+
func (b OrganizationBuilder) Do() OrganizationResponse {
72+
org := dbgen.Organization(b.t, b.db, b.seed)
73+
74+
ctx := testutil.Context(b.t, testutil.WaitShort)
75+
//nolint:gocritic // builder code needs perms
76+
ctx = dbauthz.AsSystemRestricted(ctx)
77+
everyone, err := b.db.InsertAllUsersGroup(ctx, org.ID)
78+
require.NoError(b.t, err)
79+
80+
if b.allUsersAllowance > 0 {
81+
everyone, err = b.db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
82+
Name: everyone.Name,
83+
DisplayName: everyone.DisplayName,
84+
AvatarURL: everyone.AvatarURL,
85+
QuotaAllowance: b.allUsersAllowance,
86+
ID: everyone.ID,
87+
})
88+
require.NoError(b.t, err)
89+
}
90+
91+
members := make([]database.OrganizationMember, 0)
92+
if len(b.members) > 0 {
93+
for _, u := range b.members {
94+
newMem := dbgen.OrganizationMember(b.t, b.db, database.OrganizationMember{
95+
UserID: u,
96+
OrganizationID: org.ID,
97+
CreatedAt: dbtime.Now(),
98+
UpdatedAt: dbtime.Now(),
99+
Roles: nil,
100+
})
101+
members = append(members, newMem)
102+
}
103+
}
104+
105+
groups := make([]database.Group, 0)
106+
if len(b.groups) > 0 {
107+
for g, users := range b.groups {
108+
g.OrganizationID = org.ID
109+
group := dbgen.Group(b.t, b.db, g)
110+
groups = append(groups, group)
111+
112+
for _, u := range users {
113+
dbgen.GroupMember(b.t, b.db, database.GroupMemberTable{
114+
UserID: u,
115+
GroupID: group.ID,
116+
})
117+
}
118+
}
119+
}
120+
121+
return OrganizationResponse{
122+
Org: org,
123+
AllUsersGroup: everyone,
124+
Members: members,
125+
Groups: groups,
126+
}
127+
}

coderd/database/dbgen/dbgen.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ func OrganizationMember(t testing.TB, db database.Store, orig database.Organizat
407407
}
408408

409409
func Group(t testing.TB, db database.Store, orig database.Group) database.Group {
410+
t.Helper()
411+
410412
name := takeFirst(orig.Name, testutil.GetRandomName(t))
411413
group, err := db.InsertGroup(genCtx, database.InsertGroupParams{
412414
ID: takeFirst(orig.ID, uuid.New()),

coderd/database/dbmem/dbmem.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,10 @@ func (*FakeQuerier) Ping(_ context.Context) (time.Duration, error) {
339339
return 0, nil
340340
}
341341

342+
func (*FakeQuerier) PGLocks(_ context.Context) (database.PGLocks, error) {
343+
return []database.PGLock{}, nil
344+
}
345+
342346
func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error {
343347
if _, ok := tx.FakeQuerier.locks[id]; ok {
344348
return xerrors.Errorf("cannot acquire lock %d: already held", id)

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbtestutil/db.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) {
135135
if o.dumpOnFailure {
136136
t.Cleanup(func() { DumpOnFailure(t, connectionURL) })
137137
}
138-
db = database.New(sqlDB)
138+
// Unit tests should not retry serial transaction failures.
139+
db = database.New(sqlDB, database.WithSerialRetryCount(1))
139140

140141
ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL)
141142
require.NoError(t, err)

coderd/database/dbtestutil/tx.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package dbtestutil
2+
3+
import (
4+
"sync"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"golang.org/x/xerrors"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
)
12+
13+
type DBTx struct {
14+
database.Store
15+
mu sync.Mutex
16+
done chan error
17+
finalErr chan error
18+
}
19+
20+
// StartTx starts a transaction and returns a DBTx object. This allows running
21+
// 2 transactions concurrently in a test more easily.
22+
// Example:
23+
//
24+
// a := StartTx(t, db, opts)
25+
// b := StartTx(t, db, opts)
26+
//
27+
// a.GetUsers(...)
28+
// b.GetUsers(...)
29+
//
30+
// require.NoError(t, a.Done()
31+
func StartTx(t *testing.T, db database.Store, opts *database.TxOptions) *DBTx {
32+
done := make(chan error)
33+
finalErr := make(chan error)
34+
txC := make(chan database.Store)
35+
36+
go func() {
37+
t.Helper()
38+
once := sync.Once{}
39+
count := 0
40+
41+
err := db.InTx(func(store database.Store) error {
42+
// InTx can be retried
43+
once.Do(func() {
44+
txC <- store
45+
})
46+
count++
47+
if count > 1 {
48+
// If you recursively call InTx, then don't use this.
49+
t.Logf("InTx called more than once: %d", count)
50+
assert.NoError(t, xerrors.New("InTx called more than once, this is not allowed with the StartTx helper"))
51+
}
52+
53+
<-done
54+
// Just return nil. The caller should be checking their own errors.
55+
return nil
56+
}, opts)
57+
finalErr <- err
58+
}()
59+
60+
txStore := <-txC
61+
close(txC)
62+
63+
return &DBTx{Store: txStore, done: done, finalErr: finalErr}
64+
}
65+
66+
// Done can only be called once. If you call it twice, it will panic.
67+
func (tx *DBTx) Done() error {
68+
tx.mu.Lock()
69+
defer tx.mu.Unlock()
70+
71+
close(tx.done)
72+
return <-tx.finalErr
73+
}

0 commit comments

Comments
 (0)