Skip to content

chore(codersdk/toolsdk): improve static analyzability of toolsdk.Tools #17562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 29, 2025
46 changes: 21 additions & 25 deletions cli/exp_mcp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -427,22 +428,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
server.WithInstructions(instructions),
)

// Create a new context for the tools with all relevant information.
clientCtx := toolsdk.WithClient(ctx, client)
// Get the workspace agent token from the environment.
toolOpts := make([]func(*toolsdk.Deps), 0)
var hasAgentClient bool
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
hasAgentClient = true
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(agentToken)
clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient)
toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient))
} else {
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
}
if appStatusSlug == "" {
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")

if appStatusSlug != "" {
toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug))
} else {
clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug)
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
}

toolDeps, err := toolsdk.NewDeps(client, toolOpts...)
if err != nil {
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
}

// Register tools based on the allowlist (if specified)
Expand All @@ -455,15 +461,15 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
return t == tool.Tool.Name
}) {
mcpSrv.AddTools(mcpFromSDK(tool))
mcpSrv.AddTools(mcpFromSDK(tool, toolDeps))
}
}

srv := server.NewStdioServer(mcpSrv)
done := make(chan error)
go func() {
defer close(done)
srvErr := srv.Listen(clientCtx, invStdin, invStdout)
srvErr := srv.Listen(ctx, invStdin, invStdout)
done <- srvErr
}()

Expand Down Expand Up @@ -726,7 +732,7 @@ func getAgentToken(fs afero.Fs) (string, error) {

// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
// It assumes that the tool responds with a valid JSON object.
func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
// NOTE: some clients will silently refuse to use tools if there is an issue
// with the tool's schema or configuration.
if sdkTool.Schema.Properties == nil {
Expand All @@ -743,27 +749,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := sdkTool.Handler(ctx, request.Params.Arguments)
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
}
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
if err != nil {
return nil, err
}
var sb strings.Builder
if err := json.NewEncoder(&sb).Encode(result); err == nil {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(sb.String()),
},
}, nil
}
// If the result is not JSON, return it as a string.
// This is a fallback for tools that return non-JSON data.
resultStr, ok := result.(string)
if !ok {
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(resultStr),
mcp.NewTextContent(string(result)),
},
}, nil
},
Expand Down
22 changes: 16 additions & 6 deletions cli/exp_mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ func TestExpMcpServer(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
cmdDone := make(chan struct{})
cancelCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)

// Given: a running coder deployment
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
owner := coderdtest.CreateFirstUser(t, client)

// Given: we run the exp mcp command with allowed tools set
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
Expand All @@ -48,7 +48,6 @@ func TestExpMcpServer(t *testing.T) {
// nolint: gocritic // not the focus of this test
clitest.SetupConfig(t, client, root)

cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
err := inv.Run()
Expand All @@ -61,9 +60,6 @@ func TestExpMcpServer(t *testing.T) {
_ = pty.ReadLine(ctx) // ignore echoed output
output := pty.ReadLine(ctx)

cancel()
<-cmdDone

// Then: we should only see the allowed tools in the response
var toolsResponse struct {
Result struct {
Expand All @@ -81,6 +77,20 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)

// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echoed output
output = pty.ReadLine(ctx)
require.NotEmpty(t, output, "should have received a response from the tool")
// Ensure it's valid JSON
_, err = json.Marshal(output)
require.NoError(t, err, "should have received a valid JSON response from the tool")
// Ensure the tool returns the expected user
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
cancel()
<-cmdDone
})

t.Run("OK", func(t *testing.T) {
Expand Down
28 changes: 26 additions & 2 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,33 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
Slug: req.AppSlug,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: err.Error(),
Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug),
})
return
}

if len(req.Message) > 160 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
return
}

switch req.State {
case codersdk.WorkspaceAppStatusStateComplete, codersdk.WorkspaceAppStatusStateFailure, codersdk.WorkspaceAppStatusStateWorking: // valid states
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working."},
},
})
return
}
Expand Down
83 changes: 64 additions & 19 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,27 +341,27 @@ func TestWorkspaceAgentLogs(t *testing.T) {

func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)

r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()

agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(r.AgentToken)
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(r.AgentToken)
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: "testing",
Expand All @@ -382,6 +382,51 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
require.Empty(t, agent.Apps[0].Statuses[0].Icon)
require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention)
})

t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "unknown",
Message: "testing",
URI: "https://example.com",
State: codersdk.WorkspaceAppStatusStateComplete,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})

t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: "testing",
URI: "https://example.com",
State: "unknown",
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})

t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "vscode",
Message: strings.Repeat("a", 161),
URI: "https://example.com",
State: codersdk.WorkspaceAppStatusStateComplete,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}

func TestWorkspaceAgentConnectRPC(t *testing.T) {
Expand Down
Loading
Loading