Skip to content

Commit 18aabca

Browse files
committed
fix: Prevent agentConn use before ready via AwaitReachable
1 parent 03b0917 commit 18aabca

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

coderd/coderd_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,13 @@ func TestDERP(t *testing.T) {
7878
DERPMap: derpMap,
7979
})
8080
require.NoError(t, err)
81+
w2Ready := make(chan struct{}, 1)
8182
w1.SetNodeCallback(func(node *tailnet.Node) {
8283
w2.UpdateNodes([]*tailnet.Node{node})
84+
select {
85+
case w2Ready <- struct{}{}:
86+
default:
87+
}
8388
})
8489
w2.SetNodeCallback(func(node *tailnet.Node) {
8590
w1.UpdateNodes([]*tailnet.Node{node})
@@ -98,6 +103,7 @@ func TestDERP(t *testing.T) {
98103
}()
99104

100105
<-conn
106+
<-w2Ready
101107
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
102108
require.NoError(t, err)
103109
_ = nc.Close()

coderd/workspaceagents_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
469469
t.Parallel()
470470

471471
setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) {
472+
t.Helper()
473+
472474
client := coderdtest.New(t, &coderdtest.Options{
473475
IncludeProvisionerDaemon: true,
474476
})

coderd/wsconncache/wsconncache_test.go

+11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/coder/coder/codersdk/agentsdk"
3030
"github.com/coder/coder/tailnet"
3131
"github.com/coder/coder/tailnet/tailnettest"
32+
"github.com/coder/coder/testutil"
3233
)
3334

3435
func TestMain(m *testing.M) {
@@ -131,6 +132,14 @@ func TestCache(t *testing.T) {
131132
return
132133
}
133134
defer release()
135+
136+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
137+
defer cancel()
138+
if !conn.AwaitReachable(ctx) {
139+
t.Error("agent not reachable")
140+
return
141+
}
142+
134143
transport := conn.HTTPTransport()
135144
defer transport.CloseIdleConnections()
136145
proxy.Transport = transport
@@ -146,6 +155,8 @@ func TestCache(t *testing.T) {
146155
}
147156

148157
func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
158+
t.Helper()
159+
149160
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
150161

151162
coordinator := tailnet.NewCoordinator()

codersdk/workspaceagentconn.go

+18-3
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct {
176176
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
177177
ctx, span := tracing.StartSpan(ctx)
178178
defer span.End()
179-
179+
if !c.AwaitReachable(ctx) {
180+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
181+
}
180182
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
181183
if err != nil {
182184
return nil, err
@@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
207209
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
208210
ctx, span := tracing.StartSpan(ctx)
209211
defer span.End()
212+
if !c.AwaitReachable(ctx) {
213+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
214+
}
210215
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
211216
}
212217

@@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
235240
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
236241
ctx, span := tracing.StartSpan(ctx)
237242
defer span.End()
243+
if !c.AwaitReachable(ctx) {
244+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
245+
}
238246
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
239247
if err != nil {
240248
return nil, xerrors.Errorf("dial speedtest: %w", err)
@@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad
257265
_, rawPort, _ := net.SplitHostPort(addr)
258266
port, _ := strconv.ParseUint(rawPort, 10, 16)
259267
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
268+
if !c.AwaitReachable(ctx) {
269+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
270+
}
260271
if network == "udp" {
261272
return c.Conn.DialContextUDP(ctx, ipp)
262273
}
@@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
317328
// Disable keep alives as we're usually only making a single
318329
// request, and this triggers goleak in tests
319330
DisableKeepAlives: true,
320-
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
331+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
321332
if network != "tcp" {
322333
return nil, xerrors.Errorf("network must be tcp")
323334
}
@@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
331342
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
332343
}
333344

334-
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
345+
if !c.AwaitReachable(ctx) {
346+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
347+
}
348+
349+
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
335350
if err != nil {
336351
return nil, xerrors.Errorf("dial http api: %w", err)
337352
}

0 commit comments

Comments
 (0)