Skip to content

fix: fix graceful disconnect in DialWorkspaceAgent #11993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 82 additions & 38 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
"time"

"golang.org/x/sync/errgroup"

"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
Expand Down Expand Up @@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
return agentConn, nil
}

// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
// included so that we can fake it in testing.
//
// @typescript-ignore tailnetConn
type tailnetConn interface {
tailnet.Coordinatee
SetDERPMap(derpMap *tailcfg.DERPMap)
}

// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
//
// 1) run the Coordinate API and pass node information back and forth
Expand All @@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
//
// @typescript-ignore tailnetAPIConnector
type tailnetAPIConnector struct {
ctx context.Context
// We keep track of two contexts: the main context from the caller, and a "graceful" context
// that we keep open slightly longer than the main context to give a chance to send the
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
// disconnect instead of just losing network connectivity.
ctx context.Context
gracefulCtx context.Context
cancelGracefulCtx context.CancelFunc

logger slog.Logger

agentID uuid.UUID
coordinateURL string
dialOptions *websocket.DialOptions
conn *tailnet.Conn
conn tailnetConn

connected chan error
isFirst bool
Expand All @@ -387,7 +402,7 @@ type tailnetAPIConnector struct {
func runTailnetAPIConnector(
ctx context.Context, logger slog.Logger,
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
conn *tailnet.Conn,
conn tailnetConn,
) *tailnetAPIConnector {
tac := &tailnetAPIConnector{
ctx: ctx,
Expand All @@ -399,10 +414,23 @@ func runTailnetAPIConnector(
connected: make(chan error, 1),
closed: make(chan struct{}),
}
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout()
go tac.run()
return tac
}

// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
// to allow a graceful disconnect.
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
defer tac.cancelGracefulCtx()
<-tac.ctx.Done()
select {
case <-tac.closed:
case <-time.After(time.Second):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Non-blocking) I figure this is a best-effort situation, but will 1 second be enough? Does this need to be a configurable knob?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be plenty, even on a slow connection because we're not waiting for a reply. I definitely don't want to plumb configuration thru.

It is a best effort as you say --- consequence of not doing this is that the agent on the other side will see it as "lost" and possibly still try to handshake with it for up to 15 minutes.

}
}

func (tac *tailnetAPIConnector) run() {
tac.isFirst = true
defer close(tac.closed)
Expand Down Expand Up @@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
return nil, err
}
client, err := tailnet.NewDRPCClient(
websocket.NetConn(tac.ctx, ws, websocket.MessageBinary),
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
tac.logger,
)
if err != nil {
Expand All @@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
<-conn.Closed()
}
}()
eg, egCtx := errgroup.WithContext(tac.ctx)
eg.Go(func() error {
return tac.coordinate(egCtx, client)
})
eg.Go(func() error {
return tac.derpMap(egCtx, client)
})
err := eg.Wait()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
}
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
tac.coordinate(client)
}()
go func() {
defer wg.Done()
dErr := tac.derpMap(client)
if dErr != nil && tac.ctx.Err() == nil {
// The main context is still active, meaning that we want the tailnet data plane to stay
// up, even though we hit some error getting DERP maps on the control plane. That means
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
client.DRPCConn().Close()
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
wg.Wait()
}

func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
coord, err := client.Coordinate(ctx)
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
coord, err := client.Coordinate(tac.gracefulCtx)
if err != nil {
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
return
}
defer func() {
cErr := coord.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
}
}()
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
tac.logger.Debug(ctx, "serving coordinator")
err = <-coordination.Error()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
return xerrors.Errorf("remote coordination error: %w", err)
tac.logger.Debug(tac.ctx, "serving coordinator")
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close()
if crdErr != nil {
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
}
case err = <-coordination.Error():
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "remote coordination error: %w", err)
}
}
return nil
}

func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
if err != nil {
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
}
defer func() {
cErr := s.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
for {
dmp, err := s.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
return xerrors.Errorf("error receiving DERP Map: %w", err)
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
return err
}
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
dm := tailnet.DERPMapFromProto(dmp)
tac.conn.SetDERPMap(dm)
}
Expand Down
106 changes: 106 additions & 0 deletions codersdk/workspaceagents_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package codersdk

import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)

func TestTailnetAPIConnector_Disconnects(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
logger := slogtest.Make(t, &slogtest.Options{
// we get EOF when we simulate a DERPMap error
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
}).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
clientID := uuid.UUID{0x66}
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
svc, err := tailnet.NewClientService(
logger, &coordPtr,
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
)
require.NoError(t, err)

svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))

fConn := newFakeTailnetConn()

uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)

call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.NotNil(t, reqTun.AddTunnel)

_ = testutil.RequireRecvCtx(ctx, t, uut.connected)

// simulate a problem with DERPMaps by sending nil
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)

// this should cause the coordinate call to hang up WITHOUT disconnecting
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.Nil(t, reqNil)

// ...and then reconnect
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.NotNil(t, reqTun.AddTunnel)

// canceling the context should trigger the disconnect message
cancel()
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
require.NotNil(t, reqDisc)
require.NotNil(t, reqDisc.Disconnect)
}

type fakeTailnetConn struct{}

func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
// TODO implement me
panic("implement me")
}
Comment on lines +93 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call t.Fail() instead of just panicking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

panic is nice because it gives you a stack trace.

changing to t.Fail() won't give a stack trace, so you'll have to manually chase down how the function could have been called by your test.


func (*fakeTailnetConn) SetAllPeersLost() {}

func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}

func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}

func newFakeTailnetConn() *fakeTailnetConn {
return &fakeTailnetConn{}
}
Comment on lines +91 to +106
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this live in tailnettest as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, because the interface we are faking lives in codersdk, even though the "real" object we are faking lives in tailnet.

3 changes: 2 additions & 1 deletion tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) {
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
}

Expand Down Expand Up @@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() {
}
}

// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
func NewRemoteCoordination(logger slog.Logger,
Expand Down