Skip to content

Commit 6f7bb8d

Browse files
committed
chore: refactor codersdk to use SessionTokenProvider
1 parent d7ee101 commit 6f7bb8d

File tree

11 files changed

+119
-111
lines changed

11 files changed

+119
-111
lines changed

cli/root.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
635635
}
636636

637637
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
638+
if client.SessionTokenProvider == nil {
639+
client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{}
640+
}
638641
transport := http.DefaultTransport
639642
transport = wrapTransportWithTelemetryHeader(transport, inv)
640643
if !r.noVersionCheck {

coderd/coderdtest/oidctest/idp.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
641641

642642
// ExternalLogin does the oauth2 flow for external auth providers. This requires
643643
// an authenticated coder client.
644-
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
644+
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) {
645645
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
646646
require.NoError(t, err)
647647
f.SetRedirect(t, coderOauthURL.String())
@@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
660660
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
661661
require.NoError(t, err)
662662
// External auth flow requires the user be authenticated.
663-
headerName := client.SessionTokenHeader
664-
if headerName == "" {
665-
headerName = codersdk.SessionTokenHeader
666-
}
667-
req.Header.Set(headerName, client.SessionToken())
663+
opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...)
668664
if cli.Jar == nil {
669665
cli.Jar, err = cookiejar.New(nil)
670666
require.NoError(t, err, "failed to create cookie jar")

codersdk/client.go

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{
108108
// New creates a Coder client for the provided URL.
109109
func New(serverURL *url.URL) *Client {
110110
return &Client{
111-
URL: serverURL,
112-
HTTPClient: &http.Client{},
111+
URL: serverURL,
112+
HTTPClient: &http.Client{},
113+
SessionTokenProvider: FixedSessionTokenProvider{},
113114
}
114115
}
115116

@@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client {
118119
type Client struct {
119120
// mu protects the fields sessionToken, logger, and logBodies. These
120121
// need to be safe for concurrent access.
121-
mu sync.RWMutex
122-
sessionToken string
123-
logger slog.Logger
124-
logBodies bool
122+
mu sync.RWMutex
123+
SessionTokenProvider SessionTokenProvider
124+
logger slog.Logger
125+
logBodies bool
125126

126127
HTTPClient *http.Client
127128
URL *url.URL
128129

129-
// SessionTokenHeader is an optional custom header to use for setting tokens. By
130-
// default 'Coder-Session-Token' is used.
131-
SessionTokenHeader string
132-
133130
// PlainLogger may be set to log HTTP traffic in a human-readable form.
134131
// It uses the LogBodies option.
135132
PlainLogger io.Writer
@@ -176,14 +173,19 @@ func (c *Client) SetLogBodies(logBodies bool) {
176173
func (c *Client) SessionToken() string {
177174
c.mu.RLock()
178175
defer c.mu.RUnlock()
179-
return c.sessionToken
176+
return c.SessionTokenProvider.GetSessionToken()
180177
}
181178

182-
// SetSessionToken returns the currently set token for the client.
179+
// SetSessionToken sets a fixed token for the client.
183180
func (c *Client) SetSessionToken(token string) {
181+
c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token})
182+
}
183+
184+
// SetSessionTokenProvider sets the session token provider for the client.
185+
func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) {
184186
c.mu.Lock()
185187
defer c.mu.Unlock()
186-
c.sessionToken = token
188+
c.SessionTokenProvider = provider
187189
}
188190

189191
func prefixLines(prefix, s []byte) []byte {
@@ -199,6 +201,11 @@ func prefixLines(prefix, s []byte) []byte {
199201
// Request performs a HTTP request with the body provided. The caller is
200202
// responsible for closing the response body.
201203
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
204+
opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...)
205+
return c.RequestNoSessionToken(ctx, method, path, body, opts...)
206+
}
207+
208+
func (c *Client) RequestNoSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
202209
if ctx == nil {
203210
return nil, xerrors.Errorf("context should not be nil")
204211
}
@@ -248,12 +255,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
248255
return nil, xerrors.Errorf("create request: %w", err)
249256
}
250257

251-
tokenHeader := c.SessionTokenHeader
252-
if tokenHeader == "" {
253-
tokenHeader = SessionTokenHeader
254-
}
255-
req.Header.Set(tokenHeader, c.SessionToken())
256-
257258
if r != nil {
258259
req.Header.Set("Content-Type", "application/json")
259260
}
@@ -345,20 +346,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
345346
return nil, err
346347
}
347348

