Skip to content

Commit 233f9da

Browse files
committed
address more PR comments
1 parent 2462f76 commit 233f9da

File tree

3 files changed

+77
-36
lines changed

3 files changed

+77
-36
lines changed

cli/exp_mcp.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -400,24 +400,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
400400
server.WithInstructions(instructions),
401401
)
402402

403-
// Create a new context for the tools with all relevant information.
404-
tb := toolsdk.Deps{
405-
CoderClient: client,
406-
}
407403
// Get the workspace agent token from the environment.
404+
toolOpts := make([]func(*toolsdk.Deps), 0)
408405
var hasAgentClient bool
409406
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
410407
hasAgentClient = true
411408
agentClient := agentsdk.New(client.URL)
412409
agentClient.SetSessionToken(agentToken)
413-
tb.AgentClient = agentClient
410+
toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient))
414411
} else {
415412
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
416413
}
417-
if appStatusSlug == "" {
418-
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
414+
415+
if appStatusSlug != "" {
416+
toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug))
419417
} else {
420-
tb.AppStatusSlug = appStatusSlug
418+
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
419+
}
420+
421+
toolDeps, err := toolsdk.NewDeps(client, toolOpts...)
422+
if err != nil {
423+
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
421424
}
422425

423426
// Register tools based on the allowlist (if specified)
@@ -430,7 +433,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
430433
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
431434
return t == tool.Tool.Name
432435
}) {
433-
mcpSrv.AddTools(mcpFromSDK(tool, tb))
436+
mcpSrv.AddTools(mcpFromSDK(tool, toolDeps))
434437
}
435438
}
436439

codersdk/toolsdk/toolsdk.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ import (
1515
"github.com/coder/coder/v2/codersdk/agentsdk"
1616
)
1717

