Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
114 changes: 114 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go a.sshServer.HandleConn(channel.NetConn())
case "reconnecting-pty":
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
case "dial":
go a.handleDial(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
Expand Down Expand Up @@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
}
}

// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}

func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()

writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}

_, err = conn.Write(b)
return err
}

u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}

network := u.Scheme
addr := u.Host + u.Path
if strings.HasPrefix(network, "unix") {
if runtime.GOOS == "windows" {
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
return
}
addr, err = ExpandRelativeHomePath(addr)
if err != nil {
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
return
}
}

d := net.Dialer{Timeout: 3 * time.Second}
nconn, err := d.DialContext(ctx, network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}

err = writeError(nil)
if err != nil {
return
}

Bicopy(ctx, conn, nconn)
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand Down Expand Up @@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() {
r.circularBuffer.Reset()
r.timeout.Stop()
}

// Bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()

var wg sync.WaitGroup
copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer wg.Done()
_, _ = io.Copy(dst, src)
}

wg.Add(2)
go copyFunc(c1, c2)
go copyFunc(c2, c1)

// Convert waitgroup to a channel so we can also wait on the context.
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()

select {
case <-ctx.Done():
case <-done:
}
}
Comment on lines +733 to +761
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just close the other connection if one finishes? Because the dialer already has a context, all connections will be closed when it exists.

go func() {
  defer c1.Close()
  _, _ = io.Copy(c1, c2)
}()
defer c2.Close()
_, _ = io.Copy(c2, c1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once connected the context passed to DialContext does not affect the lifetime of the connection, so if the context is cancelled the net.Conn won't be closed.


// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
// current user's home directory and returns a full absolute path.
func ExpandRelativeHomePath(in string) (string, error) {
usr, err := user.Current()
if err != nil {
return "", xerrors.Errorf("get current user details: %w", err)
}

if in == "~" {
in = usr.HomeDir
} else if strings.HasPrefix(in, "~/") {
in = filepath.Join(usr.HomeDir, in[2:])
}

return filepath.Abs(in)
}
138 changes: 138 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/pion/udp"
"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -234,6 +235,112 @@ func TestAgent(t *testing.T) {
findEcho()
findEcho()
})

t.Run("Dial", func(t *testing.T) {
t.Parallel()

cases := []struct {
name string
setup func(t *testing.T) net.Listener
}{
Comment on lines +242 to +245
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done beautifully! 😍

{
name: "TCP",
setup: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
},
{
name: "UDP",
setup: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
},
{
name: "Unix",
setup: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
},
}

for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()

// Setup listener
l := c.setup(t)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}

go testAccept(t, c)
}
}()

// Dial the listener over WebRTC twice and test out of order
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
testDial(t, conn1)
})
}
})

t.Run("DialError", func(t *testing.T) {
t.Parallel()

if runtime.GOOS == "windows" {
// This test uses Unix listeners so we can very easily ensure that
// no other tests decide to listen on the same random port we
// picked.
t.Skip("this test is unsupported on Windows")
return
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

// Try to dial the non-existent Unix socket over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
Expand Down Expand Up @@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
Conn: conn,
}
}

var dialTestPayload = []byte("dean-was-here123")

func testDial(t *testing.T, c net.Conn) {
t.Helper()

assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}

func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()

assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}

func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
require.NoError(t, err, "read payload")
require.Equal(t, len(payload), n, "read payload length does not match")
require.Equal(t, payload, b[:n])
}

func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
n, err := w.Write(payload)
require.NoError(t, err, "write payload")
require.Equal(t, len(payload), n, "payload length does not match")
}
43 changes: 41 additions & 2 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package agent

import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"strings"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -32,7 +35,7 @@ type Conn struct {
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
Protocol: "reconnecting-pty",
})
if err != nil {
Expand All @@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
Expand Down Expand Up @@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}

// DialContext dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
} else {
u.Host = addr
}

channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
Protocol: "dial",
Unordered: strings.HasPrefix(network, "udp"),
})
if err != nil {
return nil, xerrors.Errorf("create datachannel: %w", err)
}

// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res dialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}

return channel.NetConn(), nil
}

func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
Expand Down
Loading