Skip to content

Commit d1668ce

Browse files
deansheatherkylecarbs
authored andcommitted
feat: add port-forward subcommand (#1350)
1 parent 8594936 commit d1668ce

15 files changed

+1403
-119
lines changed

agent/agent.go

+114
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"net"
12+
"net/url"
1213
"os"
1314
"os/exec"
1415
"os/user"
@@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
211212
go a.sshServer.HandleConn(channel.NetConn())
212213
case "reconnecting-pty":
213214
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
215+
case "dial":
216+
go a.handleDial(ctx, channel.Label(), channel.NetConn())
214217
default:
215218
a.logger.Warn(ctx, "unhandled protocol from channel",
216219
slog.F("protocol", channel.Protocol()),
@@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
617620
}
618621
}
619622

623+
// dialResponse is written to datachannels with protocol "dial" by the agent as
624+
// the first packet to signify whether the dial succeeded or failed.
625+
type dialResponse struct {
626+
Error string `json:"error,omitempty"`
627+
}
628+
629+
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
630+
defer conn.Close()
631+
632+
writeError := func(responseError error) error {
633+
msg := ""
634+
if responseError != nil {
635+
msg = responseError.Error()
636+
if !xerrors.Is(responseError, io.EOF) {
637+
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
638+
}
639+
}
640+
b, err := json.Marshal(dialResponse{
641+
Error: msg,
642+
})
643+
if err != nil {
644+
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
645+
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
646+
}
647+
648+
_, err = conn.Write(b)
649+
return err
650+
}
651+
652+
u, err := url.Parse(label)
653+
if err != nil {
654+
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
655+
return
656+
}
657+
658+
network := u.Scheme
659+
addr := u.Host + u.Path
660+
if strings.HasPrefix(network, "unix") {
661+
if runtime.GOOS == "windows" {
662+
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
663+
return
664+
}
665+
addr, err = ExpandRelativeHomePath(addr)
666+
if err != nil {
667+
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
668+
return
669+
}
670+
}
671+
672+
d := net.Dialer{Timeout: 3 * time.Second}
673+
nconn, err := d.DialContext(ctx, network, addr)
674+
if err != nil {
675+
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
676+
return
677+
}
678+
679+
err = writeError(nil)
680+
if err != nil {
681+
return
682+
}
683+
684+
Bicopy(ctx, conn, nconn)
685+
}
686+
620687
// isClosed returns whether the API is closed or not.
621688
func (a *agent) isClosed() bool {
622689
select {
@@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() {
662729
r.circularBuffer.Reset()
663730
r.timeout.Stop()
664731
}
732+
733+
// Bicopy copies all of the data between the two connections and will close them
734+
// after one or both of them are done writing. If the context is canceled, both
735+
// of the connections will be closed.
736+
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
737+
defer c1.Close()
738+
defer c2.Close()
739+
740+
var wg sync.WaitGroup
741+
copyFunc := func(dst io.WriteCloser, src io.Reader) {
742+
defer wg.Done()
743+
_, _ = io.Copy(dst, src)
744+
}
745+
746+
wg.Add(2)
747+
go copyFunc(c1, c2)
748+
go copyFunc(c2, c1)
749+
750+
// Convert waitgroup to a channel so we can also wait on the context.
751+
done := make(chan struct{})
752+
go func() {
753+
defer close(done)
754+
wg.Wait()
755+
}()
756+
757+
select {
758+
case <-ctx.Done():
759+
case <-done:
760+
}
761+
}
762+
763+
// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
764+
// current user's home directory and returns a full absolute path.
765+
func ExpandRelativeHomePath(in string) (string, error) {
766+
usr, err := user.Current()
767+
if err != nil {
768+
return "", xerrors.Errorf("get current user details: %w", err)
769+
}
770+
771+
if in == "~" {
772+
in = usr.HomeDir
773+
} else if strings.HasPrefix(in, "~/") {
774+
in = filepath.Join(usr.HomeDir, in[2:])
775+
}
776+
777+
return filepath.Abs(in)
778+
}

agent/agent_test.go

+138
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/google/uuid"
19+
"github.com/pion/udp"
1920
"github.com/pion/webrtc/v3"
2021
"github.com/pkg/sftp"
2122
"github.com/stretchr/testify/require"
@@ -234,6 +235,112 @@ func TestAgent(t *testing.T) {
234235
findEcho()
235236
findEcho()
236237
})
238+
239+
t.Run("Dial", func(t *testing.T) {
240+
t.Parallel()
241+
242+
cases := []struct {
243+
name string
244+
setup func(t *testing.T) net.Listener
245+
}{
246+
{
247+
name: "TCP",
248+
setup: func(t *testing.T) net.Listener {
249+
l, err := net.Listen("tcp", "127.0.0.1:0")
250+
require.NoError(t, err, "create TCP listener")
251+
return l
252+
},
253+
},
254+
{
255+
name: "UDP",
256+
setup: func(t *testing.T) net.Listener {
257+
addr := net.UDPAddr{
258+
IP: net.ParseIP("127.0.0.1"),
259+
Port: 0,
260+
}
261+
l, err := udp.Listen("udp", &addr)
262+
require.NoError(t, err, "create UDP listener")
263+
return l
264+
},
265+
},
266+
{
267+
name: "Unix",
268+
setup: func(t *testing.T) net.Listener {
269+
if runtime.GOOS == "windows" {
270+
t.Skip("Unix socket forwarding isn't supported on Windows")
271+
}
272+
273+
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
274+
require.NoError(t, err, "create temp dir for unix listener")
275+
t.Cleanup(func() {
276+
_ = os.RemoveAll(tmpDir)
277+
})
278+
279+
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
280+
require.NoError(t, err, "create UDP listener")
281+
return l
282+
},
283+
},
284+
}
285+
286+
for _, c := range cases {
287+
c := c
288+
t.Run(c.name, func(t *testing.T) {
289+
t.Parallel()
290+
291+
// Setup listener
292+
l := c.setup(t)
293+
defer l.Close()
294+
go func() {
295+
for {
296+
c, err := l.Accept()
297+
if err != nil {
298+
return
299+
}
300+
301+
go testAccept(t, c)
302+
}
303+
}()
304+
305+
// Dial the listener over WebRTC twice and test out of order
306+
conn := setupAgent(t, agent.Metadata{}, 0)
307+
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
308+
require.NoError(t, err)
309+
defer conn1.Close()
310+
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
311+
require.NoError(t, err)
312+
defer conn2.Close()
313+
testDial(t, conn2)
314+
testDial(t, conn1)
315+
})
316+
}
317+
})
318+
319+
t.Run("DialError", func(t *testing.T) {
320+
t.Parallel()
321+
322+
if runtime.GOOS == "windows" {
323+
// This test uses Unix listeners so we can very easily ensure that
324+
// no other tests decide to listen on the same random port we
325+
// picked.
326+
t.Skip("this test is unsupported on Windows")
327+
return
328+
}
329+
330+
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
331+
require.NoError(t, err, "create temp dir")
332+
t.Cleanup(func() {
333+
_ = os.RemoveAll(tmpDir)
334+
})
335+
336+
// Try to dial the non-existent Unix socket over WebRTC
337+
conn := setupAgent(t, agent.Metadata{}, 0)
338+
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
339+
require.Error(t, err)
340+
require.ErrorContains(t, err, "remote dial error")
341+
require.ErrorContains(t, err, "no such file")
342+
require.Nil(t, netConn)
343+
})
237344
}
238345

