|
1 | 1 | package gitauth_test
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "net/http" |
| 6 | + "net/http/httptest" |
4 | 7 | "net/url"
|
5 | 8 | "testing"
|
| 9 | + "time" |
6 | 10 |
|
7 | 11 | "github.com/stretchr/testify/require"
|
| 12 | + "golang.org/x/oauth2" |
| 13 | + "golang.org/x/xerrors" |
8 | 14 |
|
| 15 | + "github.com/coder/coder/coderd/database" |
| 16 | + "github.com/coder/coder/coderd/database/dbfake" |
| 17 | + "github.com/coder/coder/coderd/database/dbgen" |
9 | 18 | "github.com/coder/coder/coderd/gitauth"
|
10 | 19 | "github.com/coder/coder/codersdk"
|
11 | 20 | )
|
12 | 21 |
|
| 22 | +func TestRefreshToken(t *testing.T) { |
| 23 | + t.Parallel() |
| 24 | + t.Run("FalseIfNoRefresh", func(t *testing.T) { |
| 25 | + t.Parallel() |
| 26 | + config := &gitauth.Config{ |
| 27 | + NoRefresh: true, |
| 28 | + } |
| 29 | + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ |
| 30 | + OAuthExpiry: time.Time{}, |
| 31 | + }) |
| 32 | + require.NoError(t, err) |
| 33 | + require.False(t, refreshed) |
| 34 | + }) |
| 35 | + t.Run("FalseIfTokenSourceFails", func(t *testing.T) { |
| 36 | + t.Parallel() |
| 37 | + config := &gitauth.Config{ |
| 38 | + OAuth2Config: &oauth2Config{ |
| 39 | + tokenError: xerrors.New("failure"), |
| 40 | + }, |
| 41 | + } |
| 42 | + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) |
| 43 | + require.NoError(t, err) |
| 44 | + require.False(t, refreshed) |
| 45 | + }) |
| 46 | + t.Run("ValidateServerError", func(t *testing.T) { |
| 47 | + t.Parallel() |
| 48 | + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 49 | + w.WriteHeader(http.StatusInternalServerError) |
| 50 | + w.Write([]byte("Failure")) |
| 51 | + })) |
| 52 | + config := &gitauth.Config{ |
| 53 | + OAuth2Config: &oauth2Config{}, |
| 54 | + ValidateURL: srv.URL, |
| 55 | + } |
| 56 | + _, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) |
| 57 | + require.ErrorContains(t, err, "Failure") |
| 58 | + }) |
| 59 | + t.Run("ValidateFailure", func(t *testing.T) { |
| 60 | + t.Parallel() |
| 61 | + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 62 | + w.WriteHeader(http.StatusUnauthorized) |
| 63 | + w.Write([]byte("Not permitted")) |
| 64 | + })) |
| 65 | + config := &gitauth.Config{ |
| 66 | + OAuth2Config: &oauth2Config{}, |
| 67 | + ValidateURL: srv.URL, |
| 68 | + } |
| 69 | + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) |
| 70 | + require.NoError(t, err) |
| 71 | + require.False(t, refreshed) |
| 72 | + }) |
| 73 | + t.Run("ValidateNoUpdate", func(t *testing.T) { |
| 74 | + t.Parallel() |
| 75 | + validated := make(chan struct{}) |
| 76 | + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 77 | + w.WriteHeader(http.StatusOK) |
| 78 | + close(validated) |
| 79 | + })) |
| 80 | + accessToken := "testing" |
| 81 | + config := &gitauth.Config{ |
| 82 | + OAuth2Config: &oauth2Config{ |
| 83 | + accessToken: accessToken, |
| 84 | + }, |
| 85 | + ValidateURL: srv.URL, |
| 86 | + } |
| 87 | + _, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ |
| 88 | + OAuthAccessToken: accessToken, |
| 89 | + }) |
| 90 | + require.NoError(t, err) |
| 91 | + require.True(t, valid) |
| 92 | + <-validated |
| 93 | + }) |
| 94 | + t.Run("Updates", func(t *testing.T) { |
| 95 | + t.Parallel() |
| 96 | + config := &gitauth.Config{ |
| 97 | + ID: "test", |
| 98 | + OAuth2Config: &oauth2Config{ |
| 99 | + accessToken: "updated", |
| 100 | + }, |
| 101 | + } |
| 102 | + db := dbfake.New() |
| 103 | + link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ |
| 104 | + ProviderID: config.ID, |
| 105 | + OAuthAccessToken: "initial", |
| 106 | + }) |
| 107 | + _, valid, err := config.RefreshToken(context.Background(), db, link) |
| 108 | + require.NoError(t, err) |
| 109 | + require.True(t, valid) |
| 110 | + }) |
| 111 | +} |
| 112 | + |
13 | 113 | func TestConvertYAML(t *testing.T) {
|
14 | 114 | t.Parallel()
|
15 | 115 | for _, tc := range []struct {
|
@@ -90,3 +190,40 @@ func TestConvertYAML(t *testing.T) {
|
90 | 190 | require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL(""))
|
91 | 191 | })
|
92 | 192 | }
|
| 193 | + |
| 194 | +type oauth2Config struct { |
| 195 | + accessToken string |
| 196 | + tokenError error |
| 197 | +} |
| 198 | + |
| 199 | +func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { |
| 200 | + return "/?state=" + url.QueryEscape(state) |
| 201 | +} |
| 202 | + |
| 203 | +func (o *oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) { |
| 204 | + return &oauth2.Token{ |
| 205 | + AccessToken: o.accessToken, |
| 206 | + RefreshToken: "refresh", |
| 207 | + Expiry: database.Now().Add(time.Hour), |
| 208 | + }, nil |
| 209 | +} |
| 210 | + |
| 211 | +func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { |
| 212 | + return &oauth2TokenSource{ |
| 213 | + err: o.tokenError, |
| 214 | + accessToken: o.accessToken, |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +type oauth2TokenSource struct { |
| 219 | + accessToken string |
| 220 | + err error |
| 221 | +} |
| 222 | + |
| 223 | +func (o *oauth2TokenSource) Token() (*oauth2.Token, error) { |
| 224 | + return &oauth2.Token{ |
| 225 | + AccessToken: o.accessToken, |
| 226 | + RefreshToken: "refresh", |
| 227 | + Expiry: database.Now().Add(time.Hour), |
| 228 | + }, o.err |
| 229 | +} |
0 commit comments