diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index f2573f44a1be6..1144d9265aa15 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -77,15 +77,12 @@ func TestCoderTools(t *testing.T) { pty.WriteLine(ctr) _ = pty.ReadLine(ctx) // skip the echo - templates, err := memberClient.Templates(ctx, codersdk.TemplateFilter{}) + // Then: the response is a list of expected visible to the user. + expected, err := memberClient.Templates(ctx, codersdk.TemplateFilter{}) require.NoError(t, err) - templatesJSON, err := json.Marshal(templates) - require.NoError(t, err) - - // Then: the response is a list of templates visible to the user. - expected := makeJSONRPCTextResponse(t, string(templatesJSON)) - actual := pty.ReadLine(ctx) - testutil.RequireJSONEq(t, expected, actual) + actual := unmarshalFromCallToolResult[[]codersdk.Template](t, pty.ReadLine(ctx)) + require.Len(t, actual, 1) + require.Equal(t, expected[0].ID, actual[0].ID) }) t.Run("coder_report_task", func(t *testing.T) { @@ -111,20 +108,16 @@ func TestCoderTools(t *testing.T) { t.Run("coder_whoami", func(t *testing.T) { // When: the coder_whoami tool is called - me, err := memberClient.User(ctx, codersdk.Me) - require.NoError(t, err) - meJSON, err := json.Marshal(me) - require.NoError(t, err) - ctr := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{}) pty.WriteLine(ctr) _ = pty.ReadLine(ctx) // skip the echo // Then: the response is a valid JSON respresentation of the calling user. - expected := makeJSONRPCTextResponse(t, string(meJSON)) - actual := pty.ReadLine(ctx) - testutil.RequireJSONEq(t, expected, actual) + expected, err := memberClient.User(ctx, codersdk.Me) + require.NoError(t, err) + actual := unmarshalFromCallToolResult[codersdk.User](t, pty.ReadLine(ctx)) + require.Equal(t, expected.ID, actual.ID) }) t.Run("coder_list_workspaces", func(t *testing.T) { @@ -138,15 +131,10 @@ func TestCoderTools(t *testing.T) { pty.WriteLine(ctr) _ = pty.ReadLine(ctx) // skip the echo - ws, err := memberClient.Workspaces(ctx, codersdk.WorkspaceFilter{}) - require.NoError(t, err) - wsJSON, err := json.Marshal(ws) - require.NoError(t, err) - // Then: the response is a valid JSON respresentation of the calling user's workspaces. - expected := makeJSONRPCTextResponse(t, string(wsJSON)) - actual := pty.ReadLine(ctx) - testutil.RequireJSONEq(t, expected, actual) + actual := unmarshalFromCallToolResult[codersdk.WorkspacesResponse](t, pty.ReadLine(ctx)) + require.Len(t, actual.Workspaces, 1, "expected 1 workspace") + require.Equal(t, r.Workspace.ID, actual.Workspaces[0].ID, "expected the workspace to be the one we created in setup") }) t.Run("coder_get_workspace", func(t *testing.T) { @@ -161,15 +149,12 @@ func TestCoderTools(t *testing.T) { pty.WriteLine(ctr) _ = pty.ReadLine(ctx) // skip the echo - ws, err := memberClient.Workspace(ctx, r.Workspace.ID) - require.NoError(t, err) - wsJSON, err := json.Marshal(ws) + expected, err := memberClient.Workspace(ctx, r.Workspace.ID) require.NoError(t, err) // Then: the response is a valid JSON respresentation of the workspace. - expected := makeJSONRPCTextResponse(t, string(wsJSON)) - actual := pty.ReadLine(ctx) - testutil.RequireJSONEq(t, expected, actual) + actual := unmarshalFromCallToolResult[codersdk.Workspace](t, pty.ReadLine(ctx)) + require.Equal(t, expected.ID, actual.ID) }) // NOTE: this test runs after the list_workspaces tool is called. @@ -322,6 +307,25 @@ func makeJSONRPCTextResponse(t *testing.T, text string) string { return string(bs) } +func unmarshalFromCallToolResult[T any](t *testing.T, raw string) T { + t.Helper() + + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(raw), &resp), "failed to unmarshal JSON RPC response") + res, ok := resp["result"].(map[string]any) + require.True(t, ok, "expected a result field in the response") + ct, ok := res["content"].([]any) + require.True(t, ok, "expected a content field in the result") + require.Len(t, ct, 1, "expected a single content item in the result") + ct0, ok := ct[0].(map[string]any) + require.True(t, ok, "expected a content item in the result") + txt, ok := ct0["text"].(string) + require.True(t, ok, "expected a text field in the content item") + var actual T + require.NoError(t, json.Unmarshal([]byte(txt), &actual), "failed to unmarshal content") + return actual +} + // startTestMCPServer is a helper function that starts a MCP server listening on // a pty. It is the responsibility of the caller to close the server. func startTestMCPServer(ctx context.Context, t testing.TB, stdin io.Reader, stdout io.Writer) (*server.MCPServer, func() error) {