@@ -2,12 +2,15 @@ package authzquery_test
2
2
3
3
import (
4
4
"context"
5
+ "database/sql"
5
6
"fmt"
6
7
"reflect"
7
8
"sort"
8
9
"strings"
9
10
"testing"
10
11
12
+ "golang.org/x/xerrors"
13
+
11
14
"github.com/coder/coder/coderd/rbac/regosql"
12
15
13
16
"github.com/google/uuid"
@@ -95,10 +98,11 @@ func (s *MethodTestSuite) RunMethodTest(testCaseF func(t *testing.T, db database
95
98
s .methodAccounting [methodName ]++
96
99
97
100
db := dbfake .New ()
101
+ fakeAuthorizer := & coderdtest.FakeAuthorizer {
102
+ AlwaysReturn : nil ,
103
+ }
98
104
rec := & coderdtest.RecordingAuthorizer {
99
- Wrapped : & coderdtest.FakeAuthorizer {
100
- AlwaysReturn : nil ,
101
- },
105
+ Wrapped : fakeAuthorizer ,
102
106
}
103
107
az := authzquery .NewAuthzQuerier (db , rec , slog .Make ())
104
108
actor := rbac.Subject {
@@ -118,22 +122,25 @@ MethodLoop:
118
122
for i := 0 ; i < azt .NumMethod (); i ++ {
119
123
method := azt .Method (i )
120
124
if method .Name == methodName {
125
+ if len (testCase .Assertions ) > 0 {
126
+ fakeAuthorizer .AlwaysReturn = xerrors .New ("Always fail authz" )
127
+ // If we have assertions, that means the method should FAIL
128
+ // if RBAC will disallow the request. The returned error should
129
+ // be expected to be a NotAuthorizedError.
130
+ erroredResp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
131
+ err := findError (t , erroredResp )
132
+ require .Errorf (t , err , "method %q should an error with disallow authz" , testName )
133
+ require .ErrorIsf (t , err , sql .ErrNoRows , "error should match sql.ErrNoRows" )
134
+ require .ErrorAs (t , err , & authzquery.NotAuthorizedError {}, "error should be NotAuthorizedError" )
135
+ // Set things back to normal.
136
+ fakeAuthorizer .AlwaysReturn = nil
137
+ rec .Reset ()
138
+ }
139
+
121
140
resp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
122
141
// TODO: Should we assert the object returned is the correct one?
123
- for _ , r := range resp {
124
- if r .Type ().Implements (reflect .TypeOf ((* error )(nil )).Elem ()) {
125
- if r .IsNil () {
126
- // no error!
127
- break
128
- }
129
- err , ok := r .Interface ().(error )
130
- if ! ok {
131
- t .Fatal ("error is not an error?!" )
132
- }
133
- require .NoError (t , err , "method %q returned an error" , testName )
134
- break
135
- }
136
- }
142
+ err := findError (t , resp )
143
+ require .NoError (t , err , "method %q returned an error" , testName )
137
144
found = true
138
145
break MethodLoop
139
146
}
@@ -155,6 +162,24 @@ MethodLoop:
155
162
require .NoError (t , rec .AllAsserted (), "all rbac calls must be asserted" )
156
163
}
157
164
165
+ func findError (t * testing.T , values []reflect.Value ) error {
166
+ for _ , r := range values {
167
+ if r .Type ().Implements (reflect .TypeOf ((* error )(nil )).Elem ()) {
168
+ if r .IsNil () {
169
+ // Error is found, but it's nil!
170
+ return nil
171
+ }
172
+ err , ok := r .Interface ().(error )
173
+ if ! ok {
174
+ t .Fatal ("error is not an error?!" )
175
+ }
176
+ return err
177
+ }
178
+ }
179
+ t .Fatal ("no expected error value found in responses (error can be nil)" )
180
+ panic ("unreachable" ) // For compile reasons
181
+ }
182
+
158
183
// A MethodCase contains the inputs to be provided to a single method call,
159
184
// and the assertions to be made on the RBAC checks.
160
185
type MethodCase struct {
0 commit comments