Skip to content

Commit 192c81e

Browse files
authored
chore: refactor codersdk to use SessionTokenProvider (coder#19565)
Refactors `codersdk.Client`'s use of session tokens to use a `SessionTokenProvider`, which abstracts the obtaining and storing of the session token. The main motiviation is to unify Agent authentication an an upstack PR, which can use cloud instance identity via token exchange, rather than a fixed session token. However, the abstraction could also allow functionality like obtaining the session token from other external sources like the OS credential manager, or an external secret/key management system like Vault.
1 parent f721f3d commit 192c81e

File tree

15 files changed

+128
-123
lines changed

15 files changed

+128
-123
lines changed

cli/exp_task_status_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE
243243
ctx = testutil.Context(t, testutil.WaitShort)
244244
now = time.Now().UTC() // TODO: replace with quartz
245245
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
246-
client = new(codersdk.Client)
246+
client = codersdk.New(testutil.MustURL(t, srv.URL))
247247
sb = strings.Builder{}
248248
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
249249
)
250250

251251
t.Cleanup(srv.Close)
252-
client.URL = testutil.MustURL(t, srv.URL)
253252
args = append(args, tc.args...)
254253
inv, root := clitest.New(t, args...)
255254
inv.Stdout = &sb

cli/exp_taskcreate_test.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ import (
55
"fmt"
66
"net/http"
77
"net/http/httptest"
8-
"net/url"
98
"strings"
109
"testing"
1110
"time"
1211

1312
"github.com/google/uuid"
1413
"github.com/stretchr/testify/assert"
15-
"github.com/stretchr/testify/require"
1614

1715
"github.com/coder/coder/v2/cli/clitest"
1816
"github.com/coder/coder/v2/cli/cliui"
@@ -236,17 +234,14 @@ func TestTaskCreate(t *testing.T) {
236234
var (
237235
ctx = testutil.Context(t, testutil.WaitShort)
238236
srv = httptest.NewServer(tt.handler(t, ctx))
239-
client = new(codersdk.Client)
237+
client = codersdk.New(testutil.MustURL(t, srv.URL))
240238
args = []string{"exp", "task", "create"}
241239
sb strings.Builder
242240
err error
243241
)
244242

245243
t.Cleanup(srv.Close)
246244

247-
client.URL, err = url.Parse(srv.URL)
248-
require.NoError(t, err)
249-
250245
inv, root := clitest.New(t, append(args, tt.args...)...)
251246
inv.Environ = serpent.ParseEnviron(tt.env, "")
252247
inv.Stdout = &sb

cli/root.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
635635
}
636636

637637
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
638+
if client.SessionTokenProvider == nil {
639+
client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{}
640+
}
638641
transport := http.DefaultTransport
639642
transport = wrapTransportWithTelemetryHeader(transport, inv)
640643
if !r.noVersionCheck {

coderd/coderdtest/oidctest/idp.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
641641

642642
// ExternalLogin does the oauth2 flow for external auth providers. This requires
643643
// an authenticated coder client.
644-
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
644+
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) {
645645
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
646646
require.NoError(t, err)
647647
f.SetRedirect(t, coderOauthURL.String())
@@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
660660
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
661661
require.NoError(t, err)
662662
// External auth flow requires the user be authenticated.
663-
headerName := client.SessionTokenHeader
664-
if headerName == "" {
665-
headerName = codersdk.SessionTokenHeader
666-
}
667-
req.Header.Set(headerName, client.SessionToken())
663+
opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...)
668664
if cli.Jar == nil {
669665
cli.Jar, err = cookiejar.New(nil)
670666
require.NoError(t, err, "failed to create cookie jar")

coderd/mcp/mcp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) {
115115
require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message")
116116

117117
// Test registering tools with valid client should succeed
118-
client := &codersdk.Client{}
118+
client := codersdk.New(testutil.MustURL(t, "http://not-used"))
119119
err = server.RegisterTools(client)
120120
require.NoError(t, err)
121121

codersdk/client.go

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{
108108
// New creates a Coder client for the provided URL.
109109
func New(serverURL *url.URL) *Client {
110110
return &Client{
111-
URL: serverURL,
112-
HTTPClient: &http.Client{},
111+
URL: serverURL,
112+
HTTPClient: &http.Client{},
113+
SessionTokenProvider: FixedSessionTokenProvider{},
113114
}
114115
}
115116

@@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client {
118119
type Client struct {
119120
// mu protects the fields sessionToken, logger, and logBodies. These
120121
// need to be safe for concurrent access.
121-
mu sync.RWMutex
122-
sessionToken string
123-
logger slog.Logger
124-
logBodies bool
122+
mu sync.RWMutex
123+
SessionTokenProvider SessionTokenProvider
124+
logger slog.Logger
125+
logBodies bool
125126

126127
HTTPClient *http.Client
127128
URL *url.URL
128129

129-
// SessionTokenHeader is an optional custom header to use for setting tokens. By
130-
// default 'Coder-Session-Token' is used.
131-
SessionTokenHeader string
132-
133130
// PlainLogger may be set to log HTTP traffic in a human-readable form.
134131
// It uses the LogBodies option.
135132
PlainLogger io.Writer
@@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) {
176173
func (c *Client) SessionToken() string {
177174
c.mu.RLock()
178175
defer c.mu.RUnlock()
179-
return c.sessionToken
176+
return c.SessionTokenProvider.GetSessionToken()
180177
}
181178

182-
// SetSessionToken returns the currently set token for the client.
179+
// SetSessionToken sets a fixed token for the client.
180+
// Deprecated: Create a new client instead of changing the token after creation.
183181
func (c *Client) SetSessionToken(token string) {
182+
c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token})
183+
}
184+
185+
// SetSessionTokenProvider sets the session token provider for the client.
186+
func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) {
184187
c.mu.Lock()
185188
defer c.mu.Unlock()
186-
c.sessionToken = token
189+
c.SessionTokenProvider = provider
187190
}
188191

189192
func prefixLines(prefix, s []byte) []byte {
@@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte {
199202
// Request performs a HTTP request with the body provided. The caller is
200203
// responsible for closing the response body.
201204
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
205+
opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...)
206+
return c.RequestWithoutSessionToken(ctx, method, path, body, opts...)
207+
}
208+
209+
// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set
210+
// the session token in the request header, nor does it make a call to the SessionTokenProvider.
211+
// This allows session token providers to call this method without causing reentrancy issues.
212+
func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
202213
if ctx == nil {
203214
return nil, xerrors.Errorf("context should not be nil")
204215
}
@@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
248259
return nil, xerrors.Errorf("create request: %w", err)
249260
}
250261

251-
tokenHeader := c.SessionTokenHeader
252-
if tokenHeader == "" {
253-
tokenHeader = SessionTokenHeader
254-
}
255-
req.Header.Set(tokenHeader, c.SessionToken())
256-
257262
if r != nil {
258263
req.Header.Set("Content-Type", "application/json")
259264
}
@@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
345350
return nil, err
346351
}
347352

