Skip to content

Commit 8102c71

Browse files
committed
move core impl to coderd
1 parent 9faa940 commit 8102c71

File tree

8 files changed

+91
-66
lines changed

8 files changed

+91
-66
lines changed

cli/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,12 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
719719
options.Database = dbmetrics.New(options.Database, options.PrometheusRegistry)
720720
}
721721

722+
wsUpdates, err := coderd.NewUpdatesProvider(ctx, options.Database, options.Pubsub)
723+
if err != nil {
724+
return xerrors.Errorf("create workspace updates provider: %w", err)
725+
}
726+
options.WorkspaceUpdatesProvider = wsUpdates
727+
722728
var deploymentID string
723729
err = options.Database.InTx(func(tx database.Store) error {
724730
// This will block until the lock is acquired, and will be

coderd/coderd.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ type Options struct {
228228

229229
WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions
230230

231+
WorkspaceUpdatesProvider tailnet.WorkspaceUpdatesProvider
232+
231233
// This janky function is used in telemetry to parse fields out of the raw
232234
// JWT. It needs to be passed through like this because license parsing is
233235
// under the enterprise license, and can't be imported into AGPL.
@@ -591,12 +593,13 @@ func New(options *Options) *API {
591593
panic("CoordinatorResumeTokenProvider is nil")
592594
}
593595
api.TailnetClientService, err = tailnet.NewClientService(tailnet.ClientServiceOptions{
594-
Logger: api.Logger.Named("tailnetclient"),
595-
CoordPtr: &api.TailnetCoordinator,
596-
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
597-
DERPMapFn: api.DERPMap,
598-
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
599-
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
596+
Logger: api.Logger.Named("tailnetclient"),
597+
CoordPtr: &api.TailnetCoordinator,
598+
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
599+
DERPMapFn: api.DERPMap,
600+
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
601+
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
602+
WorkspaceUpdatesProvider: api.Options.WorkspaceUpdatesProvider,
600603
})
601604
if err != nil {
602605
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))

coderd/coderdtest/coderdtest.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ type Options struct {
159159
WorkspaceUsageTrackerFlush chan int
160160
WorkspaceUsageTrackerTick chan time.Time
161161

162+
WorkspaceUpdatesProvider tailnet.WorkspaceUpdatesProvider
163+
162164
NotificationsEnqueuer notifications.Enqueuer
163165
}
164166

@@ -251,6 +253,14 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
251253
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
252254
}
253255

256+
if options.WorkspaceUpdatesProvider == nil {
257+
var err error
258+
ctx, cancel := context.WithCancel(context.Background())
259+
options.WorkspaceUpdatesProvider, err = coderd.NewUpdatesProvider(ctx, options.Database, options.Pubsub)
260+
require.NoError(t, err)
261+
t.Cleanup(cancel)
262+
}
263+
254264
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
255265
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
256266
accessControlStore.Store(&acs)
@@ -524,6 +534,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
524534
HealthcheckTimeout: options.HealthcheckTimeout,
525535
HealthcheckRefresh: options.HealthcheckRefresh,
526536
StatsBatcher: options.StatsBatcher,
537+
WorkspaceUpdatesProvider: options.WorkspaceUpdatesProvider,
527538
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
528539
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
529540
NewTicker: options.NewTicker,

coderd/workspaces.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,14 +2146,28 @@ func (api *API) tailnet(rw http.ResponseWriter, r *http.Request) {
21462146

21472147
go httpapi.Heartbeat(ctx, conn)
21482148
err = api.TailnetClientService.ServeUserClient(ctx, version, wsNetConn, tailnet.ServeUserClientOptions{
2149-
PeerID: peerID,
2150-
UserID: owner.ID,
2151-
Subject: &ownerRoles,
2152-
Authz: api.Authorizer,
2153-
Database: api.Database,
2149+
PeerID: peerID,
2150+
UserID: owner.ID,
2151+
AuthFn: api.authAgentFn(&ownerRoles),
21542152
})
21552153
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) {
21562154
_ = conn.Close(websocket.StatusInternalError, err.Error())
21572155
return
21582156
}
21592157
}
2158+
2159+
// authAgentFn accepts a subject, and returns a function that authorizes against
2160+
// passed agent IDs.
2161+
func (api *API) authAgentFn(owner *rbac.Subject) func(context.Context, uuid.UUID) error {
2162+
return func(ctx context.Context, agentID uuid.UUID) error {
2163+
ws, err := api.Database.GetWorkspaceByAgentID(ctx, agentID)
2164+
if err != nil {
2165+
return xerrors.Errorf("get workspace by agent id: %w", err)
2166+
}
2167+
err = api.Authorizer.Authorize(ctx, *owner, policy.ActionSSH, ws.RBACObject())
2168+
if err != nil {
2169+
return xerrors.Errorf("workspace agent not found or you do not have permission: %w", sql.ErrNoRows)
2170+
}
2171+
return nil
2172+
}
2173+
}

