Skip to content

Commit 5804201

Browse files
committed
add tests for agent reinitialization
1 parent 9bbd2c7 commit 5804201

File tree

6 files changed

+231
-88
lines changed

6 files changed

+231
-88
lines changed

coderd/prebuilds/claim.go

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,16 @@ package prebuilds
22

33
import (
44
"context"
5-
"net/http"
65
"sync"
76

87
"github.com/google/uuid"
98
"golang.org/x/xerrors"
109

1110
"cdr.dev/slog"
1211
"github.com/coder/coder/v2/coderd/database/pubsub"
13-
"github.com/coder/coder/v2/coderd/httpapi"
14-
"github.com/coder/coder/v2/codersdk"
1512
"github.com/coder/coder/v2/codersdk/agentsdk"
1613
)
1714

18-
type WorkspaceClaimPublisher interface {
19-
PublishWorkspaceClaim(agentsdk.ReinitializationEvent)
20-
}
21-
2215
func NewPubsubWorkspaceClaimPublisher(ps pubsub.Pubsub) *PubsubWorkspaceClaimPublisher {
2316
return &PubsubWorkspaceClaimPublisher{ps: ps}
2417
}
@@ -35,10 +28,6 @@ func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.Rein
3528
return nil
3629
}
3730

38-
type WorkspaceClaimListener interface {
39-
ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error)
40-
}
41-
4231
func NewPubsubWorkspaceClaimListener(ps pubsub.Pubsub, logger slog.Logger) *PubsubWorkspaceClaimListener {
4332
return &PubsubWorkspaceClaimListener{ps: ps, logger: logger}
4433
}
@@ -49,6 +38,12 @@ type PubsubWorkspaceClaimListener struct {
4938
}
5039

5140
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error) {
41+
select {
42+
case <-ctx.Done():
43+
return func() {}, nil, ctx.Err()
44+
default:
45+
}
46+
5247
workspaceClaims := make(chan agentsdk.ReinitializationEvent, 1)
5348
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, id []byte) {
5449
claimantID, err := uuid.ParseBytes(id)
@@ -91,52 +86,3 @@ func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Conte
9186

9287
return cancel, workspaceClaims, nil
9388
}
94-
95-
func StreamAgentReinitEvents(ctx context.Context, logger slog.Logger, rw http.ResponseWriter, r *http.Request, reinitEvents <-chan agentsdk.ReinitializationEvent) {
96-
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
97-
if err != nil {
98-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
99-
Message: "Internal error setting up server-sent events.",
100-
Detail: err.Error(),
101-
})
102-
return
103-
}
104-
// Prevent handler from returning until the sender is closed.
105-
defer func() {
106-
<-sseSenderClosed
107-
}()
108-
109-
// An initial ping signals to the requester that the server is now ready
110-
// and the client can begin servicing a channel with data.
111-
_ = sseSendEvent(codersdk.ServerSentEvent{
112-
Type: codersdk.ServerSentEventTypePing,
113-
})
114-
115-
for {
116-
select {
117-
case <-ctx.Done():
118-
return
119-
case reinitEvent := <-reinitEvents:
120-
err = sseSendEvent(codersdk.ServerSentEvent{
121-
Type: codersdk.ServerSentEventTypeData,
122-
Data: reinitEvent,
123-
})
124-
if err != nil {
125-
logger.Warn(ctx, "failed to send SSE response to trigger reinit", slog.Error(err))
126-
}
127-
}
128-
}
129-
}
130-
131-
type MockClaimCoordinator interface{}
132-
133-
type ClaimListener interface{}
134-
type PostgresClaimListener struct{}
135-
136-
type AgentReinitializer interface{}
137-
type SSEAgentReinitializer struct{}
138-
139-
type ClaimCoordinator interface {
140-
ClaimListener
141-
AgentReinitializer
142-
}

coderd/prebuilds/claim_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
7979
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
8080

8181
ctx, cancel := context.WithCancel(context.Background())
82-
cancel()
8382

8483
cancelFunc, claims, err := listener.ListenForWorkspaceClaims(ctx, uuid.New())
8584
require.NoError(t, err)
8685
defer cancelFunc()
8786

87+
cancel()
8888
// Channel should be closed immediately due to context cancellation
8989
select {
9090
case _, ok := <-claims:

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import (
4444
"github.com/coder/coder/v2/coderd/tracing"
4545
"github.com/coder/coder/v2/coderd/wspubsub"
4646
"github.com/coder/coder/v2/codersdk"
47+
"github.com/coder/coder/v2/codersdk/agentsdk"
4748
"github.com/coder/coder/v2/codersdk/drpc"
4849
"github.com/coder/coder/v2/provisioner"
4950
"github.com/coder/coder/v2/provisionerd/proto"
@@ -1749,7 +1750,12 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
17491750
slog.F("user", input.PrebuildClaimedByUser.String()),
17501751
slog.F("workspace_id", workspace.ID))
17511752

1752-
if err := prebuilds.PublishWorkspaceClaim(ctx, s.Pubsub, workspace.ID, input.PrebuildClaimedByUser); err != nil {
1753+
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
1754+
UserID: input.PrebuildClaimedByUser,
1755+
WorkspaceID: workspace.ID,
1756+
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
1757+
})
1758+
if err != nil {
17531759
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
17541760
}
17551761
}

