From 88ec60a4c5c8d650ff1bf469f943bc508e8b4cd2 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 25 Nov 2024 13:54:37 +0400 Subject: [PATCH] feat: opportunistically listen on IPv6 in port-forward --- cli/portforward.go | 121 ++++++++++++++++++------------- cli/portforward_internal_test.go | 32 ++++---- cli/portforward_test.go | 83 +++++++++++++++++++++ 3 files changed, 169 insertions(+), 67 deletions(-) diff --git a/cli/portforward.go b/cli/portforward.go index 68e1eb3d800ef..e6ef2eb11bca8 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -25,6 +25,14 @@ import ( "github.com/coder/serpent" ) +var ( + // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify + // when the local address is not specified in port-forward flags. + noAddr netip.Addr + ipv6Loopback = netip.MustParseAddr("::1") + ipv4Loopback = netip.MustParseAddr("127.0.0.1") +) + func (r *RootCmd) portForward() *serpent.Command { var ( tcpForwards []string // : @@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command { // Start all listeners. var ( wg = new(sync.WaitGroup) - listeners = make([]net.Listener, len(specs)) + listeners = make([]net.Listener, 0, len(specs)*2) closeAllListeners = func() { logger.Debug(ctx, "closing all listeners") for _, l := range listeners { @@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command { ) defer closeAllListeners() - for i, spec := range specs { + for _, spec := range specs { + if spec.listenHost == noAddr { + // first, opportunistically try to listen on IPv6 + spec6 := spec + spec6.listenHost = ipv6Loopback + l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger) + if err6 != nil { + logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6)) + } else { + listeners = append(listeners, l6) + } + spec.listenHost = ipv4Loopback + } l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) if err != nil { logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) return err } - listeners[i] = l + listeners = append(listeners, l) } stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID) @@ -206,12 +226,19 @@ func listenAndPortForward( spec portForwardSpec, logger slog.Logger, ) (net.Listener, error) { - logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress)) - _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) + logger = logger.With( + slog.F("network", spec.network), + slog.F("listen_host", spec.listenHost), + slog.F("listen_port", spec.listenPort), + ) + listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort) + dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort) + _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n", + spec.network, listenAddress, spec.network, dialAddress) - l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress) + l, err := inv.Net.Listen(spec.network, listenAddress.String()) if err != nil { - return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) + return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err) } logger.Debug(ctx, "listening") @@ -226,24 +253,31 @@ func listenAndPortForward( logger.Debug(ctx, "listener closed") return } - _, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err) + _, _ = fmt.Fprintf(inv.Stderr, + "Error accepting connection from '%s://%s': %v\n", + spec.network, listenAddress.String(), err) _, _ = fmt.Fprintln(inv.Stderr, "Killing listener") return } - logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr())) + logger.Debug(ctx, "accepted connection", + slog.F("remote_addr", netConn.RemoteAddr())) go func(netConn net.Conn) { defer netConn.Close() - remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) + remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress) if err != nil { - _, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to dial '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) return } defer remoteConn.Close() - logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) + logger.Debug(ctx, + "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) agentssh.Bicopy(ctx, netConn, remoteConn) - logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) + logger.Debug(ctx, + "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) }(netConn) } }(spec) @@ -252,11 +286,9 @@ func listenAndPortForward( } type portForwardSpec struct { - listenNetwork string // tcp, udp - listenAddress string // : or path - - dialNetwork string // tcp, udp - dialAddress string // : or path + network string // tcp, udp + listenHost netip.Addr + listenPort, dialPort uint16 } func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { @@ -264,36 +296,28 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { for _, specEntry := range tcpSpecs { for _, spec := range strings.Split(specEntry, ",") { - ports, err := parseSrcDestPorts(strings.TrimSpace(spec)) + pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) if err != nil { return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) } - for _, port := range ports { - specs = append(specs, portForwardSpec{ - listenNetwork: "tcp", - listenAddress: port.local.String(), - dialNetwork: "tcp", - dialAddress: port.remote.String(), - }) + for _, pfSpec := range pfSpecs { + pfSpec.network = "tcp" + specs = append(specs, pfSpec) } } } for _, specEntry := range udpSpecs { for _, spec := range strings.Split(specEntry, ",") { - ports, err := parseSrcDestPorts(strings.TrimSpace(spec)) + pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) if err != nil { return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) } - for _, port := range ports { - specs = append(specs, portForwardSpec{ - listenNetwork: "udp", - listenAddress: port.local.String(), - dialNetwork: "udp", - dialAddress: port.remote.String(), - }) + for _, pfSpec := range pfSpecs { + pfSpec.network = "udp" + specs = append(specs, pfSpec) } } } @@ -301,9 +325,9 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { // Check for duplicate entries. locals := map[string]struct{}{} for _, spec := range specs { - localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress) + localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort) if _, ok := locals[localStr]; ok { - return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress) + return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort) } locals[localStr] = struct{}{} } @@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) { return uint16(port), nil } -type parsedSrcDestPort struct { - local, remote netip.AddrPort -} - // specRegexp matches port specs. It handles all the following formats: // // 8000 @@ -347,21 +367,19 @@ type parsedSrcDestPort struct { // 9: end or remote port range var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`) -func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { - var ( - err error - localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) - remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) - ) +func parseSrcDestPorts(in string) ([]portForwardSpec, error) { groups := specRegexp.FindStringSubmatch(in) if len(groups) == 0 { return nil, xerrors.Errorf("invalid port specification %q", in) } + + var localAddr netip.Addr if groups[2] != "" { - localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]")) + parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]")) if err != nil { return nil, xerrors.Errorf("invalid IP address %q", groups[2]) } + localAddr = parsedAddr } local, err := parsePortRange(groups[3], groups[5]) @@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { if len(local) != len(remote) { return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote)) } - var out []parsedSrcDestPort + var out []portForwardSpec for i := range local { - out = append(out, parsedSrcDestPort{ - local: netip.AddrPortFrom(localAddr, local[i]), - remote: netip.AddrPortFrom(remoteAddr, remote[i]), + out = append(out, portForwardSpec{ + listenHost: localAddr, + listenPort: local[i], + dialPort: remote[i], }) } return out, nil diff --git a/cli/portforward_internal_test.go b/cli/portforward_internal_test.go index 08d03f5a95db1..0d1259713dac9 100644 --- a/cli/portforward_internal_test.go +++ b/cli/portforward_internal_test.go @@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) { }, }, want: []portForwardSpec{ - {"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"}, - {"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"}, - {"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"}, - {"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"}, - {"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"}, - {"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"}, - {"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"}, - {"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"}, - {"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"}, + {"tcp", noAddr, 8000, 8000}, + {"tcp", noAddr, 8080, 8081}, + {"tcp", noAddr, 9000, 9000}, + {"tcp", noAddr, 9001, 9001}, + {"tcp", noAddr, 9002, 9002}, + {"tcp", noAddr, 9003, 9005}, + {"tcp", noAddr, 9004, 9006}, + {"tcp", noAddr, 10000, 10000}, + {"tcp", noAddr, 4444, 4444}, }, }, { @@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) { tcpSpecs: []string{"127.0.0.1:8080:8081"}, }, want: []portForwardSpec{ - {"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"}, + {"tcp", ipv4Loopback, 8080, 8081}, }, }, { @@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) { tcpSpecs: []string{"[::1]:8080:8081"}, }, want: []portForwardSpec{ - {"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"}, + {"tcp", ipv6Loopback, 8080, 8081}, }, }, { @@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) { udpSpecs: []string{"8000,8080-8081"}, }, want: []portForwardSpec{ - {"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"}, - {"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"}, - {"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"}, + {"udp", noAddr, 8000, 8000}, + {"udp", noAddr, 8080, 8080}, + {"udp", noAddr, 8081, 8081}, }, }, { @@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) { udpSpecs: []string{"127.0.0.1:8080:8081"}, }, want: []portForwardSpec{ - {"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"}, + {"udp", ipv4Loopback, 8080, 8081}, }, }, { @@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) { udpSpecs: []string{"[::1]:8080:8081"}, }, want: []portForwardSpec{ - {"udp", "[::1]:8080", "udp", "127.0.0.1:8081"}, + {"udp", ipv6Loopback, 8080, 8081}, }, }, { diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 916b57105828a..e1672a5927047 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -67,6 +67,17 @@ func TestPortForward(t *testing.T) { }, localAddress: []string{"127.0.0.1:5555", "127.0.0.1:6666"}, }, + { + name: "TCP-opportunistic-ipv6", + network: "tcp", + flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"}, + setupRemote: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + localAddress: []string{"[::1]:5566", "[::1]:6655"}, + }, { name: "UDP", network: "udp", @@ -82,6 +93,21 @@ func TestPortForward(t *testing.T) { }, localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"}, }, + { + name: "UDP-opportunistic-ipv6", + network: "udp", + flag: []string{"--udp=7788:%v", "--udp=8877:%v"}, + setupRemote: func(t *testing.T) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l + }, + localAddress: []string{"[::1]:7788", "[::1]:8877"}, + }, { name: "TCPWithAddress", network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"}, @@ -295,6 +321,63 @@ func TestPortForward(t *testing.T) { require.NoError(t, err) require.Greater(t, updated.LastUsedAt, workspace.LastUsedAt) }) + + t.Run("IPv6Busy", func(t *testing.T) { + t.Parallel() + + remoteLis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + p1 := setupTestListener(t, remoteLis) + + // Create a flag that forwards from local 5555 to remote listener port. + flag := fmt.Sprintf("--tcp=5555:%v", p1) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listener. + inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) + clitest.SetupConfig(t, member, root) + pty := ptytest.New(t) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + + iNet := newInProcNet() + inv.Net = iNet + + // listen on port 5555 on IPv6 so it's busy when we try to port forward + busyLis, err := iNet.Listen("tcp", "[::1]:5555") + require.NoError(t, err) + defer busyLis.Close() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + errC := make(chan error) + go func() { + err := inv.WithContext(ctx).Run() + t.Logf("command complete; err=%s", err.Error()) + errC <- err + }() + pty.ExpectMatchContext(ctx, "Ready!") + + // Test IPv4 still works + dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort) + defer dialCtxCancel() + c1, err := iNet.dial(dialCtx, addr{"tcp", "127.0.0.1:5555"}) + require.NoError(t, err, "open connection 1 to 'local' listener") + defer c1.Close() + testDial(t, c1) + + cancel() + err = <-errC + require.ErrorIs(t, err, context.Canceled) + + flushCtx := testutil.Context(t, testutil.WaitShort) + testutil.RequireSendCtx(flushCtx, t, wuTick, dbtime.Now()) + _ = testutil.RequireRecvCtx(flushCtx, t, wuFlush) + updated, err := client.Workspace(context.Background(), workspace.ID) + require.NoError(t, err) + require.Greater(t, updated.LastUsedAt, workspace.LastUsedAt) + }) } // runAgent creates a fake workspace and starts an agent locally for that