Skip to content

Commit 96f963c

Browse files
committed
Add test for workspace resources
1 parent dfa73b3 commit 96f963c

File tree

2 files changed

+267
-5
lines changed

2 files changed

+267
-5
lines changed

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
489489
return nil, xerrors.Errorf("get job by id: %w", err)
490490
}
491491
if job.WorkerID.UUID.String() != server.ID.String() {
492-
return nil, xerrors.Errorf("you don't have permission to update this job")
492+
return nil, xerrors.Errorf("you don't own this job")
493493
}
494494

495495
telemetrySnapshot := &telemetry.Snapshot{}
@@ -509,7 +509,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
509509
slog.F("resource_type", resource.Type),
510510
slog.F("transition", transition))
511511

512-
err = insertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot)
512+
err = InsertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot)
513513
if err != nil {
514514
return nil, xerrors.Errorf("insert resource: %w", err)
515515
}
@@ -578,7 +578,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
578578
}
579579
// This could be a bulk insert to improve performance.
580580
for _, protoResource := range jobType.WorkspaceBuild.Resources {
581-
err = insertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot)
581+
err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot)
582582
if err != nil {
583583
return xerrors.Errorf("insert provisioner job: %w", err)
584584
}
@@ -614,7 +614,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
614614
slog.F("resource_name", resource.Name),
615615
slog.F("resource_type", resource.Type))
616616

617-
err = insertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
617+
err = InsertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
618618
if err != nil {
619619
return nil, xerrors.Errorf("insert resource: %w", err)
620620
}
@@ -637,6 +637,9 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
637637
}
638638

639639
default:
640+
if completed.Type == nil {
641+
return nil, xerrors.Errorf("type payload must be provided")
642+
}
640643
return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match",
641644
reflect.TypeOf(completed.Type).String())
642645
}
@@ -655,7 +658,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
655658
return &proto.Empty{}, nil
656659
}
657660

