Skip to content

Commit 2a57ea7

Browse files
authored
feat: add audit package (#1046)
1 parent a2dd618 commit 2a57ea7

File tree

8 files changed

+411
-2
lines changed

8 files changed

+411
-2
lines changed

coderd/audit/diff.go

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package audit
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
)
7+
8+
// TODO: this might need to be in the database package.
9+
type Map map[string]interface{}
10+
11+
func Empty[T Auditable]() T {
12+
var t T
13+
return t
14+
}
15+
16+
// Diff compares two auditable resources and produces a Map of the changed
17+
// values.
18+
func Diff[T Auditable](left, right T) Map {
19+
// Values are equal, return an empty diff.
20+
if reflect.DeepEqual(left, right) {
21+
return Map{}
22+
}
23+
24+
return diffValues(left, right, AuditableResources)
25+
}
26+
27+
func structName(t reflect.Type) string {
28+
return t.PkgPath() + "." + t.Name()
29+
}
30+
31+
func diffValues[T any](left, right T, table Table) Map {
32+
var (
33+
baseDiff = Map{}
34+
35+
leftV = reflect.ValueOf(left)
36+
37+
rightV = reflect.ValueOf(right)
38+
rightT = reflect.TypeOf(right)
39+
40+
diffKey = table[structName(rightT)]
41+
)
42+
43+
if diffKey == nil {
44+
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
45+
}
46+
47+
for i := 0; i < rightT.NumField(); i++ {
48+
var (
49+
leftF = leftV.Field(i)
50+
rightF = rightV.Field(i)
51+
52+
leftI = leftF.Interface()
53+
rightI = rightF.Interface()
54+
55+
diffName = rightT.Field(i).Tag.Get("json")
56+
)
57+
58+
atype, ok := diffKey[diffName]
59+
if !ok {
60+
panic(fmt.Sprintf("dev error: field %q lacks audit information", diffName))
61+
}
62+
63+
if atype == ActionIgnore {
64+
continue
65+
}
66+
67+
// If the field is a pointer, dereference it. Nil pointers are coerced
68+
// to the zero value of their underlying type.
69+
if leftF.Kind() == reflect.Ptr && rightF.Kind() == reflect.Ptr {
70+
leftF, rightF = derefPointer(leftF), derefPointer(rightF)
71+
leftI, rightI = leftF.Interface(), rightF.Interface()
72+
}
73+
74+
// Recursively walk up nested structs.
75+
if rightF.Kind() == reflect.Struct {
76+
baseDiff[diffName] = diffValues(leftI, rightI, table)
77+
continue
78+
}
79+
80+
if !reflect.DeepEqual(leftI, rightI) {
81+
switch atype {
82+
case ActionTrack:
83+
baseDiff[diffName] = rightI
84+
case ActionSecret:
85+
baseDiff[diffName] = reflect.Zero(rightF.Type()).Interface()
86+
}
87+
}
88+
}
89+
90+
return baseDiff
91+
}
92+
93+
// derefPointer deferences a reflect.Value that is a pointer to its underlying
94+
// value. It dereferences recursively until it finds a non-pointer value. If the
95+
// pointer is nil, it will be coerced to the zero value of the underlying type.
96+
func derefPointer(ptr reflect.Value) reflect.Value {
97+
if !ptr.IsNil() {
98+
// Grab the value the pointer references.
99+
ptr = ptr.Elem()
100+
} else {
101+
// Coerce nil ptrs to zero'd values of their underlying type.
102+
ptr = reflect.Zero(ptr.Type().Elem())
103+
}
104+
105+
// Recursively deref nested pointers.
106+
if ptr.Kind() == reflect.Ptr {
107+
return derefPointer(ptr)
108+
}
109+
110+
return ptr
111+
}

