From 486da4315fa3183c0dcbd3431c3a12593c2b1083 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 27 Aug 2025 13:30:02 +0400 Subject: [PATCH 1/5] chore: refactor instance identity to be a SessionTokenProvider --- agent/agent.go | 18 +- agent/agent_test.go | 31 +-- agent/agenttest/agent.go | 10 +- agent/agenttest/client.go | 16 ++ cli/agent.go | 92 +------- cli/exp_mcp.go | 4 +- cli/externalauth.go | 2 +- cli/gitaskpass.go | 2 +- cli/gitaskpass_test.go | 8 +- cli/gitssh.go | 2 +- cli/gitssh_test.go | 3 +- cli/root.go | 116 +++++++--- cli/testdata/coder_agent_--help.golden | 3 - coderd/externalauth_test.go | 18 +- coderd/gitsshkey_test.go | 6 +- coderd/insights_test.go | 6 +- .../insights/metricscollector_test.go | 3 +- .../prometheusmetrics_test.go | 3 +- coderd/workspaceagents_test.go | 48 ++--- coderd/workspaceagentsrpc_test.go | 6 +- coderd/workspaceapps/apptest/setup.go | 3 +- coderd/workspaceresourceauth_test.go | 34 ++- codersdk/agentsdk/agentsdk.go | 204 +++++++----------- codersdk/agentsdk/agentsdk_test.go | 2 +- codersdk/agentsdk/aws.go | 97 +++++++++ codersdk/agentsdk/azure.go | 60 ++++++ codersdk/agentsdk/google.go | 71 ++++++ codersdk/toolsdk/toolsdk_test.go | 3 +- enterprise/coderd/appearance_test.go | 6 +- enterprise/coderd/gitsshkey_test.go | 3 +- enterprise/coderd/workspaceagents_test.go | 3 +- scaletest/createworkspaces/run_test.go | 3 +- scaletest/workspacebuild/run_test.go | 3 +- 33 files changed, 488 insertions(+), 401 deletions(-) create mode 100644 codersdk/agentsdk/aws.go create mode 100644 codersdk/agentsdk/azure.go create mode 100644 codersdk/agentsdk/google.go diff --git a/agent/agent.go b/agent/agent.go index e4d7ab60e076b..aed6652de612c 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -74,7 +74,6 @@ type Options struct { LogDir string TempDir string ScriptDataDir string - ExchangeToken func(ctx context.Context) (string, error) Client Client ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string @@ -99,6 +98,7 @@ type Client interface { proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error, ) tailnet.DERPMapRewriter + agentsdk.RefreshableSessionTokenProvider } type Agent interface { @@ -131,11 +131,6 @@ func New(options Options) Agent { } options.ScriptDataDir = options.TempDir } - if options.ExchangeToken == nil { - options.ExchangeToken = func(_ context.Context) (string, error) { - return "", nil - } - } if options.ReportMetadataInterval == 0 { options.ReportMetadataInterval = time.Second } @@ -172,7 +167,6 @@ func New(options Options) Agent { coordDisconnected: make(chan struct{}), environmentVariables: options.EnvironmentVariables, client: options.Client, - exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, logDir: options.LogDir, tempDir: options.TempDir, @@ -203,7 +197,6 @@ func New(options Options) Agent { // coordinator during shut down. close(a.coordDisconnected) a.announcementBanners.Store(new([]codersdk.BannerConfig)) - a.sessionToken.Store(new(string)) a.init() return a } @@ -212,7 +205,6 @@ type agent struct { clock quartz.Clock logger slog.Logger client Client - exchangeToken func(ctx context.Context) (string, error) tailnetListenPort uint16 filesystem afero.Fs logDir string @@ -254,7 +246,6 @@ type agent struct { scriptRunner *agentscripts.Runner announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated. announcementBannersRefreshInterval time.Duration - sessionToken atomic.Pointer[string] sshServer *agentssh.Server sshMaxTimeout time.Duration blockFileTransfer bool @@ -916,11 +907,10 @@ func (a *agent) run() (retErr error) { // This allows the agent to refresh its token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. - sessionToken, err := a.exchangeToken(a.hardCtx) + err := a.client.RefreshToken(a.hardCtx) if err != nil { - return xerrors.Errorf("exchange token: %w", err) + return xerrors.Errorf("refresh token: %w", err) } - a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx) @@ -1359,7 +1349,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error) "CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName, // Specific Coder subcommands require the agent token exposed! - "CODER_AGENT_TOKEN": *a.sessionToken.Load(), + "CODER_AGENT_TOKEN": a.client.GetSessionToken(), // Git on Windows resolves with UNIX-style paths. // If using backslashes, it's unable to find the executable. diff --git a/agent/agent_test.go b/agent/agent_test.go index d80f5d1982b74..72219cbe16fa2 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -22,7 +22,6 @@ import ( "slices" "strconv" "strings" - "sync/atomic" "testing" "time" @@ -2926,11 +2925,11 @@ func TestAgent_Speedtest(t *testing.T) { func TestAgent_Reconnect(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) logger := testutil.Logger(t) // After the agent is disconnected from a coordinator, it's supposed // to reconnect! - coordinator := tailnet.NewCoordinator(logger) - defer coordinator.Close() + fCoordinator := tailnettest.NewFakeCoordinator() agentID := uuid.New() statsCh := make(chan *proto.Stats, 50) @@ -2942,27 +2941,21 @@ func TestAgent_Reconnect(t *testing.T) { DERPMap: derpMap, }, statsCh, - coordinator, + fCoordinator, ) defer client.Close() - initialized := atomic.Int32{} + closer := agent.New(agent.Options{ - ExchangeToken: func(ctx context.Context) (string, error) { - initialized.Add(1) - return "", nil - }, Client: client, Logger: logger.Named("agent"), }) defer closer.Close() - require.Eventually(t, func() bool { - return coordinator.Node(agentID) != nil - }, testutil.WaitShort, testutil.IntervalFast) - client.LastWorkspaceAgent() - require.Eventually(t, func() bool { - return initialized.Load() == 2 - }, testutil.WaitShort, testutil.IntervalFast) + call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + close(call1.Resps) // hang up + // expect reconnect + testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + closer.Close() } func TestAgent_WriteVSCodeConfigs(t *testing.T) { @@ -2984,9 +2977,6 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { defer client.Close() filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ - ExchangeToken: func(ctx context.Context) (string, error) { - return "", nil - }, Client: client, Logger: logger.Named("agent"), Filesystem: filesystem, @@ -3015,9 +3005,6 @@ func TestAgent_DebugServer(t *testing.T) { conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{ DERPMap: derpMap, }, 0, func(c *agenttest.Client, o *agent.Options) { - o.ExchangeToken = func(context.Context) (string, error) { - return "token", nil - } o.LogDir = logDir }) diff --git a/agent/agenttest/agent.go b/agent/agenttest/agent.go index d25170dfc2183..8a2a6260c291c 100644 --- a/agent/agenttest/agent.go +++ b/agent/agenttest/agent.go @@ -1,7 +1,6 @@ package agenttest import ( - "context" "net/url" "testing" @@ -31,18 +30,11 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent } if o.Client == nil { - agentClient := agentsdk.New(coderURL) - agentClient.SetSessionToken(agentToken) + agentClient := agentsdk.New(coderURL, agentsdk.UsingFixedToken(agentToken)) agentClient.SDK.SetLogger(log) o.Client = agentClient } - if o.ExchangeToken == nil { - o.ExchangeToken = func(_ context.Context) (string, error) { - return agentToken, nil - } - } - if o.LogDir == "" { o.LogDir = t.TempDir() } diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 5d78dfe697c93..3e9f025f18a7e 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -3,6 +3,7 @@ package agenttest import ( "context" "io" + "net/http" "slices" "sync" "sync/atomic" @@ -28,6 +29,7 @@ import ( "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" ) const statsInterval = 500 * time.Millisecond @@ -92,6 +94,20 @@ type Client struct { derpMapOnce sync.Once } +func (*Client) AsRequestOption() codersdk.RequestOption { + return func(_ *http.Request) {} +} + +func (*Client) SetDialOption(*websocket.DialOptions) {} + +func (*Client) GetSessionToken() string { + return "agenttest-token" +} + +func (*Client) RefreshToken(context.Context) error { + return nil +} + func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {} func (c *Client) Close() { diff --git a/cli/agent.go b/cli/agent.go index c192d4429ccaf..99caf773de8b4 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "cloud.google.com/go/compute/metadata" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" @@ -40,7 +39,6 @@ import ( func (r *RootCmd) workspaceAgent() *serpent.Command { var ( - auth string logDir string scriptDataDir string pprofAddress string @@ -177,11 +175,13 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { version := buildinfo.Version() logger.Info(ctx, "agent is starting now", slog.F("url", r.agentURL), - slog.F("auth", auth), + slog.F("auth", r.agentAuth), slog.F("version", version), ) - - client := agentsdk.New(r.agentURL) + client, err := r.createAgentClient(ctx) + if err != nil { + return xerrors.Errorf("create agent client: %w", err) + } client.SDK.SetLogger(logger) // Set a reasonable timeout so requests can't hang forever! // The timeout needs to be reasonably long, because requests @@ -214,68 +214,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { ignorePorts[port] = "debug" } - // exchangeToken returns a session token. - // This is abstracted to allow for the same looping condition - // regardless of instance identity auth type. - var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error) - switch auth { - case "token": - token, _ := inv.ParsedFlags().GetString(varAgentToken) - if token == "" { - tokenFile, _ := inv.ParsedFlags().GetString(varAgentTokenFile) - if tokenFile != "" { - tokenBytes, err := os.ReadFile(tokenFile) - if err != nil { - return xerrors.Errorf("read token file %q: %w", tokenFile, err) - } - token = strings.TrimSpace(string(tokenBytes)) - } - } - if token == "" { - return xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") - } - client.SetSessionToken(token) - case "google-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var gcpClient *metadata.Client - gcpClientRaw := ctx.Value("gcp-client") - if gcpClientRaw != nil { - gcpClient, _ = gcpClientRaw.(*metadata.Client) - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient) - } - case "aws-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var awsClient *http.Client - awsClientRaw := ctx.Value("aws-client") - if awsClientRaw != nil { - awsClient, _ = awsClientRaw.(*http.Client) - if awsClient != nil { - client.SDK.HTTPClient = awsClient - } - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthAWSInstanceIdentity(ctx) - } - case "azure-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var azureClient *http.Client - azureClientRaw := ctx.Value("azure-client") - if azureClientRaw != nil { - azureClient, _ = azureClientRaw.(*http.Client) - if azureClient != nil { - client.SDK.HTTPClient = azureClient - } - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthAzureInstanceIdentity(ctx) - } - } - executablePath, err := os.Executable() if err != nil { return xerrors.Errorf("getting os executable: %w", err) @@ -343,18 +281,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { LogDir: logDir, ScriptDataDir: scriptDataDir, // #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535) - TailnetListenPort: uint16(tailnetListenPort), - ExchangeToken: func(ctx context.Context) (string, error) { - if exchangeToken == nil { - return client.SDK.SessionToken(), nil - } - resp, err := exchangeToken(ctx) - if err != nil { - return "", err - } - client.SetSessionToken(resp.SessionToken) - return resp.SessionToken, nil - }, + TailnetListenPort: uint16(tailnetListenPort), EnvironmentVariables: environmentVariables, IgnorePorts: ignorePorts, SSHMaxTimeout: sshMaxTimeout, @@ -400,13 +327,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { } cmd.Options = serpent.OptionSet{ - { - Flag: "auth", - Default: "token", - Description: "Specify the authentication type to use for the agent.", - Env: "CODER_AGENT_AUTH", - Value: serpent.StringOf(&auth), - }, { Flag: "log-dir", Default: os.TempDir(), diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index d5ea26739085b..6ee4a362be4eb 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -148,7 +148,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { binPath = testBinaryName } configureClaudeEnv := map[string]string{} - agentClient, err := r.createAgentClient() + agentClient, err := r.createAgentClient(inv.Context()) if err != nil { cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err) } else { @@ -494,7 +494,7 @@ func (r *RootCmd) mcpServer() *serpent.Command { } // Try to create an agent client for status reporting. Not validated. - agentClient, err := r.createAgentClient() + agentClient, err := r.createAgentClient(inv.Context()) if err == nil { cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String()) srv.agentClient = agentClient diff --git a/cli/externalauth.go b/cli/externalauth.go index 98bd853992da7..3910d6b01afd0 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -75,7 +75,7 @@ fi return xerrors.Errorf("agent token not found") } - client, err := r.tryCreateAgentClient() + client, err := r.createAgentClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index e54d93478d8a8..f41b0e152b3e3 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -33,7 +33,7 @@ func (r *RootCmd) gitAskpass() *serpent.Command { return xerrors.Errorf("parse host: %w", err) } - client, err := r.tryCreateAgentClient() + client, err := r.createAgentClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } diff --git a/cli/gitaskpass_test.go b/cli/gitaskpass_test.go index 8e51411de9587..584e003427c4d 100644 --- a/cli/gitaskpass_test.go +++ b/cli/gitaskpass_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" ) func TestGitAskpass(t *testing.T) { @@ -32,6 +33,7 @@ func TestGitAskpass(t *testing.T) { url := srv.URL inv, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty := ptytest.New(t) inv.Stdout = pty.Output() clitest.Start(t, inv) @@ -39,6 +41,7 @@ func TestGitAskpass(t *testing.T) { inv, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty = ptytest.New(t) inv.Stdout = pty.Output() clitest.Start(t, inv) @@ -56,6 +59,7 @@ func TestGitAskpass(t *testing.T) { url := srv.URL inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty := ptytest.New(t) inv.Stderr = pty.Output() err := inv.Run() @@ -65,6 +69,7 @@ func TestGitAskpass(t *testing.T) { t.Run("Poll", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) resp := atomic.Pointer[agentsdk.ExternalAuthResponse]{} resp.Store(&agentsdk.ExternalAuthResponse{ URL: "https://something.org", @@ -86,6 +91,7 @@ func TestGitAskpass(t *testing.T) { inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") stdout := ptytest.New(t) inv.Stdout = stdout.Output() stderr := ptytest.New(t) @@ -94,7 +100,7 @@ func TestGitAskpass(t *testing.T) { err := inv.Run() assert.NoError(t, err) }() - <-poll + testutil.RequireReceive(ctx, t, poll) stderr.ExpectMatch("Open the following URL to authenticate") resp.Store(&agentsdk.ExternalAuthResponse{ Username: "username", diff --git a/cli/gitssh.go b/cli/gitssh.go index 566d3cc6f171f..59cc0299e1d22 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -38,7 +38,7 @@ func (r *RootCmd) gitssh() *serpent.Command { return err } - client, err := r.tryCreateAgentClient() + client, err := r.createAgentClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index 6d574ae651aec..85f24a9ca1aab 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -54,8 +54,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str }).WithAgent().Do() // start workspace agent - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { o.Client = agentClient }) diff --git a/cli/root.go b/cli/root.go index ed6869b6a1c49..d42946819c127 100644 --- a/cli/root.go +++ b/cli/root.go @@ -24,6 +24,7 @@ import ( "text/tabwriter" "time" + "cloud.google.com/go/compute/metadata" "github.com/mattn/go-isatty" "github.com/mitchellh/go-wordwrap" "golang.org/x/mod/semver" @@ -62,6 +63,7 @@ const ( varAgentToken = "agent-token" varAgentTokenFile = "agent-token-file" varAgentURL = "agent-url" + varAgentAuth = "auth" varHeader = "header" varHeaderCommand = "header-command" varNoOpen = "no-open" @@ -82,6 +84,7 @@ const ( //nolint:gosec envAgentTokenFile = "CODER_AGENT_TOKEN_FILE" envAgentURL = "CODER_AGENT_URL" + envAgentAuth = "CODER_AGENT_AUTH" envURL = "CODER_URL" ) @@ -405,6 +408,15 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err Hidden: true, Group: globalGroup, }, + { + Flag: varAgentAuth, + Env: envAgentAuth, + Default: "token", + Description: "Specify the authentication type to use for the agent.", + Value: serpent.StringOf(&r.agentAuth), + Hidden: true, + Group: globalGroup, + }, { Flag: varNoVersionCheck, Env: envNoVersionCheck, @@ -502,20 +514,24 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err // RootCmd contains parameters and helpers useful to all commands. type RootCmd struct { - clientURL *url.URL - token string - globalConfig string - header []string - headerCommand string + clientURL *url.URL + token string + globalConfig string + header []string + headerCommand string + + // Agent Client config agentToken string agentTokenFile string agentURL *url.URL - forceTTY bool - noOpen bool - verbose bool - versionFlag bool - disableDirect bool - debugHTTP bool + agentAuth string + + forceTTY bool + noOpen bool + verbose bool + versionFlag bool + disableDirect bool + debugHTTP bool disableNetworkTelemetry bool noVersionCheck bool @@ -674,36 +690,68 @@ func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *ur // createAgentClient returns a new client from the command context. It works // just like InitClient, but uses the agent token and URL instead. -func (r *RootCmd) createAgentClient() (*agentsdk.Client, error) { +func (r *RootCmd) createAgentClient(ctx context.Context) (*agentsdk.Client, error) { agentURL := r.agentURL if agentURL == nil || agentURL.String() == "" { return nil, xerrors.Errorf("%s must be set", envAgentURL) } - token := r.agentToken - if token == "" { - if r.agentTokenFile == "" { - return nil, xerrors.Errorf("Either %s or %s must be set", envAgentToken, envAgentTokenFile) + + switch r.agentAuth { + case "token": + token := r.agentToken + if token == "" { + if r.agentTokenFile == "" { + return nil, xerrors.Errorf("Either %s or %s must be set", envAgentToken, envAgentTokenFile) + } + tokenBytes, err := os.ReadFile(r.agentTokenFile) + if err != nil { + return nil, xerrors.Errorf("read token file %q: %w", r.agentTokenFile, err) + } + token = strings.TrimSpace(string(tokenBytes)) } - tokenBytes, err := os.ReadFile(r.agentTokenFile) - if err != nil { - return nil, xerrors.Errorf("read token file %q: %w", r.agentTokenFile, err) + if token == "" { + return nil, xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") + } + return agentsdk.New(r.agentURL, agentsdk.UsingFixedToken(token)), nil + case "google-instance-identity": + + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var gcpClient *metadata.Client + gcpClientRaw := ctx.Value("gcp-client") + if gcpClientRaw != nil { + gcpClient, _ = gcpClientRaw.(*metadata.Client) } - token = strings.TrimSpace(string(tokenBytes)) + return agentsdk.New(r.agentURL, agentsdk.UsingGoogleInstanceIdentity("", gcpClient)), nil + case "aws-instance-identity": + client := agentsdk.New(r.agentURL, agentsdk.UsingAWSInstanceIdentity()) + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var awsClient *http.Client + awsClientRaw := ctx.Value("aws-client") + if awsClientRaw != nil { + awsClient, _ = awsClientRaw.(*http.Client) + if awsClient != nil { + client.SDK.HTTPClient = awsClient + } + } + return client, nil + case "azure-instance-identity": + client := agentsdk.New(r.agentURL, agentsdk.UsingAzureInstanceIdentity()) + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var azureClient *http.Client + azureClientRaw := ctx.Value("azure-client") + if azureClientRaw != nil { + azureClient, _ = azureClientRaw.(*http.Client) + if azureClient != nil { + client.SDK.HTTPClient = azureClient + } + } + return client, nil + default: + return nil, xerrors.Errorf("unknown agent auth type: %s", r.agentAuth) } - client := agentsdk.New(agentURL) - client.SetSessionToken(token) - return client, nil -} - -// tryCreateAgentClient returns a new client from the command context. It works -// just like tryCreateAgentClient, but does not error. -func (r *RootCmd) tryCreateAgentClient() (*agentsdk.Client, error) { - // TODO: Why does this not actually return any errors despite the function - // signature? Could we just use createAgentClient instead, or is it expected - // that we return a client in some cases even without a valid URL or token? - client := agentsdk.New(r.agentURL) - client.SetSessionToken(r.agentToken) - return client, nil } type OrganizationContext struct { diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index c6d75705a6eb4..0541210c00824 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -24,9 +24,6 @@ OPTIONS: requests. The command must output each header as `key=value` on its own line. - --auth string, $CODER_AGENT_AUTH (default: token) - Specify the authentication type to use for the agent. - --block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false) Block file transfer using known applications: nc,rsync,scp,sftp. diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index c9ba4911214de..2e648fd8d8024 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -432,8 +432,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) _, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com", }) @@ -464,8 +463,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", }) @@ -565,8 +563,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) resp := coderdtest.RequestExternalAuthCallback(t, "github", client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) @@ -627,8 +624,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -674,8 +670,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -740,8 +735,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) token, err := agentClient.ExternalAuth(t.Context(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", diff --git a/coderd/gitsshkey_test.go b/coderd/gitsshkey_test.go index abd18508ce018..448705c0fedbf 100644 --- a/coderd/gitsshkey_test.go +++ b/coderd/gitsshkey_test.go @@ -118,8 +118,7 @@ func TestAgentGitSSHKey(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -157,8 +156,7 @@ func TestAgentGitSSHKey_APIKeyScopes(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/coderd/insights_test.go b/coderd/insights_test.go index cf5f63065df99..f4f3272a80261 100644 --- a/coderd/insights_test.go +++ b/coderd/insights_test.go @@ -585,8 +585,7 @@ func TestTemplateInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App @@ -1494,8 +1493,7 @@ func TestUserActivityInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App diff --git a/coderd/prometheusmetrics/insights/metricscollector_test.go b/coderd/prometheusmetrics/insights/metricscollector_test.go index 5c18ec6d1a60f..cbd082bdfd920 100644 --- a/coderd/prometheusmetrics/insights/metricscollector_test.go +++ b/coderd/prometheusmetrics/insights/metricscollector_test.go @@ -90,8 +90,7 @@ func TestCollectInsights(t *testing.T) { // Start an agent so that we can generate stats. var agentClients []agentproto.DRPCAgentClient for i, agent := range []database.WorkspaceAgent{agent1, agent2} { - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(agent.AuthToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(agent.AuthToken.String())) agentClient.SDK.SetLogger(logger.Leveled(slog.LevelDebug).Named(fmt.Sprintf("agent%d", i+1))) conn, err := agentClient.ConnectRPC(context.Background()) require.NoError(t, err) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 3d8704f92460d..dd5b3f5db74cf 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -875,8 +875,7 @@ func prepareWorkspaceAndAgent(ctx context.Context, t *testing.T, client *codersd }) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - ac := agentsdk.New(client.URL) - ac.SetSessionToken(authToken) + ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) agentAPI := agentproto.NewDRPCAgentClient(conn) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6a817966f4ff5..92fbed2cf8421 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -228,8 +228,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -269,8 +268,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -314,8 +312,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { updates, err := client.WatchWorkspace(ctx, r.Workspace.ID) require.NoError(t, err) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) err = agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{{ CreatedAt: dbtime.Now(), @@ -360,8 +357,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { return a }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -542,8 +538,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { require.NoError(t, err) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, stopBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) @@ -568,8 +563,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { ) require.NoError(t, err) // Then: the agent token should no longer be valid - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(wsb.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken((wsb.AgentToken))) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) var sdkErr *codersdk.Error @@ -890,8 +884,7 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) // Verify that the manifest has DisableDirectConnections set to true. - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) rpc, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1742,8 +1735,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) conn, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1818,8 +1810,7 @@ func TestWorkspaceAgentPostLogSource(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) req := agentsdk.PostLogSourceRequest{ ID: uuid.New(), @@ -1867,8 +1858,7 @@ func TestWorkspaceAgent_LifecycleState(t *testing.T) { } } - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1965,8 +1955,7 @@ func TestWorkspaceAgent_Metadata(t *testing.T) { } } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) conn, err := agentClient.ConnectRPC(ctx) @@ -2229,8 +2218,7 @@ func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) { } } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitSuperLong) conn, err := agentClient.ConnectRPC(ctx) @@ -2335,8 +2323,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2382,8 +2369,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2547,8 +2533,7 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) { return agents }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) // We need to include an invalid oauth token that is not expired. dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{ @@ -3028,8 +3013,7 @@ func TestReinit(t *testing.T) { pubsubSpy.Unlock() agentCtx := testutil.Context(t, testutil.WaitShort) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent) go func() { diff --git a/coderd/workspaceagentsrpc_test.go b/coderd/workspaceagentsrpc_test.go index 5175f80b0b723..221a04d4fcb68 100644 --- a/coderd/workspaceagentsrpc_test.go +++ b/coderd/workspaceagentsrpc_test.go @@ -68,8 +68,7 @@ func TestWorkspaceAgentReportStats(t *testing.T) { }, ).Do() - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(context.Background()) require.NoError(t, err) defer func() { @@ -155,8 +154,7 @@ func TestAgentAPI_LargeManifest(t *testing.T) { agents[0].ApiKeyScope = string(tc.apiKeyScope) return agents }).Do() - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) defer func() { _ = conn.Close() diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index 296934591e873..ebef0375f6959 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -482,8 +482,7 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U require.Equal(t, appURL.String(), app.SubdomainName) } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) // TODO (@dean): currently, the primary app host is used when generating // the port URL we tell the agent to use. We don't have any plans to change diff --git a/coderd/workspaceresourceauth_test.go b/coderd/workspaceresourceauth_test.go index 8c1b64feaf59a..53b2f33d47747 100644 --- a/coderd/workspaceresourceauth_test.go +++ b/coderd/workspaceresourceauth_test.go @@ -51,11 +51,9 @@ func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - client.HTTPClient = metadataClient - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthAzureInstanceIdentity(ctx) + agentClient := agentsdk.New(client.URL, agentsdk.UsingAzureInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + err := agentClient.RefreshToken(ctx) require.NoError(t, err) } @@ -97,11 +95,9 @@ func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - client.HTTPClient = metadataClient - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthAWSInstanceIdentity(ctx) + agentClient := agentsdk.New(client.URL, agentsdk.UsingAWSInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) } @@ -119,10 +115,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode()) @@ -139,10 +133,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) @@ -184,10 +176,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) } diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 5bd0030456757..a0739cac13956 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -8,9 +8,9 @@ import ( "net/http" "net/http/cookiejar" "net/url" + "sync" "time" - "cloud.google.com/go/compute/metadata" "github.com/google/uuid" "github.com/hashicorp/yamux" "golang.org/x/xerrors" @@ -37,24 +37,31 @@ import ( // log-source. This should be removed in the future. var ExternalLogSourceID = uuid.MustParse("3b579bf4-1ed8-4b99-87a8-e9a1e3410410") -// New returns a client that is used to interact with the -// Coder API from a workspace agent. -func New(serverURL *url.URL) *Client { +// SessionTokenSetup is a function that creates the token provider while setting up the workspace agent. We do it this +// way because cloud instance identity (AWS, Azure, Google, etc.) requires interacting with coderd to exchange tokens. +// This means that the token providers need a codersdk.Client. However, the SessionTokenProvider is itself used by +// the client to authenticate requests. Thus, the dependency is bidirectional. Functions of this type are used in +// New() to ensure that things are set up correctly so there is only one instance of the codersdk.Client created. +// @typescript-ignore SessionTokenSetup +type SessionTokenSetup func(client *codersdk.Client) RefreshableSessionTokenProvider + +func New(serverURL *url.URL, setup SessionTokenSetup) *Client { + c := codersdk.New(serverURL) + provider := setup(c) + c.SessionTokenProvider = provider return &Client{ - SDK: codersdk.New(serverURL), + SDK: c, + RefreshableSessionTokenProvider: provider, } } // Client wraps `codersdk.Client` with specific functions // scoped to a workspace agent. type Client struct { + RefreshableSessionTokenProvider SDK *codersdk.Client } -func (c *Client) SetSessionToken(token string) { - c.SDK.SetSessionToken(token) -} - type GitSSHKey struct { PublicKey string `json:"public_key"` PrivateKey string `json:"private_key"` @@ -326,146 +333,91 @@ type AuthenticateResponse struct { SessionToken string `json:"session_token"` } -type GoogleInstanceIdentityToken struct { - JSONWebToken string `json:"json_web_token" validate:"required"` +// RefreshableSessionTokenProvider is a SessionTokenProvider that can be refreshed, for example, via token exchange. +// @typescript-ignore RefreshableSessionTokenProvider +type RefreshableSessionTokenProvider interface { + codersdk.SessionTokenProvider + RefreshToken(ctx context.Context) error } -// AuthWorkspaceGoogleInstanceIdentity uses the Google Compute Engine Metadata API to -// fetch a signed JWT, and exchange it for a session token for a workspace agent. -// -// The requesting instance must be registered as a resource in the latest history for a workspace. -func (c *Client) AuthGoogleInstanceIdentity(ctx context.Context, serviceAccount string, gcpClient *metadata.Client) (AuthenticateResponse, error) { - if serviceAccount == "" { - // This is the default name specified by Google. - serviceAccount = "default" - } - if gcpClient == nil { - gcpClient = metadata.NewClient(c.SDK.HTTPClient) - } - // "format=full" is required, otherwise the responding payload will be missing "instance_id". - jwt, err := gcpClient.Get(fmt.Sprintf("instance/service-accounts/%s/identity?audience=coder&format=full", serviceAccount)) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("get metadata identity: %w", err) - } - res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ - JSONWebToken: jwt, - }) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) - } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) +// instanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud +// compute instance identity. +// @typescript-ignore instanceIdentitySessionTokenProvider +type instanceIdentitySessionTokenProvider struct { + tokenExchanger tokenExchanger + logger slog.Logger + + // cache so we don't request each time + mu sync.Mutex + sessionToken string } -type AWSInstanceIdentityToken struct { - Signature string `json:"signature" validate:"required"` - Document string `json:"document" validate:"required"` +// tokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token. +// @typescript-ignore tokenExchanger +type tokenExchanger interface { + exchange(ctx context.Context) (AuthenticateResponse, error) } -// AuthWorkspaceAWSInstanceIdentity uses the Amazon Metadata API to -// fetch a signed payload, and exchange it for a session token for a workspace agent. -// -// The requesting instance must be registered as a resource in the latest history for a workspace. -func (c *Client) AuthAWSInstanceIdentity(ctx context.Context) (AuthenticateResponse, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) - if err != nil { - return AuthenticateResponse{}, nil - } - req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") - res, err := c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - token, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) +func (i *instanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption { + t := i.GetSessionToken() + return func(req *http.Request) { + req.Header.Set(codersdk.SessionTokenHeader, t) } +} - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) - if err != nil { - return AuthenticateResponse{}, nil - } - req.Header.Set("X-aws-ec2-metadata-token", string(token)) - res, err = c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err +func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { + t := i.GetSessionToken() + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} } - defer res.Body.Close() - signature, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + if opts.HTTPHeader.Get(codersdk.SessionTokenHeader) == "" { + opts.HTTPHeader.Set(codersdk.SessionTokenHeader, t) } +} - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) - if err != nil { - return AuthenticateResponse{}, nil +func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string { + i.mu.Lock() + defer i.mu.Unlock() + if i.sessionToken != "" { + return i.sessionToken } - req.Header.Set("X-aws-ec2-metadata-token", string(token)) - res, err = c.SDK.HTTPClient.Do(req) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := i.tokenExchanger.exchange(ctx) if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - document, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + i.logger.Error(ctx, "failed to exchange session token: %v", err) + return "" } + i.sessionToken = resp.SessionToken + return i.sessionToken +} - res, err = c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ - Signature: string(signature), - Document: string(document), - }) +func (i *instanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error { + i.mu.Lock() + defer i.mu.Unlock() + resp, err := i.tokenExchanger.exchange(ctx) if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + return err } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) + i.sessionToken = resp.SessionToken + return nil } -type AzureInstanceIdentityToken struct { - Signature string `json:"signature" validate:"required"` - Encoding string `json:"encoding" validate:"required"` +// FixedSessionTokenProvider wraps the codersdk variant to add a no-op RefreshToken method to satisfy the +// RefreshableSessionTokenProvider interface. +// @typescript-ignore FixedSessionTokenProvider +type FixedSessionTokenProvider struct { + codersdk.FixedSessionTokenProvider } -// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to -// fetch a signed payload, and exchange it for a session token for a workspace agent. -func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateResponse, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil) - if err != nil { - return AuthenticateResponse{}, nil - } - req.Header.Set("Metadata", "true") - res, err := c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - - var token AzureInstanceIdentityToken - err = json.NewDecoder(res.Body).Decode(&token) - if err != nil { - return AuthenticateResponse{}, err - } +func (FixedSessionTokenProvider) RefreshToken(_ context.Context) error { + return nil +} - res, err = c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) +func UsingFixedToken(token string) SessionTokenSetup { + return func(_ *codersdk.Client) RefreshableSessionTokenProvider { + return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}} } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) } // Stats records the Agent's network connection statistics for use in diff --git a/codersdk/agentsdk/agentsdk_test.go b/codersdk/agentsdk/agentsdk_test.go index e6ea6838dd9b2..748b6203aaea0 100644 --- a/codersdk/agentsdk/agentsdk_test.go +++ b/codersdk/agentsdk/agentsdk_test.go @@ -141,7 +141,7 @@ func TestRewriteDERPMap(t *testing.T) { } parsed, err := url.Parse("https://coconuts.org:44558") require.NoError(t, err) - client := agentsdk.New(parsed) + client := agentsdk.New(parsed, agentsdk.UsingFixedToken("unused")) client.RewriteDERPMap(dm) region := dm.Regions[1] require.True(t, region.EmbeddedRelay) diff --git a/codersdk/agentsdk/aws.go b/codersdk/agentsdk/aws.go new file mode 100644 index 0000000000000..622657ea92211 --- /dev/null +++ b/codersdk/agentsdk/aws.go @@ -0,0 +1,97 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" +) + +type AWSInstanceIdentityToken struct { + Signature string `json:"signature" validate:"required"` + Document string `json:"document" validate:"required"` +} + +// awsSessionTokenExchanger exchanges AWS instance metadata for a Coder session token. +// @typescript-ignore awsSessionTokenExchanger +type awsSessionTokenExchanger struct { + client *codersdk.Client +} + +func UsingAWSInstanceIdentity() SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &awsSessionTokenExchanger{client: client}, + } + } +} + +// exchange uses the Amazon Metadata API to fetch a signed payload, and exchange it for a session token for a workspace +// agent. +// +// The requesting instance must be registered as a resource in the latest history for a workspace. +func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") + res, err := a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + token, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token", string(token)) + res, err = a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + signature, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token", string(token)) + res, err = a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + document, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + // request without the token to avoid re-entering this function + res, err = a.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ + Signature: string(signature), + Document: string(document), + }) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/agentsdk/azure.go b/codersdk/agentsdk/azure.go new file mode 100644 index 0000000000000..42ec2f9b0d836 --- /dev/null +++ b/codersdk/agentsdk/azure.go @@ -0,0 +1,60 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/coder/coder/v2/codersdk" +) + +type AzureInstanceIdentityToken struct { + Signature string `json:"signature" validate:"required"` + Encoding string `json:"encoding" validate:"required"` +} + +// azureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token. +// @typescript-ignore azureSessionTokenExchanger +type azureSessionTokenExchanger struct { + client *codersdk.Client +} + +func UsingAzureInstanceIdentity() SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &azureSessionTokenExchanger{client: client}, + } + } +} + +// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to +// fetch a signed payload, and exchange it for a session token for a workspace agent. +func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("Metadata", "true") + res, err := a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + + var token AzureInstanceIdentityToken + err = json.NewDecoder(res.Body).Decode(&token) + if err != nil { + return AuthenticateResponse{}, err + } + + res, err = a.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/agentsdk/google.go b/codersdk/agentsdk/google.go new file mode 100644 index 0000000000000..8b67b16ac63e1 --- /dev/null +++ b/codersdk/agentsdk/google.go @@ -0,0 +1,71 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "cloud.google.com/go/compute/metadata" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" +) + +type GoogleInstanceIdentityToken struct { + JSONWebToken string `json:"json_web_token" validate:"required"` +} + +// googleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token. +// @typescript-ignore googleSessionTokenExchanger +type googleSessionTokenExchanger struct { + serviceAccount string + gcpClient *metadata.Client + client *codersdk.Client +} + +func UsingGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &googleSessionTokenExchanger{ + client: client, + gcpClient: gcpClient, + serviceAccount: serviceAccount, + }, + } + } +} + +// exchange uses the Google Compute Engine Metadata API to fetch a signed JWT, and exchange it for a session token for a +// workspace agent. +// +// The requesting instance must be registered as a resource in the latest history for a workspace. +func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + if g.serviceAccount == "" { + // This is the default name specified by Google. + g.serviceAccount = "default" + } + gcpClient := metadata.NewClient(g.client.HTTPClient) + if g.gcpClient != nil { + gcpClient = g.gcpClient + } + + // "format=full" is required, otherwise the responding payload will be missing "instance_id". + jwt, err := gcpClient.Get(fmt.Sprintf("instance/service-accounts/%s/identity?audience=coder&format=full", g.serviceAccount)) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("get metadata identity: %w", err) + } + // request without the token to avoid re-entering this function + res, err := g.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ + JSONWebToken: jwt, + }) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index fb321e90e7dee..299c6fe5f3519 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -75,8 +75,7 @@ func TestTools(t *testing.T) { }).Do() // Given: a client configured with the agent token. - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) // Get the agent ID from the API. Overriding it in dbfake doesn't work. ws, err := client.Workspace(setupCtx, r.Workspace.ID) require.NoError(t, err) diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index 8550f13904e2d..0995f98f6d6ca 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -153,15 +153,13 @@ func TestAnnouncementBanners(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) banners := requireGetAnnouncementBanners(ctx, t, agentClient) require.Equal(t, cfg.AnnouncementBanners, banners) // Create an AGPL Coderd against the same database agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps}) - agplAgentClient := agentsdk.New(agplClient.URL) - agplAgentClient.SetSessionToken(r.AgentToken) + agplAgentClient := agentsdk.New(agplClient.URL, agentsdk.UsingFixedToken(r.AgentToken)) banners = requireGetAnnouncementBanners(ctx, t, agplAgentClient) require.Equal(t, []codersdk.BannerConfig{}, banners) diff --git a/enterprise/coderd/gitsshkey_test.go b/enterprise/coderd/gitsshkey_test.go index a4978ac8fdad3..6b6d46f9f58a3 100644 --- a/enterprise/coderd/gitsshkey_test.go +++ b/enterprise/coderd/gitsshkey_test.go @@ -69,8 +69,7 @@ func TestAgentGitSSHKeyCustomRoles(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index c9d44e667c212..a308279295a58 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -319,7 +319,7 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) agentClient.SDK.HTTPClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -328,7 +328,6 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr }, }, } - agentClient.SetSessionToken(authToken) agnt := agent.New(agent.Options{ Client: agentClient, Logger: testutil.Logger(t).Named("agent"), diff --git a/scaletest/createworkspaces/run_test.go b/scaletest/createworkspaces/run_test.go index c63854ff8a1fd..ce832947177ff 100644 --- a/scaletest/createworkspaces/run_test.go +++ b/scaletest/createworkspaces/run_test.go @@ -561,8 +561,7 @@ func goEventuallyStartFakeAgent(ctx context.Context, t *testing.T, client *coder coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(agentToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(agentToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}). diff --git a/scaletest/workspacebuild/run_test.go b/scaletest/workspacebuild/run_test.go index 5949f04d5bccd..977de6bbd573f 100644 --- a/scaletest/workspacebuild/run_test.go +++ b/scaletest/workspacebuild/run_test.go @@ -134,8 +134,7 @@ func Test_Runner(t *testing.T) { for i, authToken := range []string{authToken1, authToken2, authToken3} { i := i + 1 - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}). From f7a8557877a70582df878700de335ecca1190f33 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 28 Aug 2025 12:44:31 +0000 Subject: [PATCH 2/5] rename to RequestWithoutSessionToken --- codersdk/agentsdk/aws.go | 2 +- codersdk/agentsdk/azure.go | 2 +- codersdk/agentsdk/google.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codersdk/agentsdk/aws.go b/codersdk/agentsdk/aws.go index 622657ea92211..8865ca2708cf9 100644 --- a/codersdk/agentsdk/aws.go +++ b/codersdk/agentsdk/aws.go @@ -81,7 +81,7 @@ func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateRe } // request without the token to avoid re-entering this function - res, err = a.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ + res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ Signature: string(signature), Document: string(document), }) diff --git a/codersdk/agentsdk/azure.go b/codersdk/agentsdk/azure.go index 42ec2f9b0d836..32a5793f025e8 100644 --- a/codersdk/agentsdk/azure.go +++ b/codersdk/agentsdk/azure.go @@ -47,7 +47,7 @@ func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (Authenticate return AuthenticateResponse{}, err } - res, err = a.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) + res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) if err != nil { return AuthenticateResponse{}, err } diff --git a/codersdk/agentsdk/google.go b/codersdk/agentsdk/google.go index 8b67b16ac63e1..e3cf12c5ecdc1 100644 --- a/codersdk/agentsdk/google.go +++ b/codersdk/agentsdk/google.go @@ -56,7 +56,7 @@ func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (Authenticat return AuthenticateResponse{}, xerrors.Errorf("get metadata identity: %w", err) } // request without the token to avoid re-entering this function - res, err := g.client.RequestNoSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ + res, err := g.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ JSONWebToken: jwt, }) if err != nil { From 09c2cd4b085e98cb6a9c48995ce956c9c303433e Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 29 Aug 2025 12:19:58 +0000 Subject: [PATCH 3/5] add test for agent calling RefreshToken() --- agent/agent_test.go | 3 +++ agent/agenttest/client.go | 20 +++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 72219cbe16fa2..e8b3b99a95387 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2952,9 +2952,12 @@ func TestAgent_Reconnect(t *testing.T) { defer closer.Close() call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + require.Equal(t, client.GetNumRefreshTokenCalls(), 1) close(call1.Resps) // hang up // expect reconnect testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + // Check that the agent refreshes the token when it reconnects. + require.Equal(t, client.GetNumRefreshTokenCalls(), 2) closer.Close() } diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 3e9f025f18a7e..ff601a7d08393 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -88,10 +88,11 @@ type Client struct { fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() - mu sync.Mutex // Protects following. - logs []agentsdk.Log - derpMapUpdates chan *tailcfg.DERPMap - derpMapOnce sync.Once + mu sync.Mutex // Protects following. + logs []agentsdk.Log + derpMapUpdates chan *tailcfg.DERPMap + derpMapOnce sync.Once + refreshTokenCalls int } func (*Client) AsRequestOption() codersdk.RequestOption { @@ -104,10 +105,19 @@ func (*Client) GetSessionToken() string { return "agenttest-token" } -func (*Client) RefreshToken(context.Context) error { +func (c *Client) RefreshToken(context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + c.refreshTokenCalls++ return nil } +func (c *Client) GetNumRefreshTokenCalls() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.refreshTokenCalls +} + func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {} func (c *Client) Close() { From 813a69e9e35572b29b0c0763a7329ad608f383f0 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 2 Sep 2025 10:04:50 +0000 Subject: [PATCH 4/5] refactor agent auth options into it's own type --- agent/agenttest/agent.go | 2 +- cli/agent.go | 15 ++- cli/exp_mcp.go | 15 ++- cli/externalauth.go | 23 ++-- cli/gitaskpass.go | 8 +- cli/gitssh.go | 7 +- cli/gitssh_test.go | 2 +- cli/root.go | 125 +++++++++--------- cli/testdata/coder_agent_--help.golden | 12 ++ ...r_external-auth_access-token_--help.golden | 12 ++ coderd/externalauth_test.go | 12 +- coderd/gitsshkey_test.go | 4 +- coderd/insights_test.go | 4 +- .../insights/metricscollector_test.go | 2 +- .../prometheusmetrics_test.go | 2 +- coderd/workspaceagents_test.go | 32 ++--- coderd/workspaceagentsrpc_test.go | 4 +- coderd/workspaceapps/apptest/setup.go | 2 +- coderd/workspaceresourceauth_test.go | 10 +- codersdk/agentsdk/agentsdk.go | 2 +- codersdk/agentsdk/agentsdk_test.go | 2 +- codersdk/agentsdk/aws.go | 2 +- codersdk/agentsdk/azure.go | 2 +- codersdk/agentsdk/google.go | 2 +- codersdk/toolsdk/toolsdk_test.go | 2 +- .../cli/external-auth_access-token.md | 37 ++++++ enterprise/coderd/appearance_test.go | 4 +- enterprise/coderd/gitsshkey_test.go | 2 +- enterprise/coderd/workspaceagents_test.go | 2 +- scaletest/createworkspaces/run_test.go | 2 +- scaletest/workspacebuild/run_test.go | 2 +- 31 files changed, 213 insertions(+), 141 deletions(-) diff --git a/agent/agenttest/agent.go b/agent/agenttest/agent.go index 8a2a6260c291c..a6356e6e2503d 100644 --- a/agent/agenttest/agent.go +++ b/agent/agenttest/agent.go @@ -30,7 +30,7 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent } if o.Client == nil { - agentClient := agentsdk.New(coderURL, agentsdk.UsingFixedToken(agentToken)) + agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken)) agentClient.SDK.SetLogger(log) o.Client = agentClient } diff --git a/cli/agent.go b/cli/agent.go index 99caf773de8b4..342522ad057f7 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -37,7 +37,7 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) -func (r *RootCmd) workspaceAgent() *serpent.Command { +func workspaceAgent() *serpent.Command { var ( logDir string scriptDataDir string @@ -57,6 +57,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { devcontainerProjectDiscovery bool devcontainerDiscoveryAutostart bool ) + agentAuth := NewAgentAuth() cmd := &serpent.Command{ Use: "agent", Short: `Starts the Coder workspace agent.`, @@ -174,11 +175,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { version := buildinfo.Version() logger.Info(ctx, "agent is starting now", - slog.F("url", r.agentURL), - slog.F("auth", r.agentAuth), + slog.F("url", agentAuth.agentURL), + slog.F("auth", agentAuth.agentAuth), slog.F("version", version), ) - client, err := r.createAgentClient(ctx) + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -190,7 +191,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { client.SDK.HTTPClient.Timeout = 30 * time.Second // Attach header transport so we process --agent-header and // --agent-header-command flags - headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand) + headerTransport, err := headerTransport(ctx, agentAuth.agentURL, agentHeader, agentHeaderCommand) if err != nil { return xerrors.Errorf("configure header transport: %w", err) } @@ -292,7 +293,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Execer: execer, Devcontainers: devcontainers, DevcontainerAPIOptions: []agentcontainers.Option{ - agentcontainers.WithSubAgentURL(r.agentURL.String()), + agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()), agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery), agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart), }, @@ -449,7 +450,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Value: serpent.BoolOf(&devcontainerDiscoveryAutostart), }, } - + agentAuth.AttachOptions(cmd, false) return cmd } diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 6ee4a362be4eb..59dab808a5472 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -56,7 +56,7 @@ func (r *RootCmd) mcpConfigure() *serpent.Command { }, Children: []*serpent.Command{ r.mcpConfigureClaudeDesktop(), - r.mcpConfigureClaudeCode(), + mcpConfigureClaudeCode(), r.mcpConfigureCursor(), }, } @@ -117,7 +117,7 @@ func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command { return cmd } -func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { +func mcpConfigureClaudeCode() *serpent.Command { var ( claudeAPIKey string claudeConfigPath string @@ -131,6 +131,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { deprecatedCoderMCPClaudeAPIKey string ) + agentAuth := NewAgentAuth() cmd := &serpent.Command{ Use: "claude-code ", Short: "Configure the Claude Code server. You will need to run this command for each project you want to use. Specify the project directory as the first argument.", @@ -148,7 +149,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { binPath = testBinaryName } configureClaudeEnv := map[string]string{} - agentClient, err := r.createAgentClient(inv.Context()) + agentClient, err := agentAuth.CreateClient(inv.Context()) if err != nil { cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err) } else { @@ -292,6 +293,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { }, }, } + agentAuth.AttachOptions(cmd, false) return cmd } @@ -403,7 +405,8 @@ func (r *RootCmd) mcpServer() *serpent.Command { appStatusSlug string aiAgentAPIURL url.URL ) - return &serpent.Command{ + agentAuth := NewAgentAuth() + cmd := &serpent.Command{ Use: "server", Handler: func(inv *serpent.Invocation) error { var lastReport taskReport @@ -494,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command { } // Try to create an agent client for status reporting. Not validated. - agentClient, err := r.createAgentClient(inv.Context()) + agentClient, err := agentAuth.CreateClient(inv.Context()) if err == nil { cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String()) srv.agentClient = agentClient @@ -579,6 +582,8 @@ func (r *RootCmd) mcpServer() *serpent.Command { }, }, } + agentAuth.AttachOptions(cmd, false) + return cmd } func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) { diff --git a/cli/externalauth.go b/cli/externalauth.go index 3910d6b01afd0..4bea37466063d 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -2,19 +2,16 @@ package cli import ( "encoding/json" - "fmt" - - "golang.org/x/xerrors" "github.com/tidwall/gjson" + "golang.org/x/xerrors" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk/agentsdk" - "github.com/coder/pretty" "github.com/coder/serpent" ) -func (r *RootCmd) externalAuth() *serpent.Command { +func externalAuth() *serpent.Command { return &serpent.Command{ Use: "external-auth", Short: "Manage external authentication", @@ -23,14 +20,15 @@ func (r *RootCmd) externalAuth() *serpent.Command { return i.Command.HelpHandler(i) }, Children: []*serpent.Command{ - r.externalAuthAccessToken(), + externalAuthAccessToken(), }, } } -func (r *RootCmd) externalAuthAccessToken() *serpent.Command { +func externalAuthAccessToken() *serpent.Command { var extra string - return &serpent.Command{ + agentAuth := NewAgentAuth() + cmd := &serpent.Command{ Use: "access-token ", Short: "Print auth for an external provider", Long: "Print an access-token for an external auth provider. " + @@ -70,12 +68,7 @@ fi ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() - if r.agentToken == "" { - _, _ = fmt.Fprint(inv.Stderr, pretty.Sprintf(headLineStyle(), "No agent token found, this command must be run from inside a running workspace.\n")) - return xerrors.Errorf("agent token not found") - } - - client, err := r.createAgentClient(ctx) + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -115,4 +108,6 @@ fi return nil }, } + agentAuth.AttachOptions(cmd, false) + return cmd } diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index f41b0e152b3e3..4729b333ae154 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -18,8 +18,8 @@ import ( // gitAskpass is used by the Coder agent to automatically authenticate // with Git providers based on a hostname. -func (r *RootCmd) gitAskpass() *serpent.Command { - return &serpent.Command{ +func gitAskpass(agentAuth *AgentAuth) *serpent.Command { + cmd := &serpent.Command{ Use: "gitaskpass", Hidden: true, Handler: func(inv *serpent.Invocation) error { @@ -33,7 +33,7 @@ func (r *RootCmd) gitAskpass() *serpent.Command { return xerrors.Errorf("parse host: %w", err) } - client, err := r.createAgentClient(ctx) + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -90,4 +90,6 @@ func (r *RootCmd) gitAskpass() *serpent.Command { return nil }, } + agentAuth.AttachOptions(cmd, false) + return cmd } diff --git a/cli/gitssh.go b/cli/gitssh.go index 59cc0299e1d22..4b42a85c27447 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -18,7 +18,8 @@ import ( "github.com/coder/serpent" ) -func (r *RootCmd) gitssh() *serpent.Command { +func gitssh() *serpent.Command { + agentAuth := NewAgentAuth() cmd := &serpent.Command{ Use: "gitssh", Hidden: true, @@ -38,7 +39,7 @@ func (r *RootCmd) gitssh() *serpent.Command { return err } - client, err := r.createAgentClient(ctx) + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -108,7 +109,7 @@ func (r *RootCmd) gitssh() *serpent.Command { return nil }, } - + agentAuth.AttachOptions(cmd, false) return cmd } diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index 85f24a9ca1aab..8ff32363e986b 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -54,7 +54,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str }).WithAgent().Do() // start workspace agent - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { o.Client = agentClient }) diff --git a/cli/root.go b/cli/root.go index d42946819c127..2fbbadbdd16df 100644 --- a/cli/root.go +++ b/cli/root.go @@ -93,7 +93,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command { return []*serpent.Command{ r.completion(), r.dotfiles(), - r.externalAuth(), + externalAuth(), r.login(), r.logout(), r.netcheck(), @@ -133,11 +133,11 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command { // Hidden r.connectCmd(), r.expCmd(), - r.gitssh(), + gitssh(), r.support(), r.vpnDaemon(), r.vscodeSSH(), - r.workspaceAgent(), + workspaceAgent(), } } @@ -201,6 +201,7 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) { func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) { fmtLong := `Coder %s — A tool for provisioning self-hosted development environments with Terraform. ` + hiddenAgentAuth := NewAgentAuth() cmd := &serpent.Command{ Use: "coder [global-flags] ", Long: fmt.Sprintf(fmtLong, buildinfo.Version()) + FormatExamples( @@ -223,7 +224,7 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err // with a `gitaskpass` subcommand, we override the entrypoint // to check if the command was invoked. if gitauth.CheckCommand(i.Args, i.Environ.ToOS()) { - return r.gitAskpass().Handler(i) + return gitAskpass(hiddenAgentAuth).Handler(i) } return i.Command.HelpHandler(i) }, @@ -352,9 +353,6 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err } }) - if r.agentURL == nil { - r.agentURL = new(url.URL) - } if r.clientURL == nil { r.clientURL = new(url.URL) } @@ -384,39 +382,6 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err Value: serpent.StringOf(&r.token), Group: globalGroup, }, - { - Flag: varAgentToken, - Env: envAgentToken, - Description: "An agent authentication token.", - Value: serpent.StringOf(&r.agentToken), - Hidden: true, - Group: globalGroup, - }, - { - Flag: varAgentTokenFile, - Env: envAgentTokenFile, - Description: "A file containing an agent authentication token.", - Value: serpent.StringOf(&r.agentTokenFile), - Hidden: true, - Group: globalGroup, - }, - { - Flag: varAgentURL, - Env: envAgentURL, - Description: "URL for an agent to access your deployment.", - Value: serpent.URLOf(r.agentURL), - Hidden: true, - Group: globalGroup, - }, - { - Flag: varAgentAuth, - Env: envAgentAuth, - Default: "token", - Description: "Specify the authentication type to use for the agent.", - Value: serpent.StringOf(&r.agentAuth), - Hidden: true, - Group: globalGroup, - }, { Flag: varNoVersionCheck, Env: envNoVersionCheck, @@ -508,6 +473,7 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err Hidden: true, }, } + hiddenAgentAuth.AttachOptions(cmd, true) return cmd, nil } @@ -520,12 +486,6 @@ type RootCmd struct { header []string headerCommand string - // Agent Client config - agentToken string - agentTokenFile string - agentURL *url.URL - agentAuth string - forceTTY bool noOpen bool verbose bool @@ -688,31 +648,78 @@ func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *ur return &client, err } -// createAgentClient returns a new client from the command context. It works +type AgentAuth struct { + // Agent Client config + agentToken string + agentTokenFile string + agentURL *url.URL + agentAuth string +} + +func NewAgentAuth() *AgentAuth { + return &AgentAuth{ + agentURL: new(url.URL), + } +} + +func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { + cmd.Options = append(cmd.Options, serpent.Option{ + Name: "Agent Token", + Description: "An agent authentication token.", + Flag: varAgentToken, + Env: envAgentToken, + Value: serpent.StringOf(&a.agentToken), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent Token File", + Description: "A file containing an agent authentication token.", + Flag: varAgentTokenFile, + Env: envAgentTokenFile, + Value: serpent.StringOf(&a.agentTokenFile), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent URL", + Description: "URL for an agent to access your deployment.", + Flag: varAgentURL, + Env: envAgentURL, + Value: serpent.URLOf(a.agentURL), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent Auth", + Description: "Specify the authentication type to use for the agent.", + Flag: varAgentAuth, + Env: envAgentAuth, + Default: "token", + Value: serpent.StringOf(&a.agentAuth), + Hidden: hidden, + }) +} + +// CreateClient returns a new agent client from the command context. It works // just like InitClient, but uses the agent token and URL instead. -func (r *RootCmd) createAgentClient(ctx context.Context) (*agentsdk.Client, error) { - agentURL := r.agentURL +func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) { + agentURL := a.agentURL if agentURL == nil || agentURL.String() == "" { return nil, xerrors.Errorf("%s must be set", envAgentURL) } - switch r.agentAuth { + switch a.agentAuth { case "token": - token := r.agentToken + token := a.agentToken if token == "" { - if r.agentTokenFile == "" { + if a.agentTokenFile == "" { return nil, xerrors.Errorf("Either %s or %s must be set", envAgentToken, envAgentTokenFile) } - tokenBytes, err := os.ReadFile(r.agentTokenFile) + tokenBytes, err := os.ReadFile(a.agentTokenFile) if err != nil { - return nil, xerrors.Errorf("read token file %q: %w", r.agentTokenFile, err) + return nil, xerrors.Errorf("read token file %q: %w", a.agentTokenFile, err) } token = strings.TrimSpace(string(tokenBytes)) } if token == "" { return nil, xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") } - return agentsdk.New(r.agentURL, agentsdk.UsingFixedToken(token)), nil + return agentsdk.New(a.agentURL, agentsdk.WithFixedToken(token)), nil case "google-instance-identity": // This is *only* done for testing to mock client authentication. @@ -722,9 +729,9 @@ func (r *RootCmd) createAgentClient(ctx context.Context) (*agentsdk.Client, erro if gcpClientRaw != nil { gcpClient, _ = gcpClientRaw.(*metadata.Client) } - return agentsdk.New(r.agentURL, agentsdk.UsingGoogleInstanceIdentity("", gcpClient)), nil + return agentsdk.New(a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil case "aws-instance-identity": - client := agentsdk.New(r.agentURL, agentsdk.UsingAWSInstanceIdentity()) + client := agentsdk.New(a.agentURL, agentsdk.WithAWSInstanceIdentity()) // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var awsClient *http.Client @@ -737,7 +744,7 @@ func (r *RootCmd) createAgentClient(ctx context.Context) (*agentsdk.Client, erro } return client, nil case "azure-instance-identity": - client := agentsdk.New(r.agentURL, agentsdk.UsingAzureInstanceIdentity()) + client := agentsdk.New(a.agentURL, agentsdk.WithAzureInstanceIdentity()) // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var azureClient *http.Client @@ -750,7 +757,7 @@ func (r *RootCmd) createAgentClient(ctx context.Context) (*agentsdk.Client, erro } return client, nil default: - return nil, xerrors.Errorf("unknown agent auth type: %s", r.agentAuth) + return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth) } } diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index 0541210c00824..1f25fc6941ea1 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -6,6 +6,18 @@ USAGE: Starts the Coder workspace agent. OPTIONS: + --auth string, $CODER_AGENT_AUTH (default: token) + Specify the authentication type to use for the agent. + + --agent-token string, $CODER_AGENT_TOKEN + An agent authentication token. + + --agent-token-file string, $CODER_AGENT_TOKEN_FILE + A file containing an agent authentication token. + + --agent-url url, $CODER_AGENT_URL + URL for an agent to access your deployment. + --log-human string, $CODER_AGENT_LOGGING_HUMAN (default: /dev/stderr) Output human-readable logs to a given file. diff --git a/cli/testdata/coder_external-auth_access-token_--help.golden b/cli/testdata/coder_external-auth_access-token_--help.golden index e4693a6fb9a6d..234cca5d4f917 100644 --- a/cli/testdata/coder_external-auth_access-token_--help.golden +++ b/cli/testdata/coder_external-auth_access-token_--help.golden @@ -25,6 +25,18 @@ USAGE: $ coder external-auth access-token slack --extra "authed_user.id" OPTIONS: + --auth string, $CODER_AGENT_AUTH (default: token) + Specify the authentication type to use for the agent. + + --agent-token string, $CODER_AGENT_TOKEN + An agent authentication token. + + --agent-token-file string, $CODER_AGENT_TOKEN_FILE + A file containing an agent authentication token. + + --agent-url url, $CODER_AGENT_URL + URL for an agent to access your deployment. + --extra string Extract a field from the "extra" properties of the OAuth token. diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index 2e648fd8d8024..68244bf3a49c4 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -432,7 +432,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) _, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com", }) @@ -463,7 +463,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", }) @@ -563,7 +563,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) resp := coderdtest.RequestExternalAuthCallback(t, "github", client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) @@ -624,7 +624,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -670,7 +670,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -735,7 +735,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(t.Context(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", diff --git a/coderd/gitsshkey_test.go b/coderd/gitsshkey_test.go index 448705c0fedbf..27f9121bd39b4 100644 --- a/coderd/gitsshkey_test.go +++ b/coderd/gitsshkey_test.go @@ -118,7 +118,7 @@ func TestAgentGitSSHKey(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -156,7 +156,7 @@ func TestAgentGitSSHKey_APIKeyScopes(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/coderd/insights_test.go b/coderd/insights_test.go index f4f3272a80261..99bf9b9a667b9 100644 --- a/coderd/insights_test.go +++ b/coderd/insights_test.go @@ -585,7 +585,7 @@ func TestTemplateInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken.String())) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App @@ -1493,7 +1493,7 @@ func TestUserActivityInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken.String())) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App diff --git a/coderd/prometheusmetrics/insights/metricscollector_test.go b/coderd/prometheusmetrics/insights/metricscollector_test.go index cbd082bdfd920..560a601992140 100644 --- a/coderd/prometheusmetrics/insights/metricscollector_test.go +++ b/coderd/prometheusmetrics/insights/metricscollector_test.go @@ -90,7 +90,7 @@ func TestCollectInsights(t *testing.T) { // Start an agent so that we can generate stats. var agentClients []agentproto.DRPCAgentClient for i, agent := range []database.WorkspaceAgent{agent1, agent2} { - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(agent.AuthToken.String())) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(agent.AuthToken.String())) agentClient.SDK.SetLogger(logger.Leveled(slog.LevelDebug).Named(fmt.Sprintf("agent%d", i+1))) conn, err := agentClient.ConnectRPC(context.Background()) require.NoError(t, err) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index dd5b3f5db74cf..e75f86e51b55c 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -875,7 +875,7 @@ func prepareWorkspaceAndAgent(ctx context.Context, t *testing.T, client *codersd }) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) agentAPI := agentproto.NewDRPCAgentClient(conn) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 92fbed2cf8421..e950f970755bb 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -228,7 +228,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -268,7 +268,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -312,7 +312,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { updates, err := client.WatchWorkspace(ctx, r.Workspace.ID) require.NoError(t, err) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err = agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{{ CreatedAt: dbtime.Now(), @@ -357,7 +357,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { return a }).Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -538,7 +538,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { require.NoError(t, err) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, stopBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) @@ -563,7 +563,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { ) require.NoError(t, err) // Then: the agent token should no longer be valid - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken((wsb.AgentToken))) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken((wsb.AgentToken))) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) var sdkErr *codersdk.Error @@ -884,7 +884,7 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) // Verify that the manifest has DisableDirectConnections set to true. - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) rpc, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1735,7 +1735,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1810,7 +1810,7 @@ func TestWorkspaceAgentPostLogSource(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) req := agentsdk.PostLogSourceRequest{ ID: uuid.New(), @@ -1858,7 +1858,7 @@ func TestWorkspaceAgent_LifecycleState(t *testing.T) { } } - ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1955,7 +1955,7 @@ func TestWorkspaceAgent_Metadata(t *testing.T) { } } - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) conn, err := agentClient.ConnectRPC(ctx) @@ -2218,7 +2218,7 @@ func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) { } } - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitSuperLong) conn, err := agentClient.ConnectRPC(ctx) @@ -2323,7 +2323,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2369,7 +2369,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2533,7 +2533,7 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) { return agents }).Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) // We need to include an invalid oauth token that is not expired. dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{ @@ -3013,7 +3013,7 @@ func TestReinit(t *testing.T) { pubsubSpy.Unlock() agentCtx := testutil.Context(t, testutil.WaitShort) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent) go func() { diff --git a/coderd/workspaceagentsrpc_test.go b/coderd/workspaceagentsrpc_test.go index 221a04d4fcb68..525b8a981dbb5 100644 --- a/coderd/workspaceagentsrpc_test.go +++ b/coderd/workspaceagentsrpc_test.go @@ -68,7 +68,7 @@ func TestWorkspaceAgentReportStats(t *testing.T) { }, ).Do() - ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(context.Background()) require.NoError(t, err) defer func() { @@ -154,7 +154,7 @@ func TestAgentAPI_LargeManifest(t *testing.T) { agents[0].ApiKeyScope = string(tc.apiKeyScope) return agents }).Do() - ac := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) defer func() { _ = conn.Close() diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index ebef0375f6959..05bfb66219088 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -482,7 +482,7 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U require.Equal(t, appURL.String(), app.SubdomainName) } - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) // TODO (@dean): currently, the primary app host is used when generating // the port URL we tell the agent to use. We don't have any plans to change diff --git a/coderd/workspaceresourceauth_test.go b/coderd/workspaceresourceauth_test.go index 53b2f33d47747..73524a63ade62 100644 --- a/coderd/workspaceresourceauth_test.go +++ b/coderd/workspaceresourceauth_test.go @@ -51,7 +51,7 @@ func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingAzureInstanceIdentity()) + agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity()) agentClient.SDK.HTTPClient = metadataClient err := agentClient.RefreshToken(ctx) require.NoError(t, err) @@ -95,7 +95,7 @@ func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingAWSInstanceIdentity()) + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) agentClient.SDK.HTTPClient = metadataClient err := agentClient.RefreshToken(ctx) require.NoError(t, err) @@ -115,7 +115,7 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) @@ -133,7 +133,7 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) @@ -176,7 +176,7 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL, agentsdk.UsingGoogleInstanceIdentity("", metadata)) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index a0739cac13956..d13f600a03e0a 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -414,7 +414,7 @@ func (FixedSessionTokenProvider) RefreshToken(_ context.Context) error { return nil } -func UsingFixedToken(token string) SessionTokenSetup { +func WithFixedToken(token string) SessionTokenSetup { return func(_ *codersdk.Client) RefreshableSessionTokenProvider { return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}} } diff --git a/codersdk/agentsdk/agentsdk_test.go b/codersdk/agentsdk/agentsdk_test.go index 748b6203aaea0..4f3d7d838b524 100644 --- a/codersdk/agentsdk/agentsdk_test.go +++ b/codersdk/agentsdk/agentsdk_test.go @@ -141,7 +141,7 @@ func TestRewriteDERPMap(t *testing.T) { } parsed, err := url.Parse("https://coconuts.org:44558") require.NoError(t, err) - client := agentsdk.New(parsed, agentsdk.UsingFixedToken("unused")) + client := agentsdk.New(parsed, agentsdk.WithFixedToken("unused")) client.RewriteDERPMap(dm) region := dm.Regions[1] require.True(t, region.EmbeddedRelay) diff --git a/codersdk/agentsdk/aws.go b/codersdk/agentsdk/aws.go index 8865ca2708cf9..b4f30ec4e95e5 100644 --- a/codersdk/agentsdk/aws.go +++ b/codersdk/agentsdk/aws.go @@ -22,7 +22,7 @@ type awsSessionTokenExchanger struct { client *codersdk.Client } -func UsingAWSInstanceIdentity() SessionTokenSetup { +func WithAWSInstanceIdentity() SessionTokenSetup { return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &instanceIdentitySessionTokenProvider{ tokenExchanger: &awsSessionTokenExchanger{client: client}, diff --git a/codersdk/agentsdk/azure.go b/codersdk/agentsdk/azure.go index 32a5793f025e8..eb66e21097cf4 100644 --- a/codersdk/agentsdk/azure.go +++ b/codersdk/agentsdk/azure.go @@ -19,7 +19,7 @@ type azureSessionTokenExchanger struct { client *codersdk.Client } -func UsingAzureInstanceIdentity() SessionTokenSetup { +func WithAzureInstanceIdentity() SessionTokenSetup { return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &instanceIdentitySessionTokenProvider{ tokenExchanger: &azureSessionTokenExchanger{client: client}, diff --git a/codersdk/agentsdk/google.go b/codersdk/agentsdk/google.go index e3cf12c5ecdc1..e462ba24049cc 100644 --- a/codersdk/agentsdk/google.go +++ b/codersdk/agentsdk/google.go @@ -24,7 +24,7 @@ type googleSessionTokenExchanger struct { client *codersdk.Client } -func UsingGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { +func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &instanceIdentitySessionTokenProvider{ tokenExchanger: &googleSessionTokenExchanger{ diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 299c6fe5f3519..6d4031e22ac49 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -75,7 +75,7 @@ func TestTools(t *testing.T) { }).Do() // Given: a client configured with the agent token. - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) // Get the agent ID from the API. Overriding it in dbfake doesn't work. ws, err := client.Workspace(setupCtx, r.Workspace.ID) require.NoError(t, err) diff --git a/docs/reference/cli/external-auth_access-token.md b/docs/reference/cli/external-auth_access-token.md index 2303e8f076da8..7fb022077ac9f 100644 --- a/docs/reference/cli/external-auth_access-token.md +++ b/docs/reference/cli/external-auth_access-token.md @@ -40,3 +40,40 @@ fi | Type | string | Extract a field from the "extra" properties of the OAuth token. + +### --agent-token + +| | | +|-------------|---------------------------------| +| Type | string | +| Environment | $CODER_AGENT_TOKEN | + +An agent authentication token. + +### --agent-token-file + +| | | +|-------------|--------------------------------------| +| Type | string | +| Environment | $CODER_AGENT_TOKEN_FILE | + +A file containing an agent authentication token. + +### --agent-url + +| | | +|-------------|-------------------------------| +| Type | url | +| Environment | $CODER_AGENT_URL | + +URL for an agent to access your deployment. + +### --auth + +| | | +|-------------|--------------------------------| +| Type | string | +| Environment | $CODER_AGENT_AUTH | +| Default | token | + +Specify the authentication type to use for the agent. diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index 0995f98f6d6ca..81ba7eddc7354 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -153,13 +153,13 @@ func TestAnnouncementBanners(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) banners := requireGetAnnouncementBanners(ctx, t, agentClient) require.Equal(t, cfg.AnnouncementBanners, banners) // Create an AGPL Coderd against the same database agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps}) - agplAgentClient := agentsdk.New(agplClient.URL, agentsdk.UsingFixedToken(r.AgentToken)) + agplAgentClient := agentsdk.New(agplClient.URL, agentsdk.WithFixedToken(r.AgentToken)) banners = requireGetAnnouncementBanners(ctx, t, agplAgentClient) require.Equal(t, []codersdk.BannerConfig{}, banners) diff --git a/enterprise/coderd/gitsshkey_test.go b/enterprise/coderd/gitsshkey_test.go index 6b6d46f9f58a3..7045c8dd860fe 100644 --- a/enterprise/coderd/gitsshkey_test.go +++ b/enterprise/coderd/gitsshkey_test.go @@ -69,7 +69,7 @@ func TestAgentGitSSHKeyCustomRoles(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index a308279295a58..917d44dff2d48 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -319,7 +319,7 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) agentClient.SDK.HTTPClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ diff --git a/scaletest/createworkspaces/run_test.go b/scaletest/createworkspaces/run_test.go index ce832947177ff..edade6b79ed9a 100644 --- a/scaletest/createworkspaces/run_test.go +++ b/scaletest/createworkspaces/run_test.go @@ -561,7 +561,7 @@ func goEventuallyStartFakeAgent(ctx context.Context, t *testing.T, client *coder coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(agentToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(agentToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}). diff --git a/scaletest/workspacebuild/run_test.go b/scaletest/workspacebuild/run_test.go index 977de6bbd573f..f813019d0f6a0 100644 --- a/scaletest/workspacebuild/run_test.go +++ b/scaletest/workspacebuild/run_test.go @@ -134,7 +134,7 @@ func Test_Runner(t *testing.T) { for i, authToken := range []string{authToken1, authToken2, authToken3} { i := i + 1 - agentClient := agentsdk.New(client.URL, agentsdk.UsingFixedToken(authToken)) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}). From 68c9194579bceada816cf167497706c6cbaef6eb Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 3 Sep 2025 06:14:58 +0000 Subject: [PATCH 5/5] make URL not a pointer in AgentAuth --- cli/agent.go | 4 ++-- cli/exp_mcp.go | 4 ++-- cli/externalauth.go | 2 +- cli/gitssh.go | 2 +- cli/root.go | 34 ++++++++++++---------------------- 5 files changed, 18 insertions(+), 28 deletions(-) diff --git a/cli/agent.go b/cli/agent.go index 342522ad057f7..2b8efad55bcfb 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -57,7 +57,7 @@ func workspaceAgent() *serpent.Command { devcontainerProjectDiscovery bool devcontainerDiscoveryAutostart bool ) - agentAuth := NewAgentAuth() + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "agent", Short: `Starts the Coder workspace agent.`, @@ -191,7 +191,7 @@ func workspaceAgent() *serpent.Command { client.SDK.HTTPClient.Timeout = 30 * time.Second // Attach header transport so we process --agent-header and // --agent-header-command flags - headerTransport, err := headerTransport(ctx, agentAuth.agentURL, agentHeader, agentHeaderCommand) + headerTransport, err := headerTransport(ctx, &agentAuth.agentURL, agentHeader, agentHeaderCommand) if err != nil { return xerrors.Errorf("configure header transport: %w", err) } diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 59dab808a5472..8388a5a4c71ad 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -131,7 +131,7 @@ func mcpConfigureClaudeCode() *serpent.Command { deprecatedCoderMCPClaudeAPIKey string ) - agentAuth := NewAgentAuth() + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "claude-code ", Short: "Configure the Claude Code server. You will need to run this command for each project you want to use. Specify the project directory as the first argument.", @@ -405,7 +405,7 @@ func (r *RootCmd) mcpServer() *serpent.Command { appStatusSlug string aiAgentAPIURL url.URL ) - agentAuth := NewAgentAuth() + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "server", Handler: func(inv *serpent.Invocation) error { diff --git a/cli/externalauth.go b/cli/externalauth.go index 4bea37466063d..4aaa72c19759d 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -27,7 +27,7 @@ func externalAuth() *serpent.Command { func externalAuthAccessToken() *serpent.Command { var extra string - agentAuth := NewAgentAuth() + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "access-token ", Short: "Print auth for an external provider", diff --git a/cli/gitssh.go b/cli/gitssh.go index 4b42a85c27447..043049b7e8a97 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -19,7 +19,7 @@ import ( ) func gitssh() *serpent.Command { - agentAuth := NewAgentAuth() + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "gitssh", Hidden: true, diff --git a/cli/root.go b/cli/root.go index 2fbbadbdd16df..a18401e253038 100644 --- a/cli/root.go +++ b/cli/root.go @@ -60,10 +60,6 @@ var ( const ( varURL = "url" varToken = "token" - varAgentToken = "agent-token" - varAgentTokenFile = "agent-token-file" - varAgentURL = "agent-url" - varAgentAuth = "auth" varHeader = "header" varHeaderCommand = "header-command" varNoOpen = "no-open" @@ -201,7 +197,7 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) { func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) { fmtLong := `Coder %s — A tool for provisioning self-hosted development environments with Terraform. ` - hiddenAgentAuth := NewAgentAuth() + hiddenAgentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "coder [global-flags] ", Long: fmt.Sprintf(fmtLong, buildinfo.Version()) + FormatExamples( @@ -652,42 +648,36 @@ type AgentAuth struct { // Agent Client config agentToken string agentTokenFile string - agentURL *url.URL + agentURL url.URL agentAuth string } -func NewAgentAuth() *AgentAuth { - return &AgentAuth{ - agentURL: new(url.URL), - } -} - func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { cmd.Options = append(cmd.Options, serpent.Option{ Name: "Agent Token", Description: "An agent authentication token.", - Flag: varAgentToken, + Flag: "agent-token", Env: envAgentToken, Value: serpent.StringOf(&a.agentToken), Hidden: hidden, }, serpent.Option{ Name: "Agent Token File", Description: "A file containing an agent authentication token.", - Flag: varAgentTokenFile, + Flag: "agent-token-file", Env: envAgentTokenFile, Value: serpent.StringOf(&a.agentTokenFile), Hidden: hidden, }, serpent.Option{ Name: "Agent URL", Description: "URL for an agent to access your deployment.", - Flag: varAgentURL, + Flag: "agent-url", Env: envAgentURL, - Value: serpent.URLOf(a.agentURL), + Value: serpent.URLOf(&a.agentURL), Hidden: hidden, }, serpent.Option{ Name: "Agent Auth", Description: "Specify the authentication type to use for the agent.", - Flag: varAgentAuth, + Flag: "auth", Env: envAgentAuth, Default: "token", Value: serpent.StringOf(&a.agentAuth), @@ -699,7 +689,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { // just like InitClient, but uses the agent token and URL instead. func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) { agentURL := a.agentURL - if agentURL == nil || agentURL.String() == "" { + if agentURL.String() == "" { return nil, xerrors.Errorf("%s must be set", envAgentURL) } @@ -719,7 +709,7 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) if token == "" { return nil, xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") } - return agentsdk.New(a.agentURL, agentsdk.WithFixedToken(token)), nil + return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil case "google-instance-identity": // This is *only* done for testing to mock client authentication. @@ -729,9 +719,9 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) if gcpClientRaw != nil { gcpClient, _ = gcpClientRaw.(*metadata.Client) } - return agentsdk.New(a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil + return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil case "aws-instance-identity": - client := agentsdk.New(a.agentURL, agentsdk.WithAWSInstanceIdentity()) + client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()) // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var awsClient *http.Client @@ -744,7 +734,7 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) } return client, nil case "azure-instance-identity": - client := agentsdk.New(a.agentURL, agentsdk.WithAzureInstanceIdentity()) + client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()) // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var azureClient *http.Client