658-
func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error {
661+
func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error {
659662
resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{
660663
ID: uuid.New(),
661664
CreatedAt: database.Now(),

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"cdr.dev/slog/sloggers/slogtest"
1414
"github.com/coder/coder/coderd/database"
15+
"github.com/coder/coder/coderd/database/databasefake"
1516
"github.com/coder/coder/coderd/database/dbtestutil"
1617
"github.com/coder/coder/coderd/provisionerdserver"
1718
"github.com/coder/coder/coderd/telemetry"
@@ -498,6 +499,264 @@ func TestFailJob(t *testing.T) {
498499
})
499500
}
500501

502+
func TestCompleteJob(t *testing.T) {
503+
t.Parallel()
504+
ctx := context.Background()
505+
t.Run("NotFound", func(t *testing.T) {
506+
t.Parallel()
507+
srv := setup(t)
508+
_, err := srv.CompleteJob(ctx, &proto.CompletedJob{
509+
JobId: "hello",
510+
})
511+
require.ErrorContains(t, err, "invalid UUID")
512+
513+
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
514+
JobId: uuid.NewString(),
515+
})
516+
require.ErrorContains(t, err, "no rows in result set")
517+
})
518+
// This test prevents runners from updating jobs they don't own!
519+
t.Run("NotOwner", func(t *testing.T) {
520+
t.Parallel()
521+
srv := setup(t)
522+
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
523+
ID: uuid.New(),
524+
Provisioner: database.ProvisionerTypeEcho,
525+
})
526+
require.NoError(t, err)
527+
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
528+
WorkerID: uuid.NullUUID{
529+
UUID: uuid.New(),
530+
Valid: true,
531+
},
532+
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
533+
})
534+
require.NoError(t, err)
535+
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
536+
JobId: job.ID.String(),
537+
})
538+
require.ErrorContains(t, err, "you don't own this job")
539+
})
540+
t.Run("TemplateImport", func(t *testing.T) {
541+
t.Parallel()
542+
srv := setup(t)
543+
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
544+
ID: uuid.New(),
545+
Provisioner: database.ProvisionerTypeEcho,
546+
})
547+
require.NoError(t, err)
548+
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
549+
WorkerID: uuid.NullUUID{
550+
UUID: srv.ID,
551+
Valid: true,
552+
},
553+
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
554+
})
555+
require.NoError(t, err)
556+
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
557+
JobId: job.ID.String(),
558+
Type: &proto.CompletedJob_TemplateImport_{
559+
TemplateImport: &proto.CompletedJob_TemplateImport{
560+
StartResources: []*sdkproto.Resource{{
561+
Name: "hello",
562+
Type: "aws_instance",
563+
}},
564+
StopResources: []*sdkproto.Resource{},
565+
},
566+
},
567+
})
568+
require.NoError(t, err)
569+
})
570+
t.Run("WorkspaceBuild", func(t *testing.T) {
571+
t.Parallel()
572+
srv := setup(t)
573+
workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{
574+
ID: uuid.New(),
575+
})
576+
require.NoError(t, err)
577+
build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
578+
ID: uuid.New(),
579+
WorkspaceID: workspace.ID,
580+
Transition: database.WorkspaceTransitionDelete,
581+
})
582+
require.NoError(t, err)
583+
input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
584+
WorkspaceBuildID: build.ID,
585+
})
586+
require.NoError(t, err)
587+
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
588+
ID: uuid.New(),
589+
Provisioner: database.ProvisionerTypeEcho,
590+
Input: input,
591+
})
592+
require.NoError(t, err)
593+
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
594+
WorkerID: uuid.NullUUID{
595+
UUID: srv.ID,
596+
Valid: true,
597+
},
598+
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
599+
})
600+
require.NoError(t, err)
601+
602+
publishedWorkspace := make(chan struct{})
603+
closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) {
604+
close(publishedWorkspace)
605+
})
606+
require.NoError(t, err)
607+
defer closeWorkspaceSubscribe()
608+
publishedLogs := make(chan struct{})
609+
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
610+
close(publishedLogs)
611+
})
612+
require.NoError(t, err)
613+
defer closeLogsSubscribe()
614+
615+
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
616+
JobId: job.ID.String(),
617+
Type: &proto.CompletedJob_WorkspaceBuild_{
618+
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{
619+
State: []byte{},
620+
Resources: []*sdkproto.Resource{{
621+
Name: "example",
622+
Type: "aws_instance",
623+
}},
624+
},
625+
},
626+
})
627+
require.NoError(t, err)
628+
629+
<-publishedWorkspace
630+
<-publishedLogs
631+
632+
workspace, err = srv.Database.GetWorkspaceByID(ctx, workspace.ID)
633+
require.NoError(t, err)
634+
require.True(t, workspace.Deleted)
635+
})
636+
637+
t.Run("TemplateDryRun", func(t *testing.T) {
638+
t.Parallel()
639+
srv := setup(t)
640+
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
641+
ID: uuid.New(),
642+
Provisioner: database.ProvisionerTypeEcho,
643+
})
644+
require.NoError(t, err)
645+
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
646+
WorkerID: uuid.NullUUID{
647+
UUID: srv.ID,
648+
Valid: true,
649+
},
650+
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
651+
})
652+
require.NoError(t, err)
653+
654+
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
655+
JobId: job.ID.String(),
656+
Type: &proto.CompletedJob_TemplateDryRun_{
657+
TemplateDryRun: &proto.CompletedJob_TemplateDryRun{
658+
Resources: []*sdkproto.Resource{{
659+
Name: "something",
660+
Type: "aws_instance",
661+
}},
662+
},
663+
},
664+
})
665+
require.NoError(t, err)
666+
})
667+
}
668+
669+
func TestInsertWorkspaceResource(t *testing.T) {
670+
t.Parallel()
671+
ctx := context.Background()
672+
insert := func(db database.Store, jobID uuid.UUID, resource *sdkproto.Resource) error {
673+
return provisionerdserver.InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, &telemetry.Snapshot{})
674+
}
675+
t.Run("NoAgents", func(t *testing.T) {
676+
t.Parallel()
677+
db, _ := dbtestutil.NewDB(t)
678+
job := uuid.New()
679+
err := insert(db, job, &sdkproto.Resource{
680+
Name: "something",
681+
Type: "aws_instance",
682+
})
683+
require.NoError(t, err)
684+
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
685+
require.NoError(t, err)
686+
require.Len(t, resources, 1)
687+
})
688+
t.Run("InvalidAgentToken", func(t *testing.T) {
689+
t.Parallel()
690+
err := insert(databasefake.New(), uuid.New(), &sdkproto.Resource{
691+
Name: "something",
692+
Type: "aws_instance",
693+
Agents: []*sdkproto.Agent{{
694+
Auth: &sdkproto.Agent_Token{
695+
Token: "bananas",
696+
},
697+
}},
698+
})
699+
require.ErrorContains(t, err, "invalid UUID length")
700+
})
701+
t.Run("DuplicateApps", func(t *testing.T) {
702+
t.Parallel()
703+
err := insert(databasefake.New(), uuid.New(), &sdkproto.Resource{
704+
Name: "something",
705+
Type: "aws_instance",
706+
Agents: []*sdkproto.Agent{{
707+
Apps: []*sdkproto.App{{
708+
Slug: "a",
709+
}, {
710+
Slug: "a",
711+
}},
712+
}},
713+
})
714+
require.ErrorContains(t, err, "duplicate app slug")
715+
})
716+
t.Run("Success", func(t *testing.T) {
717+
t.Parallel()
718+
db, _ := dbtestutil.NewDB(t)
719+
job := uuid.New()
720+
err := insert(db, job, &sdkproto.Resource{
721+
Name: "something",
722+
Type: "aws_instance",
723+
Agents: []*sdkproto.Agent{{
724+
Name: "dev",
725+
Env: map[string]string{
726+
"something": "test",
727+
},
728+
StartupScript: "value",
729+
OperatingSystem: "linux",
730+
Architecture: "amd64",
731+
Auth: &sdkproto.Agent_Token{
732+
Token: uuid.NewString(),
733+
},
734+
Apps: []*sdkproto.App{{
735+
Slug: "a",
736+
}},
737+
}},
738+
})
739+
require.NoError(t, err)
740+
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
741+
require.NoError(t, err)
742+
require.Len(t, resources, 1)
743+
agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID})
744+
require.NoError(t, err)
745+
require.Len(t, agents, 1)
746+
agent := agents[0]
747+
require.Equal(t, "amd64", agent.Architecture)
748+
require.Equal(t, "linux", agent.OperatingSystem)
749+
require.Equal(t, "value", agent.StartupScript.String)
750+
want, err := json.Marshal(map[string]string{
751+
"something": "test",
752+
})
753+
require.NoError(t, err)
754+
got, err := agent.EnvironmentVariables.RawMessage.MarshalJSON()
755+
require.NoError(t, err)
756+
require.Equal(t, want, got)
757+
})
758+
}
759+
501760
func setup(t *testing.T) *provisionerdserver.Server {
502761
t.Helper()
503762
db, pubsub := dbtestutil.NewDB(t)

0 commit comments

Comments
 (0)