Skip to content

Commit 7690732

Browse files
committed
chore: refactor CLI agent auth tests as unit tests
1 parent 1354d84 commit 7690732

File tree

12 files changed

+129
-237
lines changed

12 files changed

+129
-237
lines changed

cli/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func workspaceAgent() *serpent.Command {
179179
slog.F("auth", agentAuth.agentAuth),
180180
slog.F("version", version),
181181
)
182-
client, err := agentAuth.CreateClient(ctx)
182+
client, err := agentAuth.CreateClient()
183183
if err != nil {
184184
return xerrors.Errorf("create agent client: %w", err)
185185
}

cli/agent_test.go

Lines changed: 0 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cli_test
22

33
import (
4-
"context"
54
"fmt"
65
"net/http"
76
"os"
@@ -11,7 +10,6 @@ import (
1110
"sync/atomic"
1211
"testing"
1312

14-
"github.com/google/uuid"
1513
"github.com/stretchr/testify/assert"
1614
"github.com/stretchr/testify/require"
1715

@@ -21,10 +19,7 @@ import (
2119
"github.com/coder/coder/v2/coderd/coderdtest"
2220
"github.com/coder/coder/v2/coderd/database"
2321
"github.com/coder/coder/v2/coderd/database/dbfake"
24-
"github.com/coder/coder/v2/coderd/database/dbtestutil"
2522
"github.com/coder/coder/v2/codersdk"
26-
"github.com/coder/coder/v2/codersdk/workspacesdk"
27-
"github.com/coder/coder/v2/provisionersdk/proto"
2823
"github.com/coder/coder/v2/testutil"
2924
)
3025

@@ -64,158 +59,6 @@ func TestWorkspaceAgent(t *testing.T) {
6459
}, testutil.WaitLong, testutil.IntervalMedium)
6560
})
6661