tailnet/workspaceupdates.go renamed to coderd/workspaceupdates.go

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tailnet
1+
package coderd
22

33
import (
44
"context"
@@ -12,6 +12,7 @@ import (
1212
"github.com/coder/coder/v2/coderd/database/pubsub"
1313
"github.com/coder/coder/v2/coderd/util/slice"
1414
"github.com/coder/coder/v2/codersdk"
15+
"github.com/coder/coder/v2/tailnet"
1516
"github.com/coder/coder/v2/tailnet/proto"
1617
)
1718

@@ -54,7 +55,7 @@ func convertRows(v []database.GetWorkspacesAndAgentsRow) workspacesByOwner {
5455

5556
func convertStatus(status database.ProvisionerJobStatus, trans database.WorkspaceTransition) proto.Workspace_Status {
5657
wsStatus := codersdk.ConvertWorkspaceStatus(codersdk.ProvisionerJobStatus(status), codersdk.WorkspaceTransition(trans))
57-
return WorkspaceStatusToProto(wsStatus)
58+
return tailnet.WorkspaceStatusToProto(wsStatus)
5859
}
5960

6061
type sub struct {
@@ -75,30 +76,23 @@ func (s *sub) send(all workspacesByOwner) {
7576
s.tx <- update
7677
}
7778

78-
type WorkspaceUpdatesProvider interface {
79-
Subscribe(peerID uuid.UUID, userID uuid.UUID) (<-chan *proto.WorkspaceUpdate, error)
80-
Unsubscribe(peerID uuid.UUID)
81-
Stop()
82-
}
83-
84-
type WorkspaceStore interface {
85-
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.GetWorkspaceByAgentIDRow, error)
79+
type UpdateQuerier interface {
8680
GetWorkspacesAndAgents(ctx context.Context) ([]database.GetWorkspacesAndAgentsRow, error)
8781
}
8882

8983
type updatesProvider struct {
9084
mu sync.RWMutex
91-
db WorkspaceStore
85+
db UpdateQuerier
9286
ps pubsub.Pubsub
9387
// Peer ID -> subscription
9488
subs map[uuid.UUID]*sub
9589
latest workspacesByOwner
9690
cancelFn func()
9791
}
9892

99-
var _ WorkspaceUpdatesProvider = (*updatesProvider)(nil)
93+
var _ tailnet.WorkspaceUpdatesProvider = (*updatesProvider)(nil)
10094

101-
func NewUpdatesProvider(ctx context.Context, db WorkspaceStore, ps pubsub.Pubsub) (WorkspaceUpdatesProvider, error) {
95+
func NewUpdatesProvider(ctx context.Context, db UpdateQuerier, ps pubsub.Pubsub) (tailnet.WorkspaceUpdatesProvider, error) {
10296
rows, err := db.GetWorkspacesAndAgents(ctx)
10397
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
10498
return nil, err
@@ -176,7 +170,6 @@ func (u *updatesProvider) Unsubscribe(peerID uuid.UUID) {
176170
}
177171
close(sub.tx)
178172
delete(u.subs, peerID)
179-
return
180173
}
181174

182175
func produceUpdate(old, new workspacesByID) *proto.WorkspaceUpdate {
@@ -192,23 +185,23 @@ func produceUpdate(old, new workspacesByID) *proto.WorkspaceUpdate {
192185
// Upsert both workspace and agents if the workspace is new
193186
if !exists {
194187
out.UpsertedWorkspaces = append(out.UpsertedWorkspaces, &proto.Workspace{
195-
Id: UUIDToByteSlice(wsID),
188+
Id: tailnet.UUIDToByteSlice(wsID),
196189
Name: newWorkspace.WorkspaceName,
197190
Status: convertStatus(newWorkspace.JobStatus, newWorkspace.Transition),
198191
})
199192
for _, agent := range newWorkspace.Agents {
200193
out.UpsertedAgents = append(out.UpsertedAgents, &proto.Agent{
201-
Id: UUIDToByteSlice(agent.ID),
194+
Id: tailnet.UUIDToByteSlice(agent.ID),
202195
Name: agent.Name,
203-
WorkspaceId: UUIDToByteSlice(wsID),
196+
WorkspaceId: tailnet.UUIDToByteSlice(wsID),
204197
})
205198
}
206199
continue
207200
}
208201
// Upsert workspace if the workspace is updated
209202
if !newWorkspace.Equal(oldWorkspace) {
210203
out.UpsertedWorkspaces = append(out.UpsertedWorkspaces, &proto.Workspace{
211-
Id: UUIDToByteSlice(wsID),
204+
Id: tailnet.UUIDToByteSlice(wsID),
212205
Name: newWorkspace.WorkspaceName,
213206
Status: convertStatus(newWorkspace.JobStatus, newWorkspace.Transition),
214207
})
@@ -217,16 +210,16 @@ func produceUpdate(old, new workspacesByID) *proto.WorkspaceUpdate {
217210
add, remove := slice.SymmetricDifference(oldWorkspace.Agents, newWorkspace.Agents)
218211
for _, agent := range add {
219212
out.UpsertedAgents = append(out.UpsertedAgents, &proto.Agent{
220-
Id: UUIDToByteSlice(agent.ID),
213+
Id: tailnet.UUIDToByteSlice(agent.ID),
221214
Name: agent.Name,
222-
WorkspaceId: UUIDToByteSlice(wsID),
215+
WorkspaceId: tailnet.UUIDToByteSlice(wsID),
223216
})
224217
}
225218
for _, agent := range remove {
226219
out.DeletedAgents = append(out.DeletedAgents, &proto.Agent{
227-
Id: UUIDToByteSlice(agent.ID),
220+
Id: tailnet.UUIDToByteSlice(agent.ID),
228221
Name: agent.Name,
229-
WorkspaceId: UUIDToByteSlice(wsID),
222+
WorkspaceId: tailnet.UUIDToByteSlice(wsID),
230223
})
231224
}
232225
}
@@ -235,15 +228,15 @@ func produceUpdate(old, new workspacesByID) *proto.WorkspaceUpdate {
235228
for wsID, oldWorkspace := range old {
236229
if _, exists := new[wsID]; !exists {
237230
out.DeletedWorkspaces = append(out.DeletedWorkspaces, &proto.Workspace{
238-
Id: UUIDToByteSlice(wsID),
231+
Id: tailnet.UUIDToByteSlice(wsID),
239232
Name: oldWorkspace.WorkspaceName,
240233
Status: convertStatus(oldWorkspace.JobStatus, oldWorkspace.Transition),
241234
})
242235
for _, agent := range oldWorkspace.Agents {
243236
out.DeletedAgents = append(out.DeletedAgents, &proto.Agent{
244-
Id: UUIDToByteSlice(agent.ID),
237+
Id: tailnet.UUIDToByteSlice(agent.ID),
245238
Name: agent.Name,
246-
WorkspaceId: UUIDToByteSlice(wsID),
239+
WorkspaceId: tailnet.UUIDToByteSlice(wsID),
247240
})
248241
}
249242
}

tailnet/workspaceupdates_test.go renamed to coderd/workspaceupdates_test.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tailnet_test
1+
package coderd_test
22

33
import (
44
"context"
@@ -9,8 +9,10 @@ import (
99
"github.com/google/uuid"
1010
"github.com/stretchr/testify/require"
1111

12+
"github.com/coder/coder/v2/coderd"
1213
"github.com/coder/coder/v2/coderd/database"
1314
"github.com/coder/coder/v2/coderd/database/pubsub"
15+
"github.com/coder/coder/v2/coderd/rbac"
1416
"github.com/coder/coder/v2/codersdk"
1517
"github.com/coder/coder/v2/tailnet"
1618
"github.com/coder/coder/v2/tailnet/proto"
@@ -78,7 +80,7 @@ func TestWorkspaceUpdates(t *testing.T) {
7880
cbs: map[string]pubsub.Listener{},
7981
}
8082

81-
updateProvider, err := tailnet.NewUpdatesProvider(ctx, db, ps)
83+
updateProvider, err := coderd.NewUpdatesProvider(ctx, db, ps)
8284
require.NoError(t, err)
8385

8486
ch, err := updateProvider.Subscribe(peerID, ownerid)
@@ -217,7 +219,7 @@ func TestWorkspaceUpdates(t *testing.T) {
217219
cbs: map[string]pubsub.Listener{},
218220
}
219221

220-
updateProvider, err := tailnet.NewUpdatesProvider(ctx, db, ps)
222+
updateProvider, err := coderd.NewUpdatesProvider(ctx, db, ps)
221223
require.NoError(t, err)
222224

223225
ch, err := updateProvider.Subscribe(peerID, ownerid)
@@ -257,15 +259,18 @@ type mockWorkspaceStore struct {
257259
orderedRows []database.GetWorkspacesAndAgentsRow
258260
}
259261

260-
var _ tailnet.WorkspaceStore = (*mockWorkspaceStore)(nil)
261-
262-
func (*mockWorkspaceStore) GetWorkspaceByAgentID(context.Context, uuid.UUID) (database.GetWorkspaceByAgentIDRow, error) {
263-
return database.GetWorkspaceByAgentIDRow{}, nil
262+
// GetWorkspaceRBACByAgentID implements tailnet.UpdateQuerier.
263+
func (*mockWorkspaceStore) GetWorkspaceRBACByAgentID(context.Context, uuid.UUID) (rbac.Objecter, error) {
264+
panic("unimplemented")
264265
}
265-
func (db *mockWorkspaceStore) GetWorkspacesAndAgents(context.Context) ([]database.GetWorkspacesAndAgentsRow, error) {
266-
return db.orderedRows, nil
266+
267+
// GetWorkspacesAndAgents implements tailnet.UpdateQuerier.
268+
func (m *mockWorkspaceStore) GetWorkspacesAndAgents(context.Context) ([]database.GetWorkspacesAndAgentsRow, error) {
269+
return m.orderedRows, nil
267270
}
268271

272+
var _ coderd.UpdateQuerier = (*mockWorkspaceStore)(nil)
273+
269274
type mockPubsub struct {
270275
cbs map[string]pubsub.Listener
271276
}

tailnet/service.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717

1818
"cdr.dev/slog"
1919
"github.com/coder/coder/v2/apiversion"
20-
"github.com/coder/coder/v2/coderd/rbac"
2120
"github.com/coder/coder/v2/tailnet/proto"
2221
"github.com/coder/quartz"
2322
)
@@ -40,6 +39,12 @@ func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
4039
return context.WithValue(ctx, streamIDContextKey{}, streamID)
4140
}
4241

42+
type WorkspaceUpdatesProvider interface {
43+
Subscribe(peerID uuid.UUID, userID uuid.UUID) (<-chan *proto.WorkspaceUpdate, error)
44+
Unsubscribe(peerID uuid.UUID)
45+
Stop()
46+
}
47+
4348
type ClientServiceOptions struct {
4449
Logger slog.Logger
4550
CoordPtr *atomic.Pointer[Coordinator]
@@ -114,11 +119,9 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
114119
}
115120

116121
type ServeUserClientOptions struct {
117-
PeerID uuid.UUID
118-
UserID uuid.UUID
119-
Subject *rbac.Subject
120-
Authz rbac.Authorizer
121-
Database WorkspaceStore
122+
PeerID uuid.UUID
123+
UserID uuid.UUID
124+
AuthFn func(context.Context, uuid.UUID) error
122125
}
123126

124127
func (s *ClientService) ServeUserClient(ctx context.Context, version string, conn net.Conn, opts ServeUserClientOptions) error {
@@ -130,10 +133,8 @@ func (s *ClientService) ServeUserClient(ctx context.Context, version string, con
130133
switch major {
131134
case 2:
132135
auth := ClientUserCoordinateeAuth{
133-
UserID: opts.UserID,
134-
RBACSubject: opts.Subject,
135-
Authz: opts.Authz,
136-
Database: opts.Database,
136+
UserID: opts.UserID,
137+
AuthFn: opts.AuthFn,
137138
}
138139
streamID := StreamID{
139140
Name: "client",

tailnet/tunnel.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ import (
88
"github.com/google/uuid"
99
"golang.org/x/xerrors"
1010

11-
"github.com/coder/coder/v2/coderd/rbac"
12-
"github.com/coder/coder/v2/coderd/rbac/policy"
1311
"github.com/coder/coder/v2/tailnet/proto"
1412
)
1513

@@ -95,10 +93,8 @@ func (a AgentCoordinateeAuth) Authorize(_ context.Context, req *proto.Coordinate
9593
}
9694

9795
type ClientUserCoordinateeAuth struct {
98-
UserID uuid.UUID
99-
RBACSubject *rbac.Subject
100-
Authz rbac.Authorizer
101-
Database WorkspaceStore
96+
UserID uuid.UUID
97+
AuthFn func(context.Context, uuid.UUID) error
10298
}
10399

104100
func (a ClientUserCoordinateeAuth) Authorize(ctx context.Context, req *proto.CoordinateRequest) error {
@@ -107,11 +103,7 @@ func (a ClientUserCoordinateeAuth) Authorize(ctx context.Context, req *proto.Coo
107103
if err != nil {
108104
return xerrors.Errorf("parse add tunnel id: %w", err)
109105
}
110-
row, err := a.Database.GetWorkspaceByAgentID(ctx, uid)
111-
if err != nil {
112-
return xerrors.Errorf("get workspace by agent id: %w", err)
113-
}
114-
err = a.Authz.Authorize(ctx, *a.RBACSubject, policy.ActionSSH, row.Workspace.RBACObject())
106+
err = a.AuthFn(ctx, uid)
115107
if err != nil {
116108
return xerrors.Errorf("workspace agent not found or you do not have permission: %w", sql.ErrNoRows)
117109
}

0 commit comments

Comments
 (0)