Skip to content

Commit 2ac5329

Browse files
authored
feat(coderd/database): generate foreign key constraints and add database.IsForeignKeyViolation (#9657)
* feat(coderd/database): generate foreign key constraints, add database.IsForeignKeyViolation * address PR comments
1 parent a6f7f71 commit 2ac5329

File tree

5 files changed

+160
-4
lines changed

5 files changed

+160
-4
lines changed

coderd/apikey.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (api *API) postToken(rw http.ResponseWriter, r *http.Request) {
9191
TokenName: tokenName,
9292
})
9393
if err != nil {
94-
if database.IsUniqueViolation(err, database.UniqueIndexApiKeyName) {
94+
if database.IsUniqueViolation(err, database.UniqueIndexAPIKeyName) {
9595
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
9696
Message: fmt.Sprintf("A token with name %q already exists.", tokenName),
9797
Validations: []codersdk.ValidationError{{

coderd/database/errors.go

+22
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,28 @@ func IsUniqueViolation(err error, uniqueConstraints ...UniqueConstraint) bool {
3737
return false
3838
}
3939

40+
// IsForeignKeyViolation checks if the error is due to a foreign key violation.
41+
// If one or more specific foreign key constraints are given as arguments,
42+
// the error must be caused by one of them. If no constraints are given,
43+
// this function returns true for any foreign key violation.
44+
func IsForeignKeyViolation(err error, foreignKeyConstraints ...ForeignKeyConstraint) bool {
45+
var pqErr *pq.Error
46+
if errors.As(err, &pqErr) {
47+
if pqErr.Code.Name() == "foreign_key_violation" {
48+
if len(foreignKeyConstraints) == 0 {
49+
return true
50+
}
51+
for _, fc := range foreignKeyConstraints {
52+
if pqErr.Constraint == string(fc) {
53+
return true
54+
}
55+
}
56+
}
57+
}
58+
59+
return false
60+
}
61+
4062
// IsQueryCanceledError checks if the error is due to a query being canceled.
4163
func IsQueryCanceledError(err error) bool {
4264
var pqErr *pq.Error

coderd/database/foreign_key_constraint.go

+49
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/unique_constraint.go

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scripts/dbgen/main.go

+86-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ return %s
8484
return xerrors.Errorf("generate unique constraints: %w", err)
8585
}
8686

87+
err = generateForeignKeyConstraints()
88+
if err != nil {
89+
return xerrors.Errorf("generate foreign key constraints: %w", err)
90+
}
91+
8792
return nil
8893
}
8994

@@ -125,7 +130,7 @@ func generateUniqueConstraints() error {
125130

126131
s := &bytes.Buffer{}
127132

128-
_, _ = fmt.Fprint(s, `// Code generated by gen/enum. DO NOT EDIT.
133+
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
129134
package database
130135
`)
131136
_, _ = fmt.Fprint(s, `
@@ -160,6 +165,78 @@ const (
160165
return os.WriteFile(outputPath, data, 0o600)
161166
}
162167

168+
// generateForeignKeyConstraints generates the ForeignKeyConstraint enum.
169+
func generateForeignKeyConstraints() error {
170+
localPath, err := localFilePath()
171+
if err != nil {
172+
return err
173+
}
174+
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
175+
176+
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
177+
if err != nil {
178+
return err
179+
}
180+
defer dump.Close()
181+
182+
var foreignKeyConstraints []string
183+
dumpScanner := bufio.NewScanner(dump)
184+
query := ""
185+
for dumpScanner.Scan() {
186+
line := strings.TrimSpace(dumpScanner.Text())
187+
switch {
188+
case strings.HasPrefix(line, "--"):
189+
case line == "":
190+
case strings.HasSuffix(line, ";"):
191+
query += line
192+
if strings.Contains(query, "FOREIGN KEY") {
193+
foreignKeyConstraints = append(foreignKeyConstraints, query)
194+
}
195+
query = ""
196+
default:
197+
query += line + " "
198+
}
199+
}
200+
201+
if err := dumpScanner.Err(); err != nil {
202+
return err
203+
}
204+
205+
s := &bytes.Buffer{}
206+
207+
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
208+
package database
209+
`)
210+
_, _ = fmt.Fprint(s, `
211+
// ForeignKeyConstraint represents a named foreign key constraint on a table.
212+
type ForeignKeyConstraint string
213+
214+
// ForeignKeyConstraint enums.
215+
const (
216+
`)
217+
for _, query := range foreignKeyConstraints {
218+
name := ""
219+
switch {
220+
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
221+
name = strings.Split(query, " ")[6]
222+
default:
223+
return xerrors.Errorf("unknown foreign key constraint format: %s", query)
224+
}
225+
_, _ = fmt.Fprintf(s, "\tForeignKey%s ForeignKeyConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
226+
}
227+
_, _ = fmt.Fprint(s, ")\n")
228+
229+
outputPath := filepath.Join(databasePath, "foreign_key_constraint.go")
230+
231+
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
232+
Comments: true,
233+
})
234+
if err != nil {
235+
return err
236+
}
237+
return os.WriteFile(outputPath, data, 0o600)
238+
}
239+
163240
type stubParams struct {
164241
FuncName string
165242
Parameters string
@@ -560,6 +637,14 @@ func nameFromSnakeCase(s string) string {
560637
ret += "JWT"
561638
case "idx":
562639
ret += "Index"
640+
case "api":
641+
ret += "API"
642+
case "uuid":
643+
ret += "UUID"
644+
case "gitsshkeys":
645+
ret += "GitSSHKeys"
646+
case "fkey":
647+
// ignore
563648
default:
564649
ret += strings.Title(ss)
565650
}

0 commit comments

Comments
 (0)