|
1 | 1 | package audit
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - "database/sql" |
5 |
| - "fmt" |
6 |
| - "reflect" |
7 |
| - |
8 |
| - "github.com/google/uuid" |
| 4 | + "github.com/coder/coder/coderd/database" |
9 | 5 | )
|
10 | 6 |
|
11 |
| -// TODO: this might need to be in the database package. |
12 |
| -type Map map[string]interface{} |
| 7 | +// Auditable is mostly a marker interface. It contains a definitive list of all |
| 8 | +// auditable types. If you want to audit a new type, first define it in |
| 9 | +// AuditableResources, then add it to this interface. |
| 10 | +type Auditable interface { |
| 11 | + database.APIKey | |
| 12 | + database.Organization | |
| 13 | + database.OrganizationMember | |
| 14 | + database.Template | |
| 15 | + database.TemplateVersion | |
| 16 | + database.User | |
| 17 | + database.Workspace | |
| 18 | + database.GitSSHKey |
| 19 | +} |
| 20 | + |
| 21 | +// Map is a map of changed fields in an audited resource. It maps field names to |
| 22 | +// the old and new value for that field. |
| 23 | +type Map map[string]OldNew |
| 24 | + |
| 25 | +// OldNew is a pair of values representing the old value and the new value. |
| 26 | +type OldNew struct { |
| 27 | + Old any |
| 28 | + New any |
| 29 | + Secret bool |
| 30 | +} |
13 | 31 |
|
| 32 | +// Empty returns a default value of type T. |
14 | 33 | func Empty[T Auditable]() T {
|
15 | 34 | var t T
|
16 | 35 | return t
|
17 | 36 | }
|
18 | 37 |
|
19 | 38 | // Diff compares two auditable resources and produces a Map of the changed
|
20 | 39 | // values.
|
21 |
| -func Diff[T Auditable](left, right T) Map { |
22 |
| - // Values are equal, return an empty diff. |
23 |
| - if reflect.DeepEqual(left, right) { |
24 |
| - return Map{} |
25 |
| - } |
26 |
| - |
27 |
| - return diffValues(left, right, AuditableResources) |
28 |
| -} |
29 |
| - |
30 |
| -func structName(t reflect.Type) string { |
31 |
| - return t.PkgPath() + "." + t.Name() |
32 |
| -} |
33 |
| - |
34 |
| -func diffValues[T any](left, right T, table Table) Map { |
35 |
| - var ( |
36 |
| - baseDiff = Map{} |
37 |
| - |
38 |
| - leftV = reflect.ValueOf(left) |
39 |
| - |
40 |
| - rightV = reflect.ValueOf(right) |
41 |
| - rightT = reflect.TypeOf(right) |
42 |
| - |
43 |
| - diffKey = table[structName(rightT)] |
44 |
| - ) |
45 |
| - |
46 |
| - if diffKey == nil { |
47 |
| - panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right)) |
48 |
| - } |
49 |
| - |
50 |
| - for i := 0; i < rightT.NumField(); i++ { |
51 |
| - var ( |
52 |
| - leftF = leftV.Field(i) |
53 |
| - rightF = rightV.Field(i) |
54 |
| - |
55 |
| - leftI = leftF.Interface() |
56 |
| - rightI = rightF.Interface() |
57 |
| - |
58 |
| - diffName = rightT.Field(i).Tag.Get("json") |
59 |
| - ) |
60 |
| - |
61 |
| - atype, ok := diffKey[diffName] |
62 |
| - if !ok { |
63 |
| - panic(fmt.Sprintf("dev error: field %q lacks audit information", diffName)) |
64 |
| - } |
65 |
| - |
66 |
| - if atype == ActionIgnore { |
67 |
| - continue |
68 |
| - } |
69 |
| - |
70 |
| - // coerce struct types that would produce bad diffs. |
71 |
| - if leftI, rightI, ok = convertDiffType(leftI, rightI); ok { |
72 |
| - leftF, rightF = reflect.ValueOf(leftI), reflect.ValueOf(rightI) |
73 |
| - } |
74 |
| - |
75 |
| - // If the field is a pointer, dereference it. Nil pointers are coerced |
76 |
| - // to the zero value of their underlying type. |
77 |
| - if leftF.Kind() == reflect.Ptr && rightF.Kind() == reflect.Ptr { |
78 |
| - leftF, rightF = derefPointer(leftF), derefPointer(rightF) |
79 |
| - leftI, rightI = leftF.Interface(), rightF.Interface() |
80 |
| - } |
81 |
| - |
82 |
| - // Recursively walk up nested structs. |
83 |
| - if rightF.Kind() == reflect.Struct { |
84 |
| - baseDiff[diffName] = diffValues(leftI, rightI, table) |
85 |
| - continue |
86 |
| - } |
87 |
| - |
88 |
| - if !reflect.DeepEqual(leftI, rightI) { |
89 |
| - switch atype { |
90 |
| - case ActionTrack: |
91 |
| - baseDiff[diffName] = rightI |
92 |
| - case ActionSecret: |
93 |
| - baseDiff[diffName] = reflect.Zero(rightF.Type()).Interface() |
94 |
| - } |
95 |
| - } |
96 |
| - } |
97 |
| - |
98 |
| - return baseDiff |
99 |
| -} |
100 |
| - |
101 |
| -// convertDiffType converts external struct types to primitive types. |
102 |
| -// |
103 |
| -//nolint:forcetypeassert |
104 |
| -func convertDiffType(left, right any) (newLeft, newRight any, changed bool) { |
105 |
| - switch typed := left.(type) { |
106 |
| - case uuid.UUID: |
107 |
| - return typed.String(), right.(uuid.UUID).String(), true |
108 |
| - |
109 |
| - case uuid.NullUUID: |
110 |
| - leftStr, _ := typed.MarshalText() |
111 |
| - rightStr, _ := right.(uuid.NullUUID).MarshalText() |
112 |
| - return string(leftStr), string(rightStr), true |
113 |
| - |
114 |
| - case sql.NullString: |
115 |
| - leftStr := typed.String |
116 |
| - if !typed.Valid { |
117 |
| - leftStr = "null" |
118 |
| - } |
119 |
| - |
120 |
| - rightStr := right.(sql.NullString).String |
121 |
| - if !right.(sql.NullString).Valid { |
122 |
| - rightStr = "null" |
123 |
| - } |
124 |
| - |
125 |
| - return leftStr, rightStr, true |
126 |
| - |
127 |
| - case sql.NullInt64: |
128 |
| - var leftInt64Ptr *int64 |
129 |
| - var rightInt64Ptr *int64 |
130 |
| - if !typed.Valid { |
131 |
| - leftInt64Ptr = nil |
132 |
| - } else { |
133 |
| - leftInt64Ptr = ptr(typed.Int64) |
134 |
| - } |
135 |
| - |
136 |
| - rightInt64Ptr = ptr(right.(sql.NullInt64).Int64) |
137 |
| - if !right.(sql.NullInt64).Valid { |
138 |
| - rightInt64Ptr = nil |
139 |
| - } |
140 |
| - |
141 |
| - return leftInt64Ptr, rightInt64Ptr, true |
142 |
| - |
143 |
| - default: |
144 |
| - return left, right, false |
145 |
| - } |
146 |
| -} |
147 |
| - |
148 |
| -// derefPointer deferences a reflect.Value that is a pointer to its underlying |
149 |
| -// value. It dereferences recursively until it finds a non-pointer value. If the |
150 |
| -// pointer is nil, it will be coerced to the zero value of the underlying type. |
151 |
| -func derefPointer(ptr reflect.Value) reflect.Value { |
152 |
| - if !ptr.IsNil() { |
153 |
| - // Grab the value the pointer references. |
154 |
| - ptr = ptr.Elem() |
155 |
| - } else { |
156 |
| - // Coerce nil ptrs to zero'd values of their underlying type. |
157 |
| - ptr = reflect.Zero(ptr.Type().Elem()) |
158 |
| - } |
159 |
| - |
160 |
| - // Recursively deref nested pointers. |
161 |
| - if ptr.Kind() == reflect.Ptr { |
162 |
| - return derefPointer(ptr) |
163 |
| - } |
| 40 | +func Diff[T Auditable](a Auditor, left, right T) Map { return a.diff(left, right) } |
164 | 41 |
|
165 |
| - return ptr |
| 42 | +// Differ is used so the enterprise version can implement the diff function in |
| 43 | +// the Auditor feature interface. Only types in the same package as the |
| 44 | +// interface can implement unexported methods. |
| 45 | +type Differ struct { |
| 46 | + DiffFn func(old, new any) Map |
166 | 47 | }
|
167 | 48 |
|
168 |
| -func ptr[T any](x T) *T { |
169 |
| - return &x |
| 49 | +//nolint:unused |
| 50 | +func (d Differ) diff(old, new any) Map { |
| 51 | + return d.DiffFn(old, new) |
170 | 52 | }
|
0 commit comments