Skip to content

Commit 4a9ab00

Browse files
committed
fix(codersdk/toolsdk): address type incompatibility issues
1 parent f10c081 commit 4a9ab00

File tree

4 files changed

+162
-119
lines changed

4 files changed

+162
-119
lines changed

cli/exp_mcp.go

+8-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -697,7 +698,7 @@ func getAgentToken(fs afero.Fs) (string, error) {
697698

698699
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
699700
// It assumes that the tool responds with a valid JSON object.
700-
func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTool {
701+
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
701702
// NOTE: some clients will silently refuse to use tools if there is an issue
702703
// with the tool's schema or configuration.
703704
if sdkTool.Schema.Properties == nil {
@@ -714,27 +715,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
714715
},
715716
},
716717
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717-
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
718+
var buf bytes.Buffer
719+
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
720+
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
721+
}
722+
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
718723
if err != nil {
719724
return nil, err
720725
}
721-
var sb strings.Builder
722-
if err := json.NewEncoder(&sb).Encode(result); err == nil {
723-
return &mcp.CallToolResult{
724-
Content: []mcp.Content{
725-
mcp.NewTextContent(sb.String()),
726-
},
727-
}, nil
728-
}
729-
// If the result is not JSON, return it as a string.
730-
// This is a fallback for tools that return non-JSON data.
731-
resultStr, ok := result.(string)
732-
if !ok {
733-
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
734-
}
735726
return &mcp.CallToolResult{
736727
Content: []mcp.Content{
737-
mcp.NewTextContent(resultStr),
728+
mcp.NewTextContent(string(result)),
738729
},
739730
}, nil
740731
},

cli/exp_mcp_test.go

