Skip to content

Commit 9bbd2c7

Browse files
committed
test workspace claim pubsub
1 parent a22b414 commit 9bbd2c7

File tree

4 files changed

+256
-27
lines changed

4 files changed

+256
-27
lines changed

cli/agent.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
331331
}
332332

333333
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
334+
334335
var (
335336
lastErr error
336337
mustExit bool
@@ -379,7 +380,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
379380
mustExit = true
380381
case event := <-reinitEvents:
381382
logger.Warn(ctx, "agent received instruction to reinitialize",
382-
slog.F("message", event.Message), slog.F("reason", event.Reason))
383+
slog.F("user_id", event.UserID), slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
383384
}
384385

385386
lastErr = agnt.Close()

coderd/prebuilds/claim.go

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package prebuilds
22

33
import (
44
"context"
5-
"fmt"
65
"net/http"
6+
"sync"
77

88
"github.com/google/uuid"
99
"golang.org/x/xerrors"
@@ -15,41 +15,81 @@ import (
1515
"github.com/coder/coder/v2/codersdk/agentsdk"
1616
)
1717

18-
func PublishWorkspaceClaim(ctx context.Context, ps pubsub.Pubsub, workspaceID, userID uuid.UUID) error {
19-
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
20-
if err := ps.Publish(channel, []byte(userID.String())); err != nil {
18+
type WorkspaceClaimPublisher interface {
19+
PublishWorkspaceClaim(agentsdk.ReinitializationEvent)
20+
}
21+
22+
func NewPubsubWorkspaceClaimPublisher(ps pubsub.Pubsub) *PubsubWorkspaceClaimPublisher {
23+
return &PubsubWorkspaceClaimPublisher{ps: ps}
24+
}
25+
26+
type PubsubWorkspaceClaimPublisher struct {
27+
ps pubsub.Pubsub
28+
}
29+
30+
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
31+
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
32+
if err := p.ps.Publish(channel, []byte(claim.UserID.String())); err != nil {
2133
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
2234
}
2335
return nil
2436
}
2537

26-
func ListenForWorkspaceClaims(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error) {
27-
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
28-
cancelSub, err := ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, id []byte) {
38+
type WorkspaceClaimListener interface {
39+
ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error)
40+
}
41+
42+
func NewPubsubWorkspaceClaimListener(ps pubsub.Pubsub, logger slog.Logger) *PubsubWorkspaceClaimListener {
43+
return &PubsubWorkspaceClaimListener{ps: ps, logger: logger}
44+
}
45+
46+
type PubsubWorkspaceClaimListener struct {
47+
logger slog.Logger
48+
ps pubsub.Pubsub
49+
}
50+
51+
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error) {
52+
workspaceClaims := make(chan agentsdk.ReinitializationEvent, 1)
53+
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, id []byte) {
54+
claimantID, err := uuid.ParseBytes(id)
55+
if err != nil {
56+
p.logger.Error(ctx, "invalid prebuild claimed channel payload", slog.F("input", string(id)))
57+
return
58+
}
59+
claim := agentsdk.ReinitializationEvent{
60+
UserID: claimantID,
61+
WorkspaceID: workspaceID,
62+
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
63+
}
2964
select {
3065
case <-ctx.Done():
3166
return
3267
case <-inner.Done():
3368
return
69+
case workspaceClaims <- claim:
3470
default:
3571
}
36-
37-
claimantID, err := uuid.ParseBytes(id)
38-
if err != nil {
39-
logger.Error(ctx, "invalid prebuild claimed channel payload", slog.F("input", string(id)))
40-
return
41-
}
42-
// TODO: turn this into a <- uuid.UUID
43-
reinitEvents <- agentsdk.ReinitializationEvent{
44-
Message: fmt.Sprintf("prebuild claimed by user: %s", claimantID),
45-
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
46-
}
4772
})
73+
4874
if err != nil {
75+
close(workspaceClaims)
4976
return func() {}, nil, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
5077
}
51-
defer cancelSub()
52-
return func() { cancelSub() }, reinitEvents, nil
78+
79+
var once sync.Once
80+
cancel := func() {
81+
once.Do(func() {
82+
cancelSub()
83+
close(workspaceClaims)
84+
})
85+
}
86+
87+
go func() {
88+
<-ctx.Done()
89+
cancel()
90+
}()
91+
92+
return cancel, workspaceClaims, nil
5393
}
5494

