Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions cli/exp_scaletest.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
tickInterval time.Duration
bytesPerTick int64
ssh bool
disableDirect bool
useHostLogin bool
app string
template string
Expand Down Expand Up @@ -1023,15 +1024,16 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {

// Setup our workspace agent connection.
config := workspacetraffic.Config{
AgentID: agent.ID,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name),
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name),
SSH: ssh,
Echo: ssh,
App: appConfig,
AgentID: agent.ID,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name),
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name),
SSH: ssh,
DisableDirect: disableDirect,
Echo: ssh,
App: appConfig,
}

if webClient != nil {
Expand Down Expand Up @@ -1117,6 +1119,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
Description: "Send traffic over SSH, cannot be used with --app.",
Value: serpent.BoolOf(&ssh),
},
{
Flag: "disable-direct",
Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_DISABLE_DIRECT_CONNECTIONS",
Default: "false",
Description: "Disable direct connections for SSH traffic to workspaces. Does nothing if `--ssh` is not also set.",
Value: serpent.BoolOf(&disableDirect),
},
{
Flag: "app",
Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_APP",
Expand Down
3 changes: 1 addition & 2 deletions cli/exp_task_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE
ctx = testutil.Context(t, testutil.WaitShort)
now = time.Now().UTC() // TODO: replace with quartz
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
client = new(codersdk.Client)
client = codersdk.New(testutil.MustURL(t, srv.URL))
sb = strings.Builder{}
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
)

t.Cleanup(srv.Close)
client.URL = testutil.MustURL(t, srv.URL)
args = append(args, tc.args...)
inv, root := clitest.New(t, args...)
inv.Stdout = &sb
Expand Down
7 changes: 1 addition & 6 deletions cli/exp_taskcreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
Expand Down Expand Up @@ -236,17 +234,14 @@ func TestTaskCreate(t *testing.T) {
var (
ctx = testutil.Context(t, testutil.WaitShort)
srv = httptest.NewServer(tt.handler(t, ctx))
client = new(codersdk.Client)
client = codersdk.New(testutil.MustURL(t, srv.URL))
args = []string{"exp", "task", "create"}
sb strings.Builder
err error
)

t.Cleanup(srv.Close)

client.URL, err = url.Parse(srv.URL)
require.NoError(t, err)

inv, root := clitest.New(t, append(args, tt.args...)...)
inv.Environ = serpent.ParseEnviron(tt.env, "")
inv.Stdout = &sb
Expand Down
3 changes: 3 additions & 0 deletions cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
}

func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
if client.SessionTokenProvider == nil {
client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{}
}
transport := http.DefaultTransport
transport = wrapTransportWithTelemetryHeader(transport, inv)
if !r.noVersionCheck {
Expand Down
8 changes: 2 additions & 6 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken

// ExternalLogin does the oauth2 flow for external auth providers. This requires
// an authenticated coder client.
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) {
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())
Expand All @@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
// External auth flow requires the user be authenticated.
headerName := client.SessionTokenHeader
if headerName == "" {
headerName = codersdk.SessionTokenHeader
}
req.Header.Set(headerName, client.SessionToken())
opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...)
if cli.Jar == nil {
cli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
Expand Down
2 changes: 1 addition & 1 deletion coderd/mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) {
require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message")

// Test registering tools with valid client should succeed
client := &codersdk.Client{}
client := codersdk.New(testutil.MustURL(t, "http://not-used"))
err = server.RegisterTools(client)
require.NoError(t, err)

Expand Down
55 changes: 25 additions & 30 deletions codersdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{
// New creates a Coder client for the provided URL.
func New(serverURL *url.URL) *Client {
return &Client{
URL: serverURL,
HTTPClient: &http.Client{},
URL: serverURL,
HTTPClient: &http.Client{},
SessionTokenProvider: FixedSessionTokenProvider{},
}
}

Expand All @@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client {
type Client struct {
// mu protects the fields sessionToken, logger, and logBodies. These
// need to be safe for concurrent access.
mu sync.RWMutex
sessionToken string
logger slog.Logger
logBodies bool
mu sync.RWMutex
SessionTokenProvider SessionTokenProvider
logger slog.Logger
logBodies bool

HTTPClient *http.Client
URL *url.URL

// SessionTokenHeader is an optional custom header to use for setting tokens. By
// default 'Coder-Session-Token' is used.
SessionTokenHeader string

// PlainLogger may be set to log HTTP traffic in a human-readable form.
// It uses the LogBodies option.
PlainLogger io.Writer
Expand Down Expand Up @@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) {
func (c *Client) SessionToken() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.sessionToken
return c.SessionTokenProvider.GetSessionToken()
}

// SetSessionToken returns the currently set token for the client.
// SetSessionToken sets a fixed token for the client.
// Deprecated: Create a new client instead of changing the token after creation.
func (c *Client) SetSessionToken(token string) {
c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token})
}

// SetSessionTokenProvider sets the session token provider for the client.
func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionToken = token
c.SessionTokenProvider = provider
}

