Skip to content

Commit d9e7e08

Browse files
fix: return error if template does not have an "AI Prompt" parameter
1 parent 0cd6fde commit d9e7e08

File tree

9 files changed

+124
-5
lines changed

9 files changed

+124
-5
lines changed

coderd/aitasks.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,21 @@ func (api *API) aiTasksCreate(rw http.ResponseWriter, r *http.Request) {
9292
return
9393
}
9494

95+
hasAIPrompt, err := api.Database.GetTemplateVersionHasAIPrompt(ctx, req.TemplateVersionID)
96+
if err != nil {
97+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
98+
Message: "Internal error fetching if template version has ai prompt.",
99+
Detail: err.Error(),
100+
})
101+
return
102+
}
103+
if !hasAIPrompt {
104+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
105+
Message: `Template does not have required parameter "AI Prompt"`,
106+
})
107+
return
108+
}
109+
95110
createReq := codersdk.CreateWorkspaceRequest{
96111
Name: req.Name,
97112
TemplateVersionID: req.TemplateVersionID,

coderd/aitasks_test.go

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package coderd_test
22

33
import (
4+
"net/http"
45
"testing"
56

67
"github.com/google/uuid"
8+
"github.com/stretchr/testify/assert"
79
"github.com/stretchr/testify/require"
810

911
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -166,7 +168,7 @@ func TestAITasksCreate(t *testing.T) {
166168
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
167169
user := coderdtest.CreateFirstUser(t, client)
168170

169-
// Given: A template with an "AI Prompt"
171+
// Given: A template with an "AI Prompt" parameter
170172
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, makeEchoResponses([]*proto.RichParameter{
171173
{Name: "AI Prompt", Type: "string"},
172174
}))
@@ -185,14 +187,46 @@ func TestAITasksCreate(t *testing.T) {
185187
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
186188

187189
// Then: We expect a workspace to have been created.
188-
require.Equal(t, taskName, workspace.Name)
189-
require.Equal(t, template.ID, workspace.TemplateID)
190+
assert.Equal(t, taskName, workspace.Name)
191+
assert.Equal(t, template.ID, workspace.TemplateID)
190192

191193
// And: We expect it to have the "AI Prompt" parameter correctly set.
192194
parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
193195
require.NoError(t, err)
194196
require.Len(t, parameters, 1)
195-
require.Equal(t, "AI Prompt", parameters[0].Name)
196-
require.Equal(t, taskPrompt, parameters[0].Value)
197+
assert.Equal(t, "AI Prompt", parameters[0].Name)
198+
assert.Equal(t, taskPrompt, parameters[0].Value)
199+
})
200+
201+
t.Run("FailsOnNonTaskTemplate", func(t *testing.T) {
202+
var (
203+
ctx = testutil.Context(t, testutil.WaitShort)
204+
205+
taskName = "task-foo-bar-baz"
206+
taskPrompt = "Some task prompt"
207+
)
208+
209+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
210+
user := coderdtest.CreateFirstUser(t, client)
211+
212+
// Given: A template without an "AI Prompt" parameter
213+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
214+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
215+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
216+
217+
expClient := codersdk.NewExperimentalClient(client)
218+
219+
// When: We attempt to create a Task.
220+
_, err := expClient.AITasksCreate(ctx, codersdk.CreateAITasksRequest{
221+
Name: taskName,
222+
TemplateVersionID: template.ActiveVersionID,
223+
Prompt: taskPrompt,
224+
})
225+
226+
// Then: We expect it to fail.
227+
var sdkErr *codersdk.Error
228+
require.Error(t, err)
229+
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
230+
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
197231
})
198232
}

coderd/database/dbauthz/dbauthz.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,6 +2874,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg
28742874
return tv, nil
28752875
}
28762876

2877+
func (q *querier) GetTemplateVersionHasAIPrompt(ctx context.Context, id uuid.UUID) (bool, error) {
2878+
// If we can successfully call `GetTemplateVersionByID`, then
2879+
// we know the actor has sufficient permissions to know if the
2880+
// template has an AI Prompt.
2881+
if _, err := q.GetTemplateVersionByID(ctx, id); err != nil {
2882+
return false, err
2883+
}
2884+
2885+
return q.db.GetTemplateVersionHasAIPrompt(ctx, id)
2886+
}
2887+
28772888
func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
28782889
// An actor can read template version parameters if they can read the related template.
28792890
tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() {
14431443
})
14441444
check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead)
14451445
}))
1446+
s.Run("GetTemplateVersionHasAIPrompt", s.Subtest(func(db database.Store, check *expects) {
1447+
o := dbgen.Organization(s.T(), db, database.Organization{})
1448+
u := dbgen.User(s.T(), db, database.User{})
1449+
t := dbgen.Template(s.T(), db, database.Template{
1450+
OrganizationID: o.ID,
1451+
CreatedBy: u.ID,
1452+
})
1453+
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
1454+
OrganizationID: o.ID,
1455+
TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true},
1456+
CreatedBy: u.ID,
1457+
})
1458+
check.Args(tv.ID).Asserts(t, policy.ActionRead)
1459+
}))
14461460
s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) {
14471461
o := dbgen.Organization(s.T(), db, database.Organization{})
14481462
u := dbgen.User(s.T(), db, database.User{})

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/templateversions.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,10 @@ FROM
234234
WHERE
235235
template_versions.id IN (archived_versions.id)
236236
RETURNING template_versions.id;
237+
238+
-- name: GetTemplateVersionHasAIPrompt :one
239+
SELECT EXISTS (
240+
SELECT 1
241+
FROM template_versions
242+
WHERE id = $1 AND has_ai_task = TRUE
243+
);

0 commit comments

Comments
 (0)