Skip to content

Commit 212429b

Browse files
committed
chore: fix NoRefresh to honor unlimited tokens
- improve testing coverage of gitauth
1 parent 796a975 commit 212429b

File tree

3 files changed

+128
-14
lines changed

3 files changed

+128
-14
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import (
4141
type FakeIDP struct {
4242
issuer string
4343
key *rsa.PrivateKey
44-
provider providerJSON
44+
provider ProviderJSON
4545
handler http.Handler
4646
cfg *oauth2.Config
4747

@@ -181,16 +181,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
181181
return idp
182182
}
183183

184+
func (f *FakeIDP) WellknownConfig() ProviderJSON {
185+
return f.provider
186+
}
187+
184188
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
185189
t.Helper()
186190

187191
u, err := url.Parse(issuer)
188192
require.NoError(t, err, "invalid issuer URL")
189193

190194
f.issuer = issuer
191-
// providerJSON is the JSON representation of the OpenID Connect provider
195+
// ProviderJSON is the JSON representation of the OpenID Connect provider
192196
// These are all the urls that the IDP will respond to.
193-
f.provider = providerJSON{
197+
f.provider = ProviderJSON{
194198
Issuer: issuer,
195199
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
196200
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
@@ -220,6 +224,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
220224
return srv
221225
}
222226

227+
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
228+
// valid token for some given claims.
229+
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
230+
state := uuid.NewString()
231+
f.stateToIDTokenClaims.Store(state, claims)
232+
code := f.newCode(state)
233+
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
234+
}
235+
223236
// Login does the full OIDC flow starting at the "LoginButton".
224237
// The client argument is just to get the URL of the Coder instance.
225238
//
@@ -333,7 +346,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
333346
return resp, nil
334347
}
335348