348-
tokenHeader := c.SessionTokenHeader
349-
if tokenHeader == "" {
350-
tokenHeader = SessionTokenHeader
351-
}
352-
353353
if opts == nil {
354354
opts = &websocket.DialOptions{}
355355
}
356-
if opts.HTTPHeader == nil {
357-
opts.HTTPHeader = http.Header{}
358-
}
359-
if opts.HTTPHeader.Get(tokenHeader) == "" {
360-
opts.HTTPHeader.Set(tokenHeader, c.SessionToken())
361-
}
356+
c.SessionTokenProvider.SetDialOption(opts)
362357

363358
conn, resp, err := websocket.Dial(ctx, u.String(), opts)
364359
if resp != nil && resp.Body != nil {

codersdk/credentials.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package codersdk
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/coder/websocket"
7+
)
8+
9+
// SessionTokenProvider provides the session token to access the Coder service (coderd).
10+
// @typescript-ignore SessionTokenProvider
11+
type SessionTokenProvider interface {
12+
// AsRequestOption returns a request option that attaches the session token to an HTTP request.
13+
AsRequestOption() RequestOption
14+
// SetDialOption sets the session token on a websocket request via DialOptions
15+
SetDialOption(options *websocket.DialOptions)
16+
// GetSessionToken returns the session token as a string.
17+
GetSessionToken() string
18+
}
19+
20+
// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable
21+
// at the program start.
22+
// @typescript-ignore FixedSessionTokenProvider
23+
type FixedSessionTokenProvider struct {
24+
SessionToken string
25+
// SessionTokenHeader is an optional custom header to use for setting tokens. By
26+
// default, 'Coder-Session-Token' is used.
27+
SessionTokenHeader string
28+
}
29+
30+
func (f FixedSessionTokenProvider) AsRequestOption() RequestOption {
31+
return func(req *http.Request) {
32+
tokenHeader := f.SessionTokenHeader
33+
if tokenHeader == "" {
34+
tokenHeader = SessionTokenHeader
35+
}
36+
req.Header.Set(tokenHeader, f.SessionToken)
37+
}
38+
}
39+
40+
func (f FixedSessionTokenProvider) GetSessionToken() string {
41+
return f.SessionToken
42+
}
43+
44+
func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
45+
tokenHeader := f.SessionTokenHeader
46+
if tokenHeader == "" {
47+
tokenHeader = SessionTokenHeader
48+
}
49+
if opts.HTTPHeader == nil {
50+
opts.HTTPHeader = http.Header{}
51+
}
52+
if opts.HTTPHeader.Get(tokenHeader) == "" {
53+
opts.HTTPHeader.Set(tokenHeader, f.SessionToken)
54+
}
55+
}

codersdk/workspacesdk/workspacesdk.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
215215
options.BlockEndpoints = true
216216
}
217217

