Skip to content

Commit 85f927b

Browse files
committed
chore: refactor instance identity to be a SessionTokenProvider
1 parent 53e1b76 commit 85f927b

33 files changed

+488
-401
lines changed

agent/agent.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ type Options struct {
7474
LogDir string
7575
TempDir string
7676
ScriptDataDir string
77-
ExchangeToken func(ctx context.Context) (string, error)
7877
Client Client
7978
ReconnectingPTYTimeout time.Duration
8079
EnvironmentVariables map[string]string
@@ -99,6 +98,7 @@ type Client interface {
9998
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
10099
)
101100
tailnet.DERPMapRewriter
101+
agentsdk.RefreshableSessionTokenProvider
102102
}
103103

104104
type Agent interface {
@@ -131,11 +131,6 @@ func New(options Options) Agent {
131131
}
132132
options.ScriptDataDir = options.TempDir
133133
}
134-
if options.ExchangeToken == nil {
135-
options.ExchangeToken = func(_ context.Context) (string, error) {
136-
return "", nil
137-
}
138-
}
139134
if options.ReportMetadataInterval == 0 {
140135
options.ReportMetadataInterval = time.Second
141136
}
@@ -172,7 +167,6 @@ func New(options Options) Agent {
172167
coordDisconnected: make(chan struct{}),
173168
environmentVariables: options.EnvironmentVariables,
174169
client: options.Client,
175-
exchangeToken: options.ExchangeToken,
176170
filesystem: options.Filesystem,
177171
logDir: options.LogDir,
178172
tempDir: options.TempDir,
@@ -203,7 +197,6 @@ func New(options Options) Agent {
203197
// coordinator during shut down.
204198
close(a.coordDisconnected)
205199
a.announcementBanners.Store(new([]codersdk.BannerConfig))
206-
a.sessionToken.Store(new(string))
207200
a.init()
208201
return a
209202
}
@@ -212,7 +205,6 @@ type agent struct {
212205
clock quartz.Clock
213206
logger slog.Logger
214207
client Client
215-
exchangeToken func(ctx context.Context) (string, error)
216208
tailnetListenPort uint16
217209
filesystem afero.Fs
218210
logDir string
@@ -254,7 +246,6 @@ type agent struct {
254246
scriptRunner *agentscripts.Runner
255247
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
256248
announcementBannersRefreshInterval time.Duration
257-
sessionToken atomic.Pointer[string]
258249
sshServer *agentssh.Server
259250
sshMaxTimeout time.Duration
260251
blockFileTransfer bool
@@ -916,11 +907,10 @@ func (a *agent) run() (retErr error) {
916907
// This allows the agent to refresh its token if necessary.
917908
// For instance identity this is required, since the instance
918909
// may not have re-provisioned, but a new agent ID was created.
919-
sessionToken, err := a.exchangeToken(a.hardCtx)
910+
err := a.client.RefreshToken(a.hardCtx)
920911
if err != nil {
921-
return xerrors.Errorf("exchange token: %w", err)
912+
return xerrors.Errorf("refresh token: %w", err)
922913
}
923-
a.sessionToken.Store(&sessionToken)
924914

925915
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
926916
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
@@ -1359,7 +1349,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error)
13591349
"CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName,
13601350

13611351
// Specific Coder subcommands require the agent token exposed!
1362-
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
1352+
"CODER_AGENT_TOKEN": a.client.GetSessionToken(),
13631353

13641354
// Git on Windows resolves with UNIX-style paths.
13651355
// If using backslashes, it's unable to find the executable.

agent/agent_test.go

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"slices"
2323
"strconv"
2424
"strings"
25-
"sync/atomic"
2625
"testing"
2726
"time"
2827

@@ -2926,11 +2925,11 @@ func TestAgent_Speedtest(t *testing.T) {
29262925

29272926
func TestAgent_Reconnect(t *testing.T) {
29282927
t.Parallel()
2928+
ctx := testutil.Context(t, testutil.WaitShort)
29292929
logger := testutil.Logger(t)
29302930
// After the agent is disconnected from a coordinator, it's supposed
29312931
// to reconnect!
2932-
coordinator := tailnet.NewCoordinator(logger)
2933-
defer coordinator.Close()
2932+
fCoordinator := tailnettest.NewFakeCoordinator()
29342933

29352934
agentID := uuid.New()
29362935
statsCh := make(chan *proto.Stats, 50)
@@ -2942,27 +2941,21 @@ func TestAgent_Reconnect(t *testing.T) {
29422941
DERPMap: derpMap,
29432942
},
29442943
statsCh,
2945-
coordinator,
2944+
fCoordinator,
29462945
)
29472946
defer client.Close()
2948-
initialized := atomic.Int32{}
2947+
29492948
closer := agent.New(agent.Options{
2950-
ExchangeToken: func(ctx context.Context) (string, error) {
2951-
initialized.Add(1)
2952-
return "", nil
2953-
},
29542949
Client: client,
29552950
Logger: logger.Named("agent"),
29562951
})
29572952
defer closer.Close()
29582953

2959-
require.Eventually(t, func() bool {
2960-
return coordinator.Node(agentID) != nil
2961-
}, testutil.WaitShort, testutil.IntervalFast)
2962-
client.LastWorkspaceAgent()
2963-
require.Eventually(t, func() bool {
2964-
return initialized.Load() == 2
2965-
}, testutil.WaitShort, testutil.IntervalFast)
2954+
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
2955+
close(call1.Resps) // hang up
2956+
// expect reconnect
2957+
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
2958+
closer.Close()
29662959
}
29672960

29682961
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
@@ -2984,9 +2977,6 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
29842977
defer client.Close()
29852978
filesystem := afero.NewMemMapFs()
29862979
closer := agent.New(agent.Options{
2987-
ExchangeToken: func(ctx context.Context) (string, error) {
2988-
return "", nil
2989-
},
29902980
Client: client,
29912981
Logger: logger.Named("agent"),
29922982
Filesystem: filesystem,
@@ -3015,9 +3005,6 @@ func TestAgent_DebugServer(t *testing.T) {
30153005
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
30163006
DERPMap: derpMap,
30173007
}, 0, func(c *agenttest.Client, o *agent.Options) {
3018-
o.ExchangeToken = func(context.Context) (string, error) {
3019-
return "token", nil
3020-
}
30213008
o.LogDir = logDir
30223009
})
30233010

agent/agenttest/agent.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package agenttest
22

33
import (
4-
"context"
54
"net/url"
65
"testing"
76

@@ -31,18 +30,11 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
3130
}
3231

3332
if o.Client == nil {
34-
agentClient := agentsdk.New(coderURL)
35-
agentClient.SetSessionToken(agentToken)
33+
agentClient := agentsdk.New(coderURL, agentsdk.UsingFixedToken(agentToken))
3634
agentClient.SDK.SetLogger(log)
3735
o.Client = agentClient
3836
}
3937

40-
if o.ExchangeToken == nil {
41-
o.ExchangeToken = func(_ context.Context) (string, error) {
42-
return agentToken, nil
43-
}
44-
}
45-
4638
if o.LogDir == "" {
4739
o.LogDir = t.TempDir()
4840
}

agent/agenttest/client.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package agenttest
33
import (
44
"context"
55
"io"
6+
"net/http"
67
"slices"
78
"sync"
89
"sync/atomic"
@@ -28,6 +29,7 @@ import (
2829
"github.com/coder/coder/v2/tailnet"
2930
"github.com/coder/coder/v2/tailnet/proto"
3031
"github.com/coder/coder/v2/testutil"
32+
"github.com/coder/websocket"
3133
)
3234

3335
const statsInterval = 500 * time.Millisecond
@@ -92,6 +94,20 @@ type Client struct {
9294
derpMapOnce sync.Once
9395
}
9496

97+
func (*Client) AsRequestOption() codersdk.RequestOption {
98+
return func(_ *http.Request) {}
99+
}
100+
101+
func (*Client) SetDialOption(*websocket.DialOptions) {}
102+
103+
func (*Client) GetSessionToken() string {
104+
return "agenttest-token"
105+
}
106+
107+
func (*Client) RefreshToken(context.Context) error {
108+
return nil
109+
}
110+
95111
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
96112

97113
func (c *Client) Close() {

cli/agent.go

Lines changed: 6 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"strings"
1616
"time"
1717

18-
"cloud.google.com/go/compute/metadata"
1918
"golang.org/x/xerrors"
2019
"gopkg.in/natefinch/lumberjack.v2"
2120

@@ -40,7 +39,6 @@ import (
4039

4140
func (r *RootCmd) workspaceAgent() *serpent.Command {
4241
var (
43-
auth string
4442
logDir string
4543
scriptDataDir string
4644
pprofAddress string
@@ -177,11 +175,13 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
177175
version := buildinfo.Version()
178176
logger.Info(ctx, "agent is starting now",
179177
slog.F("url", r.agentURL),
180-
slog.F("auth", auth),
178+
slog.F("auth", r.agentAuth),
181179
slog.F("version", version),
182180
)
183-
184-
client := agentsdk.New(r.agentURL)
181+
client, err := r.createAgentClient(ctx)
182+
if err != nil {
183+
return xerrors.Errorf("create agent client: %w", err)
184+
}
185185
client.SDK.SetLogger(logger)
186186
// Set a reasonable timeout so requests can't hang forever!
187187
// The timeout needs to be reasonably long, because requests
@@ -214,68 +214,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
214214
ignorePorts[port] = "debug"
215215
}
216216

217-
// exchangeToken returns a session token.
218-
// This is abstracted to allow for the same looping condition
219-
// regardless of instance identity auth type.
220-
var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error)
221-
switch auth {
222-
case "token":
223-
token, _ := inv.ParsedFlags().GetString(varAgentToken)
224-
if token == "" {
225-
tokenFile, _ := inv.ParsedFlags().GetString(varAgentTokenFile)
226-
if tokenFile != "" {
227-
tokenBytes, err := os.ReadFile(tokenFile)
228-
if err != nil {
229-
return xerrors.Errorf("read token file %q: %w", tokenFile, err)
230-
}
231-
token = strings.TrimSpace(string(tokenBytes))
232-
}
233-
}
234-
if token == "" {
235-
return xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth")
236-
}
237-
client.SetSessionToken(token)
238-
case "google-instance-identity":
239-
// This is *only* done for testing to mock client authentication.
240-
// This will never be set in a production scenario.
241-
var gcpClient *metadata.Client
242-
gcpClientRaw := ctx.Value("gcp-client")
243-
if gcpClientRaw != nil {
244-
gcpClient, _ = gcpClientRaw.(*metadata.Client)
245-
}
246-
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
247-
return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient)
248-
}
249-
case "aws-instance-identity":
250-
// This is *only* done for testing to mock client authentication.
251-
// This will never be set in a production scenario.
252-
var awsClient *http.Client
253-
awsClientRaw := ctx.Value("aws-client")
254-
if awsClientRaw != nil {
255-
awsClient, _ = awsClientRaw.(*http.Client)
256-
if awsClient != nil {
257-
client.SDK.HTTPClient = awsClient
258-
}
259-
}
260-
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
261-
return client.AuthAWSInstanceIdentity(ctx)
262-
}
263-
case "azure-instance-identity":
264-
// This is *only* done for testing to mock client authentication.
265-
// This will never be set in a production scenario.
266-
var azureClient *http.Client
267-
azureClientRaw := ctx.Value("azure-client")
268-
if azureClientRaw != nil {
269-
azureClient, _ = azureClientRaw.(*http.Client)
270-
if azureClient != nil {
271-
client.SDK.HTTPClient = azureClient
272-
}
273-
}
274-
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
275-
return client.AuthAzureInstanceIdentity(ctx)
276-
}
277-
}
278-
279217
executablePath, err := os.Executable()
280218
if err != nil {
281219
return xerrors.Errorf("getting os executable: %w", err)
@@ -343,18 +281,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
343281
LogDir: logDir,
344282
ScriptDataDir: scriptDataDir,
345283
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
346-
TailnetListenPort: uint16(tailnetListenPort),
347-
ExchangeToken: func(ctx context.Context) (string, error) {
348-
if exchangeToken == nil {
349-
return client.SDK.SessionToken(), nil
350-
}
351-
resp, err := exchangeToken(ctx)
352-
if err != nil {
353-
return "", err
354-
}
355-
client.SetSessionToken(resp.SessionToken)
356-
return resp.SessionToken, nil
357-
},
284+
TailnetListenPort: uint16(tailnetListenPort),
358285
EnvironmentVariables: environmentVariables,
359286
IgnorePorts: ignorePorts,
360287
SSHMaxTimeout: sshMaxTimeout,
@@ -400,13 +327,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
400327
}
401328

402329
cmd.Options = serpent.OptionSet{
403-
{
404-
Flag: "auth",
405-
Default: "token",
406-
Description: "Specify the authentication type to use for the agent.",
407-
Env: "CODER_AGENT_AUTH",
408-
Value: serpent.StringOf(&auth),
409-
},
410330
{
411331
Flag: "log-dir",
412332
Default: os.TempDir(),

cli/exp_mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command {
148148
binPath = testBinaryName
149149
}
150150
configureClaudeEnv := map[string]string{}
151-
agentClient, err := r.createAgentClient()
151+
agentClient, err := r.createAgentClient(inv.Context())
152152
if err != nil {
153153
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
154154
} else {
@@ -494,7 +494,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
494494
}
495495

496496
// Try to create an agent client for status reporting. Not validated.
497-
agentClient, err := r.createAgentClient()
497+
agentClient, err := r.createAgentClient(inv.Context())
498498
if err == nil {
499499
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
500500
srv.agentClient = agentClient

cli/externalauth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fi
7575
return xerrors.Errorf("agent token not found")
7676
}
7777

78-
client, err := r.tryCreateAgentClient()
78+
client, err := r.createAgentClient(ctx)
7979
if err != nil {
8080
return xerrors.Errorf("create agent client: %w", err)
8181
}

cli/gitaskpass.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func (r *RootCmd) gitAskpass() *serpent.Command {
3333
return xerrors.Errorf("parse host: %w", err)
3434
}
3535

36-
client, err := r.tryCreateAgentClient()
36+
client, err := r.createAgentClient(ctx)
3737
if err != nil {
3838
return xerrors.Errorf("create agent client: %w", err)
3939
}

0 commit comments

Comments
 (0)