From 540d7eb040dcd719b0e0f44c198fe0f46f84c1d3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 13 Sep 2024 09:01:12 +0400 Subject: [PATCH] fix: fix tailnet remoteCoordination to wait for server --- agent/agent.go | 2 +- agent/agent_test.go | 8 ++- codersdk/workspacesdk/connector.go | 2 +- .../workspacesdk/connector_internal_test.go | 6 +- tailnet/coordinator.go | 57 +++++++++++-------- tailnet/coordinator_test.go | 42 +++++++------- tailnet/test/integration/integration.go | 4 +- 7 files changed, 72 insertions(+), 49 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 98e294320b856..2194e04dd1820 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai defer close(errCh) select { case <-ctx.Done(): - err := coordination.Close() + err := coordination.Close(a.hardCtx) if err != nil { a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) } diff --git a/agent/agent_test.go b/agent/agent_test.go index e4aac04e0eedd..91e7c1c34e0c0 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) { coordinator, conn) t.Cleanup(func() { t.Logf("closing coordination %s", name) - err := coordination.Close() + cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) + defer ccancel() + err := coordination.Close(cctx) if err != nil { t.Logf("error closing in-memory coordination: %s", err.Error()) } @@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati clientID, metadata.AgentID, coordinator, conn) t.Cleanup(func() { - err := coordination.Close() + cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) + defer ccancel() + err := coordination.Close(cctx) if err != nil { t.Logf("error closing in-mem coordination: %s", err.Error()) } diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index c761c92ae3e51..780478e91a55f 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { select { case <-tac.ctx.Done(): tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") - crdErr := coordination.Close() + crdErr := coordination.Close(tac.gracefulCtx) if crdErr != nil { tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index d56f45b4821b7..7a339a0079ba2 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { derpMapCh := make(chan *tailcfg.DERPMap) defer close(derpMapCh) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger, + Logger: logger.Named("svc"), CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, @@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { fConn := newFakeTailnetConn() - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) + uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL, + quartz.NewReal(), &websocket.DialOptions{}) uut.runConnector(fConn) call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) @@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) require.NotNil(t, reqDisc) require.NotNil(t, reqDisc.Disconnect) + close(call.Resps) } func TestTailnetAPIConnector_UplevelVersion(t *testing.T) { diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index cc50c792f16ea..54ce868df9316 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -91,7 +91,7 @@ type Coordinatee interface { } type Coordination interface { - io.Closer + Close(context.Context) error Error() <-chan error } @@ -106,7 +106,10 @@ type remoteCoordination struct { respLoopDone chan struct{} } -func (c *remoteCoordination) Close() (retErr error) { +// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and +// waiting for the server to hang up the coordination. If the provided context expires, we stop +// waiting for the server and close the coordination stream from our end. +func (c *remoteCoordination) Close(ctx context.Context) (retErr error) { c.Lock() defer c.Unlock() if c.closed { @@ -114,6 +117,18 @@ func (c *remoteCoordination) Close() (retErr error) { } c.closed = true defer func() { + // We shouldn't just close the protocol right away, because the way dRPC streams work is + // that if you close them, that could take effect immediately, even before the Disconnect + // message is processed. Coordinators are supposed to hang up on us once they get a + // Disconnect message, so we should wait around for that until the context expires. + select { + case <-c.respLoopDone: + c.logger.Debug(ctx, "responses closed after disconnect") + return + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") + } + // forcefully close the stream protoErr := c.protocol.Close() <-c.respLoopDone if retErr == nil { @@ -240,7 +255,6 @@ type inMemoryCoordination struct { ctx context.Context errChan chan error closed bool - closedCh chan struct{} respLoopDone chan struct{} coordinatee Coordinatee logger slog.Logger @@ -280,7 +294,6 @@ func NewInMemoryCoordination( errChan: make(chan error, 1), coordinatee: coordinatee, logger: logger, - closedCh: make(chan struct{}), respLoopDone: make(chan struct{}), } @@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() { c.coordinatee.SetAllPeersLost() close(c.respLoopDone) }() - for { - select { - case <-c.closedCh: - c.logger.Debug(context.Background(), "in-memory coordination closed") + for resp := range c.resps { + c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) + err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) return - case resp, ok := <-c.resps: - if !ok { - c.logger.Debug(context.Background(), "in-memory response channel closed") - return - } - c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) - err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) - return - } } } + c.logger.Debug(context.Background(), "in-memory response channel closed") } func (*inMemoryCoordination) AwaitAck() <-chan struct{} { @@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} { return ch } -func (c *inMemoryCoordination) Close() error { +// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and +// waiting for the server to hang up the coordination. If the provided context expires, we stop +// waiting for the server and close the coordination stream from our end. +func (c *inMemoryCoordination) Close(ctx context.Context) error { c.Lock() defer c.Unlock() c.logger.Debug(context.Background(), "closing in-memory coordination") @@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error { } defer close(c.reqs) c.closed = true - close(c.closedCh) - <-c.respLoopDone select { - case <-c.ctx.Done(): + case <-ctx.Done(): return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") + } + + select { + case <-ctx.Done(): + return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err()) + case <-c.respLoopDone: return nil } } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 400084fafab8e..99b4724e3577f 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -2,6 +2,7 @@ package tailnet_test import ( "context" + "io" "net" "net/netip" "sync" @@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) { Times(1).Return(reqs, resps) uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) - defer uut.Close() + defer uut.Close(ctx) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) @@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) { require.NoError(t, err) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) - defer uut.Close() + defer uut.Close(ctx) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - select { - case err := <-uut.Error(): - require.ErrorContains(t, err, "stream terminated by sending close") - default: - // OK! - } + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Error()) + require.ErrorIs(t, err, io.EOF) } func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { @@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { require.NoError(t, err) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{}) - defer uut.Close() + defer uut.Close(ctx) nk, err := key.NewNode().Public().MarshalBinary() require.NoError(t, err) @@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { require.Len(t, rfh.ReadyForHandshake, 1) require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) - require.NoError(t, uut.Close()) + go uut.Close(ctx) + dis := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, dis) + require.NotNil(t, dis.Disconnect) + close(resps) - select { - case err := <-uut.Error(): - require.ErrorContains(t, err, "stream terminated by sending close") - default: - // OK! - } + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Error()) + require.ErrorIs(t, err, io.EOF) } // coordinationTest tests that a coordination behaves correctly @@ -464,13 +463,18 @@ func coordinationTest( require.Len(t, fConn.updates[0], 1) require.Equal(t, agentID[:], fConn.updates[0][0].Id) - err = uut.Close() - require.NoError(t, err) - uut.Error() + errCh := make(chan error, 1) + go func() { + errCh <- uut.Close(ctx) + }() // When we close, it should gracefully disconnect req = testutil.RequireRecvCtx(ctx, t, reqs) require.NotNil(t, req.Disconnect) + close(resps) + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.NoError(t, err) // It should set all peers lost on the coordinatee require.Equal(t, 1, fConn.setAllPeersLostCalls) diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index 41326caaa7e4e..0d3956cf44d3e 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) t.Cleanup(func() { - _ = coordination.Close() + cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + _ = coordination.Close(cctx) }) return conn