Skip to content

refactor(agent): add agenttest.New helper function #9812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 26, 2023
Prev Previous commit
Next Next commit
address code review comments
  • Loading branch information
johnstcn committed Sep 21, 2023
commit 3316b99d1b884a9fc422b3d6a84fd9b8b4141776
109 changes: 30 additions & 79 deletions agent/agenttest/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"sync"
"testing"

"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

Expand All @@ -20,36 +18,22 @@ import (
"github.com/coder/coder/v2/codersdk/agentsdk"
)

// Options are options for creating a new test agent.
type Options struct {
// AgentOptions are the options to use for the agent.
AgentOptions agent.Options

// AgentToken is the token to use for the agent.
AgentToken string
// URL is the URL to which the agent should connect.
URL *url.URL
// WorkspaceID is the ID of the workspace to which the agent should connect.
WorkspaceID uuid.UUID
// Logger is the logger to use for the agent.
// Defaults to a new test logger if not specified.
Logger *slog.Logger
}

// Agent is a small wrapper around an agent for use in tests.
type Agent struct {
waitOnce sync.Once
agent agent.Agent
agentClient *agentsdk.Client
resources []codersdk.WorkspaceResource
waiter func(*codersdk.Client) []codersdk.WorkspaceResource
waiter func(*codersdk.Client, uuid.UUID, ...string) []codersdk.WorkspaceResource
waitOnce sync.Once
}