18+
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
19+
d := Deps{
20+
CoderClient: client,
21+
}
22+
for _, opt := range opts {
23+
opt(&d)
24+
}
25+
if d.CoderClient == nil {
26+
return Deps{}, xerrors.New("developer error: coder client may not be nil")
27+
}
28+
return d, nil
29+
}
30+
31+
func WithAgentClient(client *agentsdk.Client) func(*Deps) {
32+
return func(d *Deps) {
33+
d.AgentClient = client
34+
}
35+
}
36+
37+
func WithAppStatusSlug(slug string) func(*Deps) {
38+
return func(d *Deps) {
39+
d.AppStatusSlug = slug
40+
}
41+
}
42+
1843
// Deps provides access to tool dependencies.
1944
type Deps struct {
2045
CoderClient *codersdk.Client
@@ -175,7 +200,7 @@ var ReportTask = Tool[ReportTaskArgs, codersdk.Response]{
175200
return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
176201
}
177202
if deps.AppStatusSlug == "" {
178-
return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox")
203+
return codersdk.Response{}, xerrors.New("tool unavailable as CODER_MCP_APP_STATUS_SLUG is not set")
179204
}
180205
if len(args.Summary) > 160 {
181206
return codersdk.Response{}, xerrors.New("summary must be less than 160 characters")

codersdk/toolsdk/toolsdk_test.go

+39-26
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,9 @@ func TestTools(t *testing.T) {
7272
})
7373

7474
t.Run("ReportTask", func(t *testing.T) {
75-
tb := toolsdk.Deps{
76-
CoderClient: memberClient,
77-
AgentClient: agentClient,
78-
AppStatusSlug: "some-agent-app",
79-
}
80-
_, err := testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{
75+
tb, err := toolsdk.NewDeps(memberClient, toolsdk.WithAgentClient(agentClient), toolsdk.WithAppStatusSlug("some-agent-app"))
76+
require.NoError(t, err)
77+
_, err = testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{
8178
Summary: "test summary",
8279
State: "complete",
8380
Link: "https://example.com",
@@ -86,7 +83,8 @@ func TestTools(t *testing.T) {
8683
})
8784

8885
t.Run("GetWorkspace", func(t *testing.T) {
89-
tb := toolsdk.Deps{CoderClient: memberClient}
86+
tb, err := toolsdk.NewDeps(memberClient)
87+
require.NoError(t, err)
9088
result, err := testTool(t, toolsdk.GetWorkspace, tb, toolsdk.GetWorkspaceArgs{
9189
WorkspaceID: r.Workspace.ID.String(),
9290
})
@@ -96,7 +94,8 @@ func TestTools(t *testing.T) {
9694
})
9795

9896
t.Run("ListTemplates", func(t *testing.T) {
99-
tb := toolsdk.Deps{CoderClient: memberClient}
97+
tb, err := toolsdk.NewDeps(memberClient)
98+
require.NoError(t, err)
10099
// Get the templates directly for comparison
101100
expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{})
102101
require.NoError(t, err)
@@ -119,7 +118,8 @@ func TestTools(t *testing.T) {
119118
})
120119

121120
t.Run("Whoami", func(t *testing.T) {
122-
tb := toolsdk.Deps{CoderClient: memberClient}
121+
tb, err := toolsdk.NewDeps(memberClient)
122+
require.NoError(t, err)
123123
result, err := testTool(t, toolsdk.GetAuthenticatedUser, tb, toolsdk.NoArgs{})
124124

125125
require.NoError(t, err)
@@ -128,7 +128,8 @@ func TestTools(t *testing.T) {
128128
})
129129

130130
t.Run("ListWorkspaces", func(t *testing.T) {
131-
tb := toolsdk.Deps{CoderClient: memberClient}
131+
tb, err := toolsdk.NewDeps(memberClient)
132+
require.NoError(t, err)
132133
result, err := testTool(t, toolsdk.ListWorkspaces, tb, toolsdk.ListWorkspacesArgs{})
133134

134135
require.NoError(t, err)
@@ -140,7 +141,8 @@ func TestTools(t *testing.T) {
140141
t.Run("CreateWorkspaceBuild", func(t *testing.T) {
141142
t.Run("Stop", func(t *testing.T) {
142143
ctx := testutil.Context(t, testutil.WaitShort)
143-
tb := toolsdk.Deps{CoderClient: memberClient}
144+
tb, err := toolsdk.NewDeps(memberClient)
145+
require.NoError(t, err)
144146
result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{
145147
WorkspaceID: r.Workspace.ID.String(),
146148
Transition: "stop",
@@ -159,7 +161,8 @@ func TestTools(t *testing.T) {
159161

160162
t.Run("Start", func(t *testing.T) {
161163
ctx := testutil.Context(t, testutil.WaitShort)
162-
tb := toolsdk.Deps{CoderClient: memberClient}
164+
tb, err := toolsdk.NewDeps(memberClient)
165+
require.NoError(t, err)
163166
result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{
164167
WorkspaceID: r.Workspace.ID.String(),
165168
Transition: "start",
@@ -178,7 +181,8 @@ func TestTools(t *testing.T) {
178181

179182
t.Run("TemplateVersionChange", func(t *testing.T) {
180183
ctx := testutil.Context(t, testutil.WaitShort)
181-
tb := toolsdk.Deps{CoderClient: memberClient}
184+
tb, err := toolsdk.NewDeps(memberClient)
185+
require.NoError(t, err)
182186
// Get the current template version ID before updating
183187
workspace, err := memberClient.Workspace(ctx, r.Workspace.ID)
184188
require.NoError(t, err)
@@ -222,7 +226,8 @@ func TestTools(t *testing.T) {
222226
})
223227

224228
t.Run("ListTemplateVersionParameters", func(t *testing.T) {
225-
tb := toolsdk.Deps{CoderClient: memberClient}
229+
tb, err := toolsdk.NewDeps(memberClient)
230+
require.NoError(t, err)
226231
params, err := testTool(t, toolsdk.ListTemplateVersionParameters, tb, toolsdk.ListTemplateVersionParametersArgs{
227232
TemplateVersionID: r.TemplateVersion.ID.String(),
228233
})
@@ -232,7 +237,8 @@ func TestTools(t *testing.T) {
232237
})
233238

234239
t.Run("GetWorkspaceAgentLogs", func(t *testing.T) {
235-
tb := toolsdk.Deps{CoderClient: client}
240+
tb, err := toolsdk.NewDeps(memberClient)
241+
require.NoError(t, err)
236242
logs, err := testTool(t, toolsdk.GetWorkspaceAgentLogs, tb, toolsdk.GetWorkspaceAgentLogsArgs{
237243
WorkspaceAgentID: agentID.String(),
238244
})
@@ -242,7 +248,8 @@ func TestTools(t *testing.T) {
242248
})
243249

244250
t.Run("GetWorkspaceBuildLogs", func(t *testing.T) {
245-
tb := toolsdk.Deps{CoderClient: memberClient}
251+
tb, err := toolsdk.NewDeps(memberClient)
252+
require.NoError(t, err)
246253
logs, err := testTool(t, toolsdk.GetWorkspaceBuildLogs, tb, toolsdk.GetWorkspaceBuildLogsArgs{
247254
WorkspaceBuildID: r.Build.ID.String(),
248255
})
@@ -252,7 +259,8 @@ func TestTools(t *testing.T) {
252259
})
253260

254261
t.Run("GetTemplateVersionLogs", func(t *testing.T) {
255-
tb := toolsdk.Deps{CoderClient: memberClient}
262+
tb, err := toolsdk.NewDeps(memberClient)
263+
require.NoError(t, err)
256264
logs, err := testTool(t, toolsdk.GetTemplateVersionLogs, tb, toolsdk.GetTemplateVersionLogsArgs{
257265
TemplateVersionID: r.TemplateVersion.ID.String(),
258266
})
@@ -262,7 +270,8 @@ func TestTools(t *testing.T) {
262270
})
263271

264272
t.Run("UpdateTemplateActiveVersion", func(t *testing.T) {
265-
tb := toolsdk.Deps{CoderClient: client}
273+
tb, err := toolsdk.NewDeps(client)
274+
require.NoError(t, err)
266275
result, err := testTool(t, toolsdk.UpdateTemplateActiveVersion, tb, toolsdk.UpdateTemplateActiveVersionArgs{
267276
TemplateID: r.Template.ID.String(),
268277
TemplateVersionID: r.TemplateVersion.ID.String(),
@@ -273,8 +282,9 @@ func TestTools(t *testing.T) {
273282
})
274283

275284
t.Run("DeleteTemplate", func(t *testing.T) {
276-
tb := toolsdk.Deps{CoderClient: client}
277-
_, err := testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{
285+
tb, err := toolsdk.NewDeps(client)
286+
require.NoError(t, err)
287+
_, err = testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{
278288
TemplateID: r.Template.ID.String(),
279289
})
280290

@@ -283,10 +293,11 @@ func TestTools(t *testing.T) {
283293
})
284294

285295
t.Run("UploadTarFile", func(t *testing.T) {
286-
tb := toolsdk.Deps{CoderClient: client}
287296
files := map[string]string{
288297
"main.tf": `resource "null_resource" "example" {}`,
289298
}
299+
tb, err := toolsdk.NewDeps(memberClient)
300+
require.NoError(t, err)
290301

291302
result, err := testTool(t, toolsdk.UploadTarFile, tb, toolsdk.UploadTarFileArgs{
292303
Files: files,
@@ -297,7 +308,8 @@ func TestTools(t *testing.T) {
297308
})
298309

299310
t.Run("CreateTemplateVersion", func(t *testing.T) {
300-
tb := toolsdk.Deps{CoderClient: client}
311+
tb, err := toolsdk.NewDeps(client)
312+
require.NoError(t, err)
301313
// nolint:gocritic // This is in a test package and does not end up in the build
302314
file := dbgen.File(t, store, database.File{})
303315
t.Run("WithoutTemplateID", func(t *testing.T) {
@@ -308,7 +320,6 @@ func TestTools(t *testing.T) {
308320
require.NotEmpty(t, tv)
309321
})
310322
t.Run("WithTemplateID", func(t *testing.T) {
311-
tb := toolsdk.Deps{CoderClient: client}
312323
tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{
313324
FileID: file.ID.String(),
314325
TemplateID: r.Template.ID.String(),
@@ -319,15 +330,16 @@ func TestTools(t *testing.T) {
319330
})
320331

321332
t.Run("CreateTemplate", func(t *testing.T) {
322-
tb := toolsdk.Deps{CoderClient: client}
333+
tb, err := toolsdk.NewDeps(client)
334+
require.NoError(t, err)
323335
// Create a new template version for use here.
324336
tv := dbfake.TemplateVersion(t, store).
325337
// nolint:gocritic // This is in a test package and does not end up in the build
326338
Seed(database.TemplateVersion{OrganizationID: owner.OrganizationID, CreatedBy: owner.UserID}).
327339
SkipCreateTemplate().Do()
328340

329341
// We're going to re-use the pre-existing template version
330-
_, err := testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{
342+
_, err = testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{
331343
Name: testutil.GetRandomNameHyphenated(t),
332344
DisplayName: "Test Template",
333345
Description: "This is a test template",
@@ -338,7 +350,8 @@ func TestTools(t *testing.T) {
338350
})
339351

340352
t.Run("CreateWorkspace", func(t *testing.T) {
341-
tb := toolsdk.Deps{CoderClient: memberClient}
353+
tb, err := toolsdk.NewDeps(client)
354+
require.NoError(t, err)
342355
// We need a template version ID to create a workspace
343356
res, err := testTool(t, toolsdk.CreateWorkspace, tb, toolsdk.CreateWorkspaceArgs{
344357
User: "me",

0 commit comments

Comments
 (0)