Skip to content

Commit 6ad169d

Browse files
committed
update provisionerd test
1 parent 905cd1a commit 6ad169d

File tree

2 files changed

+89
-31
lines changed

2 files changed

+89
-31
lines changed

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,18 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
195195
}
196196
}
197197

198-
sessionToken, err := server.regenerateSessionToken(ctx, owner, workspace)
199-
if err != nil {
200-
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
198+
var sessionToken string
199+
switch workspaceBuild.Transition {
200+
case database.WorkspaceTransitionStart:
201+
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
202+
if err != nil {
203+
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
204+
}
205+
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
206+
err = server.deleteSessionToken(ctx, workspace)
207+
if err != nil {
208+
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
209+
}
201210
}
202211

203212
// Compute parameters for the workspace to consume.
@@ -1434,35 +1443,35 @@ func (server *Server) regenerateSessionToken(ctx context.Context, user database.
14341443
return "", xerrors.Errorf("generate API key: %w", err)
14351444
}
14361445

1437-
err = server.Database.InTx(
1438-
func(tx database.Store) error {
1439-
key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
1440-
UserID: workspace.OwnerID,
1441-
TokenName: workspaceSessionTokenName(workspace),
1442-
})
1443-
if err == nil {
1444-
err = tx.DeleteAPIKeyByID(ctx, key.ID)
1445-
if err != nil {
1446-
return xerrors.Errorf("delete api key: %w", err)
1447-
}
1448-
}
1449-
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
1450-
return xerrors.Errorf("get api key by name: %w", err)
1451-
}
1452-
1453-
_, err = tx.InsertAPIKey(ctx, newkey)
1454-
if err != nil {
1455-
return xerrors.Errorf("insert API key: %w", err)
1456-
}
1446+
err = server.deleteSessionToken(ctx, workspace)
1447+
if err != nil {
1448+
return "", xerrors.Errorf("delete session token: %w", err)
1449+
}
14571450

1458-
return nil
1459-
}, nil)
1451+
_, err = server.Database.InsertAPIKey(ctx, newkey)
14601452
if err != nil {
1461-
return "", xerrors.Errorf("regenerate API key: %w", err)
1453+
return "", xerrors.Errorf("insert API key: %w", err)
14621454
}
1455+
14631456
return secret, nil
14641457
}
14651458

1459+
func (server *Server) deleteSessionToken(ctx context.Context, workspace database.Workspace) error {
1460+
key, err := server.Database.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
1461+
UserID: workspace.OwnerID,
1462+
TokenName: workspaceSessionTokenName(workspace),
1463+
})
1464+
if err == nil {
1465+
err = server.Database.DeleteAPIKeyByID(ctx, key.ID)
1466+
}
1467+
1468+
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
1469+
return xerrors.Errorf("get api key by name: %w", err)
1470+
}
1471+
1472+
return nil
1473+
}
1474+
14661475
// obtainOIDCAccessToken returns a valid OpenID Connect access token
14671476
// for the user if it's able to obtain one, otherwise it returns an empty string.
14681477
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) {

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,16 @@ func TestAcquireJob(t *testing.T) {
199199
})),
200200
})
201201

202-
published := make(chan struct{})
203-
closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
204-
close(published)
202+
startPublished := make(chan struct{})
203+
var closed bool
204+
closeStartSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
205+
if !closed {
206+
close(startPublished)
207+
closed = true
208+
}
205209
})
206210
require.NoError(t, err)
207-
defer closeSubscribe()
211+
defer closeStartSubscribe()
208212

209213
var job *proto.AcquiredJob
210214

@@ -218,7 +222,7 @@ func TestAcquireJob(t *testing.T) {
218222
}
219223
}
220224

221-
<-published
225+
<-startPublished
222226

223227
got, err := json.Marshal(job.Type)
224228
require.NoError(t, err)
@@ -271,7 +275,52 @@ func TestAcquireJob(t *testing.T) {
271275
require.NoError(t, err)
272276

273277
require.JSONEq(t, string(want), string(got))
278+
279+
// Assert that we delete the session token whenever
280+
// a stop is issued.
281+
stopbuild := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{
282+
WorkspaceID: workspace.ID,
283+
BuildNumber: 2,
284+
JobID: uuid.New(),
285+
TemplateVersionID: version.ID,
286+
Transition: database.WorkspaceTransitionStop,
287+
Reason: database.BuildReasonInitiator,
288+
})
289+
_ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
290+
ID: stopbuild.ID,
291+
InitiatorID: user.ID,
292+
Provisioner: database.ProvisionerTypeEcho,
293+
StorageMethod: database.ProvisionerStorageMethodFile,
294+
FileID: file.ID,
295+
Type: database.ProvisionerJobTypeWorkspaceBuild,
296+
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
297+
WorkspaceBuildID: stopbuild.ID,
298+
})),
299+
})
300+
301+
stopPublished := make(chan struct{})
302+
closeStopSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
303+
close(stopPublished)
304+
})
305+
require.NoError(t, err)
306+
defer closeStopSubscribe()
307+
308+
// Grab jobs until we find the workspace build job. There is also
309+
// an import version job that we need to ignore.
310+
job, err = srv.AcquireJob(ctx, nil)
311+
require.NoError(t, err)
312+
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
313+
require.True(t, ok, "acquired job not a workspace build?")
314+
315+
<-stopPublished
316+
317+
// Validate that a session token is deleted during a stop job.
318+
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.CoderSessionToken
319+
require.Empty(t, sessionToken)
320+
_, err = srv.Database.GetAPIKeyByID(ctx, key.ID)
321+
require.ErrorIs(t, err, sql.ErrNoRows)
274322
})
323+
275324
t.Run("TemplateVersionDryRun", func(t *testing.T) {
276325
t.Parallel()
277326
srv := setup(t, false)

0 commit comments

Comments
 (0)