Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ type Options struct {
LogDir string
TempDir string
ScriptDataDir string
ExchangeToken func(ctx context.Context) (string, error)
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Expand All @@ -99,6 +98,7 @@ type Client interface {
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
)
tailnet.DERPMapRewriter
agentsdk.RefreshableSessionTokenProvider
}

type Agent interface {
Expand Down Expand Up @@ -131,11 +131,6 @@ func New(options Options) Agent {
}
options.ScriptDataDir = options.TempDir
}
if options.ExchangeToken == nil {
options.ExchangeToken = func(_ context.Context) (string, error) {
return "", nil
}
}
if options.ReportMetadataInterval == 0 {
options.ReportMetadataInterval = time.Second
}
Expand Down Expand Up @@ -172,7 +167,6 @@ func New(options Options) Agent {
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
Expand Down Expand Up @@ -203,7 +197,6 @@ func New(options Options) Agent {
// coordinator during shut down.
close(a.coordDisconnected)
a.announcementBanners.Store(new([]codersdk.BannerConfig))
a.sessionToken.Store(new(string))
a.init()
return a
}
Expand All @@ -212,7 +205,6 @@ type agent struct {
clock quartz.Clock
logger slog.Logger
client Client
exchangeToken func(ctx context.Context) (string, error)
tailnetListenPort uint16
filesystem afero.Fs
logDir string
Expand Down Expand Up @@ -254,7 +246,6 @@ type agent struct {
scriptRunner *agentscripts.Runner
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
announcementBannersRefreshInterval time.Duration
sessionToken atomic.Pointer[string]
sshServer *agentssh.Server
sshMaxTimeout time.Duration
blockFileTransfer bool
Expand Down Expand Up @@ -916,11 +907,10 @@ func (a *agent) run() (retErr error) {
// This allows the agent to refresh its token if necessary.
// For instance identity this is required, since the instance
// may not have re-provisioned, but a new agent ID was created.
sessionToken, err := a.exchangeToken(a.hardCtx)
err := a.client.RefreshToken(a.hardCtx)
if err != nil {
return xerrors.Errorf("exchange token: %w", err)
return xerrors.Errorf("refresh token: %w", err)
}
a.sessionToken.Store(&sessionToken)

// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
Expand Down Expand Up @@ -1359,7 +1349,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error)
"CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName,

// Specific Coder subcommands require the agent token exposed!
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
"CODER_AGENT_TOKEN": a.client.GetSessionToken(),

// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
Expand Down
34 changes: 12 additions & 22 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -2926,11 +2925,11 @@ func TestAgent_Speedtest(t *testing.T) {

func TestAgent_Reconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
// After the agent is disconnected from a coordinator, it's supposed
// to reconnect!
coordinator := tailnet.NewCoordinator(logger)
defer coordinator.Close()
fCoordinator := tailnettest.NewFakeCoordinator()

agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
Expand All @@ -2942,27 +2941,24 @@ func TestAgent_Reconnect(t *testing.T) {
DERPMap: derpMap,
},
statsCh,
coordinator,
fCoordinator,
)
defer client.Close()
initialized := atomic.Int32{}

closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
initialized.Add(1)
return "", nil
},
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()

require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client.LastWorkspaceAgent()
require.Eventually(t, func() bool {
return initialized.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
close(call1.Resps) // hang up
// expect reconnect
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
// Check that the agent refreshes the token when it reconnects.
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
closer.Close()
}

func TestAgent_WriteVSCodeConfigs(t *testing.T) {
Expand All @@ -2984,9 +2980,6 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
defer client.Close()
filesystem := afero.NewMemMapFs()
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
return "", nil
},
Client: client,
Logger: logger.Named("agent"),
Filesystem: filesystem,
Expand Down Expand Up @@ -3015,9 +3008,6 @@ func TestAgent_DebugServer(t *testing.T) {
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
DERPMap: derpMap,
}, 0, func(c *agenttest.Client, o *agent.Options) {
o.ExchangeToken = func(context.Context) (string, error) {
return "token", nil
}
o.LogDir = logDir
})

Expand Down
10 changes: 1 addition & 9 deletions agent/agenttest/agent.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package agenttest

import (
"context"
"net/url"
"testing"

Expand Down Expand Up @@ -31,18 +30,11 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
}

if o.Client == nil {
agentClient := agentsdk.New(coderURL)
agentClient.SetSessionToken(agentToken)
agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken))
agentClient.SDK.SetLogger(log)
o.Client = agentClient
}

if o.ExchangeToken == nil {
o.ExchangeToken = func(_ context.Context) (string, error) {
return agentToken, nil
}
}

if o.LogDir == "" {
o.LogDir = t.TempDir()
}
Expand Down
34 changes: 30 additions & 4 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agenttest
import (
"context"
"io"
"net/http"
"slices"
"sync"
"sync/atomic"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)

const statsInterval = 500 * time.Millisecond
Expand Down Expand Up @@ -86,10 +88,34 @@ type Client struct {
fakeAgentAPI *FakeAgentAPI
LastWorkspaceAgent func()

mu sync.Mutex // Protects following.
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
mu sync.Mutex // Protects following.
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
refreshTokenCalls int
}

func (*Client) AsRequestOption() codersdk.RequestOption {
return func(_ *http.Request) {}
}

func (*Client) SetDialOption(*websocket.DialOptions) {}

func (*Client) GetSessionToken() string {
return "agenttest-token"
}

func (c *Client) RefreshToken(context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
c.refreshTokenCalls++
return nil
}

func (c *Client) GetNumRefreshTokenCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.refreshTokenCalls
}

func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
Expand Down
Loading
Loading