Skip to content

feat: add WorkspaceUpdates tailnet RPC #14847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
review
  • Loading branch information
ethanndickson committed Nov 1, 2024
commit b5090994b9e86e9be341742f900b520019cf7994
7 changes: 0 additions & 7 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
options.Database = dbmetrics.NewDBMetrics(options.Database, options.Logger, options.PrometheusRegistry)
}

wsUpdates, err := coderd.NewUpdatesProvider(logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)
if err != nil {
return xerrors.Errorf("create workspace updates provider: %w", err)
}
options.WorkspaceUpdatesProvider = wsUpdates
defer wsUpdates.Close()

var deploymentID string
err = options.Database.InTx(func(tx database.Store) error {
// This will block until the lock is acquired, and will be
Expand Down
10 changes: 7 additions & 3 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ type Options struct {

WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions

WorkspaceUpdatesProvider tailnet.WorkspaceUpdatesProvider

// This janky function is used in telemetry to parse fields out of the raw
// JWT. It needs to be passed through like this because license parsing is
// under the enterprise license, and can't be imported into AGPL.
Expand Down Expand Up @@ -495,6 +493,8 @@ func New(options *Options) *API {
}
}

updatesProvider := NewUpdatesProvider(options.Logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)

// Start a background process that rotates keys. We intentionally start this after the caches
// are created to force initial requests for a key to populate the caches. This helps catch
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
Expand Down Expand Up @@ -525,6 +525,7 @@ func New(options *Options) *API {
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
TailnetCoordinator: atomic.Pointer[tailnet.Coordinator]{},
UpdatesProvider: updatesProvider,
TemplateScheduleStore: options.TemplateScheduleStore,
UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore,
AccessControlStore: options.AccessControlStore,
Expand Down Expand Up @@ -660,7 +661,7 @@ func New(options *Options) *API {
DERPMapFn: api.DERPMap,
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
WorkspaceUpdatesProvider: api.Options.WorkspaceUpdatesProvider,
WorkspaceUpdatesProvider: api.UpdatesProvider,
})
if err != nil {
api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err))
Expand Down Expand Up @@ -1415,6 +1416,8 @@ type API struct {
AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore]
PortSharer atomic.Pointer[portsharing.PortSharer]

UpdatesProvider tailnet.WorkspaceUpdatesProvider

HTTPAuth *HTTPAuthorizer

// APIHandler serves "/api/v2"
Expand Down Expand Up @@ -1496,6 +1499,7 @@ func (api *API) Close() error {
_ = api.OIDCConvertKeyCache.Close()
_ = api.AppSigningKeyCache.Close()
_ = api.AppEncryptionKeyCache.Close()
_ = api.UpdatesProvider.Close()
return nil
}

Expand Down
17 changes: 0 additions & 17 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ type Options struct {
APIKeyEncryptionCache cryptokeys.EncryptionKeycache
OIDCConvertKeyCache cryptokeys.SigningKeycache
Clock quartz.Clock

WorkspaceUpdatesProvider tailnet.WorkspaceUpdatesProvider
}

// New constructs a codersdk client connected to an in-memory API instance.
Expand Down Expand Up @@ -256,20 +254,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
}

if options.WorkspaceUpdatesProvider == nil {
var err error
options.WorkspaceUpdatesProvider, err = coderd.NewUpdatesProvider(
options.Logger.Named("workspace_updates"),
options.Pubsub,
options.Database,
options.Authorizer,
)
require.NoError(t, err)
t.Cleanup(func() {
_ = options.WorkspaceUpdatesProvider.Close()
})
}

accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
accessControlStore.Store(&acs)
Expand Down Expand Up @@ -547,7 +531,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
HealthcheckTimeout: options.HealthcheckTimeout,
HealthcheckRefresh: options.HealthcheckRefresh,
StatsBatcher: options.StatsBatcher,
WorkspaceUpdatesProvider: options.WorkspaceUpdatesProvider,
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
NewTicker: options.NewTicker,
Expand Down
36 changes: 18 additions & 18 deletions coderd/workspaceupdates.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
Expand All @@ -23,7 +22,8 @@ import (
)

type UpdatesQuerier interface {
GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID, prep rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
// GetAuthorizedWorkspacesAndAgentsByOwnerID requires a context with an actor set
GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error)
}

Expand All @@ -42,14 +42,14 @@ func (w ownedWorkspace) Equal(other ownedWorkspace) bool {
}

