Skip to content

feat: add resume support to coordinator connections #14234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR comments 3
  • Loading branch information
deansheather committed Aug 18, 2024
commit 3db51e070fb062c83fcd30e1676bd5e0240b0e75
232 changes: 111 additions & 121 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand All @@ -21,6 +18,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"

"cdr.dev/slog"
Expand All @@ -43,6 +41,8 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
Expand Down Expand Up @@ -512,111 +512,138 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
require.Equal(t, "version", sdkErr.Validations[0].Field)
}

type resumeTokenTestFakeCoordinator struct {
tailnet.Coordinator
lastPeerID uuid.UUID
}

var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}

func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
c.lastPeerID = id
return c.Coordinator.ServeClient(conn, id, agentID)
}

func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
c.lastPeerID = id
return c.Coordinator.Coordinate(ctx, id, name, a)
}

func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)

// We block direct in this test to ensure that even if there's no direct
// connection, no shenanigans happen with the peer IDs on either side.
dv := coderdtest.DeploymentValues(t)
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
coordinator := &resumeTokenTestFakeCoordinator{
Coordinator: tailnet.NewCoordinator(logger),
}
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dv,
Coordinator: coordinator,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)

// Change the DERP mapper to our custom one.
var currentDerpMap atomic.Pointer[tailcfg.DERPMap]
originalDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
currentDerpMap.Store(originalDerpMap)
derpMapFn := func(_ *tailcfg.DERPMap) *tailcfg.DERPMap {
return currentDerpMap.Load().Clone()
}
api.DERPMapper.Store(&derpMapFn)

// Start workspace a workspace agent.
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.Workspace{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()

agentCloser := agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
agentID := resources[0].Agents[0].ID

// Create a new "proxy" server that we can use to kill the connection
// whenever we want.
l, err := netListenDroppable("tcp", "localhost:0")
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
defer l.Close()
srv := &httptest.Server{
Listener: l,
//nolint:gosec
Config: &http.Server{Handler: api.RootHandler},
}
srv.Start()
proxyURL, err := url.Parse(srv.URL)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
require.NoError(t, err)
proxyClient := codersdk.New(proxyURL)
proxyClient.SetSessionToken(client.SessionToken())

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

// Connect from a client.
conn, err := workspacesdk.New(proxyClient).
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
Logger: logger.Named("client"),
})
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
coordinator.lastPeerID = uuid.Nil
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
defer conn.Close()
originalPeerID := coordinator.lastPeerID
require.NotEqual(t, originalPeerID, uuid.Nil)

ok := conn.AwaitReachable(ctx)
require.True(t, ok)
originalAgentPeers := agentCloser.TailnetConn().GetKnownPeerIDs()
// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
coordinator.lastPeerID = uuid.Nil
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
require.Equal(t, originalPeerID, coordinator.lastPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)

// Drop client conn's coordinator connection.
l.DropAllConns()
// Connect with an invalid resume token, and ensure that the request is
// rejected.
coordinator.lastPeerID = uuid.Nil
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
require.Equal(t, uuid.Nil, coordinator.lastPeerID)
}

// HACK: Change the DERP map and add a second "marker" region so we know
// when the client has reconnected to the coordinator.
//
// With some refactoring of the client connection to expose the
// coordinator connection status, this wouldn't be needed, but this
// also works.
derpMap := currentDerpMap.Load().Clone()
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
derpMap.Regions[2] = newDerpMap.Regions[1]
currentDerpMap.Store(derpMap)
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
// with a given resume token. It returns an error if the connection is rejected.
// If the connection is accepted, it is immediately closed and no error is
// returned.
func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Logger, sdkClient *codersdk.Client, agentID uuid.UUID, resumeToken string) (string, error) {
u, err := sdkClient.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
if err != nil {
return "", xerrors.Errorf("parse URL: %w", err)
}
q := u.Query()
q.Set("version", "2.0")
if resumeToken != "" {
q.Set("resume_token", resumeToken)
}
u.RawQuery = q.Encode()

// Wait for the agent's DERP map to be updated.
require.Eventually(t, func() bool {
conn := agentCloser.TailnetConn()
if conn == nil {
return false
//nolint:bodyclose
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: http.Header{
"Coder-Session-Token": []string{sdkClient.SessionToken()},
},
})
if err != nil {
if resp.StatusCode != http.StatusSwitchingProtocols {
err = codersdk.ReadBodyAsError(resp)
}
regionIDs := conn.DERPMap().RegionIDs()
return len(regionIDs) == 2 && regionIDs[1] == 2
}, testutil.WaitLong, testutil.IntervalFast)

// Wait for the DERP map to be updated on the client. This means that the
// client has reconnected to the coordinator.
require.Eventually(t, func() bool {
regionIDs := conn.Conn.DERPMap().RegionIDs()
return len(regionIDs) == 2 && regionIDs[1] == 2
}, testutil.WaitLong, testutil.IntervalFast)
return "", xerrors.Errorf("websocket dial: %w", err)
}
defer wsConn.Close(websocket.StatusNormalClosure, "done")

