Skip to content

Commit 035092a

Browse files
committed
Return err on invalid params
That way the caller can just check err instead of the length of errors.
1 parent 7b132e4 commit 035092a

File tree

3 files changed

+34
-30
lines changed

3 files changed

+34
-30
lines changed

enterprise/coderd/identityprovider/authorize.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,15 @@ type authorizeParams struct {
2626
state string
2727
}
2828

29-
func extractAuthorizeParams(r *http.Request, callbackURL string) (authorizeParams, []codersdk.ValidationError, error) {
29+
func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizeParams, []codersdk.ValidationError, error) {
3030
p := httpapi.NewQueryParamParser()
3131
vals := r.URL.Query()
3232

3333
p.RequiredNotEmpty("state", "response_type", "client_id")
3434

35-
cb, err := url.Parse(callbackURL)
36-
if err != nil {
37-
return authorizeParams{}, nil, err
38-
}
3935
params := authorizeParams{
4036
clientID: p.String(vals, "", "client_id"),
41-
redirectURL: p.RedirectURL(vals, cb, "redirect_uri"),
37+
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
4238
responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]),
4339
scope: p.Strings(vals, []string{}, "scope"),
4440
state: p.String(vals, "", "state"),
@@ -48,7 +44,10 @@ func extractAuthorizeParams(r *http.Request, callbackURL string) (authorizeParam
4844
_ = p.String(vals, "", "redirected")
4945

5046
p.ErrorExcessParams(vals)
51-
return params, p.Errors, nil
47+
if len(p.Errors) > 0 {
48+
return authorizeParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
49+
}
50+
return params, nil, nil
5251
}
5352

5453
/**
@@ -63,17 +62,20 @@ func Authorize(db database.Store, accessURL *url.URL) http.HandlerFunc {
6362
apiKey := httpmw.APIKey(r)
6463
app := httpmw.OAuth2ProviderApp(r)
6564

66-
params, validationErrs, err := extractAuthorizeParams(r, app.CallbackURL)
65+
callbackURL, err := url.Parse(app.CallbackURL)
6766
if err != nil {
6867
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
6968
Message: "Failed to validate query parameters.",
7069
Detail: err.Error(),
7170
})
7271
return
7372
}
74-
if len(validationErrs) > 0 {
73+
74+
params, validationErrs, err := extractAuthorizeParams(r, callbackURL)
75+
if err != nil {
7576
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
7677
Message: "Invalid query params.",
78+
Detail: err.Error(),
7779
Validations: validationErrs,
7880
})
7981
return

enterprise/coderd/identityprovider/middleware.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@ func authorizeMW(accessURL *url.URL) func(next http.Handler) http.Handler {
5757
return
5858
}
5959

60-
// Extract the form parameters for two reasons:
61-
// 1. We need the redirect URI to build the cancel URI.
62-
// 2. Since validation will run once the user clicks "allow", it is
63-
// better to validate now to avoid wasting the user's time clicking a
64-
// button that will just error anyway.
65-
params, errs, err := extractAuthorizeParams(r, app.CallbackURL)
60+
callbackURL, err := url.Parse(app.CallbackURL)
6661
if err != nil {
6762
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
6863
Status: http.StatusInternalServerError,
@@ -75,9 +70,16 @@ func authorizeMW(accessURL *url.URL) func(next http.Handler) http.Handler {
7570
})
7671
return
7772
}
78-
if len(errs) > 0 {
79-
errStr := make([]string, len(errs))
80-
for i, err := range errs {
73+
74+
// Extract the form parameters for two reasons:
75+
// 1. We need the redirect URI to build the cancel URI.
76+
// 2. Since validation will run once the user clicks "allow", it is
77+
// better to validate now to avoid wasting the user's time clicking a
78+
// button that will just error anyway.
79+
params, validationErrs, err := extractAuthorizeParams(r, callbackURL)
80+
if err != nil {
81+
errStr := make([]string, len(validationErrs))
82+
for i, err := range validationErrs {
8183
errStr[i] = err.Detail
8284
}
8385
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{

enterprise/coderd/identityprovider/tokens.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,49 +33,49 @@ type tokenParams struct {
3333
redirectURL *url.URL
3434
}
3535

36-
func extractTokenParams(r *http.Request, callbackURL string) (tokenParams, []codersdk.ValidationError, error) {
36+
func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []codersdk.ValidationError, error) {
3737
p := httpapi.NewQueryParamParser()
3838
err := r.ParseForm()
3939
if err != nil {
4040
return tokenParams{}, nil, xerrors.Errorf("parse form: %w", err)
4141
}
42-
43-
cb, err := url.Parse(callbackURL)
44-
if err != nil {
45-
return tokenParams{}, nil, err
46-
}
47-
4842
p.RequiredNotEmpty("grant_type", "client_secret", "client_id", "code")
4943

5044
vals := r.Form
5145
params := tokenParams{
5246
clientID: p.String(vals, "", "client_id"),
5347
clientSecret: p.String(vals, "", "client_secret"),
5448
code: p.String(vals, "", "code"),
55-
redirectURL: p.RedirectURL(vals, cb, "redirect_uri"),
49+
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
5650
grantType: httpapi.ParseCustom(p, vals, "", "grant_type", httpapi.ParseEnum[codersdk.OAuth2ProviderGrantType]),
5751
}
5852

5953
p.ErrorExcessParams(vals)
60-
return params, p.Errors, nil
54+
if len(p.Errors) > 0 {
55+
return tokenParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
56+
}
57+
return params, nil, nil
6158
}
6259

6360
func Tokens(db database.Store, defaultLifetime time.Duration) http.HandlerFunc {
6461
return func(rw http.ResponseWriter, r *http.Request) {
6562
ctx := r.Context()
6663
app := httpmw.OAuth2ProviderApp(r)
6764

68-
params, validationErrs, err := extractTokenParams(r, app.CallbackURL)
65+
callbackURL, err := url.Parse(app.CallbackURL)
6966
if err != nil {
70-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
67+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
7168
Message: "Failed to validate form values.",
7269
Detail: err.Error(),
7370
})
7471
return
7572
}
76-
if len(validationErrs) > 0 {
73+
74+
params, validationErrs, err := extractTokenParams(r, callbackURL)
75+
if err != nil {
7776
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
7877
Message: "Invalid query params.",
78+
Detail: err.Error(),
7979
Validations: validationErrs,
8080
})
8181
return

0 commit comments

Comments
 (0)