Skip to content

fix: fix tailnet remoteCoordination to wait for server #14666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix: fix tailnet remoteCoordination to wait for server
  • Loading branch information
spikecurtis committed Sep 13, 2024
commit 540d7eb040dcd719b0e0f44c198fe0f46f84c1d3
2 changes: 1 addition & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
defer close(errCh)
select {
case <-ctx.Done():
err := coordination.Close()
err := coordination.Close(a.hardCtx)
if err != nil {
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
}
Expand Down
8 changes: 6 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
coordinator, conn)
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
err := coordination.Close()
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
Expand Down Expand Up @@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
clientID, metadata.AgentID,
coordinator, conn)
t.Cleanup(func() {
err := coordination.Close()
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close()
crdErr := coordination.Close(tac.gracefulCtx)
if crdErr != nil {
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
}
Expand Down
6 changes: 4 additions & 2 deletions codersdk/workspacesdk/connector_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
Expand All @@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {

fConn := newFakeTailnetConn()

uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn)

call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
Expand All @@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
require.NotNil(t, reqDisc)
require.NotNil(t, reqDisc.Disconnect)
close(call.Resps)
}

func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
Expand Down
57 changes: 34 additions & 23 deletions tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ type Coordinatee interface {
}

type Coordination interface {
io.Closer
Close(context.Context) error
Error() <-chan error
}

Expand All @@ -106,14 +106,29 @@ type remoteCoordination struct {
respLoopDone chan struct{}
}

func (c *remoteCoordination) Close() (retErr error) {
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.protocol.Close()
<-c.respLoopDone
if retErr == nil {
Expand Down Expand Up @@ -240,7 +255,6 @@ type inMemoryCoordination struct {
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
respLoopDone chan struct{}
coordinatee Coordinatee
logger slog.Logger
Expand Down Expand Up @@ -280,7 +294,6 @@ func NewInMemoryCoordination(
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
respLoopDone: make(chan struct{}),
}

Expand Down Expand Up @@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
select {
case <-c.closedCh:
c.logger.Debug(context.Background(), "in-memory coordination closed")
for resp := range c.resps {
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
case resp, ok := <-c.resps:
if !ok {
c.logger.Debug(context.Background(), "in-memory response channel closed")
return
}
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
}
}
}
c.logger.Debug(context.Background(), "in-memory response channel closed")
}

func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
Expand All @@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
return ch
}

func (c *inMemoryCoordination) Close() error {
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *inMemoryCoordination) Close(ctx context.Context) error {
c.Lock()
defer c.Unlock()
c.logger.Debug(context.Background(), "closing in-memory coordination")
Expand All @@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
}
defer close(c.reqs)
c.closed = true
close(c.closedCh)
<-c.respLoopDone
select {
case <-c.ctx.Done():
case <-ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
}

select {
case <-ctx.Done():
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
case <-c.respLoopDone:
return nil
}
}
Expand Down
42 changes: 23 additions & 19 deletions tailnet/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tailnet_test

import (
"context"
"io"
"net"
"net/netip"
"sync"
Expand Down Expand Up @@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
Times(1).Return(reqs, resps)

uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()
defer uut.Close(ctx)

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

Expand Down Expand Up @@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
require.NoError(t, err)

uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()
defer uut.Close(ctx)

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorIs(t, err, io.EOF)
}

func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
Expand Down Expand Up @@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
require.NoError(t, err)

uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
defer uut.Close()
defer uut.Close(ctx)

nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
Expand All @@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)

require.NoError(t, uut.Close())
go uut.Close(ctx)
dis := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, dis)
require.NotNil(t, dis.Disconnect)
close(resps)

select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorIs(t, err, io.EOF)
}

// coordinationTest tests that a coordination behaves correctly
Expand Down Expand Up @@ -464,13 +463,18 @@ func coordinationTest(
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)

err = uut.Close()
require.NoError(t, err)
uut.Error()
errCh := make(chan error, 1)
go func() {
errCh <- uut.Close(ctx)
}()

// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)

err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)

// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
Expand Down
4 changes: 3 additions & 1 deletion tailnet/test/integration/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me

coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
t.Cleanup(func() {
_ = coordination.Close()
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_ = coordination.Close(cctx)
})

return conn
Expand Down
Loading