Skip to content

chore: unit test to enforce authorized queries match args #11211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Also check querycontext arguments
  • Loading branch information
Emyrk committed Dec 14, 2023
commit 87bdb57af0a8b07aa686e80ef310313a3bf71fbe
123 changes: 94 additions & 29 deletions coderd/database/gentest/modelqueries_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gentest

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
Expand All @@ -26,51 +27,65 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
}
parsedRowScanArgs := make(map[string][]ast.Expr)

// Scan custom
var custom []string
for _, fn := range funcsToTrack {
custom = append(custom, fn)
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "../modelqueries.go", nil, parser.SkipObjectResolution)
require.NoError(t, err)

for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if slices.Contains(custom, fn.Name.Name) {
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
}
}
}

fset = token.NewFileSet()
f, err = parser.ParseFile(fset, "../queries.sql.go", nil, parser.SkipObjectResolution)
require.NoError(t, err)

for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if _, ok := funcsToTrack[fn.Name.Name]; ok {
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
}
}
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
return slices.Contains(custom, name)
})
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
_, ok := funcsToTrack[name]
return ok
})
merged := customFns
for k, v := range generatedFns {
merged[k] = v
}

for a, b := range funcsToTrack {
if !compareArgs(t, a, b, parsedRowScanArgs[a], parsedRowScanArgs[b]) {
if !compareFns(t, a, b, merged[a], merged[b]) {
defer func() {
// Run this at the end so the suggested fix is the last thing printed.
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' in their function bodies. "+
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
"and 'db.QueryContext()' arguments in their function bodies. "+
"Make sure to copy the function body from the autogenerated %q body. "+
"Specifically the parameters for 'rows.Scan()'.", a, b, a)
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
}()
}
}

}

func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
type parsedFunc struct {
RowScanArgs []ast.Expr
QueryArgs []ast.Expr
}

func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
require.NoErrorf(t, err, "failed to parse file %q", filename)

parsed := make(map[string]*parsedFunc)
for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if trackFunc(fn.Name.Name) {
parsed[fn.Name.String()] = &parsedFunc{
RowScanArgs: pullRowScanArgs(fn),
QueryArgs: pullQueryArgs(fn),
}
}
}
}

return parsed
}

func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
if a == nil {
t.Errorf("The function %q is missing", aName)
return false
Expand All @@ -79,7 +94,18 @@ func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
t.Errorf("The function %q is missing", bName)
return false
}
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched args for %s and %s", aName, bName)
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
// This is because the actual query param name is different. One uses the
// const, the other uses a variable that is a mutation of the original query.
a.QueryArgs[1] = b.QueryArgs[1]
}
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
return r && q
}

func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
}

func argList(t *testing.T, args []ast.Expr) []string {
Expand All @@ -91,12 +117,51 @@ func argList(t *testing.T, args []ast.Expr) []string {

var argNames []string
for _, arg := range args {
sel := arg.(*ast.UnaryExpr).X.(*ast.SelectorExpr).Sel
argNames = append(argNames, sel.Name)
argname := "unknown"
// This is "&i.Arg" style stuff
if unary, ok := arg.(*ast.UnaryExpr); ok {
argname = unary.X.(*ast.SelectorExpr).Sel.Name
}
if ident, ok := arg.(*ast.Ident); ok {
argname = ident.Name
}
if sel, ok := arg.(*ast.SelectorExpr); ok {
argname = sel.Sel.Name
}
if call, ok := arg.(*ast.CallExpr); ok {
// Eh, this is pg.Array style stuff. Do a best effort.
argname = fmt.Sprintf("call(%d)", len(call.Args))
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
}
}

if argname == "unknown" {
t.Errorf("Unknown arg, cannot parse: %T", arg)
}
argNames = append(argNames, argname)
}
return argNames
}

func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
// find "rows, err :="
if assign, ok := exp.(*ast.AssignStmt); ok {
if len(assign.Lhs) == 2 {
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
// This is rows, err :=
query := assign.Rhs[0].(*ast.CallExpr)
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
return query.Args
}
}
}
}
}
return nil
}

func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
if forStmt, ok := exp.(*ast.ForStmt); ok {
Expand Down