Skip to content

Commit 964bfe9

Browse files
committed
review
1 parent b11895f commit 964bfe9

File tree

10 files changed

+215
-96
lines changed

10 files changed

+215
-96
lines changed

cli/server.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -722,13 +722,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
722722
options.Database = dbmetrics.NewDBMetrics(options.Database, options.Logger, options.PrometheusRegistry)
723723
}
724724

725-
wsUpdates, err := coderd.NewUpdatesProvider(logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)
726-
if err != nil {
727-
return xerrors.Errorf("create workspace updates provider: %w", err)
728-
}
729-
options.WorkspaceUpdatesProvider = wsUpdates
730-
defer wsUpdates.Close()
731-
732725
var deploymentID string
733726
err = options.Database.InTx(func(tx database.Store) error {
734727
// This will block until the lock is acquired, and will be

coderd/coderd.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,13 @@ func New(options *Options) *API {
495495
}
496496
}
497497

498+
if options.WorkspaceUpdatesProvider == nil {
499+
options.WorkspaceUpdatesProvider, err = NewUpdatesProvider(options.Logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)
500+
if err != nil {
501+
options.Logger.Critical(ctx, "failed to properly instantiate workspace updates provider", slog.Error(err))
502+
}
503+
}
504+
498505
// Start a background process that rotates keys. We intentionally start this after the caches
499506
// are created to force initial requests for a key to populate the caches. This helps catch
500507
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
@@ -1495,6 +1502,7 @@ func (api *API) Close() error {
14951502
_ = api.OIDCConvertKeyCache.Close()
14961503
_ = api.AppSigningKeyCache.Close()
14971504
_ = api.AppEncryptionKeyCache.Close()
1505+
_ = api.WorkspaceUpdatesProvider.Close()
14981506
return nil
14991507
}
15001508

coderd/workspaceupdates.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/coder/coder/v2/coderd/database/dbauthz"
1515
"github.com/coder/coder/v2/coderd/database/pubsub"
1616
"github.com/coder/coder/v2/coderd/rbac"
17-
"github.com/coder/coder/v2/coderd/rbac/policy"
1817
"github.com/coder/coder/v2/coderd/util/slice"
1918
"github.com/coder/coder/v2/coderd/wspubsub"
2019
"github.com/coder/coder/v2/codersdk"
@@ -23,7 +22,8 @@ import (
2322
)
2423

2524
type UpdatesQuerier interface {
26-
GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID, prep rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
25+
// GetAuthorizedWorkspacesAndAgentsByOwnerID requires a context with an actor set
26+
GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
2727
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error)
2828
}
2929

@@ -45,11 +45,10 @@ type sub struct {
4545
ctx context.Context
4646
cancelFn context.CancelFunc
4747

48-
mu sync.RWMutex
49-
userID uuid.UUID
50-
ch chan *proto.WorkspaceUpdate
51-
prev workspacesByID
52-
readPrep rbac.PreparedAuthorized
48+
mu sync.RWMutex
49+
userID uuid.UUID
50+
ch chan *proto.WorkspaceUpdate
51+
prev workspacesByID
5352

5453
db UpdatesQuerier
5554
ps pubsub.Pubsub
@@ -76,7 +75,7 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
7675
}
7776
}
7877

79-
rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
78+
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(ctx, s.userID)
8079
if err != nil {
8180
s.logger.Warn(ctx, "failed to get workspaces and agents by owner ID", slog.Error(err))
8281
return
@@ -97,7 +96,7 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
9796
}
9897

