Skip to content

Commit 6a560ae

Browse files
Emyrkpull[bot]
authored andcommitted
chore: make scim auth header case insensitive for 'bearer' (#15538)
Fixes status codes to return more than 500. The way we were using the package, it always returned a status code 500
1 parent a1230a7 commit 6a560ae

File tree

3 files changed

+91
-22
lines changed

3 files changed

+91
-22
lines changed

enterprise/coderd/scim.go

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"bytes"
45
"crypto/subtle"
56
"database/sql"
67
"encoding/json"
@@ -26,16 +27,21 @@ import (
2627
)
2728

2829
func (api *API) scimVerifyAuthHeader(r *http.Request) bool {
29-
bearer := []byte("Bearer ")
30+
bearer := []byte("bearer ")
3031
hdr := []byte(r.Header.Get("Authorization"))
3132

32-
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 {
33+
// Use toLower to make the comparison case-insensitive.
34+
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 {
3335
hdr = hdr[len(bearer):]
3436
}
3537

3638
return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1
3739
}
3840

41+
func scimUnauthorized(rw http.ResponseWriter) {
42+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization")))
43+
}
44+
3945
// scimServiceProviderConfig returns a static SCIM service provider configuration.
4046
//
4147
// @Summary SCIM 2.0: Service Provider Config
@@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques
114120
//nolint:revive
115121
func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
116122
if !api.scimVerifyAuthHeader(r) {
117-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
123+
scimUnauthorized(rw)
118124
return
119125
}
120126

@@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
142148
//nolint:revive
143149
func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) {
144150
if !api.scimVerifyAuthHeader(r) {
145-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
151+
scimUnauthorized(rw)
146152
return
147153
}
148154

149-
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
155+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404")))
150156
}
151157

152158
// We currently use our own struct instead of using the SCIM package. This was
@@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{
192198
func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
193199
ctx := r.Context()
194200
if !api.scimVerifyAuthHeader(r) {
195-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
201+
scimUnauthorized(rw)
196202
return
197203
}
198204

@@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
209215
var sUser SCIMUser
210216
err := json.NewDecoder(r.Body).Decode(&sUser)
211217
if err != nil {
212-
_ = handlerutil.WriteError(rw, err)
218+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
213219
return
214220
}
215221

@@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
222228
}
223229

224230
if email == "" {
225-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"})
231+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided")))
226232
return
227233
}
228234

@@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
232238
Username: sUser.UserName,
233239
})
234240
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
235-
_ = handlerutil.WriteError(rw, err)
241+
_ = handlerutil.WriteError(rw, err) // internal error
236242
return
237243
}
238244
if err == nil {
@@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
248254
UpdatedAt: dbtime.Now(),
249255
})
250256
if err != nil {
251-
_ = handlerutil.WriteError(rw, err)
257+
_ = handlerutil.WriteError(rw, err) // internal error
252258
return
253259
}
254260
aReq.New = newUser
@@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
284290
//nolint:gocritic // SCIM operations are a system user
285291
orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database)
286292
if err != nil {
287-
_ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err))
293+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err)))
288294
return
289295
}
290296
if orgSync.AssignDefault {
291297
//nolint:gocritic // SCIM operations are a system user
292298
defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
293299
if err != nil {
294-
_ = handlerutil.WriteError(rw, err)
300+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err)))
295301
return
296302
}
297303
organizations = append(organizations, defaultOrganization.ID)
@@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
309315
SkipNotifications: true,
310316
})
311317
if err != nil {
312-
_ = handlerutil.WriteError(rw, err)
318+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err)))
313319
return
314320
}
315321
aReq.New = dbUser
@@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
335341
func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
336342
ctx := r.Context()
337343
if !api.scimVerifyAuthHeader(r) {
338-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
344+
scimUnauthorized(rw)
339345
return
340346
}
341347

@@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
354360
var sUser SCIMUser
355361
err := json.NewDecoder(r.Body).Decode(&sUser)
356362
if err != nil {
357-
_ = handlerutil.WriteError(rw, err)
363+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
358364
return
359365
}
360366
sUser.ID = id
361367

362368
uid, err := uuid.Parse(id)
363369
if err != nil {
364-
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"})
370+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err)))
365371
return
366372
}
367373