5595
func StreamAgentReinitEvents(ctx context.Context, logger slog.Logger, rw http.ResponseWriter, r *http.Request, reinitEvents <-chan agentsdk.ReinitializationEvent) {

coderd/prebuilds/claim_test.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package prebuilds_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"golang.org/x/xerrors"
12+
13+
"cdr.dev/slog/sloggers/slogtest"
14+
"github.com/coder/coder/v2/coderd/database/pubsub"
15+
"github.com/coder/coder/v2/coderd/prebuilds"
16+
"github.com/coder/coder/v2/codersdk/agentsdk"
17+
"github.com/coder/coder/v2/testutil"
18+
)
19+
20+
func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
21+
t.Parallel()
22+
t.Run("publish claim", func(t *testing.T) {
23+
t.Parallel()
24+
25+
ps := pubsub.NewInMemory()
26+
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
27+
28+
workspaceID := uuid.New()
29+
userID := uuid.New()
30+
31+
userIDCh := make(chan uuid.UUID, 1)
32+
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
33+
cancel, err := ps.Subscribe(channel, func(ctx context.Context, message []byte) {
34+
userIDCh <- uuid.MustParse(string(message))
35+
})
36+
require.NoError(t, err)
37+
defer cancel()
38+
39+
claim := agentsdk.ReinitializationEvent{
40+
UserID: userID,
41+
WorkspaceID: workspaceID,
42+
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
43+
}
44+
err = publisher.PublishWorkspaceClaim(claim)
45+
require.NoError(t, err)
46+
47+
// Verify the message was published
48+
select {
49+
case gotUserID := <-userIDCh:
50+
require.Equal(t, userID, gotUserID)
51+
case <-time.After(testutil.WaitShort):
52+
t.Fatal("timeout waiting for claim")
53+
}
54+
})
55+
56+
t.Run("fail to publish claim", func(t *testing.T) {
57+
t.Parallel()
58+
59+
ps := &brokenPubsub{}
60+
61+
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
62+
claim := agentsdk.ReinitializationEvent{
63+
UserID: uuid.New(),
64+
WorkspaceID: uuid.New(),
65+
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
66+
}
67+
68+
err := publisher.PublishWorkspaceClaim(claim)
69+
require.Error(t, err)
70+
})
71+
}
72+
73+
func TestPubsubWorkspaceClaimListener(t *testing.T) {
74+
t.Parallel()
75+
t.Run("stops listening if context canceled", func(t *testing.T) {
76+
t.Parallel()
77+
78+
ps := pubsub.NewInMemory()
79+
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
80+
81+
ctx, cancel := context.WithCancel(context.Background())
82+
cancel()
83+
84+
cancelFunc, claims, err := listener.ListenForWorkspaceClaims(ctx, uuid.New())
85+
require.NoError(t, err)
86+
defer cancelFunc()
87+
88+
// Channel should be closed immediately due to context cancellation
89+
select {
90+
case _, ok := <-claims:
91+
assert.False(t, ok)
92+
case <-time.After(testutil.WaitShort):
93+
t.Fatal("timeout waiting for closed channel")
94+
}
95+
})
96+
97+
t.Run("stops listening if cancel func is called", func(t *testing.T) {
98+
t.Parallel()
99+
100+
ps := pubsub.NewInMemory()
101+
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
102+
103+
cancelFunc, claims, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New())
104+
require.NoError(t, err)
105+
106+
cancelFunc()
107+
select {
108+
case _, ok := <-claims:
109+
assert.False(t, ok)
110+
case <-time.After(testutil.WaitShort):
111+
t.Fatal("timeout waiting for closed channel")
112+
}
113+
})
114+
115+
t.Run("finds claim events for its workspace", func(t *testing.T) {
116+
t.Parallel()
117+
118+
ps := pubsub.NewInMemory()
119+
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
120+
121+
workspaceID := uuid.New()
122+
userID := uuid.New()
123+
cancelFunc, claims, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
124+
require.NoError(t, err)
125+
defer cancelFunc()
126+
127+
// Publish a claim
128+
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
129+
err = ps.Publish(channel, []byte(userID.String()))
130+
require.NoError(t, err)
131+
132+
// Verify we receive the claim
133+
select {
134+
case claim := <-claims:
135+
assert.Equal(t, userID, claim.UserID)
136+
assert.Equal(t, workspaceID, claim.WorkspaceID)
137+
assert.Equal(t, agentsdk.ReinitializeReasonPrebuildClaimed, claim.Reason)
138+
case <-time.After(time.Second):
139+
t.Fatal("timeout waiting for claim")
140+
}
141+
})
142+
143+
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
144+
t.Parallel()
145+
146+
ps := pubsub.NewInMemory()
147+
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
148+
149+
workspaceID := uuid.New()
150+
otherWorkspaceID := uuid.New()
151+
cancelFunc, claims, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
152+
require.NoError(t, err)
153+
defer cancelFunc()
154+
155+
// Publish a claim for a different workspace
156+
channel := agentsdk.PrebuildClaimedChannel(otherWorkspaceID)
157+
err = ps.Publish(channel, []byte(uuid.New().String()))
158+
require.NoError(t, err)
159+
160+
// Verify we don't receive the claim
161+
select {
162+
case <-claims:
163+
t.Fatal("received claim for wrong workspace")
164+
case <-time.After(100 * time.Millisecond):
165+
// Expected - no claim received
166+
}
167+
})
168+
169+
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
170+
t.Parallel()
171+
172+
ps := &brokenPubsub{}
173+
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
174+
175+
_, _, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New())
176+
require.Error(t, err)
177+
assert.Contains(t, err.Error(), "failed to subscribe to prebuild claimed channel")
178+
})
179+
}
180+
181+
type brokenPubsub struct {
182+
pubsub.Pubsub
183+
}
184+
185+
func (brokenPubsub) Subscribe(_ string, _ pubsub.Listener) (func(), error) {
186+
return nil, xerrors.New("broken")
187+
}
188+
189+
func (brokenPubsub) Publish(_ string, _ []byte) error {
190+
return xerrors.New("broken")
191+
}

