Skip to content

Commit 11504b4

Browse files
committed
Extract connecting to agent from MCP
This will be used in multiple tools.
1 parent 8083d9d commit 11504b4

File tree

4 files changed

+122
-116
lines changed

4 files changed

+122
-116
lines changed

codersdk/toolsdk/bash.go

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717

1818
"github.com/coder/coder/v2/cli/cliui"
1919
"github.com/coder/coder/v2/codersdk"
20-
"github.com/coder/coder/v2/codersdk/workspacesdk"
2120
)
2221

2322
type WorkspaceBashArgs struct {
@@ -94,42 +93,12 @@ Examples:
9493
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
9594
defer cancel()
9695

97-
// Normalize workspace input to handle various formats
98-
workspaceName := NormalizeWorkspaceInput(args.Workspace)
99-
100-
// Find workspace and agent
101-
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName)
102-
if err != nil {
103-
return WorkspaceBashResult{}, xerrors.Errorf("failed to find workspace: %w", err)
104-
}
105-
106-
// Wait for agent to be ready
107-
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
108-
FetchInterval: 0,
109-
Fetch: deps.coderClient.WorkspaceAgent,
110-
FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter,
111-
Wait: true, // Always wait for startup scripts
112-
}); err != nil {
113-
return WorkspaceBashResult{}, xerrors.Errorf("agent not ready: %w", err)
114-
}
115-
116-
// Create workspace SDK client for agent connection
117-
wsClient := workspacesdk.New(deps.coderClient)
118-
119-
// Dial agent
120-
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
121-
BlockEndpoints: false,
122-
})
96+
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
12397
if err != nil {
124-
return WorkspaceBashResult{}, xerrors.Errorf("failed to dial agent: %w", err)
98+
return WorkspaceBashResult{}, err
12599
}
126100
defer conn.Close()
127101

128-
// Wait for connection to be reachable
129-
if !conn.AwaitReachable(ctx) {
130-
return WorkspaceBashResult{}, xerrors.New("agent connection not reachable")
131-
}
132-
133102
// Create SSH client
134103
sshClient, err := conn.SSHClient(ctx)
135104
if err != nil {
@@ -323,32 +292,6 @@ func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier str
323292
return client.WorkspaceByOwnerAndName(ctx, owner, workspaceName, codersdk.WorkspaceOptions{})
324293
}
325294

326-
// NormalizeWorkspaceInput converts workspace name input to standard format.
327-
// Handles the following input formats:
328-
// - workspace → workspace
329-
// - workspace.agent → workspace.agent
330-
// - owner/workspace → owner/workspace
331-
// - owner--workspace → owner/workspace
332-
// - owner/workspace.agent → owner/workspace.agent
333-
// - owner--workspace.agent → owner/workspace.agent
334-
// - agent.workspace.owner → owner/workspace.agent (Coder Connect format)
335-
func NormalizeWorkspaceInput(input string) string {
336-
// Handle the special Coder Connect format: agent.workspace.owner
337-
// This format uses only dots and has exactly 3 parts
338-
if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") {
339-
parts := strings.Split(input, ".")
340-
if len(parts) == 3 {
341-
// Convert agent.workspace.owner → owner/workspace.agent
342-
return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0])
343-
}
344-
}
345-
346-
// Convert -- separator to / separator for consistency
347-
normalized := strings.ReplaceAll(input, "--", "/")
348-
349-
return normalized
350-
}
351-
352295
// executeCommandWithTimeout executes a command with timeout support
353296
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
354297
// Set up pipes to capture output

codersdk/toolsdk/bash_test.go

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -99,63 +99,6 @@ func TestWorkspaceBash(t *testing.T) {
9999
})
100100
}
101101

