diff --git a/cli/clibase/cmd.go b/cli/clibase/cmd.go index ba359faae6246..c21bc38684618 100644 --- a/cli/clibase/cmd.go +++ b/cli/clibase/cmd.go @@ -189,6 +189,7 @@ type Invocation struct { Stderr io.Writer Stdin io.Reader Logger slog.Logger + Net Net // testing signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) @@ -203,6 +204,7 @@ func (inv *Invocation) WithOS() *Invocation { i.Stdin = os.Stdin i.Args = os.Args[1:] i.Environ = ParseEnviron(os.Environ(), "") + i.Net = osNet{} }) } diff --git a/cli/clibase/net.go b/cli/clibase/net.go new file mode 100644 index 0000000000000..583343407b45f --- /dev/null +++ b/cli/clibase/net.go @@ -0,0 +1,50 @@ +package clibase + +import ( + "net" + "strconv" + + "github.com/pion/udp" + "golang.org/x/xerrors" +) + +// Net abstracts CLI commands interacting with the operating system networking. +// +// At present, it covers opening local listening sockets, since doing this +// in testing is a challenge without flakes, since it's hard to pick a port we +// know a priori will be free. +type Net interface { + // Listen has the same semantics as `net.Listen` but also supports `udp` + Listen(network, address string) (net.Listener, error) +} + +// osNet is an implementation that call the real OS for networking. +type osNet struct{} + +func (osNet) Listen(network, address string) (net.Listener, error) { + switch network { + case "tcp", "tcp4", "tcp6", "unix", "unixpacket": + return net.Listen(network, address) + case "udp": + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, xerrors.Errorf("split %q: %w", address, err) + } + + var portInt int + portInt, err = strconv.Atoi(port) + if err != nil { + return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, address, err) + } + + // Use pion here so that we get a stream-style net.Conn listener, instead + // of a packet-oriented connection that can read and write to multiple + // addresses. + return udp.Listen(network, &net.UDPAddr{ + IP: net.ParseIP(host), + Port: portInt, + }) + default: + return nil, xerrors.Errorf("unknown listen network %q", network) + } +} diff --git a/cli/portforward.go b/cli/portforward.go index a42765e3f918d..c26c12d75166f 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -12,7 +12,6 @@ import ( "sync" "syscall" - "github.com/pion/udp" "golang.org/x/xerrors" "cdr.dev/slog" @@ -121,6 +120,7 @@ func (r *RootCmd) portForward() *clibase.Cmd { wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { + logger.Debug(ctx, "closing all listeners") for _, l := range listeners { if l == nil { continue @@ -134,6 +134,7 @@ func (r *RootCmd) portForward() *clibase.Cmd { for i, spec := range specs { l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) if err != nil { + logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) return err } listeners[i] = l @@ -151,8 +152,10 @@ func (r *RootCmd) portForward() *clibase.Cmd { select { case <-ctx.Done(): + logger.Debug(ctx, "command context expired waiting for signal", slog.Error(ctx.Err())) closeErr = ctx.Err() - case <-sigs: + case sig := <-sigs: + logger.Debug(ctx, "received signal", slog.F("signal", sig)) _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections") } @@ -161,6 +164,7 @@ func (r *RootCmd) portForward() *clibase.Cmd { }() conn.AwaitReachable(ctx) + logger.Debug(ctx, "read to accept connections to forward") _, _ = fmt.Fprintln(inv.Stderr, "Ready!") wg.Wait() return closeErr @@ -198,33 +202,7 @@ func listenAndPortForward( logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress)) _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) - var ( - l net.Listener - err error - ) - switch spec.listenNetwork { - case "tcp": - l, err = net.Listen(spec.listenNetwork, spec.listenAddress) - case "udp": - var host, port string - host, port, err = net.SplitHostPort(spec.listenAddress) - if err != nil { - return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err) - } - - var portInt int - portInt, err = strconv.Atoi(port) - if err != nil { - return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) - } - - l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ - IP: net.ParseIP(host), - Port: portInt, - }) - default: - return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork) - } + l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress) if err != nil { return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) } diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 38971a0c89ba4..b211a840dd870 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -13,6 +13,7 @@ import ( "github.com/pion/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agenttest" @@ -45,47 +46,35 @@ func TestPortForward_None(t *testing.T) { pty.ExpectMatch("port-forward ") } -//nolint:tparallel,paralleltest // Subtests require setup that must not be done in parallel. func TestPortForward(t *testing.T) { + t.Parallel() cases := []struct { name string network string - // The flag to pass to `coder port-forward X` to port-forward this type - // of connection. Has two format args (both strings), the first is the - // local address and the second is the remote address. - flag string + // The flag(s) to pass to `coder port-forward X` to port-forward this type + // of connection. Has one format arg (string) for the remote address. + flag []string // setupRemote creates a "remote" listener to emulate a service in the // workspace. setupRemote func(t *testing.T) net.Listener - // setupLocal returns an available port that the - // port-forward command will listen on "locally". Returns the address - // you pass to net.Dial, and the port/path you pass to `coder - // port-forward`. - setupLocal func(t *testing.T) (string, string) + // the local address(es) to "dial" + localAddress []string }{ { name: "TCP", network: "tcp", - flag: "--tcp=%v:%v", + flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"}, setupRemote: func(t *testing.T) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l }, - setupLocal: func(t *testing.T) (string, string) { - l, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err, "create TCP listener to generate random port") - defer l.Close() - - _, port, err := net.SplitHostPort(l.Addr().String()) - require.NoErrorf(t, err, "split TCP address %q", l.Addr().String()) - return l.Addr().String(), port - }, + localAddress: []string{"127.0.0.1:5555", "127.0.0.1:6666"}, }, { name: "UDP", network: "udp", - flag: "--udp=%v:%v", + flag: []string{"--udp=7777:%v", "--udp=8888:%v"}, setupRemote: func(t *testing.T) net.Listener { addr := net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -95,38 +84,17 @@ func TestPortForward(t *testing.T) { require.NoError(t, err, "create UDP listener") return l }, - setupLocal: func(t *testing.T) (string, string) { - addr := net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - l, err := udp.Listen("udp", &addr) - require.NoError(t, err, "create UDP listener to generate random port") - defer l.Close() - - _, port, err := net.SplitHostPort(l.Addr().String()) - require.NoErrorf(t, err, "split UDP address %q", l.Addr().String()) - return l.Addr().String(), port - }, + localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"}, }, { name: "TCPWithAddress", - network: "tcp", - flag: "--tcp=%v:%v", + network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"}, setupRemote: func(t *testing.T) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l }, - setupLocal: func(t *testing.T) (string, string) { - l, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err, "create TCP listener to generate random port") - defer l.Close() - - _, port, err := net.SplitHostPort(l.Addr().String()) - require.NoErrorf(t, err, "split TCP address %q", l.Addr().String()) - return l.Addr().String(), fmt.Sprint("0.0.0.0:", port) - }, + localAddress: []string{"10.10.10.99:9999", "10.10.10.10:1010"}, }, } @@ -141,16 +109,12 @@ func TestPortForward(t *testing.T) { for _, c := range cases { c := c - // No parallel tests here because setupLocal reserves - // a free open port which is not guaranteed to be free - // between the listener closing and port-forward ready. - //nolint:tparallel,paralleltest t.Run(c.name+"_OnePort", func(t *testing.T) { + t.Parallel() p1 := setupTestListener(t, c.setupRemote(t)) // Create a flag that forwards from local to listener 1. - localAddress, localFlag := c.setupLocal(t) - flag := fmt.Sprintf(c.flag, localFlag, p1) + flag := fmt.Sprintf(c.flag[0], p1) // Launch port-forward in a goroutine so we can start dialing // the "local" listener. @@ -160,21 +124,27 @@ func TestPortForward(t *testing.T) { inv.Stdin = pty.Input() inv.Stdout = pty.Output() inv.Stderr = pty.Output() + + iNet := newInProcNet() + inv.Net = iNet ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() errC := make(chan error) go func() { - errC <- inv.WithContext(ctx).Run() + err := inv.WithContext(ctx).Run() + t.Logf("command complete; err=%s", err.Error()) + errC <- err }() pty.ExpectMatchContext(ctx, "Ready!") // Open two connections simultaneously and test them out of // sync. - d := net.Dialer{Timeout: testutil.WaitShort} - c1, err := d.DialContext(ctx, c.network, localAddress) + dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort) + defer dialCtxCancel() + c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]}) require.NoError(t, err, "open connection 1 to 'local' listener") defer c1.Close() - c2, err := d.DialContext(ctx, c.network, localAddress) + c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]}) require.NoError(t, err, "open connection 2 to 'local' listener") defer c2.Close() testDial(t, c2) @@ -185,21 +155,16 @@ func TestPortForward(t *testing.T) { require.ErrorIs(t, err, context.Canceled) }) - // No parallel tests here because setupLocal reserves - // a free open port which is not guaranteed to be free - // between the listener closing and port-forward ready. - //nolint:tparallel,paralleltest t.Run(c.name+"_TwoPorts", func(t *testing.T) { + t.Parallel() var ( p1 = setupTestListener(t, c.setupRemote(t)) p2 = setupTestListener(t, c.setupRemote(t)) ) // Create a flags for listener 1 and listener 2. - localAddress1, localFlag1 := c.setupLocal(t) - localAddress2, localFlag2 := c.setupLocal(t) - flag1 := fmt.Sprintf(c.flag, localFlag1, p1) - flag2 := fmt.Sprintf(c.flag, localFlag2, p2) + flag1 := fmt.Sprintf(c.flag[0], p1) + flag2 := fmt.Sprintf(c.flag[1], p2) // Launch port-forward in a goroutine so we can start dialing // the "local" listeners. @@ -209,6 +174,9 @@ func TestPortForward(t *testing.T) { inv.Stdin = pty.Input() inv.Stdout = pty.Output() inv.Stderr = pty.Output() + + iNet := newInProcNet() + inv.Net = iNet ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() errC := make(chan error) @@ -219,11 +187,12 @@ func TestPortForward(t *testing.T) { // Open a connection to both listener 1 and 2 simultaneously and // then test them out of order. - d := net.Dialer{Timeout: testutil.WaitShort} - c1, err := d.DialContext(ctx, c.network, localAddress1) + dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort) + defer dialCtxCancel() + c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]}) require.NoError(t, err, "open connection 1 to 'local' listener 1") defer c1.Close() - c2, err := d.DialContext(ctx, c.network, localAddress2) + c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[1]}) require.NoError(t, err, "open connection 2 to 'local' listener 2") defer c2.Close() testDial(t, c2) @@ -235,12 +204,8 @@ func TestPortForward(t *testing.T) { }) } - // Test doing TCP and UDP at the same time. - // No parallel tests here because setupLocal reserves - // a free open port which is not guaranteed to be free - // between the listener closing and port-forward ready. - //nolint:tparallel,paralleltest t.Run("All", func(t *testing.T) { + t.Parallel() var ( dials = []addr{} flags = []string{} @@ -250,12 +215,11 @@ func TestPortForward(t *testing.T) { for _, c := range cases { p := setupTestListener(t, c.setupRemote(t)) - localAddress, localFlag := c.setupLocal(t) dials = append(dials, addr{ network: c.network, - addr: localAddress, + addr: c.localAddress[0], }) - flags = append(flags, fmt.Sprintf(c.flag, localFlag, p)) + flags = append(flags, fmt.Sprintf(c.flag[0], p)) } // Launch port-forward in a goroutine so we can start dialing @@ -264,6 +228,9 @@ func TestPortForward(t *testing.T) { clitest.SetupConfig(t, member, root) pty := ptytest.New(t).Attach(inv) inv.Stderr = pty.Output() + + iNet := newInProcNet() + inv.Net = iNet ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() errC := make(chan error) @@ -274,11 +241,12 @@ func TestPortForward(t *testing.T) { // Open connections to all items in the "dial" array. var ( - d = net.Dialer{Timeout: testutil.WaitShort} - conns = make([]net.Conn, len(dials)) + dialCtx, dialCtxCancel = context.WithTimeout(ctx, testutil.WaitShort) + conns = make([]net.Conn, len(dials)) ) + defer dialCtxCancel() for i, a := range dials { - c, err := d.DialContext(ctx, a.network, a.addr) + c, err := iNet.dial(dialCtx, a) require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1) t.Cleanup(func() { _ = c.Close() @@ -396,3 +364,90 @@ type addr struct { network string addr string } + +func (a addr) Network() string { + return a.network +} + +func (a addr) Address() string { + return a.addr +} + +func (a addr) String() string { + return a.network + "|" + a.addr +} + +type inProcNet struct { + sync.Mutex + + listeners map[addr]*inProcListener +} + +type inProcListener struct { + c chan net.Conn + n *inProcNet + a addr + o sync.Once +} + +func newInProcNet() *inProcNet { + return &inProcNet{listeners: make(map[addr]*inProcListener)} +} + +func (n *inProcNet) Listen(network, address string) (net.Listener, error) { + a := addr{network, address} + n.Lock() + defer n.Unlock() + if _, ok := n.listeners[a]; ok { + return nil, xerrors.New("busy") + } + l := newInProcListener(n, a) + n.listeners[a] = l + return l, nil +} + +func (n *inProcNet) dial(ctx context.Context, a addr) (net.Conn, error) { + n.Lock() + defer n.Unlock() + l, ok := n.listeners[a] + if !ok { + return nil, xerrors.Errorf("nothing listening on %s", a) + } + x, y := net.Pipe() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case l.c <- x: + return y, nil + } +} + +func newInProcListener(n *inProcNet, a addr) *inProcListener { + return &inProcListener{ + c: make(chan net.Conn), + n: n, + a: a, + } +} + +func (l *inProcListener) Accept() (net.Conn, error) { + c, ok := <-l.c + if !ok { + return nil, net.ErrClosed + } + return c, nil +} + +func (l *inProcListener) Close() error { + l.o.Do(func() { + l.n.Lock() + defer l.n.Unlock() + delete(l.n.listeners, l.a) + close(l.c) + }) + return nil +} + +func (l *inProcListener) Addr() net.Addr { + return l.a +}