// Send a request to the server to ensure that we're plumbed all the way
// through.
rpcClient, err := tailnet.NewDRPCClient(
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
logger,
)
if err != nil {
return "", xerrors.Errorf("new dRPC client: %w", err)
}

// The first client should still be able to reach the agent.
ok = conn.AwaitReachable(ctx)
require.True(t, ok)
_, err = conn.ListeningPorts(ctx)
require.NoError(t, err)
// Send an empty coordination request. This will do nothing on the server,
// but ensures our wrapped coordinator can record the peer ID.
coordinateClient, err := rpcClient.Coordinate(ctx)
if err != nil {
return "", xerrors.Errorf("coordinate: %w", err)
}
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
if err != nil {
return "", xerrors.Errorf("send empty coordination request: %w", err)
}
err = coordinateClient.Close()
if err != nil {
return "", xerrors.Errorf("close coordination request: %w", err)
}

// The agent should not see any new peers.
require.ElementsMatch(t, originalAgentPeers, agentCloser.TailnetConn().GetKnownPeerIDs())
// Fetch a resume token.
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
if err != nil {
return "", xerrors.Errorf("fetch resume token: %w", err)
}
return newResumeToken.Token, nil
}

func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
Expand Down Expand Up @@ -1832,40 +1859,3 @@ func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
return err
}

type droppableTCPListener struct {
net.Listener
mu sync.Mutex
conns []net.Conn
}

var _ net.Listener = &droppableTCPListener{}

func netListenDroppable(network, addr string) (*droppableTCPListener, error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &droppableTCPListener{Listener: l}, nil
}

func (l *droppableTCPListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}

l.mu.Lock()
defer l.mu.Unlock()
l.conns = append(l.conns, conn)
return conn, nil
}

func (l *droppableTCPListener) DropAllConns() {
l.mu.Lock()
defer l.mu.Unlock()
for _, c := range l.conns {
_ = c.Close()
}
l.conns = nil
}
31 changes: 20 additions & 11 deletions codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
"github.com/coder/retry"
)

Expand Down Expand Up @@ -62,6 +63,7 @@ type tailnetAPIConnector struct {

agentID uuid.UUID
coordinateURL string
clock quartz.Clock
dialOptions *websocket.DialOptions
conn tailnetConn
customDialFn func() (proto.DRPCTailnetClient, error)
Expand All @@ -70,7 +72,7 @@ type tailnetAPIConnector struct {
client proto.DRPCTailnetClient

connected chan error
resumeToken atomic.Pointer[proto.RefreshResumeTokenResponse]
resumeToken *proto.RefreshResumeTokenResponse
isFirst bool
closed chan struct{}

Expand All @@ -80,12 +82,13 @@ type tailnetAPIConnector struct {
}

// Create a new tailnetAPIConnector without running it
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
return &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: coordinateURL,
clock: clock,
dialOptions: dialOptions,
conn: nil,
connected: make(chan error, 1),
Expand All @@ -98,7 +101,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
defer tac.cancelGracefulCtx()
<-tac.ctx.Done()
timer := time.NewTimer(tailnetConnectorGracefulTimeout)
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
defer timer.Stop()
select {
case <-tac.closed:
Expand All @@ -114,6 +117,8 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
go func() {
tac.isFirst = true
defer close(tac.closed)
// Sadly retry doesn't support quartz.Clock yet so this is not
// influenced by the configured clock.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
tailnetClient, err := tac.dial()
if err != nil {
Expand Down Expand Up @@ -145,12 +150,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
if err != nil {
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
}
resumeToken := tac.resumeToken.Load()
if resumeToken != nil {
if tac.resumeToken != nil {
q := u.Query()
q.Set("resume_token", resumeToken.Token)
q.Set("resume_token", tac.resumeToken.Token)
u.RawQuery = q.Encode()
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", resumeToken))
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
}

coordinateURL := u.String()
Expand Down Expand Up @@ -186,7 +190,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
if v.Field == "resume_token" {
// Unset the resume token for the next attempt
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
tac.resumeToken.Store(nil)
tac.resumeToken = nil
didLog = true
}
}
Expand Down Expand Up @@ -317,7 +321,7 @@ func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
}

func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
ticker := time.NewTicker(15 * time.Second)
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
defer ticker.Stop()

initialCh := make(chan struct{}, 1)
Expand All @@ -341,8 +345,13 @@ func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.D
return
}
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
tac.resumeToken.Store(res)
ticker.Reset(res.RefreshIn.AsDuration())
tac.resumeToken = res
dur := res.RefreshIn.AsDuration()
if dur <= 0 {
// A sensible delay to refresh again.
dur = 30 * time.Minute
}
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
}
}

Expand Down
Loading
Loading