Skip to content

Commit 08b4eb3

Browse files
authored
fix: refresh all oauth links on external auth page (#11646)
* fix: refresh all oauth links on external auth page
1 parent d583aca commit 08b4eb3

File tree

5 files changed

+85
-11
lines changed

5 files changed

+85
-11
lines changed

coderd/externalauth.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ func (api *API) listUserExternalAuths(rw http.ResponseWriter, r *http.Request) {
362362
if err == nil && valid {
363363
links[i] = newLink
364364
}
365-
break
366365
}
367366
}
368367
}

coderd/externalauth_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818

1919
"github.com/coder/coder/v2/coderd/coderdtest"
2020
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
21+
"github.com/coder/coder/v2/coderd/database"
22+
"github.com/coder/coder/v2/coderd/database/dbauthz"
2123
"github.com/coder/coder/v2/coderd/database/dbtime"
2224
"github.com/coder/coder/v2/coderd/externalauth"
2325
"github.com/coder/coder/v2/coderd/httpapi"
@@ -198,6 +200,66 @@ func TestExternalAuthManagement(t *testing.T) {
198200
require.Len(t, list.Providers, 2)
199201
require.Len(t, list.Links, 0)
200202
})
203+
t.Run("RefreshAllProviders", func(t *testing.T) {
204+
t.Parallel()
205+
const githubID = "fake-github"
206+
const gitlabID = "fake-gitlab"
207+
208+
githubCalled := false
209+
githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
210+
githubCalled = true
211+
return nil
212+
}))
213+
gitlabCalled := false
214+
gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
215+
gitlabCalled = true
216+
return nil
217+
}))
218+
219+
owner, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
220+
ExternalAuthConfigs: []*externalauth.Config{
221+
githubApp.ExternalAuthConfig(t, githubID, nil, func(cfg *externalauth.Config) {
222+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
223+
}),
224+
gitlab.ExternalAuthConfig(t, gitlabID, nil, func(cfg *externalauth.Config) {
225+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitLab.String()
226+
}),
227+
},
228+
})
229+
ownerUser := coderdtest.CreateFirstUser(t, owner)
230+
// Just a regular user
231+
client, user := coderdtest.CreateAnotherUser(t, owner, ownerUser.OrganizationID)
232+
ctx := testutil.Context(t, testutil.WaitLong)
233+
234+
// Log into github & gitlab
235+
githubApp.ExternalLogin(t, client)
236+
gitlab.ExternalLogin(t, client)
237+
238+
links, err := db.GetExternalAuthLinksByUserID(
239+
dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), user.ID)
240+
require.NoError(t, err)
241+
require.Len(t, links, 2)
242+
243+
// Expire the links
244+
for _, l := range links {
245+
_, err := db.UpdateExternalAuthLink(dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), database.UpdateExternalAuthLinkParams{
246+
ProviderID: l.ProviderID,
247+
UserID: l.UserID,
248+
UpdatedAt: dbtime.Now(),
249+
OAuthAccessToken: l.OAuthAccessToken,
250+
OAuthRefreshToken: l.OAuthRefreshToken,
251+
OAuthExpiry: time.Now().Add(time.Hour * -1),
252+
OAuthExtra: l.OAuthExtra,
253+
})
254+
require.NoErrorf(t, err, "expire key for %s", l.ProviderID)
255+
}
256+
257+
list, err := client.ListExternalAuths(ctx)
258+
require.NoError(t, err)
259+
require.Len(t, list.Links, 2)
260+
require.True(t, githubCalled, "github should be refreshed")
261+
require.True(t, gitlabCalled, "gitlab should be refreshed")
262+
})
201263
}
202264

203265
func TestExternalAuthDevice(t *testing.T) {

codersdk/client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ func ExpectJSONMime(res *http.Response) error {
336336

337337
// ReadBodyAsError reads the response as a codersdk.Response, and
338338
// wraps it in a codersdk.Error type for easy marshaling.
339+
//
340+
// This will always return an error, so only call it if the response failed
341+
// your expectations. Usually via status code checking.
342+
// nolint:staticcheck
339343
func ReadBodyAsError(res *http.Response) error {
340344
if res == nil {
341345
return xerrors.Errorf("no body returned")

codersdk/client_internal_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ func Test_readBodyAsError(t *testing.T) {
283283
assert.Equal(t, unexpectedJSON, sdkErr.Response.Detail)
284284
},
285285
},
286+
{
287+
// Even status code 200 should be considered an error if this function
288+
// is called. There are parts of the code that require this function
289+
// to always return an error.
290+
name: "OKResp",
291+
req: nil,
292+
res: newResponse(http.StatusOK, jsonCT, marshal(map[string]any{})),
293+
assert: func(t *testing.T, err error) {
294+
require.Error(t, err)
295+
},
296+
},
286297
}
287298

288299
for _, c := range tests {

enterprise/coderd/proxyhealth/proxyhealth.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,19 +321,17 @@ func (p *ProxyHealth) runOnce(ctx context.Context, now time.Time) (map[uuid.UUID
321321
// readable.
322322
builder.WriteString(fmt.Sprintf("unexpected status code %d. ", resp.StatusCode))
323323
builder.WriteString(fmt.Sprintf("\nEncountered error, send a request to %q from the Coderd environment to debug this issue.", reqURL))
324+
// err will always be non-nil
324325
err := codersdk.ReadBodyAsError(resp)
325-
if err != nil {
326-
var apiErr *codersdk.Error
327-
if xerrors.As(err, &apiErr) {
328-
builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail))
329-
for _, v := range apiErr.Validations {
330-
// Pretty sure this is not possible from the called endpoint, but just in case.
331-
builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail))
332-
}
333-
} else {
334-
builder.WriteString(fmt.Sprintf("\nError: %s", err.Error()))
326+
var apiErr *codersdk.Error
327+
if xerrors.As(err, &apiErr) {
328+
builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail))
329+
for _, v := range apiErr.Validations {
330+
// Pretty sure this is not possible from the called endpoint, but just in case.
331+
builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail))
335332
}
336333
}
334+
builder.WriteString(fmt.Sprintf("\nError: %s", err.Error()))
337335

338336
status.Report.Errors = []string{builder.String()}
339337
case err != nil:

0 commit comments

Comments
 (0)