From 30b40081bc0132e6b8a37f8e2a7f8e6868c07d70 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 09:40:35 +0000 Subject: [PATCH 01/15] refactor: lift task creation to coderd Instead of creating tasks with a specialised call to `CreateWorkspace` on the frontend, we instead lift this to the backend and allow the frontend to simply call `CreateAITask`. --- coderd/aitasks.go | 59 ++++++++++++++++++++++++++ coderd/aitasks_test.go | 55 ++++++++++++++++++++++++ coderd/coderd.go | 1 + codersdk/aitasks.go | 26 ++++++++++++ site/src/api/api.ts | 11 +++++ site/src/api/typesGenerated.ts | 8 ++++ site/src/pages/TasksPage/TasksPage.tsx | 10 ++--- 7 files changed, 163 insertions(+), 7 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index a982ccc39b26b..d9bc812ab48c9 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -7,7 +7,10 @@ import ( "github.com/google/uuid" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -61,3 +64,59 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) { Prompts: promptsByBuildID, }) } + +// This endpoint is experimental and not guaranteed to be stable, so we're not +// generating public-facing documentation for it. +func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + apiKey = httpmw.APIKey(r) + auditor = api.Auditor.Load() + ) + + var req codersdk.CreateAITasksRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + user, err := api.Database.GetUserByID(ctx, apiKey.UserID) + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching user.", + Detail: err.Error(), + }) + return + } + + createReq := codersdk.CreateWorkspaceRequest{ + Name: req.Name, + TemplateVersionID: req.TemplateVersionID, + TemplateVersionPresetID: req.TemplateVersionPresetID, + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: "AI Prompt", Value: req.Prompt}, + }, + } + + owner := workspaceOwner{ + ID: user.ID, + Username: user.Username, + AvatarURL: user.AvatarURL, + } + + aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{ + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, + AdditionalFields: audit.AdditionalFields{ + WorkspaceOwner: owner.Username, + }, + }) + + defer commitAudit() + createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, rw, r) +} diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 53f0174d6f03d..ca3378d16d3a9 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -139,3 +139,58 @@ func TestAITasksPrompts(t *testing.T) { require.Empty(t, prompts.Prompts) }) } + +func TestAITasksCreate(t *testing.T) { + t.Parallel() + + makeEchoResponses := func(parameters []*proto.RichParameter) *echo.Responses { + return &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: echo.ApplyComplete, + ProvisionPlan: []*proto.Response{ + {Type: &proto.Response_Plan{Plan: &proto.PlanComplete{Parameters: parameters}}}, + }, + } + } + + t.Run("OK", func(t *testing.T) { + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template with an "AI Prompt" + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, makeEchoResponses([]*proto.RichParameter{ + {Name: "AI Prompt", Type: "string"}, + })) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task. + workspace, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + Name: taskName, + TemplateVersionID: template.ActiveVersionID, + Prompt: taskPrompt, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + // Then: We expect a workspace to have been created. + require.Equal(t, taskName, workspace.Name) + require.Equal(t, template.ID, workspace.TemplateID) + + // And: We expect it to have the "AI Prompt" parameter correctly set. + parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) + require.NoError(t, err) + require.Len(t, parameters, 1) + require.Equal(t, "AI Prompt", parameters[0].Name) + require.Equal(t, taskPrompt, parameters[0].Value) + }) +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 78ae849fd1894..85c5e83da3fcd 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -994,6 +994,7 @@ func New(options *Options) *API { r.Use(apiKeyMiddleware) r.Route("/aitasks", func(r chi.Router) { r.Get("/prompts", api.aiTasksPrompts) + r.Post("/", api.aiTasksCreate) }) r.Route("/mcp", func(r chi.Router) { r.Use( diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index 89ca9c948f272..9598fa117a664 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -44,3 +44,29 @@ func (c *ExperimentalClient) AITaskPrompts(ctx context.Context, buildIDs []uuid. var prompts AITasksPromptsResponse return prompts, json.NewDecoder(res.Body).Decode(&prompts) } + +type CreateAITasksRequest struct { + Name string `json:"name"` + TemplateVersionID uuid.UUID `json:"template_version_id"` + TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty"` + Prompt string `json:"prompt"` +} + +func (c *ExperimentalClient) AITasksCreate(ctx context.Context, request CreateAITasksRequest) (Workspace, error) { + res, err := c.Request(ctx, http.MethodPost, "/api/experimental/aitasks/", request) + if err != nil { + return Workspace{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusCreated { + return Workspace{}, ReadBodyAsError(res) + } + + var workspace Workspace + if err := json.NewDecoder(res.Body).Decode(&workspace); err != nil { + return Workspace{}, err + } + + return workspace, nil +} diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2b21ddf1e8a08..0185eea6940cc 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2665,6 +2665,17 @@ class ExperimentalApiMethods { return response.data; }; + + createAITask = async ( + req: TypesGen.CreateAITasksRequest, + ): Promise => { + const response = await this.axios.post( + "/api/experimental/aitasks", + req, + ); + + return response.data; + }; } // This is a hard coded CSRF token/cookie pair for local development. In prod, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 52fdb1d6effc4..d47a486db3eb4 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -422,6 +422,14 @@ export interface ConvertLoginRequest { readonly password: string; } +// From codersdk/aitasks.go +export interface CreateAITasksRequest { + readonly name: string; + readonly template_version_id: string; + readonly template_version_preset_id?: string; + readonly prompt: string; +} + // From codersdk/users.go export interface CreateFirstUserRequest { readonly email: string; diff --git a/site/src/pages/TasksPage/TasksPage.tsx b/site/src/pages/TasksPage/TasksPage.tsx index ce6ddea380046..c40216e46432e 100644 --- a/site/src/pages/TasksPage/TasksPage.tsx +++ b/site/src/pages/TasksPage/TasksPage.tsx @@ -232,7 +232,6 @@ type TaskFormProps = { }; const TaskForm: FC = ({ templates, onSuccess }) => { - const { user } = useAuthenticated(); const queryClient = useQueryClient(); const [selectedTemplateId, setSelectedTemplateId] = useState( templates[0].id, @@ -293,7 +292,7 @@ const TaskForm: FC = ({ templates, onSuccess }) => { templateVersionId, presetId, }: CreateTaskMutationFnProps) => - data.createTask(prompt, user.id, templateVersionId, presetId), + data.createTask(prompt, templateVersionId, presetId), onSuccess: async (task) => { await queryClient.invalidateQueries({ queryKey: ["tasks"], @@ -727,7 +726,6 @@ export const data = { async createTask( prompt: string, - userId: string, templateVersionId: string, presetId: string | null = null, ): Promise { @@ -741,13 +739,11 @@ export const data = { } } - const workspace = await API.createWorkspace(userId, { + const workspace = await API.experimental.createAITask({ name: `task-${generateWorkspaceName()}`, template_version_id: templateVersionId, template_version_preset_id: preset_id || undefined, - rich_parameter_values: [ - { name: AI_PROMPT_PARAMETER_NAME, value: prompt }, - ], + prompt, }); return { From 16eee744ac591ad1fa9c1fe5a8e02c9686dedc87 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 10:49:20 +0000 Subject: [PATCH 02/15] fix: add parallel to test --- coderd/aitasks_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index ca3378d16d3a9..1a74f66d41da3 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -154,6 +154,8 @@ func TestAITasksCreate(t *testing.T) { } t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( ctx = testutil.Context(t, testutil.WaitShort) From 0cd6fde0c5d7e04a36d5576c354fbd4b2f34586d Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 10:49:38 +0000 Subject: [PATCH 03/15] fix: add `format:"uuid"` to `uuid.UUID` fields --- codersdk/aitasks.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index 9598fa117a664..000acf44cac7b 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -47,8 +47,8 @@ func (c *ExperimentalClient) AITaskPrompts(ctx context.Context, buildIDs []uuid. type CreateAITasksRequest struct { Name string `json:"name"` - TemplateVersionID uuid.UUID `json:"template_version_id"` - TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty"` + TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"` + TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"` Prompt string `json:"prompt"` } From d9e7e08fffc61392ae6f45b81916e438e62b586b Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 11:36:13 +0000 Subject: [PATCH 04/15] fix: return error if template does not have an "AI Prompt" parameter --- coderd/aitasks.go | 15 +++++++ coderd/aitasks_test.go | 44 +++++++++++++++++--- coderd/database/dbauthz/dbauthz.go | 11 +++++ coderd/database/dbauthz/dbauthz_test.go | 14 +++++++ coderd/database/dbmetrics/querymetrics.go | 7 ++++ coderd/database/dbmock/dbmock.go | 15 +++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 15 +++++++ coderd/database/queries/templateversions.sql | 7 ++++ 9 files changed, 124 insertions(+), 5 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index d9bc812ab48c9..a9330547dcbe4 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -92,6 +92,21 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { return } + hasAIPrompt, err := api.Database.GetTemplateVersionHasAIPrompt(ctx, req.TemplateVersionID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching if template version has ai prompt.", + Detail: err.Error(), + }) + return + } + if !hasAIPrompt { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: `Template does not have required parameter "AI Prompt"`, + }) + return + } + createReq := codersdk.CreateWorkspaceRequest{ Name: req.Name, TemplateVersionID: req.TemplateVersionID, diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 1a74f66d41da3..39fa686b63815 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -1,9 +1,11 @@ package coderd_test import ( + "net/http" "testing" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" @@ -166,7 +168,7 @@ func TestAITasksCreate(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) user := coderdtest.CreateFirstUser(t, client) - // Given: A template with an "AI Prompt" + // Given: A template with an "AI Prompt" parameter version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, makeEchoResponses([]*proto.RichParameter{ {Name: "AI Prompt", Type: "string"}, })) @@ -185,14 +187,46 @@ func TestAITasksCreate(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) // Then: We expect a workspace to have been created. - require.Equal(t, taskName, workspace.Name) - require.Equal(t, template.ID, workspace.TemplateID) + assert.Equal(t, taskName, workspace.Name) + assert.Equal(t, template.ID, workspace.TemplateID) // And: We expect it to have the "AI Prompt" parameter correctly set. parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) require.NoError(t, err) require.Len(t, parameters, 1) - require.Equal(t, "AI Prompt", parameters[0].Name) - require.Equal(t, taskPrompt, parameters[0].Value) + assert.Equal(t, "AI Prompt", parameters[0].Name) + assert.Equal(t, taskPrompt, parameters[0].Value) + }) + + t.Run("FailsOnNonTaskTemplate", func(t *testing.T) { + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template without an "AI Prompt" parameter + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task. + _, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + Name: taskName, + TemplateVersionID: template.ActiveVersionID, + Prompt: taskPrompt, + }) + + // Then: We expect it to fail. + var sdkErr *codersdk.Error + require.Error(t, err) + require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error") + assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) }) } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index d5cc334f5ff7f..a69475bcab6d8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2874,6 +2874,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg return tv, nil } +func (q *querier) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { + // If we can successfully call `GetTemplateVersionByID`, then + // we know the actor has sufficient permissions to know if the + // template has an AI Prompt. + if _, err := q.GetTemplateVersionByID(ctx, id); err != nil { + return false, err + } + + return q.db.GetTemplateVersionHasAIPrompt(ctx, id) +} + func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { // An actor can read template version parameters if they can read the related template. tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index a55f9c37aa4f5..b026c2a09abab 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() { }) check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead) })) + s.Run("GetTemplateVersionHasAIPrompt", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + t := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + OrganizationID: o.ID, + TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true}, + CreatedBy: u.ID, + }) + check.Args(tv.ID).Asserts(t, policy.ActionRead) + })) s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u := dbgen.User(s.T(), db, database.User{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e0606f9e40665..7aaf99894e116 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1538,6 +1538,13 @@ func (m queryMetricsStore) GetTemplateVersionByTemplateIDAndName(ctx context.Con return version, err } +func (m queryMetricsStore) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { + start := time.Now() + r0, r1 := m.s.GetTemplateVersionHasAIPrompt(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateVersionHasAIPrompt").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { start := time.Now() parameters, err := m.s.GetTemplateVersionParameters(ctx, templateVersionID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 22807f0e3569d..f51b795f547f9 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3271,6 +3271,21 @@ func (mr *MockStoreMockRecorder) GetTemplateVersionByTemplateIDAndName(ctx, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionByTemplateIDAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionByTemplateIDAndName), ctx, arg) } +// GetTemplateVersionHasAIPrompt mocks base method. +func (m *MockStore) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTemplateVersionHasAIPrompt", ctx, id) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTemplateVersionHasAIPrompt indicates an expected call of GetTemplateVersionHasAIPrompt. +func (mr *MockStoreMockRecorder) GetTemplateVersionHasAIPrompt(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAIPrompt", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAIPrompt), ctx, id) +} + // GetTemplateVersionParameters mocks base method. func (m *MockStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a0f265e9658ce..e840df95c1a64 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -355,6 +355,7 @@ type sqlcQuerier interface { GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error) + GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionParameter, error) GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (TemplateVersionTerraformValue, error) GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionVariable, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 74cefd09359b0..ec7f63aae0f17 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -12870,6 +12870,21 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, return i, err } +const getTemplateVersionHasAIPrompt = `-- name: GetTemplateVersionHasAIPrompt :one +SELECT EXISTS ( + SELECT 1 + FROM template_versions + WHERE id = $1 AND has_ai_task = TRUE +) +` + +func (q *sqlQuerier) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { + row := q.db.QueryRowContext(ctx, getTemplateVersionHasAIPrompt, id) + var exists bool + err := row.Scan(&exists) + return exists, err +} + const getTemplateVersionsByIDs = `-- name: GetTemplateVersionsByIDs :many SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, has_ai_task, created_by_avatar_url, created_by_username, created_by_name diff --git a/coderd/database/queries/templateversions.sql b/coderd/database/queries/templateversions.sql index 5cf59fab30272..6cc39beeb6b97 100644 --- a/coderd/database/queries/templateversions.sql +++ b/coderd/database/queries/templateversions.sql @@ -234,3 +234,10 @@ FROM WHERE template_versions.id IN (archived_versions.id) RETURNING template_versions.id; + +-- name: GetTemplateVersionHasAIPrompt :one +SELECT EXISTS ( + SELECT 1 + FROM template_versions + WHERE id = $1 AND has_ai_task = TRUE +); From e4695fff55cfc16d365b6144b4e1709b1475312e Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 11:44:57 +0000 Subject: [PATCH 05/15] fix: add parallel to test --- coderd/aitasks_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 39fa686b63815..cbd8d124fb587 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -199,6 +199,8 @@ func TestAITasksCreate(t *testing.T) { }) t.Run("FailsOnNonTaskTemplate", func(t *testing.T) { + t.Parallel() + var ( ctx = testutil.Context(t, testutil.WaitShort) From 0ad6ce39fdbc7088972e2b179f3837a35af40687 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 12:39:36 +0000 Subject: [PATCH 06/15] fix: broken test --- coderd/aitasks_test.go | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index cbd8d124fb587..5e54b50fd287d 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -145,16 +145,6 @@ func TestAITasksPrompts(t *testing.T) { func TestAITasksCreate(t *testing.T) { t.Parallel() - makeEchoResponses := func(parameters []*proto.RichParameter) *echo.Responses { - return &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionApply: echo.ApplyComplete, - ProvisionPlan: []*proto.Response{ - {Type: &proto.Response_Plan{Plan: &proto.PlanComplete{Parameters: parameters}}}, - }, - } - } - t.Run("OK", func(t *testing.T) { t.Parallel() @@ -169,9 +159,16 @@ func TestAITasksCreate(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) // Given: A template with an "AI Prompt" parameter - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, makeEchoResponses([]*proto.RichParameter{ - {Name: "AI Prompt", Type: "string"}, - })) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: echo.ApplyComplete, + ProvisionPlan: []*proto.Response{ + {Type: &proto.Response_Plan{Plan: &proto.PlanComplete{ + Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}}, + HasAiTasks: true, + }}}, + }, + }) coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) From 66a4fd79514ef204f341472ba85200f7fdde7796 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 14:17:27 +0000 Subject: [PATCH 07/15] test: fail on an invalid template --- coderd/aitasks.go | 8 ++++++++ coderd/aitasks_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index a9330547dcbe4..493011961c988 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -1,6 +1,8 @@ package coderd import ( + "database/sql" + "errors" "fmt" "net/http" "strings" @@ -11,6 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" ) @@ -94,6 +97,11 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { hasAIPrompt, err := api.Database.GetTemplateVersionHasAIPrompt(ctx, req.TemplateVersionID) if err != nil { + if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching if template version has ai prompt.", Detail: err.Error(), diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 5e54b50fd287d..d87b45a9ca786 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -228,4 +228,36 @@ func TestAITasksCreate(t *testing.T) { require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error") assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) }) + + t.Run("FailsOnInvalidTemplate", func(t *testing.T) { + var ( + ctx = testutil.Context(t, testutil.WaitShort) + + taskName = "task-foo-bar-baz" + taskPrompt = "Some task prompt" + ) + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Given: A template + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + expClient := codersdk.NewExperimentalClient(client) + + // When: We attempt to create a Task with an invalid template version ID. + _, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + Name: taskName, + TemplateVersionID: uuid.New(), + Prompt: taskPrompt, + }) + + // Then: We expect it to fail. + var sdkErr *codersdk.Error + require.Error(t, err) + require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error") + assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) } From 6e14061eb86e1cb92a709efda59fdd07c395fb13 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 14:22:08 +0000 Subject: [PATCH 08/15] refactor: remove magic strings --- coderd/aitasks.go | 4 ++-- coderd/aitasks_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index 493011961c988..e514d38479a79 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -110,7 +110,7 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { } if !hasAIPrompt { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: `Template does not have required parameter "AI Prompt"`, + Message: `Template does not have required parameter "` + codersdk.AITaskPromptParameterName + `"`, }) return } @@ -120,7 +120,7 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { TemplateVersionID: req.TemplateVersionID, TemplateVersionPresetID: req.TemplateVersionPresetID, RichParameterValues: []codersdk.WorkspaceBuildParameter{ - {Name: "AI Prompt", Value: req.Prompt}, + {Name: codersdk.AITaskPromptParameterName, Value: req.Prompt}, }, } diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index d87b45a9ca786..9b98a30dc05f4 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -191,7 +191,7 @@ func TestAITasksCreate(t *testing.T) { parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) require.NoError(t, err) require.Len(t, parameters, 1) - assert.Equal(t, "AI Prompt", parameters[0].Name) + assert.Equal(t, codersdk.AITaskPromptParameterName, parameters[0].Name) assert.Equal(t, taskPrompt, parameters[0].Value) }) From ae212d06e4954216c26beb0efbee219bcd278834 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 14:27:09 +0000 Subject: [PATCH 09/15] fix: add parallel to test (yet again) --- coderd/aitasks_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 9b98a30dc05f4..a73deefcce579 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -230,6 +230,8 @@ func TestAITasksCreate(t *testing.T) { }) t.Run("FailsOnInvalidTemplate", func(t *testing.T) { + t.Parallel() + var ( ctx = testutil.Context(t, testutil.WaitShort) From 5a671e841f349b638cca7da13c5034c4e2eb4084 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 14:35:42 +0000 Subject: [PATCH 10/15] fix: add api rate limiter to ai tasks create --- coderd/coderd.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 85c5e83da3fcd..6e9e94335609d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -994,7 +994,8 @@ func New(options *Options) *API { r.Use(apiKeyMiddleware) r.Route("/aitasks", func(r chi.Router) { r.Get("/prompts", api.aiTasksPrompts) - r.Post("/", api.aiTasksCreate) + + r.With(apiRateLimiter).Post("/", api.aiTasksCreate) }) r.Route("/mcp", func(r chi.Router) { r.Use( From 08cab33360af1e6caadc8567ba4111dc51591bac Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Mon, 11 Aug 2025 14:45:59 +0000 Subject: [PATCH 11/15] refactor: names used --- coderd/aitasks.go | 6 +++--- coderd/database/dbauthz/dbauthz.go | 6 +++--- coderd/database/dbauthz/dbauthz_test.go | 2 +- coderd/database/dbmetrics/querymetrics.go | 6 +++--- coderd/database/dbmock/dbmock.go | 12 ++++++------ coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 6 +++--- coderd/database/queries/templateversions.sql | 2 +- codersdk/aitasks.go | 2 +- 9 files changed, 22 insertions(+), 22 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index e514d38479a79..c4bf6d16c4e54 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -95,7 +95,7 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { return } - hasAIPrompt, err := api.Database.GetTemplateVersionHasAIPrompt(ctx, req.TemplateVersionID) + hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID) if err != nil { if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) { httpapi.ResourceNotFound(rw) @@ -103,12 +103,12 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching if template version has ai prompt.", + Message: "Internal error fetching whether the template version has an AI task.", Detail: err.Error(), }) return } - if !hasAIPrompt { + if !hasAITask { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: `Template does not have required parameter "` + codersdk.AITaskPromptParameterName + `"`, }) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a69475bcab6d8..69041d2aac859 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2874,15 +2874,15 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg return tv, nil } -func (q *querier) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { +func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { // If we can successfully call `GetTemplateVersionByID`, then // we know the actor has sufficient permissions to know if the - // template has an AI Prompt. + // template has an AI task. if _, err := q.GetTemplateVersionByID(ctx, id); err != nil { return false, err } - return q.db.GetTemplateVersionHasAIPrompt(ctx, id) + return q.db.GetTemplateVersionHasAITask(ctx, id) } func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index b026c2a09abab..c3ecfbf8e4502 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1443,7 +1443,7 @@ func (s *MethodTestSuite) TestTemplate() { }) check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead) })) - s.Run("GetTemplateVersionHasAIPrompt", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetTemplateVersionHasAITask", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u := dbgen.User(s.T(), db, database.User{}) t := dbgen.Template(s.T(), db, database.Template{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 7aaf99894e116..cc852113d0af0 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1538,10 +1538,10 @@ func (m queryMetricsStore) GetTemplateVersionByTemplateIDAndName(ctx context.Con return version, err } -func (m queryMetricsStore) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { +func (m queryMetricsStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { start := time.Now() - r0, r1 := m.s.GetTemplateVersionHasAIPrompt(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateVersionHasAIPrompt").Observe(time.Since(start).Seconds()) + r0, r1 := m.s.GetTemplateVersionHasAITask(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateVersionHasAITask").Observe(time.Since(start).Seconds()) return r0, r1 } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f51b795f547f9..24f57ffffb6f8 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3271,19 +3271,19 @@ func (mr *MockStoreMockRecorder) GetTemplateVersionByTemplateIDAndName(ctx, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionByTemplateIDAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionByTemplateIDAndName), ctx, arg) } -// GetTemplateVersionHasAIPrompt mocks base method. -func (m *MockStore) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { +// GetTemplateVersionHasAITask mocks base method. +func (m *MockStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTemplateVersionHasAIPrompt", ctx, id) + ret := m.ctrl.Call(m, "GetTemplateVersionHasAITask", ctx, id) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetTemplateVersionHasAIPrompt indicates an expected call of GetTemplateVersionHasAIPrompt. -func (mr *MockStoreMockRecorder) GetTemplateVersionHasAIPrompt(ctx, id any) *gomock.Call { +// GetTemplateVersionHasAITask indicates an expected call of GetTemplateVersionHasAITask. +func (mr *MockStoreMockRecorder) GetTemplateVersionHasAITask(ctx, id any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAIPrompt", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAIPrompt), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAITask", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAITask), ctx, id) } // GetTemplateVersionParameters mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index e840df95c1a64..fe81863d72ac0 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -355,7 +355,7 @@ type sqlcQuerier interface { GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error) - GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) + GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionParameter, error) GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (TemplateVersionTerraformValue, error) GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionVariable, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ec7f63aae0f17..c03e028a7a4e6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -12870,7 +12870,7 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, return i, err } -const getTemplateVersionHasAIPrompt = `-- name: GetTemplateVersionHasAIPrompt :one +const getTemplateVersionHasAITask = `-- name: GetTemplateVersionHasAITask :one SELECT EXISTS ( SELECT 1 FROM template_versions @@ -12878,8 +12878,8 @@ SELECT EXISTS ( ) ` -func (q *sqlQuerier) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) { - row := q.db.QueryRowContext(ctx, getTemplateVersionHasAIPrompt, id) +func (q *sqlQuerier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) { + row := q.db.QueryRowContext(ctx, getTemplateVersionHasAITask, id) var exists bool err := row.Scan(&exists) return exists, err diff --git a/coderd/database/queries/templateversions.sql b/coderd/database/queries/templateversions.sql index 6cc39beeb6b97..97fb6bd9ecc08 100644 --- a/coderd/database/queries/templateversions.sql +++ b/coderd/database/queries/templateversions.sql @@ -235,7 +235,7 @@ WHERE template_versions.id IN (archived_versions.id) RETURNING template_versions.id; --- name: GetTemplateVersionHasAIPrompt :one +-- name: GetTemplateVersionHasAITask :one SELECT EXISTS ( SELECT 1 FROM template_versions diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index 000acf44cac7b..c3e09f432f72b 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -53,7 +53,7 @@ type CreateAITasksRequest struct { } func (c *ExperimentalClient) AITasksCreate(ctx context.Context, request CreateAITasksRequest) (Workspace, error) { - res, err := c.Request(ctx, http.MethodPost, "/api/experimental/aitasks/", request) + res, err := c.Request(ctx, http.MethodPost, "/api/experimental/aitasks", request) if err != nil { return Workspace{}, err } From 27b58230e070d17de99855857514940d1671cc9d Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Tue, 12 Aug 2025 08:15:35 +0000 Subject: [PATCH 12/15] chore: `aiTasksCreate` -> `tasksCreate` --- coderd/aitasks.go | 22 +++++----------------- coderd/coderd.go | 5 ++++- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index c4bf6d16c4e54..2e8b3e2036b63 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -70,11 +70,12 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) { // This endpoint is experimental and not guaranteed to be stable, so we're not // generating public-facing documentation for it. -func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { +func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() apiKey = httpmw.APIKey(r) auditor = api.Auditor.Load() + member = httpmw.OrganizationMemberParam(r) ) var req codersdk.CreateAITasksRequest @@ -82,19 +83,6 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { return } - user, err := api.Database.GetUserByID(ctx, apiKey.UserID) - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) - return - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching user.", - Detail: err.Error(), - }) - return - } - hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID) if err != nil { if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) { @@ -125,9 +113,9 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) { } owner := workspaceOwner{ - ID: user.ID, - Username: user.Username, - AvatarURL: user.AvatarURL, + ID: member.UserID, + Username: member.Username, + AvatarURL: member.AvatarURL, } aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{ diff --git a/coderd/coderd.go b/coderd/coderd.go index 6e9e94335609d..34784a1fbb67e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -994,8 +994,11 @@ func New(options *Options) *API { r.Use(apiKeyMiddleware) r.Route("/aitasks", func(r chi.Router) { r.Get("/prompts", api.aiTasksPrompts) + }) + r.Route("/tasks", func(r chi.Router) { + r.Use(apiRateLimiter) - r.With(apiRateLimiter).Post("/", api.aiTasksCreate) + r.Post("/{user}", api.tasksCreate) }) r.Route("/mcp", func(r chi.Router) { r.Use( From 9b5a99ce2979c44d6857f6a6a5216d2ddb62a00c Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Tue, 12 Aug 2025 08:42:50 +0000 Subject: [PATCH 13/15] chore: finish making that change --- coderd/aitasks.go | 52 +++++++++++++++++++++++--- coderd/aitasks_test.go | 8 ++-- coderd/coderd.go | 6 ++- codersdk/aitasks.go | 7 ++-- site/src/api/api.ts | 7 ++-- site/src/api/typesGenerated.ts | 16 ++++---- site/src/pages/TasksPage/TasksPage.tsx | 6 ++- 7 files changed, 75 insertions(+), 27 deletions(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index 2e8b3e2036b63..83f6e24f0f2d6 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "slices" "strings" "github.com/google/uuid" @@ -75,10 +76,10 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { ctx = r.Context() apiKey = httpmw.APIKey(r) auditor = api.Auditor.Load() - member = httpmw.OrganizationMemberParam(r) + mems = httpmw.OrganizationMembersParam(r) ) - var req codersdk.CreateAITasksRequest + var req codersdk.CreateTaskRequest if !httpapi.Read(ctx, rw, r, &req) { return } @@ -112,10 +113,49 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { }, } - owner := workspaceOwner{ - ID: member.UserID, - Username: member.Username, - AvatarURL: member.AvatarURL, + var owner workspaceOwner + if mems.User != nil { + // This user fetch is an optimization path for the most common case of creating a + // task for 'Me'. + // + // This is also required to allow `owners` to create workspaces for users + // that are not in an organization. + owner = workspaceOwner{ + ID: mems.User.ID, + Username: mems.User.Username, + AvatarURL: mems.User.AvatarURL, + } + } else { + // A task can still be created if the caller can read the organization + // member. The organization is required, which can be sourced from the + // template. + // + // TODO: This code gets called twice for each workspace build request. + // This is inefficient and costs at most 2 extra RTTs to the DB. + // This can be optimized. It exists as it is now for code simplicity. + // The most common case is to create a workspace for 'Me'. Which does + // not enter this code branch. + template, ok := requestTemplate(ctx, rw, createReq, api.Database) + if !ok { + return + } + + // If the caller can find the organization membership in the same org + // as the template, then they can continue. + orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool { + return mem.OrganizationID == template.OrganizationID + }) + if orgIndex == -1 { + httpapi.ResourceNotFound(rw) + return + } + + member := mems.Memberships[orgIndex] + owner = workspaceOwner{ + ID: member.UserID, + Username: member.Username, + AvatarURL: member.AvatarURL, + } } aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{ diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index a73deefcce579..8d12dd3a5ec95 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -142,7 +142,7 @@ func TestAITasksPrompts(t *testing.T) { }) } -func TestAITasksCreate(t *testing.T) { +func TestTaskCreate(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { @@ -175,7 +175,7 @@ func TestAITasksCreate(t *testing.T) { expClient := codersdk.NewExperimentalClient(client) // When: We attempt to create a Task. - workspace, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + workspace, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ Name: taskName, TemplateVersionID: template.ActiveVersionID, Prompt: taskPrompt, @@ -216,7 +216,7 @@ func TestAITasksCreate(t *testing.T) { expClient := codersdk.NewExperimentalClient(client) // When: We attempt to create a Task. - _, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + _, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ Name: taskName, TemplateVersionID: template.ActiveVersionID, Prompt: taskPrompt, @@ -250,7 +250,7 @@ func TestAITasksCreate(t *testing.T) { expClient := codersdk.NewExperimentalClient(client) // When: We attempt to create a Task with an invalid template version ID. - _, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{ + _, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ Name: taskName, TemplateVersionID: uuid.New(), Prompt: taskPrompt, diff --git a/coderd/coderd.go b/coderd/coderd.go index 34784a1fbb67e..2aa30c9d7a45c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -998,7 +998,11 @@ func New(options *Options) *API { r.Route("/tasks", func(r chi.Router) { r.Use(apiRateLimiter) - r.Post("/{user}", api.tasksCreate) + r.Route("/{user}", func(r chi.Router) { + r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize)) + + r.Post("/", api.tasksCreate) + }) }) r.Route("/mcp", func(r chi.Router) { r.Use( diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index c3e09f432f72b..49d89bf5e2656 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -3,6 +3,7 @@ package codersdk import ( "context" "encoding/json" + "fmt" "net/http" "strings" @@ -45,15 +46,15 @@ func (c *ExperimentalClient) AITaskPrompts(ctx context.Context, buildIDs []uuid. return prompts, json.NewDecoder(res.Body).Decode(&prompts) } -type CreateAITasksRequest struct { +type CreateTaskRequest struct { Name string `json:"name"` TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"` TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"` Prompt string `json:"prompt"` } -func (c *ExperimentalClient) AITasksCreate(ctx context.Context, request CreateAITasksRequest) (Workspace, error) { - res, err := c.Request(ctx, http.MethodPost, "/api/experimental/aitasks", request) +func (c *ExperimentalClient) CreateTask(ctx context.Context, user string, request CreateTaskRequest) (Workspace, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/tasks/%s", user), request) if err != nil { return Workspace{}, err } diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 0185eea6940cc..b9d5f06924519 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2666,11 +2666,12 @@ class ExperimentalApiMethods { return response.data; }; - createAITask = async ( - req: TypesGen.CreateAITasksRequest, + createTask = async ( + user: string, + req: TypesGen.CreateTaskRequest, ): Promise => { const response = await this.axios.post( - "/api/experimental/aitasks", + `/api/experimental/tasks/${user}`, req, ); diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index d47a486db3eb4..6f5ab307a2fa8 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -422,14 +422,6 @@ export interface ConvertLoginRequest { readonly password: string; } -// From codersdk/aitasks.go -export interface CreateAITasksRequest { - readonly name: string; - readonly template_version_id: string; - readonly template_version_preset_id?: string; - readonly prompt: string; -} - // From codersdk/users.go export interface CreateFirstUserRequest { readonly email: string; @@ -484,6 +476,14 @@ export interface CreateProvisionerKeyResponse { readonly key: string; } +// From codersdk/aitasks.go +export interface CreateTaskRequest { + readonly name: string; + readonly template_version_id: string; + readonly template_version_preset_id?: string; + readonly prompt: string; +} + // From codersdk/organizations.go export interface CreateTemplateRequest { readonly name: string; diff --git a/site/src/pages/TasksPage/TasksPage.tsx b/site/src/pages/TasksPage/TasksPage.tsx index c40216e46432e..2f6405e796134 100644 --- a/site/src/pages/TasksPage/TasksPage.tsx +++ b/site/src/pages/TasksPage/TasksPage.tsx @@ -232,6 +232,7 @@ type TaskFormProps = { }; const TaskForm: FC = ({ templates, onSuccess }) => { + const { user } = useAuthenticated(); const queryClient = useQueryClient(); const [selectedTemplateId, setSelectedTemplateId] = useState( templates[0].id, @@ -292,7 +293,7 @@ const TaskForm: FC = ({ templates, onSuccess }) => { templateVersionId, presetId, }: CreateTaskMutationFnProps) => - data.createTask(prompt, templateVersionId, presetId), + data.createTask(prompt, user.id, templateVersionId, presetId), onSuccess: async (task) => { await queryClient.invalidateQueries({ queryKey: ["tasks"], @@ -726,6 +727,7 @@ export const data = { async createTask( prompt: string, + userId: string, templateVersionId: string, presetId: string | null = null, ): Promise { @@ -739,7 +741,7 @@ export const data = { } } - const workspace = await API.experimental.createAITask({ + const workspace = await API.experimental.createTask(userId, { name: `task-${generateWorkspaceName()}`, template_version_id: templateVersionId, template_version_preset_id: preset_id || undefined, From e1d9fd4830b29bde689f1221b4d12496c0b102ce Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Tue, 12 Aug 2025 08:43:51 +0000 Subject: [PATCH 14/15] chore: sprintf --- coderd/aitasks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index 83f6e24f0f2d6..5c1df349d7bc2 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -99,7 +99,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { } if !hasAITask { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: `Template does not have required parameter "` + codersdk.AITaskPromptParameterName + `"`, + Message: fmt.Sprintf(`Template does not have required parameter "%s"`, codersdk.AITaskPromptParameterName), }) return } From dc041de40bfa0212435d083216b4c1ccfabbd92f Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Tue, 12 Aug 2025 08:44:27 +0000 Subject: [PATCH 15/15] chore: oops --- coderd/aitasks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/aitasks.go b/coderd/aitasks.go index 5c1df349d7bc2..e1d72f264a025 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -99,7 +99,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { } if !hasAITask { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf(`Template does not have required parameter "%s"`, codersdk.AITaskPromptParameterName), + Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName), }) return }