// Wait waits for the agent to connect to the workspace and returns the
// resources for the connected workspace.
func (a *Agent) Wait(client *codersdk.Client) []codersdk.WorkspaceResource {
// Calls coderdtest.AwaitWorkspaceAgents under the hood.
// Multiple calls to Wait() are idempotent.
func (a *Agent) Wait(client *codersdk.Client, workspaceID uuid.UUID, agentNames ...string) []codersdk.WorkspaceResource {
a.waitOnce.Do(func() {
a.resources = a.waiter(client)
a.resources = a.waiter(client, workspaceID, agentNames...)
})
return a.resources
}
Expand All @@ -64,84 +48,51 @@ func (a *Agent) Agent() agent.Agent {
return a.agent
}

// OptFunc is a function that modifies the given options.
type OptFunc func(*Options)

func WithAgentToken(token string) OptFunc {
return func(o *Options) {
o.AgentToken = token
}
}

func WithURL(u *url.URL) OptFunc {
return func(o *Options) {
o.URL = u
}
}

func WithWorkspaceID(id uuid.UUID) OptFunc {
return func(o *Options) {
o.WorkspaceID = id
}
}

// New starts a new agent for use in tests.
// Returns a wrapper around the agent that can be used to wait for the agent to
// connect to the workspace.
// The agent will use the provided coder URL and session token.
// The options passed to agent.New() can be modified by passing an optional
// variadic func(*agent.Options).
// Returns a wrapper that can be used to wait for the agent to connect to the
// workspace by calling Wait(). The arguments to Wait() are passed to
// coderdtest.AwaitWorkspaceAgents.
// Closing the agent is handled by the test cleanup.
func New(t testing.TB, opts ...OptFunc) *Agent {
func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent.Options)) *Agent {
t.Helper()

var o Options
var o agent.Options
log := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("agent")
o.Logger = log

for _, opt := range opts {
opt(&o)
}

if o.URL == nil {
require.Fail(t, "must specify URL for agent")
}
agentClient := agentsdk.New(o.URL)

if o.AgentToken == "" {
o.AgentToken = uuid.NewString()
if o.Client == nil {
agentClient := agentsdk.New(coderURL)
agentClient.SetSessionToken(agentToken)
o.Client = agentClient
}
agentClient.SetSessionToken(o.AgentToken)

if o.AgentOptions.Client == nil {
o.AgentOptions.Client = agentClient
}

if o.AgentOptions.ExchangeToken == nil {
o.AgentOptions.ExchangeToken = func(_ context.Context) (string, error) {
return o.AgentToken, nil
if o.ExchangeToken == nil {
o.ExchangeToken = func(_ context.Context) (string, error) {
return agentToken, nil
}
}

if o.AgentOptions.LogDir == "" {
o.AgentOptions.LogDir = t.TempDir()
if o.LogDir == "" {
o.LogDir = t.TempDir()
}

if o.Logger == nil {
log := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("agent")
o.Logger = &log
}

o.AgentOptions.Logger = *o.Logger

agentCloser := agent.New(o.AgentOptions)
agentCloser := agent.New(o)
t.Cleanup(func() {
assert.NoError(t, agentCloser.Close(), "failed to close agent during cleanup")
})

return &Agent{
agent: agentCloser,
agentClient: agentClient,
waiter: func(c *codersdk.Client) []codersdk.WorkspaceResource {
if o.WorkspaceID == uuid.Nil {
require.FailNow(t, "must specify workspace ID for agent in order to wait")
return nil // unreachable
}
return coderdtest.AwaitWorkspaceAgents(t, c, o.WorkspaceID)
agentClient: o.Client.(*agentsdk.Client), // nolint:forcetypeassert
waiter: func(c *codersdk.Client, workspaceID uuid.UUID, agentNames ...string) []codersdk.WorkspaceResource {
return coderdtest.AwaitWorkspaceAgents(t, c, workspaceID, agentNames...)
},
}
}
6 changes: 1 addition & 5 deletions cli/configssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,7 @@ func TestConfigSSH(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
resources := agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(authToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
resources := agenttest.New(t, client.URL, authToken).Wait(client, workspace.ID)
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer agentConn.Close()
Expand Down
8 changes: 2 additions & 6 deletions cli/gitssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

// start workspace agent
agt := agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
)
agt.Wait(client)
agt := agenttest.New(t, client.URL, agentToken)
agt.Wait(client, workspace.ID)
return agt.Client(), agentToken, pubkey
}

Expand Down
6 changes: 1 addition & 5 deletions cli/ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ func TestPing(t *testing.T) {
inv.Stderr = pty.Output()
inv.Stdout = pty.Output()

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
Expand Down
12 changes: 5 additions & 7 deletions cli/portforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"

"github.com/coder/coder/v2/cli/clitest"
Expand Down Expand Up @@ -315,14 +316,11 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk.
workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
func(o *agenttest.Options) {
o.AgentOptions.SSHMaxTimeout = 60 * time.Second
agenttest.New(t, client.URL, agentToken,
func(o *agent.Options) {
o.SSHMaxTimeout = 60 * time.Second
},
).Wait(client)
).Wait(client, workspace.ID)

return workspace
}
Expand Down
6 changes: 1 addition & 5 deletions cli/speedtest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ func TestSpeedtest(t *testing.T) {
t.Skip("This test takes a minimum of 5ms per a hardcoded value in Tailscale!")
}
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
Expand Down
59 changes: 14 additions & 45 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
Expand Down Expand Up @@ -102,11 +103,7 @@ func TestSSH(t *testing.T) {
})
pty.ExpectMatch("Waiting")

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
pty.WriteLine("exit")
Expand Down Expand Up @@ -162,11 +159,7 @@ func TestSSH(t *testing.T) {
})
pty.ExpectMatch("Waiting")

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

// Ensure the agent is connected.
pty.WriteLine("echo hell'o'")
Expand All @@ -187,11 +180,7 @@ func TestSSH(t *testing.T) {
_, _ = tGoContext(t, func(ctx context.Context) {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)
<-ctx.Done()
})

Expand Down Expand Up @@ -253,11 +242,7 @@ func TestSSH(t *testing.T) {
_, _ = tGoContext(t, func(ctx context.Context) {
// Run this async so the SSH command has to wait for
// the build and agent to connect.
_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)
<-ctx.Done()
})

Expand Down Expand Up @@ -320,11 +305,7 @@ func TestSSH(t *testing.T) {

client, workspace, agentToken := setupWorkspaceForAgent(t, nil)

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

// Generate private key.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expand Down Expand Up @@ -409,11 +390,7 @@ func TestSSH(t *testing.T) {

client, workspace, agentToken := setupWorkspaceForAgent(t, nil)

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
Expand Down Expand Up @@ -458,11 +435,7 @@ func TestSSH(t *testing.T) {

pty.ExpectMatch("Waiting")

agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
pty.WriteLine("exit")
Expand Down Expand Up @@ -630,16 +603,12 @@ Expire-Date: 0

client, workspace, agentToken := setupWorkspaceForAgent(t, nil)

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
func(o *agenttest.Options) {
o.AgentOptions.EnvironmentVariables = map[string]string{
"GNUPGHOME": gnupgHomeWorkspace,
}
},
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) {
o.EnvironmentVariables = map[string]string{
"GNUPGHOME": gnupgHomeWorkspace,
}
},
).Wait(client, workspace.ID)

inv, root := clitest.New(t,
"ssh",
Expand Down
6 changes: 1 addition & 5 deletions cli/vscodessh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ func TestVSCodeSSH(t *testing.T) {
user, err := client.User(ctx, codersdk.Me)
require.NoError(t, err)

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

fs := afero.NewMemMapFs()
err = afero.WriteFile(fs, "/url", []byte(client.URL.String()), 0o600)
Expand Down
6 changes: 1 addition & 5 deletions coderd/activitybump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
require.NoError(t, err)
}

_ = agenttest.New(t,
agenttest.WithURL(client.URL),
agenttest.WithAgentToken(agentToken),
agenttest.WithWorkspaceID(workspace.ID),
).Wait(client)
_ = agenttest.New(t, client.URL, agentToken).Wait(client, workspace.ID)

// Sanity-check that deadline is near.
workspace, err := client.Workspace(ctx, workspace.ID)
Expand Down
Loading