Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
Next Next commit
feat: add agent dial handler
  • Loading branch information
deansheather committed May 9, 2022
commit f104e02b7e4a2da4d3818b0b5f0aa92ad795015c
2 changes: 2 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go a.sshServer.HandleConn(channel.NetConn())
case "reconnecting-pty":
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
case "dial":
go a.handleDial(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
Expand Down
142 changes: 142 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/pion/udp"
"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -234,6 +235,114 @@ func TestAgent(t *testing.T) {
findEcho()
findEcho()
})

t.Run("Dial", func(t *testing.T) {
t.Parallel()

cases := []struct {
name string
setup func(t *testing.T) net.Listener
}{
Comment on lines +242 to +245
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done beautifully! 😍

{
name: "TCP",
setup: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
},
{
name: "UDP",
setup: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
},
{
name: "Unix",
setup: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
},
}

for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()

// Setup listener
l := c.setup(t)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}

testAccept(t, c)
}
}()

// Try to dial the listener over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.Dial(l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
testDial(t, conn1)

// Dial again using the same WebRTC client
conn2, err := conn.Dial(l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
})
}
})

t.Run("DialError", func(t *testing.T) {
t.Parallel()

if runtime.GOOS == "windows" {
// This test uses Unix listeners so we can very easily ensure that
// no other tests decide to listen on the same random port we
// picked.
t.Skip("this test is unsupported on Windows")
return
}

tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})

// Try to dial the non-existent Unix socket over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.Dial("unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
Expand Down Expand Up @@ -303,3 +412,36 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
Conn: conn,
}
}

var dialTestPayload = []byte("dean-was-here123")

func testDial(t *testing.T, c net.Conn) {
t.Helper()

// Dials will write and then expect echo back
n, err := c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")

b := make([]byte, len(dialTestPayload)+16)
n, err = c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])
}

func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()

// Accepts will read then echo
b := make([]byte, len(dialTestPayload)+16)
n, err := c.Read(b)
require.NoError(t, err, "read test payload")
require.Equal(t, len(dialTestPayload), n, "read payload length does not match")
require.Equal(t, dialTestPayload, b[:n])

n, err = c.Write(dialTestPayload)
require.NoError(t, err, "write test payload")
require.Equal(t, len(dialTestPayload), n, "test payload length does not match")
}
4 changes: 2 additions & 2 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Conn struct {
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
channel, err := c.OpenChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
Protocol: "reconnecting-pty",
})
if err != nil {
Expand All @@ -43,7 +43,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
channel, err := c.OpenChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
Expand Down
135 changes: 135 additions & 0 deletions agent/dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package agent

import (
"context"
"encoding/json"
"io"
"net"
"net/url"
"strings"

"github.com/google/uuid"
"golang.org/x/xerrors"

"github.com/coder/coder/peer"
)

// DialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type DialResponse struct {
Error string `json:"error,omitempty"`
}

// Dial dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *Conn) Dial(network string, addr string) (net.Conn, error) {
// Force unique URL by including a random UUID.
id, err := uuid.NewRandom()
if err != nil {
return nil, xerrors.Errorf("generate random UUID: %w", err)
}

host := ""
path := ""
if strings.HasPrefix(network, "unix") {
path = addr
} else {
host = addr
}

label := (&url.URL{
Scheme: network,
Host: host,
Path: path,
RawQuery: (url.Values{
"id": []string{id.String()},
}).Encode(),
}).String()

channel, err := c.OpenChannel(context.Background(), label, &peer.ChannelOptions{
Protocol: "dial",
})
if err != nil {
return nil, xerrors.Errorf("pty: %w", err)
}

// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res DialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}

return channel.NetConn(), nil
}

func (*agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()

writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
}
b, err := json.Marshal(DialResponse{
Error: msg,
})
if err != nil {
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}

_, err = conn.Write(b)
return err
}

u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}

network := u.Scheme
addr := u.Host + u.Path
nconn, err := net.Dial(network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}

err = writeError(nil)
if err != nil {
return
}

bicopy(ctx, conn, nconn)
}

// bicopy copies all of the data between the two connections
// and will close them after one or both of them are done writing.
// If the context is canceled, both of the connections will be
// closed.
//
// NOTE: This function will block until the copying is done or the
// context is canceled.
func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()

ctx, cancel := context.WithCancel(ctx)

copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer cancel()
_, _ = io.Copy(dst, src)
}

go copyFunc(c1, c2)
go copyFunc(c2, c1)

<-ctx.Done()
}
4 changes: 2 additions & 2 deletions peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
return newChannel(c, dataChannel, &ChannelOptions{}), nil
}

// Dial creates a new DataChannel.
func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
// OpenChannel creates a new DataChannel.
func (c *Conn) OpenChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
if opts == nil {
opts = &ChannelOptions{}
}
Expand Down
14 changes: 7 additions & 7 deletions peer/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestConn(t *testing.T) {
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
channel, err := server.Dial(context.Background(), "wow", nil)
channel, err := server.OpenChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
Expand All @@ -108,7 +108,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)

sch, err := server.Accept(context.Background())
Expand All @@ -124,7 +124,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
Expand All @@ -141,7 +141,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestConn(t *testing.T) {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{})
cch, err = client.OpenChannel(ctx, "hello", &peer.ChannelOptions{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestConn(t *testing.T) {
require.NoError(t, err)
expectedErr := xerrors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.Dial(context.Background(), "", nil)
_, err = conn.OpenChannel(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})

Expand Down Expand Up @@ -274,7 +274,7 @@ func TestConn(t *testing.T) {
client, server, _ := createPair(t)
exchange(t, client, server)
go func() {
channel, err := client.Dial(context.Background(), "test", nil)
channel, err := client.OpenChannel(context.Background(), "test", nil)
require.NoError(t, err)
_, err = channel.Write([]byte{1, 2})
require.NoError(t, err)
Expand Down