368374
//nolint:gocritic // needed for SCIM
369375
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid)
370376
if err != nil {
371-
_ = handlerutil.WriteError(rw, err)
377+
_ = handlerutil.WriteError(rw, err) // internal error
372378
return
373379
}
374380
aReq.Old = dbUser
@@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
400406
UpdatedAt: dbtime.Now(),
401407
})
402408
if err != nil {
403-
_ = handlerutil.WriteError(rw, err)
409+
_ = handlerutil.WriteError(rw, err) // internal error
404410
return
405411
}
406412
dbUser = userNew

enterprise/coderd/scim/scimtypes.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
package scim
22

3-
import "time"
3+
import (
4+
"encoding/json"
5+
"time"
6+
7+
"github.com/imulab/go-scim/pkg/v2/spec"
8+
)
49

510
type ServiceProviderConfig struct {
611
Schemas []string `json:"schemas"`
@@ -44,3 +49,37 @@ type AuthenticationScheme struct {
4449
SpecURI string `json:"specUri"`
4550
DocURI string `json:"documentationUri"`
4651
}
52+
53+
// HTTPError wraps a *spec.Error for correct usage with
54+
// 'handlerutil.WriteError'. This error type is cursed to be
55+
// absolutely strange and specific to the SCIM library we use.
56+
//
57+
// The library expects *spec.Error to be returned on unwrap, and the
58+
// internal error description to be returned by a json.Marshal of the
59+
// top level error.
60+
type HTTPError struct {
61+
scim *spec.Error
62+
internal error
63+
}
64+
65+
func NewHTTPError(status int, eType string, err error) *HTTPError {
66+
return &HTTPError{
67+
scim: &spec.Error{
68+
Status: status,
69+
Type: eType,
70+
},
71+
internal: err,
72+
}
73+
}
74+
75+
func (e HTTPError) Error() string {
76+
return e.internal.Error()
77+
}
78+
79+
func (e HTTPError) MarshalJSON() ([]byte, error) {
80+
return json.Marshal(e.internal)
81+
}
82+
83+
func (e HTTPError) Unwrap() error {
84+
return e.scim
85+
}

enterprise/coderd/scim_test.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"net/http/httptest"
910
"testing"
1011

1112
"github.com/golang-jwt/jwt/v4"
13+
"github.com/imulab/go-scim/pkg/v2/handlerutil"
14+
"github.com/imulab/go-scim/pkg/v2/spec"
1215
"github.com/stretchr/testify/assert"
1316
"github.com/stretchr/testify/require"
17+
"golang.org/x/xerrors"
1418

1519
"github.com/coder/coder/v2/coderd/audit"
1620
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -22,6 +26,7 @@ import (
2226
"github.com/coder/coder/v2/enterprise/coderd"
2327
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
2428
"github.com/coder/coder/v2/enterprise/coderd/license"
29+
"github.com/coder/coder/v2/enterprise/coderd/scim"
2530
"github.com/coder/coder/v2/testutil"
2631
)
2732

@@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) {
5964

6065
func setScimAuthBearer(key []byte) func(*http.Request) {
6166
return func(r *http.Request) {
62-
r.Header.Set("Authorization", "Bearer "+string(key))
67+
// Do strange casing to ensure it's case-insensitive
68+
r.Header.Set("Authorization", "beAreR "+string(key))
6369
}
6470
}
6571

@@ -111,7 +117,7 @@ func TestScim(t *testing.T) {
111117
res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{})
112118
require.NoError(t, err)
113119
defer res.Body.Close()
114-
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
120+
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
115121
})
116122

117123
t.Run("OK", func(t *testing.T) {
@@ -454,7 +460,7 @@ func TestScim(t *testing.T) {
454460
require.NoError(t, err)
455461
_, _ = io.Copy(io.Discard, res.Body)
456462
_ = res.Body.Close()
457-
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
463+
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
458464
})
459465

460466
t.Run("OK", func(t *testing.T) {
@@ -585,3 +591,21 @@ func TestScim(t *testing.T) {
585591
})
586592
})
587593
}
594+
595+
func TestScimError(t *testing.T) {
596+
t.Parallel()
597+
598+
// Demonstrates that we cannot use the standard errors
599+
rw := httptest.NewRecorder()
600+
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
601+
resp := rw.Result()
602+
defer resp.Body.Close()
603+
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
604+
605+
// Our error wrapper works
606+
rw = httptest.NewRecorder()
607+
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found")))
608+
resp = rw.Result()
609+
defer resp.Body.Close()
610+
require.Equal(t, http.StatusNotFound, resp.StatusCode)
611+
}

0 commit comments

Comments
 (0)