Skip to content

Commit 15e4de1

Browse files
committed
chore: First attempt at upgrading bad validation errors
1 parent 96ff400 commit 15e4de1

File tree

4 files changed

+154
-15
lines changed

4 files changed

+154
-15
lines changed

coderd/httpapi/httpapi.go

+21-15
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,24 @@ import (
1212
"strings"
1313
"time"
1414

15+
"github.com/coder/coder/coderd/httpapi/validate"
16+
1517
"github.com/go-playground/validator/v10"
1618
"golang.org/x/xerrors"
1719

1820
"github.com/coder/coder/coderd/tracing"
1921
"github.com/coder/coder/codersdk"
2022
)
2123

22-
var Validate *validator.Validate
24+
var Validate *validate.Validator
2325

2426
// This init is used to create a validator and register validation-specific
2527
// functionality for the HTTP API.
2628
//
2729
// A single validator instance is used, because it caches struct parsing.
2830
func init() {
29-
Validate = validator.New()
31+
Validate = validate.New()
32+
3033
Validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
3134
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
3235
if name == "-" {
@@ -35,14 +38,13 @@ func init() {
3538
return name
3639
})
3740

38-
nameValidator := func(fl validator.FieldLevel) bool {
41+
nameValidator := func(fl validator.FieldLevel) error {
3942
f := fl.Field().Interface()
4043
str, ok := f.(string)
4144
if !ok {
42-
return false
45+
return xerrors.New("name must be a string")
4346
}
44-
valid := NameValid(str)
45-
return valid == nil
47+
return NameValid(str)
4648
}
4749
for _, tag := range []string{"username", "template_name", "workspace_name"} {
4850
err := Validate.RegisterValidation(tag, nameValidator)
@@ -51,28 +53,26 @@ func init() {
5153
}
5254
}
5355

54-
templateDisplayNameValidator := func(fl validator.FieldLevel) bool {
56+
templateDisplayNameValidator := func(fl validator.FieldLevel) error {
5557
f := fl.Field().Interface()
5658
str, ok := f.(string)
5759
if !ok {
58-
return false
60+
return xerrors.New("template_display_name must be a string")
5961
}
60-
valid := TemplateDisplayNameValid(str)
61-
return valid == nil
62+
return TemplateDisplayNameValid(str)
6263
}
6364
err := Validate.RegisterValidation("template_display_name", templateDisplayNameValidator)
6465
if err != nil {
6566
panic(err)
6667
}
6768

68-
templateVersionNameValidator := func(fl validator.FieldLevel) bool {
69+
templateVersionNameValidator := func(fl validator.FieldLevel) error {
6970
f := fl.Field().Interface()
7071
str, ok := f.(string)
7172
if !ok {
72-
return false
73+
return xerrors.New("template_version_name must be a string")
7374
}
74-
valid := TemplateVersionNameValid(str)
75-
return valid == nil
75+
return TemplateVersionNameValid(str)
7676
}
7777
err = Validate.RegisterValidation("template_version_name", templateVersionNameValidator)
7878
if err != nil {
@@ -168,9 +168,15 @@ func Read(ctx context.Context, rw http.ResponseWriter, r *http.Request, value in
168168
if errors.As(err, &validationErrors) {
169169
apiErrors := make([]codersdk.ValidationError, 0, len(validationErrors))
170170
for _, validationError := range validationErrors {
171+
detail := fmt.Sprintf("Validation failed for tag %q with value: \"%v\"", validationError.Tag(), validationError.Value())
172+
var custom validate.DetailedFieldError
173+
if xerrors.As(validationError, &custom) {
174+
detail = fmt.Sprintf("Validation failed for tag %q=\"%v\": %s", validationError.Tag(), validationError.Value(), custom.Error())
175+
}
176+
171177
apiErrors = append(apiErrors, codersdk.ValidationError{
172178
Field: validationError.Field(),
173-
Detail: fmt.Sprintf("Validation failed for tag %q with value: \"%v\"", validationError.Tag(), validationError.Value()),
179+
Detail: detail,
174180
})
175181
}
176182
Write(ctx, rw, http.StatusBadRequest, codersdk.Response{

coderd/httpapi/httpapi_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,28 @@ func TestRead(t *testing.T) {
121121
require.Equal(t, "value", v.Validations[0].Field)
122122
require.Equal(t, "Validation failed for tag \"required\" with value: \"\"", v.Validations[0].Detail)
123123
})
124+
125+
t.Run("CustomValidateFailure", func(t *testing.T) {
126+
t.Parallel()
127+
type toValidate struct {
128+
Value []string `json:"value" validate:"dive,username"`
129+
}
130+
ctx := context.Background()
131+
rw := httptest.NewRecorder()
132+
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
133+
134+
validate := toValidate{
135+
Value: []string{"+", "random_valid", "n"},
136+
}
137+
require.False(t, httpapi.Read(ctx, rw, r, &validate))
138+
var v codersdk.Response
139+
err := json.NewDecoder(rw.Body).Decode(&v)
140+
require.NoError(t, err)
141+
require.Len(t, v.Validations, 1)
142+
require.Equal(t, "value", v.Validations[0].Field)
143+
fmt.Println(v.Validations[0].Detail)
144+
require.Equal(t, "Validation failed for tag \"username\" with value: \"\"", v.Validations[0].Detail)
145+
})
124146
}
125147

126148
func TestWebsocketCloseMsg(t *testing.T) {

coderd/httpapi/validate/validate.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package validate
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
7+
"github.com/go-playground/validator/v10"
8+
"golang.org/x/xerrors"
9+
)
10+
11+
type Validator struct {
12+
*validator.Validate
13+
}
14+
15+
func New() *Validator {
16+
return &Validator{
17+
Validate: validator.New(),
18+
}
19+
}
20+
21+
type customErrors struct {
22+
Exported any `json:""` // Any is the actual struct to validate
23+
24+
sync.Mutex
25+
errors map[string]error
26+
}
27+
28+
func (ce *customErrors) AddError(field string, err error) {
29+
ce.Lock()
30+
defer ce.Unlock()
31+
32+
ce.errors[field] = err
33+
}
34+
35+
// DetailedFieldError includes a custom "reason" error to explain why the
36+
// validation failed.
37+
type DetailedFieldError struct {
38+
validator.FieldError
39+
Reason error
40+
}
41+
42+
// Struct overrides the default Struct method to allow for custom errors.
43+
func (v *Validator) Struct(value interface{}) error {
44+
c := &customErrors{
45+
errors: make(map[string]error),
46+
Exported: value,
47+
}
48+
err := v.Validate.Struct(c)
49+
if err == nil && len(c.errors) == 0 {
50+
return nil
51+
}
52+
53+
var validErrors validator.ValidationErrors
54+
if xerrors.As(err, &validErrors) {
55+
for i, ve := range validErrors {
56+
fieldName := ve.Namespace()
57+
if reason, ok := c.errors[fieldName]; ok {
58+
validErrors[i] = DetailedFieldError{
59+
FieldError: ve,
60+
Reason: reason,
61+
}
62+
delete(c.errors, fieldName)
63+
}
64+
}
65+
if len(c.errors) > 0 {
66+
panic(fmt.Sprintf("%d custom errors remain: %v", len(c.errors), c.errors))
67+
}
68+
return validErrors
69+
}
70+
return err
71+
}
72+
73+
type FuncWithError func(fl validator.FieldLevel) error
74+
75+
// RegisterValidation adds a validation with the given tag
76+
//
77+
// NOTES:
78+
// - if the key already exists, the previous validation function will be replaced.
79+
// - this method is not thread-safe it is intended that these all be registered prior to any validation
80+
func (v *Validator) RegisterValidation(tag string, fn FuncWithError, callValidationEvenIfNull ...bool) error {
81+
return v.Validate.RegisterValidation(tag, func(fl validator.FieldLevel) bool {
82+
err := fn(fl)
83+
if err != nil {
84+
top := fl.Top().Interface()
85+
ce, ok := (top).(*customErrors)
86+
if ok {
87+
// We cannot get the full namespace resolution. So hopefully
88+
// the field name with the parent type is unique enough.
89+
namespace := fmt.Sprintf("%s.%s=%v",
90+
fl.Parent().Type().Name(), fl.FieldName(), fl.Field().Interface())
91+
ce.AddError(namespace, err)
92+
// Always return false, because this error will be added after.
93+
return true
94+
}
95+
return false
96+
}
97+
return true
98+
}, callValidationEvenIfNull...)
99+
}

codersdk/validate.go

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package codersdk
2+
3+
type Validator struct {
4+
// ValidationErrors is all the validation errors encountered during the
5+
// validation process.
6+
ValidationErrors []ValidationError `json:"validations,omitempty"`
7+
}
8+
9+
func Validate() *Validator {
10+
return &Validator{}
11+
}
12+

0 commit comments

Comments
 (0)