Skip to content

Commit 3db51e0

Browse files
committed
PR comments 3
1 parent f49fe6b commit 3db51e0

File tree

6 files changed

+215
-177
lines changed

6 files changed

+215
-177
lines changed

coderd/workspaceagents_test.go

Lines changed: 111 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ import (
66
"fmt"
77
"net"
88
"net/http"
9-
"net/http/httptest"
10-
"net/url"
119
"runtime"
1210
"strconv"
1311
"strings"
14-
"sync"
1512
"sync/atomic"
1613
"testing"
1714
"time"
@@ -21,6 +18,7 @@ import (
2118
"github.com/stretchr/testify/require"
2219
"golang.org/x/xerrors"
2320
"google.golang.org/protobuf/types/known/timestamppb"
21+
"nhooyr.io/websocket"
2422
"tailscale.com/tailcfg"
2523

2624
"cdr.dev/slog"
@@ -43,6 +41,8 @@ import (
4341
"github.com/coder/coder/v2/codersdk/workspacesdk"
4442
"github.com/coder/coder/v2/provisioner/echo"
4543
"github.com/coder/coder/v2/provisionersdk/proto"
44+
"github.com/coder/coder/v2/tailnet"
45+
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
4646
"github.com/coder/coder/v2/tailnet/tailnettest"
4747
"github.com/coder/coder/v2/testutil"
4848
)
@@ -512,111 +512,138 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
512512
require.Equal(t, "version", sdkErr.Validations[0].Field)
513513
}
514514

515+
type resumeTokenTestFakeCoordinator struct {
516+
tailnet.Coordinator
517+
lastPeerID uuid.UUID
518+
}
519+
520+
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}
521+
522+
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
523+
c.lastPeerID = id
524+
return c.Coordinator.ServeClient(conn, id, agentID)
525+
}
526+
527+
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
528+
c.lastPeerID = id
529+
return c.Coordinator.Coordinate(ctx, id, name, a)
530+
}
531+
515532
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
516533
t.Parallel()
517534

518535
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
519-
520-
// We block direct in this test to ensure that even if there's no direct
521-
// connection, no shenanigans happen with the peer IDs on either side.
522-
dv := coderdtest.DeploymentValues(t)
523-
err := dv.DERP.Config.BlockDirect.Set("true")
524-
require.NoError(t, err)
536+
coordinator := &resumeTokenTestFakeCoordinator{
537+
Coordinator: tailnet.NewCoordinator(logger),
538+
}
525539
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
526-
DeploymentValues: dv,
540+
Coordinator: coordinator,
527541
})
528542
defer closer.Close()
529543
user := coderdtest.CreateFirstUser(t, client)
530544

531-
// Change the DERP mapper to our custom one.
532-
var currentDerpMap atomic.Pointer[tailcfg.DERPMap]
533-
originalDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
534-
currentDerpMap.Store(originalDerpMap)
535-
derpMapFn := func(_ *tailcfg.DERPMap) *tailcfg.DERPMap {
536-
return currentDerpMap.Load().Clone()
537-
}
538-
api.DERPMapper.Store(&derpMapFn)
539-
540-
// Start workspace a workspace agent.
545+
// Create a workspace with an agent. No need to connect it since clients can
546+
// still connect to the coordinator while the agent isn't connected.
541547
r := dbfake.WorkspaceBuild(t, api.Database, database.Workspace{
542548
OrganizationID: user.OrganizationID,
543549
OwnerID: user.UserID,
544550
}).WithAgent().Do()
545-
546-
agentCloser := agenttest.New(t, client.URL, r.AgentToken)
547-
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
548-
agentID := resources[0].Agents[0].ID
549-
550-
// Create a new "proxy" server that we can use to kill the connection
551-
// whenever we want.
552-
l, err := netListenDroppable("tcp", "localhost:0")
551+
agentTokenUUID, err := uuid.Parse(r.AgentToken)
553552
require.NoError(t, err)
554-
defer l.Close()
555-
srv := &httptest.Server{
556-
Listener: l,
557-
//nolint:gosec
558-
Config: &http.Server{Handler: api.RootHandler},
559-
}
560-
srv.Start()
561-
proxyURL, err := url.Parse(srv.URL)
553+
ctx := testutil.Context(t, testutil.WaitLong)
554+
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
562555
require.NoError(t, err)
563-
proxyClient := codersdk.New(proxyURL)
564-
proxyClient.SetSessionToken(client.SessionToken())
565-
566-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
567-
defer cancel()
568556

569-
// Connect from a client.
570-
conn, err := workspacesdk.New(proxyClient).
571-
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
572-
Logger: logger.Named("client"),
573-
})
557+
// Connect with no resume token, and ensure that the peer ID is set to a
558+
// random value.
559+
coordinator.lastPeerID = uuid.Nil
560+
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
574561
require.NoError(t, err)
575-
defer conn.Close()
562+
originalPeerID := coordinator.lastPeerID
563+
require.NotEqual(t, originalPeerID, uuid.Nil)
576564