type sub struct {
// ALways contains an actor
ctx context.Context
cancelFn context.CancelFunc

mu sync.RWMutex
userID uuid.UUID
ch chan *proto.WorkspaceUpdate
prev workspacesByID
readPrep rbac.PreparedAuthorized
mu sync.RWMutex
userID uuid.UUID
ch chan *proto.WorkspaceUpdate
prev workspacesByID

db UpdatesQuerier
ps pubsub.Pubsub
Expand All @@ -76,7 +76,8 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
}
}

rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
// Use context containing actor
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(s.ctx, s.userID)
if err != nil {
s.logger.Warn(ctx, "failed to get workspaces and agents by owner ID", slog.Error(err))
return
Expand All @@ -97,7 +98,7 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
}

func (s *sub) start(ctx context.Context) (err error) {
rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(ctx, s.userID)
if err != nil {
return xerrors.Errorf("get workspaces and agents by owner ID: %w", err)
}
Expand Down Expand Up @@ -150,7 +151,7 @@ func NewUpdatesProvider(
ps pubsub.Pubsub,
db UpdatesQuerier,
auth rbac.Authorizer,
) (tailnet.WorkspaceUpdatesProvider, error) {
) tailnet.WorkspaceUpdatesProvider {
ctx, cancel := context.WithCancel(context.Background())
out := &updatesProvider{
auth: auth,
Expand All @@ -160,25 +161,25 @@ func NewUpdatesProvider(
ctx: ctx,
cancelFn: cancel,
}
return out, nil
return out
}

func (u *updatesProvider) Close() error {
u.cancelFn()
return nil
}

// Subscribe subscribes to workspace updates for a user, for the workspaces
// that user is authorized to `ActionRead` on. The provided context must have
// a dbauthz actor set.
func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tailnet.Subscription, error) {
actor, ok := dbauthz.ActorFromContext(ctx)
if !ok {
return nil, xerrors.Errorf("actor not found in context")
}
readPrep, err := u.auth.Prepare(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
return nil, xerrors.Errorf("prepare read action: %w", err)
}
ctx, cancel := context.WithCancel(u.ctx)
ctx = dbauthz.As(ctx, actor)
ch := make(chan *proto.WorkspaceUpdate, 1)
ctx, cancel := context.WithCancel(ctx)
sub := &sub{
ctx: ctx,
cancelFn: cancel,
Expand All @@ -188,9 +189,8 @@ func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tail
ps: u.ps,
logger: u.logger.Named(fmt.Sprintf("workspace_updates_subscriber_%s", userID)),
prev: workspacesByID{},
readPrep: readPrep,
}
err = sub.start(ctx)
err := sub.start(ctx)
if err != nil {
_ = sub.Close()
return nil, err
Expand Down
78 changes: 52 additions & 26 deletions coderd/workspaceupdates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ import (

func TestWorkspaceUpdates(t *testing.T) {
t.Parallel()
ctx := context.Background()

ws1ID := uuid.New()
ws1ID := uuid.UUID{0x01}
ws1IDSlice := tailnet.UUIDToByteSlice(ws1ID)
agent1ID := uuid.New()
agent1ID := uuid.UUID{0x02}
agent1IDSlice := tailnet.UUIDToByteSlice(agent1ID)
ws2ID := uuid.New()
ws2ID := uuid.UUID{0x03}
ws2IDSlice := tailnet.UUIDToByteSlice(ws2ID)
ws3ID := uuid.New()
ws3ID := uuid.UUID{0x04}
ws3IDSlice := tailnet.UUIDToByteSlice(ws3ID)
agent2ID := uuid.New()
agent2ID := uuid.UUID{0x05}
agent2IDSlice := tailnet.UUIDToByteSlice(agent2ID)
ws4ID := uuid.New()
ws4ID := uuid.UUID{0x06}
ws4IDSlice := tailnet.UUIDToByteSlice(ws4ID)
agent3ID := uuid.UUID{0x07}
agent3IDSlice := tailnet.UUIDToByteSlice(agent3ID)

ownerID := uuid.New()
ownerID := uuid.UUID{0x08}
memberRole, err := rbac.RoleByName(rbac.RoleMember())
require.NoError(t, err)
ownerSubject := rbac.Subject{
Expand All @@ -53,9 +54,11 @@ func TestWorkspaceUpdates(t *testing.T) {
t.Run("Basic", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)

db := &mockWorkspaceStore{
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
// Gains a new agent
// Gains agent2
{
ID: ws1ID,
Name: "ws1",
Expand All @@ -81,6 +84,12 @@ func TestWorkspaceUpdates(t *testing.T) {
Name: "ws3",
JobStatus: database.ProvisionerJobStatusSucceeded,
Transition: database.WorkspaceTransitionStop,
Agents: []database.AgentIDNamePair{
{
ID: agent3ID,
Name: "agent3",
},
},
},
},
}
Expand All @@ -89,21 +98,24 @@ func TestWorkspaceUpdates(t *testing.T) {
cbs: map[string]pubsub.ListenerWithErr{},
}

updateProvider, err := coderd.NewUpdatesProvider(slogtest.Make(t, nil), ps, db, &mockAuthorizer{})
require.NoError(t, err)
updateProvider := coderd.NewUpdatesProvider(slogtest.Make(t, nil), ps, db, &mockAuthorizer{})
t.Cleanup(func() {
_ = updateProvider.Close()
})

sub, err := updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
require.NoError(t, err)
ch := sub.Updates()
t.Cleanup(func() {
_ = sub.Close()
})

update, ok := <-ch
require.True(t, ok)
update := testutil.RequireRecvCtx(ctx, t, sub.Updates())
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
return strings.Compare(a.Name, b.Name)
})
slices.SortFunc(update.UpsertedAgents, func(a, b *proto.Agent) int {
return strings.Compare(a.Name, b.Name)
})
require.Equal(t, &proto.WorkspaceUpdate{
UpsertedWorkspaces: []*proto.Workspace{
{
Expand All @@ -128,6 +140,11 @@ func TestWorkspaceUpdates(t *testing.T) {
Name: "agent1",
WorkspaceId: ws1IDSlice,
},
{
Id: agent3IDSlice,
Name: "agent3",
WorkspaceId: ws3IDSlice,
},
},
DeletedWorkspaces: []*proto.Workspace{},
DeletedAgents: []*proto.Agent{},
Expand Down Expand Up @@ -169,8 +186,7 @@ func TestWorkspaceUpdates(t *testing.T) {
WorkspaceID: ws1ID,
})

update, ok = <-ch
require.True(t, ok)
update = testutil.RequireRecvCtx(ctx, t, sub.Updates())
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
return strings.Compare(a.Name, b.Name)
})
Expand Down Expand Up @@ -203,13 +219,21 @@ func TestWorkspaceUpdates(t *testing.T) {
Status: proto.Workspace_STOPPED,
},
},
DeletedAgents: []*proto.Agent{},
DeletedAgents: []*proto.Agent{
{
Id: agent3IDSlice,
Name: "agent3",
WorkspaceId: ws3IDSlice,
},
},
}, update)
})

