Skip to content

Commit 1c3bfac

Browse files
authored
fix(cli): ensure cliui.Agent doesn't fetch infinitely (#8446)
1 parent 14caa9b commit 1c3bfac

File tree

6 files changed

+72
-44
lines changed

6 files changed

+72
-44
lines changed

cli/cliui/agent.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@ var errAgentShuttingDown = xerrors.New("agent is shutting down")
1515

1616
type AgentOptions struct {
1717
FetchInterval time.Duration
18-
Fetch func(context.Context) (codersdk.WorkspaceAgent, error)
18+
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
1919
FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error)
2020
Wait bool // If true, wait for the agent to be ready (startup script).
2121
}
2222

2323
// Agent displays a spinning indicator that waits for a workspace agent to connect.
24-
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
24+
func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error {
25+
ctx, cancel := context.WithCancel(ctx)
26+
defer cancel()
27+
2528
if opts.FetchInterval == 0 {
2629
opts.FetchInterval = 500 * time.Millisecond
2730
}
@@ -47,7 +50,7 @@ func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
4750
case <-ctx.Done():
4851
return
4952
case <-t.C:
50-
agent, err := opts.Fetch(ctx)
53+
agent, err := opts.Fetch(ctx, agentID)
5154
select {
5255
case <-fetchedAgent:
5356
default:

cli/cliui/agent_test.go

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"io"
88
"strings"
9+
"sync/atomic"
910
"testing"
1011
"time"
1112

@@ -16,17 +17,14 @@ import (
1617
"github.com/coder/coder/cli/clibase"
1718
"github.com/coder/coder/cli/clitest"
1819
"github.com/coder/coder/cli/cliui"
20+
"github.com/coder/coder/coderd/util/ptr"
1921
"github.com/coder/coder/codersdk"
2022
"github.com/coder/coder/testutil"
2123
)
2224

2325
func TestAgent(t *testing.T) {
2426
t.Parallel()
2527

26-
ptrTime := func(t time.Time) *time.Time {
27-
return &t
28-
}
29-
3028
for _, tc := range []struct {
3129
name string
3230
iter []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error
@@ -47,7 +45,7 @@ func TestAgent(t *testing.T) {
4745
},
4846
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
4947
agent.Status = codersdk.WorkspaceAgentConnected
50-
agent.FirstConnectedAt = ptrTime(time.Now())
48+
agent.FirstConnectedAt = ptr.Ref(time.Now())
5149
close(logs)
5250
return nil
5351
},
@@ -69,7 +67,7 @@ func TestAgent(t *testing.T) {
6967
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
7068
agent.Status = codersdk.WorkspaceAgentConnecting
7169
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
72-
agent.StartedAt = ptrTime(time.Now())
70+
agent.StartedAt = ptr.Ref(time.Now())
7371
return nil
7472
},
7573
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
@@ -78,9 +76,9 @@ func TestAgent(t *testing.T) {
7876
},
7977
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
8078
agent.Status = codersdk.WorkspaceAgentConnected
81-
agent.FirstConnectedAt = ptrTime(time.Now())
79+
agent.FirstConnectedAt = ptr.Ref(time.Now())
8280
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
83-
agent.ReadyAt = ptrTime(time.Now())
81+
agent.ReadyAt = ptr.Ref(time.Now())
8482
close(logs)
8583
return nil
8684
},
@@ -102,17 +100,17 @@ func TestAgent(t *testing.T) {
102100
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
103101
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
104102
agent.Status = codersdk.WorkspaceAgentDisconnected
105-
agent.FirstConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
106-
agent.LastConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
107-
agent.DisconnectedAt = ptrTime(time.Now())
103+
agent.FirstConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
104+
agent.LastConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
105+
agent.DisconnectedAt = ptr.Ref(time.Now())
108106
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
109-
agent.StartedAt = ptrTime(time.Now().Add(-1 * time.Minute))
110-
agent.ReadyAt = ptrTime(time.Now())
107+
agent.StartedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
108+
agent.ReadyAt = ptr.Ref(time.Now())
111109
return nil
112110
},
113111
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
114112
agent.Status = codersdk.WorkspaceAgentConnected
115-
agent.LastConnectedAt = ptrTime(time.Now())
113+
agent.LastConnectedAt = ptr.Ref(time.Now())
116114
return nil
117115
},
118116
func(_ context.Context, _ *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
@@ -136,9 +134,9 @@ func TestAgent(t *testing.T) {
136134
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
137135
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
138136
agent.Status = codersdk.WorkspaceAgentConnected
139-
agent.FirstConnectedAt = ptrTime(time.Now())
137+
agent.FirstConnectedAt = ptr.Ref(time.Now())
140138
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
141-
agent.StartedAt = ptrTime(time.Now())
139+
agent.StartedAt = ptr.Ref(time.Now())
142140
logs <- []codersdk.WorkspaceAgentStartupLog{
143141
{
144142
CreatedAt: time.Now(),
@@ -149,7 +147,7 @@ func TestAgent(t *testing.T) {
149147
},
150148
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
151149
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
152-
agent.ReadyAt = ptrTime(time.Now())
150+
agent.ReadyAt = ptr.Ref(time.Now())
153151
logs <- []codersdk.WorkspaceAgentStartupLog{
154152
{
155153
CreatedAt: time.Now(),
@@ -176,10 +174,10 @@ func TestAgent(t *testing.T) {
176174
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
177175
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
178176
agent.Status = codersdk.WorkspaceAgentConnected
179-
agent.FirstConnectedAt = ptrTime(time.Now())
180-
agent.StartedAt = ptrTime(time.Now())
177+
agent.FirstConnectedAt = ptr.Ref(time.Now())
178+
agent.StartedAt = ptr.Ref(time.Now())
181179
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartError
182-
agent.ReadyAt = ptrTime(time.Now())
180+
agent.ReadyAt = ptr.Ref(time.Now())
183181
logs <- []codersdk.WorkspaceAgentStartupLog{
184182
{
185183
CreatedAt: time.Now(),
@@ -222,9 +220,9 @@ func TestAgent(t *testing.T) {
222220
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
223221
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
224222
agent.Status = codersdk.WorkspaceAgentConnected
225-
agent.FirstConnectedAt = ptrTime(time.Now())
223+
agent.FirstConnectedAt = ptr.Ref(time.Now())
226224
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
227-
agent.StartedAt = ptrTime(time.Now())
225+
agent.StartedAt = ptr.Ref(time.Now())
228226
logs <- []codersdk.WorkspaceAgentStartupLog{
229227
{
230228
CreatedAt: time.Now(),
@@ -234,7 +232,7 @@ func TestAgent(t *testing.T) {
234232
return nil
235233
},
236234
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
237-
agent.ReadyAt = ptrTime(time.Now())
235+
agent.ReadyAt = ptr.Ref(time.Now())
238236
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleShuttingDown
239237
close(logs)
240238
return nil
@@ -310,7 +308,7 @@ func TestAgent(t *testing.T) {
310308

311309
cmd := &clibase.Cmd{
312310
Handler: func(inv *clibase.Invocation) error {
313-
tc.opts.Fetch = func(_ context.Context) (codersdk.WorkspaceAgent, error) {
311+
tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
314312
var err error
315313
if len(tc.iter) > 0 {
316314
err = tc.iter[0](ctx, &agent, logs)
@@ -321,7 +319,7 @@ func TestAgent(t *testing.T) {
321319
tc.opts.FetchLogs = func(_ context.Context, _ uuid.UUID, _ int64, _ bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) {
322320
return logs, closeFunc(func() error { return nil }), nil
323321
}
324-
err := cliui.Agent(inv.Context(), &buf, tc.opts)
322+
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, tc.opts)
325323
return err
326324
},
327325
}
@@ -350,4 +348,37 @@ func TestAgent(t *testing.T) {
350348
}
351349
})
352350
}
351+
352+
t.Run("NotInfinite", func(t *testing.T) {
353+
t.Parallel()
354+
var fetchCalled uint64
355+
356+
cmd := &clibase.Cmd{
357+
Handler: func(inv *clibase.Invocation) error {
358+
buf := bytes.Buffer{}
359+
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{
360+
FetchInterval: 10 * time.Millisecond,
361+
Fetch: func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) {
362+
atomic.AddUint64(&fetchCalled, 1)
363+
364+
return codersdk.WorkspaceAgent{
365+
Status: codersdk.WorkspaceAgentConnected,
366+
LifecycleState: codersdk.WorkspaceAgentLifecycleReady,
367+
}, nil
368+
},
369+
})
370+
if err != nil {
371+
return err
372+
}
373+
374+
require.Never(t, func() bool {
375+
called := atomic.LoadUint64(&fetchCalled)
376+
return called > 5 || called == 0
377+
}, time.Second, 100*time.Millisecond)
378+
379+
return nil
380+
},
381+
}
382+
require.NoError(t, cmd.Invoke().Run())
383+
})
353384
}

cli/portforward.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,9 @@ func (r *RootCmd) portForward() *clibase.Cmd {
9090
}
9191
}
9292

93-
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
94-
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
95-
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
96-
},
97-
Wait: false,
93+
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
94+
Fetch: client.WorkspaceAgent,
95+
Wait: false,
9896
})
9997
if err != nil {
10098
return xerrors.Errorf("await agent: %w", err)

cli/speedtest.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ func (r *RootCmd) speedtest() *clibase.Cmd {
4040
return err
4141
}
4242

43-
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
44-
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
45-
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
46-
},
47-
Wait: false,
43+
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
44+
Fetch: client.WorkspaceAgent,
45+
Wait: false,
4846
})
4947
if err != nil {
5048
return xerrors.Errorf("await agent: %w", err)

cli/ssh.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
175175

176176
// OpenSSH passes stderr directly to the calling TTY.
177177
// This is required in "stdio" mode so a connecting indicator can be displayed.
178-
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
179-
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
180-
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
181-
},
178+
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
179+
Fetch: client.WorkspaceAgent,
182180
FetchLogs: client.WorkspaceAgentStartupLogsAfter,
183181
Wait: wait,
184182
})

cmd/cliui/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ func main() {
214214
agent.LastConnectedAt = &lastConnectedAt
215215
},
216216
}
217-
err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
217+
err := cliui.Agent(inv.Context(), inv.Stdout, uuid.Nil, cliui.AgentOptions{
218218
FetchInterval: 100 * time.Millisecond,
219219
Wait: true,
220-
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
220+
Fetch: func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
221221
if len(fetchSteps) == 0 {
222222
return agent, nil
223223
}

0 commit comments

Comments
 (0)