Skip to content

Commit 2acf0ad

Browse files
authored
chore(codersdk/toolsdk): improve static analyzability of toolsdk.Tools (#17562)
* Refactors toolsdk.Tools to remove opaque `map[string]any` argument in favour of typed args structs. * Refactors toolsdk.Tools to remove opaque passing of dependencies via `context.Context` in favour of a tool dependencies struct. * Adds panic recovery and clean context middleware to all tools. * Adds `GenericTool` implementation to allow keeping `toolsdk.All` with uniform type signature while maintaining type information in handlers. * Adds stricter checks to `patchWorkspaceAgentAppStatus` handler.
1 parent 1fc74f6 commit 2acf0ad

File tree

6 files changed

+1139
-865
lines changed

6 files changed

+1139
-865
lines changed

cli/exp_mcp.go

+21-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -427,22 +428,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
427428
server.WithInstructions(instructions),
428429
)
429430

430-
// Create a new context for the tools with all relevant information.
431-
clientCtx := toolsdk.WithClient(ctx, client)
432431
// Get the workspace agent token from the environment.
432+
toolOpts := make([]func(*toolsdk.Deps), 0)
433433
var hasAgentClient bool
434434
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
435435
hasAgentClient = true
436436
agentClient := agentsdk.New(client.URL)
437437
agentClient.SetSessionToken(agentToken)
438-
clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient)
438+
toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient))
439439
} else {
440440
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
441441
}
442-
if appStatusSlug == "" {
443-
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
442+
443+
if appStatusSlug != "" {
444+
toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug))
444445
} else {
445-
clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug)
446+
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
447+
}
448+
449+
toolDeps, err := toolsdk.NewDeps(client, toolOpts...)
450+
if err != nil {
451+
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
446452
}
447453

448454
// Register tools based on the allowlist (if specified)
@@ -455,15 +461,15 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
455461
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
456462
return t == tool.Tool.Name
457463
}) {
458-
mcpSrv.AddTools(mcpFromSDK(tool))
464+
mcpSrv.AddTools(mcpFromSDK(tool, toolDeps))
459465
}
460466
}
461467

462468
srv := server.NewStdioServer(mcpSrv)
463469
done := make(chan error)
464470
go func() {
465471
defer close(done)
466-
srvErr := srv.Listen(clientCtx, invStdin, invStdout)
472+
srvErr := srv.Listen(ctx, invStdin, invStdout)
467473
done <- srvErr
468474
}()
469475

@@ -726,7 +732,7 @@ func getAgentToken(fs afero.Fs) (string, error) {
726732

727733
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
728734
// It assumes that the tool responds with a valid JSON object.
729-
func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
735+
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
730736
// NOTE: some clients will silently refuse to use tools if there is an issue
731737
// with the tool's schema or configuration.
732738
if sdkTool.Schema.Properties == nil {
@@ -743,27 +749,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
743749
},
744750
},
745751
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
746-
result, err := sdkTool.Handler(ctx, request.Params.Arguments)
752+
var buf bytes.Buffer
753+
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
754+
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
755+
}
756+
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
747757
if err != nil {
748758
return nil, err
749759
}
750-
var sb strings.Builder
751-
if err := json.NewEncoder(&sb).Encode(result); err == nil {
752-
return &mcp.CallToolResult{
753-
Content: []mcp.Content{
754-
mcp.NewTextContent(sb.String()),
755-
},
756-
}, nil
757-
}
758-
// If the result is not JSON, return it as a string.
759-
// This is a fallback for tools that return non-JSON data.
760-
resultStr, ok := result.(string)
761-
if !ok {
762-
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
763-
}
764760
return &mcp.CallToolResult{
765761
Content: []mcp.Content{
766-
mcp.NewTextContent(resultStr),
762+
mcp.NewTextContent(string(result)),
767763
},
768764
}, nil
769765
},

