diff --git a/agent/agent_test.go b/agent/agent_test.go index 31f1448f34018..6cef939d47a4c 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1547,32 +1547,33 @@ func TestAgent_Dial(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - // Setup listener + // The purpose of this test is to ensure that a client can dial a + // listener in the workspace over tailnet. l := c.setup(t) - defer l.Close() - go func() { - for { - c, err := l.Accept() - if err != nil { - return - } + done := make(chan struct{}) + defer func() { + l.Close() + <-done + }() - go testAccept(t, c) - } + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + go func() { + defer close(done) + c, err := l.Accept() + assert.NoError(t, err, "accept connection") + defer c.Close() + testAccept(ctx, t, c) }() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - require.True(t, conn.AwaitReachable(context.Background())) - conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) + agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + require.True(t, agentConn.AwaitReachable(ctx)) + conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String()) require.NoError(t, err) - defer conn1.Close() - conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) - require.NoError(t, err) - defer conn2.Close() - testDial(t, conn2) - testDial(t, conn1) - time.Sleep(150 * time.Millisecond) + defer conn.Close() + testDial(ctx, t, conn) }) } } @@ -2002,22 +2003,41 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati var dialTestPayload = []byte("dean-was-here123") -func testDial(t *testing.T, c net.Conn) { +func testDial(ctx context.Context, t *testing.T, c net.Conn) { t.Helper() + if deadline, ok := ctx.Deadline(); ok { + err := c.SetDeadline(deadline) + assert.NoError(t, err) + defer func() { + err := c.SetDeadline(time.Time{}) + assert.NoError(t, err) + }() + } + assertWritePayload(t, c, dialTestPayload) assertReadPayload(t, c, dialTestPayload) } -func testAccept(t *testing.T, c net.Conn) { +func testAccept(ctx context.Context, t *testing.T, c net.Conn) { t.Helper() defer c.Close() + if deadline, ok := ctx.Deadline(); ok { + err := c.SetDeadline(deadline) + assert.NoError(t, err) + defer func() { + err := c.SetDeadline(time.Time{}) + assert.NoError(t, err) + }() + } + assertReadPayload(t, c, dialTestPayload) assertWritePayload(t, c, dialTestPayload) } func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { + t.Helper() b := make([]byte, len(payload)+16) n, err := r.Read(b) assert.NoError(t, err, "read payload") @@ -2026,6 +2046,7 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { } func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { + t.Helper() n, err := w.Write(payload) assert.NoError(t, err, "write payload") assert.Equal(t, len(payload), n, "payload length does not match")