102-
func TestNormalizeWorkspaceInput(t *testing.T) {
103-
t.Parallel()
104-
if runtime.GOOS == "windows" {
105-
t.Skip("Skipping on Windows: Workspace MCP bash tools rely on a Unix-like shell (bash) and POSIX/SSH semantics. Use Linux/macOS or WSL for these tests.")
106-
}
107-
108-
testCases := []struct {
109-
name string
110-
input string
111-
expected string
112-
}{
113-
{
114-
name: "SimpleWorkspace",
115-
input: "workspace",
116-
expected: "workspace",
117-
},
118-
{
119-
name: "WorkspaceWithAgent",
120-
input: "workspace.agent",
121-
expected: "workspace.agent",
122-
},
123-
{
124-
name: "OwnerAndWorkspace",
125-
input: "owner/workspace",
126-
expected: "owner/workspace",
127-
},
128-
{
129-
name: "OwnerDashWorkspace",
130-
input: "owner--workspace",
131-
expected: "owner/workspace",
132-
},
133-
{
134-
name: "OwnerWorkspaceAgent",
135-
input: "owner/workspace.agent",
136-
expected: "owner/workspace.agent",
137-
},
138-
{
139-
name: "OwnerDashWorkspaceAgent",
140-
input: "owner--workspace.agent",
141-
expected: "owner/workspace.agent",
142-
},
143-
{
144-
name: "CoderConnectFormat",
145-
input: "agent.workspace.owner", // Special Coder Connect reverse format
146-
expected: "owner/workspace.agent",
147-
},
148-
}
149-
150-
for _, tc := range testCases {
151-
t.Run(tc.name, func(t *testing.T) {
152-
t.Parallel()
153-
result := toolsdk.NormalizeWorkspaceInput(tc.input)
154-
require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result)
155-
})
156-
}
157-
}
158-
159102
func TestAllToolsIncludesBash(t *testing.T) {
160103
t.Parallel()
161104
if runtime.GOOS == "windows" {

codersdk/toolsdk/toolsdk.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"bytes"
66
"context"
77
"encoding/json"
8+
"fmt"
89
"io"
910
"runtime/debug"
11+
"strings"
1012

1113
"github.com/google/uuid"
1214
"golang.org/x/xerrors"
@@ -1360,3 +1362,64 @@ type MinimalTemplate struct {
13601362
ActiveVersionID uuid.UUID `json:"active_version_id"`
13611363
ActiveUserCount int `json:"active_user_count"`
13621364
}
1365+
1366+
// NormalizeWorkspaceInput converts workspace name input to standard format.
1367+
// Handles the following input formats:
1368+
// - workspace → workspace
1369+
// - workspace.agent → workspace.agent
1370+
// - owner/workspace → owner/workspace
1371+
// - owner--workspace → owner/workspace
1372+
// - owner/workspace.agent → owner/workspace.agent
1373+
// - owner--workspace.agent → owner/workspace.agent
1374+
// - agent.workspace.owner → owner/workspace.agent (Coder Connect format)
1375+
func NormalizeWorkspaceInput(input string) string {
1376+
// Handle the special Coder Connect format: agent.workspace.owner
1377+
// This format uses only dots and has exactly 3 parts
1378+
if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") {
1379+
parts := strings.Split(input, ".")
1380+
if len(parts) == 3 {
1381+
// Convert agent.workspace.owner → owner/workspace.agent
1382+
return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0])
1383+
}
1384+
}
1385+
1386+
// Convert -- separator to / separator for consistency
1387+
normalized := strings.ReplaceAll(input, "--", "/")
1388+
1389+
return normalized
1390+
}
1391+
1392+
// newAgentConn returns a connection to the agent specified by the workspace,
1393+
// which must be in the format [owner/]workspace[.agent].
1394+
func newAgentConn(ctx context.Context, client *codersdk.Client, workspace string) (workspacesdk.AgentConn, error) {
1395+
workspaceName := NormalizeWorkspaceInput(workspace)
1396+
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, client, workspaceName)
1397+
if err != nil {
1398+
return nil, xerrors.Errorf("failed to find workspace: %w", err)
1399+
}
1400+
1401+
// Wait for agent to be ready.
1402+
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
1403+
FetchInterval: 0,
1404+
Fetch: client.WorkspaceAgent,
1405+
FetchLogs: client.WorkspaceAgentLogsAfter,
1406+
Wait: true, // Always wait for startup scripts
1407+
}); err != nil {
1408+
return nil, xerrors.Errorf("agent not ready: %w", err)
1409+
}
1410+
1411+
wsClient := workspacesdk.New(client)
1412+
1413+
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
1414+
BlockEndpoints: false,
1415+
})
1416+
if err != nil {
1417+
return nil, xerrors.Errorf("failed to dial agent: %w", err)
1418+
}
1419+
1420+
if !conn.AwaitReachable(ctx) {
1421+
conn.Close()
1422+
return nil, xerrors.New("agent connection not reachable")
1423+
}
1424+
return conn, nil
1425+
}

codersdk/toolsdk/toolsdk_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,60 @@ func TestReportTaskWithReporter(t *testing.T) {
748748
// Verify response
749749
require.Equal(t, "Thanks for reporting!", result.Message)
750750
}
751+
752+
func TestNormalizeWorkspaceInput(t *testing.T) {
753+
t.Parallel()
754+
if runtime.GOOS == "windows" {
755+
t.Skip("Skipping on Windows: Workspace MCP bash tools rely on a Unix-like shell (bash) and POSIX/SSH semantics. Use Linux/macOS or WSL for these tests.")
756+
}
757+
758+
testCases := []struct {
759+
name string
760+
input string
761+
expected string
762+
}{
763+
{
764+
name: "SimpleWorkspace",
765+
input: "workspace",
766+
expected: "workspace",
767+
},
768+
{
769+
name: "WorkspaceWithAgent",
770+
input: "workspace.agent",
771+
expected: "workspace.agent",
772+
},
773+
{
774+
name: "OwnerAndWorkspace",
775+
input: "owner/workspace",
776+
expected: "owner/workspace",
777+
},
778+
{
779+
name: "OwnerDashWorkspace",
780+
input: "owner--workspace",
781+
expected: "owner/workspace",
782+
},
783+
{
784+
name: "OwnerWorkspaceAgent",
785+
input: "owner/workspace.agent",
786+
expected: "owner/workspace.agent",
787+
},
788+
{
789+
name: "OwnerDashWorkspaceAgent",
790+
input: "owner--workspace.agent",
791+
expected: "owner/workspace.agent",
792+
},
793+
{
794+
name: "CoderConnectFormat",
795+
input: "agent.workspace.owner", // Special Coder Connect reverse format
796+
expected: "owner/workspace.agent",
797+
},
798+
}
799+
800+
for _, tc := range testCases {
801+
t.Run(tc.name, func(t *testing.T) {
802+
t.Parallel()
803+
result := toolsdk.NormalizeWorkspaceInput(tc.input)
804+
require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result)
805+
})
806+
}
807+
}

0 commit comments

Comments
 (0)