diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index a844a7e8c6258..4580ff3e1bc8a 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -864,6 +864,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { tickInterval time.Duration bytesPerTick int64 ssh bool + disableDirect bool useHostLogin bool app string template string @@ -1023,15 +1024,16 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { // Setup our workspace agent connection. config := workspacetraffic.Config{ - AgentID: agent.ID, - BytesPerTick: bytesPerTick, - Duration: strategy.timeout, - TickInterval: tickInterval, - ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name), - WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name), - SSH: ssh, - Echo: ssh, - App: appConfig, + AgentID: agent.ID, + BytesPerTick: bytesPerTick, + Duration: strategy.timeout, + TickInterval: tickInterval, + ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name), + WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name), + SSH: ssh, + DisableDirect: disableDirect, + Echo: ssh, + App: appConfig, } if webClient != nil { @@ -1117,6 +1119,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command { Description: "Send traffic over SSH, cannot be used with --app.", Value: serpent.BoolOf(&ssh), }, + { + Flag: "disable-direct", + Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_DISABLE_DIRECT_CONNECTIONS", + Default: "false", + Description: "Disable direct connections for SSH traffic to workspaces. Does nothing if `--ssh` is not also set.", + Value: serpent.BoolOf(&disableDirect), + }, { Flag: "app", Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_APP", diff --git a/cli/exp_task_status_test.go b/cli/exp_task_status_test.go index 6631980ac1fbd..b520d2728804e 100644 --- a/cli/exp_task_status_test.go +++ b/cli/exp_task_status_test.go @@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE ctx = testutil.Context(t, testutil.WaitShort) now = time.Now().UTC() // TODO: replace with quartz srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now))) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) sb = strings.Builder{} args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()} ) t.Cleanup(srv.Close) - client.URL = testutil.MustURL(t, srv.URL) args = append(args, tc.args...) inv, root := clitest.New(t, args...) inv.Stdout = &sb diff --git a/cli/exp_taskcreate_test.go b/cli/exp_taskcreate_test.go index 121f22eb525f6..f49c2fee1194a 100644 --- a/cli/exp_taskcreate_test.go +++ b/cli/exp_taskcreate_test.go @@ -5,14 +5,12 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "strings" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" @@ -236,7 +234,7 @@ func TestTaskCreate(t *testing.T) { var ( ctx = testutil.Context(t, testutil.WaitShort) srv = httptest.NewServer(tt.handler(t, ctx)) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) args = []string{"exp", "task", "create"} sb strings.Builder err error @@ -244,9 +242,6 @@ func TestTaskCreate(t *testing.T) { t.Cleanup(srv.Close) - client.URL, err = url.Parse(srv.URL) - require.NoError(t, err) - inv, root := clitest.New(t, append(args, tt.args...)...) inv.Environ = serpent.ParseEnviron(tt.env, "") inv.Stdout = &sb diff --git a/cli/root.go b/cli/root.go index b3e67a46ad463..ed6869b6a1c49 100644 --- a/cli/root.go +++ b/cli/root.go @@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod } func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { + if client.SessionTokenProvider == nil { + client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{} + } transport := http.DefaultTransport transport = wrapTransportWithTelemetryHeader(transport, inv) if !r.noVersionCheck { diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index c7f7d35937198..a76f6447dcabd 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken // ExternalLogin does the oauth2 flow for external auth providers. This requires // an authenticated coder client. -func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) { +func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) { coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID)) require.NoError(t, err) f.SetRedirect(t, coderOauthURL.String()) @@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil) require.NoError(t, err) // External auth flow requires the user be authenticated. - headerName := client.SessionTokenHeader - if headerName == "" { - headerName = codersdk.SessionTokenHeader - } - req.Header.Set(headerName, client.SessionToken()) + opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...) if cli.Jar == nil { cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") diff --git a/coderd/mcp/mcp_test.go b/coderd/mcp/mcp_test.go index 0c53c899b9830..b7b5a714780d9 100644 --- a/coderd/mcp/mcp_test.go +++ b/coderd/mcp/mcp_test.go @@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) { require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message") // Test registering tools with valid client should succeed - client := &codersdk.Client{} + client := codersdk.New(testutil.MustURL(t, "http://not-used")) err = server.RegisterTools(client) require.NoError(t, err) diff --git a/codersdk/client.go b/codersdk/client.go index 105c8437f841b..b6f10465e3a07 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{ // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ - URL: serverURL, - HTTPClient: &http.Client{}, + URL: serverURL, + HTTPClient: &http.Client{}, + SessionTokenProvider: FixedSessionTokenProvider{}, } } @@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client { type Client struct { // mu protects the fields sessionToken, logger, and logBodies. These // need to be safe for concurrent access. - mu sync.RWMutex - sessionToken string - logger slog.Logger - logBodies bool + mu sync.RWMutex + SessionTokenProvider SessionTokenProvider + logger slog.Logger + logBodies bool HTTPClient *http.Client URL *url.URL - // SessionTokenHeader is an optional custom header to use for setting tokens. By - // default 'Coder-Session-Token' is used. - SessionTokenHeader string - // PlainLogger may be set to log HTTP traffic in a human-readable form. // It uses the LogBodies option. PlainLogger io.Writer @@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) { func (c *Client) SessionToken() string { c.mu.RLock() defer c.mu.RUnlock() - return c.sessionToken + return c.SessionTokenProvider.GetSessionToken() } -// SetSessionToken returns the currently set token for the client. +// SetSessionToken sets a fixed token for the client. +// Deprecated: Create a new client instead of changing the token after creation. func (c *Client) SetSessionToken(token string) { + c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token}) +} + +// SetSessionTokenProvider sets the session token provider for the client. +func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) { c.mu.Lock() defer c.mu.Unlock() - c.sessionToken = token + c.SessionTokenProvider = provider } func prefixLines(prefix, s []byte) []byte { @@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte { // Request performs a HTTP request with the body provided. The caller is // responsible for closing the response body. func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { + opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...) + return c.RequestWithoutSessionToken(ctx, method, path, body, opts...) +} + +// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set +// the session token in the request header, nor does it make a call to the SessionTokenProvider. +// This allows session token providers to call this method without causing reentrancy issues. +func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { if ctx == nil { return nil, xerrors.Errorf("context should not be nil") } @@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac return nil, xerrors.Errorf("create request: %w", err) } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - req.Header.Set(tokenHeader, c.SessionToken()) - if r != nil { req.Header.Set("Content-Type", "application/json") } @@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti return nil, err } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - if opts == nil { opts = &websocket.DialOptions{} } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - if opts.HTTPHeader.Get(tokenHeader) == "" { - opts.HTTPHeader.Set(tokenHeader, c.SessionToken()) - } + c.SessionTokenProvider.SetDialOption(opts) conn, resp, err := websocket.Dial(ctx, u.String(), opts) if resp != nil && resp.Body != nil { diff --git a/codersdk/credentials.go b/codersdk/credentials.go new file mode 100644 index 0000000000000..06dc8cc22a114 --- /dev/null +++ b/codersdk/credentials.go @@ -0,0 +1,55 @@ +package codersdk + +import ( + "net/http" + + "github.com/coder/websocket" +) + +// SessionTokenProvider provides the session token to access the Coder service (coderd). +// @typescript-ignore SessionTokenProvider +type SessionTokenProvider interface { + // AsRequestOption returns a request option that attaches the session token to an HTTP request. + AsRequestOption() RequestOption + // SetDialOption sets the session token on a websocket request via DialOptions + SetDialOption(options *websocket.DialOptions) + // GetSessionToken returns the session token as a string. + GetSessionToken() string +} + +// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable +// at the program start. +// @typescript-ignore FixedSessionTokenProvider +type FixedSessionTokenProvider struct { + SessionToken string + // SessionTokenHeader is an optional custom header to use for setting tokens. By + // default, 'Coder-Session-Token' is used. + SessionTokenHeader string +} + +func (f FixedSessionTokenProvider) AsRequestOption() RequestOption { + return func(req *http.Request) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + req.Header.Set(tokenHeader, f.SessionToken) + } +} + +func (f FixedSessionTokenProvider) GetSessionToken() string { + return f.SessionToken +} + +func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + if opts.HTTPHeader.Get(tokenHeader) == "" { + opts.HTTPHeader.Set(tokenHeader, f.SessionToken) + } +} diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index ddaec06388238..29ddbd1f53094 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * options.BlockEndpoints = true } - headers := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.client.SessionTokenHeader != "" { - tokenHeader = c.client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: c.client.HTTPClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, } - headers.Set(tokenHeader, c.client.SessionToken()) + c.client.SessionTokenProvider.SetDialOption(wsOptions) // New context, separate from dialCtx. We don't want to cancel the // connection if dialCtx is canceled. @@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * return nil, xerrors.Errorf("parse url: %w", err) } - dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{ - HTTPClient: c.client.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) + dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions) clk := quartz.NewReal() controller := tailnet.NewController(options.Logger, dialer) controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk) diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 7024ad2366423..28d46c0137b0d 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) // Register req := wsproxysdk.RegisterWorkspaceProxyRequest{ @@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) req := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://proxy.coder.test", @@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: uuid.New(), @@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) { // Register a replica on proxy 2. This shouldn't be returned by replicas // for proxy 1. - proxyClient2 := wsproxysdk.New(client.URL) - proxyClient2.SetSessionToken(createRes2.ProxyToken) + proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken) _, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://other.proxy.coder.test", WildcardHostname: "*.other.proxy.coder.test", @@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) { require.NoError(t, err) // Register replica 1. - proxyClient1 := wsproxysdk.New(client.URL) - proxyClient1.SetSessionToken(createRes1.ProxyToken) + proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken) req1 := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://one.proxy.coder.test", WildcardHostname: "*.one.proxy.coder.test", @@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) for i := 0; i < 100; i++ { ok := false @@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("BadAppRequest", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{ @@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) { } t.Run("OK", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, goodRequest) @@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("OKHTML", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) rw := httptest.NewRecorder() ctx := testutil.Context(t, testutil.WaitLong) @@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) { Name: testutil.GetRandomName(t), }) - client := wsproxysdk.New(cclient.URL) - client.SetSessionToken(cclient.SessionToken()) + client := wsproxysdk.New(cclient.URL, cclient.SessionToken()) _, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.Error(t, err) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index c2ac1baf2db4e..6e1da2f25853d 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return nil, err } - client := wsproxysdk.New(opts.DashboardURL) - err := client.SetSessionToken(opts.ProxySessionToken) - if err != nil { - return nil, xerrors.Errorf("set client token: %w", err) - } + client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken) // Use the configured client if provided. if opts.HTTPClient != nil { diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 523d429476243..0e8e61af88995 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { t.Cleanup(srv.Close) // Register a proxy. - wsproxyClient := wsproxysdk.New(primaryAccessURL) - wsproxyClient.SetSessionToken(token) + wsproxyClient := wsproxysdk.New(primaryAccessURL, token) hostname, err := cryptorand.String(6) require.NoError(t, err) replicaID := uuid.New() @@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { require.Contains(t, respJSON.Warnings[0], "High availability networking") // Deregister the other replica. - wsproxyClient := wsproxysdk.New(api.AccessURL) - wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken) + wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken) err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: otherReplicaID, }) diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 72f5a4291c40e..15400a2d33c16 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -33,15 +33,20 @@ type Client struct { // New creates a external proxy client for the provided primary coder server // URL. -func New(serverURL *url.URL) *Client { +func New(serverURL *url.URL, sessionToken string) *Client { sdkClient := codersdk.New(serverURL) - sdkClient.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader - + sdkClient.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } sdkClientIgnoreRedirects := codersdk.New(serverURL) sdkClientIgnoreRedirects.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } - sdkClientIgnoreRedirects.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader + sdkClientIgnoreRedirects.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } return &Client{ SDKClient: sdkClient, @@ -49,14 +54,6 @@ func New(serverURL *url.URL) *Client { } } -// SetSessionToken sets the session token for the client. An error is returned -// if the session token is not in the correct format for external proxies. -func (c *Client) SetSessionToken(token string) error { - c.SDKClient.SetSessionToken(token) - c.sdkClientIgnoreRedirects.SetSessionToken(token) - return nil -} - // SessionToken returns the currently set token for the client. func (c *Client) SessionToken() string { return c.SDKClient.SessionToken() @@ -506,17 +503,12 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) { if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } - coordinateHeaders := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.SDKClient.SessionTokenHeader != "" { - tokenHeader = c.SDKClient.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: c.SDKClient.HTTPClient, } - coordinateHeaders.Set(tokenHeader, c.SessionToken()) + c.SDKClient.SessionTokenProvider.SetDialOption(wsOptions) - return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{ - HTTPClient: c.SDKClient.HTTPClient, - HTTPHeader: coordinateHeaders, - }), nil + return workspacesdk.NewWebsocketDialer(logger, coordinateURL, wsOptions), nil } type CryptoKeysResponse struct { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index aada23da9dc12..6b4da6831c9bf 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -60,8 +60,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) @@ -111,8 +110,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - _ = client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) diff --git a/go.mod b/go.mod index f111e6e6260d7..dd8109b35bcf0 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ replace github.com/tcnksm/go-httpstat => github.com/coder/go-httpstat v0.0.0-202 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716 +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e // This is replaced to include // 1. a fix for a data race: c.f. https://github.com/tailscale/wireguard-go/pull/25 @@ -530,7 +530,7 @@ require ( github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/tmaxmax/go-sse v0.10.0 // indirect - github.com/ulikunitz/xz v0.5.12 // indirect + github.com/ulikunitz/xz v0.5.15 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.36.0 // indirect diff --git a/go.sum b/go.sum index ba73e2228f398..b0ec2563d5dbf 100644 --- a/go.sum +++ b/go.sum @@ -928,8 +928,8 @@ github.com/coder/serpent v0.10.0 h1:ofVk9FJXSek+SmL3yVE3GoArP83M+1tX+H7S4t8BSuM= github.com/coder/serpent v0.10.0/go.mod h1:cZFW6/fP+kE9nd/oRkEHJpG6sXCtQ+AX7WMMEHv0Y3Q= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788 h1:YoUSJ19E8AtuUFVYBpXuOD6a/zVP3rcxezNsoDseTUw= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788/go.mod h1:aGQbuCLyhRLMzZF067xc84Lh7JDs1FKwCmF1Crl9dxQ= -github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716 h1:hi7o0sA+RPBq8Rvvz+hNrC/OTL2897OKREMIRIuQeTs= -github.com/coder/tailscale v1.1.1-0.20250729141742-067f1e5d9716/go.mod h1:l7ml5uu7lFh5hY28lGYM4b/oFSmuPHYX6uk4RAu23Lc= +github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e h1:9RKGKzGLHtTvVBQublzDGtCtal3cXP13diCHoAIGPeI= +github.com/coder/tailscale v1.1.1-0.20250829055706-6eafe0f9199e/go.mod h1:jU9T1vEs+DOs8NtGp1F2PT0/TOGVwtg/JCCKYRgvMOs= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e h1:JNLPDi2P73laR1oAclY6jWzAbucf70ASAvf5mh2cME0= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= github.com/coder/terraform-provider-coder/v2 v2.10.0 h1:cGPMfARGHKb80kZsbDX/t/YKwMOwI5zkIyVCQziHR2M= @@ -1828,8 +1828,8 @@ github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1 github.com/u-root/uio v0.0.0-20240209044354-b3d14b93376a h1:BH1SOPEvehD2kVrndDnGJiUF0TrBpNs+iyYocu6h0og= github.com/u-root/uio v0.0.0-20240209044354-b3d14b93376a/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= -github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= +github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/scaletest/workspacetraffic/config.go b/scaletest/workspacetraffic/config.go index 6ef0760ff3013..0948d35ea7dbb 100644 --- a/scaletest/workspacetraffic/config.go +++ b/scaletest/workspacetraffic/config.go @@ -28,6 +28,9 @@ type Config struct { SSH bool `json:"ssh"` + // Ignored unless SSH is true. + DisableDirect bool `json:"ssh_disable_direct"` + // Echo controls whether the agent should echo the data it receives. // If false, the agent will discard the data. Note that setting this // to true will double the amount of data read from the agent for diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go index 7640203e6c224..3b516c6347225 100644 --- a/scaletest/workspacetraffic/conn.go +++ b/scaletest/workspacetraffic/conn.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net" - "net/http" "sync" "time" @@ -144,7 +143,7 @@ func (c *rptyConn) Close() (err error) { } //nolint:revive // Ignore requestPTY control flag. -func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, cmd string, requestPTY bool) (rwc *countReadWriteCloser, err error) { +func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, cmd string, requestPTY bool, blockEndpoints bool) (rwc *countReadWriteCloser, err error) { var closers []func() error defer func() { if err != nil { @@ -156,7 +155,9 @@ func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, } }() - agentConn, err := workspacesdk.New(client).DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{}) + agentConn, err := workspacesdk.New(client).DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{ + BlockEndpoints: blockEndpoints, + }) if err != nil { return nil, xerrors.Errorf("dial workspace agent: %w", err) } @@ -267,18 +268,13 @@ func (w *wrappedSSHConn) Write(p []byte) (n int, err error) { } func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) { - headers := http.Header{} - tokenHeader := codersdk.SessionTokenHeader - if client.SessionTokenHeader != "" { - tokenHeader = client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: client.HTTPClient, } - headers.Set(tokenHeader, client.SessionToken()) + client.SessionTokenProvider.SetDialOption(wsOptions) //nolint:bodyclose // The websocket conn manages the body. - conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ - HTTPClient: client.HTTPClient, - HTTPHeader: headers, - }) + conn, _, err := websocket.Dial(ctx, url, wsOptions) if err != nil { return nil, xerrors.Errorf("websocket dial: %w", err) } diff --git a/scaletest/workspacetraffic/run.go b/scaletest/workspacetraffic/run.go index cad6a9d51c6ce..7dd7cb6803695 100644 --- a/scaletest/workspacetraffic/run.go +++ b/scaletest/workspacetraffic/run.go @@ -111,7 +111,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) // If echo is enabled, disable PTY to avoid double echo and // reduce CPU usage. requestPTY := !r.cfg.Echo - conn, err = connectSSH(ctx, r.client, agentID, command, requestPTY) + conn, err = connectSSH(ctx, r.client, agentID, command, requestPTY, r.cfg.DisableDirect) if err != nil { logger.Error(ctx, "connect to workspace agent via ssh", slog.Error(err)) return xerrors.Errorf("connect to workspace via ssh: %w", err) diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index 59801e68d8f62..dd84747886456 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "runtime" "slices" "strings" @@ -313,9 +314,7 @@ func TestRun(t *testing.T) { readMetrics = &testMetrics{} writeMetrics = &testMetrics{} ) - client := &codersdk.Client{ - HTTPClient: &http.Client{}, - } + client := codersdk.New(&url.URL{}) runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ BytesPerTick: int64(bytesPerTick), TickInterval: tickInterval,