Skip to content

Commit 2f54f76

Browse files
authored
feat: allow IDP to return single string for roles/groups claim (#10993)
* feat: allow IDP to return single string instead of array for roles/groups claim This is to support ADFS
1 parent 3883d71 commit 2f54f76

File tree

3 files changed

+259
-39
lines changed

3 files changed

+259
-39
lines changed

coderd/userauth.go

Lines changed: 71 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,31 +1019,26 @@ func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interfac
10191019
if api.OIDCConfig.GroupField != "" {
10201020
usingGroups = true
10211021
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
1022-
if ok && api.OIDCConfig.GroupField != "" {
1023-
// Convert the []interface{} we get to a []string.
1024-
groupsInterface, ok := groupsRaw.([]interface{})
1025-
if ok {
1026-
api.Logger.Debug(ctx, "groups returned in oidc claims",
1027-
slog.F("len", len(groupsInterface)),
1028-
slog.F("groups", groupsInterface),
1022+
if ok {
1023+
parsedGroups, err := parseStringSliceClaim(groupsRaw)
1024+
if err != nil {
1025+
api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims",
1026+
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
1027+
slog.Error(err),
10291028
)
1029+
return false, nil, err
1030+
}
10301031

1031-
for _, groupInterface := range groupsInterface {
1032-
group, ok := groupInterface.(string)
1033-
if !ok {
1034-
return false, nil, xerrors.Errorf("Invalid group type. Expected string, got: %T", groupInterface)
1035-
}
1036-
1037-
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
1038-
group = mappedGroup
1039-
}
1032+
api.Logger.Debug(ctx, "groups returned in oidc claims",
1033+
slog.F("len", len(parsedGroups)),
1034+
slog.F("groups", parsedGroups),
1035+
)
10401036

1041-
groups = append(groups, group)
1037+
for _, group := range parsedGroups {
1038+
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
1039+
group = mappedGroup
10421040
}
1043-
} else {
1044-
api.Logger.Debug(ctx, "groups field was an unknown type",
1045-
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
1046-
)
1041+
groups = append(groups, group)
10471042
}
10481043
}
10491044
}
@@ -1079,10 +1074,11 @@ func (api *API) oidcRoles(ctx context.Context, rw http.ResponseWriter, r *http.R
10791074
rolesRow = []interface{}{}
10801075
}
10811076

1082-
rolesInterface, ok := rolesRow.([]interface{})
1083-
if !ok {
1084-
api.Logger.Error(ctx, "oidc claim user roles field was an unknown type",
1077+
parsedRoles, err := parseStringSliceClaim(rolesRow)
1078+
if err != nil {
1079+
api.Logger.Error(ctx, "oidc claims user roles field was an unknown type",
10851080
slog.F("type", fmt.Sprintf("%T", rolesRow)),
1081+
slog.Error(err),
10861082
)
10871083
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
10881084
Status: http.StatusInternalServerError,
@@ -1096,21 +1092,10 @@ func (api *API) oidcRoles(ctx context.Context, rw http.ResponseWriter, r *http.R
10961092
}
10971093

10981094
api.Logger.Debug(ctx, "roles returned in oidc claims",
1099-
slog.F("len", len(rolesInterface)),
1100-
slog.F("roles", rolesInterface),
1095+
slog.F("len", len(parsedRoles)),
1096+
slog.F("roles", parsedRoles),
11011097
)
1102-
for _, roleInterface := range rolesInterface {
1103-
role, ok := roleInterface.(string)
1104-
if !ok {
1105-
api.Logger.Error(ctx, "invalid oidc user role type",
1106-
slog.F("type", fmt.Sprintf("%T", rolesRow)),
1107-
)
1108-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
1109-
Message: fmt.Sprintf("Invalid user role type. Expected string, got: %T", roleInterface),
1110-
})
1111-
return nil, false
1112-
}
1113-
1098+
for _, role := range parsedRoles {
11141099
if mappedRoles, ok := api.OIDCConfig.UserRoleMapping[role]; ok {
11151100
if len(mappedRoles) == 0 {
11161101
continue
@@ -1449,7 +1434,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14491434
if err != nil {
14501435
return httpError{
14511436
code: http.StatusBadRequest,
1452-
msg: "Invalid roles through OIDC claim",
1437+
msg: "Invalid roles through OIDC claims",
14531438
detail: fmt.Sprintf("Error from role assignment attempt: %s", err.Error()),
14541439
renderStaticPage: true,
14551440
}
@@ -1744,3 +1729,50 @@ func wrongLoginTypeHTTPError(user database.LoginType, params database.LoginType)
17441729
params, user, addedMsg),
17451730
}
17461731
}
1732+
1733+
// parseStringSliceClaim parses the claim for groups and roles, expected []string.
1734+
//
1735+
// Some providers like ADFS return a single string instead of an array if there
1736+
// is only 1 element. So this function handles the edge cases.
1737+
func parseStringSliceClaim(claim interface{}) ([]string, error) {
1738+
groups := make([]string, 0)
1739+
if claim == nil {
1740+
return groups, nil
1741+
}
1742+
1743+
// The simple case is the type is exactly what we expected
1744+
asStringArray, ok := claim.([]string)
1745+
if ok {
1746+
return asStringArray, nil
1747+
}
1748+
1749+
asArray, ok := claim.([]interface{})
1750+
if ok {
1751+
for i, item := range asArray {
1752+
asString, ok := item.(string)
1753+
if !ok {
1754+
return nil, xerrors.Errorf("invalid claim type. Element %d expected a string, got: %T", i, item)
1755+
}
1756+
groups = append(groups, asString)
1757+
}
1758+
return groups, nil
1759+
}
1760+
1761+
asString, ok := claim.(string)
1762+
if ok {
1763+
if asString == "" {
1764+
// Empty string should be 0 groups.
1765+
return []string{}, nil
1766+
}
1767+
// If it is a single string, first check if it is a csv.
1768+
// If a user hits this, it is likely a misconfiguration and they need
1769+
// to reconfigure their IDP to send an array instead.
1770+
if strings.Contains(asString, ",") {
1771+
return nil, xerrors.Errorf("invalid claim type. Got a csv string (%q), change this claim to return an array of strings instead.", asString)
1772+
}
1773+
return []string{asString}, nil
1774+
}
1775+
1776+
// Not sure what the user gave us.
1777+
return nil, xerrors.Errorf("invalid claim type. Expected an array of strings, got: %T", claim)
1778+
}

coderd/userauth_internal_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package coderd
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestParseStringSliceClaim(t *testing.T) {
11+
t.Parallel()
12+
13+
cases := []struct {
14+
Name string
15+
GoClaim interface{}
16+
// JSON Claim allows testing the json -> go conversion
17+
// of some strings.
18+
JSONClaim string
19+
ErrorExpected bool
20+
ExpectedSlice []string
21+
}{
22+
{
23+
Name: "Nil",
24+
GoClaim: nil,
25+
ExpectedSlice: []string{},
26+
},
27+
// Go Slices
28+
{
29+
Name: "EmptySlice",
30+
GoClaim: []string{},
31+
ExpectedSlice: []string{},
32+
},
33+
{
34+
Name: "StringSlice",
35+
GoClaim: []string{"a", "b", "c"},
36+
ExpectedSlice: []string{"a", "b", "c"},
37+
},
38+
{
39+
Name: "InterfaceSlice",
40+
GoClaim: []interface{}{"a", "b", "c"},
41+
ExpectedSlice: []string{"a", "b", "c"},
42+
},
43+
{
44+
Name: "MixedSlice",
45+
GoClaim: []interface{}{"a", string("b"), interface{}("c")},
46+
ExpectedSlice: []string{"a", "b", "c"},
47+
},
48+
{
49+
Name: "StringSliceOneElement",
50+
GoClaim: []string{"a"},
51+
ExpectedSlice: []string{"a"},
52+
},
53+
// Json Slices
54+
{
55+
Name: "JSONEmptySlice",
56+
JSONClaim: `[]`,
57+
ExpectedSlice: []string{},
58+
},
59+
{
60+
Name: "JSONStringSlice",
61+
JSONClaim: `["a", "b", "c"]`,
62+
ExpectedSlice: []string{"a", "b", "c"},
63+
},
64+
{
65+
Name: "JSONStringSliceOneElement",
66+
JSONClaim: `["a"]`,
67+
ExpectedSlice: []string{"a"},
68+
},
69+
// Go string
70+
{
71+
Name: "String",
72+
GoClaim: "a",
73+
ExpectedSlice: []string{"a"},
74+
},
75+
{
76+
Name: "EmptyString",
77+
GoClaim: "",
78+
ExpectedSlice: []string{},
79+
},
80+
{
81+
Name: "Interface",
82+
GoClaim: interface{}("a"),
83+
ExpectedSlice: []string{"a"},
84+
},
85+
// JSON string
86+
{
87+
Name: "JSONString",
88+
JSONClaim: `"a"`,
89+
ExpectedSlice: []string{"a"},
90+
},
91+
{
92+
Name: "JSONEmptyString",
93+
JSONClaim: `""`,
94+
ExpectedSlice: []string{},
95+
},
96+
// Go Errors
97+
{
98+
Name: "IntegerInSlice",
99+
GoClaim: []interface{}{"a", "b", 1},
100+
ErrorExpected: true,
101+
},
102+
// Json Errors
103+
{
104+
Name: "JSONIntegerInSlice",
105+
JSONClaim: `["a", "b", 1]`,
106+
ErrorExpected: true,
107+
},
108+
{
109+
Name: "JSON_CSV",
110+
JSONClaim: `"a,b,c"`,
111+
ErrorExpected: true,
112+
},
113+
}
114+
115+
for _, c := range cases {
116+
c := c
117+
t.Run(c.Name, func(t *testing.T) {
118+
t.Parallel()
119+
120+
if len(c.JSONClaim) > 0 {
121+
require.Nil(t, c.GoClaim, "go claim should be nil if json set")
122+
err := json.Unmarshal([]byte(c.JSONClaim), &c.GoClaim)
123+
require.NoError(t, err, "unmarshal json claim")
124+
}
125+
126+
found, err := parseStringSliceClaim(c.GoClaim)
127+
if c.ErrorExpected {
128+
require.Error(t, err)
129+
} else {
130+
require.NoError(t, err)
131+
require.ElementsMatch(t, c.ExpectedSlice, found, "expected groups")
132+
}
133+
})
134+
}
135+
}

enterprise/coderd/userauth_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ func TestUserOIDC(t *testing.T) {
5555
runner.AssertRoles(t, "alice", []string{})
5656
})
5757

58+
// Some IDPs (ADFS) send the "string" type vs "[]string" if only
59+
// 1 role exists.
60+
t.Run("SingleRoleString", func(t *testing.T) {
61+
t.Parallel()
62+
63+
const oidcRoleName = "TemplateAuthor"
64+
runner := setupOIDCTest(t, oidcTestConfig{
65+
Config: func(cfg *coderd.OIDCConfig) {
66+
cfg.AllowSignups = true
67+
cfg.UserRoleField = "roles"
68+
cfg.UserRoleMapping = map[string][]string{
69+
oidcRoleName: {rbac.RoleTemplateAdmin()},
70+
}
71+
},
72+
})
73+
74+
// User starts with the owner role
75+
_, resp := runner.Login(t, jwt.MapClaims{
76+
"email": "alice@coder.com",
77+
// This is sent as a **string** intentionally instead
78+
// of an array.
79+
"roles": oidcRoleName,
80+
})
81+
require.Equal(t, http.StatusOK, resp.StatusCode)
82+
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin()})
83+
})
84+
5885
// A user has some roles, then on an oauth refresh will lose said
5986
// roles from an updated claim.
6087
t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) {
@@ -334,6 +361,32 @@ func TestUserOIDC(t *testing.T) {
334361
require.Equal(t, http.StatusOK, resp.StatusCode)
335362
runner.AssertGroups(t, "alice", []string{groupName})
336363
})
364+
365+
// Some IDPs (ADFS) send the "string" type vs "[]string" if only
366+
// 1 group exists.
367+
t.Run("SingleRoleGroup", func(t *testing.T) {
368+
t.Parallel()
369+
370+
const groupClaim = "custom-groups"
371+
const groupName = "bingbong"
372+
runner := setupOIDCTest(t, oidcTestConfig{
373+
Config: func(cfg *coderd.OIDCConfig) {
374+
cfg.AllowSignups = true
375+
cfg.GroupField = groupClaim
376+
cfg.CreateMissingGroups = true
377+
},
378+
})
379+
380+
// User starts with the owner role
381+
_, resp := runner.Login(t, jwt.MapClaims{
382+
"email": "alice@coder.com",
383+
// This is sent as a **string** intentionally instead
384+
// of an array.
385+
groupClaim: groupName,
386+
})
387+
require.Equal(t, http.StatusOK, resp.StatusCode)
388+
runner.AssertGroups(t, "alice", []string{groupName})
389+
})
337390
})
338391

339392
t.Run("Refresh", func(t *testing.T) {

0 commit comments

Comments
 (0)