Skip to content

Commit 0238f29

Browse files
authored
feat: persist AI task state in template imports & workspace builds (#18449)
1 parent 6cc4cfa commit 0238f29

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2773
-530
lines changed

cli/testdata/coder_list_--output_json.golden

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
"available": 0,
6969
"most_recently_seen": null
7070
},
71-
"template_version_preset_id": null
71+
"template_version_preset_id": null,
72+
"has_ai_task": false
7273
},
7374
"latest_app_status": null,
7475
"outdated": false,

coderd/database/dbauthz/dbauthz.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4518,6 +4518,28 @@ func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.U
45184518
return update(q.log, q.auth, fetch, q.db.UpdateTemplateScheduleByID)(ctx, arg)
45194519
}
45204520

4521+
func (q *querier) UpdateTemplateVersionAITaskByJobID(ctx context.Context, arg database.UpdateTemplateVersionAITaskByJobIDParams) error {
4522+
// An actor is allowed to update the template version AI task flag if they are authorized to update the template.
4523+
tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID)
4524+
if err != nil {
4525+
return err
4526+
}
4527+
var obj rbac.Objecter
4528+
if !tv.TemplateID.Valid {
4529+
obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID)
4530+
} else {
4531+
tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID)
4532+
if err != nil {
4533+
return err
4534+
}
4535+
obj = tpl
4536+
}
4537+
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
4538+
return err
4539+
}
4540+
return q.db.UpdateTemplateVersionAITaskByJobID(ctx, arg)
4541+
}
4542+
45214543
func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) error {
45224544
// An actor is allowed to update the template version if they are authorized to update the template.
45234545
tv, err := q.db.GetTemplateVersionByID(ctx, arg.ID)
@@ -4874,6 +4896,24 @@ func (q *querier) UpdateWorkspaceAutostart(ctx context.Context, arg database.Upd
48744896
return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceAutostart)(ctx, arg)
48754897
}
48764898

