Skip to content

Commit df31636

Browse files
authored
feat: pass access_token to coder_git_auth resource (#6713)
This allows template authors to leverage git auth to perform custom actions, like clone repositories.
1 parent 79ae7cd commit df31636

20 files changed

+647
-479
lines changed

cli/create_test.go

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,16 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7-
"net/url"
87
"os"
98
"regexp"
109
"testing"
1110
"time"
1211

1312
"github.com/stretchr/testify/assert"
1413
"github.com/stretchr/testify/require"
15-
"golang.org/x/oauth2"
1614

1715
"github.com/coder/coder/cli/clitest"
1816
"github.com/coder/coder/coderd/coderdtest"
19-
"github.com/coder/coder/coderd/database"
2017
"github.com/coder/coder/coderd/gitauth"
2118
"github.com/coder/coder/codersdk"
2219
"github.com/coder/coder/provisioner/echo"
@@ -768,7 +765,7 @@ func TestCreateWithGitAuth(t *testing.T) {
768765

769766
client := coderdtest.New(t, &coderdtest.Options{
770767
GitAuthConfigs: []*gitauth.Config{{
771-
OAuth2Config: &oauth2Config{},
768+
OAuth2Config: &testutil.OAuth2Config{},
772769
ID: "github",
773770
Regex: regexp.MustCompile(`github\.com`),
774771
Type: codersdk.GitProviderGitHub,
@@ -836,31 +833,3 @@ func createTestParseResponseWithDefault(defaultValue string) []*proto.Parse_Resp
836833
},
837834
}}
838835
}
839-
840-
type oauth2Config struct{}
841-
842-
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
843-
return "/?state=" + url.QueryEscape(state)
844-
}
845-
846-
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
847-
return &oauth2.Token{
848-
AccessToken: "token",
849-
RefreshToken: "refresh",
850-
Expiry: database.Now().Add(time.Hour),
851-
}, nil
852-
}
853-
854-
func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
855-
return &oauth2TokenSource{}
856-
}
857-
858-
type oauth2TokenSource struct{}
859-
860-
func (*oauth2TokenSource) Token() (*oauth2.Token, error) {
861-
return &oauth2.Token{
862-
AccessToken: "token",
863-
RefreshToken: "refresh",
864-
Expiry: database.Now().Add(time.Hour),
865-
}, nil
866-
}

coderd/coderd.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,18 +831,14 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
831831

832832
mux := drpcmux.New()
833833

834-
gitAuthProviders := make([]string, 0, len(api.GitAuthConfigs))
835-
for _, cfg := range api.GitAuthConfigs {
836-
gitAuthProviders = append(gitAuthProviders, cfg.ID)
837-
}
838834
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
839835
AccessURL: api.AccessURL,
840836
ID: daemon.ID,
841837
OIDCConfig: api.OIDCConfig,
842838
Database: api.Database,
843839
Pubsub: api.Pubsub,
844840
Provisioners: daemon.Provisioners,
845-
GitAuthProviders: gitAuthProviders,
841+
GitAuthConfigs: api.GitAuthConfigs,
846842
Telemetry: api.Telemetry,
847843
Tags: tags,
848844
QuotaCommitter: &api.QuotaCommitter,

coderd/gitauth/config.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package gitauth
22