348-
tokenHeader := c.SessionTokenHeader
349-
if tokenHeader == "" {
350-
tokenHeader = SessionTokenHeader
351-
}
352-
353349
if opts == nil {
354350
opts = &websocket.DialOptions{}
355351
}
356-
if opts.HTTPHeader == nil {
357-
opts.HTTPHeader = http.Header{}
358-
}
359-
if opts.HTTPHeader.Get(tokenHeader) == "" {
360-
opts.HTTPHeader.Set(tokenHeader, c.SessionToken())
361-
}
352+
c.SessionTokenProvider.SetDialOption(opts)
362353

363354
conn, resp, err := websocket.Dial(ctx, u.String(), opts)
364355
if resp != nil && resp.Body != nil {

codersdk/credentials.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package codersdk
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/coder/websocket"
7+
)
8+
9+
// SessionTokenProvider provides the session token to access the Coder service (coderd).
10+
// @typescript-ignore SessionTokenProvider
11+
type SessionTokenProvider interface {
12+
// AsRequestOption returns a request option that attaches the session token to an HTTP request.
13+
AsRequestOption() RequestOption
14+
// SetDialOption sets the session token on a websocket request via DialOptions
15+
SetDialOption(options *websocket.DialOptions)
16+
// GetSessionToken returns the session token as a string.
17+
GetSessionToken() string
18+
}
19+
20+
// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable
21+
// at the program start.
22+
// @typescript-ignore FixedSessionTokenProvider
23+
type FixedSessionTokenProvider struct {
24+
SessionToken string
25+
// SessionTokenHeader is an optional custom header to use for setting tokens. By
26+
// default, 'Coder-Session-Token' is used.
27+
SessionTokenHeader string
28+
}
29+
30+
func (f FixedSessionTokenProvider) AsRequestOption() RequestOption {
31+
return func(req *http.Request) {
32+
tokenHeader := f.SessionTokenHeader
33+
if tokenHeader == "" {
34+
tokenHeader = SessionTokenHeader
35+
}
36+
req.Header.Set(tokenHeader, f.SessionToken)
37+
}
38+
}
39+
40+
func (f FixedSessionTokenProvider) GetSessionToken() string {
41+
return f.SessionToken
42+
}
43+
44+
func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
45+
tokenHeader := f.SessionTokenHeader
46+
if tokenHeader == "" {
47+
tokenHeader = SessionTokenHeader
48+
}
49+
if opts.HTTPHeader == nil {
50+
opts.HTTPHeader = http.Header{}
51+
}
52+
if opts.HTTPHeader.Get(tokenHeader) == "" {
53+
opts.HTTPHeader.Set(tokenHeader, f.SessionToken)
54+
}
55+
}

codersdk/workspacesdk/workspacesdk.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
215215
options.BlockEndpoints = true
216216
}
217217

218-
headers := make(http.Header)
219-
tokenHeader := codersdk.SessionTokenHeader
220-
if c.client.SessionTokenHeader != "" {
221-
tokenHeader = c.client.SessionTokenHeader
218+
wsOptions := &websocket.DialOptions{
219+
HTTPClient: c.client.HTTPClient,
220+
// Need to disable compression to avoid a data-race.
221+
CompressionMode: websocket.CompressionDisabled,
222222
}
223-
headers.Set(tokenHeader, c.client.SessionToken())
223+
c.client.SessionTokenProvider.SetDialOption(wsOptions)
224224

225225
// New context, separate from dialCtx. We don't want to cancel the
226226
// connection if dialCtx is canceled.
@@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
236236
return nil, xerrors.Errorf("parse url: %w", err)
237237
}
238238

239-
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
240-
HTTPClient: c.client.HTTPClient,
241-
HTTPHeader: headers,
242-
// Need to disable compression to avoid a data-race.
243-
CompressionMode: websocket.CompressionDisabled,
244-
})
239+
dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions)
245240
clk := quartz.NewReal()
246241
controller := tailnet.NewController(options.Logger, dialer)
247242
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)

enterprise/coderd/workspaceproxy_test.go

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
312312
})
313313
require.NoError(t, err)
314314

315-
proxyClient := wsproxysdk.New(client.URL)
316-
proxyClient.SetSessionToken(createRes.ProxyToken)
315+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
317316

318317
// Register
319318
req := wsproxysdk.RegisterWorkspaceProxyRequest{
@@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
427426
})
428427
require.NoError(t, err)
429428

