Skip to content

Commit f349edc

Browse files
refactor: create tasks in coderd instead of frontend (#19280)
Instead of creating tasks with a specialized call to `CreateWorkspace` on the frontend, we instead lift this to the backend and allow the frontend to simply call `CreateAITask`.
1 parent cda1a3a commit f349edc

File tree

14 files changed

+362
-4
lines changed

14 files changed

+362
-4
lines changed

coderd/aitasks.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package coderd
22

33
import (
4+
"database/sql"
5+
"errors"
46
"fmt"
57
"net/http"
8+
"slices"
69
"strings"
710

811
"github.com/google/uuid"
912

13+
"github.com/coder/coder/v2/coderd/audit"
14+
"github.com/coder/coder/v2/coderd/database"
1015
"github.com/coder/coder/v2/coderd/httpapi"
16+
"github.com/coder/coder/v2/coderd/httpmw"
17+
"github.com/coder/coder/v2/coderd/rbac"
1118
"github.com/coder/coder/v2/codersdk"
1219
)
1320

@@ -61,3 +68,106 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
6168
Prompts: promptsByBuildID,
6269
})
6370
}
71+
72+
// This endpoint is experimental and not guaranteed to be stable, so we're not
73+
// generating public-facing documentation for it.
74+
func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
75+
var (
76+
ctx = r.Context()
77+
apiKey = httpmw.APIKey(r)
78+
auditor = api.Auditor.Load()
79+
mems = httpmw.OrganizationMembersParam(r)
80+
)
81+
82+
var req codersdk.CreateTaskRequest
83+
if !httpapi.Read(ctx, rw, r, &req) {
84+
return
85+
}
86+
87+
hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID)
88+
if err != nil {
89+
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
90+
httpapi.ResourceNotFound(rw)
91+
return
92+
}
93+
94+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
95+
Message: "Internal error fetching whether the template version has an AI task.",
96+
Detail: err.Error(),
97+
})
98+
return
99+
}
100+
if !hasAITask {
101+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
102+
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
103+
})
104+
return
105+
}
106+
107+
createReq := codersdk.CreateWorkspaceRequest{
108+
Name: req.Name,
109+
TemplateVersionID: req.TemplateVersionID,
110+
TemplateVersionPresetID: req.TemplateVersionPresetID,
111+
RichParameterValues: []codersdk.WorkspaceBuildParameter{
112+
{Name: codersdk.AITaskPromptParameterName, Value: req.Prompt},
113+
},
114+
}
115+
116+
var owner workspaceOwner
117+
if mems.User != nil {
118+
// This user fetch is an optimization path for the most common case of creating a
119+
// task for 'Me'.
120+
//
121+
// This is also required to allow `owners` to create workspaces for users
122+
// that are not in an organization.
123+
owner = workspaceOwner{
124+
ID: mems.User.ID,
125+
Username: mems.User.Username,
126+
AvatarURL: mems.User.AvatarURL,
127+
}
128+
} else {
129+
// A task can still be created if the caller can read the organization
130+
// member. The organization is required, which can be sourced from the
131+
// template.
132+
//
133+
// TODO: This code gets called twice for each workspace build request.
134+
// This is inefficient and costs at most 2 extra RTTs to the DB.
135+
// This can be optimized. It exists as it is now for code simplicity.
136+
// The most common case is to create a workspace for 'Me'. Which does
137+
// not enter this code branch.
138+
template, ok := requestTemplate(ctx, rw, createReq, api.Database)
139+
if !ok {
140+
return
141+
}
142+
143+
// If the caller can find the organization membership in the same org
144+
// as the template, then they can continue.
145+
orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool {
146+
return mem.OrganizationID == template.OrganizationID
147+
})
148+
if orgIndex == -1 {
149+
httpapi.ResourceNotFound(rw)
150+
return
151+
}
152+
153+
member := mems.Memberships[orgIndex]
154+
owner = workspaceOwner{
155+
ID: member.UserID,
156+
Username: member.Username,
157+
AvatarURL: member.AvatarURL,
158+
}
159+
}
160+
161+
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
162+
Audit: *auditor,
163+
Log: api.Logger,
164+
Request: r,
165+
Action: database.AuditActionCreate,
166+
AdditionalFields: audit.AdditionalFields{
167+
WorkspaceOwner: owner.Username,
168+
},
169+
})
170+
171+
defer commitAudit()
172+
createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, rw, r)
173+
}

coderd/aitasks_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package coderd_test
22

