Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
Prev Previous commit
Next Next commit
chore: Kyle PR review comments
  • Loading branch information
deansheather committed May 9, 2022
commit 1972629c1f61a5742191cb40b5cdfb440cabed4f
72 changes: 72 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -619,6 +620,57 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
}
}

// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}

func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()

writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}

_, err = conn.Write(b)
return err
}

u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}

network := u.Scheme
addr := u.Host + u.Path
nconn, err := net.Dial(network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}

err = writeError(nil)
if err != nil {
return
}

bicopy(ctx, conn, nconn)
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand Down Expand Up @@ -664,3 +716,23 @@ func (r *reconnectingPTY) Close() {
r.circularBuffer.Reset()
r.timeout.Stop()
}

// bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()

ctx, cancel := context.WithCancel(ctx)

copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer cancel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it lead to problems if c1 finishes early, resulting in c2.Close() being called before c2 has finished copying?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched this to a WaitGroup to avoid this case... not sure if it could've happened in practice but better to be safe than sorry

_, _ = io.Copy(dst, src)
}

go copyFunc(c1, c2)
go copyFunc(c2, c1)

<-ctx.Done()
}
6 changes: 3 additions & 3 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,13 @@ func TestAgent(t *testing.T) {

// Try to dial the listener over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.Dial(l.Addr().Network(), l.Addr().String())
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
testDial(t, conn1)

// Dial again using the same WebRTC client
conn2, err := conn.Dial(l.Addr().Network(), l.Addr().String())
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
Expand All @@ -337,7 +337,7 @@ func TestAgent(t *testing.T) {

// Try to dial the non-existent Unix socket over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.Dial("unix", filepath.Join(tmpDir, "test.sock"))
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
Expand Down
42 changes: 40 additions & 2 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package agent

import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"strings"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -32,7 +35,7 @@ type Conn struct {
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
channel, err := c.OpenChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
Protocol: "reconnecting-pty",
})
if err != nil {
Expand All @@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.OpenChannel(context.Background(), "ssh", &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
Expand Down Expand Up @@ -71,6 +74,41 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}

// DialContext dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
} else {
u.Host = addr
}

channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
Protocol: "dial",
})
if err != nil {
return nil, xerrors.Errorf("create datachannel: %w", err)
}

// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res dialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}

return channel.NetConn(), nil
}

func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
Expand Down
135 changes: 0 additions & 135 deletions agent/dial.go

This file was deleted.

4 changes: 2 additions & 2 deletions peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
return newChannel(c, dataChannel, &ChannelOptions{}), nil
}

// OpenChannel creates a new DataChannel.
func (c *Conn) OpenChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
// CreateChannel creates a new DataChannel.
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
if opts == nil {
opts = &ChannelOptions{}
}
Expand Down
14 changes: 7 additions & 7 deletions peer/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestConn(t *testing.T) {
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
channel, err := server.OpenChannel(context.Background(), "wow", nil)
channel, err := server.CreateChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
Expand All @@ -108,7 +108,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)

sch, err := server.Accept(context.Background())
Expand All @@ -124,7 +124,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
Expand All @@ -141,7 +141,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestConn(t *testing.T) {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.OpenChannel(ctx, "hello", &peer.ChannelOptions{})
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestConn(t *testing.T) {
require.NoError(t, err)
expectedErr := xerrors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.OpenChannel(context.Background(), "", nil)
_, err = conn.CreateChannel(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})

Expand Down Expand Up @@ -274,7 +274,7 @@ func TestConn(t *testing.T) {
client, server, _ := createPair(t)
exchange(t, client, server)
go func() {
channel, err := client.OpenChannel(context.Background(), "test", nil)
channel, err := client.CreateChannel(context.Background(), "test", nil)
require.NoError(t, err)
_, err = channel.Write([]byte{1, 2})
require.NoError(t, err)
Expand Down