diff --git a/cli/portforward.go b/cli/portforward.go index 7a7723213f760..c243f6d35228a 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -3,14 +3,8 @@ package cli import ( "context" "fmt" - "net" - "net/netip" "os" "os/signal" - "regexp" - "strconv" - "strings" - "sync" "syscall" "golang.org/x/xerrors" @@ -18,21 +12,13 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/portforward" "github.com/coder/serpent" ) -var ( - // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify - // when the local address is not specified in port-forward flags. - noAddr netip.Addr - ipv6Loopback = netip.MustParseAddr("::1") - ipv4Loopback = netip.MustParseAddr("127.0.0.1") -) - func (r *RootCmd) portForward() *serpent.Command { var ( tcpForwards []string // : @@ -76,7 +62,7 @@ func (r *RootCmd) portForward() *serpent.Command { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - specs, err := parsePortForwards(tcpForwards, udpForwards) + specs, err := portforward.ParseSpecs(tcpForwards, udpForwards) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) } @@ -127,74 +113,79 @@ func (r *RootCmd) portForward() *serpent.Command { } defer conn.Close() - // Start all listeners. - var ( - wg = new(sync.WaitGroup) - listeners = make([]net.Listener, 0, len(specs)*2) - closeAllListeners = func() { - logger.Debug(ctx, "closing all listeners") - for _, l := range listeners { - if l == nil { - continue - } - _ = l.Close() - } + // Create port forwarding manager + pfOpts := portforward.Options{ + Logger: logger, + Dialer: conn, + Listener: inv.Net, + } + manager := portforward.NewManager(pfOpts) + defer func() { + if stopErr := manager.Stop(); stopErr != nil { + logger.Error(ctx, "failed to stop port forwarding manager", slog.Error(stopErr)) } - ) - defer closeAllListeners() + }() + + // Create a signal handler for graceful shutdown + shutdownCh := make(chan struct{}) + go func() { + defer close(shutdownCh) + + // Wait until context is canceled (Ctrl+C, etc.) + <-ctx.Done() + }() for _, spec := range specs { - if spec.listenHost == noAddr { + if spec.ListenHost == portforward.NoAddr { // first, opportunistically try to listen on IPv6 spec6 := spec - spec6.listenHost = ipv6Loopback - l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger) - if err6 != nil { - logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6)) + spec6.ListenHost = portforward.IPv6Loopback + _, err := manager.Add(spec6) + if err != nil { + logger.Info(ctx, "failed to opportunistically add IPv6 forwarder", slog.F("spec", spec6), slog.Error(err)) } else { - listeners = append(listeners, l6) + _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://[%s]:%d' locally to '%s://127.0.0.1:%d' in the workspace\n", + spec6.Network, spec6.ListenHost, spec6.ListenPort, spec6.Network, spec6.DialPort) } - spec.listenHost = ipv4Loopback + spec.ListenHost = portforward.IPv4Loopback } - l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) + + _, err := manager.Add(spec) if err != nil { - logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) + logger.Error(ctx, "failed to add forwarder", slog.F("spec", spec), slog.Error(err)) return err } - listeners = append(listeners, l) - } - stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID) + _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s:%d' locally to '%s://127.0.0.1:%d' in the workspace\n", + spec.Network, spec.ListenHost, spec.ListenPort, spec.Network, spec.DialPort) + } - // Wait for the context to be canceled or for a signal and close - // all listeners. - var closeErr error - wg.Add(1) - go func() { - defer wg.Done() + // Start all forwarders at once + err = manager.Start(ctx) + if err != nil { + return xerrors.Errorf("start port forwarding: %w", err) + } - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + conn.AwaitReachable(ctx) + logger.Debug(ctx, "ready to accept connections to forward") + _, _ = fmt.Fprintln(inv.Stderr, "Ready!") - select { - case <-ctx.Done(): - logger.Debug(ctx, "command context expired waiting for signal", slog.Error(ctx.Err())) - closeErr = ctx.Err() - case sig := <-sigs: - logger.Debug(ctx, "received signal", slog.F("signal", sig)) - _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections") - } + stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID) + defer stopUpdating() - cancel() - stopUpdating() - closeAllListeners() - }() + // Wait for shutdown signal or context cancellation + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - conn.AwaitReachable(ctx) - logger.Debug(ctx, "read to accept connections to forward") - _, _ = fmt.Fprintln(inv.Stderr, "Ready!") - wg.Wait() - return closeErr + select { + case <-shutdownCh: + logger.Debug(ctx, "context canceled") + return ctx.Err() + case sig := <-sigs: + logger.Debug(ctx, "received signal", slog.F("signal", sig)) + _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections") + return nil + } }, } @@ -217,214 +208,3 @@ func (r *RootCmd) portForward() *serpent.Command { return cmd } - -func listenAndPortForward( - ctx context.Context, - inv *serpent.Invocation, - conn *workspacesdk.AgentConn, - wg *sync.WaitGroup, - spec portForwardSpec, - logger slog.Logger, -) (net.Listener, error) { - logger = logger.With( - slog.F("network", spec.network), - slog.F("listen_host", spec.listenHost), - slog.F("listen_port", spec.listenPort), - ) - listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort) - dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort) - _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n", - spec.network, listenAddress, spec.network, dialAddress) - - l, err := inv.Net.Listen(spec.network, listenAddress.String()) - if err != nil { - return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err) - } - logger.Debug(ctx, "listening") - - wg.Add(1) - go func(spec portForwardSpec) { - defer wg.Done() - for { - netConn, err := l.Accept() - if err != nil { - // Silently ignore net.ErrClosed errors. - if xerrors.Is(err, net.ErrClosed) { - logger.Debug(ctx, "listener closed") - return - } - _, _ = fmt.Fprintf(inv.Stderr, - "Error accepting connection from '%s://%s': %v\n", - spec.network, listenAddress.String(), err) - _, _ = fmt.Fprintln(inv.Stderr, "Killing listener") - return - } - logger.Debug(ctx, "accepted connection", - slog.F("remote_addr", netConn.RemoteAddr())) - - go func(netConn net.Conn) { - defer netConn.Close() - remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress) - if err != nil { - _, _ = fmt.Fprintf(inv.Stderr, - "Failed to dial '%s://%s' in workspace: %s\n", - spec.network, dialAddress, err) - return - } - defer remoteConn.Close() - logger.Debug(ctx, - "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) - - agentssh.Bicopy(ctx, netConn, remoteConn) - logger.Debug(ctx, - "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) - }(netConn) - } - }(spec) - - return l, nil -} - -type portForwardSpec struct { - network string // tcp, udp - listenHost netip.Addr - listenPort, dialPort uint16 -} - -func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { - specs := []portForwardSpec{} - - for _, specEntry := range tcpSpecs { - for _, spec := range strings.Split(specEntry, ",") { - pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) - if err != nil { - return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) - } - - for _, pfSpec := range pfSpecs { - pfSpec.network = "tcp" - specs = append(specs, pfSpec) - } - } - } - - for _, specEntry := range udpSpecs { - for _, spec := range strings.Split(specEntry, ",") { - pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) - if err != nil { - return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) - } - - for _, pfSpec := range pfSpecs { - pfSpec.network = "udp" - specs = append(specs, pfSpec) - } - } - } - - // Check for duplicate entries. - locals := map[string]struct{}{} - for _, spec := range specs { - localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort) - if _, ok := locals[localStr]; ok { - return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort) - } - locals[localStr] = struct{}{} - } - - return specs, nil -} - -func parsePort(in string) (uint16, error) { - port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16) - if err != nil { - return 0, xerrors.Errorf("parse port %q: %w", in, err) - } - if port == 0 { - return 0, xerrors.New("port cannot be 0") - } - - return uint16(port), nil -} - -// specRegexp matches port specs. It handles all the following formats: -// -// 8000 -// 8888:9999 -// 1-5:6-10 -// 8000-8005 -// 127.0.0.1:4000:4000 -// [::1]:8080:8081 -// 127.0.0.1:4000-4005 -// [::1]:4000-4001:5000-5001 -// -// Important capturing groups: -// -// 2: local IP address (including [] for IPv6) -// 3: local port, or start of local port range -// 5: end of local port range -// 7: remote port, or start of remote port range -// 9: end or remote port range -var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`) - -func parseSrcDestPorts(in string) ([]portForwardSpec, error) { - groups := specRegexp.FindStringSubmatch(in) - if len(groups) == 0 { - return nil, xerrors.Errorf("invalid port specification %q", in) - } - - var localAddr netip.Addr - if groups[2] != "" { - parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]")) - if err != nil { - return nil, xerrors.Errorf("invalid IP address %q", groups[2]) - } - localAddr = parsedAddr - } - - local, err := parsePortRange(groups[3], groups[5]) - if err != nil { - return nil, xerrors.Errorf("parse local port range from %q: %w", in, err) - } - remote := local - if groups[7] != "" { - remote, err = parsePortRange(groups[7], groups[9]) - if err != nil { - return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err) - } - } - if len(local) != len(remote) { - return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote)) - } - var out []portForwardSpec - for i := range local { - out = append(out, portForwardSpec{ - listenHost: localAddr, - listenPort: local[i], - dialPort: remote[i], - }) - } - return out, nil -} - -func parsePortRange(s, e string) ([]uint16, error) { - start, err := parsePort(s) - if err != nil { - return nil, xerrors.Errorf("parse range start port from %q: %w", s, err) - } - end := start - if len(e) != 0 { - end, err = parsePort(e) - if err != nil { - return nil, xerrors.Errorf("parse range end port from %q: %w", e, err) - } - } - if end < start { - return nil, xerrors.Errorf("range end port %v is less than start port %v", end, start) - } - var ports []uint16 - for i := start; i <= end; i++ { - ports = append(ports, i) - } - return ports, nil -} diff --git a/portforward/forwarder.go b/portforward/forwarder.go new file mode 100644 index 0000000000000..10bfc28815737 --- /dev/null +++ b/portforward/forwarder.go @@ -0,0 +1,272 @@ +package portforward + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/v2/agent/agentssh" +) + +// Spec represents a port forwarding specification. +type Spec struct { + Network string // tcp, udp + ListenHost netip.Addr // Local address to bind to + ListenPort uint16 // Local port to listen on + DialPort uint16 // Remote port to connect to +} + +// Forwarder handles a single port forward. +type Forwarder interface { + // Start begins the port forwarding operation. + Start(ctx context.Context) error + // Stop stops the port forwarding operation. + Stop() error + // IsActive returns true if the forwarder is currently active. + IsActive() bool + // Spec returns the port forwarding specification. + Spec() Spec +} + +// Manager manages multiple port forwards. +type Manager interface { + // Add adds a new port forward. + Add(spec Spec) (Forwarder, error) + // Remove removes an existing port forward. + Remove(spec Spec) error + // List returns all active port forwards. + List() []Forwarder + // Start starts all port forwards. + Start(ctx context.Context) error + // Stop stops all port forwards. + Stop() error +} + +// Dialer provides network dialing capabilities. +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Listener provides network listening capabilities. +type Listener interface { + Listen(network, address string) (net.Listener, error) +} + +// Options configures port forwarding behavior. +type Options struct { + Logger slog.Logger + Dialer Dialer + Listener Listener +} + +// LocalForwarder implements a single port forward from local to remote. +type LocalForwarder struct { + spec Spec + opts Options + listener net.Listener + active atomic.Bool + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewLocal creates a new local port forwarder. +func NewLocal(spec Spec, opts Options) *LocalForwarder { + return &LocalForwarder{ + spec: spec, + opts: opts, + } +} + +func (f *LocalForwarder) Start(ctx context.Context) error { + if f.active.Load() { + return xerrors.New("forwarder is already active") + } + + ctx, cancel := context.WithCancel(ctx) + f.cancel = cancel + + logger := f.opts.Logger.With( + slog.F("network", f.spec.Network), + slog.F("listen_host", f.spec.ListenHost), + slog.F("listen_port", f.spec.ListenPort), + ) + + listenAddress := netip.AddrPortFrom(f.spec.ListenHost, f.spec.ListenPort) + dialAddress := fmt.Sprintf("127.0.0.1:%d", f.spec.DialPort) + + l, err := f.opts.Listener.Listen(f.spec.Network, listenAddress.String()) + if err != nil { + cancel() + return xerrors.Errorf("listen '%s://%s': %w", f.spec.Network, listenAddress.String(), err) + } + f.listener = l + logger.Debug(ctx, "listening") + + f.active.Store(true) + + f.wg.Add(1) + go func() { + defer func() { + f.wg.Done() + f.active.Store(false) + }() + + for { + netConn, err := l.Accept() + if err != nil { + // Silently ignore net.ErrClosed errors. + if xerrors.Is(err, net.ErrClosed) { + logger.Debug(ctx, "listener closed") + return + } + logger.Error(ctx, "error accepting connection", + slog.F("listen_address", listenAddress.String()), + slog.Error(err)) + return + } + logger.Debug(ctx, "accepted connection", + slog.F("remote_addr", netConn.RemoteAddr())) + + go func(netConn net.Conn) { + defer netConn.Close() + remoteConn, err := f.opts.Dialer.DialContext(ctx, f.spec.Network, dialAddress) + if err != nil { + logger.Error(ctx, "failed to dial remote", + slog.F("dial_address", dialAddress), + slog.Error(err)) + return + } + defer remoteConn.Close() + logger.Debug(ctx, "dialed remote", + slog.F("remote_addr", netConn.RemoteAddr())) + + agentssh.Bicopy(ctx, netConn, remoteConn) + logger.Debug(ctx, "connection closing", + slog.F("remote_addr", netConn.RemoteAddr())) + }(netConn) + } + }() + + return nil +} + +func (f *LocalForwarder) Stop() error { + if !f.active.Load() { + return nil + } + + if f.cancel != nil { + f.cancel() + } + if f.listener != nil { + _ = f.listener.Close() + } + f.wg.Wait() + return nil +} + +func (f *LocalForwarder) IsActive() bool { + return f.active.Load() +} + +func (f *LocalForwarder) Spec() Spec { + return f.spec +} + +// manager implements the Manager interface. +type manager struct { + forwarders map[string]Forwarder + opts Options + mu sync.RWMutex +} + +// NewManager creates a new port forwarding manager. +func NewManager(opts Options) Manager { + return &manager{ + forwarders: make(map[string]Forwarder), + opts: opts, + } +} + +func (m *manager) Add(spec Spec) (Forwarder, error) { + m.mu.Lock() + defer m.mu.Unlock() + + key := fmt.Sprintf("%s:%s:%d", spec.Network, spec.ListenHost, spec.ListenPort) + if _, exists := m.forwarders[key]; exists { + return nil, xerrors.Errorf("forwarder already exists for %s", key) + } + + // Test if we can actually bind to the port before adding the forwarder + listenAddress := netip.AddrPortFrom(spec.ListenHost, spec.ListenPort) + testListener, err := m.opts.Listener.Listen(spec.Network, listenAddress.String()) + if err != nil { + return nil, xerrors.Errorf("cannot bind to '%s://%s': %w", spec.Network, listenAddress.String(), err) + } + // Close the test listener immediately since we just wanted to verify we can bind + _ = testListener.Close() + + forwarder := NewLocal(spec, m.opts) + m.forwarders[key] = forwarder + return forwarder, nil +} + +func (m *manager) Remove(spec Spec) error { + m.mu.Lock() + defer m.mu.Unlock() + + key := fmt.Sprintf("%s:%s:%d", spec.Network, spec.ListenHost, spec.ListenPort) + forwarder, exists := m.forwarders[key] + if !exists { + return xerrors.Errorf("forwarder not found for %s", key) + } + + err := forwarder.Stop() + delete(m.forwarders, key) + return err +} + +func (m *manager) List() []Forwarder { + m.mu.RLock() + defer m.mu.RUnlock() + + forwarders := make([]Forwarder, 0, len(m.forwarders)) + for _, f := range m.forwarders { + forwarders = append(forwarders, f) + } + return forwarders +} + +func (m *manager) Start(ctx context.Context) error { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, forwarder := range m.forwarders { + if !forwarder.IsActive() { + if err := forwarder.Start(ctx); err != nil { + return err + } + } + } + return nil +} + +func (m *manager) Stop() error { + m.mu.RLock() + defer m.mu.RUnlock() + + var lastErr error + for _, forwarder := range m.forwarders { + if err := forwarder.Stop(); err != nil { + lastErr = err + } + } + return lastErr +} diff --git a/portforward/parse.go b/portforward/parse.go new file mode 100644 index 0000000000000..d0f4a49713e56 --- /dev/null +++ b/portforward/parse.go @@ -0,0 +1,158 @@ +package portforward + +import ( + "fmt" + "net/netip" + "regexp" + "strconv" + "strings" + + "golang.org/x/xerrors" +) + +// Constants for default addresses. +var ( + // NoAddr is the zero-value of netip.Addr, used when no local address is specified. + NoAddr = netip.Addr{} + IPv6Loopback = netip.MustParseAddr("::1") + IPv4Loopback = netip.MustParseAddr("127.0.0.1") +) + +// specRegexp matches port specs. It handles all the following formats: +// +// 8000 +// 8888:9999 +// 1-5:6-10 +// 8000-8005 +// 127.0.0.1:4000:4000 +// [::1]:8080:8081 +// 127.0.0.1:4000-4005 +// [::1]:4000-4001:5000-5001 +// +// Important capturing groups: +// +// 2: local IP address (including [] for IPv6) +// 3: local port, or start of local port range +// 5: end of local port range +// 7: remote port, or start of remote port range +// 9: end or remote port range +var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`) + +// ParseSpecs parses TCP and UDP port forwarding specifications. +func ParseSpecs(tcpSpecs, udpSpecs []string) ([]Spec, error) { + specs := []Spec{} + + for _, specEntry := range tcpSpecs { + for _, spec := range strings.Split(specEntry, ",") { + pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) + if err != nil { + return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) + } + + for _, pfSpec := range pfSpecs { + pfSpec.Network = "tcp" + specs = append(specs, pfSpec) + } + } + } + + for _, specEntry := range udpSpecs { + for _, spec := range strings.Split(specEntry, ",") { + pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec)) + if err != nil { + return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) + } + + for _, pfSpec := range pfSpecs { + pfSpec.Network = "udp" + specs = append(specs, pfSpec) + } + } + } + + // Check for duplicate entries. + locals := map[string]struct{}{} + for _, spec := range specs { + localStr := fmt.Sprintf("%s:%s:%d", spec.Network, spec.ListenHost, spec.ListenPort) + if _, ok := locals[localStr]; ok { + return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.Network, spec.ListenHost, spec.ListenPort) + } + locals[localStr] = struct{}{} + } + + return specs, nil +} + +func parsePort(in string) (uint16, error) { + port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16) + if err != nil { + return 0, xerrors.Errorf("parse port %q: %w", in, err) + } + if port == 0 { + return 0, xerrors.New("port cannot be 0") + } + + return uint16(port), nil +} + +func parseSrcDestPorts(in string) ([]Spec, error) { + groups := specRegexp.FindStringSubmatch(in) + if len(groups) == 0 { + return nil, xerrors.Errorf("invalid port specification %q", in) + } + + var localAddr netip.Addr + if groups[2] != "" { + parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]")) + if err != nil { + return nil, xerrors.Errorf("invalid IP address %q", groups[2]) + } + localAddr = parsedAddr + } + + local, err := parsePortRange(groups[3], groups[5]) + if err != nil { + return nil, xerrors.Errorf("parse local port range from %q: %w", in, err) + } + remote := local + if groups[7] != "" { + remote, err = parsePortRange(groups[7], groups[9]) + if err != nil { + return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err) + } + } + if len(local) != len(remote) { + return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote)) + } + var out []Spec + for i := range local { + out = append(out, Spec{ + ListenHost: localAddr, + ListenPort: local[i], + DialPort: remote[i], + }) + } + return out, nil +} + +func parsePortRange(s, e string) ([]uint16, error) { + start, err := parsePort(s) + if err != nil { + return nil, xerrors.Errorf("parse range start port from %q: %w", s, err) + } + end := start + if len(e) != 0 { + end, err = parsePort(e) + if err != nil { + return nil, xerrors.Errorf("parse range end port from %q: %w", e, err) + } + } + if end < start { + return nil, xerrors.Errorf("range end port %v is less than start port %v", end, start) + } + var ports []uint16 + for i := start; i <= end; i++ { + ports = append(ports, i) + } + return ports, nil +} diff --git a/cli/portforward_internal_test.go b/portforward/parse_internal_test.go similarity index 58% rename from cli/portforward_internal_test.go rename to portforward/parse_internal_test.go index 5698363f95e5e..c63df7aa3bae0 100644 --- a/cli/portforward_internal_test.go +++ b/portforward/parse_internal_test.go @@ -1,4 +1,4 @@ -package cli +package portforward import ( "testing" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_parsePortForwards(t *testing.T) { +func Test_ParseSpecs(t *testing.T) { t.Parallel() type args struct { @@ -16,7 +16,7 @@ func Test_parsePortForwards(t *testing.T) { tests := []struct { name string args args - want []portForwardSpec + want []Spec wantErr bool }{ { @@ -28,16 +28,16 @@ func Test_parsePortForwards(t *testing.T) { "4444-4444", }, }, - want: []portForwardSpec{ - {"tcp", noAddr, 8000, 8000}, - {"tcp", noAddr, 8080, 8081}, - {"tcp", noAddr, 9000, 9000}, - {"tcp", noAddr, 9001, 9001}, - {"tcp", noAddr, 9002, 9002}, - {"tcp", noAddr, 9003, 9005}, - {"tcp", noAddr, 9004, 9006}, - {"tcp", noAddr, 10000, 10000}, - {"tcp", noAddr, 4444, 4444}, + want: []Spec{ + {"tcp", NoAddr, 8000, 8000}, + {"tcp", NoAddr, 8080, 8081}, + {"tcp", NoAddr, 9000, 9000}, + {"tcp", NoAddr, 9001, 9001}, + {"tcp", NoAddr, 9002, 9002}, + {"tcp", NoAddr, 9003, 9005}, + {"tcp", NoAddr, 9004, 9006}, + {"tcp", NoAddr, 10000, 10000}, + {"tcp", NoAddr, 4444, 4444}, }, }, { @@ -45,8 +45,8 @@ func Test_parsePortForwards(t *testing.T) { args: args{ tcpSpecs: []string{"127.0.0.1:8080:8081"}, }, - want: []portForwardSpec{ - {"tcp", ipv4Loopback, 8080, 8081}, + want: []Spec{ + {"tcp", IPv4Loopback, 8080, 8081}, }, }, { @@ -54,8 +54,8 @@ func Test_parsePortForwards(t *testing.T) { args: args{ tcpSpecs: []string{"[::1]:8080:8081"}, }, - want: []portForwardSpec{ - {"tcp", ipv6Loopback, 8080, 8081}, + want: []Spec{ + {"tcp", IPv6Loopback, 8080, 8081}, }, }, { @@ -63,10 +63,10 @@ func Test_parsePortForwards(t *testing.T) { args: args{ udpSpecs: []string{"8000,8080-8081"}, }, - want: []portForwardSpec{ - {"udp", noAddr, 8000, 8000}, - {"udp", noAddr, 8080, 8080}, - {"udp", noAddr, 8081, 8081}, + want: []Spec{ + {"udp", NoAddr, 8000, 8000}, + {"udp", NoAddr, 8080, 8080}, + {"udp", NoAddr, 8081, 8081}, }, }, { @@ -74,8 +74,8 @@ func Test_parsePortForwards(t *testing.T) { args: args{ udpSpecs: []string{"127.0.0.1:8080:8081"}, }, - want: []portForwardSpec{ - {"udp", ipv4Loopback, 8080, 8081}, + want: []Spec{ + {"udp", IPv4Loopback, 8080, 8081}, }, }, { @@ -83,8 +83,8 @@ func Test_parsePortForwards(t *testing.T) { args: args{ udpSpecs: []string{"[::1]:8080:8081"}, }, - want: []portForwardSpec{ - {"udp", ipv6Loopback, 8080, 8081}, + want: []Spec{ + {"udp", IPv6Loopback, 8080, 8081}, }, }, { @@ -106,9 +106,9 @@ func Test_parsePortForwards(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := parsePortForwards(tt.args.tcpSpecs, tt.args.udpSpecs) + got, err := ParseSpecs(tt.args.tcpSpecs, tt.args.udpSpecs) if (err != nil) != tt.wantErr { - t.Fatalf("parsePortForwards() error = %v, wantErr %v", err, tt.wantErr) + t.Fatalf("ParseSpecs() error = %v, wantErr %v", err, tt.wantErr) return } require.Equal(t, tt.want, got)