cli/exp_mcp_test.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ func TestExpMcpServer(t *testing.T) {
3131
t.Parallel()
3232

3333
ctx := testutil.Context(t, testutil.WaitShort)
34+
cmdDone := make(chan struct{})
3435
cancelCtx, cancel := context.WithCancel(ctx)
35-
t.Cleanup(cancel)
3636

3737
// Given: a running coder deployment
3838
client := coderdtest.New(t, nil)
39-
_ = coderdtest.CreateFirstUser(t, client)
39+
owner := coderdtest.CreateFirstUser(t, client)
4040

4141
// Given: we run the exp mcp command with allowed tools set
4242
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
@@ -48,7 +48,6 @@ func TestExpMcpServer(t *testing.T) {
4848
// nolint: gocritic // not the focus of this test
4949
clitest.SetupConfig(t, client, root)
5050

51-
cmdDone := make(chan struct{})
5251
go func() {
5352
defer close(cmdDone)
5453
err := inv.Run()
@@ -61,9 +60,6 @@ func TestExpMcpServer(t *testing.T) {
6160
_ = pty.ReadLine(ctx) // ignore echoed output
6261
output := pty.ReadLine(ctx)
6362

64-
cancel()
65-
<-cmdDone
66-
6763
// Then: we should only see the allowed tools in the response
6864
var toolsResponse struct {
6965
Result struct {
@@ -81,6 +77,20 @@ func TestExpMcpServer(t *testing.T) {
8177
}
8278
slices.Sort(foundTools)
8379
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
80+
81+
// Call the tool and ensure it works.
82+
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
83+
pty.WriteLine(toolPayload)
84+
_ = pty.ReadLine(ctx) // ignore echoed output
85+
output = pty.ReadLine(ctx)
86+
require.NotEmpty(t, output, "should have received a response from the tool")
87+
// Ensure it's valid JSON
88+
_, err = json.Marshal(output)
89+
require.NoError(t, err, "should have received a valid JSON response from the tool")
90+
// Ensure the tool returns the expected user
91+
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
92+
cancel()
93+
<-cmdDone
8494
})
8595

8696
t.Run("OK", func(t *testing.T) {

coderd/workspaceagents.go

+26-2
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,33 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
338338
Slug: req.AppSlug,
339339
})
340340
if err != nil {
341-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
341+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
342342
Message: "Failed to get workspace app.",
343-
Detail: err.Error(),
343+
Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug),
344+
})
345+
return
346+
}
347+
348+
if len(req.Message) > 160 {
349+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
350+
Message: "Message is too long.",
351+
Detail: "Message must be less than 160 characters.",
352+
Validations: []codersdk.ValidationError{
353+
{Field: "message", Detail: "Message must be less than 160 characters."},
354+
},
355+
})
356+
return
357+
}
358+
359+
switch req.State {
360+
case codersdk.WorkspaceAppStatusStateComplete, codersdk.WorkspaceAppStatusStateFailure, codersdk.WorkspaceAppStatusStateWorking: // valid states
361+
default:
362+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
363+
Message: "Invalid state provided.",
364+
Detail: fmt.Sprintf("invalid state: %q", req.State),
365+
Validations: []codersdk.ValidationError{
366+
{Field: "state", Detail: "State must be one of: complete, failure, working."},
367+
},
344368
})
345369
return
346370
}

coderd/workspaceagents_test.go

+64-19
Original file line numberDiff line numberDiff line change
@@ -340,27 +340,27 @@ func TestWorkspaceAgentLogs(t *testing.T) {
340340

341341
func TestWorkspaceAgentAppStatus(t *testing.T) {
342342
t.Parallel()
343-
t.Run("Success", func(t *testing.T) {
344-
t.Parallel()
345-
ctx := testutil.Context(t, testutil.WaitMedium)
346-
client, db := coderdtest.NewWithDatabase(t, nil)
347-
user := coderdtest.CreateFirstUser(t, client)
348-
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
343+
client, db := coderdtest.NewWithDatabase(t, nil)
344+
user := coderdtest.CreateFirstUser(t, client)
345+
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
349346

350-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
351-
OrganizationID: user.OrganizationID,
352-
OwnerID: user2.ID,
353-
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
354-
a[0].Apps = []*proto.App{
355-
{
356-
Slug: "vscode",
357-
},
358-
}
359-
return a
360-
}).Do()
347+
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
348+
OrganizationID: user.OrganizationID,
349+
OwnerID: user2.ID,
350+
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
351+
a[0].Apps = []*proto.App{
352+
{
353+
Slug: "vscode",
354+
},
355+
}
356+
return a
357+
}).Do()
361358

362-
agentClient := agentsdk.New(client.URL)
363-
agentClient.SetSessionToken(r.AgentToken)
359+
agentClient := agentsdk.New(client.URL)
360+
agentClient.SetSessionToken(r.AgentToken)
361+
t.Run("Success", func(t *testing.T) {
362+
t.Parallel()
363+
ctx := testutil.Context(t, testutil.WaitShort)
364364
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
365365
AppSlug: "vscode",
366366
Message: "testing",
@@ -381,6 +381,51 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
381381
require.Empty(t, agent.Apps[0].Statuses[0].Icon)
382382
require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention)
383383
})
384+
385+
t.Run("FailUnknownApp", func(t *testing.T) {
386+
t.Parallel()
387+
ctx := testutil.Context(t, testutil.WaitShort)
388+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
389+
AppSlug: "unknown",
390+
Message: "testing",
391+
URI: "https://example.com",
392+
State: codersdk.WorkspaceAppStatusStateComplete,
393+
})
394+
require.ErrorContains(t, err, "No app found with slug")
395+
var sdkErr *codersdk.Error
396+
require.ErrorAs(t, err, &sdkErr)
397+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
398+
})
399+
400+
t.Run("FailUnknownState", func(t *testing.T) {
401+
t.Parallel()
402+
ctx := testutil.Context(t, testutil.WaitShort)
403+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
404+
AppSlug: "vscode",
405+
Message: "testing",
406+
URI: "https://example.com",
407+
State: "unknown",
408+
})
409+
require.ErrorContains(t, err, "Invalid state")
410+
var sdkErr *codersdk.Error
411+
require.ErrorAs(t, err, &sdkErr)
412+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
413+
})
414+
415+
t.Run("FailTooLong", func(t *testing.T) {
416+
t.Parallel()
417+
ctx := testutil.Context(t, testutil.WaitShort)
418+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
419+
AppSlug: "vscode",
420+
Message: strings.Repeat("a", 161),
421+
URI: "https://example.com",
422+
State: codersdk.WorkspaceAppStatusStateComplete,
423+
})
424+
require.ErrorContains(t, err, "Message is too long")
425+
var sdkErr *codersdk.Error
426+
require.ErrorAs(t, err, &sdkErr)
427+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
428+
})
384429
}
385430

386431
func TestWorkspaceAgentConnectRPC(t *testing.T) {

0 commit comments

Comments
 (0)