67-
t.Run("Azure", func(t *testing.T) {
68-
t.Parallel()
69-
instanceID := "instanceidentifier"
70-
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
71-
db, ps := dbtestutil.NewDB(t,
72-
dbtestutil.WithDumpOnFailure(),
73-
)
74-
client := coderdtest.New(t, &coderdtest.Options{
75-
Database: db,
76-
Pubsub: ps,
77-
AzureCertificates: certificates,
78-
})
79-
user := coderdtest.CreateFirstUser(t, client)
80-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
81-
OrganizationID: user.OrganizationID,
82-
OwnerID: user.UserID,
83-
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
84-
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
85-
return agents
86-
}).Do()
87-
88-
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
89-
inv = inv.WithContext(
90-
//nolint:revive,staticcheck
91-
context.WithValue(inv.Context(), "azure-client", metadataClient),
92-
)
93-
94-
ctx := inv.Context()
95-
clitest.Start(t, inv)
96-
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
97-
MatchResources(matchAgentWithVersion).Wait()
98-
workspace, err := client.Workspace(ctx, r.Workspace.ID)
99-
require.NoError(t, err)
100-
resources := workspace.LatestBuild.Resources
101-
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
102-
assert.NotEmpty(t, resources[0].Agents[0].Version)
103-
}
104-
dialer, err := workspacesdk.New(client).
105-
DialAgent(ctx, resources[0].Agents[0].ID, nil)
106-
require.NoError(t, err)
107-
defer dialer.Close()
108-
require.True(t, dialer.AwaitReachable(ctx))
109-
})
110-
111-
t.Run("AWS", func(t *testing.T) {
112-
t.Parallel()
113-
instanceID := "instanceidentifier"
114-
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
115-
db, ps := dbtestutil.NewDB(t,
116-
dbtestutil.WithDumpOnFailure(),
117-
)
118-
client := coderdtest.New(t, &coderdtest.Options{
119-
Database: db,
120-
Pubsub: ps,
121-
AWSCertificates: certificates,
122-
})
123-
user := coderdtest.CreateFirstUser(t, client)
124-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
125-
OrganizationID: user.OrganizationID,
126-
OwnerID: user.UserID,
127-
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
128-
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
129-
return agents
130-
}).Do()
131-
132-
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
133-
inv = inv.WithContext(
134-
//nolint:revive,staticcheck
135-
context.WithValue(inv.Context(), "aws-client", metadataClient),
136-
)
137-
138-
clitest.Start(t, inv)
139-
ctx := inv.Context()
140-
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
141-
MatchResources(matchAgentWithVersion).
142-
Wait()
143-
workspace, err := client.Workspace(ctx, r.Workspace.ID)
144-
require.NoError(t, err)
145-
resources := workspace.LatestBuild.Resources
146-
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
147-
assert.NotEmpty(t, resources[0].Agents[0].Version)
148-
}
149-
dialer, err := workspacesdk.New(client).
150-
DialAgent(ctx, resources[0].Agents[0].ID, nil)
151-
require.NoError(t, err)
152-
defer dialer.Close()
153-
require.True(t, dialer.AwaitReachable(ctx))
154-
})
155-
156-
t.Run("GoogleCloud", func(t *testing.T) {
157-
t.Parallel()
158-
instanceID := "instanceidentifier"
159-
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
160-
db, ps := dbtestutil.NewDB(t,
161-
dbtestutil.WithDumpOnFailure(),
162-
)
163-
client := coderdtest.New(t, &coderdtest.Options{
164-
Database: db,
165-
Pubsub: ps,
166-
GoogleTokenValidator: validator,
167-
})
168-
owner := coderdtest.CreateFirstUser(t, client)
169-
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
170-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
171-
OrganizationID: owner.OrganizationID,
172-
OwnerID: memberUser.ID,
173-
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
174-
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
175-
return agents
176-
}).Do()
177-
178-
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
179-
clitest.SetupConfig(t, member, cfg)
180-
181-
clitest.Start(t,
182-
inv.WithContext(
183-
//nolint:revive,staticcheck
184-
context.WithValue(inv.Context(), "gcp-client", metadataClient),
185-
),
186-
)
187-
188-
ctx := inv.Context()
189-
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
190-
MatchResources(matchAgentWithVersion).
191-
Wait()
192-
workspace, err := client.Workspace(ctx, r.Workspace.ID)
193-
require.NoError(t, err)
194-
resources := workspace.LatestBuild.Resources
195-
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
196-
assert.NotEmpty(t, resources[0].Agents[0].Version)
197-
}
198-
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
199-
require.NoError(t, err)
200-
defer dialer.Close()
201-
require.True(t, dialer.AwaitReachable(ctx))
202-
sshClient, err := dialer.SSHClient(ctx)
203-
require.NoError(t, err)
204-
defer sshClient.Close()
205-
session, err := sshClient.NewSession()
206-
require.NoError(t, err)
207-
defer session.Close()
208-
key := "CODER_AGENT_TOKEN"
209-
command := "sh -c 'echo $" + key + "'"
210-
if runtime.GOOS == "windows" {
211-
command = "cmd.exe /c echo %" + key + "%"
212-
}
213-
token, err := session.CombinedOutput(command)
214-
require.NoError(t, err)
215-
_, err = uuid.Parse(strings.TrimSpace(string(token)))
216-
require.NoError(t, err)
217-
})
218-
21962
t.Run("PostStartup", func(t *testing.T) {
22063
t.Parallel()
22164

cli/exp_mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
149149
binPath = testBinaryName
150150
}
151151
configureClaudeEnv := map[string]string{}
152-
agentClient, err := agentAuth.CreateClient(inv.Context())
152+
agentClient, err := agentAuth.CreateClient()
153153
if err != nil {
154154
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
155155
} else {
@@ -497,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
497497
}
498498

499499
// Try to create an agent client for status reporting. Not validated.
500-
agentClient, err := agentAuth.CreateClient(inv.Context())
500+
agentClient, err := agentAuth.CreateClient()
501501
if err == nil {
502502
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
503503
srv.agentClient = agentClient

cli/externalauth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ fi
6868
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
6969
defer stop()
7070

71-
client, err := agentAuth.CreateClient(ctx)
71+
client, err := agentAuth.CreateClient()
7272
if err != nil {
7373
return xerrors.Errorf("create agent client: %w", err)
7474
}

cli/gitaskpass.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
3333
return xerrors.Errorf("parse host: %w", err)
3434
}
3535

36-
client, err := agentAuth.CreateClient(ctx)
36+
client, err := agentAuth.CreateClient()
3737
if err != nil {
3838
return xerrors.Errorf("create agent client: %w", err)
3939
}

cli/gitssh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func gitssh() *serpent.Command {
3939
return err
4040
}
4141

42-
client, err := agentAuth.CreateClient(ctx)
42+
client, err := agentAuth.CreateClient()
4343
if err != nil {
4444
return xerrors.Errorf("create agent client: %w", err)
4545
}

cli/root.go

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"text/tabwriter"
2525
"time"
2626

27-
"cloud.google.com/go/compute/metadata"
2827
"github.com/mattn/go-isatty"
2928
"github.com/mitchellh/go-wordwrap"
3029
"golang.org/x/mod/semver"
@@ -687,7 +686,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
687686

688687
// CreateClient returns a new agent client from the command context. It works
689688
// just like InitClient, but uses the agent token and URL instead.
690-
func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) {
689+
func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
691690
agentURL := a.agentURL
692691
if agentURL.String() == "" {
693692
return nil, xerrors.Errorf("%s must be set", envAgentURL)
@@ -711,41 +710,11 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error)
711710
}
712711
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
713712
case "google-instance-identity":
714-
715-
// This is *only* done for testing to mock client authentication.
716-
// This will never be set in a production scenario.
717-
var gcpClient *metadata.Client
718-
gcpClientRaw := ctx.Value("gcp-client")
719-
if gcpClientRaw != nil {
720-
gcpClient, _ = gcpClientRaw.(*metadata.Client)
721-
}
722-
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil
713+
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil
723714
case "aws-instance-identity":
724-
client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity())
725-
// This is *only* done for testing to mock client authentication.
726-
// This will never be set in a production scenario.
727-
var awsClient *http.Client
728-
awsClientRaw := ctx.Value("aws-client")
729-
if awsClientRaw != nil {
730-
awsClient, _ = awsClientRaw.(*http.Client)
731-
if awsClient != nil {
732-
client.SDK.HTTPClient = awsClient
733-
}
734-
}
735-
return client, nil
715+
return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil
736716
case "azure-instance-identity":
737-
client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity())
738-
// This is *only* done for testing to mock client authentication.
739-
// This will never be set in a production scenario.
740-
var azureClient *http.Client
741-
azureClientRaw := ctx.Value("azure-client")
742-
if azureClientRaw != nil {
743-
azureClient, _ = azureClientRaw.(*http.Client)
744-
if azureClient != nil {
745-
client.SDK.HTTPClient = azureClient
746-
}
747-
}
748-
return client, nil
717+
return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil
749718
default:
750719
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
751720
}

0 commit comments

Comments
 (0)