Skip to content

chore: unit test to enforce authorized queries match args #11211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions coderd/database/gentest/modelqueries_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package gentest_test

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)

// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
// are synced with the autogenerated queries.sql.go. This should probably be
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
// error message.
//
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
// test. Ping @Emyrk to fix it again.
func TestCustomQueriesSyncedRowScan(t *testing.T) {
t.Parallel()

funcsToTrack := map[string]string{
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
}

// Scan custom
var custom []string
for _, fn := range funcsToTrack {
custom = append(custom, fn)
}

customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
return slices.Contains(custom, name)
})
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
_, ok := funcsToTrack[name]
return ok
})
merged := customFns
for k, v := range generatedFns {
merged[k] = v
}

for a, b := range funcsToTrack {
a, b := a, b
if !compareFns(t, a, b, merged[a], merged[b]) {
//nolint:revive
defer func() {
// Run this at the end so the suggested fix is the last thing printed.
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
"and 'db.QueryContext()' arguments in their function bodies. "+
"Make sure to copy the function body from the autogenerated %q body. "+
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
}()
}
}
}

type parsedFunc struct {
RowScanArgs []ast.Expr
QueryArgs []ast.Expr
}

func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
require.NoErrorf(t, err, "failed to parse file %q", filename)

parsed := make(map[string]*parsedFunc)
for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if trackFunc(fn.Name.Name) {
parsed[fn.Name.String()] = &parsedFunc{
RowScanArgs: pullRowScanArgs(fn),
QueryArgs: pullQueryArgs(fn),
}
}
}
}

return parsed
}

func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
if a == nil {
t.Errorf("The function %q is missing", aName)
return false
}
if b == nil {
t.Errorf("The function %q is missing", bName)
return false
}
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
// This is because the actual query param name is different. One uses the
// const, the other uses a variable that is a mutation of the original query.
a.QueryArgs[1] = b.QueryArgs[1]
}
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
return r && q
}

func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
}

func argList(t *testing.T, args []ast.Expr) []string {
defer func() {
if r := recover(); r != nil {
t.Errorf("Recovered in f reading arg names: %s", r)
}
}()

var argNames []string
for _, arg := range args {
argname := "unknown"
// This is "&i.Arg" style stuff
if unary, ok := arg.(*ast.UnaryExpr); ok {
argname = unary.X.(*ast.SelectorExpr).Sel.Name
}
if ident, ok := arg.(*ast.Ident); ok {
argname = ident.Name
}
if sel, ok := arg.(*ast.SelectorExpr); ok {
argname = sel.Sel.Name
}
if call, ok := arg.(*ast.CallExpr); ok {
// Eh, this is pg.Array style stuff. Do a best effort.
argname = fmt.Sprintf("call(%d)", len(call.Args))
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
}
}

if argname == "unknown" {
t.Errorf("Unknown arg, cannot parse: %T", arg)
}
argNames = append(argNames, argname)
}
return argNames
}

func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
// find "rows, err :="
if assign, ok := exp.(*ast.AssignStmt); ok {
if len(assign.Lhs) == 2 {
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
// This is rows, err :=
query := assign.Rhs[0].(*ast.CallExpr)
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
return query.Args
}
}
}
}
}
return nil
}

func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
if forStmt, ok := exp.(*ast.ForStmt); ok {
// This came from the debugger window and tracking it down.
rowScan := (forStmt.Body.
// Second statement in the for loop is the if statement
// with rows.can
List[1].(*ast.IfStmt).
// This is the err := rows.Scan()
Init.(*ast.AssignStmt).
// Rhs is the row.Scan part
Rhs)[0].(*ast.CallExpr)
return rowScan.Args
}
}
return nil
}