Skip to content

Commit f06c549

Browse files
committed
feat: unify workflow execution in GitHub Actions tools
- Refactored the RunWorkflow tool to accept both numeric workflow IDs and filenames, enhancing flexibility for users. - Updated the corresponding tests to reflect changes in parameter handling and added assertions for workflow type in responses. - Removed the separate RunWorkflowByFileName tool to streamline functionality and improve code maintainability.
1 parent f7e1320 commit f06c549

File tree

3 files changed

+46
-118
lines changed

3 files changed

+46
-118
lines changed

pkg/github/actions.go

Lines changed: 16 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"strconv"
910
"strings"
1011

1112
"github.com/github/github-mcp-server/pkg/translations"
@@ -236,10 +237,10 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun
236237
}
237238
}
238239

239-
// RunWorkflow creates a tool to run an Actions workflow by workflow ID
240+
// RunWorkflow creates a tool to run an Actions workflow
240241
func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
241242
return mcp.NewTool("run_workflow",
242-
mcp.WithDescription(t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID")),
243+
mcp.WithDescription(t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID or filename")),
243244
mcp.WithToolAnnotation(mcp.ToolAnnotation{
244245
Title: t("TOOL_RUN_WORKFLOW_USER_TITLE", "Run workflow"),
245246
ReadOnlyHint: ToBoolPtr(false),
@@ -252,9 +253,9 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t
252253
mcp.Required(),
253254
mcp.Description(DescriptionRepositoryName),
254255
),
255-
mcp.WithNumber("workflow_id",
256+
mcp.WithString("workflow_id",
256257
mcp.Required(),
257-
mcp.Description("The workflow ID (numeric identifier)"),
258+
mcp.Description("The workflow ID (numeric) or workflow file name (e.g., main.yml, ci.yaml)"),
258259
),
259260
mcp.WithString("ref",
260261
mcp.Required(),
@@ -273,11 +274,10 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t
273274
if err != nil {
274275
return mcp.NewToolResultError(err.Error()), nil
275276
}
276-
workflowIDInt, err := RequiredInt(request, "workflow_id")
277+
workflowID, err := RequiredParam[string](request, "workflow_id")
277278
if err != nil {
278279
return mcp.NewToolResultError(err.Error()), nil
279280
}
280-
workflowID := int64(workflowIDInt)
281281
ref, err := RequiredParam[string](request, "ref")
282282
if err != nil {
283283
return mcp.NewToolResultError(err.Error()), nil
@@ -301,105 +301,26 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t
301301
Inputs: inputs,
302302
}
303303

304-
// Convert workflow ID to string format for the API call
305-
workflowIDStr := fmt.Sprintf("%d", workflowID)
306-
resp, err := client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowIDStr, event)
307-
if err != nil {
308-
return nil, fmt.Errorf("failed to run workflow: %w", err)
309-
}
310-
defer func() { _ = resp.Body.Close() }()
311-
312-
result := map[string]any{
313-
"message": "Workflow run has been queued",
314-
"workflow_id": workflowID,
315-
"ref": ref,
316-
"inputs": inputs,
317-
"status": resp.Status,
318-
"status_code": resp.StatusCode,
319-
}
320-
321-
r, err := json.Marshal(result)
322-
if err != nil {
323-
return nil, fmt.Errorf("failed to marshal response: %w", err)
324-
}
325-
326-
return mcp.NewToolResultText(string(r)), nil
327-
}
328-
}
329-
330-
// RunWorkflowByFileName creates a tool to run an Actions workflow by filename
331-
func RunWorkflowByFileName(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
332-
return mcp.NewTool("run_workflow_by_filename",
333-
mcp.WithDescription(t("TOOL_RUN_WORKFLOW_BY_FILENAME_DESCRIPTION", "Run an Actions workflow by workflow filename")),
334-
mcp.WithToolAnnotation(mcp.ToolAnnotation{
335-
Title: t("TOOL_RUN_WORKFLOW_BY_FILENAME_USER_TITLE", "Run workflow by filename"),
336-
ReadOnlyHint: ToBoolPtr(false),
337-
}),
338-
mcp.WithString("owner",
339-
mcp.Required(),
340-
mcp.Description(DescriptionRepositoryOwner),
341-
),
342-
mcp.WithString("repo",
343-
mcp.Required(),
344-
mcp.Description(DescriptionRepositoryName),
345-
),
346-
mcp.WithString("workflow_file",
347-
mcp.Required(),
348-
mcp.Description("The workflow file name (e.g., main.yml, ci.yaml)"),
349-
),
350-
mcp.WithString("ref",
351-
mcp.Required(),
352-
mcp.Description("The git reference for the workflow. The reference can be a branch or tag name."),
353-
),
354-
mcp.WithObject("inputs",
355-
mcp.Description("Inputs the workflow accepts"),
356-
),
357-
),
358-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
359-
owner, err := RequiredParam[string](request, "owner")
360-
if err != nil {
361-
return mcp.NewToolResultError(err.Error()), nil
362-
}
363-
repo, err := RequiredParam[string](request, "repo")
364-
if err != nil {
365-
return mcp.NewToolResultError(err.Error()), nil
366-
}
367-
workflowFile, err := RequiredParam[string](request, "workflow_file")
368-
if err != nil {
369-
return mcp.NewToolResultError(err.Error()), nil
370-
}
371-
ref, err := RequiredParam[string](request, "ref")
372-
if err != nil {
373-
return mcp.NewToolResultError(err.Error()), nil
374-
}
375-
376-
// Get optional inputs parameter
377-
var inputs map[string]interface{}
378-
if requestInputs, ok := request.GetArguments()["inputs"]; ok {
379-
if inputsMap, ok := requestInputs.(map[string]interface{}); ok {
380-
inputs = inputsMap
381-
}
382-
}
383-
384-
client, err := getClient(ctx)
385-
if err != nil {
386-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
387-
}
304+
var resp *github.Response
305+
var workflowType string
388306

389-
event := github.CreateWorkflowDispatchEventRequest{
390-
Ref: ref,
391-
Inputs: inputs,
307+
if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil {
308+
resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event)
309+
workflowType = "workflow_id"
310+
} else {
311+
resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event)
312+
workflowType = "workflow_file"
392313
}
393314

394-
resp, err := client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowFile, event)
395315
if err != nil {
396316
return nil, fmt.Errorf("failed to run workflow: %w", err)
397317
}
398318
defer func() { _ = resp.Body.Close() }()
399319

400320
result := map[string]any{
401321
"message": "Workflow run has been queued",
402-
"workflow_file": workflowFile,
322+
"workflow_type": workflowType,
323+
"workflow_id": workflowID,
403324
"ref": ref,
404325
"inputs": inputs,
405326
"status": resp.Status,

pkg/github/actions_test.go

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func Test_RunWorkflow(t *testing.T) {
159159
requestArgs: map[string]any{
160160
"owner": "owner",
161161
"repo": "repo",
162-
"workflow_id": float64(12345),
162+
"workflow_id": "12345",
163163
"ref": "main",
164164
},
165165
expectError: false,
@@ -205,24 +205,13 @@ func Test_RunWorkflow(t *testing.T) {
205205
err = json.Unmarshal([]byte(textContent.Text), &response)
206206
require.NoError(t, err)
207207
assert.Equal(t, "Workflow run has been queued", response["message"])
208+
assert.Contains(t, response, "workflow_type")
208209
})
209210
}
210211
}
211212

212-
func Test_RunWorkflowByFileName(t *testing.T) {
213-
// Verify tool definition once
214-
mockClient := github.NewClient(nil)
215-
tool, _ := RunWorkflowByFileName(stubGetClientFn(mockClient), translations.NullTranslationHelper)
216-
217-
assert.Equal(t, "run_workflow_by_filename", tool.Name)
218-
assert.NotEmpty(t, tool.Description)
219-
assert.Contains(t, tool.InputSchema.Properties, "owner")
220-
assert.Contains(t, tool.InputSchema.Properties, "repo")
221-
assert.Contains(t, tool.InputSchema.Properties, "workflow_file")
222-
assert.Contains(t, tool.InputSchema.Properties, "ref")
223-
assert.Contains(t, tool.InputSchema.Properties, "inputs")
224-
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "workflow_file", "ref"})
225-
213+
func Test_RunWorkflow_WithFilename(t *testing.T) {
214+
// Test the unified RunWorkflow function with filenames
226215
tests := []struct {
227216
name string
228217
mockedClient *http.Client
@@ -241,31 +230,49 @@ func Test_RunWorkflowByFileName(t *testing.T) {
241230
),
242231
),
243232
requestArgs: map[string]any{
244-
"owner": "owner",
245-
"repo": "repo",
246-
"workflow_file": "ci.yml",
247-
"ref": "main",
233+
"owner": "owner",
234+
"repo": "repo",
235+
"workflow_id": "ci.yml",
236+
"ref": "main",
248237
},
249238
expectError: false,
250239
},
251240
{
252-
name: "missing required parameter workflow_file",
241+
name: "successful workflow run by numeric ID as string",
242+
mockedClient: mock.NewMockedHTTPClient(
243+
mock.WithRequestMatchHandler(
244+
mock.PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowId,
245+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
246+
w.WriteHeader(http.StatusNoContent)
247+
}),
248+
),
249+
),
250+
requestArgs: map[string]any{
251+
"owner": "owner",
252+
"repo": "repo",
253+
"workflow_id": "12345",
254+
"ref": "main",
255+
},
256+
expectError: false,
257+
},
258+
{
259+
name: "missing required parameter workflow_id",
253260
mockedClient: mock.NewMockedHTTPClient(),
254261
requestArgs: map[string]any{
255262
"owner": "owner",
256263
"repo": "repo",
257264
"ref": "main",
258265
},
259266
expectError: true,
260-
expectedErrMsg: "missing required parameter: workflow_file",
267+
expectedErrMsg: "missing required parameter: workflow_id",
261268
},
262269
}
263270

