Skip to content

Commit a96376e

Browse files
authored
chore: Add "required" to allow requring url params (coder#6994)
1 parent 3cca30c commit a96376e

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

coderd/httpapi/queryparams.go

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ type QueryParamParser struct {
2424
// Parsed is a map of all query params that were parsed. This is useful
2525
// for checking if extra query params were passed in.
2626
Parsed map[string]bool
27+
// RequiredParams is a map of all query params that are required. This is useful
28+
// for forcing a value to be provided.
29+
RequiredParams map[string]bool
2730
}
2831

2932
func NewQueryParamParser() *QueryParamParser {
3033
return &QueryParamParser{
31-
Errors: []codersdk.ValidationError{},
32-
Parsed: map[string]bool{},
34+
Errors: []codersdk.ValidationError{},
35+
Parsed: map[string]bool{},
36+
RequiredParams: map[string]bool{},
3337
}
3438
}
3539

@@ -51,6 +55,20 @@ func (p *QueryParamParser) addParsed(key string) {
5155
p.Parsed[key] = true
5256
}
5357

58+
func (p *QueryParamParser) UInt(vals url.Values, def uint64, queryParam string) uint64 {
59+
v, err := parseQueryParam(p, vals, func(v string) (uint64, error) {
60+
return strconv.ParseUint(v, 10, 64)
61+
}, def, queryParam)
62+
if err != nil {
63+
p.Errors = append(p.Errors, codersdk.ValidationError{
64+
Field: queryParam,
65+
Detail: fmt.Sprintf("Query param %q must be a valid positive integer (%s)", queryParam, err.Error()),
66+
})
67+
return 0
68+
}
69+
return v
70+
}
71+
5472
func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int {
5573
v, err := parseQueryParam(p, vals, strconv.Atoi, def, queryParam)
5674
if err != nil {
@@ -62,6 +80,11 @@ func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int
6280
return v
6381
}
6482

83+
func (p *QueryParamParser) Required(queryParam string) *QueryParamParser {
84+
p.RequiredParams[queryParam] = true
85+
return p
86+
}
87+
6588
func (p *QueryParamParser) UUIDorMe(vals url.Values, def uuid.UUID, me uuid.UUID, queryParam string) uuid.UUID {
6689
return ParseCustom(p, vals, def, queryParam, func(v string) (uuid.UUID, error) {
6790
if v == "me" {
@@ -178,6 +201,16 @@ func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T,
178201

179202
func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) {
180203
parser.addParsed(queryParam)
204+
// If the query param is required and not present, return an error.
205+
if parser.RequiredParams[queryParam] && (!vals.Has(queryParam)) {
206+
parser.Errors = append(parser.Errors, codersdk.ValidationError{
207+
Field: queryParam,
208+
Detail: fmt.Sprintf("Query param %q is required", queryParam),
209+
})
210+
return def, nil
211+
}
212+
213+
// If the query param is not present, return the default value.
181214
if !vals.Has(queryParam) || vals.Get(queryParam) == "" {
182215
return def, nil
183216
}

coderd/httpapi/queryparams_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,43 @@ func TestParseQueryParams(t *testing.T) {
195195
testQueryParams(t, expParams, parser, parser.Int)
196196
})
197197

198+
t.Run("UInt", func(t *testing.T) {
199+
t.Parallel()
200+
expParams := []queryParamTestCase[uint64]{
201+
{
202+
QueryParam: "valid_integer",
203+
Value: "100",
204+
Expected: 100,
205+
},
206+
{
207+
QueryParam: "empty",
208+
Value: "",
209+
Expected: 0,
210+
},
211+
{
212+
QueryParam: "no_value",
213+
NoSet: true,
214+
Default: 5,
215+
Expected: 5,
216+
},
217+
{
218+
QueryParam: "negative",
219+
Value: "-10",
220+
Default: 5,
221+
ExpectedErrorContains: "must be a valid positive integer",
222+
},
223+
{
224+
QueryParam: "invalid_integer",
225+
Value: "bogus",
226+
Expected: 0,
227+
ExpectedErrorContains: "must be a valid positive integer",
228+
},
229+
}
230+
231+
parser := httpapi.NewQueryParamParser()
232+
testQueryParams(t, expParams, parser, parser.UInt)
233+
})
234+
198235
t.Run("UUIDs", func(t *testing.T) {
199236
t.Parallel()
200237
expParams := []queryParamTestCase[[]uuid.UUID]{
@@ -237,6 +274,15 @@ func TestParseQueryParams(t *testing.T) {
237274
parser := httpapi.NewQueryParamParser()
238275
testQueryParams(t, expParams, parser, parser.UUIDs)
239276
})
277+
278+
t.Run("Required", func(t *testing.T) {
279+
t.Parallel()
280+
281+
parser := httpapi.NewQueryParamParser()
282+
parser.Required("test_value")
283+
parser.UUID(url.Values{}, uuid.New(), "test_value")
284+
require.Len(t, parser.Errors, 1)
285+
})
240286
}
241287

242288
func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], parser *httpapi.QueryParamParser, parse func(vals url.Values, def T, queryParam string) T) {

0 commit comments

Comments
 (0)