Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
Prev Previous commit
Next Next commit
wip
  • Loading branch information
deansheather committed May 17, 2022
commit 46092cc615fc41637974d19363423af995edace6
5 changes: 4 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,10 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
}
}

nconn, err := net.Dial(network, addr)
a.logger.Warn(ctx, "yeah", slog.F("network", network), slog.F("addr", addr))

d := net.Dialer{Timeout: 3 * time.Second}
nconn, err := d.DialContext(ctx, network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
Expand Down
44 changes: 20 additions & 24 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,20 @@ func TestAgent(t *testing.T) {
return
}

testAccept(t, c)
go testAccept(t, c)
}
}()

// Try to dial the listener over WebRTC
// Dial the listener over WebRTC twice and test out of order
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
testDial(t, conn1)

// Dial again using the same WebRTC client
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
testDial(t, conn1)
})
}
})
Expand Down Expand Up @@ -418,30 +416,28 @@ var dialTestPayload = []byte("dean-was-here123")
func testDial(t *testing.T, c net.Conn) {
t.Helper()

// Dials will write and then expect echo back
n, err := c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")

b := make([]byte, len(dialTestPayload)+16)
n, err = c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}

func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()

// Accepts will read then echo
b := make([]byte, len(dialTestPayload)+16)
n, err := c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}

func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
require.NoError(t, err, "read payload")
require.Equal(t, len(payload), n, "read payload length does not match")
require.Equal(t, payload, b[:n])
}

n, err = c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
n, err := w.Write(payload)
require.NoError(t, err, "write payload")
require.Equal(t, len(payload), n, "payload length does not match")
}
4 changes: 3 additions & 1 deletion agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker/proto"
"github.com/google/uuid"
)

// ReconnectingPTYRequest is sent from the client to the server
Expand Down Expand Up @@ -78,7 +79,8 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
// proxies it through the provided net.Conn.
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
Scheme: network,
RawQuery: "test=" + uuid.Must(uuid.NewRandom()).String(),
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
Expand Down
64 changes: 55 additions & 9 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"context"
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"

"github.com/pion/udp"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -72,6 +75,13 @@ func portForward() *cobra.Command {
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
}
if len(specs) == 0 {
err = cmd.Help()
if err != nil {
return xerrors.Errorf("generate help output: %w", err)
}
return xerrors.New("no port-forwards requested")
}

fmt.Println("SPECS:")
for _, spec := range specs {
Expand Down Expand Up @@ -118,13 +128,26 @@ func portForward() *cobra.Command {
defer conn.Close()

// Start all listeners
var wg sync.WaitGroup
for _, spec := range specs {
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg sync.WaitGroup
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
for _, l := range listeners {
if l == nil {
continue
}
_ = l.Close()
}
}
)
defer cancel()
for i, spec := range specs {
var (
l net.Listener
err error
)
fmt.Printf("Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
switch spec.listenNetwork {
case "tcp":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
Expand All @@ -145,40 +168,63 @@ func portForward() *cobra.Command {
case "unix":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
default:
closeAllListeners()
return xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
}
if err != nil {
closeAllListeners()
return xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
}
listeners[i] = l

wg.Add(1)
go func(spec portForwardSpec) {
defer wg.Done()
for {
netConn, err := l.Accept()
if err != nil {
fmt.Printf("Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
fmt.Println("Killing listener")
fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
return
}

go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(cmd.Context(), spec.dialNetwork, spec.dialAddress)
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
if err != nil {
fmt.Printf("Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
return
}
defer remoteConn.Close()

coderagent.Bicopy(cmd.Context(), netConn, remoteConn)
coderagent.Bicopy(ctx, netConn, remoteConn)
}(netConn)
}
}(spec)
}

// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

select {
case <-ctx.Done():
closeErr = ctx.Err()
case <-sigs:
fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
closeErr = xerrors.New("signal received")
}

cancel()
closeAllListeners()
}()

fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
wg.Wait()
return nil
return closeErr
},
}

Expand Down
Loading