Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
74 changes: 74 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go a.sshServer.HandleConn(channel.NetConn())
case "reconnecting-pty":
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
case "dial":
go a.handleDial(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
Expand Down Expand Up @@ -617,6 +620,57 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
}
}

// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}

func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()

writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}

_, err = conn.Write(b)
return err
}

u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}

network := u.Scheme
addr := u.Host + u.Path
nconn, err := net.Dial(network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}

err = writeError(nil)
if err != nil {
return
}

bicopy(ctx, conn, nconn)
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand Down Expand Up @@ -662,3 +716,23 @@ func (r *reconnectingPTY) Close() {
r.circularBuffer.Reset()
r.timeout.Stop()
}

// bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()

ctx, cancel := context.WithCancel(ctx)

copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer cancel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it lead to problems if c1 finishes early, resulting in c2.Close() being called before c2 has finished copying?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched this to a WaitGroup to avoid this case... not sure if it could've happened in practice but better to be safe than sorry

_, _ = io.Copy(dst, src)
}

go copyFunc(c1, c2)
go copyFunc(c2, c1)

<-ctx.Done()
}
142 changes: 142 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/pion/udp"
"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -234,6 +235,114 @@ func TestAgent(t *testing.T) {
findEcho()
findEcho()
})

t.Run("Dial", func(t *testing.T) {
t.Parallel()

cases := []struct {
name string
setup func(t *testing.T) net.Listener
}{
Comment on lines +242 to +245
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done beautifully! 😍

{
name: "TCP",
setup: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
},
{
name: "UDP",
setup: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
},
{
name: "Unix",
setup: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
},
}

for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()

// Setup listener
l := c.setup(t)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}

testAccept(t, c)
}
}()

// Try to dial the listener over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
testDial(t, conn1)

// Dial again using the same WebRTC client
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
})
}
})

t.Run("DialError", func(t *testing.T) {
t.Parallel()

if runtime.GOOS == "windows" {
// This test uses Unix listeners so we can very easily ensure that
// no other tests decide to listen on the same random port we
// picked.
t.Skip("this test is unsupported on Windows")
return
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

// Try to dial the non-existent Unix socket over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
Expand Down Expand Up @@ -303,3 +412,36 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
Conn: conn,
}
}

var dialTestPayload = []byte("dean-was-here123")

func testDial(t *testing.T, c net.Conn) {
t.Helper()

// Dials will write and then expect echo back
n, err := c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")

b := make([]byte, len(dialTestPayload)+16)
n, err = c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])
}

func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()

// Accepts will read then echo
b := make([]byte, len(dialTestPayload)+16)
n, err := c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])

n, err = c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")
}
42 changes: 40 additions & 2 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package agent

import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"strings"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -32,7 +35,7 @@ type Conn struct {
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
Protocol: "reconnecting-pty",
})
if err != nil {
Expand All @@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
Expand Down Expand Up @@ -71,6 +74,41 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}

// DialContext dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
} else {
u.Host = addr
}

channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
Protocol: "dial",
})
if err != nil {
return nil, xerrors.Errorf("create datachannel: %w", err)
}

// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res dialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}

return channel.NetConn(), nil
}

func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
Expand Down
4 changes: 2 additions & 2 deletions peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
return newChannel(c, dataChannel, &ChannelOptions{}), nil
}

// Dial creates a new DataChannel.
func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
// CreateChannel creates a new DataChannel.
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
if opts == nil {
opts = &ChannelOptions{}
}
Expand Down
Loading