33
import (
4+
"net/http"
45
"testing"
56

67
"github.com/google/uuid"
8+
"github.com/stretchr/testify/assert"
79
"github.com/stretchr/testify/require"
810

911
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -139,3 +141,125 @@ func TestAITasksPrompts(t *testing.T) {
139141
require.Empty(t, prompts.Prompts)
140142
})
141143
}
144+
145+
func TestTaskCreate(t *testing.T) {
146+
t.Parallel()
147+
148+
t.Run("OK", func(t *testing.T) {
149+
t.Parallel()
150+
151+
var (
152+
ctx = testutil.Context(t, testutil.WaitShort)
153+
154+
taskName = "task-foo-bar-baz"
155+
taskPrompt = "Some task prompt"
156+
)
157+
158+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
159+
user := coderdtest.CreateFirstUser(t, client)
160+
161+
// Given: A template with an "AI Prompt" parameter
162+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
163+
Parse: echo.ParseComplete,
164+
ProvisionApply: echo.ApplyComplete,
165+
ProvisionPlan: []*proto.Response{
166+
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
167+
Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}},
168+
HasAiTasks: true,
169+
}}},
170+
},
171+
})
172+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
173+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
174+
175+
expClient := codersdk.NewExperimentalClient(client)
176+
177+
// When: We attempt to create a Task.
178+
workspace, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
179+
Name: taskName,
180+
TemplateVersionID: template.ActiveVersionID,
181+
Prompt: taskPrompt,
182+
})
183+
require.NoError(t, err)
184+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
185+
186+
// Then: We expect a workspace to have been created.
187+
assert.Equal(t, taskName, workspace.Name)
188+
assert.Equal(t, template.ID, workspace.TemplateID)
189+
190+
// And: We expect it to have the "AI Prompt" parameter correctly set.
191+
parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
192+
require.NoError(t, err)
193+
require.Len(t, parameters, 1)
194+
assert.Equal(t, codersdk.AITaskPromptParameterName, parameters[0].Name)
195+
assert.Equal(t, taskPrompt, parameters[0].Value)
196+
})
197+
198+
t.Run("FailsOnNonTaskTemplate", func(t *testing.T) {
199+
t.Parallel()
200+
201+
var (
202+
ctx = testutil.Context(t, testutil.WaitShort)
203+
204+
taskName = "task-foo-bar-baz"
205+
taskPrompt = "Some task prompt"
206+
)
207+
208+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
209+
user := coderdtest.CreateFirstUser(t, client)
210+
211+
// Given: A template without an "AI Prompt" parameter
212+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
213+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
214+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
215+
216+
expClient := codersdk.NewExperimentalClient(client)
217+
218+
// When: We attempt to create a Task.
219+
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
220+
Name: taskName,
221+
TemplateVersionID: template.ActiveVersionID,
222+
Prompt: taskPrompt,
223+
})
224+
225+
// Then: We expect it to fail.
226+
var sdkErr *codersdk.Error
227+
require.Error(t, err)
228+
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
229+
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
230+
})
231+
232+
t.Run("FailsOnInvalidTemplate", func(t *testing.T) {
233+
t.Parallel()
234+
235+
var (
236+
ctx = testutil.Context(t, testutil.WaitShort)
237+
238+
taskName = "task-foo-bar-baz"
239+
taskPrompt = "Some task prompt"
240+
)
241+
242+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
243+
user := coderdtest.CreateFirstUser(t, client)
244+
245+
// Given: A template
246+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
247+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
248+
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
249+
250+
expClient := codersdk.NewExperimentalClient(client)
251+
252+
// When: We attempt to create a Task with an invalid template version ID.
253+
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
254+
Name: taskName,
255+
TemplateVersionID: uuid.New(),
256+
Prompt: taskPrompt,
257+
})
258+
259+
// Then: We expect it to fail.
260+
var sdkErr *codersdk.Error
261+
require.Error(t, err)
262+
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
263+
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
264+
})
265+
}

coderd/coderd.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,15 @@ func New(options *Options) *API {
995995
r.Route("/aitasks", func(r chi.Router) {
996996
r.Get("/prompts", api.aiTasksPrompts)
997997
})
998+
r.Route("/tasks", func(r chi.Router) {
999+
r.Use(apiRateLimiter)
1000+
1001+
r.Route("/{user}", func(r chi.Router) {
1002+
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
1003+
1004+
r.Post("/", api.tasksCreate)
1005+
})
1006+
})
9981007
r.Route("/mcp", func(r chi.Router) {
9991008
r.Use(
10001009
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),

coderd/database/dbauthz/dbauthz.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg
28632863
return tv, nil
28642864
}
28652865

2866+
func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
2867+
// If we can successfully call `GetTemplateVersionByID`, then
2868+
// we know the actor has sufficient permissions to know if the
2869+
// template has an AI task.
2870+
if _, err := q.GetTemplateVersionByID(ctx, id); err != nil {
2871+
return false, err
2872+
}
2873+
2874+
return q.db.GetTemplateVersionHasAITask(ctx, id)
2875+
}
2876+
28662877
func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
28672878
// An actor can read template version parameters if they can read the related template.
28682879
tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() {
14431443
})
14441444
check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead)
14451445
}))
1446+
s.Run("GetTemplateVersionHasAITask", s.Subtest(func(db database.Store, check *expects) {
1447+
o := dbgen.Organization(s.T(), db, database.Organization{})
1448+
u := dbgen.User(s.T(), db, database.User{})
1449+
t := dbgen.Template(s.T(), db, database.Template{
1450+
OrganizationID: o.ID,
1451+
CreatedBy: u.ID,
1452+
})
1453+
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
1454+
OrganizationID: o.ID,
1455+
TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true},
1456+
CreatedBy: u.ID,
1457+
})
1458+
check.Args(tv.ID).Asserts(t, policy.ActionRead)
1459+
}))
14461460
s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) {
14471461
o := dbgen.Organization(s.T(), db, database.Organization{})
14481462
u := dbgen.User(s.T(), db, database.User{})

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/templateversions.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,10 @@ FROM
234234
WHERE
235235
template_versions.id IN (archived_versions.id)
236236
RETURNING template_versions.id;
237+
238+
-- name: GetTemplateVersionHasAITask :one
239+
SELECT EXISTS (
240+
SELECT 1
241+
FROM template_versions
242+
WHERE id = $1 AND has_ai_task = TRUE
243+
);

0 commit comments

Comments
 (0)