Skip to content

chore: move port forwarding out of cli package #19092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 58 additions & 278 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,22 @@ package cli
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"sync"
"syscall"

"golang.org/x/xerrors"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"

"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/portforward"
"github.com/coder/serpent"
)

var (
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
// when the local address is not specified in port-forward flags.
noAddr netip.Addr
ipv6Loopback = netip.MustParseAddr("::1")
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
)

func (r *RootCmd) portForward() *serpent.Command {
var (
tcpForwards []string // <port>:<port>
Expand Down Expand Up @@ -76,7 +62,7 @@ func (r *RootCmd) portForward() *serpent.Command {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()

specs, err := parsePortForwards(tcpForwards, udpForwards)
specs, err := portforward.ParseSpecs(tcpForwards, udpForwards)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
}
Expand Down Expand Up @@ -127,74 +113,79 @@ func (r *RootCmd) portForward() *serpent.Command {
}
defer conn.Close()

// Start all listeners.
var (
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, 0, len(specs)*2)
closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners {
if l == nil {
continue
}
_ = l.Close()
}
// Create port forwarding manager
pfOpts := portforward.Options{
Logger: logger,
Dialer: conn,
Listener: inv.Net,
}
manager := portforward.NewManager(pfOpts)
defer func() {
if stopErr := manager.Stop(); stopErr != nil {
logger.Error(ctx, "failed to stop port forwarding manager", slog.Error(stopErr))
}
)
defer closeAllListeners()
}()

// Create a signal handler for graceful shutdown
shutdownCh := make(chan struct{})
go func() {
defer close(shutdownCh)

// Wait until context is canceled (Ctrl+C, etc.)
<-ctx.Done()
}()

for _, spec := range specs {
if spec.listenHost == noAddr {
if spec.ListenHost == portforward.NoAddr {
// first, opportunistically try to listen on IPv6
spec6 := spec
spec6.listenHost = ipv6Loopback
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
if err6 != nil {
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
spec6.ListenHost = portforward.IPv6Loopback
_, err := manager.Add(spec6)
if err != nil {
logger.Info(ctx, "failed to opportunistically add IPv6 forwarder", slog.F("spec", spec6), slog.Error(err))
} else {
listeners = append(listeners, l6)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://[%s]:%d' locally to '%s://127.0.0.1:%d' in the workspace\n",
spec6.Network, spec6.ListenHost, spec6.ListenPort, spec6.Network, spec6.DialPort)
}
spec.listenHost = ipv4Loopback
spec.ListenHost = portforward.IPv4Loopback
}
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)

_, err := manager.Add(spec)
if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
logger.Error(ctx, "failed to add forwarder", slog.F("spec", spec), slog.Error(err))
return err
}
listeners = append(listeners, l)
}

stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s:%d' locally to '%s://127.0.0.1:%d' in the workspace\n",
spec.Network, spec.ListenHost, spec.ListenPort, spec.Network, spec.DialPort)
}

// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
wg.Add(1)
go func() {
defer wg.Done()
// Start all forwarders at once
err = manager.Start(ctx)
if err != nil {
return xerrors.Errorf("start port forwarding: %w", err)
}

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
conn.AwaitReachable(ctx)
logger.Debug(ctx, "ready to accept connections to forward")
_, _ = fmt.Fprintln(inv.Stderr, "Ready!")

select {
case <-ctx.Done():
logger.Debug(ctx, "command context expired waiting for signal", slog.Error(ctx.Err()))
closeErr = ctx.Err()
case sig := <-sigs:
logger.Debug(ctx, "received signal", slog.F("signal", sig))
_, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections")
}
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
defer stopUpdating()

cancel()
stopUpdating()
closeAllListeners()
}()
// Wait for shutdown signal or context cancellation
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

conn.AwaitReachable(ctx)
logger.Debug(ctx, "read to accept connections to forward")
_, _ = fmt.Fprintln(inv.Stderr, "Ready!")
wg.Wait()
return closeErr
select {
case <-shutdownCh:
logger.Debug(ctx, "context canceled")
return ctx.Err()
case sig := <-sigs:
logger.Debug(ctx, "received signal", slog.F("signal", sig))
_, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections")
return nil
}
},
}

Expand All @@ -217,214 +208,3 @@ func (r *RootCmd) portForward() *serpent.Command {

return cmd
}