430-
proxyClient := wsproxysdk.New(client.URL)
431-
proxyClient.SetSessionToken(createRes.ProxyToken)
429+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
432430

433431
req := wsproxysdk.RegisterWorkspaceProxyRequest{
434432
AccessURL: "https://proxy.coder.test",
@@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
472470
})
473471
require.NoError(t, err)
474472

475-
proxyClient := wsproxysdk.New(client.URL)
476-
proxyClient.SetSessionToken(createRes.ProxyToken)
473+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
477474

478475
err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
479476
ReplicaID: uuid.New(),
@@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
501498

502499
// Register a replica on proxy 2. This shouldn't be returned by replicas
503500
// for proxy 1.
504-
proxyClient2 := wsproxysdk.New(client.URL)
505-
proxyClient2.SetSessionToken(createRes2.ProxyToken)
501+
proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken)
506502
_, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{
507503
AccessURL: "https://other.proxy.coder.test",
508504
WildcardHostname: "*.other.proxy.coder.test",
@@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
516512
require.NoError(t, err)
517513

518514
// Register replica 1.
519-
proxyClient1 := wsproxysdk.New(client.URL)
520-
proxyClient1.SetSessionToken(createRes1.ProxyToken)
515+
proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken)
521516
req1 := wsproxysdk.RegisterWorkspaceProxyRequest{
522517
AccessURL: "https://one.proxy.coder.test",
523518
WildcardHostname: "*.one.proxy.coder.test",
@@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
574569
})
575570
require.NoError(t, err)
576571

577-
proxyClient := wsproxysdk.New(client.URL)
578-
proxyClient.SetSessionToken(createRes.ProxyToken)
572+
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
579573

580574
for i := 0; i < 100; i++ {
581575
ok := false
@@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) {
652646

653647
t.Run("BadAppRequest", func(t *testing.T) {
654648
t.Parallel()
655-
proxyClient := wsproxysdk.New(client.URL)
656-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
649+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
657650

658651
ctx := testutil.Context(t, testutil.WaitLong)
659652
_, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{
@@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) {
674667
}
675668
t.Run("OK", func(t *testing.T) {
676669
t.Parallel()
677-
proxyClient := wsproxysdk.New(client.URL)
678-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
670+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
679671

680672
ctx := testutil.Context(t, testutil.WaitLong)
681673
_, err := proxyClient.IssueSignedAppToken(ctx, goodRequest)
@@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) {
684676

685677
t.Run("OKHTML", func(t *testing.T) {
686678
t.Parallel()
687-
proxyClient := wsproxysdk.New(client.URL)
688-
proxyClient.SetSessionToken(proxyRes.ProxyToken)
679+
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
689680

690681
rw := httptest.NewRecorder()
691682
ctx := testutil.Context(t, testutil.WaitLong)
@@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) {
10321023
Name: testutil.GetRandomName(t),
10331024
})
10341025

1035-
client := wsproxysdk.New(cclient.URL)
1036-
client.SetSessionToken(cclient.SessionToken())
1026+
client := wsproxysdk.New(cclient.URL, cclient.SessionToken())
10371027

10381028
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
10391029
require.Error(t, err)

enterprise/wsproxy/wsproxy.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
163163
return nil, err
164164
}
165165

166-
client := wsproxysdk.New(opts.DashboardURL)
167-
err := client.SetSessionToken(opts.ProxySessionToken)
168-
if err != nil {
169-
return nil, xerrors.Errorf("set client token: %w", err)
170-
}
166+
client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken)
171167

172168
// Use the configured client if provided.
173169
if opts.HTTPClient != nil {

enterprise/wsproxy/wsproxy_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
577577
t.Cleanup(srv.Close)
578578

579579
// Register a proxy.
580-
wsproxyClient := wsproxysdk.New(primaryAccessURL)
581-
wsproxyClient.SetSessionToken(token)
580+
wsproxyClient := wsproxysdk.New(primaryAccessURL, token)
582581
hostname, err := cryptorand.String(6)
583582
require.NoError(t, err)
584583
replicaID := uuid.New()
@@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
879878
require.Contains(t, respJSON.Warnings[0], "High availability networking")
880879

881880
// Deregister the other replica.
882-
wsproxyClient := wsproxysdk.New(api.AccessURL)
883-
wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken)
881+
wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken)
884882
err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
885883
ReplicaID: otherReplicaID,
886884
})

0 commit comments

Comments
 (0)