Skip to content

Commit a22b414

Browse files
committed
introduce unit testable abstraction layers
1 parent b117b5c commit a22b414

File tree

4 files changed

+110
-71
lines changed

4 files changed

+110
-71
lines changed

coderd/prebuilds/claim.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package prebuilds
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
8+
"github.com/google/uuid"
9+
"golang.org/x/xerrors"
10+
11+
"cdr.dev/slog"
12+
"github.com/coder/coder/v2/coderd/database/pubsub"
13+
"github.com/coder/coder/v2/coderd/httpapi"
14+
"github.com/coder/coder/v2/codersdk"
15+
"github.com/coder/coder/v2/codersdk/agentsdk"
16+
)
17+
18+
func PublishWorkspaceClaim(ctx context.Context, ps pubsub.Pubsub, workspaceID, userID uuid.UUID) error {
19+
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
20+
if err := ps.Publish(channel, []byte(userID.String())); err != nil {
21+
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
22+
}
23+
return nil
24+
}
25+
26+
func ListenForWorkspaceClaims(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, workspaceID uuid.UUID) (func(), <-chan agentsdk.ReinitializationEvent, error) {
27+
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
28+
cancelSub, err := ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, id []byte) {
29+
select {
30+
case <-ctx.Done():
31+
return
32+
case <-inner.Done():
33+
return
34+
default:
35+
}
36+
37+
claimantID, err := uuid.ParseBytes(id)
38+
if err != nil {
39+
logger.Error(ctx, "invalid prebuild claimed channel payload", slog.F("input", string(id)))
40+
return
41+
}
42+
// TODO: turn this into a <- uuid.UUID
43+
reinitEvents <- agentsdk.ReinitializationEvent{
44+
Message: fmt.Sprintf("prebuild claimed by user: %s", claimantID),
45+
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
46+
}
47+
})
48+
if err != nil {
49+
return func() {}, nil, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
50+
}
51+
defer cancelSub()
52+
return func() { cancelSub() }, reinitEvents, nil
53+
}
54+
55+
func StreamAgentReinitEvents(ctx context.Context, logger slog.Logger, rw http.ResponseWriter, r *http.Request, reinitEvents <-chan agentsdk.ReinitializationEvent) {
56+
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
57+
if err != nil {
58+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
59+
Message: "Internal error setting up server-sent events.",
60+
Detail: err.Error(),
61+
})
62+
return
63+
}
64+
// Prevent handler from returning until the sender is closed.
65+
defer func() {
66+
<-sseSenderClosed
67+
}()
68+
69+
// An initial ping signals to the requester that the server is now ready
70+
// and the client can begin servicing a channel with data.
71+
_ = sseSendEvent(codersdk.ServerSentEvent{
72+
Type: codersdk.ServerSentEventTypePing,
73+
})
74+
75+
for {
76+
select {
77+
case <-ctx.Done():
78+
return
79+
case reinitEvent := <-reinitEvents:
80+
err = sseSendEvent(codersdk.ServerSentEvent{
81+
Type: codersdk.ServerSentEventTypeData,
82+
Data: reinitEvent,
83+
})
84+
if err != nil {
85+
logger.Warn(ctx, "failed to send SSE response to trigger reinit", slog.Error(err))
86+
}
87+
}
88+
}
89+
}
90+
91+
type MockClaimCoordinator interface{}
92+
93+
type ClaimListener interface{}
94+
type PostgresClaimListener struct{}
95+
96+
type AgentReinitializer interface{}
97+
type SSEAgentReinitializer struct{}
98+
99+
type ClaimCoordinator interface {
100+
ClaimListener
101+
AgentReinitializer
102+
}

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import (
1616
"sync/atomic"
1717
"time"
1818

19-
"github.com/coder/coder/v2/codersdk/agentsdk"
20-
2119
"github.com/google/uuid"
2220
"github.com/sqlc-dev/pqtype"
2321
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
@@ -39,6 +37,7 @@ import (
3937
"github.com/coder/coder/v2/coderd/database/pubsub"
4038
"github.com/coder/coder/v2/coderd/externalauth"
4139
"github.com/coder/coder/v2/coderd/notifications"
40+
"github.com/coder/coder/v2/coderd/prebuilds"
4241
"github.com/coder/coder/v2/coderd/promoauth"
4342
"github.com/coder/coder/v2/coderd/schedule"
4443
"github.com/coder/coder/v2/coderd/telemetry"
@@ -1750,9 +1749,8 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
17501749
slog.F("user", input.PrebuildClaimedByUser.String()),
17511750
slog.F("workspace_id", workspace.ID))
17521751

1753-
channel := agentsdk.PrebuildClaimedChannel(workspace.ID)
1754-
if err := s.Pubsub.Publish(channel, []byte(input.PrebuildClaimedByUser.String())); err != nil {
1755-
s.Logger.Error(ctx, "failed to trigger prebuilt workspace agent reinitialization", slog.Error(err))
1752+
if err := prebuilds.PublishWorkspaceClaim(ctx, s.Pubsub, workspace.ID, input.PrebuildClaimedByUser); err != nil {
1753+
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
17561754
}
17571755
}
17581756
case *proto.CompletedJob_TemplateDryRun_:

coderd/workspaceagents.go

Lines changed: 4 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"github.com/coder/coder/v2/coderd/httpmw"
3636
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
3737
"github.com/coder/coder/v2/coderd/jwtutils"
38+
"github.com/coder/coder/v2/coderd/prebuilds"
3839
"github.com/coder/coder/v2/coderd/rbac"
3940
"github.com/coder/coder/v2/coderd/rbac/policy"
4041
"github.com/coder/coder/v2/coderd/telemetry"
@@ -1180,76 +1181,14 @@ func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
11801181

11811182
log.Info(ctx, "agent waiting for reinit instruction")
11821183

1183-
prebuildClaims := make(chan uuid.UUID, 1)
1184-
cancelSub, err := api.Pubsub.Subscribe(agentsdk.PrebuildClaimedChannel(workspace.ID), func(inner context.Context, id []byte) {
1185-
select {
1186-
case <-ctx.Done():
1187-
return
1188-
case <-inner.Done():
1189-
return
1190-
default:
1191-
}
1192-
1193-
parsed, err := uuid.ParseBytes(id)
1194-
if err != nil {
1195-
log.Error(ctx, "invalid prebuild claimed channel payload", slog.F("input", string(id)))
1196-
return
1197-
}
1198-
prebuildClaims <- parsed
1199-
})
1184+
cancel, reinitEvents, err := prebuilds.ListenForWorkspaceClaims(ctx, log, api.Pubsub, workspace.ID)
12001185
if err != nil {
12011186
log.Error(ctx, "failed to subscribe to prebuild claimed channel", slog.Error(err))
12021187
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
12031188
return
12041189
}
1205-
defer cancelSub()
1206-
1207-
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
1208-
if err != nil {
1209-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
1210-
Message: "Internal error setting up server-sent events.",
1211-
Detail: err.Error(),
1212-
})
1213-
return
1214-
}
1215-
// Prevent handler from returning until the sender is closed.
1216-
defer func() {
1217-
cancel()
1218-
<-sseSenderClosed
1219-
}()
1220-
// Synchronize cancellation from SSE -> context, this lets us simplify the
1221-
// cancellation logic.
1222-
go func() {
1223-
select {
1224-
case <-ctx.Done():
1225-
case <-sseSenderClosed:
1226-
cancel()
1227-
}
1228-
}()
1229-
1230-
// An initial ping signals to the request that the server is now ready
1231-
// and the client can begin servicing a channel with data.
1232-
_ = sseSendEvent(codersdk.ServerSentEvent{
1233-
Type: codersdk.ServerSentEventTypePing,
1234-
})
1235-
1236-
for {
1237-
select {
1238-
case <-ctx.Done():
1239-
return
1240-
case user := <-prebuildClaims:
1241-
err = sseSendEvent(codersdk.ServerSentEvent{
1242-
Type: codersdk.ServerSentEventTypeData,
1243-
Data: agentsdk.ReinitializationEvent{
1244-
Message: fmt.Sprintf("prebuild claimed by user: %s", user),
1245-
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
1246-
},
1247-
})
1248-
if err != nil {
1249-
log.Warn(ctx, "failed to send SSE response to trigger reinit", slog.Error(err))
1250-
}
1251-
}
1252-
}
1190+
defer cancel()
1191+
prebuilds.StreamAgentReinitEvents(ctx, log, rw, r, reinitEvents)
12531192
}
12541193

12551194
// convertProvisionedApps converts applications that are in the middle of provisioning process.

codersdk/agentsdk/agentsdk.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,8 @@ func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client)
776776
reinitEvent, err := client.WaitForReinit(ctx)
777777
if err != nil {
778778
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
779+
continue
779780
}
780-
reinitEvents <- *reinitEvent
781781
select {
782782
case <-ctx.Done():
783783
close(reinitEvents)

0 commit comments

Comments
 (0)