coderd/audit/diff_internal_test.go

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package audit
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"k8s.io/utils/pointer"
8+
)
9+
10+
func Test_diffValues(t *testing.T) {
11+
t.Parallel()
12+
13+
t.Run("Normal", func(t *testing.T) {
14+
t.Parallel()
15+
16+
type foo struct {
17+
Bar string `json:"bar"`
18+
Baz int64 `json:"baz"`
19+
}
20+
21+
table := auditMap(map[any]map[string]Action{
22+
&foo{}: {
23+
"bar": ActionTrack,
24+
"baz": ActionTrack,
25+
},
26+
})
27+
28+
runDiffTests(t, table, []diffTest{
29+
{
30+
name: "LeftEmpty",
31+
left: foo{Bar: "", Baz: 0}, right: foo{Bar: "bar", Baz: 10},
32+
exp: Map{
33+
"bar": "bar",
34+
"baz": int64(10),
35+
},
36+
},
37+
{
38+
name: "RightEmpty",
39+
left: foo{Bar: "Bar", Baz: 10}, right: foo{Bar: "", Baz: 0},
40+
exp: Map{
41+
"bar": "",
42+
"baz": int64(0),
43+
},
44+
},
45+
{
46+
name: "NoChange",
47+
left: foo{Bar: "", Baz: 0}, right: foo{Bar: "", Baz: 0},
48+
exp: Map{},
49+
},
50+
{
51+
name: "SingleFieldChange",
52+
left: foo{Bar: "", Baz: 0}, right: foo{Bar: "Bar", Baz: 0},
53+
exp: Map{
54+
"bar": "Bar",
55+
},
56+
},
57+
})
58+
})
59+
60+
t.Run("PointerField", func(t *testing.T) {
61+
t.Parallel()
62+
63+
type foo struct {
64+
Bar *string `json:"bar"`
65+
}
66+
67+
table := auditMap(map[any]map[string]Action{
68+
&foo{}: {
69+
"bar": ActionTrack,
70+
},
71+
})
72+
73+
runDiffTests(t, table, []diffTest{
74+
{
75+
name: "LeftNil",
76+
left: foo{Bar: nil}, right: foo{Bar: pointer.StringPtr("baz")},
77+
exp: Map{"bar": "baz"},
78+
},
79+
{
80+
name: "RightNil",
81+
left: foo{Bar: pointer.StringPtr("baz")}, right: foo{Bar: nil},
82+
exp: Map{"bar": ""},
83+
},
84+
})
85+
})
86+
87+
t.Run("NestedStruct", func(t *testing.T) {
88+
t.Parallel()
89+
90+
type bar struct {
91+
Baz string `json:"baz"`
92+
}
93+
94+
type foo struct {
95+
Bar *bar `json:"bar"`
96+
}
97+
98+
table := auditMap(map[any]map[string]Action{
99+
&foo{}: {
100+
"bar": ActionTrack,
101+
},
102+
&bar{}: {
103+
"baz": ActionTrack,
104+
},
105+
})
106+
107+
runDiffTests(t, table, []diffTest{
108+
{
109+
name: "LeftEmpty",
110+
left: foo{Bar: &bar{}}, right: foo{Bar: &bar{Baz: "baz"}},
111+
exp: Map{
112+
"bar": Map{
113+
"baz": "baz",
114+
},
115+
},
116+
},
117+
{
118+
name: "RightEmpty",
119+
left: foo{Bar: &bar{Baz: "baz"}}, right: foo{Bar: &bar{}},
120+
exp: Map{
121+
"bar": Map{
122+
"baz": "",
123+
},
124+
},
125+
},
126+
{
127+
name: "LeftNil",
128+
left: foo{Bar: nil}, right: foo{Bar: &bar{}},
129+
exp: Map{
130+
"bar": Map{},
131+
},
132+
},
133+
{
134+
name: "RightNil",
135+
left: foo{Bar: &bar{Baz: "baz"}}, right: foo{Bar: nil},
136+
exp: Map{
137+
"bar": Map{
138+
"baz": "",
139+
},
140+
},
141+
},
142+
})
143+
})
144+
}
145+
146+
type diffTest struct {
147+
name string
148+
left, right any
149+
exp any
150+
}
151+
152+
func runDiffTests(t *testing.T, table Table, tests []diffTest) {
153+
t.Helper()
154+
155+
for _, test := range tests {
156+
t.Run(test.name, func(t *testing.T) {
157+
assert.Equal(t,
158+
test.exp,
159+
diffValues(test.left, test.right, table),
160+
)
161+
})
162+
}
163+
}

coderd/audit/diff_test.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package audit_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
8+
"github.com/coder/coder/coderd/audit"
9+
"github.com/coder/coder/coderd/database"
10+
)
11+
12+
func TestDiff(t *testing.T) {
13+
t.Parallel()
14+
15+
t.Run("Normal", func(t *testing.T) {
16+
t.Parallel()
17+
18+
runDiffTests(t, []diffTest[database.User]{
19+
{
20+
name: "LeftEmpty",
21+
left: audit.Empty[database.User](), right: database.User{Username: "colin", Email: "colin@coder.com"},
22+
exp: audit.Map{
23+
"email": "colin@coder.com",
24+
},
25+
},
26+
{
27+
name: "RightEmpty",
28+
left: database.User{Username: "colin", Email: "colin@coder.com"}, right: audit.Empty[database.User](),
29+
exp: audit.Map{
30+
"email": "",
31+
},
32+
},
33+
{
34+
name: "NoChange",
35+
left: audit.Empty[database.User](), right: audit.Empty[database.User](),
36+
exp: audit.Map{},
37+
},
38+
})
39+
})
40+
}
41+
42+
type diffTest[T audit.Auditable] struct {
43+
name string
44+
left, right T
45+
exp audit.Map
46+
}
47+
48+
func runDiffTests[T audit.Auditable](t *testing.T, tests []diffTest[T]) {
49+
t.Helper()
50+
51+
for _, test := range tests {
52+
t.Run(test.name, func(t *testing.T) {
53+
require.Equal(t,
54+
test.exp,
55+
audit.Diff(test.left, test.right),
56+
)
57+
})
58+
}
59+
}

0 commit comments

Comments
 (0)