|
5 | 5 | "bytes"
|
6 | 6 | "fmt"
|
7 | 7 | "go/format"
|
8 |
| - "go/parser" |
9 | 8 | "go/token"
|
10 | 9 | "os"
|
11 | 10 | "path"
|
@@ -419,49 +418,43 @@ type querierFunction struct {
|
419 | 418 |
|
420 | 419 | // readQuerierFunctions reads the functions from coderd/database/querier.go
|
421 | 420 | func readQuerierFunctions() ([]querierFunction, error) {
|
422 |
| - localPath, err := localFilePath() |
| 421 | + f, err := parseDBFile("querier.go") |
423 | 422 | if err != nil {
|
424 |
| - return nil, err |
| 423 | + return nil, xerrors.Errorf("parse querier.go: %w", err) |
425 | 424 | }
|
426 |
| - |
427 |
| - // Parse the database package as a whole so all references are resolved across |
428 |
| - // files. |
429 |
| - dirPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") |
430 |
| - packages, err := decorator.ParseDir(token.NewFileSet(), dirPath, func(info os.FileInfo) bool { |
431 |
| - if strings.HasSuffix(info.Name(), "_test.go") { |
432 |
| - return false |
433 |
| - } |
434 |
| - if !strings.HasSuffix(info.Name(), ".go") { |
435 |
| - return false |
436 |
| - } |
437 |
| - return true |
438 |
| - }, parser.ParseComments) |
439 |
| - |
440 |
| - dbPackage := packages["database"] |
441 |
| - |
442 |
| - findFile := func(name string) *dst.File { |
443 |
| - for k, v := range dbPackage.Files { |
444 |
| - if strings.HasSuffix(k, name) { |
445 |
| - return v |
446 |
| - } |
447 |
| - } |
448 |
| - return nil |
449 |
| - } |
450 |
| - |
451 |
| - funcs, err := loadInterfaceFuncs(findFile("querier.go"), "sqlcQuerier") |
| 425 | + funcs, err := loadInterfaceFuncs(f, "sqlcQuerier") |
452 | 426 | if err != nil {
|
453 | 427 | return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
|
454 | 428 | }
|
455 | 429 |
|
| 430 | + customFile, err := parseDBFile("modelqueries.go") |
| 431 | + if err != nil { |
| 432 | + return nil, xerrors.Errorf("parse modelqueriers.go: %w", err) |
| 433 | + } |
456 | 434 | // Custom funcs should be appended after the regular functions
|
457 |
| - customFuncs, err := loadInterfaceFuncs(findFile("modelqueries.go"), "customQuerier") |
| 435 | + customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier") |
458 | 436 | if err != nil {
|
459 | 437 | return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
|
460 | 438 | }
|
461 | 439 |
|
462 | 440 | return append(funcs, customFuncs...), nil
|
463 | 441 | }
|
464 | 442 |
|
| 443 | +func parseDBFile(filename string) (*dst.File, error) { |
| 444 | + localPath, err := localFilePath() |
| 445 | + if err != nil { |
| 446 | + return nil, err |
| 447 | + } |
| 448 | + |
| 449 | + querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename) |
| 450 | + querierData, err := os.ReadFile(querierPath) |
| 451 | + if err != nil { |
| 452 | + return nil, xerrors.Errorf("read %s: %w", filename, err) |
| 453 | + } |
| 454 | + f, err := decorator.Parse(querierData) |
| 455 | + return f, err |
| 456 | +} |
| 457 | + |
465 | 458 | func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
|
466 | 459 | var querier *dst.InterfaceType
|
467 | 460 | for _, decl := range f.Decls {
|
|
0 commit comments