Skip to content

chore: automatically generate dbauthz when new queries are added #8007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,051 changes: 2,051 additions & 0 deletions coderd/database/dbauthz/dbauthz.go

Large diffs are not rendered by default.

1,429 changes: 1,429 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go

Large diffs are not rendered by default.

1,634 changes: 0 additions & 1,634 deletions coderd/database/dbauthz/querier.go

This file was deleted.

1,155 changes: 0 additions & 1,155 deletions coderd/database/dbauthz/querier_test.go

This file was deleted.

440 changes: 0 additions & 440 deletions coderd/database/dbauthz/system.go

This file was deleted.

301 changes: 0 additions & 301 deletions coderd/database/dbauthz/system_test.go

This file was deleted.

199 changes: 199 additions & 0 deletions coderd/database/gen/authz/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package main

import (
"go/format"
"go/token"
"log"
"os"

"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/xerrors"
)

func main() {
err := run()
if err != nil {
log.Fatal(err)
}
}

func run() error {
funcs, err := readStoreInterface()
if err != nil {
return err
}
funcByName := map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
declByName := map[string]*dst.FuncDecl{}

dbauthz, err := os.ReadFile("./dbauthz/dbauthz.go")
if err != nil {
return xerrors.Errorf("read dbauthz: %w", err)
}

// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbauthz", goast.New()).Parse(dbauthz)
if err != nil {
return xerrors.Errorf("parse dbauthz: %w", err)
}

for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
// Check if the receiver is the struct we're interested in
starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr)
if !ok {
continue
}
ident, ok := starExpr.X.(*dst.Ident)
if !ok || ident.Name != "querier" {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}

for _, fn := range funcs {
decl, ok := declByName[fn.Name]
if !ok {
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent("q")},
Type: dst.NewIdent("*querier"),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
&dst.ExprStmt{
X: &dst.CallExpr{
Fun: &dst.Ident{
Name: "panic",
},
Args: []dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: "\"Not implemented\"",
},
},
},
},
},
},
Comment on lines +86 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 100% fine and correct, however I definitely prefer your other package stub method as it's easier to read and modify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I'm going to merge these funcs together afterwards

}
}
f.Decls = append(f.Decls, decl)
}

file, err := os.OpenFile("./dbauthz/dbauthz.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return xerrors.Errorf("open dbauthz: %w", err)
}
defer file.Close()

// Required to preserve imports!
restorer := decorator.NewRestorerWithImports("dbauthz", guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore dbauthz: %w", err)
}
err = format.Node(file, restorer.Fset, restored)
return err
}

type storeMethod struct {
Name string
Func *dst.FuncType
}

func readStoreInterface() ([]storeMethod, error) {
querier, err := os.ReadFile("./querier.go")
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
}
f, err := decorator.Parse(querier)
if err != nil {
return nil, err
}

var sqlcQuerier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}

for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
if typeSpec.Name.Name != "sqlcQuerier" {
continue
}
sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if sqlcQuerier == nil {
return nil, xerrors.Errorf("sqlcQuerier not found")
}
funcs := []storeMethod{}
for _, method := range sqlcQuerier.Methods.List {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}

for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} {
if t == nil {
continue
}
for _, f := range t.List {
ident, ok := f.Type.(*dst.Ident)
if !ok {
continue
}
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/coderd/database"
}
}

funcs = append(funcs, storeMethod{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}
3 changes: 3 additions & 0 deletions coderd/database/generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
go run gen/fake/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbfake/dbfake.go

go run gen/authz/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbauthz/dbauthz.go

go run gen/metrics/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbmetrics/dbmetrics.go
)