577-
ok := conn.AwaitReachable(ctx)
578-
require.True(t, ok)
579-
originalAgentPeers := agentCloser.TailnetConn().GetKnownPeerIDs()
565+
// Connect with a valid resume token, and ensure that the peer ID is set to
566+
// the stored value.
567+
coordinator.lastPeerID = uuid.Nil
568+
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
569+
require.NoError(t, err)
570+
require.Equal(t, originalPeerID, coordinator.lastPeerID)
571+
require.NotEqual(t, originalResumeToken, newResumeToken)
580572

581-
// Drop client conn's coordinator connection.
582-
l.DropAllConns()
573+
// Connect with an invalid resume token, and ensure that the request is
574+
// rejected.
575+
coordinator.lastPeerID = uuid.Nil
576+
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
577+
require.Error(t, err)
578+
var sdkErr *codersdk.Error
579+
require.ErrorAs(t, err, &sdkErr)
580+
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
581+
require.Len(t, sdkErr.Validations, 1)
582+
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
583+
require.Equal(t, uuid.Nil, coordinator.lastPeerID)
584+
}
583585

584-
// HACK: Change the DERP map and add a second "marker" region so we know
585-
// when the client has reconnected to the coordinator.
586-
//
587-
// With some refactoring of the client connection to expose the
588-
// coordinator connection status, this wouldn't be needed, but this
589-
// also works.
590-
derpMap := currentDerpMap.Load().Clone()
591-
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
592-
derpMap.Regions[2] = newDerpMap.Regions[1]
593-
currentDerpMap.Store(derpMap)
586+
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
587+
// with a given resume token. It returns an error if the connection is rejected.
588+
// If the connection is accepted, it is immediately closed and no error is
589+
// returned.
590+
func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Logger, sdkClient *codersdk.Client, agentID uuid.UUID, resumeToken string) (string, error) {
591+
u, err := sdkClient.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
592+
if err != nil {
593+
return "", xerrors.Errorf("parse URL: %w", err)
594+
}
595+
q := u.Query()
596+
q.Set("version", "2.0")
597+
if resumeToken != "" {
598+
q.Set("resume_token", resumeToken)
599+
}
600+
u.RawQuery = q.Encode()
594601

595-
// Wait for the agent's DERP map to be updated.
596-
require.Eventually(t, func() bool {
597-
conn := agentCloser.TailnetConn()
598-
if conn == nil {
599-
return false
602+
//nolint:bodyclose
603+
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
604+
HTTPHeader: http.Header{
605+
"Coder-Session-Token": []string{sdkClient.SessionToken()},
606+
},
607+
})
608+
if err != nil {
609+
if resp.StatusCode != http.StatusSwitchingProtocols {
610+
err = codersdk.ReadBodyAsError(resp)
600611
}
601-
regionIDs := conn.DERPMap().RegionIDs()
602-
return len(regionIDs) == 2 && regionIDs[1] == 2
603-
}, testutil.WaitLong, testutil.IntervalFast)
604-
605-
// Wait for the DERP map to be updated on the client. This means that the
606-
// client has reconnected to the coordinator.
607-
require.Eventually(t, func() bool {
608-
regionIDs := conn.Conn.DERPMap().RegionIDs()
609-
return len(regionIDs) == 2 && regionIDs[1] == 2
610-
}, testutil.WaitLong, testutil.IntervalFast)
612+
return "", xerrors.Errorf("websocket dial: %w", err)
613+
}
614+
defer wsConn.Close(websocket.StatusNormalClosure, "done")
615+
616+
// Send a request to the server to ensure that we're plumbed all the way
617+
// through.
618+
rpcClient, err := tailnet.NewDRPCClient(
619+
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
620+
logger,
621+
)
622+
if err != nil {
623+
return "", xerrors.Errorf("new dRPC client: %w", err)
624+
}
611625

612-
// The first client should still be able to reach the agent.
613-
ok = conn.AwaitReachable(ctx)
614-
require.True(t, ok)
615-
_, err = conn.ListeningPorts(ctx)
616-
require.NoError(t, err)
626+
// Send an empty coordination request. This will do nothing on the server,
627+
// but ensures our wrapped coordinator can record the peer ID.
628+
coordinateClient, err := rpcClient.Coordinate(ctx)
629+
if err != nil {
630+
return "", xerrors.Errorf("coordinate: %w", err)
631+
}
632+
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
633+
if err != nil {
634+
return "", xerrors.Errorf("send empty coordination request: %w", err)
635+
}
636+
err = coordinateClient.Close()
637+
if err != nil {
638+
return "", xerrors.Errorf("close coordination request: %w", err)
639+
}
617640

618-
// The agent should not see any new peers.
619-
require.ElementsMatch(t, originalAgentPeers, agentCloser.TailnetConn().GetKnownPeerIDs())
641+
// Fetch a resume token.
642+
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
643+
if err != nil {
644+
return "", xerrors.Errorf("fetch resume token: %w", err)
645+
}
646+
return newResumeToken.Token, nil
620647
}
621648