239346
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
@@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
303410
Conn: conn,
304411
}
305412
}
413+
414+
var dialTestPayload = []byte("dean-was-here123")
415+
416+
func testDial(t *testing.T, c net.Conn) {
417+
t.Helper()
418+
419+
assertWritePayload(t, c, dialTestPayload)
420+
assertReadPayload(t, c, dialTestPayload)
421+
}
422+
423+
func testAccept(t *testing.T, c net.Conn) {
424+
t.Helper()
425+
defer c.Close()
426+
427+
assertReadPayload(t, c, dialTestPayload)
428+
assertWritePayload(t, c, dialTestPayload)
429+
}
430+
431+
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
432+
b := make([]byte, len(payload)+16)
433+
n, err := r.Read(b)
434+
require.NoError(t, err, "read payload")
435+
require.Equal(t, len(payload), n, "read payload length does not match")
436+
require.Equal(t, payload, b[:n])
437+
}
438+
439+
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
440+
n, err := w.Write(payload)
441+
require.NoError(t, err, "write payload")
442+
require.Equal(t, len(payload), n, "payload length does not match")
443+
}

agent/conn.go

+41-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package agent
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"net"
8+
"net/url"
9+
"strings"
710

811
"golang.org/x/crypto/ssh"
912
"golang.org/x/xerrors"
@@ -32,7 +35,7 @@ type Conn struct {
3235
// ReconnectingPTY returns a connection serving a TTY that can
3336
// be reconnected to via ID.
3437
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
35-
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
38+
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
3639
Protocol: "reconnecting-pty",
3740
})
3841
if err != nil {
@@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error
4346

4447
// SSH dials the built-in SSH server.
4548
func (c *Conn) SSH() (net.Conn, error) {
46-
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
49+
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
4750
Protocol: "ssh",
4851
})
4952
if err != nil {
@@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
7174
return ssh.NewClient(sshConn, channels, requests), nil
7275
}
7376

77+
// DialContext dials an arbitrary protocol+address from inside the workspace and
78+
// proxies it through the provided net.Conn.
79+
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
80+
u := &url.URL{
81+
Scheme: network,
82+
}
83+
if strings.HasPrefix(network, "unix") {
84+
u.Path = addr
85+
} else {
86+
u.Host = addr
87+
}
88+
89+
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
90+
Protocol: "dial",
91+
Unordered: strings.HasPrefix(network, "udp"),
92+
})
93+
if err != nil {
94+
return nil, xerrors.Errorf("create datachannel: %w", err)
95+
}
96+
97+
// The first message written from the other side is a JSON payload
98+
// containing the dial error.
99+
dec := json.NewDecoder(channel)
100+
var res dialResponse
101+
err = dec.Decode(&res)
102+
if err != nil {
103+
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
104+
}
105+
if res.Error != "" {
106+
_ = channel.Close()
107+
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
108+
}
109+
110+
return channel.NetConn(), nil
111+
}
112+
74113
func (c *Conn) Close() error {
75114
_ = c.Negotiator.DRPCConn().Close()
76115
return c.Conn.Close()

0 commit comments

Comments
 (0)