Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func workspaceAgent() *serpent.Command {
slog.F("auth", agentAuth.agentAuth),
slog.F("version", version),
)
client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
Expand Down
157 changes: 0 additions & 157 deletions cli/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cli_test

import (
"context"
"fmt"
"net/http"
"os"
Expand All @@ -11,7 +10,6 @@ import (
"sync/atomic"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -21,10 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)

Expand Down Expand Up @@ -64,158 +59,6 @@ func TestWorkspaceAgent(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalMedium)
})

t.Run("Azure", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AzureCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()

inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "azure-client", metadataClient),
)

ctx := inv.Context()
clitest.Start(t, inv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})

t.Run("AWS", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AWSCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()

inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "aws-client", metadataClient),
)

clitest.Start(t, inv)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})

t.Run("GoogleCloud", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
GoogleTokenValidator: validator,
})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: memberUser.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()

inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
clitest.SetupConfig(t, member, cfg)

clitest.Start(t,
inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "gcp-client", metadataClient),
),
)

ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
key := "CODER_AGENT_TOKEN"
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
token, err := session.CombinedOutput(command)
require.NoError(t, err)
_, err = uuid.Parse(strings.TrimSpace(string(token)))
require.NoError(t, err)
})

t.Run("PostStartup", func(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 2 additions & 2 deletions cli/exp_mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
binPath = testBinaryName
}
configureClaudeEnv := map[string]string{}
agentClient, err := agentAuth.CreateClient(inv.Context())
agentClient, err := agentAuth.CreateClient()
if err != nil {
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
} else {
Expand Down Expand Up @@ -497,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
}

// Try to create an agent client for status reporting. Not validated.
agentClient, err := agentAuth.CreateClient(inv.Context())
agentClient, err := agentAuth.CreateClient()
if err == nil {
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
srv.agentClient = agentClient
Expand Down
2 changes: 1 addition & 1 deletion cli/externalauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ fi
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()

client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/gitaskpass.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("parse host: %w", err)
}

client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/gitssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func gitssh() *serpent.Command {
return err
}

client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
Expand Down
39 changes: 4 additions & 35 deletions cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"text/tabwriter"
"time"

"cloud.google.com/go/compute/metadata"
"github.com/mattn/go-isatty"
"github.com/mitchellh/go-wordwrap"
"golang.org/x/mod/semver"
Expand Down Expand Up @@ -687,7 +686,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {

// CreateClient returns a new agent client from the command context. It works
// just like InitClient, but uses the agent token and URL instead.
func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) {
func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
agentURL := a.agentURL
if agentURL.String() == "" {
return nil, xerrors.Errorf("%s must be set", envAgentURL)
Expand All @@ -711,41 +710,11 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error)
}
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
case "google-instance-identity":

// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var gcpClient *metadata.Client
gcpClientRaw := ctx.Value("gcp-client")
if gcpClientRaw != nil {
gcpClient, _ = gcpClientRaw.(*metadata.Client)
}
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil
case "aws-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity())
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var awsClient *http.Client
awsClientRaw := ctx.Value("aws-client")
if awsClientRaw != nil {
awsClient, _ = awsClientRaw.(*http.Client)
if awsClient != nil {
client.SDK.HTTPClient = awsClient
}
}
return client, nil
return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil
case "azure-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity())
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var azureClient *http.Client
azureClientRaw := ctx.Value("azure-client")
if azureClientRaw != nil {
azureClient, _ = azureClientRaw.(*http.Client)
if azureClient != nil {
client.SDK.HTTPClient = azureClient
}
}
return client, nil
return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil
default:
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
}
Expand Down
Loading
Loading