Skip to content

chore(codersdk/toolsdk): improve static analyzability of toolsdk.Tools #17562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 29, 2025
Prev Previous commit
Next Next commit
fix(codersdk/toolsdk): address type incompatibility issues
  • Loading branch information
johnstcn committed Apr 29, 2025
commit 9edd5f7ab10fa57afadc79f16fd9ea1441a1dc67
25 changes: 8 additions & 17 deletions cli/exp_mcp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -697,7 +698,7 @@ func getAgentToken(fs afero.Fs) (string, error) {

// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
// It assumes that the tool responds with a valid JSON object.
func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTool {
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
// NOTE: some clients will silently refuse to use tools if there is an issue
// with the tool's schema or configuration.
if sdkTool.Schema.Properties == nil {
Expand All @@ -714,27 +715,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
}
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
if err != nil {
return nil, err
}
var sb strings.Builder
if err := json.NewEncoder(&sb).Encode(result); err == nil {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(sb.String()),
},
}, nil
}
// If the result is not JSON, return it as a string.
// This is a fallback for tools that return non-JSON data.
resultStr, ok := result.(string)
if !ok {
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(resultStr),
mcp.NewTextContent(string(result)),
},
}, nil
},
Expand Down
24 changes: 18 additions & 6 deletions cli/exp_mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ func TestExpMcpServer(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
cmdDone := make(chan struct{})
cancelCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
t.Cleanup(func() {
cancel()
<-cmdDone
})

// Given: a running coder deployment
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
owner := coderdtest.CreateFirstUser(t, client)

// Given: we run the exp mcp command with allowed tools set
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
Expand All @@ -48,7 +52,6 @@ func TestExpMcpServer(t *testing.T) {
// nolint: gocritic // not the focus of this test
clitest.SetupConfig(t, client, root)

cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
err := inv.Run()
Expand All @@ -61,9 +64,6 @@ func TestExpMcpServer(t *testing.T) {
_ = pty.ReadLine(ctx) // ignore echoed output
output := pty.ReadLine(ctx)

cancel()
<-cmdDone

// Then: we should only see the allowed tools in the response
var toolsResponse struct {
Result struct {
Expand All @@ -81,6 +81,18 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)

// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echoed output
output = pty.ReadLine(ctx)
require.NotEmpty(t, output, "should have received a response from the tool")
// Ensure it's valid JSON
_, err = json.Marshal(output)
require.NoError(t, err, "should have received a valid JSON response from the tool")
// Ensure the tool returns the expected user
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
})

t.Run("OK", func(t *testing.T) {
Expand Down
102 changes: 66 additions & 36 deletions codersdk/toolsdk/toolsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package toolsdk

import (
"archive/tar"
"bytes"
"context"
"encoding/json"
"io"

"github.com/google/uuid"
Expand All @@ -20,28 +22,49 @@ type Deps struct {
AppStatusSlug string
}

// HandlerFunc is a function that handles a tool call.
// HandlerFunc is a typed function that handles a tool call.
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)

// Tool consists of an aisdk.Tool and a corresponding typed handler function.
type Tool[Arg, Ret any] struct {
aisdk.Tool
Handler HandlerFunc[Arg, Ret]
}

// Generic returns a type-erased version of the Tool.
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
return Tool[any, any]{
// Generic returns a type-erased version of a TypedTool where the arguments and
// return values are converted to/from json.RawMessage.
// This allows the tool to be referenced without knowing the concrete arguments
// or return values. The original TypedHandlerFunc is wrapped to handle type
// conversion.
func (t Tool[Arg, Ret]) Generic() GenericTool {
return GenericTool{
Tool: t.Tool,
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
typedArg, ok := args.(Arg)
if !ok {
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
Handler: wrap(func(ctx context.Context, tb Deps, args json.RawMessage) (json.RawMessage, error) {
var typedArgs Arg
if err := json.Unmarshal(args, &typedArgs); err != nil {
return nil, xerrors.Errorf("failed to unmarshal args: %w", err)
}
return t.Handler(ctx, tb, typedArg)
},
ret, err := t.Handler(ctx, tb, typedArgs)
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(ret); err != nil {
return json.RawMessage{}, err
}
return buf.Bytes(), err
}, WithCleanContext, WithRecover),
}
}

// GenericTool is a type-erased wrapper for GenericTool.
// This allows referencing the tool without knowing the concrete argument or
// return type. The Handler function allows calling the tool with known types.
type GenericTool struct {
aisdk.Tool
Handler GenericHandlerFunc
}

// GenericHandlerFunc is a function that handles a tool call.
type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error)

type NoArgs struct{}

type ReportTaskArgs struct {
Expand Down Expand Up @@ -114,8 +137,8 @@ type UploadTarFileArgs struct {
}

// WithRecover wraps a HandlerFunc to recover from panics and return an error.
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
func WithRecover(h GenericHandlerFunc) GenericHandlerFunc {
return func(ctx context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
defer func() {
if r := recover(); r != nil {
err = xerrors.Errorf("tool handler panic: %v", r)
Expand All @@ -129,8 +152,8 @@ func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
// This ensures that no data is passed using context.Value.
// If a deadline is set on the parent context, it will be passed to the child
// context.
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc {
return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
child, childCancel := context.WithCancel(context.Background())
defer childCancel()
// Ensure that cancellation propagates from the parent context to the child context.
Expand All @@ -153,19 +176,18 @@ func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Re
}
}

// wrapAll wraps all provided tools with the given middleware function.
func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool[any, any]) []Tool[any, any] {
for i, t := range tools {
t.Handler = mw(t.Handler)
tools[i] = t
// wrap wraps the provided GenericHandlerFunc with the provided middleware functions.
func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFunc) GenericHandlerFunc {
for _, m := range mw {
hf = m(hf)
}
return tools
return hf
}

var (
// All is a list of all tools that can be used in the Coder CLI.
// When you add a new tool, be sure to include it here!
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
All = []GenericTool{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to have different lists depending on whether you have an agent client and slug.

CreateTemplate.Generic(),
CreateTemplateVersion.Generic(),
CreateWorkspace.Generic(),
Expand All @@ -182,9 +204,9 @@ var (
ReportTask.Generic(),
UploadTarFile.Generic(),
UpdateTemplateActiveVersion.Generic(),
)...)
}

ReportTask = Tool[ReportTaskArgs, string]{
ReportTask = Tool[ReportTaskArgs, codersdk.Response]{
Tool: aisdk.Tool{
Name: "coder_report_task",
Description: "Report progress on a user task in Coder.",
Expand All @@ -211,22 +233,24 @@ var (
Required: []string{"summary", "link", "state"},
},
},
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (codersdk.Response, error) {
if tb.AgentClient == nil {
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
}
if tb.AppStatusSlug == "" {
return "", xerrors.New("workspace app status slug not found in toolbox")
return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox")
}
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: tb.AppStatusSlug,
Message: args.Summary,
URI: args.Link,
State: codersdk.WorkspaceAppStatusState(args.State),
}); err != nil {
return "", err
return codersdk.Response{}, err
}
return "Thanks for reporting!", nil
return codersdk.Response{
Message: "Thanks for reporting!",
}, nil
},
}

Expand Down Expand Up @@ -934,9 +958,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
if err != nil {
return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err)
}
templateID, err := uuid.Parse(args.TemplateID)
if err != nil {
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
var templateID uuid.UUID
if args.TemplateID != "" {
tid, err := uuid.Parse(args.TemplateID)
if err != nil {
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
}
templateID = tid
}
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
Message: "Created by AI",
Expand Down Expand Up @@ -1183,7 +1211,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
},
}

DeleteTemplate = Tool[DeleteTemplateArgs, string]{
DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{
Tool: aisdk.Tool{
Name: "coder_delete_template",
Description: "Delete a template. This is irreversible.",
Expand All @@ -1195,16 +1223,18 @@ The file_id provided is a reference to a tar file you have uploaded containing t
},
},
},
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (codersdk.Response, error) {
templateID, err := uuid.Parse(args.TemplateID)
if err != nil {
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
}
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
if err != nil {
return "", err
return codersdk.Response{}, err
}
return "Successfully deleted template!", nil
return codersdk.Response{
Message: "Template deleted successfully.",
}, nil
},
}
)
Expand Down
Loading