+18-6
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ func TestExpMcpServer(t *testing.T) {
3131
t.Parallel()
3232

3333
ctx := testutil.Context(t, testutil.WaitShort)
34+
cmdDone := make(chan struct{})
3435
cancelCtx, cancel := context.WithCancel(ctx)
35-
t.Cleanup(cancel)
36+
t.Cleanup(func() {
37+
cancel()
38+
<-cmdDone
39+
})
3640

3741
// Given: a running coder deployment
3842
client := coderdtest.New(t, nil)
39-
_ = coderdtest.CreateFirstUser(t, client)
43+
owner := coderdtest.CreateFirstUser(t, client)
4044

4145
// Given: we run the exp mcp command with allowed tools set
4246
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
@@ -48,7 +52,6 @@ func TestExpMcpServer(t *testing.T) {
4852
// nolint: gocritic // not the focus of this test
4953
clitest.SetupConfig(t, client, root)
5054

51-
cmdDone := make(chan struct{})
5255
go func() {
5356
defer close(cmdDone)
5457
err := inv.Run()
@@ -61,9 +64,6 @@ func TestExpMcpServer(t *testing.T) {
6164
_ = pty.ReadLine(ctx) // ignore echoed output
6265
output := pty.ReadLine(ctx)
6366

64-
cancel()
65-
<-cmdDone
66-
6767
// Then: we should only see the allowed tools in the response
6868
var toolsResponse struct {
6969
Result struct {
@@ -81,6 +81,18 @@ func TestExpMcpServer(t *testing.T) {
8181
}
8282
slices.Sort(foundTools)
8383
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
84+
85+
// Call the tool and ensure it works.
86+
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
87+
pty.WriteLine(toolPayload)
88+
_ = pty.ReadLine(ctx) // ignore echoed output
89+
output = pty.ReadLine(ctx)
90+
require.NotEmpty(t, output, "should have received a response from the tool")
91+
// Ensure it's valid JSON
92+
_, err = json.Marshal(output)
93+
require.NoError(t, err, "should have received a valid JSON response from the tool")
94+
// Ensure the tool returns the expected user
95+
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
8496
})
8597

8698
t.Run("OK", func(t *testing.T) {

codersdk/toolsdk/toolsdk.go

+66-36
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package toolsdk
22

33
import (
44
"archive/tar"
5+
"bytes"
56
"context"
7+
"encoding/json"
68
"io"
79

810
"github.com/google/uuid"
@@ -20,28 +22,49 @@ type Deps struct {
2022
AppStatusSlug string
2123
}
2224

23-
// HandlerFunc is a function that handles a tool call.
25+
// HandlerFunc is a typed function that handles a tool call.
2426
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)
2527

28+
// Tool consists of an aisdk.Tool and a corresponding typed handler function.
2629
type Tool[Arg, Ret any] struct {
2730
aisdk.Tool
2831
Handler HandlerFunc[Arg, Ret]
2932
}
3033

31-
// Generic returns a type-erased version of the Tool.
32-
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
33-
return Tool[any, any]{
34+
// Generic returns a type-erased version of a TypedTool where the arguments and
35+
// return values are converted to/from json.RawMessage.
36+
// This allows the tool to be referenced without knowing the concrete arguments
37+
// or return values. The original TypedHandlerFunc is wrapped to handle type
38+
// conversion.
39+
func (t Tool[Arg, Ret]) Generic() GenericTool {
40+
return GenericTool{
3441
Tool: t.Tool,
35-
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
36-
typedArg, ok := args.(Arg)
37-
if !ok {
38-
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
42+
Handler: wrap(func(ctx context.Context, tb Deps, args json.RawMessage) (json.RawMessage, error) {
43+
var typedArgs Arg
44+
if err := json.Unmarshal(args, &typedArgs); err != nil {
45+
return nil, xerrors.Errorf("failed to unmarshal args: %w", err)
3946
}
40-
return t.Handler(ctx, tb, typedArg)
41-
},
47+
ret, err := t.Handler(ctx, tb, typedArgs)
48+
var buf bytes.Buffer
49+
if err := json.NewEncoder(&buf).Encode(ret); err != nil {
50+
return json.RawMessage{}, err
51+
}
52+
return buf.Bytes(), err
53+
}, WithCleanContext, WithRecover),
4254
}
4355
}
4456

57+
// GenericTool is a type-erased wrapper for GenericTool.
58+
// This allows referencing the tool without knowing the concrete argument or
59+
// return type. The Handler function allows calling the tool with known types.
60+
type GenericTool struct {
61+
aisdk.Tool
62+
Handler GenericHandlerFunc
63+
}
64+
65+
// GenericHandlerFunc is a function that handles a tool call.
66+
type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error)
67+
4568
type NoArgs struct{}
4669

4770
type ReportTaskArgs struct {
@@ -114,8 +137,8 @@ type UploadTarFileArgs struct {
114137
}
115138

116139
// WithRecover wraps a HandlerFunc to recover from panics and return an error.
117-
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
118-
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
140+
func WithRecover(h GenericHandlerFunc) GenericHandlerFunc {
141+
return func(ctx context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
119142
defer func() {
120143
if r := recover(); r != nil {
121144
err = xerrors.Errorf("tool handler panic: %v", r)
@@ -129,8 +152,8 @@ func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
129152
// This ensures that no data is passed using context.Value.
130153
// If a deadline is set on the parent context, it will be passed to the child
131154
// context.
132-
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
133-
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
155+
func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc {
156+
return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
134157
child, childCancel := context.WithCancel(context.Background())
135158
defer childCancel()
136159
// Ensure that cancellation propagates from the parent context to the child context.
@@ -153,19 +176,18 @@ func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Re
153176
}
154177
}
155178

156-
// wrapAll wraps all provided tools with the given middleware function.
157-
func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool[any, any]) []Tool[any, any] {
158-
for i, t := range tools {
159-
t.Handler = mw(t.Handler)
160-
tools[i] = t
179+
// wrap wraps the provided GenericHandlerFunc with the provided middleware functions.
180+
func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFunc) GenericHandlerFunc {
181+
for _, m := range mw {
182+
hf = m(hf)
161183
}
162-
return tools
184+
return hf
163185
}
164186

165187
var (
166188
// All is a list of all tools that can be used in the Coder CLI.
167189
// When you add a new tool, be sure to include it here!
168-
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
190+
All = []GenericTool{
169191
CreateTemplate.Generic(),
170192
CreateTemplateVersion.Generic(),
171193
CreateWorkspace.Generic(),
@@ -182,9 +204,9 @@ var (
182204
ReportTask.Generic(),
183205
UploadTarFile.Generic(),
184206
UpdateTemplateActiveVersion.Generic(),
185-
)...)
207+
}
186208

187-
ReportTask = Tool[ReportTaskArgs, string]{
209+
ReportTask = Tool[ReportTaskArgs, codersdk.Response]{
188210
Tool: aisdk.Tool{
189211
Name: "coder_report_task",
190212
Description: "Report progress on a user task in Coder.",
@@ -211,22 +233,24 @@ var (
211233
Required: []string{"summary", "link", "state"},
212234
},
213235
},
214-
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
236+
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (codersdk.Response, error) {
215237
if tb.AgentClient == nil {
216-
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
238+
return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
217239
}
218240
if tb.AppStatusSlug == "" {
219-
return "", xerrors.New("workspace app status slug not found in toolbox")
241+
return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox")
220242
}
221243
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
222244
AppSlug: tb.AppStatusSlug,
223245
Message: args.Summary,
224246
URI: args.Link,
225247
State: codersdk.WorkspaceAppStatusState(args.State),
226248
}); err != nil {
227-
return "", err
249+
return codersdk.Response{}, err
228250
}
229-
return "Thanks for reporting!", nil
251+
return codersdk.Response{
252+
Message: "Thanks for reporting!",
253+
}, nil
230254
},
231255
}
232256

@@ -934,9 +958,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
934958
if err != nil {
935959
return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err)
936960
}
937-
templateID, err := uuid.Parse(args.TemplateID)
938-
if err != nil {
939-
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
961+
var templateID uuid.UUID
962+
if args.TemplateID != "" {
963+
tid, err := uuid.Parse(args.TemplateID)
964+
if err != nil {
965+
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
966+
}
967+
templateID = tid
940968
}
941969
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
942970
Message: "Created by AI",
@@ -1183,7 +1211,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11831211
},
11841212
}
11851213

1186-
DeleteTemplate = Tool[DeleteTemplateArgs, string]{
1214+
DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{
11871215
Tool: aisdk.Tool{
11881216
Name: "coder_delete_template",
11891217
Description: "Delete a template. This is irreversible.",
@@ -1195,16 +1223,18 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11951223
},
11961224
},
11971225
},
1198-
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
1226+
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (codersdk.Response, error) {
11991227
templateID, err := uuid.Parse(args.TemplateID)
12001228
if err != nil {
1201-
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
1229+
return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
12021230
}
12031231
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
12041232
if err != nil {
1205-
return "", err
1233+
return codersdk.Response{}, err
12061234
}
1207-
return "Successfully deleted template!", nil
1235+
return codersdk.Response{
1236+
Message: "Template deleted successfully.",
1237+
}, nil
12081238
},
12091239
}
12101240
)

0 commit comments

Comments
 (0)