Skip to content

fix: use fake local network for port-forward tests #11119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/clibase/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ type Invocation struct {
Stderr io.Writer
Stdin io.Reader
Logger slog.Logger
Net Net

// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
Expand All @@ -203,6 +204,7 @@ func (inv *Invocation) WithOS() *Invocation {
i.Stdin = os.Stdin
i.Args = os.Args[1:]
i.Environ = ParseEnviron(os.Environ(), "")
i.Net = osNet{}
})
}

Expand Down
50 changes: 50 additions & 0 deletions cli/clibase/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package clibase

import (
"net"
"strconv"

"github.com/pion/udp"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This package is no longer maintained (repo archived). We should consider switching to github.com/pion/transport/v2/udp if we still need the library.

"golang.org/x/xerrors"
)

// Net abstracts CLI commands interacting with the operating system networking.
//
// At present, it covers opening local listening sockets, since doing this
// in testing is a challenge without flakes, since it's hard to pick a port we
// know a priori will be free.
type Net interface {
// Listen has the same semantics as `net.Listen` but also supports `udp`
Listen(network, address string) (net.Listener, error)
}

// osNet is an implementation that call the real OS for networking.
type osNet struct{}

func (osNet) Listen(network, address string) (net.Listener, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we handle situations where you might need to configure dialing via net.Dialer?

Maybe we don't have such use-cases at this time but I'm concerned since we are introducing the field on clibase, which suggests it's the only way one should establish network connections (compared to injecting it into the specific command that needs it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a GetDialer(options) net.Dialer method to the Net interface, or somesuch.

(compared to injecting it into the specific command that needs it)

I'm not sure what you mean here. How would that suggestion be different than what I've implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that Net is like the stdio and environment fields on Invocation. When you actually run the command from a command line, you get the OS-provided functions, but in testing we can hook various things up to intercept the command interacting with the OS, just like how today we hook stdio into the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a GetDialer(options) net.Dialer method to the Net interface, or somesuch.

I suppose that could work, injecting resolver/dialer for the in-mem stuff 👍🏻.

I'm not sure what you mean here. How would that suggestion be different than what I've implemented?

With that parenthesis, I basically meant not putting the field in clibase/inv until there are more use-cases, i.e. keeping it port-forwarding only for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how we could keep it port-forwarding only without a major refactor. The command handler only gets an Invocation, which is also the only thing we have access to from clitest.New(). There literally isn't anything else to hang it on besides Invocation (module-level vars will wreak havoc in parallel testing).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also context.Context, but I hear you. That's perhaps a bit hidden away.

switch network {
case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
return net.Listen(network, address)
case "udp":
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", address, err)
}

var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, address, err)
}

// Use pion here so that we get a stream-style net.Conn listener, instead
// of a packet-oriented connection that can read and write to multiple
// addresses.
return udp.Listen(network, &net.UDPAddr{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want to use pion for UDP in all CLI usage going forward? (I understand this is mainly used by port-forward, so this is mainly a future concern about making assumptions based on one use-case).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, no. We can always refactor later if this doesn't meet the future use case.

Putting UDP into its own method makes faking this interface harder/more complex.

IP: net.ParseIP(host),
Port: portInt,
})
default:
return nil, xerrors.Errorf("unknown listen network %q", network)
}
}
36 changes: 7 additions & 29 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"sync"
"syscall"

"github.com/pion/udp"
"golang.org/x/xerrors"

"cdr.dev/slog"
Expand Down Expand Up @@ -121,6 +120,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners {
if l == nil {
continue
Expand All @@ -134,6 +134,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
for i, spec := range specs {
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
return err
}
listeners[i] = l
Expand All @@ -151,8 +152,10 @@ func (r *RootCmd) portForward() *clibase.Cmd {

select {
case <-ctx.Done():
logger.Debug(ctx, "command context expired waiting for signal", slog.Error(ctx.Err()))
closeErr = ctx.Err()
case <-sigs:
case sig := <-sigs:
logger.Debug(ctx, "received signal", slog.F("signal", sig))
_, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections")
}

Expand All @@ -161,6 +164,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
}()

conn.AwaitReachable(ctx)
logger.Debug(ctx, "read to accept connections to forward")
_, _ = fmt.Fprintln(inv.Stderr, "Ready!")
wg.Wait()
return closeErr
Expand Down Expand Up @@ -198,33 +202,7 @@ func listenAndPortForward(
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)

var (
l net.Listener
err error
)
switch spec.listenNetwork {
case "tcp":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
case "udp":
var host, port string
host, port, err = net.SplitHostPort(spec.listenAddress)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
}

var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
}

l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
default:
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
}
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
}
Expand Down
Loading