218-
headers := make(http.Header)
219-
tokenHeader := codersdk.SessionTokenHeader
220-
if c.client.SessionTokenHeader != "" {
221-
tokenHeader = c.client.SessionTokenHeader
218+
wsOptions := &websocket.DialOptions{
219+
HTTPClient: c.client.HTTPClient,
220+
// Need to disable compression to avoid a data-race.
221+
CompressionMode: websocket.CompressionDisabled,
222222
}
223-
headers.Set(tokenHeader, c.client.SessionToken())
223+
c.client.SessionTokenProvider.SetDialOption(wsOptions)
224224

225225
// New context, separate from dialCtx. We don't want to cancel the
226226
// connection if dialCtx is canceled.
@@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
236236
return nil, xerrors.Errorf("parse url: %w", err)
237237
}
238238

239-
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
240-
HTTPClient: c.client.HTTPClient,
241-
HTTPHeader: headers,
242-
// Need to disable compression to avoid a data-race.
243-
CompressionMode: websocket.CompressionDisabled,
244-
})
239+
dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions)
245240
clk := quartz.NewReal()
246241
controller := tailnet.NewController(options.Logger, dialer)
247242
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)

enterprise/coderd/workspaceproxy_test.go

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
312312
})
313313
require.NoError(t, err)
314314

315-
proxyClient := wsproxysdk.New(client.URL)
316-
proxyClient.SetSessionToken(createRes.ProxyToken)
315+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
317316

318317
// Register
319318
req := wsproxysdk.RegisterWorkspaceProxyRequest{
@@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
427426
})
428427
require.NoError(t, err)
429428

430-
proxyClient := wsproxysdk.New(client.URL)
431-
proxyClient.SetSessionToken(createRes.ProxyToken)
429+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
432430

433431
req := wsproxysdk.RegisterWorkspaceProxyRequest{
434432
AccessURL: "https://proxy.coder.test",
@@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
472470
})
473471
require.NoError(t, err)
474472

475-
proxyClient := wsproxysdk.New(client.URL)
476-
proxyClient.SetSessionToken(createRes.ProxyToken)
473+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
477474

478475
err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
479476
ReplicaID: uuid.New(),
@@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
501498

502499
// Register a replica on proxy 2. This shouldn't be returned by replicas
503500
// for proxy 1.
504-
proxyClient2 := wsproxysdk.New(client.URL)
505-
proxyClient2.SetSessionToken(createRes2.ProxyToken)
501+
proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken)
506502
_, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{
507503
AccessURL: "https://other.proxy.coder.test",
508504
WildcardHostname: "*.other.proxy.coder.test",
@@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
516512
require.NoError(t, err)
517513

518514
// Register replica 1.
519-
proxyClient1 := wsproxysdk.New(client.URL)
520-
proxyClient1.SetSessionToken(createRes1.ProxyToken)
515+
proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken)
521516
req1 := wsproxysdk.RegisterWorkspaceProxyRequest{
522517
AccessURL: "https://one.proxy.coder.test",
523518
WildcardHostname: "*.one.proxy.coder.test",
@@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
574569
})
575570
require.NoError(t, err)
576571

577-
proxyClient := wsproxysdk.New(client.URL)
578-
proxyClient.SetSessionToken(createRes.ProxyToken)
572+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
579573

580574
for i := 0; i < 100; i++ {
581575
ok := false
@@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) {
652646

653647
t.Run("BadAppRequest", func(t *testing.T) {
654648
t.Parallel()
655-
proxyClient := wsproxysdk.New(client.URL)
656-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
649+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
657650

658651
ctx := testutil.Context(t, testutil.WaitLong)
659652
_, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{
@@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) {
674667
}
675668
t.Run("OK", func(t *testing.T) {
676669
t.Parallel()
677-
proxyClient := wsproxysdk.New(client.URL)
678-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
670+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
679671

680672
ctx := testutil.Context(t, testutil.WaitLong)
681673
_, err := proxyClient.IssueSignedAppToken(ctx, goodRequest)
@@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) {
684676

685677
t.Run("OKHTML", func(t *testing.T) {
686678
t.Parallel()
687-
proxyClient := wsproxysdk.New(client.URL)
688-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
679+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
689680

690681
rw := httptest.NewRecorder()
691682
ctx := testutil.Context(t, testutil.WaitLong)
@@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) {
10321023
Name: testutil.GetRandomName(t),
10331024
})
10341025

1035-
client := wsproxysdk.New(cclient.URL)
1036-
client.SetSessionToken(cclient.SessionToken())
1026+
client := wsproxysdk.New(cclient.URL, cclient.SessionToken())
10371027

10381028
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
10391029
require.Error(t, err)

enterprise/wsproxy/wsproxy.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
163163
return nil, err
164164
}
165165

166-
client := wsproxysdk.New(opts.DashboardURL)
167-
err := client.SetSessionToken(opts.ProxySessionToken)
168-
if err != nil {
169-
return nil, xerrors.Errorf("set client token: %w", err)
170-
}
166+
client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken)
171167

172168
// Use the configured client if provided.
173169
if opts.HTTPClient != nil {

0 commit comments

Comments
 (0)