Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cli/cliui/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ var errAgentShuttingDown = xerrors.New("agent is shutting down")

type AgentOptions struct {
FetchInterval time.Duration
Fetch func(context.Context) (codersdk.WorkspaceAgent, error)
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad you made this change too (to pass agent ID)!

FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error)
Wait bool // If true, wait for the agent to be ready (startup script).
}

// Agent displays a spinning indicator that waits for a workspace agent to connect.
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if opts.FetchInterval == 0 {
opts.FetchInterval = 500 * time.Millisecond
}
Expand All @@ -47,7 +50,7 @@ func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
case <-ctx.Done():
return
case <-t.C:
agent, err := opts.Fetch(ctx)
agent, err := opts.Fetch(ctx, agentID)
select {
case <-fetchedAgent:
default:
Expand Down
81 changes: 56 additions & 25 deletions cli/cliui/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"io"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -16,17 +17,14 @@ import (
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)

func TestAgent(t *testing.T) {
t.Parallel()

ptrTime := func(t time.Time) *time.Time {
return &t
}

for _, tc := range []struct {
name string
iter []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error
Expand All @@ -47,7 +45,7 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
close(logs)
return nil
},
Expand All @@ -69,7 +67,7 @@ func TestAgent(t *testing.T) {
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnecting
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
Expand All @@ -78,9 +76,9 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
close(logs)
return nil
},
Expand All @@ -102,17 +100,17 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentDisconnected
agent.FirstConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.LastConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.DisconnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.LastConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.DisconnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.StartedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.ReadyAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.ReadyAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.LastConnectedAt = ptrTime(time.Now())
agent.LastConnectedAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, _ *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
Expand All @@ -136,9 +134,9 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
Expand All @@ -149,7 +147,7 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
Expand All @@ -176,10 +174,10 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.StartedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartError
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
Expand Down Expand Up @@ -222,9 +220,9 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
Expand All @@ -234,7 +232,7 @@ func TestAgent(t *testing.T) {
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleShuttingDown
close(logs)
return nil
Expand Down Expand Up @@ -310,7 +308,7 @@ func TestAgent(t *testing.T) {

cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
tc.opts.Fetch = func(_ context.Context) (codersdk.WorkspaceAgent, error) {
tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
var err error
if len(tc.iter) > 0 {
err = tc.iter[0](ctx, &agent, logs)
Expand All @@ -321,7 +319,7 @@ func TestAgent(t *testing.T) {
tc.opts.FetchLogs = func(_ context.Context, _ uuid.UUID, _ int64, _ bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) {
return logs, closeFunc(func() error { return nil }), nil
}
err := cliui.Agent(inv.Context(), &buf, tc.opts)
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, tc.opts)
return err
},
}
Expand Down Expand Up @@ -350,4 +348,37 @@ func TestAgent(t *testing.T) {
}
})
}

t.Run("NotInfinite", func(t *testing.T) {
t.Parallel()
var fetchCalled uint64

cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
buf := bytes.Buffer{}
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{
FetchInterval: 10 * time.Millisecond,
Fetch: func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) {
atomic.AddUint64(&fetchCalled, 1)

return codersdk.WorkspaceAgent{
Status: codersdk.WorkspaceAgentConnected,
LifecycleState: codersdk.WorkspaceAgentLifecycleReady,
}, nil
},
})
if err != nil {
return err
}

require.Never(t, func() bool {
called := atomic.LoadUint64(&fetchCalled)
return called > 5 || called == 0
}, time.Second, 100*time.Millisecond)

return nil
},
}
require.NoError(t, cmd.Invoke().Run())
})
Comment on lines +352 to +383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:chefs-kiss:

}
8 changes: 3 additions & 5 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ func (r *RootCmd) portForward() *clibase.Cmd {
}
}

err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
Wait: false,
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
Wait: false,
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
Expand Down
8 changes: 3 additions & 5 deletions cli/speedtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ func (r *RootCmd) speedtest() *clibase.Cmd {
return err
}

err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
Wait: false,
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
Wait: false,
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
Expand Down
6 changes: 2 additions & 4 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {

// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
FetchLogs: client.WorkspaceAgentStartupLogsAfter,
Wait: wait,
})
Expand Down
4 changes: 2 additions & 2 deletions cmd/cliui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ func main() {
agent.LastConnectedAt = &lastConnectedAt
},
}
err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
err := cliui.Agent(inv.Context(), inv.Stdout, uuid.Nil, cliui.AgentOptions{
FetchInterval: 100 * time.Millisecond,
Wait: true,
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
Fetch: func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
if len(fetchSteps) == 0 {
return agent, nil
}
Expand Down