9998
func (s *sub) start(ctx context.Context) (err error) {
100-
rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
99+
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(ctx, s.userID)
101100
if err != nil {
102101
return xerrors.Errorf("get workspaces and agents by owner ID: %w", err)
103102
}
@@ -168,17 +167,17 @@ func (u *updatesProvider) Close() error {
168167
return nil
169168
}
170169

170+
// Subscribe subscribes to workspace updates for a user, for the workspaces
171+
// that user is authorized to `ActionRead` on. The provided context must have
172+
// a dbauthz actor set.
171173
func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tailnet.Subscription, error) {
172174
actor, ok := dbauthz.ActorFromContext(ctx)
173175
if !ok {
174176
return nil, xerrors.Errorf("actor not found in context")
175177
}
176-
readPrep, err := u.auth.Prepare(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace.Type)
177-
if err != nil {
178-
return nil, xerrors.Errorf("prepare read action: %w", err)
179-
}
178+
ctx, cancel := context.WithCancel(u.ctx)
179+
ctx = dbauthz.As(ctx, actor)
180180
ch := make(chan *proto.WorkspaceUpdate, 1)
181-
ctx, cancel := context.WithCancel(ctx)
182181
sub := &sub{
183182
ctx: ctx,
184183
cancelFn: cancel,
@@ -188,9 +187,8 @@ func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tail
188187
ps: u.ps,
189188
logger: u.logger.Named(fmt.Sprintf("workspace_updates_subscriber_%s", userID)),
190189
prev: workspacesByID{},
191-
readPrep: readPrep,
192190
}
193-
err = sub.start(ctx)
191+
err := sub.start(ctx)
194192
if err != nil {
195193
_ = sub.Close()
196194
return nil, err

coderd/workspaceupdates_test.go

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,23 @@ import (
2525

2626
func TestWorkspaceUpdates(t *testing.T) {
2727
t.Parallel()
28-
ctx := context.Background()
2928

30-
ws1ID := uuid.New()
29+
ws1ID := uuid.UUID{0x01}
3130
ws1IDSlice := tailnet.UUIDToByteSlice(ws1ID)
32-
agent1ID := uuid.New()
31+
agent1ID := uuid.UUID{0x02}
3332
agent1IDSlice := tailnet.UUIDToByteSlice(agent1ID)
34-
ws2ID := uuid.New()
33+
ws2ID := uuid.UUID{0x03}
3534
ws2IDSlice := tailnet.UUIDToByteSlice(ws2ID)
36-
ws3ID := uuid.New()
35+
ws3ID := uuid.UUID{0x04}
3736
ws3IDSlice := tailnet.UUIDToByteSlice(ws3ID)
38-
agent2ID := uuid.New()
37+
agent2ID := uuid.UUID{0x05}
3938
agent2IDSlice := tailnet.UUIDToByteSlice(agent2ID)
40-
ws4ID := uuid.New()
39+
ws4ID := uuid.UUID{0x06}
4140
ws4IDSlice := tailnet.UUIDToByteSlice(ws4ID)
41+
agent3ID := uuid.UUID{0x07}
42+
agent3IDSlice := tailnet.UUIDToByteSlice(agent3ID)
4243

43-
ownerID := uuid.New()
44+
ownerID := uuid.UUID{0x08}
4445
memberRole, err := rbac.RoleByName(rbac.RoleMember())
4546
require.NoError(t, err)
4647
ownerSubject := rbac.Subject{
@@ -53,9 +54,11 @@ func TestWorkspaceUpdates(t *testing.T) {
5354
t.Run("Basic", func(t *testing.T) {
5455
t.Parallel()
5556

57+
ctx := testutil.Context(t, testutil.WaitShort)
58+
5659
db := &mockWorkspaceStore{
5760
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
58-
// Gains a new agent
61+
// Gains agent2
5962
{
6063
ID: ws1ID,
6164
Name: "ws1",
@@ -81,6 +84,12 @@ func TestWorkspaceUpdates(t *testing.T) {
8184
Name: "ws3",
8285
JobStatus: database.ProvisionerJobStatusSucceeded,
8386
Transition: database.WorkspaceTransitionStop,
87+
Agents: []database.AgentIDNamePair{
88+
{
89+
ID: agent3ID,
90+
Name: "agent3",
91+
},
92+
},
8493
},
8594
},
8695
}
@@ -97,13 +106,15 @@ func TestWorkspaceUpdates(t *testing.T) {
97106

98107
sub, err := updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
99108
require.NoError(t, err)
100-
ch := sub.Updates()
109+
defer sub.Close()
101110

102-
update, ok := <-ch
103-
require.True(t, ok)
111+
update := testutil.RequireRecvCtx(ctx, t, sub.Updates())
104112
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
105113
return strings.Compare(a.Name, b.Name)
106114
})
115+
slices.SortFunc(update.UpsertedAgents, func(a, b *proto.Agent) int {
116+
return strings.Compare(a.Name, b.Name)
117+
})
107118
require.Equal(t, &proto.WorkspaceUpdate{
108119
UpsertedWorkspaces: []*proto.Workspace{
109120
{
@@ -128,6 +139,11 @@ func TestWorkspaceUpdates(t *testing.T) {
128139
Name: "agent1",
129140
WorkspaceId: ws1IDSlice,
130141
},
142+
{
143+
Id: agent3IDSlice,
144+
Name: "agent3",
145+
WorkspaceId: ws3IDSlice,
146+
},
131147
},
132148
DeletedWorkspaces: []*proto.Workspace{},
133149
DeletedAgents: []*proto.Agent{},
@@ -169,8 +185,7 @@ func TestWorkspaceUpdates(t *testing.T) {
169185
WorkspaceID: ws1ID,
170186
})
171187

172-
update, ok = <-ch
173-
require.True(t, ok)
188+
update = testutil.RequireRecvCtx(ctx, t, sub.Updates())
174189
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
175190
return strings.Compare(a.Name, b.Name)
176191
})
@@ -203,13 +218,21 @@ func TestWorkspaceUpdates(t *testing.T) {
203218
Status: proto.Workspace_STOPPED,
204219
},
205220
},
206-
DeletedAgents: []*proto.Agent{},
221+
DeletedAgents: []*proto.Agent{
222+
{
223+
Id: agent3IDSlice,
224+
Name: "agent3",
225+
WorkspaceId: ws3IDSlice,
226+
},
227+
},
207228
}, update)
208229
})
209230

210231
t.Run("Resubscribe", func(t *testing.T) {
211232
t.Parallel()
212233

234+
ctx := testutil.Context(t, testutil.WaitShort)
235+
213236
db := &mockWorkspaceStore{
214237
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
215238
{
@@ -290,7 +313,7 @@ type mockWorkspaceStore struct {
290313
}
291314

292315
// GetAuthorizedWorkspacesAndAgentsByOwnerID implements coderd.UpdatesQuerier.
293-
func (m *mockWorkspaceStore) GetAuthorizedWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID, rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
316+
func (m *mockWorkspaceStore) GetWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
294317
return m.orderedRows, nil
295318
}
296319

enterprise/tailnet/connio.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ var errDisconnect = xerrors.New("graceful disconnect")
133133

134134
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
135135
c.logger.Debug(c.peerCtx, "got request")
136-
err := c.auth.Authorize(c.coordCtx, req)
136+
err := c.auth.Authorize(c.peerCtx, req)
137137
if err != nil {
138138
c.logger.Warn(c.peerCtx, "unauthorized request", slog.Error(err))
139139
return xerrors.Errorf("authorize request: %w", err)

enterprise/tailnet/pgcoord_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,42 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) {
913913
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
914914
}
915915

916+
// TestPGCoordinatorPropogatedPeerContext tests that the context for a specific peer
917+
// is propogated through to the `Authorize` method of the coordinatee auth
918+
func TestPGCoordinatorPropogatedPeerContext(t *testing.T) {
919+
t.Parallel()
920+
921+
if !dbtestutil.WillUsePostgres() {
922+
t.Skip("test only with postgres")
923+
}
924+
925+
ctx := testutil.Context(t, testutil.WaitShort)
926+
store, ps := dbtestutil.NewDB(t)
927+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
928+
929+
peerCtx := context.WithValue(ctx, agpltest.FakeSubjectKey{}, struct{}{})
930+
peerID := uuid.UUID{0x01}
931+
agentID := uuid.UUID{0x02}
932+
933+
c1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
934+
require.NoError(t, err)
935+
defer func() {
936+
err := c1.Close()
937+
require.NoError(t, err)
938+
}()
939+
940+
ch := make(chan struct{})
941+
auth := agpltest.FakeCoordinateeAuth{
942+
Chan: ch,
943+
}
944+
945+
reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth)
946+
947+
testutil.RequireSendCtx(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agpl.UUIDToByteSlice(agentID)}})
948+
949+
_ = testutil.RequireRecvCtx(ctx, t, ch)
950+
}
951+
916952
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
917953
t.Helper()
918954
assert.Eventually(t, func() bool {

tailnet/coordinator_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,36 @@ func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
529529
defer f.Unlock()
530530
f.callback = callback
531531
}
532+
533+
// TestCoordinatorPropogatedPeerContext tests that the context for a specific peer
534+
// is propogated through to the `Authorize“ method of the coordinatee auth
535+
func TestCoordinatorPropogatedPeerContext(t *testing.T) {
536+
t.Parallel()
537+
538+
ctx := testutil.Context(t, testutil.WaitShort)
539+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
540+
541+
peerCtx := context.WithValue(ctx, test.FakeSubjectKey{}, struct{}{})
542+
peerCtx, peerCtxCancel := context.WithCancel(peerCtx)
543+
peerID := uuid.UUID{0x01}
544+
agentID := uuid.UUID{0x02}
545+
546+
c1 := tailnet.NewCoordinator(logger)
547+
t.Cleanup(func() {
548+
err := c1.Close()
549+
require.NoError(t, err)
550+
})
551+
552+
ch := make(chan struct{})
553+
auth := test.FakeCoordinateeAuth{
554+
Chan: ch,
555+
}
556+
557+
reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth)
558+
559+
testutil.RequireSendCtx(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(agentID)}})
560+
_ = testutil.RequireRecvCtx(ctx, t, ch)
561+
// If we don't cancel the context, the coordinator close will wait until the
562+
// peer request loop finishes, which will be after the timeout
563+
peerCtxCancel()
564+
}

