Skip to content

chore: fix concurrent CommitQuota transactions for unrelated users/orgs #15261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions coderd/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Store interface {
wrapper

Ping(ctx context.Context) (time.Duration, error)
PGLocks(ctx context.Context) (PGLocks, error)
InTx(func(Store) error, *TxOptions) error
}

Expand All @@ -48,13 +49,26 @@ type DBTX interface {
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}

func WithSerialRetryCount(count int) func(*sqlQuerier) {
return func(q *sqlQuerier) {
q.serialRetryCount = count
}
}

// New creates a new database store using a SQL database connection.
func New(sdb *sql.DB) Store {
func New(sdb *sql.DB, opts ...func(*sqlQuerier)) Store {
dbx := sqlx.NewDb(sdb, "postgres")
return &sqlQuerier{
q := &sqlQuerier{
db: dbx,
sdb: dbx,
// This is an arbitrary number.
serialRetryCount: 3,
}

for _, opt := range opts {
opt(q)
}
return q
}

// TxOptions is used to pass some execution metadata to the callers.
Expand Down Expand Up @@ -104,6 +118,10 @@ type querier interface {
type sqlQuerier struct {
sdb *sqlx.DB
db DBTX

// serialRetryCount is the number of times to retry a transaction
// if it fails with a serialization error.
serialRetryCount int
}

func (*sqlQuerier) Wrappers() []string {
Expand Down Expand Up @@ -143,11 +161,9 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *TxOptions) error {
// If we are in a transaction already, the parent InTx call will handle the retry.
// We do not want to duplicate those retries.
if !inTx && sqlOpts.Isolation == sql.LevelSerializable {
// This is an arbitrarily chosen number.
const retryAmount = 3
var err error
attempts := 0
for attempts = 0; attempts < retryAmount; attempts++ {
for attempts = 0; attempts < q.serialRetryCount; attempts++ {
txOpts.executionCount++
err = q.runTx(function, sqlOpts)
if err == nil {
Expand Down Expand Up @@ -203,3 +219,10 @@ func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) er
}
return nil
}

func safeString(s *string) string {
if s == nil {
return "<nil>"
}
return *s
}
4 changes: 4 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,10 @@ func (q *querier) Ping(ctx context.Context) (time.Duration, error) {
return q.db.Ping(ctx)
}

func (q *querier) PGLocks(ctx context.Context) (database.PGLocks, error) {
return q.db.PGLocks(ctx)
}

// InTx runs the given function in a transaction.
func (q *querier) InTx(function func(querier database.Store) error, txOpts *database.TxOptions) error {
return q.db.InTx(func(tx database.Store) error {
Expand Down
5 changes: 4 additions & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ func TestDBAuthzRecursive(t *testing.T) {
for i := 2; i < method.Type.NumIn(); i++ {
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
}
if method.Name == "InTx" || method.Name == "Ping" || method.Name == "Wrappers" {
if method.Name == "InTx" ||
method.Name == "Ping" ||
method.Name == "Wrappers" ||
method.Name == "PGLocks" {
continue
}
// Log the name of the last method, so if there is a panic, it is
Expand Down
1 change: 1 addition & 0 deletions coderd/database/dbauthz/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var errMatchAny = xerrors.New("match any error")
var skipMethods = map[string]string{
"InTx": "Not relevant",
"Ping": "Not relevant",
"PGLocks": "Not relevant",
"Wrappers": "Not relevant",
"AcquireLock": "Not relevant",
"TryAcquireLock": "Not relevant",
Expand Down
127 changes: 127 additions & 0 deletions coderd/database/dbfake/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package dbfake

import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/testutil"
)

type OrganizationBuilder struct {
t *testing.T
db database.Store
seed database.Organization
allUsersAllowance int32
members []uuid.UUID
groups map[database.Group][]uuid.UUID
}

func Organization(t *testing.T, db database.Store) OrganizationBuilder {
return OrganizationBuilder{
t: t,
db: db,
members: []uuid.UUID{},
groups: make(map[database.Group][]uuid.UUID),
}
}

type OrganizationResponse struct {
Org database.Organization
AllUsersGroup database.Group
Members []database.OrganizationMember
Groups []database.Group
}

func (b OrganizationBuilder) EveryoneAllowance(allowance int) OrganizationBuilder {
//nolint: revive // returns modified struct
b.allUsersAllowance = int32(allowance)
return b
}

func (b OrganizationBuilder) Seed(seed database.Organization) OrganizationBuilder {
//nolint: revive // returns modified struct
b.seed = seed
return b
}

func (b OrganizationBuilder) Members(users ...database.User) OrganizationBuilder {
for _, u := range users {
//nolint: revive // returns modified struct
b.members = append(b.members, u.ID)
}
return b
}

func (b OrganizationBuilder) Group(seed database.Group, members ...database.User) OrganizationBuilder {
//nolint: revive // returns modified struct
b.groups[seed] = []uuid.UUID{}
for _, u := range members {
//nolint: revive // returns modified struct
b.groups[seed] = append(b.groups[seed], u.ID)
}
return b
}

func (b OrganizationBuilder) Do() OrganizationResponse {
org := dbgen.Organization(b.t, b.db, b.seed)

ctx := testutil.Context(b.t, testutil.WaitShort)
//nolint:gocritic // builder code needs perms
ctx = dbauthz.AsSystemRestricted(ctx)
everyone, err := b.db.InsertAllUsersGroup(ctx, org.ID)
require.NoError(b.t, err)

if b.allUsersAllowance > 0 {
everyone, err = b.db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
Name: everyone.Name,
DisplayName: everyone.DisplayName,
AvatarURL: everyone.AvatarURL,
QuotaAllowance: b.allUsersAllowance,
ID: everyone.ID,
})
require.NoError(b.t, err)
}

members := make([]database.OrganizationMember, 0)
if len(b.members) > 0 {
for _, u := range b.members {
newMem := dbgen.OrganizationMember(b.t, b.db, database.OrganizationMember{
UserID: u,
OrganizationID: org.ID,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
Roles: nil,
})
members = append(members, newMem)
}
}

groups := make([]database.Group, 0)
if len(b.groups) > 0 {
for g, users := range b.groups {
g.OrganizationID = org.ID
group := dbgen.Group(b.t, b.db, g)
groups = append(groups, group)

for _, u := range users {
dbgen.GroupMember(b.t, b.db, database.GroupMemberTable{
UserID: u,
GroupID: group.ID,
})
}
}
}

return OrganizationResponse{
Org: org,
AllUsersGroup: everyone,
Members: members,
Groups: groups,
}
}
2 changes: 2 additions & 0 deletions coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ func OrganizationMember(t testing.TB, db database.Store, orig database.Organizat
}

func Group(t testing.TB, db database.Store, orig database.Group) database.Group {
t.Helper()

name := takeFirst(orig.Name, testutil.GetRandomName(t))
group, err := db.InsertGroup(genCtx, database.InsertGroupParams{
ID: takeFirst(orig.ID, uuid.New()),
Expand Down
4 changes: 4 additions & 0 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ func (*FakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}

func (*FakeQuerier) PGLocks(_ context.Context) (database.PGLocks, error) {
return []database.PGLock{}, nil
}

func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error {
if _, ok := tx.FakeQuerier.locks[id]; ok {
return xerrors.Errorf("cannot acquire lock %d: already held", id)
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion coderd/database/dbtestutil/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) {
if o.dumpOnFailure {
t.Cleanup(func() { DumpOnFailure(t, connectionURL) })
}
db = database.New(sqlDB)
// Unit tests should not retry serial transaction failures.
db = database.New(sqlDB, database.WithSerialRetryCount(1))

ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL)
require.NoError(t, err)
Expand Down
73 changes: 73 additions & 0 deletions coderd/database/dbtestutil/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package dbtestutil

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
)

type DBTx struct {
database.Store
mu sync.Mutex
done chan error
finalErr chan error
}

// StartTx starts a transaction and returns a DBTx object. This allows running
// 2 transactions concurrently in a test more easily.
// Example:
//
// a := StartTx(t, db, opts)
// b := StartTx(t, db, opts)
//
// a.GetUsers(...)
// b.GetUsers(...)
//
// require.NoError(t, a.Done()
func StartTx(t *testing.T, db database.Store, opts *database.TxOptions) *DBTx {
done := make(chan error)
finalErr := make(chan error)
txC := make(chan database.Store)

go func() {
t.Helper()
once := sync.Once{}
count := 0

err := db.InTx(func(store database.Store) error {
// InTx can be retried
once.Do(func() {
txC <- store
})
count++
if count > 1 {
// If you recursively call InTx, then don't use this.
t.Logf("InTx called more than once: %d", count)
assert.NoError(t, xerrors.New("InTx called more than once, this is not allowed with the StartTx helper"))
}

<-done
// Just return nil. The caller should be checking their own errors.
return nil
}, opts)
finalErr <- err
}()

txStore := <-txC
close(txC)

return &DBTx{Store: txStore, done: done, finalErr: finalErr}
}

// Done can only be called once. If you call it twice, it will panic.
func (tx *DBTx) Done() error {
tx.mu.Lock()
defer tx.mu.Unlock()

close(tx.done)
return <-tx.finalErr
}
Loading
Loading