codersdk/agentsdk/agentsdk.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,8 +696,9 @@ const (
696696
)
697697

698698
type ReinitializationEvent struct {
699-
Message string `json:"message"`
700-
Reason ReinitializationReason `json:"reason"`
699+
WorkspaceID uuid.UUID
700+
UserID uuid.UUID
701+
Reason ReinitializationReason `json:"reason"`
701702
}
702703

703704
func PrebuildClaimedChannel(id uuid.UUID) string {
@@ -707,7 +708,6 @@ func PrebuildClaimedChannel(id uuid.UUID) string {
707708
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
708709
// - ping: ignored, keepalive
709710
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
710-
// NOTE: the caller is responsible for closing the events chan.
711711
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
712712
// TODO: allow configuring httpclient
713713
c.SDK.HTTPClient.Timeout = time.Hour * 24
@@ -733,9 +733,6 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
733733
nextEvent := codersdk.ServerSentEventReader(ctx, res.Body)
734734

735735
for {
736-
// TODO (Sasswart): I don't like that we do this select at the start and at the end.
737-
// nextEvent should return an error if the context is canceled, but that feels like a larger refactor.
738-
// if it did, we'd only have the select at the end of the loop.
739736
select {
740737
case <-ctx.Done():
741738
return nil, ctx.Err()

0 commit comments

Comments
 (0)