336-
type providerJSON struct {
349+
// ProviderJSON is the .well-known/configuration JSON
350+
type ProviderJSON struct {
337351
Issuer string `json:"issuer"`
338352
AuthURL string `json:"authorization_endpoint"`
339353
TokenURL string `json:"token_endpoint"`

coderd/gitauth/config.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,26 @@ type Config struct {
6363
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
6464
// If the token is expired and refresh is disabled, we prompt
6565
// the user to authenticate again.
66-
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
66+
if c.NoRefresh &&
67+
// If the time is set to 0, then it should never expire.
68+
// This is true for github, which has no expiry.
69+
!gitAuthLink.OAuthExpiry.IsZero() &&
70+
gitAuthLink.OAuthExpiry.Before(database.Now()) {
6771
return gitAuthLink, false, nil
6872
}
6973

74+
// This is additional defensive programming. Because TokenSource is an interface,
75+
// we cannot be sure that the implementation will treat an 'IsZero' time
76+
// as "not-expired". The default implementation does, but a custom implementation
77+
// might not. Removing the refreshToken will guarantee a refresh will fail.
78+
refreshToken := gitAuthLink.OAuthRefreshToken
79+
if c.NoRefresh {
80+
refreshToken = ""
81+
}
82+
7083
token, err := c.TokenSource(ctx, &oauth2.Token{
7184
AccessToken: gitAuthLink.OAuthAccessToken,
72-
RefreshToken: gitAuthLink.OAuthRefreshToken,
85+
RefreshToken: refreshToken,
7386
Expiry: gitAuthLink.OAuthExpiry,
7487
}).Token()
7588
if err != nil {
@@ -129,8 +142,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders
129142
if err != nil {
130143
return false, nil, err
131144
}
145+
146+
cli := http.DefaultClient
147+
if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
148+
cli = v
149+
}
132150
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
133-
res, err := http.DefaultClient.Do(req)
151+
res, err := cli.Do(req)
134152
if err != nil {
135153
return false, nil, err
136154
}

coderd/gitauth/config_test.go

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ import (
88
"testing"
99
"time"
1010

11+
"github.com/golang-jwt/jwt/v4"
12+
13+
"github.com/google/uuid"
14+
15+
"github.com/coreos/go-oidc/v3/oidc"
16+
17+
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
18+
1119
"github.com/stretchr/testify/require"
1220
"golang.org/x/oauth2"
1321
"golang.org/x/xerrors"
@@ -22,17 +30,82 @@ import (
2230

2331
func TestRefreshToken(t *testing.T) {
2432
t.Parallel()
25-
t.Run("FalseIfNoRefresh", func(t *testing.T) {
33+
const providerID = "test-idp"
34+
expired := time.Now().Add(time.Hour * -1)
35+
t.Run("NoRefreshExpired", func(t *testing.T) {
2636
t.Parallel()
37+
38+
fake := oidctest.NewFakeIDP(t,
39+
// The IDP should not be contacted since the token is expired. An expired
40+
// token with 'NoRefresh' should early abort.
41+
oidctest.WithRefreshHook(func(_ string) error {
42+
t.Error("refresh on the IDP was called, but NoRefresh was set")
43+
return xerrors.New("should not be called")
44+
}),
45+
oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims {
46+
t.Error("token was validated, but it was expired and this should never have happened.")
47+
return nil
48+
}),
49+
)
50+
51+
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
2752
config := &gitauth.Config{
28-
NoRefresh: true,
53+
ID: providerID,
54+
OAuth2Config: fake.OIDCConfig(t, nil),
55+
NoRefresh: true,
56+
ValidateURL: fake.WellknownConfig().UserInfoURL,
2957
}
30-
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
31-
OAuthExpiry: time.Time{},
58+
_, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{
59+
ProviderID: providerID,
60+
UserID: uuid.New(),
61+
OAuthAccessToken: uuid.NewString(),
62+
OAuthRefreshToken: uuid.NewString(),
63+
OAuthExpiry: expired,
3264
})
3365
require.NoError(t, err)
3466
require.False(t, refreshed)
3567
})
68+
t.Run("NoRefreshNoExpiry", func(t *testing.T) {
69+
t.Parallel()
70+
71+
validated := false
72+
fake := oidctest.NewFakeIDP(t,
73+
// The IDP should not be contacted since the token is expired. An expired
74+
// token with 'NoRefresh' should early abort.
75+
oidctest.WithRefreshHook(func(_ string) error {
76+
t.Error("refresh on the IDP was called, but NoRefresh was set")
77+
return xerrors.New("should not be called")
78+
}),
79+
oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims {
80+
validated = true
81+
return jwt.MapClaims{}
82+
}),
83+
)
84+
85+
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
86+
config := &gitauth.Config{
87+
ID: providerID,
88+
OAuth2Config: fake.OIDCConfig(t, nil),
89+
NoRefresh: true,
90+
ValidateURL: fake.WellknownConfig().UserInfoURL,
91+
}
92+
93+
token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{})
94+
require.NoError(t, err)
95+
96+
_, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{
97+
ProviderID: providerID,
98+
UserID: uuid.New(),
99+
OAuthAccessToken: token.AccessToken,
100+
// Pass a refresh token, but this should be ignored in this test!
101+
OAuthRefreshToken: token.RefreshToken,
102+
// Zero time used
103+
OAuthExpiry: time.Time{},
104+
})
105+
require.NoError(t, err)
106+
require.True(t, refreshed, "token without expiry is always valid")
107+
require.True(t, validated, "token should have been validated")
108+
})
36109
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
37110
t.Parallel()
38111
config := &gitauth.Config{
@@ -42,7 +115,9 @@ func TestRefreshToken(t *testing.T) {
42115
},
43116
},
44117
}
45-
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
118+
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
119+
OAuthExpiry: expired,
120+
})
46121
require.NoError(t, err)
47122
require.False(t, refreshed)
48123
})
@@ -56,7 +131,9 @@ func TestRefreshToken(t *testing.T) {
56131
OAuth2Config: &testutil.OAuth2Config{},
57132
ValidateURL: srv.URL,
58133
}
59-
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
134+
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
135+
OAuthExpiry: expired,
136+
})
60137
require.ErrorContains(t, err, "Failure")
61138
})
62139
t.Run("ValidateFailure", func(t *testing.T) {
@@ -69,7 +146,9 @@ func TestRefreshToken(t *testing.T) {
69146
OAuth2Config: &testutil.OAuth2Config{},
70147
ValidateURL: srv.URL,
71148
}
72-
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
149+
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
150+
OAuthExpiry: expired,
151+
})
73152
require.NoError(t, err)
74153
require.False(t, refreshed)
75154
})
@@ -100,6 +179,7 @@ func TestRefreshToken(t *testing.T) {
100179
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
101180
ProviderID: config.ID,
102181
OAuthAccessToken: "initial",
182+
OAuthExpiry: expired,
103183
})
104184
_, refreshed, err := config.RefreshToken(context.Background(), db, link)
105185
require.NoError(t, err)
@@ -124,6 +204,7 @@ func TestRefreshToken(t *testing.T) {
124204
}
125205
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
126206
OAuthAccessToken: accessToken,
207+
OAuthExpiry: expired,
127208
})
128209
require.NoError(t, err)
129210
require.True(t, valid)
@@ -143,6 +224,7 @@ func TestRefreshToken(t *testing.T) {
143224
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
144225
ProviderID: config.ID,
145226
OAuthAccessToken: "initial",
227+
OAuthExpiry: expired,
146228
})
147229
_, valid, err := config.RefreshToken(context.Background(), db, link)
148230
require.NoError(t, err)

0 commit comments

Comments
 (0)