From 9893031fea414e7b10054ca655135e962819735b Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 11 Sep 2024 15:42:25 +0400 Subject: [PATCH] fix: fix flake in TestWorkspaceAgentClientCoordinate_ResumeToken --- coderd/workspaceagents.go | 2 + coderd/workspaceagents_test.go | 82 +++++++++++++++++----------------- tailnet/service.go | 2 + 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 75f2a06045af7..eaf06a643a3e0 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -864,6 +864,8 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R }) return } + api.Logger.Debug(ctx, "accepted coordinate resume token for peer", + slog.F("peer_id", peerID.String())) } api.WebsocketWaitMutex.Lock() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index b80efc36b9e19..906333456ae70 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -513,30 +513,42 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) { require.Equal(t, "version", sdkErr.Validations[0].Field) } -type resumeTokenTestFakeCoordinator struct { - tailnet.Coordinator - t testing.TB - peerIDCh chan uuid.UUID +type resumeTokenRecordingProvider struct { + tailnet.ResumeTokenProvider + t testing.TB + generateCalls chan uuid.UUID + verifyCalls chan string } -var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{} +var _ tailnet.ResumeTokenProvider = &resumeTokenRecordingProvider{} -func (c *resumeTokenTestFakeCoordinator) storeID(id uuid.UUID) { - select { - case c.peerIDCh <- id: - default: - c.t.Fatal("peer ID channel full") +func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeTokenProvider) *resumeTokenRecordingProvider { + return &resumeTokenRecordingProvider{ + ResumeTokenProvider: underlying, + t: t, + generateCalls: make(chan uuid.UUID, 1), + verifyCalls: make(chan string, 1), } } -func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error { - c.storeID(id) - return c.Coordinator.ServeClient(conn, id, agentID) +func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) { + select { + case r.generateCalls <- peerID: + return r.ResumeTokenProvider.GenerateResumeToken(peerID) + default: + r.t.Error("generateCalls full") + return nil, xerrors.New("generateCalls full") + } } -func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) { - c.storeID(id) - return c.Coordinator.Coordinate(ctx, id, name, a) +func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) { + select { + case r.verifyCalls <- token: + return r.ResumeTokenProvider.VerifyResumeToken(token) + default: + r.t.Error("verifyCalls full") + return uuid.Nil, xerrors.New("verifyCalls full") + } } func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { @@ -546,15 +558,12 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { clock := quartz.NewMock(t) resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) - coordinator := &resumeTokenTestFakeCoordinator{ - Coordinator: tailnet.NewCoordinator(logger), - t: t, - peerIDCh: make(chan uuid.UUID, 1), - } - defer close(coordinator.peerIDCh) + resumeTokenProvider := newResumeTokenRecordingProvider( + t, + tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour), + ) client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - Coordinator: coordinator, + Coordinator: tailnet.NewCoordinator(logger), CoordinatorResumeTokenProvider: resumeTokenProvider, }) defer closer.Close() @@ -576,7 +585,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { // random value. originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") require.NoError(t, err) - originalPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh) + originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) require.NotEqual(t, originalPeerID, uuid.Nil) // Connect with a valid resume token, and ensure that the peer ID is set to @@ -584,7 +593,9 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { clock.Advance(time.Second) newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) require.NoError(t, err) - newPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh) + verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, originalResumeToken, verifiedToken) + newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) require.Equal(t, originalPeerID, newPeerID) require.NotEqual(t, originalResumeToken, newResumeToken) @@ -598,9 +609,11 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) require.Len(t, sdkErr.Validations, 1) require.Equal(t, "resume_token", sdkErr.Validations[0].Field) + verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, "invalid", verifiedToken) select { - case <-coordinator.peerIDCh: + case <-resumeTokenProvider.generateCalls: t.Fatal("unexpected peer ID in channel") default: } @@ -646,21 +659,6 @@ func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Lo return "", xerrors.Errorf("new dRPC client: %w", err) } - // Send an empty coordination request. This will do nothing on the server, - // but ensures our wrapped coordinator can record the peer ID. - coordinateClient, err := rpcClient.Coordinate(ctx) - if err != nil { - return "", xerrors.Errorf("coordinate: %w", err) - } - err = coordinateClient.Send(&tailnetproto.CoordinateRequest{}) - if err != nil { - return "", xerrors.Errorf("send empty coordination request: %w", err) - } - err = coordinateClient.Close() - if err != nil { - return "", xerrors.Errorf("close coordination request: %w", err) - } - // Fetch a resume token. newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{}) if err != nil { diff --git a/tailnet/service.go b/tailnet/service.go index ebb5f7e9163a0..22111ce2fe9c9 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -119,6 +119,8 @@ func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID return xerrors.Errorf("yamux init failed: %w", err) } ctx = WithStreamID(ctx, streamID) + s.Logger.Debug(ctx, "serving dRPC tailnet v2 API session", + slog.F("peer_id", streamID.ID.String())) return s.drpc.Serve(ctx, session) }