diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 7a96ddf18e662..d250422160ea9 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -26,13 +26,14 @@ import ( // A Controller connects to the tailnet control plane, and then uses the control protocols to // program a tailnet.Conn in production (in test it could be an interface simulating the Conn). It // delegates this task to sub-controllers responsible for the main areas of the tailnet control -// protocol: coordination, DERP map updates, resume tokens, and telemetry. +// protocol: coordination, DERP map updates, resume tokens, telemetry, and workspace updates. type Controller struct { - Dialer ControlProtocolDialer - CoordCtrl CoordinationController - DERPCtrl DERPController - ResumeTokenCtrl ResumeTokenController - TelemetryCtrl TelemetryController + Dialer ControlProtocolDialer + CoordCtrl CoordinationController + DERPCtrl DERPController + ResumeTokenCtrl ResumeTokenController + TelemetryCtrl TelemetryController + WorkspaceUpdatesCtrl WorkspaceUpdatesController ctx context.Context gracefulCtx context.Context @@ -94,15 +95,25 @@ type TelemetryController interface { New(TelemetryClient) } +type WorkspaceUpdatesClient interface { + Close() error + Recv() (*proto.WorkspaceUpdate, error) +} + +type WorkspaceUpdatesController interface { + New(WorkspaceUpdatesClient) CloserWaiter +} + // ControlProtocolClients represents an abstract interface to the tailnet control plane via a set // of protocol clients. The Closer should close all the clients (e.g. by closing the underlying // connection). type ControlProtocolClients struct { - Closer io.Closer - Coordinator CoordinatorClient - DERP DERPClient - ResumeToken ResumeTokenClient - Telemetry TelemetryClient + Closer io.Closer + Coordinator CoordinatorClient + DERP DERPClient + ResumeToken ResumeTokenClient + Telemetry TelemetryClient + WorkspaceUpdates WorkspaceUpdatesClient } type ControlProtocolDialer interface { @@ -419,6 +430,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) { } }() for dest := range toAdd { + c.Coordinatee.SetTunnelDestination(dest) err = c.coordination.Client.Send( &proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, @@ -822,6 +834,213 @@ func (r *basicResumeTokenRefresher) refresh() { r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh") } +type tunnelAllWorkspaceUpdatesController struct { + coordCtrl *TunnelSrcCoordController + logger slog.Logger +} + +type workspace struct { + id uuid.UUID + name string + agents map[uuid.UUID]agent +} + +type agent struct { + id uuid.UUID + name string +} + +func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter { + updater := &tunnelUpdater{ + client: client, + errChan: make(chan error, 1), + logger: t.logger, + coordCtrl: t.coordCtrl, + recvLoopDone: make(chan struct{}), + workspaces: make(map[uuid.UUID]*workspace), + } + go updater.recvLoop() + return updater +} + +type tunnelUpdater struct { + errChan chan error + logger slog.Logger + client WorkspaceUpdatesClient + coordCtrl *TunnelSrcCoordController + recvLoopDone chan struct{} + + // don't need the mutex since only manipulated by the recvLoop + workspaces map[uuid.UUID]*workspace + + sync.Mutex + closed bool +} + +func (t *tunnelUpdater) Close(ctx context.Context) error { + t.Lock() + defer t.Unlock() + if t.closed { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.recvLoopDone: + return nil + } + } + t.closed = true + cErr := t.client.Close() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.recvLoopDone: + return cErr + } +} + +func (t *tunnelUpdater) Wait() <-chan error { + return t.errChan +} + +func (t *tunnelUpdater) recvLoop() { + t.logger.Debug(context.Background(), "tunnel updater recvLoop started") + defer t.logger.Debug(context.Background(), "tunnel updater recvLoop done") + defer close(t.recvLoopDone) + for { + update, err := t.client.Recv() + if err != nil { + t.logger.Debug(context.Background(), "failed to receive workspace Update", slog.Error(err)) + select { + case t.errChan <- err: + default: + } + return + } + t.logger.Debug(context.Background(), "got workspace update", + slog.F("workspace_update", update), + ) + err = t.handleUpdate(update) + if err != nil { + t.logger.Critical(context.Background(), "failed to handle workspace Update", slog.Error(err)) + cErr := t.client.Close() + if cErr != nil { + t.logger.Warn(context.Background(), "failed to close client", slog.Error(cErr)) + } + select { + case t.errChan <- err: + default: + } + return + } + } +} + +func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error { + for _, uw := range update.UpsertedWorkspaces { + workspaceID, err := uuid.FromBytes(uw.Id) + if err != nil { + return xerrors.Errorf("failed to parse workspace ID: %w", err) + } + w := workspace{ + id: workspaceID, + name: uw.Name, + agents: make(map[uuid.UUID]agent), + } + t.upsertWorkspace(w) + } + + // delete agents before deleting workspaces, since the agents have workspace ID references + for _, da := range update.DeletedAgents { + agentID, err := uuid.FromBytes(da.Id) + if err != nil { + return xerrors.Errorf("failed to parse agent ID: %w", err) + } + workspaceID, err := uuid.FromBytes(da.WorkspaceId) + if err != nil { + return xerrors.Errorf("failed to parse workspace ID: %w", err) + } + err = t.deleteAgent(workspaceID, agentID) + if err != nil { + return xerrors.Errorf("failed to delete agent: %w", err) + } + } + for _, dw := range update.DeletedWorkspaces { + workspaceID, err := uuid.FromBytes(dw.Id) + if err != nil { + return xerrors.Errorf("failed to parse workspace ID: %w", err) + } + t.deleteWorkspace(workspaceID) + } + + // upsert agents last, after all workspaces have been added and deleted, since agents reference + // workspace ID. + for _, ua := range update.UpsertedAgents { + agentID, err := uuid.FromBytes(ua.Id) + if err != nil { + return xerrors.Errorf("failed to parse agent ID: %w", err) + } + workspaceID, err := uuid.FromBytes(ua.WorkspaceId) + if err != nil { + return xerrors.Errorf("failed to parse workspace ID: %w", err) + } + a := agent{name: ua.Name, id: agentID} + err = t.upsertAgent(workspaceID, a) + if err != nil { + return xerrors.Errorf("failed to upsert agent: %w", err) + } + } + allAgents := t.allAgentIDs() + t.coordCtrl.SyncDestinations(allAgents) + return nil +} + +func (t *tunnelUpdater) upsertWorkspace(w workspace) { + old, ok := t.workspaces[w.id] + if !ok { + t.workspaces[w.id] = &w + return + } + old.name = w.name +} + +func (t *tunnelUpdater) deleteWorkspace(id uuid.UUID) { + delete(t.workspaces, id) +} + +func (t *tunnelUpdater) upsertAgent(workspaceID uuid.UUID, a agent) error { + w, ok := t.workspaces[workspaceID] + if !ok { + return xerrors.Errorf("workspace %s not found", workspaceID) + } + w.agents[a.id] = a + return nil +} + +func (t *tunnelUpdater) deleteAgent(workspaceID, id uuid.UUID) error { + w, ok := t.workspaces[workspaceID] + if !ok { + return xerrors.Errorf("workspace %s not found", workspaceID) + } + delete(w.agents, id) + return nil +} + +func (t *tunnelUpdater) allAgentIDs() []uuid.UUID { + out := make([]uuid.UUID, 0, len(t.workspaces)) + for _, w := range t.workspaces { + for id := range w.agents { + out = append(out, id) + } + } + return out +} + +func NewTunnelAllWorkspaceUpdatesController( + logger slog.Logger, c *TunnelSrcCoordController, +) WorkspaceUpdatesController { + return &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c} +} + // NewController creates a new Controller without running it func NewController(logger slog.Logger, dialer ControlProtocolDialer, opts ...ControllerOpt) *Controller { c := &Controller{ diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 90aa4c7f9bc48..26b8286eb3d7e 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -2,6 +2,7 @@ package tailnet_test import ( "context" + "fmt" "io" "net" "slices" @@ -345,6 +346,10 @@ func TestTunnelSrcCoordController_Sync(t *testing.T) { require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId()) testutil.RequireSendCtx(ctx, t, call.err, nil) + testutil.RequireRecvCtx(ctx, t, syncDone) + // dest3 should be added to coordinatee + require.Contains(t, fConn.tunnelDestinations, dest3) + // shut down respCall := testutil.RequireRecvCtx(ctx, t, client1.resps) testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) @@ -1262,3 +1267,312 @@ type coordRespCall struct { resp chan<- *proto.CoordinateResponse err chan<- error } + +type fakeWorkspaceUpdateClient struct { + ctx context.Context + t testing.TB + recv chan *updateRecvCall + close chan chan<- error +} + +func (f *fakeWorkspaceUpdateClient) Close() error { + f.t.Helper() + errs := make(chan error) + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send close call") + return f.ctx.Err() + case f.close <- errs: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for close call response") + return f.ctx.Err() + case err := <-errs: + return err + } +} + +func (f *fakeWorkspaceUpdateClient) Recv() (*proto.WorkspaceUpdate, error) { + f.t.Helper() + resps := make(chan *proto.WorkspaceUpdate) + errs := make(chan error) + call := &updateRecvCall{ + resp: resps, + err: errs, + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send Recv() call") + return nil, f.ctx.Err() + case f.recv <- call: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for Recv() call response") + return nil, f.ctx.Err() + case err := <-errs: + return nil, err + case resp := <-resps: + return resp, nil + } +} + +func newFakeWorkspaceUpdateClient(ctx context.Context, t testing.TB) *fakeWorkspaceUpdateClient { + return &fakeWorkspaceUpdateClient{ + ctx: ctx, + t: t, + recv: make(chan *updateRecvCall), + close: make(chan chan<- error), + } +} + +type updateRecvCall struct { + resp chan<- *proto.WorkspaceUpdate + err chan<- error +} + +// testUUID returns a UUID with bytes set as b, but shifted 6 bytes so that service prefixes don't +// overwrite them. +func testUUID(b ...byte) uuid.UUID { + o := uuid.UUID{} + for i := range b { + o[i+6] = b[i] + } + return o +} + +func setupConnectedAllWorkspaceUpdatesController( + ctx context.Context, t testing.TB, logger slog.Logger, +) ( + *fakeCoordinatorClient, *fakeWorkspaceUpdateClient, +) { + fConn := &fakeCoordinatee{} + tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc) + + // connect up a coordinator client, to track adding and removing tunnels + coordC := newFakeCoordinatorClient(ctx, t) + coordCW := tsc.New(coordC) + t.Cleanup(func() { + // hang up coord client + coordRecv := testutil.RequireRecvCtx(ctx, t, coordC.resps) + testutil.RequireSendCtx(ctx, t, coordRecv.err, io.EOF) + // sends close on client + cCall := testutil.RequireRecvCtx(ctx, t, coordC.close) + testutil.RequireSendCtx(ctx, t, cCall, nil) + err := testutil.RequireRecvCtx(ctx, t, coordCW.Wait()) + require.ErrorIs(t, err, io.EOF) + }) + + // connect up the updates client + updateC := newFakeWorkspaceUpdateClient(ctx, t) + updateCW := uut.New(updateC) + t.Cleanup(func() { + // hang up WorkspaceUpdates client + upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, upRecvCall.err, io.EOF) + err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait()) + require.ErrorIs(t, err, io.EOF) + }) + return coordC, updateC +} + +func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger) + + // Initial update contains 2 workspaces with 1 & 2 agents, respectively + w1ID := testUUID(1) + w2ID := testUUID(2) + w1a1ID := testUUID(1, 1) + w2a1ID := testUUID(2, 1) + w2a2ID := testUUID(2, 2) + initUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w1ID[:], Name: "w1"}, + {Id: w2ID[:], Name: "w2"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]}, + {Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:]}, + {Id: w2a2ID[:], Name: "w2a2", WorkspaceId: w2ID[:]}, + }, + } + + upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp) + + // This should trigger AddTunnel for each agent + var adds []uuid.UUID + for range 3 { + coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs) + adds = append(adds, uuid.Must(uuid.FromBytes(coordCall.req.GetAddTunnel().GetId()))) + testutil.RequireSendCtx(ctx, t, coordCall.err, nil) + } + require.Contains(t, adds, w1a1ID) + require.Contains(t, adds, w2a1ID) + require.Contains(t, adds, w2a2ID) +} + +func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger) + + w1ID := testUUID(1) + w1a1ID := testUUID(1, 1) + w1a2ID := testUUID(1, 2) + initUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w1ID[:], Name: "w1"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]}, + }, + } + + upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp) + + // Add for w1a1 + coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs) + require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, coordCall.err, nil) + + // Send update that removes w1a1 and adds w1a2 + agentUpdate := &proto.WorkspaceUpdate{ + UpsertedAgents: []*proto.Agent{ + {Id: w1a2ID[:], Name: "w1a2", WorkspaceId: w1ID[:]}, + }, + DeletedAgents: []*proto.Agent{ + {Id: w1a1ID[:], WorkspaceId: w1ID[:]}, + }, + } + upRecvCall = testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, upRecvCall.resp, agentUpdate) + + // Add for w1a2 + coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs) + require.Equal(t, w1a2ID[:], coordCall.req.GetAddTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, coordCall.err, nil) + + // Remove for w1a1 + coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs) + require.Equal(t, w1a1ID[:], coordCall.req.GetRemoveTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, coordCall.err, nil) +} + +func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) { + t.Parallel() + validWorkspaceID := testUUID(1) + validAgentID := testUUID(1, 1) + + testCases := []struct { + name string + update *proto.WorkspaceUpdate + errorContains string + }{ + { + name: "unparsableUpsertWorkspaceID", + update: &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: []byte{2, 2}, Name: "bander"}, + }, + }, + errorContains: "failed to parse workspace ID", + }, + { + name: "unparsableDeleteWorkspaceID", + update: &proto.WorkspaceUpdate{ + DeletedWorkspaces: []*proto.Workspace{ + {Id: []byte{2, 2}, Name: "bander"}, + }, + }, + errorContains: "failed to parse workspace ID", + }, + { + name: "unparsableDeleteAgentWorkspaceID", + update: &proto.WorkspaceUpdate{ + DeletedAgents: []*proto.Agent{ + {Id: validAgentID[:], Name: "devo", WorkspaceId: []byte{2, 2}}, + }, + }, + errorContains: "failed to parse workspace ID", + }, + { + name: "unparsableUpsertAgentWorkspaceID", + update: &proto.WorkspaceUpdate{ + UpsertedAgents: []*proto.Agent{ + {Id: validAgentID[:], Name: "devo", WorkspaceId: []byte{2, 2}}, + }, + }, + errorContains: "failed to parse workspace ID", + }, + { + name: "unparsableDeleteAgentID", + update: &proto.WorkspaceUpdate{ + DeletedAgents: []*proto.Agent{ + {Id: []byte{2, 2}, Name: "devo", WorkspaceId: validWorkspaceID[:]}, + }, + }, + errorContains: "failed to parse agent ID", + }, + { + name: "unparsableUpsertAgentID", + update: &proto.WorkspaceUpdate{ + UpsertedAgents: []*proto.Agent{ + {Id: []byte{2, 2}, Name: "devo", WorkspaceId: validWorkspaceID[:]}, + }, + }, + errorContains: "failed to parse agent ID", + }, + { + name: "upsertAgentMissingWorkspace", + update: &proto.WorkspaceUpdate{ + UpsertedAgents: []*proto.Agent{ + {Id: validAgentID[:], Name: "devo", WorkspaceId: validWorkspaceID[:]}, + }, + }, + errorContains: fmt.Sprintf("workspace %s not found", validWorkspaceID.String()), + }, + { + name: "deleteAgentMissingWorkspace", + update: &proto.WorkspaceUpdate{ + DeletedAgents: []*proto.Agent{ + {Id: validAgentID[:], Name: "devo", WorkspaceId: validWorkspaceID[:]}, + }, + }, + errorContains: fmt.Sprintf("workspace %s not found", validWorkspaceID.String()), + }, + } + // nolint: paralleltest // no longer need to reinitialize loop vars in go 1.22 + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc) + updateC := newFakeWorkspaceUpdateClient(ctx, t) + updateCW := uut.New(updateC) + + recvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, recvCall.resp, tc.update) + closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + + err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait()) + require.ErrorContains(t, err, tc.errorContains) + }) + } +}