tailnet/service.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,28 +220,15 @@ func (s *DRPCService) WorkspaceUpdates(req *proto.WorkspaceUpdatesRequest, strea
220220
defer stream.Close()
221221

222222
ctx := stream.Context()
223-
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
224-
if !ok {
225-
return xerrors.New("no Stream ID")
226-
}
227223

228224
ownerID, err := uuid.FromBytes(req.WorkspaceOwnerId)
229225
if err != nil {
230226
return xerrors.Errorf("parse workspace owner ID: %w", err)
231227
}
232228

233-
var sub Subscription
234-
switch auth := streamID.Auth.(type) {
235-
case ClientUserCoordinateeAuth:
236-
sub, err = s.WorkspaceUpdatesProvider.Subscribe(ctx, ownerID)
237-
if err != nil {
238-
err = xerrors.Errorf("subscribe to workspace updates: %w", err)
239-
}
240-
default:
241-
err = xerrors.Errorf("workspace updates not supported by auth name %T", auth)
242-
}
229+
sub, err := s.WorkspaceUpdatesProvider.Subscribe(ctx, ownerID)
243230
if err != nil {
244-
return err
231+
return xerrors.Errorf("subscribe to workspace updates: %w", err)
245232
}
246233
defer sub.Close()
247234

0 commit comments

Comments
 (0)