264271
for _, tc := range tests {
265272
t.Run(tc.name, func(t *testing.T) {
266273
// Setup client with mock
267274
client := github.NewClient(tc.mockedClient)
268-
_, handler := RunWorkflowByFileName(stubGetClientFn(client), translations.NullTranslationHelper)
275+
_, handler := RunWorkflow(stubGetClientFn(client), translations.NullTranslationHelper)
269276

270277
// Create call request
271278
request := createMCPRequest(tc.requestArgs)
@@ -289,6 +296,7 @@ func Test_RunWorkflowByFileName(t *testing.T) {
289296
err = json.Unmarshal([]byte(textContent.Text), &response)
290297
require.NoError(t, err)
291298
assert.Equal(t, "Workflow run has been queued", response["message"])
299+
assert.Contains(t, response, "workflow_type")
292300
})
293301
}
294302
}

pkg/github/tools.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG
125125
).
126126
AddWriteTools(
127127
toolsets.NewServerTool(RunWorkflow(getClient, t)),
128-
toolsets.NewServerTool(RunWorkflowByFileName(getClient, t)),
129128
toolsets.NewServerTool(RerunWorkflowRun(getClient, t)),
130129
toolsets.NewServerTool(RerunFailedJobs(getClient, t)),
131130
toolsets.NewServerTool(CancelWorkflowRun(getClient, t)),

0 commit comments

Comments
 (0)