Skip to content

Commit a4961b9

Browse files
committed
Add unit test
1 parent 6d6a46e commit a4961b9

File tree

9 files changed

+118
-15
lines changed

9 files changed

+118
-15
lines changed

coderd/audit/diff.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ type Auditable interface {
1717
database.WorkspaceBuild |
1818
database.AuditableGroup |
1919
database.License |
20-
database.WorkspaceProxy
20+
database.WorkspaceProxy |
21+
database.OauthMergeState
2122
}
2223

2324
// Map is a map of changed fields in an audited resource. It maps field names to

coderd/audit/request.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ func ResourceTarget[T Auditable](tgt T) string {
8484
return strconv.Itoa(int(typed.ID))
8585
case database.WorkspaceProxy:
8686
return typed.Name
87+
case database.OauthMergeState:
88+
return typed.StateString
8789
default:
8890
panic(fmt.Sprintf("unknown resource %T", tgt))
8991
}
@@ -111,6 +113,9 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
111113
return typed.UUID
112114
case database.WorkspaceProxy:
113115
return typed.ID
116+
case database.OauthMergeState:
117+
// The merge state is for the given user
118+
return typed.UserID
114119
default:
115120
panic(fmt.Sprintf("unknown resource %T", tgt))
116121
}
@@ -138,6 +143,8 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
138143
return database.ResourceTypeLicense
139144
case database.WorkspaceProxy:
140145
return database.ResourceTypeWorkspaceProxy
146+
case database.OauthMergeState:
147+
return database.ResourceTypeConvertLogin
141148
default:
142149
panic(fmt.Sprintf("unknown resource %T", typed))
143150
}

coderd/database/dbfake/dbfake.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,9 +1175,9 @@ func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time
11751175
return nil
11761176
}
11771177

