Skip to content

chore(codersdk/toolsdk): improve static analyzability of toolsdk.Tools #17562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 29, 2025
Prev Previous commit
Next Next commit
add WithCleanContext middleware func
  • Loading branch information
johnstcn committed Apr 29, 2025
commit c1057d930089681e5ef934b4f0f64e962394a191
4 changes: 2 additions & 2 deletions cli/exp_mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,8 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
Required: sdkTool.Schema.Required,
},
},
Handler: func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := sdkTool.Handler(tb, request.Params.Arguments)
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
if err != nil {
return nil, err
}
Expand Down
110 changes: 69 additions & 41 deletions codersdk/toolsdk/toolsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Deps struct {
}

// HandlerFunc is a function that handles a tool call.
type HandlerFunc[Arg, Ret any] func(tb Deps, args Arg) (Ret, error)
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)

type Tool[Arg, Ret any] struct {
aisdk.Tool
Expand All @@ -32,12 +32,12 @@ type Tool[Arg, Ret any] struct {
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
return Tool[any, any]{
Tool: t.Tool,
Handler: func(tb Deps, args any) (any, error) {
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
typedArg, ok := args.(Arg)
if !ok {
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
}
return t.Handler(tb, typedArg)
return t.Handler(ctx, tb, typedArg)
},
}
}
Expand Down Expand Up @@ -115,13 +115,41 @@ type UploadTarFileArgs struct {

// WithRecover wraps a HandlerFunc to recover from panics and return an error.
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
return func(tb Deps, args Arg) (ret Ret, err error) {
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
defer func() {
if r := recover(); r != nil {
err = xerrors.Errorf("tool handler panic: %v", r)
}
}()
return h(tb, args)
return h(ctx, tb, args)
}
}

// WithCleanContext wraps a HandlerFunc to provide it with a new context.
// This ensures that no data is passed using context.Value.
// If a deadline is set on the parent context, it will be passed to the child
// context.
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
child, childCancel := context.WithCancel(context.Background())
defer childCancel()
// Ensure that cancellation propagates from the parent context to the child context.
go func() {
select {
case <-child.Done():
return
case <-parent.Done():
childCancel()
}
}()
// Also ensure that the child context has the same deadline as the parent
// context.
if deadline, ok := parent.Deadline(); ok {
deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline)
defer deadlineCancel()
child = deadlineCtx
}
return h(child, tb, args)
}
}

Expand All @@ -137,7 +165,7 @@ func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool
var (
// All is a list of all tools that can be used in the Coder CLI.
// When you add a new tool, be sure to include it here!
All = wrapAll(WithRecover,
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
CreateTemplate.Generic(),
CreateTemplateVersion.Generic(),
CreateWorkspace.Generic(),
Expand All @@ -154,7 +182,7 @@ var (
ReportTask.Generic(),
UploadTarFile.Generic(),
UpdateTemplateActiveVersion.Generic(),
)
)...)

