Skip to content

Commit f0eddba

Browse files
Kira-PilotEmyrk
andauthored
chore: Support anonymously embedded fields for audit diffs (coder#5746)
- Anonymously embedded structs are expanded as top level fields. - Unit tests for anonymously embedded structs Co-authored-by: Steven Masley <stevenmasley@coder.com>
1 parent e37bff6 commit f0eddba

File tree

2 files changed

+152
-13
lines changed

2 files changed

+152
-13
lines changed

enterprise/audit/diff.go

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"reflect"
77

88
"github.com/google/uuid"
9+
"golang.org/x/xerrors"
910

1011
"github.com/coder/coder/coderd/audit"
1112
"github.com/coder/coder/coderd/database"
@@ -18,11 +19,9 @@ func structName(t reflect.Type) string {
1819
func diffValues(left, right any, table Table) audit.Map {
1920
var (
2021
baseDiff = audit.Map{}
21-
22-
leftV = reflect.ValueOf(left)
23-
24-
rightV = reflect.ValueOf(right)
25-
rightT = reflect.TypeOf(right)
22+
rightT = reflect.TypeOf(right)
23+
leftV = reflect.ValueOf(left)
24+
rightV = reflect.ValueOf(right)
2625

2726
diffKey = table[structName(rightT)]
2827
)
@@ -31,19 +30,25 @@ func diffValues(left, right any, table Table) audit.Map {
3130
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
3231
}
3332

34-
for i := 0; i < rightT.NumField(); i++ {
35-
if !rightT.Field(i).IsExported() {
36-
continue
37-
}
33+
// allFields contains all top level fields of the struct.
34+
allFields, err := flattenStructFields(leftV, rightV)
35+
if err != nil {
36+
// This should never happen. Only structs should be flattened. If an
37+
// error occurs, an unsupported or non-struct type was passed in.
38+
panic(fmt.Sprintf("dev error: failed to flatten struct fields: %v", err))
39+
}
3840

41+
for _, field := range allFields {
3942
var (
40-
leftF = leftV.Field(i)
41-
rightF = rightV.Field(i)
43+
leftF = field.LeftF
44+
rightF = field.RightF
4245

4346
leftI = leftF.Interface()
4447
rightI = rightF.Interface()
48+
)
4549

46-
diffName = rightT.Field(i).Tag.Get("json")
50+
var (
51+
diffName = field.FieldType.Tag.Get("json")
4752
)
4853

4954
atype, ok := diffKey[diffName]
@@ -145,6 +150,64 @@ func convertDiffType(left, right any) (newLeft, newRight any, changed bool) {
145150
}
146151
}
147152

153+
// fieldDiff has all the required information to return an audit diff for a
154+
// given field.
155+
type fieldDiff struct {
156+
FieldType reflect.StructField
157+
LeftF reflect.Value
158+
RightF reflect.Value
159+
}
160+
161+
// flattenStructFields will return all top level fields for a given structure.
162+
// Only anonymously embedded structs will be recursively flattened such that their
163+
// fields are returned as top level fields. Named nested structs will be returned
164+
// as a single field.
165+
// Conflicting field names need to be handled by the caller.
166+
func flattenStructFields(leftV, rightV reflect.Value) ([]fieldDiff, error) {
167+
// Dereference pointers if the field is a pointer field.
168+
if leftV.Kind() == reflect.Ptr {
169+
leftV = derefPointer(leftV)
170+
rightV = derefPointer(rightV)
171+
}
172+
173+
if leftV.Kind() != reflect.Struct {
174+
return nil, xerrors.Errorf("%q is not a struct, kind=%s", leftV.String(), leftV.Kind())
175+
}
176+
177+
var allFields []fieldDiff
178+
rightT := rightV.Type()
179+
180+
// Loop through all top level fields of the struct.
181+
for i := 0; i < rightT.NumField(); i++ {
182+
if !rightT.Field(i).IsExported() {
183+
continue
184+
}
185+
186+
var (
187+
leftF = leftV.Field(i)
188+
rightF = rightV.Field(i)
189+
)
190+
191+
if rightT.Field(i).Anonymous {
192+
// Anonymous fields are recursively flattened.
193+
anonFields, err := flattenStructFields(leftF, rightF)
194+
if err != nil {
195+
return nil, xerrors.Errorf("flatten anonymous field %q: %w", rightT.Field(i).Name, err)
196+
}
197+
allFields = append(allFields, anonFields...)
198+
continue
199+
}
200+
201+
// Single fields append as is.
202+
allFields = append(allFields, fieldDiff{
203+
LeftF: leftF,
204+
RightF: rightF,
205+
FieldType: rightT.Field(i),
206+
})
207+
}
208+
return allFields, nil
209+
}
210+
148211
// derefPointer deferences a reflect.Value that is a pointer to its underlying
149212
// value. It dereferences recursively until it finds a non-pointer value. If the
150213
// pointer is nil, it will be coerced to the zero value of the underlying type.

enterprise/audit/diff_internal_test.go

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ func Test_diffValues(t *testing.T) {
6666
})
6767
})
6868

69-
//nolint:revive
7069
t.Run("PointerField", func(t *testing.T) {
7170
t.Parallel()
7271

@@ -98,6 +97,83 @@ func Test_diffValues(t *testing.T) {
9897
})
9998
})
10099