622649
func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
@@ -1832,40 +1859,3 @@ func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup
18321859
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
18331860
return err
18341861
}
1835-
1836-
type droppableTCPListener struct {
1837-
net.Listener
1838-
mu sync.Mutex
1839-
conns []net.Conn
1840-
}
1841-
1842-
var _ net.Listener = &droppableTCPListener{}
1843-
1844-
func netListenDroppable(network, addr string) (*droppableTCPListener, error) {
1845-
l, err := net.Listen(network, addr)
1846-
if err != nil {
1847-
return nil, err
1848-
}
1849-
return &droppableTCPListener{Listener: l}, nil
1850-
}
1851-
1852-
func (l *droppableTCPListener) Accept() (net.Conn, error) {
1853-
conn, err := l.Listener.Accept()
1854-
if err != nil {
1855-
return nil, err
1856-
}
1857-
1858-
l.mu.Lock()
1859-
defer l.mu.Unlock()
1860-
l.conns = append(l.conns, conn)
1861-
return conn, nil
1862-
}
1863-
1864-
func (l *droppableTCPListener) DropAllConns() {
1865-
l.mu.Lock()
1866-
defer l.mu.Unlock()
1867-
for _, c := range l.conns {
1868-
_ = c.Close()
1869-
}
1870-
l.conns = nil
1871-
}

codersdk/workspacesdk/connector.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/coder/coder/v2/codersdk"
2626
"github.com/coder/coder/v2/tailnet"
2727
"github.com/coder/coder/v2/tailnet/proto"
28+
"github.com/coder/quartz"
2829
"github.com/coder/retry"
2930
)
3031

@@ -62,6 +63,7 @@ type tailnetAPIConnector struct {
6263

6364
agentID uuid.UUID
6465
coordinateURL string
66+
clock quartz.Clock
6567
dialOptions *websocket.DialOptions
6668
conn tailnetConn
6769
customDialFn func() (proto.DRPCTailnetClient, error)
@@ -70,7 +72,7 @@ type tailnetAPIConnector struct {
7072
client proto.DRPCTailnetClient
7173

7274
connected chan error
73-
resumeToken atomic.Pointer[proto.RefreshResumeTokenResponse]
75+
resumeToken *proto.RefreshResumeTokenResponse
7476
isFirst bool
7577
closed chan struct{}
7678

@@ -80,12 +82,13 @@ type tailnetAPIConnector struct {
8082
}
8183

8284
// Create a new tailnetAPIConnector without running it
83-
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
85+
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
8486
return &tailnetAPIConnector{
8587
ctx: ctx,
8688
logger: logger,
8789
agentID: agentID,
8890
coordinateURL: coordinateURL,
91+
clock: clock,
8992
dialOptions: dialOptions,
9093
conn: nil,
9194
connected: make(chan error, 1),
@@ -98,7 +101,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
98101
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
99102
defer tac.cancelGracefulCtx()
100103
<-tac.ctx.Done()
101-
timer := time.NewTimer(tailnetConnectorGracefulTimeout)
104+
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
102105
defer timer.Stop()
103106
select {
104107
case <-tac.closed:
@@ -114,6 +117,8 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
114117
go func() {
115118
tac.isFirst = true
116119
defer close(tac.closed)
120+
// Sadly retry doesn't support quartz.Clock yet so this is not
121+
// influenced by the configured clock.
117122
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
118123
tailnetClient, err := tac.dial()
119124
if err != nil {
@@ -145,12 +150,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
145150
if err != nil {
146151
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
147152
}
148-
resumeToken := tac.resumeToken.Load()
149-
if resumeToken != nil {
153+
if tac.resumeToken != nil {
150154
q := u.Query()
151-
q.Set("resume_token", resumeToken.Token)
155+
q.Set("resume_token", tac.resumeToken.Token)
152156
u.RawQuery = q.Encode()
153-
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", resumeToken))
157+
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
154158
}
155159

156160
coordinateURL := u.String()
@@ -186,7 +190,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
186190
if v.Field == "resume_token" {
187191
// Unset the resume token for the next attempt
188192
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
189-
tac.resumeToken.Store(nil)
193+
tac.resumeToken = nil
190194
didLog = true
191195
}
192196
}
@@ -317,7 +321,7 @@ func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
317321
}
318322

319323
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
320-
ticker := time.NewTicker(15 * time.Second)
324+
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
321325
defer ticker.Stop()
322326

323327
initialCh := make(chan struct{}, 1)
@@ -341,8 +345,13 @@ func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.D
341345
return
342346
}
343347
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
344-
tac.resumeToken.Store(res)
345-
ticker.Reset(res.RefreshIn.AsDuration())
348+
tac.resumeToken = res
349+
dur := res.RefreshIn.AsDuration()
350+
if dur <= 0 {
351+
// A sensible delay to refresh again.
352+
dur = 30 * time.Minute
353+
}
354+
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
346355
}
347356
}
348357

0 commit comments

Comments
 (0)