33
import (
4+
"context"
45
"fmt"
6+
"io"
7+
"net/http"
58
"net/url"
69
"regexp"
710

811
"golang.org/x/oauth2"
912
"golang.org/x/xerrors"
1013

14+
"github.com/coder/coder/coderd/database"
1115
"github.com/coder/coder/coderd/httpapi"
1216
"github.com/coder/coder/coderd/httpmw"
1317
"github.com/coder/coder/codersdk"
@@ -34,6 +38,77 @@ type Config struct {
3438
ValidateURL string
3539
}
3640

41+
// RefreshToken automatically refreshes the token if expired and permitted.
42+
// It returns the token and a bool indicating if the token was refreshed.
43+
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
44+
// If the token is expired and refresh is disabled, we prompt
45+
// the user to authenticate again.
46+
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
47+
return gitAuthLink, false, nil
48+
}
49+
50+
token, err := c.TokenSource(ctx, &oauth2.Token{
51+
AccessToken: gitAuthLink.OAuthAccessToken,
52+
RefreshToken: gitAuthLink.OAuthRefreshToken,
53+
Expiry: gitAuthLink.OAuthExpiry,
54+
}).Token()
55+
if err != nil {
56+
// Even if the token fails to be obtained, we still return false because
57+
// we aren't trying to surface an error, we're just trying to obtain a valid token.
58+
return gitAuthLink, false, nil
59+
}
60+
61+
if c.ValidateURL != "" {
62+
valid, err := c.ValidateToken(ctx, token.AccessToken)
63+
if err != nil {
64+
return gitAuthLink, false, xerrors.Errorf("validate git auth token: %w", err)
65+
}
66+
if !valid {
67+
// The token is no longer valid!
68+
return gitAuthLink, false, nil
69+
}
70+
}
71+
72+
if token.AccessToken != gitAuthLink.OAuthAccessToken {
73+
// Update it
74+
gitAuthLink, err = db.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
75+
ProviderID: c.ID,
76+
UserID: gitAuthLink.UserID,
77+
UpdatedAt: database.Now(),
78+
OAuthAccessToken: token.AccessToken,
79+
OAuthRefreshToken: token.RefreshToken,
80+
OAuthExpiry: token.Expiry,
81+
})
82+
if err != nil {
83+
return gitAuthLink, false, xerrors.Errorf("update git auth link: %w", err)
84+
}
85+
}
86+
return gitAuthLink, true, nil
87+
}
88+
89+
// ValidateToken ensures the Git token provided is valid!
90+
func (c *Config) ValidateToken(ctx context.Context, token string) (bool, error) {
91+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.ValidateURL, nil)
92+
if err != nil {
93+
return false, err
94+
}
95+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
96+
res, err := http.DefaultClient.Do(req)
97+
if err != nil {
98+
return false, err
99+
}
100+
defer res.Body.Close()
101+
if res.StatusCode == http.StatusUnauthorized {
102+
// The token is no longer valid!
103+
return false, nil
104+
}
105+
if res.StatusCode != http.StatusOK {
106+
data, _ := io.ReadAll(res.Body)
107+
return false, xerrors.Errorf("status %d: body: %s", res.StatusCode, data)
108+
}
109+
return true, nil
110+
}
111+
37112
// ConvertConfig converts the SDK configuration entry format
38113
// to the parsed and ready-to-consume in coderd provider type.
39114
func ConvertConfig(entries []codersdk.GitAuthConfig, accessURL *url.URL) ([]*Config, error) {

coderd/gitauth/config_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,122 @@
11
package gitauth_test
22

33
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
47
"net/url"
58
"testing"
9+
"time"
610

711
"github.com/stretchr/testify/require"
12+
"golang.org/x/oauth2"
13+
"golang.org/x/xerrors"
814

15+
"github.com/coder/coder/coderd/database"
16+
"github.com/coder/coder/coderd/database/dbfake"
17+
"github.com/coder/coder/coderd/database/dbgen"
918
"github.com/coder/coder/coderd/gitauth"
1019
"github.com/coder/coder/codersdk"
20+
"github.com/coder/coder/testutil"
1121
)
1222