func listenAndPortForward(
ctx context.Context,
inv *serpent.Invocation,
conn *workspacesdk.AgentConn,
wg *sync.WaitGroup,
spec portForwardSpec,
logger slog.Logger,
) (net.Listener, error) {
logger = logger.With(
slog.F("network", spec.network),
slog.F("listen_host", spec.listenHost),
slog.F("listen_port", spec.listenPort),
)
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
spec.network, listenAddress, spec.network, dialAddress)

l, err := inv.Net.Listen(spec.network, listenAddress.String())
if err != nil {
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
}
logger.Debug(ctx, "listening")

wg.Add(1)
go func(spec portForwardSpec) {
defer wg.Done()
for {
netConn, err := l.Accept()
if err != nil {
// Silently ignore net.ErrClosed errors.
if xerrors.Is(err, net.ErrClosed) {
logger.Debug(ctx, "listener closed")
return
}
_, _ = fmt.Fprintf(inv.Stderr,
"Error accepting connection from '%s://%s': %v\n",
spec.network, listenAddress.String(), err)
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
return
}
logger.Debug(ctx, "accepted connection",
slog.F("remote_addr", netConn.RemoteAddr()))

go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
if err != nil {
_, _ = fmt.Fprintf(inv.Stderr,
"Failed to dial '%s://%s' in workspace: %s\n",
spec.network, dialAddress, err)
return
}
defer remoteConn.Close()
logger.Debug(ctx,
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))

agentssh.Bicopy(ctx, netConn, remoteConn)
logger.Debug(ctx,
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
}(netConn)
}
}(spec)

return l, nil
}

type portForwardSpec struct {
network string // tcp, udp
listenHost netip.Addr
listenPort, dialPort uint16
}

func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
specs := []portForwardSpec{}

for _, specEntry := range tcpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
}

for _, pfSpec := range pfSpecs {
pfSpec.network = "tcp"
specs = append(specs, pfSpec)
}
}
}

for _, specEntry := range udpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
}

for _, pfSpec := range pfSpecs {
pfSpec.network = "udp"
specs = append(specs, pfSpec)
}
}
}

// Check for duplicate entries.
locals := map[string]struct{}{}
for _, spec := range specs {
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
if _, ok := locals[localStr]; ok {
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
}
locals[localStr] = struct{}{}
}

return specs, nil
}

func parsePort(in string) (uint16, error) {
port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16)
if err != nil {
return 0, xerrors.Errorf("parse port %q: %w", in, err)
}
if port == 0 {
return 0, xerrors.New("port cannot be 0")
}

return uint16(port), nil
}

// specRegexp matches port specs. It handles all the following formats:
//
// 8000
// 8888:9999
// 1-5:6-10
// 8000-8005
// 127.0.0.1:4000:4000
// [::1]:8080:8081
// 127.0.0.1:4000-4005
// [::1]:4000-4001:5000-5001
//
// Important capturing groups:
//
// 2: local IP address (including [] for IPv6)
// 3: local port, or start of local port range
// 5: end of local port range
// 7: remote port, or start of remote port range
// 9: end or remote port range
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)

func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
groups := specRegexp.FindStringSubmatch(in)
if len(groups) == 0 {
return nil, xerrors.Errorf("invalid port specification %q", in)
}

var localAddr netip.Addr
if groups[2] != "" {
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
if err != nil {
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
}
localAddr = parsedAddr
}

local, err := parsePortRange(groups[3], groups[5])
if err != nil {
return nil, xerrors.Errorf("parse local port range from %q: %w", in, err)
}
remote := local
if groups[7] != "" {
remote, err = parsePortRange(groups[7], groups[9])
if err != nil {
return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err)
}
}
if len(local) != len(remote) {
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
}
var out []portForwardSpec
for i := range local {
out = append(out, portForwardSpec{
listenHost: localAddr,
listenPort: local[i],
dialPort: remote[i],
})
}
return out, nil
}

func parsePortRange(s, e string) ([]uint16, error) {
start, err := parsePort(s)
if err != nil {
return nil, xerrors.Errorf("parse range start port from %q: %w", s, err)
}
end := start
if len(e) != 0 {
end, err = parsePort(e)
if err != nil {
return nil, xerrors.Errorf("parse range end port from %q: %w", e, err)
}
}
if end < start {
return nil, xerrors.Errorf("range end port %v is less than start port %v", end, start)
}
var ports []uint16
for i := start; i <= end; i++ {
ports = append(ports, i)
}
return ports, nil
}
Loading
Loading