diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 3108b0bac880c..3d3e7aaa061d6 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -445,13 +445,20 @@ func TestWorkspaceAgentTailnet(t *testing.T) { _ = agenttest.New(t, client.URL, authToken) resources := coderdtest.AwaitWorkspaceAgents(t, client, ws.ID) - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) + conn, err := func() (*codersdk.WorkspaceAgentConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() // Connection should remain open even if the dial context is canceled. + + return client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) + }() require.NoError(t, err) defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) session, err := sshClient.NewSession() @@ -1416,12 +1423,20 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { agentID := resources[0].Agents[0].ID // Connect from a client. - ctx := testutil.Context(t, testutil.WaitLong) - conn1, err := client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{ - Logger: logger.Named("client1"), - }) + conn1, err := func() (*codersdk.WorkspaceAgentConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() // Connection should remain open even if the dial context is canceled. + + return client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{ + Logger: logger.Named("client1"), + }) + }() require.NoError(t, err) defer conn1.Close() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + ok := conn1.AwaitReachable(ctx) require.True(t, ok) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index e020fd579a417..ac3f28aa28324 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -258,12 +258,12 @@ type DialWorkspaceAgentOptions struct { BlockEndpoints bool } -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) { +func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) { if options == nil { options = &DialWorkspaceAgentOptions{} } - connInfo, err := c.WorkspaceAgentConnectionInfo(ctx, agentID) + connInfo, err := c.WorkspaceAgentConnectionInfo(dialCtx, agentID) if err != nil { return nil, xerrors.Errorf("get connection info: %w", err) } @@ -302,7 +302,10 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti tokenHeader = c.SessionTokenHeader } headers.Set(tokenHeader, c.SessionToken()) - ctx, cancel := context.WithCancel(ctx) + + // New context, separate from dialCtx. We don't want to cancel the + // connection if dialCtx is canceled. + ctx, cancel := context.WithCancel(context.Background()) defer func() { if err != nil { cancel() @@ -314,7 +317,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, xerrors.Errorf("parse url: %w", err) } closedCoordinator := make(chan struct{}) - firstCoordinator := make(chan error) + // Must only ever be used once, send error OR close to avoid + // reassignment race. Buffered so we don't hang in goroutine. + firstCoordinator := make(chan error, 1) go func() { defer close(closedCoordinator) isFirst := true @@ -366,7 +371,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, xerrors.Errorf("parse url: %w", err) } closedDerpMap := make(chan struct{}) - firstDerpMap := make(chan error) + // Must only ever be used once, send error OR close to avoid + // reassignment race. Buffered so we don't hang in goroutine. + firstDerpMap := make(chan error, 1) go func() { defer close(closedDerpMap) isFirst := true @@ -420,13 +427,21 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti } }() - err = <-firstCoordinator - if err != nil { - return nil, err - } - err = <-firstDerpMap - if err != nil { - return nil, err + for firstCoordinator != nil || firstDerpMap != nil { + select { + case <-dialCtx.Done(): + return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) + case err = <-firstCoordinator: + if err != nil { + return nil, xerrors.Errorf("start coordinator: %w", err) + } + firstCoordinator = nil + case err = <-firstDerpMap: + if err != nil { + return nil, xerrors.Errorf("receive derp map: %w", err) + } + firstDerpMap = nil + } } agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ @@ -444,9 +459,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti }, }) - if !agentConn.AwaitReachable(ctx) { + if !agentConn.AwaitReachable(dialCtx) { _ = agentConn.Close() - return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) + return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", dialCtx.Err()) } return agentConn, nil