func prefixLines(prefix, s []byte) []byte {
Expand All @@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte {
// Request performs a HTTP request with the body provided. The caller is
// responsible for closing the response body.
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...)
return c.RequestWithoutSessionToken(ctx, method, path, body, opts...)
}

// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set
// the session token in the request header, nor does it make a call to the SessionTokenProvider.
// This allows session token providers to call this method without causing reentrancy issues.
func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
if ctx == nil {
return nil, xerrors.Errorf("context should not be nil")
}
Expand Down Expand Up @@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
return nil, xerrors.Errorf("create request: %w", err)
}

tokenHeader := c.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
req.Header.Set(tokenHeader, c.SessionToken())

if r != nil {
req.Header.Set("Content-Type", "application/json")
}
Expand Down Expand Up @@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
return nil, err
}

tokenHeader := c.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}

if opts == nil {
opts = &websocket.DialOptions{}
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
if opts.HTTPHeader.Get(tokenHeader) == "" {
opts.HTTPHeader.Set(tokenHeader, c.SessionToken())
}
c.SessionTokenProvider.SetDialOption(opts)

conn, resp, err := websocket.Dial(ctx, u.String(), opts)
if resp != nil && resp.Body != nil {
Expand Down
55 changes: 55 additions & 0 deletions codersdk/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package codersdk

import (
"net/http"

"github.com/coder/websocket"
)

// SessionTokenProvider provides the session token to access the Coder service (coderd).
// @typescript-ignore SessionTokenProvider
type SessionTokenProvider interface {
// AsRequestOption returns a request option that attaches the session token to an HTTP request.
AsRequestOption() RequestOption
// SetDialOption sets the session token on a websocket request via DialOptions
SetDialOption(options *websocket.DialOptions)
// GetSessionToken returns the session token as a string.
GetSessionToken() string
}

// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable
// at the program start.
// @typescript-ignore FixedSessionTokenProvider
type FixedSessionTokenProvider struct {
SessionToken string
// SessionTokenHeader is an optional custom header to use for setting tokens. By
// default, 'Coder-Session-Token' is used.
SessionTokenHeader string
}

func (f FixedSessionTokenProvider) AsRequestOption() RequestOption {
return func(req *http.Request) {
tokenHeader := f.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
req.Header.Set(tokenHeader, f.SessionToken)
}
}

func (f FixedSessionTokenProvider) GetSessionToken() string {
return f.SessionToken
}

func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
tokenHeader := f.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
if opts.HTTPHeader.Get(tokenHeader) == "" {
opts.HTTPHeader.Set(tokenHeader, f.SessionToken)
}
}
17 changes: 6 additions & 11 deletions codersdk/workspacesdk/workspacesdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
options.BlockEndpoints = true
}

headers := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.client.SessionTokenHeader != "" {
tokenHeader = c.client.SessionTokenHeader
wsOptions := &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
}
headers.Set(tokenHeader, c.client.SessionToken())
c.client.SessionTokenProvider.SetDialOption(wsOptions)

// New context, separate from dialCtx. We don't want to cancel the
// connection if dialCtx is canceled.
Expand All @@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
return nil, xerrors.Errorf("parse url: %w", err)
}

dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
HTTPHeader: headers,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions)
clk := quartz.NewReal()
controller := tailnet.NewController(options.Logger, dialer)
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
Expand Down
Loading
Loading