100+
//nolint:revive
101+
t.Run("EmbeddedStruct", func(t *testing.T) {
102+
t.Parallel()
103+
104+
type Bar struct {
105+
Baz int `json:"baz"`
106+
Buzz string `json:"buzz"`
107+
}
108+
109+
type PtrBar struct {
110+
Qux string `json:"qux"`
111+
}
112+
113+
type foo struct {
114+
Bar
115+
*PtrBar
116+
TopLevel string `json:"top_level"`
117+
}
118+
119+
table := auditMap(map[any]map[string]Action{
120+
&foo{}: {
121+
"baz": ActionTrack,
122+
"buzz": ActionTrack,
123+
"qux": ActionTrack,
124+
"top_level": ActionTrack,
125+
},
126+
})
127+
128+
runDiffValuesTests(t, table, []diffTest{
129+
{
130+
name: "SingleFieldChange",
131+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
132+
right: foo{TopLevel: "top-after", Bar: Bar{Baz: 0, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}},
133+
exp: audit.Map{
134+
"baz": audit.OldNew{Old: 1, New: 0},
135+
"buzz": audit.OldNew{Old: "before", New: "after"},
136+
"qux": audit.OldNew{Old: "qux-before", New: "qux-after"},
137+
"top_level": audit.OldNew{Old: "top-before", New: "top-after"},
138+
},
139+
},
140+
{
141+
name: "Empty",
142+
left: foo{},
143+
right: foo{},
144+
exp: audit.Map{},
145+
},
146+
{
147+
name: "NoChange",
148+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
149+
right: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
150+
exp: audit.Map{},
151+
},
152+
{
153+
name: "LeftEmpty",
154+
left: foo{},
155+
right: foo{TopLevel: "top-after", Bar: Bar{Baz: 1, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}},
156+
exp: audit.Map{
157+
"baz": audit.OldNew{Old: 0, New: 1},
158+
"buzz": audit.OldNew{Old: "", New: "after"},
159+
"qux": audit.OldNew{Old: "", New: "qux-after"},
160+
"top_level": audit.OldNew{Old: "", New: "top-after"},
161+
},
162+
},
163+
{
164+
name: "RightNil",
165+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
166+
right: foo{},
167+
exp: audit.Map{
168+
"baz": audit.OldNew{Old: 1, New: 0},
169+
"buzz": audit.OldNew{Old: "before", New: ""},
170+
"qux": audit.OldNew{Old: "qux-before", New: ""},
171+
"top_level": audit.OldNew{Old: "top-before", New: ""},
172+
},
173+
},
174+
})
175+
})
176+
101177
// We currently don't support nested structs.
102178
// t.Run("NestedStruct", func(t *testing.T) {
103179
// t.Parallel()

0 commit comments

Comments
 (0)