Skip to content

feat: Guard search queries against common mistakes #6404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 2, 2023
Merged
Prev Previous commit
Next Next commit
Fix unit tests
  • Loading branch information
Emyrk committed Mar 1, 2023
commit 533b55d54ed09b43187970dc5d550d79cf3048bd
21 changes: 7 additions & 14 deletions coderd/httpapi/queryparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ type QueryParamParser struct {
Errors []codersdk.ValidationError
// Parsed is a map of all query params that were parsed. This is useful
// for checking if extra query params were passed in.
Parsed map[string]bool
Parsed map[string]int
}

func NewQueryParamParser() *QueryParamParser {
return &QueryParamParser{
Errors: []codersdk.ValidationError{},
Parsed: map[string]bool{},
Parsed: map[string]int{},
}
}

Expand All @@ -48,12 +48,7 @@ func (p *QueryParamParser) ErrorExcessParams(values url.Values) {
}

func (p *QueryParamParser) addParsed(key string) {
p.Parsed[key] = true
}

func (p *QueryParamParser) hasParsed(key string) bool {
_, ok := p.Parsed[key]
return ok
p.Parsed[key]++
}

func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int {
Expand Down Expand Up @@ -94,13 +89,13 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st
}

func (p *QueryParamParser) Time(vals url.Values, def time.Time, queryParam string, format string) time.Time {
v, err := parseQueryParam(p, vals, func(v string) (time.Time, error) {
return time.Parse(queryParam, format)
v, err := parseQueryParam(p, vals, func(term string) (time.Time, error) {
return time.Parse(format, term)
}, def, queryParam)
if err != nil {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must be a valid date format (%s)", queryParam, format),
Detail: fmt.Sprintf("Query param %q must be a valid date format (%s): %s", queryParam, format, err.Error()),
})
}
return v
Expand Down Expand Up @@ -182,13 +177,11 @@ func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T,
}

func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) {
if parser.hasParsed(queryParam) {
return def, xerrors.Errorf("query param %q provided more than once", queryParam)
}
parser.addParsed(queryParam)
if !vals.Has(queryParam) || vals.Get(queryParam) == "" {
return def, nil
}

str := vals.Get(queryParam)
return parse(str)
}
59 changes: 57 additions & 2 deletions coderd/httpapi/queryparams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"net/http"
"net/url"
"testing"
"time"

"github.com/coder/coder/coderd/database"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
Expand All @@ -26,6 +29,51 @@ type queryParamTestCase[T any] struct {
func TestParseQueryParams(t *testing.T) {
t.Parallel()

t.Run("Enum", func(t *testing.T) {
t.Parallel()

expParams := []queryParamTestCase[database.ResourceType]{
{
QueryParam: "resource_type",
Value: string(database.ResourceTypeWorkspace),
Expected: database.ResourceTypeWorkspace,
},
{
QueryParam: "bad_type",
Value: "foo",
ExpectedErrorContains: "not a valid value",
},
}

parser := httpapi.NewQueryParamParser()
testQueryParams(t, expParams, parser, func(vals url.Values, def database.ResourceType, queryParam string) database.ResourceType {
return httpapi.ParseCustom(parser, vals, def, queryParam, httpapi.ParseEnum[database.ResourceType])
})
})

t.Run("Time", func(t *testing.T) {
t.Parallel()
const layout = "2006-01-02"

expParams := []queryParamTestCase[time.Time]{
{
QueryParam: "date",
Value: "2010-01-01",
Expected: must(time.Parse(layout, "2010-01-01")),
},
{
QueryParam: "bad_date",
Value: "2010",
ExpectedErrorContains: "must be a valid date format",
},
}

parser := httpapi.NewQueryParamParser()
testQueryParams(t, expParams, parser, func(vals url.Values, def time.Time, queryParam string) time.Time {
return parser.Time(vals, time.Time{}, queryParam, layout)
})
})

t.Run("UUID", func(t *testing.T) {
t.Parallel()
me := uuid.New()
Expand Down Expand Up @@ -187,8 +235,8 @@ func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], par
for _, c := range testCases {
// !! Do not run these in parallel !!
t.Run(c.QueryParam, func(t *testing.T) {
v := parse(v, c.Default, c.QueryParam)
require.Equal(t, c.Expected, v, fmt.Sprintf("param=%q value=%q", c.QueryParam, c.Value))
value := parse(v, c.Default, c.QueryParam)
require.Equal(t, c.Expected, value, fmt.Sprintf("param=%q value=%q", c.QueryParam, c.Value))
if c.ExpectedErrorContains != "" {
errors := parser.Errors
require.True(t, len(errors) > 0, "error exist")
Expand All @@ -199,3 +247,10 @@ func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], par
})
}
}

func must[T any](value T, err error) T {
if err != nil {
panic(err)
}
return value
}
24 changes: 18 additions & 6 deletions coderd/searchquery/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func Audit(query string) (database.GetAuditLogsOffsetParams, []codersdk.Validati
// Always lowercase for all searches.
query = strings.ToLower(query)
values, errors := searchTerms(query, func(term string, values url.Values) error {
values.Set("resource_type", term)
values.Add("resource_type", term)
return nil
})
if len(errors) > 0 {
Expand All @@ -46,7 +46,7 @@ func Users(query string) (database.GetUsersParams, []codersdk.ValidationError) {
// Always lowercase for all searches.
query = strings.ToLower(query)
values, errors := searchTerms(query, func(term string, values url.Values) error {
values.Set("search", term)
values.Add("search", term)
return nil
})
if len(errors) > 0 {
Expand Down Expand Up @@ -82,10 +82,10 @@ func Workspace(query string, page codersdk.Pagination, agentInactiveDisconnectTi
parts := splitQueryParameterByDelimiter(term, '/', false)
switch len(parts) {
case 1:
values.Set("name", parts[0])
values.Add("name", parts[0])
case 2:
values.Set("owner", parts[0])
values.Set("name", parts[1])
values.Add("owner", parts[0])
values.Add("name", parts[1])
default:
return xerrors.Errorf("Query element %q can only contain 1 '/'", term)
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func searchTerms(query string, defaultKey func(term string, values url.Values) e
}
}
case 2:
searchValues.Set(strings.ToLower(parts[0]), parts[1])
searchValues.Add(strings.ToLower(parts[0]), parts[1])
default:
return nil, []codersdk.ValidationError{
{
Expand All @@ -134,6 +134,18 @@ func searchTerms(query string, defaultKey func(term string, values url.Values) e
}
}
}

for k := range searchValues {
if len(searchValues[k]) > 1 {
return nil, []codersdk.ValidationError{
{
Field: "q",
Detail: fmt.Sprintf("Query parameter %q provided more than once, found %d times", k, len(searchValues[k])),
},
}
}
}

return searchValues, nil
}

Expand Down