Skip to content

chore: unit test to enforce authorized queries match args #11211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: unit test to enforce authorized queries match args
  • Loading branch information
Emyrk committed Dec 14, 2023
commit 00c61960434e6e6609cd07ea534a76524c7f0ce9
116 changes: 116 additions & 0 deletions coderd/database/gentest/modelqueries_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package gentest

import (
"go/ast"
"go/parser"
"go/token"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)

// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
// are synced with the autogenerated queries.sql.go. This should probably be
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
// error message.
//
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
// test. Ping @Emyrk to fix it again.
func TestCustomQueriesSyncedRowScan(t *testing.T) {
t.Parallel()

funcsToTrack := map[string]string{
"GetTemplates": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
}
parsedRowScanArgs := make(map[string][]ast.Expr)

// Scan custom
var custom []string
for _, fn := range funcsToTrack {
custom = append(custom, fn)
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "../modelqueries.go", nil, parser.SkipObjectResolution)
require.NoError(t, err)

for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if slices.Contains(custom, fn.Name.Name) {
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
}
}
}

fset = token.NewFileSet()
f, err = parser.ParseFile(fset, "../queries.sql.go", nil, parser.SkipObjectResolution)
require.NoError(t, err)

for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if _, ok := funcsToTrack[fn.Name.Name]; ok {
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
}
}
}

for a, b := range funcsToTrack {
if !compareArgs(t, a, b, parsedRowScanArgs[a], parsedRowScanArgs[b]) {
defer func() {
// Run this at the end so the suggested fix is the last thing printed.
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' in their function bodies. "+
"Make sure to copy the function body from the autogenerated %q body. "+
"Specifically the parameters for 'rows.Scan()'.", a, b, a)
}()
}
}

}

func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
if a == nil {
t.Errorf("The function %q is missing", aName)
return false
}
if b == nil {
t.Errorf("The function %q is missing", bName)
return false
}
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched args for %s and %s", aName, bName)
}

func argList(t *testing.T, args []ast.Expr) []string {
defer func() {
if r := recover(); r != nil {
t.Errorf("Recovered in f reading arg names: %s", r)
}
}()

var argNames []string
for _, arg := range args {
sel := arg.(*ast.UnaryExpr).X.(*ast.SelectorExpr).Sel
argNames = append(argNames, sel.Name)
}
return argNames
}

func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
if forStmt, ok := exp.(*ast.ForStmt); ok {
// This came from the debugger window and tracking it down.
rowScan := (forStmt.Body.
// Second statement in the for loop is the if statement
// with rows.can
List[1].(*ast.IfStmt).
// This is the err := rows.Scan()
Init.(*ast.AssignStmt).
// Rhs is the row.Scan part
Rhs)[0].(*ast.CallExpr)
return rowScan.Args
}
}
return nil
}