Skip to content

feat: Handle pagination cases where after_id does not exist #1947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: Use correct db inside InTx
Add rules.go for catching this
  • Loading branch information
Emyrk committed Jun 1, 2022
commit cffcd7135958603d59abf8b8d595d3d9d7ad331d
6 changes: 3 additions & 3 deletions coderd/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
}

var organization database.Organization
err = api.Database.InTx(func(db database.Store) error {
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
err = api.Database.InTx(func(store database.Store) error {
organization, err = store.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.New(),
Name: req.Name,
CreatedAt: database.Now(),
Expand All @@ -67,7 +67,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
_, err = store.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: apiKey.UserID,
CreatedAt: database.Now(),
Expand Down
10 changes: 5 additions & 5 deletions coderd/templateversions.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
if paginationParams.AfterID != uuid.Nil {
// See if the record exists first. If the record does not exist, the pagination
// query will not work.
_, err := api.Database.GetTemplateVersionByID(r.Context(), paginationParams.AfterID)
_, err := store.GetTemplateVersionByID(r.Context(), paginationParams.AfterID)
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("record at \"after_id\" (%q) does not exists", paginationParams.AfterID.String()),
Expand All @@ -405,7 +405,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
}
}

versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
versions, err := store.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
TemplateID: template.ID,
AfterID: paginationParams.AfterID,
LimitOpt: int32(paginationParams.Limit),
Expand All @@ -426,7 +426,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
for _, version := range versions {
jobIDs = append(jobIDs, version.JobID)
}
jobs, err := api.Database.GetProvisionerJobsByIDs(r.Context(), jobIDs)
jobs, err := store.GetProvisionerJobsByIDs(r.Context(), jobIDs)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get jobs: %s", err),
Expand Down Expand Up @@ -608,7 +608,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}

provisionerJob, err = api.Database.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
ID: jobID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Expand All @@ -632,7 +632,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}

templateVersion, err = api.Database.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
templateVersion, err = db.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
ID: uuid.New(),
TemplateID: templateID,
OrganizationID: organization.ID,
Expand Down
2 changes: 1 addition & 1 deletion coderd/workspacebuilds.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (api *API) workspaceBuilds(rw http.ResponseWriter, r *http.Request) {
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
}
builds, err = api.Database.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
builds, err = store.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
if xerrors.Is(err, sql.ErrNoRows) {
err = nil
}
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ services:
condition: service_healthy
database:
image: "postgres:14.2"
ports:
- "5432:5432"
environment:
POSTGRES_USER: ${POSTGRES_USER:-username} # The PostgreSQL user (useful to connect to the database)
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} # The PostgreSQL password (useful to connect to the database)
Expand Down
49 changes: 49 additions & 0 deletions scripts/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,52 @@ func doNotCallTFailNowInsideGoroutine(m dsl.Matcher) {
Where(m["t"].Type.Implements("testing.TB") && m["fail"].Text.Matches("^(FailNow|Fatal|Fatalf)$")).
Report("Do not call functions that may call t.FailNow in a goroutine, as this can cause data races (see testing.go:834)")
}

// InTx checks to ensure the database used inside the transaction closure is the transaction
// database, and not the original database that creates the tx.
func InTx(m dsl.Matcher) {
m.Import("github.com/coder/coder/coderd/database")

// ':=' and '=' are different matches. Really...
m.Match(`
$x.InTx(func($y database.Store) error {
$*_
$*_ := $x.$f($*_)
$*_
})
`, `
$x.InTx(func($y database.Store) error {
$*_
$*_ = $x.$f($*_)
$*_
})
`).Where(m["x"].Text != m["y"].Text && m["x"].Type.Implements("database.Store")).
At(m["f"]).
Report("Do not use the database directly within the InTx closure. Use '$y' instead of '$x'.")

// When using a tx closure, ensure that if you pass the db to another
// function inside the closure, it is the tx.
// This will miss more complex cases such as passing the db as apart
// of another struct.
m.Match(
`
$x.InTx(func($y database.Store) error {
$*_
$*_ := $f($*_, $x, $*_)
$*_
})
`, `
$x.InTx(func($y database.Store) error {
$*_
$*_ = $f($*_, $x, $*_)
$*_
})
`, `
$x.InTx(func($y database.Store) error {
$*_
$f($*_, $x, $*_)
$*_
})
`).Where(m["x"].Text != m["y"].Text && m["x"].Type.Implements("database.Store")).
At(m["f"]).Report("Pass the tx database into the '$f' function inside the closure. Use '$y' over $x'")
}