23+
func TestRefreshToken(t *testing.T) {
24+
t.Parallel()
25+
t.Run("FalseIfNoRefresh", func(t *testing.T) {
26+
t.Parallel()
27+
config := &gitauth.Config{
28+
NoRefresh: true,
29+
}
30+
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
31+
OAuthExpiry: time.Time{},
32+
})
33+
require.NoError(t, err)
34+
require.False(t, refreshed)
35+
})
36+
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
37+
t.Parallel()
38+
config := &gitauth.Config{
39+
OAuth2Config: &testutil.OAuth2Config{
40+
TokenSourceFunc: func() (*oauth2.Token, error) {
41+
return nil, xerrors.New("failure")
42+
},
43+
},
44+
}
45+
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
46+
require.NoError(t, err)
47+
require.False(t, refreshed)
48+
})
49+
t.Run("ValidateServerError", func(t *testing.T) {
50+
t.Parallel()
51+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
52+
w.WriteHeader(http.StatusInternalServerError)
53+
w.Write([]byte("Failure"))
54+
}))
55+
config := &gitauth.Config{
56+
OAuth2Config: &testutil.OAuth2Config{},
57+
ValidateURL: srv.URL,
58+
}
59+
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
60+
require.ErrorContains(t, err, "Failure")
61+
})
62+
t.Run("ValidateFailure", func(t *testing.T) {
63+
t.Parallel()
64+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
w.WriteHeader(http.StatusUnauthorized)
66+
w.Write([]byte("Not permitted"))
67+
}))
68+
config := &gitauth.Config{
69+
OAuth2Config: &testutil.OAuth2Config{},
70+
ValidateURL: srv.URL,
71+
}
72+
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
73+
require.NoError(t, err)
74+
require.False(t, refreshed)
75+
})
76+
t.Run("ValidateNoUpdate", func(t *testing.T) {
77+
t.Parallel()
78+
validated := make(chan struct{})
79+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80+
w.WriteHeader(http.StatusOK)
81+
close(validated)
82+
}))
83+
accessToken := "testing"
84+
config := &gitauth.Config{
85+
OAuth2Config: &testutil.OAuth2Config{
86+
Token: &oauth2.Token{
87+
AccessToken: accessToken,
88+
},
89+
},
90+
ValidateURL: srv.URL,
91+
}
92+
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
93+
OAuthAccessToken: accessToken,
94+
})
95+
require.NoError(t, err)
96+
require.True(t, valid)
97+
<-validated
98+
})
99+
t.Run("Updates", func(t *testing.T) {
100+
t.Parallel()
101+
config := &gitauth.Config{
102+
ID: "test",
103+
OAuth2Config: &testutil.OAuth2Config{
104+
Token: &oauth2.Token{
105+
AccessToken: "updated",
106+
},
107+
},
108+
}
109+
db := dbfake.New()
110+
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
111+
ProviderID: config.ID,
112+
OAuthAccessToken: "initial",
113+
})
114+
_, valid, err := config.RefreshToken(context.Background(), db, link)
115+
require.NoError(t, err)
116+
require.True(t, valid)
117+
})
118+
}
119+
13120
func TestConvertYAML(t *testing.T) {
14121
t.Parallel()
15122
for _, tc := range []struct {

coderd/httpmw/apikey_test.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/coder/coder/coderd/httpmw"
2323
"github.com/coder/coder/codersdk"
2424
"github.com/coder/coder/cryptorand"
25+
"github.com/coder/coder/testutil"
2526
)
2627

2728
func randomAPIKeyParts() (id string, secret string) {
@@ -462,10 +463,8 @@ func TestAPIKey(t *testing.T) {
462463
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
463464
DB: db,
464465
OAuth2Configs: &httpmw.OAuth2Configs{
465-
Github: &oauth2Config{
466-
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
467-
return oauthToken, nil
468-
}),
466+
Github: &testutil.OAuth2Config{
467+
Token: oauthToken,
469468
},
470469
},
471470
RedirectToLogin: false,
@@ -597,25 +596,3 @@ func TestAPIKey(t *testing.T) {
597596
require.Equal(t, sentAPIKey.LoginType, gotAPIKey.LoginType)
598597
})
599598
}
600-
601-
type oauth2Config struct {
602-
tokenSource oauth2TokenSource
603-
}
604-
605-
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
606-
return o.tokenSource
607-
}
608-
609-
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
610-
return ""
611-
}
612-
613-
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
614-
return &oauth2.Token{}, nil
615-
}
616-
617-
type oauth2TokenSource func() (*oauth2.Token, error)
618-
619-
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
620-
return o()
621-
}

0 commit comments

Comments
 (0)