Skip to content

Commit 1ae65f3

Browse files
committed
chore: add query to fetch top level idp claim fields
Used for idp sync settings. Tests WIP
1 parent 4fedc7c commit 1ae65f3

File tree

9 files changed

+360
-0
lines changed

9 files changed

+360
-0
lines changed

coderd/database/dbauthz/dbauthz.go

+12
Original file line numberDiff line numberDiff line change
@@ -3283,6 +3283,18 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
32833283
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
32843284
}
32853285

3286+
func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
3287+
resource := rbac.ResourceIdpsyncSettings
3288+
if organizationID != uuid.Nil {
3289+
resource = resource.InOrg(organizationID)
3290+
}
3291+
3292+
if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil {
3293+
return nil, err
3294+
}
3295+
return q.db.OIDCClaimFields(ctx, organizationID)
3296+
}
3297+
32863298
func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
32873299
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg)
32883300
}

coderd/database/dbmem/dbmem.go

+29
Original file line numberDiff line numberDiff line change
@@ -8409,6 +8409,35 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
84098409
return shares, nil
84108410
}
84118411

8412+
func (q *FakeQuerier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
8413+
orgMembers := q.getOrganizationMemberNoLock(organizationID)
8414+
8415+
var fields []string
8416+
for _, link := range q.userLinks {
8417+
if organizationID != uuid.Nil {
8418+
inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool {
8419+
return organizationMember.UserID == link.UserID
8420+
})
8421+
if !inOrg {
8422+
continue
8423+
}
8424+
}
8425+
8426+
if link.LoginType != database.LoginTypeOIDC {
8427+
continue
8428+
}
8429+
8430+
for k := range link.Claims.IDTokenClaims {
8431+
fields = append(fields, k)
8432+
}
8433+
for k := range link.Claims.UserInfoClaims {
8434+
fields = append(fields, k)
8435+
}
8436+
}
8437+
8438+
return slice.Unique(fields), nil
8439+
}
8440+
84128441
func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
84138442
if err := validateDatabaseType(arg); err != nil {
84148443
return []database.OrganizationMembersRow{}, err

coderd/database/dbmetrics/querymetrics.go

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package database
33
import (
44
"context"
55
"database/sql"
6+
"encoding/json"
67
"fmt"
78
"strings"
89

@@ -527,3 +528,9 @@ func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
527528
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
528529
return filtered, nil
529530
}
531+
532+
// UpdateUserLinkRawJSON is a custom query for unit testing. Do not ever expose this
533+
func (q *sqlQuerier) UpdateUserLinkRawJSON(ctx context.Context, userID uuid.UUID, data json.RawMessage) error {
534+
_, err := q.sdb.Exec("INSERT INTO user_links (user_id, claims) VALUES ($1, $2)", userID, data)
535+
return err
536+
}

coderd/database/oidcclaims_test.go

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package database_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"testing"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/coder/coder/v2/coderd/database"
12+
"github.com/coder/coder/v2/coderd/database/dbfake"
13+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
14+
"github.com/coder/coder/v2/coderd/database/dbtime"
15+
"github.com/coder/coder/v2/testutil"
16+
)
17+
18+
type extraKeys struct {
19+
database.UserLinkClaims
20+
Foo string `json:"foo"`
21+
}
22+
23+
func TestOIDCClaims(t *testing.T) {
24+
t.Parallel()
25+
26+
toJSON := func(a any) json.RawMessage {
27+
b, _ := json.Marshal(a)
28+
return b
29+
}
30+
31+
db, _ := dbtestutil.NewDB(t)
32+
g := userGenerator{t: t, db: db}
33+
34+
// https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters
35+
alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{
36+
UserLinkClaims: database.UserLinkClaims{
37+
IDTokenClaims: map[string]interface{}{
38+
"sub": "alice",
39+
},
40+
UserInfoClaims: map[string]interface{}{
41+
"sub": "alice",
42+
},
43+
},
44+
// Always should be a no-op
45+
Foo: "bar",
46+
}))
47+
bob := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
48+
IDTokenClaims: map[string]interface{}{
49+
"sub": "bob",
50+
},
51+
UserInfoClaims: map[string]interface{}{
52+
"sub": "bob",
53+
},
54+
}))
55+
charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
56+
IDTokenClaims: map[string]interface{}{
57+
"sub": "charlie",
58+
},
59+
UserInfoClaims: map[string]interface{}{
60+
"sub": "charlie",
61+
},
62+
}))
63+
64+
// users that just try to cause problems, but should not affect the output of
65+
// queries.
66+
problematics := []database.User{
67+
g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{})), // null claims
68+
g.withLink(database.LoginTypeOIDC, []byte(`{}`)), // empty claims
69+
g.withLink(database.LoginTypeOIDC, []byte(`{"foo": "bar"}`)), // random keys
70+
g.noLink(database.LoginTypeOIDC), // no link
71+
72+
g.withLink(database.LoginTypeGithub, toJSON(database.UserLinkClaims{
73+
IDTokenClaims: map[string]interface{}{
74+
"not": "allowed",
75+
},
76+
UserInfoClaims: map[string]interface{}{
77+
"do-not": "look",
78+
},
79+
})), // github should be omitted
80+
81+
// extra random users
82+
g.noLink(database.LoginTypeGithub),
83+
g.noLink(database.LoginTypePassword),
84+
}
85+
86+
// Insert some orgs, users, and links
87+
orgA := dbfake.Organization(t, db).Members(
88+
append(problematics,
89+
alice,
90+
bob)...,
91+
).Do()
92+
orgB := dbfake.Organization(t, db).Members(
93+
append(problematics,
94+
charlie,
95+
)...,
96+
).Do()
97+
98+
// Verify the OIDC claim fields
99+
requireClaims(t, db, orgA.Org.ID, []string{"sub"})
100+
requireClaims(t, db, orgB.Org.ID, []string{"sub"})
101+
}
102+
103+
func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) {
104+
t.Helper()
105+
106+
ctx := testutil.Context(t, testutil.WaitMedium)
107+
got, err := db.OIDCClaimFields(ctx, orgID)
108+
require.NoError(t, err)
109+
110+
require.ElementsMatch(t, want, got)
111+
}
112+
113+
type userGenerator struct {
114+
t *testing.T
115+
db database.Store
116+
}
117+
118+
func (g userGenerator) noLink(lt database.LoginType) database.User {
119+
return g.user(lt, false, nil)
120+
}
121+
122+
func (g userGenerator) withLink(lt database.LoginType, rawJSON json.RawMessage) database.User {
123+
return g.user(lt, true, rawJSON)
124+
}
125+
126+
func (g userGenerator) user(lt database.LoginType, createLink bool, rawJSON json.RawMessage) database.User {
127+
t := g.t
128+
db := g.db
129+
130+
t.Helper()
131+
132+
u, err := db.InsertUser(context.Background(), database.InsertUserParams{
133+
ID: uuid.New(),
134+
Email: testutil.GetRandomName(t),
135+
Username: testutil.GetRandomName(t),
136+
Name: testutil.GetRandomName(t),
137+
CreatedAt: dbtime.Now(),
138+
UpdatedAt: dbtime.Now(),
139+
RBACRoles: []string{},
140+
LoginType: lt,
141+
Status: string(database.UserStatusActive),
142+
})
143+
require.NoError(t, err)
144+
145+
if !createLink {
146+
return u
147+
}
148+
149+
link, err := db.InsertUserLink(context.Background(), database.InsertUserLinkParams{
150+
UserID: u.ID,
151+
LoginType: lt,
152+
Claims: database.UserLinkClaims{},
153+
})
154+
require.NoError(t, err)
155+
156+
if sql, ok := db.(rawUpdater); ok {
157+
// The only way to put arbitrary json into the db for testing edge cases.
158+
// Making this a public API would be a mistake.
159+
err = sql.UpdateUserLinkRawJSON(context.Background(), u.ID, rawJSON)
160+
require.NoError(t, err)
161+
} else {
162+
// no need to test the json key logic in dbmem. Everything is type safe.
163+
var claims database.UserLinkClaims
164+
err := json.Unmarshal(rawJSON, &claims)
165+
require.NoError(t, err)
166+
167+
_, err = db.UpdateUserLink(context.Background(), database.UpdateUserLinkParams{
168+
OAuthAccessToken: link.OAuthAccessToken,
169+
OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID,
170+
OAuthRefreshToken: link.OAuthRefreshToken,
171+
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID,
172+
OAuthExpiry: link.OAuthExpiry,
173+
UserID: link.UserID,
174+
LoginType: link.LoginType,
175+
// The new claims
176+
Claims: claims,
177+
})
178+
require.NoError(t, err)
179+
}
180+
181+
return u
182+
}
183+
184+
type rawUpdater interface {
185+
UpdateUserLinkRawJSON(ctx context.Context, userID uuid.UUID, data json.RawMessage) error
186+
}

coderd/database/querier.go

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+61
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)