t.Run("Resubscribe", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)

db := &mockWorkspaceStore{
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
{
Expand All @@ -231,15 +255,16 @@ func TestWorkspaceUpdates(t *testing.T) {
cbs: map[string]pubsub.ListenerWithErr{},
}

updateProvider, err := coderd.NewUpdatesProvider(slogtest.Make(t, nil), ps, db, &mockAuthorizer{})
require.NoError(t, err)
updateProvider := coderd.NewUpdatesProvider(slogtest.Make(t, nil), ps, db, &mockAuthorizer{})
t.Cleanup(func() {
_ = updateProvider.Close()
})

sub, err := updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
require.NoError(t, err)
ch := sub.Updates()
t.Cleanup(func() {
_ = sub.Close()
})

expected := &proto.WorkspaceUpdate{
UpsertedWorkspaces: []*proto.Workspace{
Expand All @@ -260,18 +285,19 @@ func TestWorkspaceUpdates(t *testing.T) {
DeletedAgents: []*proto.Agent{},
}

update := testutil.RequireRecvCtx(ctx, t, ch)
update := testutil.RequireRecvCtx(ctx, t, sub.Updates())
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
return strings.Compare(a.Name, b.Name)
})
require.Equal(t, expected, update)

resub, err := updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
require.NoError(t, err)
sub, err = updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
require.NoError(t, err)
ch = sub.Updates()
t.Cleanup(func() {
_ = resub.Close()
})

update = testutil.RequireRecvCtx(ctx, t, ch)
update = testutil.RequireRecvCtx(ctx, t, resub.Updates())
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
return strings.Compare(a.Name, b.Name)
})
Expand All @@ -290,7 +316,7 @@ type mockWorkspaceStore struct {
}

// GetAuthorizedWorkspacesAndAgentsByOwnerID implements coderd.UpdatesQuerier.
func (m *mockWorkspaceStore) GetAuthorizedWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID, rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
func (m *mockWorkspaceStore) GetWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
return m.orderedRows, nil
}

Expand Down
Loading
Loading