Skip to content

Commit 1354d84

Browse files
authored
chore: refactor instance identity to be a SessionTokenProvider (#19566)
Refactors Agent instance identity to be a SessionTokenProvider. Refactors the CLI to create Agent clients via a centralized function, rather than add-hoc via individual command handlers and their flags. This allows commands besides `coder agent`, but which still use the agent identity, to support instance identity authentication. Fixes #19111 by unifying all API requests to go thru the SessionTokenProvider for auth credentials.
1 parent ee35ad3 commit 1354d84

35 files changed

+632
-470
lines changed

agent/agent.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ type Options struct {
7474
LogDir string
7575
TempDir string
7676
ScriptDataDir string
77-
ExchangeToken func(ctx context.Context) (string, error)
7877
Client Client
7978
ReconnectingPTYTimeout time.Duration
8079
EnvironmentVariables map[string]string
@@ -99,6 +98,7 @@ type Client interface {
9998
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
10099
)
101100
tailnet.DERPMapRewriter
101+
agentsdk.RefreshableSessionTokenProvider
102102
}
103103

104104
type Agent interface {
@@ -131,11 +131,6 @@ func New(options Options) Agent {
131131
}
132132
options.ScriptDataDir = options.TempDir
133133
}
134-
if options.ExchangeToken == nil {
135-
options.ExchangeToken = func(_ context.Context) (string, error) {
136-
return "", nil
137-
}
138-
}
139134
if options.ReportMetadataInterval == 0 {
140135
options.ReportMetadataInterval = time.Second
141136
}
@@ -172,7 +167,6 @@ func New(options Options) Agent {
172167
coordDisconnected: make(chan struct{}),
173168
environmentVariables: options.EnvironmentVariables,
174169
client: options.Client,
175-
exchangeToken: options.ExchangeToken,
176170
filesystem: options.Filesystem,
177171
logDir: options.LogDir,
178172
tempDir: options.TempDir,
@@ -203,7 +197,6 @@ func New(options Options) Agent {
203197
// coordinator during shut down.
204198
close(a.coordDisconnected)
205199
a.announcementBanners.Store(new([]codersdk.BannerConfig))
206-
a.sessionToken.Store(new(string))
207200
a.init()
208201
return a
209202
}
@@ -212,7 +205,6 @@ type agent struct {
212205
clock quartz.Clock
213206
logger slog.Logger
214207
client Client
215-
exchangeToken func(ctx context.Context) (string, error)
216208
tailnetListenPort uint16
217209
filesystem afero.Fs
218210
logDir string
@@ -254,7 +246,6 @@ type agent struct {
254246
scriptRunner *agentscripts.Runner
255247
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
256248
announcementBannersRefreshInterval time.Duration
257-
sessionToken atomic.Pointer[string]
258249
sshServer *agentssh.Server
259250
sshMaxTimeout time.Duration
260251
blockFileTransfer bool
@@ -916,11 +907,10 @@ func (a *agent) run() (retErr error) {
916907
// This allows the agent to refresh its token if necessary.
917908
// For instance identity this is required, since the instance
918909
// may not have re-provisioned, but a new agent ID was created.
919-
sessionToken, err := a.exchangeToken(a.hardCtx)
910+
err := a.client.RefreshToken(a.hardCtx)
920911
if err != nil {
921-
return xerrors.Errorf("exchange token: %w", err)
912+
return xerrors.Errorf("refresh token: %w", err)
922913
}
923-
a.sessionToken.Store(&sessionToken)
924914

925915
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
926916
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
@@ -1359,7 +1349,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error)
13591349
"CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName,
13601350

13611351
// Specific Coder subcommands require the agent token exposed!
1362-
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
1352+
"CODER_AGENT_TOKEN": a.client.GetSessionToken(),
13631353

13641354
// Git on Windows resolves with UNIX-style paths.
13651355
// If using backslashes, it's unable to find the executable.

agent/agent_test.go

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"slices"
2323
"strconv"
2424
"strings"
25-
"sync/atomic"
2625
"testing"
2726
"time"
2827

@@ -2926,11 +2925,11 @@ func TestAgent_Speedtest(t *testing.T) {
29262925

29272926
func TestAgent_Reconnect(t *testing.T) {
29282927
t.Parallel()
2928+
ctx := testutil.Context(t, testutil.WaitShort)
29292929
logger := testutil.Logger(t)
29302930
// After the agent is disconnected from a coordinator, it's supposed
29312931
// to reconnect!
2932-
coordinator := tailnet.NewCoordinator(logger)
2933-
defer coordinator.Close()
2932+
fCoordinator := tailnettest.NewFakeCoordinator()
29342933

29352934
agentID := uuid.New()
29362935
statsCh := make(chan *proto.Stats, 50)
@@ -2942,27 +2941,24 @@ func TestAgent_Reconnect(t *testing.T) {
29422941
DERPMap: derpMap,
29432942
},
29442943
statsCh,
2945-
coordinator,
2944+
fCoordinator,
29462945
)
29472946
defer client.Close()
2948-
initialized := atomic.Int32{}
2947+
29492948
closer := agent.New(agent.Options{
2950-
ExchangeToken: func(ctx context.Context) (string, error) {
2951-
initialized.Add(1)
2952-
return "", nil
2953-
},
29542949
Client: client,
29552950
Logger: logger.Named("agent"),
29562951
})
29572952
defer closer.Close()
29582953

2959-
require.Eventually(t, func() bool {
2960-
return coordinator.Node(agentID) != nil
2961-
}, testutil.WaitShort, testutil.IntervalFast)
2962-
client.LastWorkspaceAgent()
2963-
require.Eventually(t, func() bool {
2964-
return initialized.Load() == 2
2965-
}, testutil.WaitShort, testutil.IntervalFast)
2954+
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
2955+
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
2956+
close(call1.Resps) // hang up
2957+
// expect reconnect
2958+
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
2959+
// Check that the agent refreshes the token when it reconnects.
2960+
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
2961+
closer.Close()
29662962
}
29672963

29682964
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
@@ -2984,9 +2980,6 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
29842980
defer client.Close()
29852981
filesystem := afero.NewMemMapFs()
29862982
closer := agent.New(agent.Options{
2987-
ExchangeToken: func(ctx context.Context) (string, error) {
2988-
return "", nil
2989-
},
29902983
Client: client,
29912984
Logger: logger.Named("agent"),
29922985
Filesystem: filesystem,
@@ -3015,9 +3008,6 @@ func TestAgent_DebugServer(t *testing.T) {
30153008
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
30163009
DERPMap: derpMap,
30173010
}, 0, func(c *agenttest.Client, o *agent.Options) {
3018-
o.ExchangeToken = func(context.Context) (string, error) {
3019-
return "token", nil
3020-
}
30213011
o.LogDir = logDir
30223012
})
30233013

agent/agenttest/agent.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package agenttest
22

33
import (
4-
"context"
54
"net/url"
65
"testing"
76

@@ -31,18 +30,11 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
3130
}
3231

3332
if o.Client == nil {
34-
agentClient := agentsdk.New(coderURL)
35-
agentClient.SetSessionToken(agentToken)
33+
agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken))
3634
agentClient.SDK.SetLogger(log)
3735
o.Client = agentClient
3836
}
3937

40-
if o.ExchangeToken == nil {
41-
o.ExchangeToken = func(_ context.Context) (string, error) {
42-
return agentToken, nil
43-
}
44-
}
45-
4638
if o.LogDir == "" {
4739
o.LogDir = t.TempDir()
4840
}

agent/agenttest/client.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package agenttest
33
import (
44
"context"
55
"io"
6+
"net/http"
67
"slices"
78
"sync"
89
"sync/atomic"
@@ -28,6 +29,7 @@ import (
2829
"github.com/coder/coder/v2/tailnet"
2930
"github.com/coder/coder/v2/tailnet/proto"
3031
"github.com/coder/coder/v2/testutil"
32+
"github.com/coder/websocket"
3133
)
3234

3335
const statsInterval = 500 * time.Millisecond
@@ -86,10 +88,34 @@ type Client struct {
8688
fakeAgentAPI *FakeAgentAPI
8789
LastWorkspaceAgent func()
8890

89-
mu sync.Mutex // Protects following.
90-
logs []agentsdk.Log
91-
derpMapUpdates chan *tailcfg.DERPMap
92-
derpMapOnce sync.Once
91+
mu sync.Mutex // Protects following.
92+
logs []agentsdk.Log
93+
derpMapUpdates chan *tailcfg.DERPMap
94+
derpMapOnce sync.Once
95+
refreshTokenCalls int
96+
}
97+
98+
func (*Client) AsRequestOption() codersdk.RequestOption {
99+
return func(_ *http.Request) {}
100+
}
101+
102+
func (*Client) SetDialOption(*websocket.DialOptions) {}
103+
104+
func (*Client) GetSessionToken() string {
105+
return "agenttest-token"
106+
}
107+
108+
func (c *Client) RefreshToken(context.Context) error {
109+
c.mu.Lock()
110+
defer c.mu.Unlock()
111+
c.refreshTokenCalls++
112+
return nil
113+
}
114+
115+
func (c *Client) GetNumRefreshTokenCalls() int {
116+
c.mu.Lock()
117+
defer c.mu.Unlock()
118+
return c.refreshTokenCalls
93119
}
94120

95121
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}

0 commit comments

Comments
 (0)