coderd/workspaceagents.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,14 +1181,24 @@ func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
11811181

11821182
log.Info(ctx, "agent waiting for reinit instruction")
11831183

1184-
cancel, reinitEvents, err := prebuilds.ListenForWorkspaceClaims(ctx, log, api.Pubsub, workspace.ID)
1184+
cancel, reinitEvents, err := prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID)
11851185
if err != nil {
11861186
log.Error(ctx, "failed to subscribe to prebuild claimed channel", slog.Error(err))
11871187
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
11881188
return
11891189
}
11901190
defer cancel()
1191-
prebuilds.StreamAgentReinitEvents(ctx, log, rw, r, reinitEvents)
1191+
1192+
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)
1193+
1194+
err = transmitter.Transmit(ctx, reinitEvents)
1195+
if err != nil {
1196+
log.Error(ctx, "failed to stream agent reinit events", slog.Error(err))
1197+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
1198+
Message: "Internal error streaming agent reinitialization events.",
1199+
Detail: err.Error(),
1200+
})
1201+
}
11921202
}
11931203

11941204
// convertProvisionedApps converts applications that are in the middle of provisioning process.

codersdk/agentsdk/agentsdk.go

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/coder/coder/v2/agent/proto"
2626
"github.com/coder/coder/v2/apiversion"
27+
"github.com/coder/coder/v2/coderd/httpapi"
2728
"github.com/coder/coder/v2/codersdk"
2829
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
2930
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
@@ -730,8 +731,86 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
730731
return nil, codersdk.ReadBodyAsError(res)
731732
}
732733

733-
nextEvent := codersdk.ServerSentEventReader(ctx, res.Body)
734+
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
735+
if err != nil {
736+
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
737+
}
738+
return reinitEvent, nil
739+
}
740+
741+
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
742+
reinitEvents := make(chan ReinitializationEvent)
743+
744+
go func() {
745+
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
746+
logger.Debug(ctx, "waiting for agent reinitialization instructions")
747+
reinitEvent, err := client.WaitForReinit(ctx)
748+
if err != nil {
749+
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
750+
continue
751+
}
752+
select {
753+
case <-ctx.Done():
754+
close(reinitEvents)
755+
return
756+
case reinitEvents <- *reinitEvent:
757+
}
758+
}
759+
}()
760+
761+
return reinitEvents
762+
}
763+
764+
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
765+
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
766+
}
767+
768+
type SSEAgentReinitTransmitter struct {
769+
rw http.ResponseWriter
770+
r *http.Request
771+
logger slog.Logger
772+
}
773+
774+
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
775+
select {
776+
case <-ctx.Done():
777+
return ctx.Err()
778+
default:
779+
}
780+
781+
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
782+
if err != nil {
783+
return xerrors.Errorf("failed to create sse transmitter: %w", err)
784+
}
734785

786+
for {
787+
select {
788+
case <-ctx.Done():
789+
return ctx.Err()
790+
case <-sseSenderClosed:
791+
return xerrors.New("sse connection closed")
792+
case reinitEvent := <-reinitEvents:
793+
err := sseSendEvent(codersdk.ServerSentEvent{
794+
Type: codersdk.ServerSentEventTypeData,
795+
Data: reinitEvent,
796+
})
797+
if err != nil {
798+
s.logger.Warn(ctx, "failed to send SSE response to trigger reinit", slog.Error(err))
799+
}
800+
}
801+
}
802+
}
803+
804+
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
805+
return &SSEAgentReinitReceiver{r: r}
806+
}
807+
808+
type SSEAgentReinitReceiver struct {
809+
r io.ReadCloser
810+
}
811+
812+
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
813+
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
735814
for {
736815
select {
737816
case <-ctx.Done():
@@ -763,26 +842,3 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
763842
}
764843
}
765844
}
766-
767-
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
768-
reinitEvents := make(chan ReinitializationEvent)
769-
770-
go func() {
771-
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
772-
logger.Debug(ctx, "waiting for agent reinitialization instructions")
773-
reinitEvent, err := client.WaitForReinit(ctx)
774-
if err != nil {
775-
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
776-
continue
777-
}
778-
select {
779-
case <-ctx.Done():
780-
close(reinitEvents)
781-
return
782-
case reinitEvents <- *reinitEvent:
783-
}
784-
}
785-
}()
786-
787-
return reinitEvents
788-
}

0 commit comments

Comments
 (0)