From 44fd5381def6c425cb3488dcb8e401f45f0fc4a3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 6 Nov 2024 12:42:42 +0400 Subject: [PATCH] feat: add support for multiple tunnel destinations in tailnet --- agent/agent_test.go | 6 +- codersdk/workspacesdk/workspacesdk.go | 4 +- tailnet/controllers.go | 157 +++++++++- tailnet/controllers_test.go | 383 +++++++++++++++++++++++- tailnet/test/integration/integration.go | 3 +- 5 files changed, 530 insertions(+), 23 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index a1e6af43e042f..277702faebf02 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1918,7 +1918,8 @@ func TestAgent_UpdatedDERP(t *testing.T) { testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - ctrl := tailnet.NewSingleDestController(logger, conn, agentID) + ctrl := tailnet.NewTunnelSrcCoordController(logger, conn) + ctrl.AddDestination(agentID) auth := tailnet.ClientCoordinateeAuth{AgentID: agentID} coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator)) t.Cleanup(func() { @@ -2408,7 +2409,8 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID) + ctrl := tailnet.NewTunnelSrcCoordController(logger, conn) + ctrl.AddDestination(metadata.AgentID) auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID} coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient( logger, clientID, auth, coordinator)) diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 5ce0c06065173..2e0214fdc1010 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -268,7 +268,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * _ = conn.Close() } }() - controller.CoordCtrl = tailnet.NewSingleDestController(options.Logger, conn, agentID) + coordCtrl := tailnet.NewTunnelSrcCoordController(options.Logger, conn) + coordCtrl.AddDestination(agentID) + controller.CoordCtrl = coordCtrl controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn) controller.Run(ctx) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index bc2b814b91368..7a96ddf18e662 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "maps" "math" "strings" "sync" @@ -239,7 +240,8 @@ func (c *BasicCoordination) respLoop() { defer func() { cErr := c.Client.Close() if cErr != nil { - c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr)) + c.logger.Debug(context.Background(), + "failed to close coordinate client after respLoop exit", slog.Error(cErr)) } c.coordinatee.SetAllPeersLost() close(c.respLoopDone) @@ -247,7 +249,8 @@ func (c *BasicCoordination) respLoop() { for { resp, err := c.Client.Recv() if err != nil { - c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) + c.logger.Debug(context.Background(), + "failed to read from protocol", slog.Error(err)) c.SendErr(xerrors.Errorf("read: %w", err)) return } @@ -278,7 +281,8 @@ func (c *BasicCoordination) respLoop() { ReadyForHandshake: rfh, }) if err != nil { - c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) + c.logger.Debug(context.Background(), + "failed to send ready for handshake", slog.Error(err)) c.SendErr(xerrors.Errorf("send: %w", err)) return } @@ -287,37 +291,158 @@ func (c *BasicCoordination) respLoop() { } } -type singleDestController struct { +type TunnelSrcCoordController struct { *BasicCoordinationController - dest uuid.UUID + + mu sync.Mutex + dests map[uuid.UUID]struct{} + coordination *BasicCoordination } -// NewSingleDestController creates a CoordinationController for Coder clients that connect to a -// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent. -func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController { - coordinatee.SetTunnelDestination(dest) - return &singleDestController{ +// NewTunnelSrcCoordController creates a CoordinationController for peers that are exclusively +// tunnel sources (that is, they create tunnel --- Coder clients not workspaces). +func NewTunnelSrcCoordController( + logger slog.Logger, coordinatee Coordinatee, +) *TunnelSrcCoordController { + return &TunnelSrcCoordController{ BasicCoordinationController: &BasicCoordinationController{ Logger: logger, Coordinatee: coordinatee, SendAcks: false, }, - dest: dest, + dests: make(map[uuid.UUID]struct{}), } } -func (c *singleDestController) New(client CoordinatorClient) CloserWaiter { +func (c *TunnelSrcCoordController) New(client CoordinatorClient) CloserWaiter { + c.mu.Lock() + defer c.mu.Unlock() b := c.BasicCoordinationController.NewCoordination(client) - err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}}) - if err != nil { - b.SendErr(err) + c.coordination = b + // resync destinations on reconnect + for dest := range c.dests { + err := client.Send(&proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, + }) + if err != nil { + b.SendErr(err) + c.coordination = nil + cErr := client.Close() + if cErr != nil { + c.Logger.Debug( + context.Background(), + "failed to close coordinator client after add tunnel failure", + slog.Error(cErr), + ) + } + break + } } return b } +func (c *TunnelSrcCoordController) AddDestination(dest uuid.UUID) { + c.mu.Lock() + defer c.mu.Unlock() + c.Coordinatee.SetTunnelDestination(dest) // this prepares us for an ack + c.dests[dest] = struct{}{} + if c.coordination == nil { + return + } + err := c.coordination.Client.Send( + &proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, + }) + if err != nil { + c.coordination.SendErr(err) + cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect + if cErr != nil { + c.Logger.Debug(context.Background(), + "failed to close coordinator client after add tunnel failure", + slog.Error(cErr)) + } + c.coordination = nil + } +} + +func (c *TunnelSrcCoordController) RemoveDestination(dest uuid.UUID) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.dests, dest) + if c.coordination == nil { + return + } + err := c.coordination.Client.Send( + &proto.CoordinateRequest{ + RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, + }) + if err != nil { + c.coordination.SendErr(err) + cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect + if cErr != nil { + c.Logger.Debug(context.Background(), + "failed to close coordinator client after remove tunnel failure", + slog.Error(cErr)) + } + c.coordination = nil + } +} + +func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) { + c.mu.Lock() + defer c.mu.Unlock() + toAdd := make(map[uuid.UUID]struct{}) + toRemove := maps.Clone(c.dests) + all := make(map[uuid.UUID]struct{}) + for _, dest := range destinations { + all[dest] = struct{}{} + delete(toRemove, dest) + if _, ok := c.dests[dest]; !ok { + toAdd[dest] = struct{}{} + } + } + c.dests = all + if c.coordination == nil { + return + } + var err error + defer func() { + if err != nil { + c.coordination.SendErr(err) + cErr := c.coordination.Client.Close() // don't gracefully disconnect + if cErr != nil { + c.Logger.Debug(context.Background(), + "failed to close coordinator client during sync destinations", + slog.Error(cErr)) + } + c.coordination = nil + } + }() + for dest := range toAdd { + err = c.coordination.Client.Send( + &proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, + }) + if err != nil { + return + } + } + for dest := range toRemove { + err = c.coordination.Client.Send( + &proto.CoordinateRequest{ + RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, + }) + if err != nil { + return + } + } +} + // NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never // create tunnels and always send ReadyToHandshake acknowledgements. -func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController { +func NewAgentCoordinationController( + logger slog.Logger, coordinatee Coordinatee, +) CoordinationController { return &BasicCoordinationController{ Logger: logger, Coordinatee: coordinatee, diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index fea8bc0e02ab3..90aa4c7f9bc48 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "slices" "sync" "sync/atomic" "testing" @@ -46,7 +47,8 @@ func TestInMemoryCoordination(t *testing.T) { mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth). Times(1).Return(reqs, resps) - ctrl := tailnet.NewSingleDestController(logger, fConn, agentID) + ctrl := tailnet.NewTunnelSrcCoordController(logger, fConn) + ctrl.AddDestination(agentID) uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord)) defer uut.Close(ctx) @@ -57,7 +59,7 @@ func TestInMemoryCoordination(t *testing.T) { require.ErrorIs(t, err, io.EOF) } -func TestSingleDestController(t *testing.T) { +func TestTunnelSrcCoordController_Mainline(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) @@ -102,7 +104,8 @@ func TestSingleDestController(t *testing.T) { protocol, err := client.Coordinate(ctx) require.NoError(t, err) - ctrl := tailnet.NewSingleDestController(logger.Named("coordination"), fConn, agentID) + ctrl := tailnet.NewTunnelSrcCoordController(logger.Named("coordination"), fConn) + ctrl.AddDestination(agentID) uut := ctrl.New(protocol) defer uut.Close(ctx) @@ -113,6 +116,284 @@ func TestSingleDestController(t *testing.T) { require.ErrorIs(t, err, io.EOF) } +func TestTunnelSrcCoordController_AddDestination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + uut := tailnet.NewTunnelSrcCoordController(logger, fConn) + + // GIVEN: client already connected + client1 := newFakeCoordinatorClient(ctx, t) + cw1 := uut.New(client1) + + // WHEN: we add 2 destinations + dest1 := uuid.UUID{1} + dest2 := uuid.UUID{2} + addDone := make(chan struct{}) + go func() { + defer close(addDone) + uut.AddDestination(dest1) + uut.AddDestination(dest2) + }() + + // THEN: Controller sends AddTunnel for the destinations + for i := range 2 { + b0 := byte(i + 1) + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + require.Equal(t, b0, call.req.GetAddTunnel().GetId()[0]) + testutil.RequireSendCtx(ctx, t, call.err, nil) + } + _ = testutil.RequireRecvCtx(ctx, t, addDone) + + // THEN: Controller sets destinations on Coordinatee + require.Contains(t, fConn.tunnelDestinations, dest1) + require.Contains(t, fConn.tunnelDestinations, dest2) + + // WHEN: Closed from server side and reconnects + respCall := testutil.RequireRecvCtx(ctx, t, client1.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + closeCall := testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.ErrorIs(t, err, io.EOF) + client2 := newFakeCoordinatorClient(ctx, t) + cws := make(chan tailnet.CloserWaiter) + go func() { + cws <- uut.New(client2) + }() + + // THEN: should immediately send both destinations + var dests []byte + for range 2 { + call := testutil.RequireRecvCtx(ctx, t, client2.reqs) + dests = append(dests, call.req.GetAddTunnel().GetId()[0]) + testutil.RequireSendCtx(ctx, t, call.err, nil) + } + slices.Sort(dests) + require.Equal(t, dests, []byte{1, 2}) + + cw2 := testutil.RequireRecvCtx(ctx, t, cws) + + // close client2 + respCall = testutil.RequireRecvCtx(ctx, t, client2.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + closeCall = testutil.RequireRecvCtx(ctx, t, client2.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err = testutil.RequireRecvCtx(ctx, t, cw2.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + uut := tailnet.NewTunnelSrcCoordController(logger, fConn) + + // GIVEN: 1 destination + dest1 := uuid.UUID{1} + uut.AddDestination(dest1) + + // GIVEN: client already connected + client1 := newFakeCoordinatorClient(ctx, t) + cws := make(chan tailnet.CloserWaiter) + go func() { + cws <- uut.New(client1) + }() + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + testutil.RequireSendCtx(ctx, t, call.err, nil) + cw1 := testutil.RequireRecvCtx(ctx, t, cws) + + // WHEN: we remove one destination + removeDone := make(chan struct{}) + go func() { + defer close(removeDone) + uut.RemoveDestination(dest1) + }() + + // THEN: Controller sends RemoveTunnel for the destination + call = testutil.RequireRecvCtx(ctx, t, client1.reqs) + require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, call.err, nil) + _ = testutil.RequireRecvCtx(ctx, t, removeDone) + + // WHEN: Closed from server side and reconnect + respCall := testutil.RequireRecvCtx(ctx, t, client1.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + closeCall := testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.ErrorIs(t, err, io.EOF) + + client2 := newFakeCoordinatorClient(ctx, t) + go func() { + cws <- uut.New(client2) + }() + + // THEN: should immediately resolve without sending anything + cw2 := testutil.RequireRecvCtx(ctx, t, cws) + + // close client2 + respCall = testutil.RequireRecvCtx(ctx, t, client2.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + closeCall = testutil.RequireRecvCtx(ctx, t, client2.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err = testutil.RequireRecvCtx(ctx, t, cw2.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestTunnelSrcCoordController_RemoveDestination_Error(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + uut := tailnet.NewTunnelSrcCoordController(logger, fConn) + + // GIVEN: 3 destination + dest1 := uuid.UUID{1} + dest2 := uuid.UUID{2} + dest3 := uuid.UUID{3} + uut.AddDestination(dest1) + uut.AddDestination(dest2) + uut.AddDestination(dest3) + + // GIVEN: client already connected + client1 := newFakeCoordinatorClient(ctx, t) + cws := make(chan tailnet.CloserWaiter) + go func() { + cws <- uut.New(client1) + }() + for range 3 { + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + testutil.RequireSendCtx(ctx, t, call.err, nil) + } + cw1 := testutil.RequireRecvCtx(ctx, t, cws) + + // WHEN: we remove all destinations + removeDone := make(chan struct{}) + go func() { + defer close(removeDone) + uut.RemoveDestination(dest1) + uut.RemoveDestination(dest2) + uut.RemoveDestination(dest3) + }() + + // WHEN: first RemoveTunnel call fails + theErr := xerrors.New("a bad thing happened") + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, call.err, theErr) + + // THEN: we disconnect and do not send remaining RemoveTunnel messages + closeCall := testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + _ = testutil.RequireRecvCtx(ctx, t, removeDone) + + // shut down + respCall := testutil.RequireRecvCtx(ctx, t, client1.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + // triggers second close call + closeCall = testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.ErrorIs(t, err, theErr) +} + +func TestTunnelSrcCoordController_Sync(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + uut := tailnet.NewTunnelSrcCoordController(logger, fConn) + dest1 := uuid.UUID{1} + dest2 := uuid.UUID{2} + dest3 := uuid.UUID{3} + + // GIVEN: dest1 & dest2 already added + uut.AddDestination(dest1) + uut.AddDestination(dest2) + + // GIVEN: client already connected + client1 := newFakeCoordinatorClient(ctx, t) + cws := make(chan tailnet.CloserWaiter) + go func() { + cws <- uut.New(client1) + }() + for range 2 { + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + testutil.RequireSendCtx(ctx, t, call.err, nil) + } + cw1 := testutil.RequireRecvCtx(ctx, t, cws) + + // WHEN: we sync dest2 & dest3 + syncDone := make(chan struct{}) + go func() { + defer close(syncDone) + uut.SyncDestinations([]uuid.UUID{dest2, dest3}) + }() + + // THEN: we get an add for dest3 and remove for dest1 + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + require.Equal(t, dest3[:], call.req.GetAddTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, call.err, nil) + call = testutil.RequireRecvCtx(ctx, t, client1.reqs) + require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId()) + testutil.RequireSendCtx(ctx, t, call.err, nil) + + // shut down + respCall := testutil.RequireRecvCtx(ctx, t, client1.resps) + testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF) + closeCall := testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestTunnelSrcCoordController_AddDestination_Error(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + fConn := &fakeCoordinatee{} + uut := tailnet.NewTunnelSrcCoordController(logger, fConn) + + // GIVEN: client already connected + client1 := newFakeCoordinatorClient(ctx, t) + cw1 := uut.New(client1) + + // WHEN: we add a destination, and the AddTunnel fails + dest1 := uuid.UUID{1} + addDone := make(chan struct{}) + go func() { + defer close(addDone) + uut.AddDestination(dest1) + }() + theErr := xerrors.New("a bad thing happened") + call := testutil.RequireRecvCtx(ctx, t, client1.reqs) + testutil.RequireSendCtx(ctx, t, call.err, theErr) + + // THEN: Client is closed and exits + closeCall := testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + + // close the resps, since the client has closed + resp := testutil.RequireRecvCtx(ctx, t, client1.resps) + testutil.RequireSendCtx(ctx, t, resp.err, net.ErrClosed) + // this triggers a second Close() call on the client + closeCall = testutil.RequireRecvCtx(ctx, t, client1.close) + testutil.RequireSendCtx(ctx, t, closeCall, nil) + + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.ErrorIs(t, err, theErr) + + _ = testutil.RequireRecvCtx(ctx, t, addDone) +} + func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -885,3 +1166,99 @@ func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (t Telemetry: client, }, nil } + +type fakeCoordinatorClient struct { + ctx context.Context + t testing.TB + reqs chan *coordReqCall + resps chan *coordRespCall + close chan chan<- error +} + +func (f fakeCoordinatorClient) Close() error { + f.t.Helper() + errs := make(chan error) + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send close call") + return f.ctx.Err() + case f.close <- errs: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for close call response") + return f.ctx.Err() + case err := <-errs: + return err + } +} + +func (f fakeCoordinatorClient) Send(request *proto.CoordinateRequest) error { + f.t.Helper() + errs := make(chan error) + call := &coordReqCall{ + req: request, + err: errs, + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send call") + return f.ctx.Err() + case f.reqs <- call: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for send call response") + return f.ctx.Err() + case err := <-errs: + return err + } +} + +func (f fakeCoordinatorClient) Recv() (*proto.CoordinateResponse, error) { + f.t.Helper() + resps := make(chan *proto.CoordinateResponse) + errs := make(chan error) + call := &coordRespCall{ + resp: resps, + err: errs, + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send Recv() call") + return nil, f.ctx.Err() + case f.resps <- call: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for Recv() call response") + return nil, f.ctx.Err() + case err := <-errs: + return nil, err + case resp := <-resps: + return resp, nil + } +} + +func newFakeCoordinatorClient(ctx context.Context, t testing.TB) *fakeCoordinatorClient { + return &fakeCoordinatorClient{ + ctx: ctx, + t: t, + reqs: make(chan *coordReqCall), + resps: make(chan *coordRespCall), + close: make(chan chan<- error), + } +} + +type coordReqCall struct { + req *proto.CoordinateRequest + err chan<- error +} + +type coordRespCall struct { + resp chan<- *proto.CoordinateResponse + err chan<- error +} diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index 232e7ab027d72..62825973e75a0 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -467,7 +467,8 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me _ = conn.Close() }) - ctrl := tailnet.NewSingleDestController(logger, conn, peer.ID) + ctrl := tailnet.NewTunnelSrcCoordController(logger, conn) + ctrl.AddDestination(peer.ID) coordination := ctrl.New(coord) t.Cleanup(func() { cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)