1178-
func (q *fakeQuerier) DeleteUserOauthMergeStates(ctx context.Context, userID uuid.UUID) error {
1179-
q.mutex.RLock()
1180-
defer q.mutex.RUnlock()
1178+
func (q *fakeQuerier) DeleteUserOauthMergeStates(_ context.Context, userID uuid.UUID) error {
1179+
q.mutex.Lock()
1180+
defer q.mutex.Unlock()
11811181

11821182
i := 0
11831183
for {
@@ -3961,8 +3961,8 @@ func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
39613961
}
39623962

39633963
func (q *fakeQuerier) InsertUserOauthMergeState(_ context.Context, arg database.InsertUserOauthMergeStateParams) (database.OauthMergeState, error) {
3964-
q.mutex.RLock()
3965-
defer q.mutex.RUnlock()
3964+
q.mutex.Lock()
3965+
defer q.mutex.Unlock()
39663966

39673967
if err := validateDatabaseType(arg); err != nil {
39683968
return database.OauthMergeState{}, err

coderd/database/models.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/userauth.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,14 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
3838
var (
3939
ctx = r.Context()
4040
auditor = api.Auditor.Load()
41-
aReq, commitAudit = audit.InitRequest[database.APIKey](rw, &audit.RequestParams{
41+
aReq, commitAudit = audit.InitRequest[database.OauthMergeState](rw, &audit.RequestParams{
4242
Audit: *auditor,
4343
Log: api.Logger,
4444
Request: r,
4545
Action: database.AuditActionCreate,
4646
})
4747
)
48-
// TODO: @emyrk This does make a new api key. Make a new auditable resource
49-
// for oidc state.
50-
aReq.Old = database.APIKey{}
48+
aReq.Old = database.OauthMergeState{}
5149
defer commitAudit()
5250

5351
var req codersdk.ConvertLoginRequest
@@ -102,12 +100,12 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
102100
// We should only ever have 1 oauth merge state per user. So delete
103101
// any existing if they exist.
104102
//nolint:gocritic // Keeping the table clean
105-
err := api.Database.DeleteUserOauthMergeStates(dbauthz.AsSystemRestricted(ctx), user.ID)
103+
err := store.DeleteUserOauthMergeStates(dbauthz.AsSystemRestricted(ctx), user.ID)
106104
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
107105
return err
108106
}
109107

110-
mergeState, err = api.Database.InsertUserOauthMergeState(ctx, database.InsertUserOauthMergeStateParams{
108+
mergeState, err = store.InsertUserOauthMergeState(ctx, database.InsertUserOauthMergeStateParams{
111109
UserID: user.ID,
112110
StateString: stateString,
113111
ToLoginType: database.LoginType(req.ToLoginType),
@@ -134,6 +132,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
134132
ToLoginType: codersdk.LoginType(mergeState.ToLoginType),
135133
UserID: mergeState.UserID,
136134
})
135+
aReq.New = mergeState
137136
}
138137

139138
// Authenticates the user with an email and password.
@@ -532,6 +531,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
532531
Username: ghUser.GetLogin(),
533532
AvatarURL: ghUser.GetAvatarURL(),
534533
OauthConversionEnabled: api.DeploymentValues.EnableOauthAccountConversion.Value(),
534+
InitAuditRequest: func(params *audit.RequestParams) (*audit.Request[database.OauthMergeState], func()) {
535+
return audit.InitRequest[database.OauthMergeState](rw, params)
536+
},
535537
})
536538
var httpErr httpError
537539
if xerrors.As(err, &httpErr) {
@@ -864,6 +866,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
864866
UsingGroups: usingGroups,
865867
Groups: groups,
866868
OauthConversionEnabled: api.DeploymentValues.EnableOauthAccountConversion.Value(),
869+
InitAuditRequest: func(params *audit.RequestParams) (*audit.Request[database.OauthMergeState], func()) {
870+
return audit.InitRequest[database.OauthMergeState](rw, params)
871+
},
867872
})
868873
var httpErr httpError
869874
if xerrors.As(err, &httpErr) {
@@ -946,6 +951,8 @@ type oauthLoginParams struct {
946951
UsingGroups bool
947952
Groups []string
948953
OauthConversionEnabled bool
954+
955+
InitAuditRequest func(params *audit.RequestParams) (*audit.Request[database.OauthMergeState], func())
949956
}
950957

951958
type httpError struct {
@@ -997,13 +1004,27 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
9971004
if !params.OauthConversionEnabled {
9981005
return wrongLoginTypeErr
9991006
}
1000-
mergeState, err := api.Database.GetUserOauthMergeState(dbauthz.AsSystemRestricted(ctx), database.GetUserOauthMergeStateParams{
1007+
var (
1008+
auditor = *api.Auditor.Load()
1009+
oauthConvertAudit, commitOauthConvertAudit = params.InitAuditRequest(&audit.RequestParams{
1010+
Audit: auditor,
1011+
Log: api.Logger,
1012+
Request: r,
1013+
Action: database.AuditActionLogin,
1014+
})
1015+
)
1016+
defer commitOauthConvertAudit()
1017+
1018+
// nolint:gocritic // Required to auth the oidc convert
1019+
mergeState, err := tx.GetUserOauthMergeState(dbauthz.AsSystemRestricted(ctx), database.GetUserOauthMergeStateParams{
10011020
UserID: user.ID,
10021021
StateString: params.State.StateString,
10031022
})
10041023
if xerrors.Is(err, sql.ErrNoRows) {
10051024
return wrongLoginTypeErr
10061025
}
1026+
oauthConvertAudit.Old = mergeState
1027+
10071028
failedMsg := fmt.Sprintf("Request to convert login type from %s to %s failed", user.LoginType, params.LoginType)
10081029
if err != nil {
10091030
return httpError{

coderd/userauth_test.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,46 @@ func TestUserOIDC(t *testing.T) {
782782
})
783783
}
784784

785+
t.Run("OIDCConvert", func(t *testing.T) {
786+
t.Parallel()
787+
auditor := audit.NewMock()
788+
conf := coderdtest.NewOIDCConfig(t, "")
789+
790+
config := conf.OIDCConfig(t, nil)
791+
config.AllowSignups = true
792+
793+
cfg := coderdtest.DeploymentValues(t)
794+
cfg.EnableOauthAccountConversion = true
795+
client := coderdtest.New(t, &coderdtest.Options{
796+
Auditor: auditor,
797+
OIDCConfig: config,
798+
DeploymentValues: cfg,
799+
})
800+
owner := coderdtest.CreateFirstUser(t, client)
801+
802+
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
803+
804+
numLogs := len(auditor.AuditLogs())
805+
806+
code := conf.EncodeClaims(t, jwt.MapClaims{
807+
"email": userData.Email,
808+
})
809+
810+
numLogs++ // add an audit log for login
811+
ctx := testutil.Context(t, testutil.WaitShort)
812+
convertResponse, err := user.ConvertToOAuthLogin(ctx, codersdk.ConvertLoginRequest{
813+
ToLoginType: codersdk.LoginTypeOIDC,
814+
LoginWithPasswordRequest: codersdk.LoginWithPasswordRequest{
815+
Email: userData.Email,
816+
Password: "SomeSecurePassword!",
817+
},
818+
})
819+
require.NoError(t, err)
820+
821+
resp := oidcCallbackWithState(t, client, code, convertResponse.StateString)
822+
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
823+
})
824+
785825
t.Run("AlternateUsername", func(t *testing.T) {
786826
t.Parallel()
787827
auditor := audit.NewMock()
@@ -1004,17 +1044,21 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
10041044
}
10051045

10061046
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
1047+
return oidcCallbackWithState(t, client, code, "somestate")
1048+
}
1049+
1050+
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string) *http.Response {
10071051
t.Helper()
10081052
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
10091053
return http.ErrUseLastResponse
10101054
}
1011-
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
1055+
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state))
10121056
require.NoError(t, err)
10131057
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
10141058
require.NoError(t, err)
10151059
req.AddCookie(&http.Cookie{
10161060
Name: codersdk.OAuth2StateCookie,
1017-
Value: "somestate",
1061+
Value: state,
10181062
})
10191063
res, err := client.HTTPClient.Do(req)
10201064
require.NoError(t, err)

codersdk/deployment.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,8 @@ when required by your organization's security policy.`,
14351435
Value: &c.EnableOauthAccountConversion,
14361436
Group: &deploymentGroupNetworkingHTTP,
14371437
YAML: "enableOauthAuthConversion",
1438+
// Do not show this until the feature is fully ready.
1439+
Hidden: true,
14381440
},
14391441
{
14401442
Name: "Config Path",

codersdk/users.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,26 @@ func (c *Client) LoginWithPassword(ctx context.Context, req LoginWithPasswordReq
312312
return resp, nil
313313
}
314314

315+
// ConvertToOAuthLogin will send a request to convert the user from password
316+
// based authentication to oauth based. The response has the oauth state code
317+
// to use in the oauth flow.
318+
func (c *Client) ConvertToOAuthLogin(ctx context.Context, req ConvertLoginRequest) (OauthConversionResponse, error) {
319+
res, err := c.Request(ctx, http.MethodPost, "/api/v2/users/convert-login", req)
320+
if err != nil {
321+
return OauthConversionResponse{}, err
322+
}
323+
defer res.Body.Close()
324+
if res.StatusCode != http.StatusCreated {
325+
return OauthConversionResponse{}, ReadBodyAsError(res)
326+
}
327+
var resp OauthConversionResponse
328+
err = json.NewDecoder(res.Body).Decode(&resp)
329+
if err != nil {
330+
return OauthConversionResponse{}, err
331+
}
332+
return resp, nil
333+
}
334+
315335
// Logout calls the /logout API
316336
// Call `ClearSessionToken()` to clear the session token of the client.
317337
func (c *Client) Logout(ctx context.Context) error {

enterprise/audit/table.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ var auditableResourcesTypes = map[any]map[string]Action{
155155
"scope": ActionIgnore,
156156
"token_name": ActionIgnore,
157157
},
158+
&database.OauthMergeState{}: {
159+
"state_string": ActionSecret,
160+
"created_at": ActionTrack,
161+
"expires_at": ActionTrack,
162+
"to_login_type": ActionTrack,
163+
"user_id": ActionTrack,
164+
},
158165
// TODO: track an ID here when the below ticket is completed:
159166
// https://github.com/coder/coder/pull/6012
160167
&database.License{}: {

0 commit comments

Comments
 (0)