ReportTask = Tool[ReportTaskArgs, string]{
Tool: aisdk.Tool{
Expand Down Expand Up @@ -183,14 +211,14 @@ var (
Required: []string{"summary", "link", "state"},
},
},
Handler: func(tb Deps, args ReportTaskArgs) (string, error) {
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
if tb.AgentClient == nil {
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
}
if tb.AppStatusSlug == "" {
return "", xerrors.New("workspace app status slug not found in toolbox")
}
if err := tb.AgentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: tb.AppStatusSlug,
Message: args.Summary,
URI: args.Link,
Expand All @@ -217,12 +245,12 @@ This returns more data than list_workspaces to reduce token usage.`,
Required: []string{"workspace_id"},
},
},
Handler: func(tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
wsID, err := uuid.Parse(args.WorkspaceID)
if err != nil {
return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID")
}
return tb.CoderClient.Workspace(context.TODO(), wsID)
return tb.CoderClient.Workspace(ctx, wsID)
},
}

Expand Down Expand Up @@ -257,7 +285,7 @@ is provisioned correctly and the agent can connect to the control plane.
Required: []string{"user", "template_version_id", "name", "rich_parameters"},
},
},
Handler: func(tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
tvID, err := uuid.Parse(args.TemplateVersionID)
if err != nil {
return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID")
Expand All @@ -272,7 +300,7 @@ is provisioned correctly and the agent can connect to the control plane.
Value: v,
})
}
workspace, err := tb.CoderClient.CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{
workspace, err := tb.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{
TemplateVersionID: tvID,
Name: args.Name,
RichParameterValues: buildParams,
Expand All @@ -297,12 +325,12 @@ is provisioned correctly and the agent can connect to the control plane.
},
},
},
Handler: func(tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
Handler: func(ctx context.Context, tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
owner := args.Owner
if owner == "" {
owner = codersdk.Me
}
workspaces, err := tb.CoderClient.Workspaces(context.TODO(), codersdk.WorkspaceFilter{
workspaces, err := tb.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: owner,
})
if err != nil {
Expand Down Expand Up @@ -334,8 +362,8 @@ is provisioned correctly and the agent can connect to the control plane.
Required: []string{},
},
},
Handler: func(tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
templates, err := tb.CoderClient.Templates(context.TODO(), codersdk.TemplateFilter{})
Handler: func(ctx context.Context, tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
templates, err := tb.CoderClient.Templates(ctx, codersdk.TemplateFilter{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -367,12 +395,12 @@ is provisioned correctly and the agent can connect to the control plane.
Required: []string{"template_version_id"},
},
},
Handler: func(tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
Handler: func(ctx context.Context, tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
if err != nil {
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
}
parameters, err := tb.CoderClient.TemplateVersionRichParameters(context.TODO(), templateVersionID)
parameters, err := tb.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID)
if err != nil {
return nil, err
}
Expand All @@ -389,8 +417,8 @@ is provisioned correctly and the agent can connect to the control plane.
Required: []string{},
},
},
Handler: func(tb Deps, _ NoArgs) (codersdk.User, error) {
return tb.CoderClient.User(context.TODO(), "me")
Handler: func(ctx context.Context, tb Deps, _ NoArgs) (codersdk.User, error) {
return tb.CoderClient.User(ctx, "me")
},
}

Expand All @@ -416,7 +444,7 @@ is provisioned correctly and the agent can connect to the control plane.
Required: []string{"workspace_id", "transition"},
},
},
Handler: func(tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
workspaceID, err := uuid.Parse(args.WorkspaceID)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err)
Expand All @@ -435,7 +463,7 @@ is provisioned correctly and the agent can connect to the control plane.
if templateVersionID != uuid.Nil {
cbr.TemplateVersionID = templateVersionID
}
return tb.CoderClient.CreateWorkspaceBuild(context.TODO(), workspaceID, cbr)
return tb.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr)
},
}

Expand Down Expand Up @@ -897,8 +925,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"file_id"},
},
},
Handler: func(tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
me, err := tb.CoderClient.User(context.TODO(), "me")
Handler: func(ctx context.Context, tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
me, err := tb.CoderClient.User(ctx, "me")
if err != nil {
return codersdk.TemplateVersion{}, err
}
Expand All @@ -910,7 +938,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
if err != nil {
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
}
templateVersion, err := tb.CoderClient.CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
Message: "Created by AI",
StorageMethod: codersdk.ProvisionerStorageMethodFile,
FileID: fileID,
Expand Down Expand Up @@ -939,12 +967,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"workspace_agent_id"},
},
},
Handler: func(tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID)
if err != nil {
return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err)
}
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false)
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -974,12 +1002,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"workspace_build_id"},
},
},
Handler: func(tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID)
if err != nil {
return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err)
}
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0)
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0)
if err != nil {
return nil, err
}
Expand All @@ -1005,13 +1033,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"template_version_id"},
},
},
Handler: func(tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
Handler: func(ctx context.Context, tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
if err != nil {
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
}

logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0)
logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1040,7 +1068,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"template_id", "template_version_id"},
},
},
Handler: func(tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
Handler: func(ctx context.Context, tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
templateID, err := uuid.Parse(args.TemplateID)
if err != nil {
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
Expand All @@ -1049,7 +1077,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
if err != nil {
return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
}
err = tb.CoderClient.UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{
err = tb.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{
ID: templateVersionID,
})
if err != nil {
Expand All @@ -1073,7 +1101,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"mime_type", "files"},
},
},
Handler: func(tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
Handler: func(ctx context.Context, tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
pipeReader, pipeWriter := io.Pipe()
go func() {
defer pipeWriter.Close()
Expand All @@ -1098,7 +1126,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
}
}()

resp, err := tb.CoderClient.Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader)
resp, err := tb.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader)
if err != nil {
return codersdk.UploadResponse{}, err
}
Expand Down Expand Up @@ -1133,16 +1161,16 @@ The file_id provided is a reference to a tar file you have uploaded containing t
Required: []string{"name", "display_name", "description", "version_id"},
},
},
Handler: func(tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
me, err := tb.CoderClient.User(context.TODO(), "me")
Handler: func(ctx context.Context, tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
me, err := tb.CoderClient.User(ctx, "me")
if err != nil {
return codersdk.Template{}, err
}
versionID, err := uuid.Parse(args.VersionID)
if err != nil {
return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err)
}
template, err := tb.CoderClient.CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
template, err := tb.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
Name: args.Name,
DisplayName: args.DisplayName,
Description: args.Description,
Expand All @@ -1167,12 +1195,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
},
},
},
Handler: func(tb Deps, args DeleteTemplateArgs) (string, error) {
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
templateID, err := uuid.Parse(args.TemplateID)
if err != nil {
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
}
err = tb.CoderClient.DeleteTemplate(context.TODO(), templateID)
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
if err != nil {
return "", err
}
Expand Down
Loading