diff --git a/cli/portforward.go b/cli/portforward.go index 3af3a1ca8411f..68e1eb3d800ef 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -7,6 +7,7 @@ import ( "net/netip" "os" "os/signal" + "regexp" "strconv" "strings" "sync" @@ -263,7 +264,7 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { for _, specEntry := range tcpSpecs { for _, spec := range strings.Split(specEntry, ",") { - ports, err := parseSrcDestPorts(spec) + ports, err := parseSrcDestPorts(strings.TrimSpace(spec)) if err != nil { return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) } @@ -281,7 +282,7 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { for _, specEntry := range udpSpecs { for _, spec := range strings.Split(specEntry, ",") { - ports, err := parseSrcDestPorts(spec) + ports, err := parseSrcDestPorts(strings.TrimSpace(spec)) if err != nil { return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) } @@ -326,63 +327,53 @@ type parsedSrcDestPort struct { local, remote netip.AddrPort } +// specRegexp matches port specs. It handles all the following formats: +// +// 8000 +// 8888:9999 +// 1-5:6-10 +// 8000-8005 +// 127.0.0.1:4000:4000 +// [::1]:8080:8081 +// 127.0.0.1:4000-4005 +// [::1]:4000-4001:5000-5001 +// +// Important capturing groups: +// +// 2: local IP address (including [] for IPv6) +// 3: local port, or start of local port range +// 5: end of local port range +// 7: remote port, or start of remote port range +// 9: end or remote port range +var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`) + func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { var ( err error - parts = strings.Split(in, ":") localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) ) - - switch len(parts) { - case 1: - // Duplicate the single part - parts = append(parts, parts[0]) - case 2: - // Check to see if the first part is an IP address. - _localAddr, err := netip.ParseAddr(parts[0]) - if err != nil { - break - } - // The first part is the local address, so duplicate the port. - localAddr = _localAddr - parts = []string{parts[1], parts[1]} - - case 3: - _localAddr, err := netip.ParseAddr(parts[0]) - if err != nil { - return nil, xerrors.Errorf("invalid port specification %q; invalid ip %q: %w", in, parts[0], err) - } - localAddr = _localAddr - parts = parts[1:] - - default: + groups := specRegexp.FindStringSubmatch(in) + if len(groups) == 0 { return nil, xerrors.Errorf("invalid port specification %q", in) } - - if !strings.Contains(parts[0], "-") { - localPort, err := parsePort(parts[0]) + if groups[2] != "" { + localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]")) if err != nil { - return nil, xerrors.Errorf("parse local port from %q: %w", in, err) + return nil, xerrors.Errorf("invalid IP address %q", groups[2]) } - remotePort, err := parsePort(parts[1]) - if err != nil { - return nil, xerrors.Errorf("parse remote port from %q: %w", in, err) - } - - return []parsedSrcDestPort{{ - local: netip.AddrPortFrom(localAddr, localPort), - remote: netip.AddrPortFrom(remoteAddr, remotePort), - }}, nil } - local, err := parsePortRange(parts[0]) + local, err := parsePortRange(groups[3], groups[5]) if err != nil { return nil, xerrors.Errorf("parse local port range from %q: %w", in, err) } - remote, err := parsePortRange(parts[1]) - if err != nil { - return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err) + remote := local + if groups[7] != "" { + remote, err = parsePortRange(groups[7], groups[9]) + if err != nil { + return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err) + } } if len(local) != len(remote) { return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote)) @@ -397,18 +388,17 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { return out, nil } -func parsePortRange(in string) ([]uint16, error) { - parts := strings.Split(in, "-") - if len(parts) != 2 { - return nil, xerrors.Errorf("invalid port range specification %q", in) - } - start, err := parsePort(parts[0]) +func parsePortRange(s, e string) ([]uint16, error) { + start, err := parsePort(s) if err != nil { - return nil, xerrors.Errorf("parse range start port from %q: %w", in, err) + return nil, xerrors.Errorf("parse range start port from %q: %w", s, err) } - end, err := parsePort(parts[1]) - if err != nil { - return nil, xerrors.Errorf("parse range end port from %q: %w", in, err) + end := start + if len(e) != 0 { + end, err = parsePort(e) + if err != nil { + return nil, xerrors.Errorf("parse range end port from %q: %w", e, err) + } } if end < start { return nil, xerrors.Errorf("range end port %v is less than start port %v", end, start) diff --git a/cli/portforward_internal_test.go b/cli/portforward_internal_test.go index ad083b8cf0705..08d03f5a95db1 100644 --- a/cli/portforward_internal_test.go +++ b/cli/portforward_internal_test.go @@ -1,8 +1,6 @@ package cli import ( - "fmt" - "strings" "testing" "github.com/stretchr/testify/require" @@ -11,13 +9,6 @@ import ( func Test_parsePortForwards(t *testing.T) { t.Parallel() - portForwardSpecToString := func(v []portForwardSpec) (out []string) { - for _, p := range v { - require.Equal(t, p.listenNetwork, p.dialNetwork) - out = append(out, fmt.Sprintf("%s:%s", strings.Replace(p.listenAddress, "127.0.0.1:", "", 1), strings.Replace(p.dialAddress, "127.0.0.1:", "", 1))) - } - return out - } type args struct { tcpSpecs []string udpSpecs []string @@ -25,7 +16,7 @@ func Test_parsePortForwards(t *testing.T) { tests := []struct { name string args args - want []string + want []portForwardSpec wantErr bool }{ { @@ -34,17 +25,37 @@ func Test_parsePortForwards(t *testing.T) { tcpSpecs: []string{ "8000,8080:8081,9000-9002,9003-9004:9005-9006", "10000", + "4444-4444", }, }, - want: []string{ - "8000:8000", - "8080:8081", - "9000:9000", - "9001:9001", - "9002:9002", - "9003:9005", - "9004:9006", - "10000:10000", + want: []portForwardSpec{ + {"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"}, + {"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"}, + {"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"}, + {"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"}, + {"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"}, + {"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"}, + {"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"}, + {"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"}, + {"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"}, + }, + }, + { + name: "TCP IPv4 local", + args: args{ + tcpSpecs: []string{"127.0.0.1:8080:8081"}, + }, + want: []portForwardSpec{ + {"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"}, + }, + }, + { + name: "TCP IPv6 local", + args: args{ + tcpSpecs: []string{"[::1]:8080:8081"}, + }, + want: []portForwardSpec{ + {"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"}, }, }, { @@ -52,10 +63,28 @@ func Test_parsePortForwards(t *testing.T) { args: args{ udpSpecs: []string{"8000,8080-8081"}, }, - want: []string{ - "8000:8000", - "8080:8080", - "8081:8081", + want: []portForwardSpec{ + {"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"}, + {"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"}, + {"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"}, + }, + }, + { + name: "UDP IPv4 local", + args: args{ + udpSpecs: []string{"127.0.0.1:8080:8081"}, + }, + want: []portForwardSpec{ + {"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"}, + }, + }, + { + name: "UDP IPv6 local", + args: args{ + udpSpecs: []string{"[::1]:8080:8081"}, + }, + want: []portForwardSpec{ + {"udp", "[::1]:8080", "udp", "127.0.0.1:8081"}, }, }, { @@ -83,8 +112,7 @@ func Test_parsePortForwards(t *testing.T) { t.Fatalf("parsePortForwards() error = %v, wantErr %v", err, tt.wantErr) return } - gotStrings := portForwardSpecToString(got) - require.Equal(t, tt.want, gotStrings) + require.Equal(t, tt.want, got) }) } } diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 29fccafb20ac1..916b57105828a 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -92,6 +92,16 @@ func TestPortForward(t *testing.T) { }, localAddress: []string{"10.10.10.99:9999", "10.10.10.10:1010"}, }, + { + name: "TCP-IPv6", + network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"}, + setupRemote: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + localAddress: []string{"[fe80::99]:9999", "[fe80::10]:1010"}, + }, } // Setup agent once to be shared between test-cases (avoid expensive