4899+
func (q *querier) UpdateWorkspaceBuildAITaskByID(ctx context.Context, arg database.UpdateWorkspaceBuildAITaskByIDParams) error {
4900+
build, err := q.db.GetWorkspaceBuildByID(ctx, arg.ID)
4901+
if err != nil {
4902+
return err
4903+
}
4904+
4905+
workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID)
4906+
if err != nil {
4907+
return err
4908+
}
4909+
4910+
err = q.authorizeContext(ctx, policy.ActionUpdate, workspace.RBACObject())
4911+
if err != nil {
4912+
return err
4913+
}
4914+
return q.db.UpdateWorkspaceBuildAITaskByID(ctx, arg)
4915+
}
4916+
48774917
// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build.
48784918
func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) error {
48794919
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,24 @@ func (s *MethodTestSuite) TestTemplate() {
13911391
ID: t1.ID,
13921392
}).Asserts(t1, policy.ActionUpdate)
13931393
}))
1394+
s.Run("UpdateTemplateVersionAITaskByJobID", s.Subtest(func(db database.Store, check *expects) {
1395+
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
1396+
o := dbgen.Organization(s.T(), db, database.Organization{})
1397+
u := dbgen.User(s.T(), db, database.User{})
1398+
_ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: o.ID, UserID: u.ID})
1399+
t := dbgen.Template(s.T(), db, database.Template{OrganizationID: o.ID, CreatedBy: u.ID})
1400+
job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{OrganizationID: o.ID})
1401+
_ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
1402+
OrganizationID: o.ID,
1403+
CreatedBy: u.ID,
1404+
JobID: job.ID,
1405+
TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true},
1406+
})
1407+
check.Args(database.UpdateTemplateVersionAITaskByJobIDParams{
1408+
JobID: job.ID,
1409+
HasAITask: sql.NullBool{Bool: true, Valid: true},
1410+
}).Asserts(t, policy.ActionUpdate)
1411+
}))
13941412
s.Run("UpdateTemplateWorkspacesLastUsedAt", s.Subtest(func(db database.Store, check *expects) {
13951413
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
13961414
t1 := dbgen.Template(s.T(), db, database.Template{})
@@ -3050,6 +3068,40 @@ func (s *MethodTestSuite) TestWorkspace() {
30503068
Deadline: b.Deadline,
30513069
}).Asserts(w, policy.ActionUpdate)
30523070
}))
3071+
s.Run("UpdateWorkspaceBuildAITaskByID", s.Subtest(func(db database.Store, check *expects) {
3072+
u := dbgen.User(s.T(), db, database.User{})
3073+
o := dbgen.Organization(s.T(), db, database.Organization{})
3074+
tpl := dbgen.Template(s.T(), db, database.Template{
3075+
OrganizationID: o.ID,
3076+
CreatedBy: u.ID,
3077+
})
3078+
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
3079+
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
3080+
OrganizationID: o.ID,
3081+
CreatedBy: u.ID,
3082+
})
3083+
w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
3084+
TemplateID: tpl.ID,
3085+
OrganizationID: o.ID,
3086+
OwnerID: u.ID,
3087+
})
3088+
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
3089+
Type: database.ProvisionerJobTypeWorkspaceBuild,
3090+
})
3091+
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{
3092+
JobID: j.ID,
3093+
WorkspaceID: w.ID,
3094+
TemplateVersionID: tv.ID,
3095+
})
3096+
res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: b.JobID})
3097+
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID})
3098+
app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID})
3099+
check.Args(database.UpdateWorkspaceBuildAITaskByIDParams{
3100+
HasAITask: sql.NullBool{Bool: true, Valid: true},
3101+
SidebarAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
3102+
ID: b.ID,
3103+
}).Asserts(w, policy.ActionUpdate)
3104+
}))
30533105
s.Run("SoftDeleteWorkspaceByID", s.Subtest(func(db database.Store, check *expects) {
30543106
u := dbgen.User(s.T(), db, database.User{})
30553107
o := dbgen.Organization(s.T(), db, database.Organization{})

coderd/database/dbgen/dbgen.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
390390
t.Helper()
391391

392392
buildID := takeFirst(orig.ID, uuid.New())
393+
jobID := takeFirst(orig.JobID, uuid.New())
394+
hasAITask := takeFirst(orig.HasAITask, sql.NullBool{})
395+
sidebarAppID := takeFirst(orig.AITaskSidebarAppID, uuid.NullUUID{})
393396
var build database.WorkspaceBuild
394397
err := db.InTx(func(db database.Store) error {
395398
err := db.InsertWorkspaceBuild(genCtx, database.InsertWorkspaceBuildParams{
@@ -401,7 +404,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
401404
BuildNumber: takeFirst(orig.BuildNumber, 1),
402405
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
403406
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
404-
JobID: takeFirst(orig.JobID, uuid.New()),
407+
JobID: jobID,
405408
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
406409
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
407410
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
@@ -410,7 +413,6 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
410413
UUID: uuid.UUID{},
411414
Valid: false,
412415
}),
413-
HasAITask: orig.HasAITask,
414416
})
415417
if err != nil {
416418
return err
@@ -424,6 +426,15 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
424426
require.NoError(t, err)
425427
}
426428

429+
if hasAITask.Valid {
430+
require.NoError(t, db.UpdateWorkspaceBuildAITaskByID(genCtx, database.UpdateWorkspaceBuildAITaskByIDParams{
431+
HasAITask: hasAITask,
432+
SidebarAppID: sidebarAppID,
433+
UpdatedAt: dbtime.Now(),
434+
ID: buildID,
435+
}))
436+
}
437+
427438
build, err = db.GetWorkspaceBuildByID(genCtx, buildID)
428439
if err != nil {
429440
return err
@@ -971,6 +982,8 @@ func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAut
971982

972983
func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVersion) database.TemplateVersion {
973984
var version database.TemplateVersion
985+
hasAITask := takeFirst(orig.HasAITask, sql.NullBool{})
986+
jobID := takeFirst(orig.JobID, uuid.New())
974987
err := db.InTx(func(db database.Store) error {
975988
versionID := takeFirst(orig.ID, uuid.New())
976989
err := db.InsertTemplateVersion(genCtx, database.InsertTemplateVersionParams{
@@ -982,15 +995,22 @@ func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVers
982995
Name: takeFirst(orig.Name, testutil.GetRandomName(t)),
983996
Message: orig.Message,
984997
Readme: takeFirst(orig.Readme, testutil.GetRandomName(t)),
985-
JobID: takeFirst(orig.JobID, uuid.New()),
998+
JobID: jobID,
986999
CreatedBy: takeFirst(orig.CreatedBy, uuid.New()),
9871000
SourceExampleID: takeFirst(orig.SourceExampleID, sql.NullString{}),
988-
HasAITask: orig.HasAITask,
9891001
})
9901002
if err != nil {
9911003
return err
9921004
}
9931005

1006+
if hasAITask.Valid {
1007+
require.NoError(t, db.UpdateTemplateVersionAITaskByJobID(genCtx, database.UpdateTemplateVersionAITaskByJobIDParams{
1008+
JobID: jobID,
1009+
HasAITask: hasAITask,
1010+
UpdatedAt: dbtime.Now(),
1011+
}))
1012+
}
1013+
9941014
version, err = db.GetTemplateVersionByID(genCtx, versionID)
9951015
if err != nil {
9961016
return err

coderd/database/dbmem/dbmem.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9465,7 +9465,6 @@ func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.Inse
94659465
JobID: arg.JobID,
94669466
CreatedBy: arg.CreatedBy,
94679467
SourceExampleID: arg.SourceExampleID,
9468-
HasAITask: arg.HasAITask,
94699468
}
94709469
q.templateVersions = append(q.templateVersions, version)
94719470
return nil
@@ -10103,7 +10102,6 @@ func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.Inser
1010310102
MaxDeadline: arg.MaxDeadline,
1010410103
Reason: arg.Reason,
1010510104
TemplateVersionPresetID: arg.TemplateVersionPresetID,
10106-
HasAITask: arg.HasAITask,
1010710105
}
1010810106
q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild)
1010910107
return nil
@@ -11308,6 +11306,26 @@ func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database
1130811306
return sql.ErrNoRows
1130911307
}
1131011308

11309+
func (q *FakeQuerier) UpdateTemplateVersionAITaskByJobID(_ context.Context, arg database.UpdateTemplateVersionAITaskByJobIDParams) error {
11310+
if err := validateDatabaseType(arg); err != nil {
11311+
return err
11312+
}
11313+
11314+
q.mutex.Lock()
11315+
defer q.mutex.Unlock()
11316+
11317+
for index, templateVersion := range q.templateVersions {
11318+
if templateVersion.JobID != arg.JobID {
11319+
continue
11320+
}
11321+
templateVersion.HasAITask = arg.HasAITask
11322+
templateVersion.UpdatedAt = arg.UpdatedAt
11323+
q.templateVersions[index] = templateVersion
11324+
return nil
11325+
}
11326+
return sql.ErrNoRows
11327+
}
11328+
1131111329
func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) error {
1131211330
if err := validateDatabaseType(arg); err != nil {
1131311331
return err
@@ -12003,6 +12021,35 @@ func (q *FakeQuerier) UpdateWorkspaceAutostart(_ context.Context, arg database.U
1200312021
return sql.ErrNoRows
1200412022
}
1200512023

12024+
func (q *FakeQuerier) UpdateWorkspaceBuildAITaskByID(_ context.Context, arg database.UpdateWorkspaceBuildAITaskByIDParams) error {
12025+
if arg.HasAITask.Bool && !arg.SidebarAppID.Valid {
12026+
return xerrors.Errorf("ai_task_sidebar_app_id is required when has_ai_task is true")
12027+
}
12028+
if !arg.HasAITask.Valid && arg.SidebarAppID.Valid {
12029+
return xerrors.Errorf("ai_task_sidebar_app_id is can only be set when has_ai_task is true")
12030+
}
12031+
12032+
err := validateDatabaseType(arg)
12033+
if err != nil {
12034+
return err
12035+
}
12036+
12037+
q.mutex.Lock()
12038+
defer q.mutex.Unlock()
12039+
12040+
for index, workspaceBuild := range q.workspaceBuilds {
12041+
if workspaceBuild.ID != arg.ID {
12042+
continue
12043+
}
12044+
workspaceBuild.HasAITask = arg.HasAITask
12045+
workspaceBuild.AITaskSidebarAppID = arg.SidebarAppID
12046+
workspaceBuild.UpdatedAt = dbtime.Now()
12047+
q.workspaceBuilds[index] = workspaceBuild
12048+
return nil
12049+
}
12050+
return sql.ErrNoRows
12051+
}
12052+
1200612053
func (q *FakeQuerier) UpdateWorkspaceBuildCostByID(_ context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) error {
1200712054
if err := validateDatabaseType(arg); err != nil {
1200812055
return err

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dump.sql

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/foreign_key_constraint.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)