From f104e02b7e4a2da4d3818b0b5f0aa92ad795015c Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 9 May 2022 20:35:19 +0000 Subject: [PATCH 01/10] feat: add agent dial handler --- agent/agent.go | 2 + agent/agent_test.go | 142 ++++++++++++++++++++++++++++++++++++++++++++ agent/conn.go | 4 +- agent/dial.go | 135 +++++++++++++++++++++++++++++++++++++++++ peer/conn.go | 4 +- peer/conn_test.go | 14 ++--- 6 files changed, 290 insertions(+), 11 deletions(-) create mode 100644 agent/dial.go diff --git a/agent/agent.go b/agent/agent.go index b946166056532..05d7e1799519a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -211,6 +211,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { go a.sshServer.HandleConn(channel.NetConn()) case "reconnecting-pty": go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn()) + case "dial": + go a.handleDial(ctx, channel.Label(), channel.NetConn()) default: a.logger.Warn(ctx, "unhandled protocol from channel", slog.F("protocol", channel.Protocol()), diff --git a/agent/agent_test.go b/agent/agent_test.go index bd26fae7f0a69..959cb04b1bbe6 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/google/uuid" + "github.com/pion/udp" "github.com/pion/webrtc/v3" "github.com/pkg/sftp" "github.com/stretchr/testify/require" @@ -234,6 +235,114 @@ func TestAgent(t *testing.T) { findEcho() findEcho() }) + + t.Run("Dial", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + setup func(t *testing.T) net.Listener + }{ + { + name: "TCP", + setup: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + }, + { + name: "UDP", + setup: func(t *testing.T) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l + }, + }, + { + name: "Unix", + setup: func(t *testing.T) net.Listener { + if runtime.GOOS == "windows" { + t.Skip("Unix socket forwarding isn't supported on Windows") + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) + require.NoError(t, err, "create UDP listener") + return l + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + // Setup listener + l := c.setup(t) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + + testAccept(t, c) + } + }() + + // Try to dial the listener over WebRTC + conn := setupAgent(t, agent.Metadata{}, 0) + conn1, err := conn.Dial(l.Addr().Network(), l.Addr().String()) + require.NoError(t, err) + defer conn1.Close() + testDial(t, conn1) + + // Dial again using the same WebRTC client + conn2, err := conn.Dial(l.Addr().Network(), l.Addr().String()) + require.NoError(t, err) + defer conn2.Close() + testDial(t, conn2) + }) + } + }) + + t.Run("DialError", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + // This test uses Unix listeners so we can very easily ensure that + // no other tests decide to listen on the same random port we + // picked. + t.Skip("this test is unsupported on Windows") + return + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + // Try to dial the non-existent Unix socket over WebRTC + conn := setupAgent(t, agent.Metadata{}, 0) + netConn, err := conn.Dial("unix", filepath.Join(tmpDir, "test.sock")) + require.Error(t, err) + require.ErrorContains(t, err, "remote dial error") + require.ErrorContains(t, err, "no such file") + require.Nil(t, netConn) + }) } func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { @@ -303,3 +412,36 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) Conn: conn, } } + +var dialTestPayload = []byte("dean-was-here123") + +func testDial(t *testing.T, c net.Conn) { + t.Helper() + + // Dials will write and then expect echo back + n, err := c.Write(dialTestPayload) + require.NoError(t, err, "write test payload") + require.Equal(t, len(dialTestPayload), n, "test payload length does not match") + + b := make([]byte, len(dialTestPayload)+16) + n, err = c.Read(b) + require.NoError(t, err, "read test payload") + require.Equal(t, len(dialTestPayload), n, "read payload length does not match") + require.Equal(t, dialTestPayload, b[:n]) +} + +func testAccept(t *testing.T, c net.Conn) { + t.Helper() + defer c.Close() + + // Accepts will read then echo + b := make([]byte, len(dialTestPayload)+16) + n, err := c.Read(b) + require.NoError(t, err, "read test payload") + require.Equal(t, len(dialTestPayload), n, "read payload length does not match") + require.Equal(t, dialTestPayload, b[:n]) + + n, err = c.Write(dialTestPayload) + require.NoError(t, err, "write test payload") + require.Equal(t, len(dialTestPayload), n, "test payload length does not match") +} diff --git a/agent/conn.go b/agent/conn.go index 81a6315af26de..3fb28b5a74a5b 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -32,7 +32,7 @@ type Conn struct { // ReconnectingPTY returns a connection serving a TTY that can // be reconnected to via ID. func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) { - channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ + channel, err := c.OpenChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ Protocol: "reconnecting-pty", }) if err != nil { @@ -43,7 +43,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error // SSH dials the built-in SSH server. func (c *Conn) SSH() (net.Conn, error) { - channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{ + channel, err := c.OpenChannel(context.Background(), "ssh", &peer.ChannelOptions{ Protocol: "ssh", }) if err != nil { diff --git a/agent/dial.go b/agent/dial.go new file mode 100644 index 0000000000000..f457a2c77c179 --- /dev/null +++ b/agent/dial.go @@ -0,0 +1,135 @@ +package agent + +import ( + "context" + "encoding/json" + "io" + "net" + "net/url" + "strings" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/peer" +) + +// DialResponse is written to datachannels with protocol "dial" by the agent as +// the first packet to signify whether the dial succeeded or failed. +type DialResponse struct { + Error string `json:"error,omitempty"` +} + +// Dial dials an arbitrary protocol+address from inside the workspace and +// proxies it through the provided net.Conn. +func (c *Conn) Dial(network string, addr string) (net.Conn, error) { + // Force unique URL by including a random UUID. + id, err := uuid.NewRandom() + if err != nil { + return nil, xerrors.Errorf("generate random UUID: %w", err) + } + + host := "" + path := "" + if strings.HasPrefix(network, "unix") { + path = addr + } else { + host = addr + } + + label := (&url.URL{ + Scheme: network, + Host: host, + Path: path, + RawQuery: (url.Values{ + "id": []string{id.String()}, + }).Encode(), + }).String() + + channel, err := c.OpenChannel(context.Background(), label, &peer.ChannelOptions{ + Protocol: "dial", + }) + if err != nil { + return nil, xerrors.Errorf("pty: %w", err) + } + + // The first message written from the other side is a JSON payload + // containing the dial error. + dec := json.NewDecoder(channel) + var res DialResponse + err = dec.Decode(&res) + if err != nil { + return nil, xerrors.Errorf("failed to decode initial packet: %w", err) + } + if res.Error != "" { + _ = channel.Close() + return nil, xerrors.Errorf("remote dial error: %v", res.Error) + } + + return channel.NetConn(), nil +} + +func (*agent) handleDial(ctx context.Context, label string, conn net.Conn) { + defer conn.Close() + + writeError := func(responseError error) error { + msg := "" + if responseError != nil { + msg = responseError.Error() + } + b, err := json.Marshal(DialResponse{ + Error: msg, + }) + if err != nil { + return xerrors.Errorf("marshal agent webrtc dial response: %w", err) + } + + _, err = conn.Write(b) + return err + } + + u, err := url.Parse(label) + if err != nil { + _ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) + return + } + + network := u.Scheme + addr := u.Host + u.Path + nconn, err := net.Dial(network, addr) + if err != nil { + _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) + return + } + + err = writeError(nil) + if err != nil { + return + } + + bicopy(ctx, conn, nconn) +} + +// bicopy copies all of the data between the two connections +// and will close them after one or both of them are done writing. +// If the context is canceled, both of the connections will be +// closed. +// +// NOTE: This function will block until the copying is done or the +// context is canceled. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + defer c1.Close() + defer c2.Close() + + ctx, cancel := context.WithCancel(ctx) + + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer cancel() + _, _ = io.Copy(dst, src) + } + + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + <-ctx.Done() +} diff --git a/peer/conn.go b/peer/conn.go index e9126443b8eca..c08c9f6b63203 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -469,8 +469,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) { return newChannel(c, dataChannel, &ChannelOptions{}), nil } -// Dial creates a new DataChannel. -func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { +// OpenChannel creates a new DataChannel. +func (c *Conn) OpenChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { if opts == nil { opts = &ChannelOptions{} } diff --git a/peer/conn_test.go b/peer/conn_test.go index 960ec34cfafba..e8a6f56f735f5 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -90,7 +90,7 @@ func TestConn(t *testing.T) { _, err := server.Ping() require.NoError(t, err) // Create a channel that closes on disconnect. - channel, err := server.Dial(context.Background(), "wow", nil) + channel, err := server.OpenChannel(context.Background(), "wow", nil) assert.NoError(t, err) err = wan.Stop() require.NoError(t, err) @@ -108,7 +108,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) @@ -124,7 +124,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, wan := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -196,7 +196,7 @@ func TestConn(t *testing.T) { defaultTransport := http.DefaultTransport.(*http.Transport).Clone() var cch *peer.Channel defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{}) + cch, err = client.OpenChannel(ctx, "hello", &peer.ChannelOptions{}) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func TestConn(t *testing.T) { require.NoError(t, err) expectedErr := xerrors.New("wow") _ = conn.CloseWithError(expectedErr) - _, err = conn.Dial(context.Background(), "", nil) + _, err = conn.OpenChannel(context.Background(), "", nil) require.ErrorIs(t, err, expectedErr) }) @@ -274,7 +274,7 @@ func TestConn(t *testing.T) { client, server, _ := createPair(t) exchange(t, client, server) go func() { - channel, err := client.Dial(context.Background(), "test", nil) + channel, err := client.OpenChannel(context.Background(), "test", nil) require.NoError(t, err) _, err = channel.Write([]byte{1, 2}) require.NoError(t, err) From 1972629c1f61a5742191cb40b5cdfb440cabed4f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 9 May 2022 21:13:41 +0000 Subject: [PATCH 02/10] chore: Kyle PR review comments --- agent/agent.go | 72 +++++++++++++++++++++++ agent/agent_test.go | 6 +- agent/conn.go | 42 +++++++++++++- agent/dial.go | 135 -------------------------------------------- peer/conn.go | 4 +- peer/conn_test.go | 14 ++--- 6 files changed, 124 insertions(+), 149 deletions(-) delete mode 100644 agent/dial.go diff --git a/agent/agent.go b/agent/agent.go index 05d7e1799519a..7745c96d6e8bb 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "net/url" "os" "os/exec" "os/user" @@ -619,6 +620,57 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne } } +// dialResponse is written to datachannels with protocol "dial" by the agent as +// the first packet to signify whether the dial succeeded or failed. +type dialResponse struct { + Error string `json:"error,omitempty"` +} + +func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { + defer conn.Close() + + writeError := func(responseError error) error { + msg := "" + if responseError != nil { + msg = responseError.Error() + if !xerrors.Is(responseError, io.EOF) { + a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError)) + } + } + b, err := json.Marshal(dialResponse{ + Error: msg, + }) + if err != nil { + a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err)) + return xerrors.Errorf("marshal agent webrtc dial response: %w", err) + } + + _, err = conn.Write(b) + return err + } + + u, err := url.Parse(label) + if err != nil { + _ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) + return + } + + network := u.Scheme + addr := u.Host + u.Path + nconn, err := net.Dial(network, addr) + if err != nil { + _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) + return + } + + err = writeError(nil) + if err != nil { + return + } + + bicopy(ctx, conn, nconn) +} + // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { select { @@ -664,3 +716,23 @@ func (r *reconnectingPTY) Close() { r.circularBuffer.Reset() r.timeout.Stop() } + +// bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + defer c1.Close() + defer c2.Close() + + ctx, cancel := context.WithCancel(ctx) + + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer cancel() + _, _ = io.Copy(dst, src) + } + + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + <-ctx.Done() +} diff --git a/agent/agent_test.go b/agent/agent_test.go index 959cb04b1bbe6..cbf0f4a66ab23 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -304,13 +304,13 @@ func TestAgent(t *testing.T) { // Try to dial the listener over WebRTC conn := setupAgent(t, agent.Metadata{}, 0) - conn1, err := conn.Dial(l.Addr().Network(), l.Addr().String()) + conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn1.Close() testDial(t, conn1) // Dial again using the same WebRTC client - conn2, err := conn.Dial(l.Addr().Network(), l.Addr().String()) + conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn2.Close() testDial(t, conn2) @@ -337,7 +337,7 @@ func TestAgent(t *testing.T) { // Try to dial the non-existent Unix socket over WebRTC conn := setupAgent(t, agent.Metadata{}, 0) - netConn, err := conn.Dial("unix", filepath.Join(tmpDir, "test.sock")) + netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock")) require.Error(t, err) require.ErrorContains(t, err, "remote dial error") require.ErrorContains(t, err, "no such file") diff --git a/agent/conn.go b/agent/conn.go index 3fb28b5a74a5b..747b34ddc7206 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -2,8 +2,11 @@ package agent import ( "context" + "encoding/json" "fmt" "net" + "net/url" + "strings" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -32,7 +35,7 @@ type Conn struct { // ReconnectingPTY returns a connection serving a TTY that can // be reconnected to via ID. func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) { - channel, err := c.OpenChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ + channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ Protocol: "reconnecting-pty", }) if err != nil { @@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error // SSH dials the built-in SSH server. func (c *Conn) SSH() (net.Conn, error) { - channel, err := c.OpenChannel(context.Background(), "ssh", &peer.ChannelOptions{ + channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ Protocol: "ssh", }) if err != nil { @@ -71,6 +74,41 @@ func (c *Conn) SSHClient() (*ssh.Client, error) { return ssh.NewClient(sshConn, channels, requests), nil } +// DialContext dials an arbitrary protocol+address from inside the workspace and +// proxies it through the provided net.Conn. +func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + u := &url.URL{ + Scheme: network, + } + if strings.HasPrefix(network, "unix") { + u.Path = addr + } else { + u.Host = addr + } + + channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ + Protocol: "dial", + }) + if err != nil { + return nil, xerrors.Errorf("create datachannel: %w", err) + } + + // The first message written from the other side is a JSON payload + // containing the dial error. + dec := json.NewDecoder(channel) + var res dialResponse + err = dec.Decode(&res) + if err != nil { + return nil, xerrors.Errorf("failed to decode initial packet: %w", err) + } + if res.Error != "" { + _ = channel.Close() + return nil, xerrors.Errorf("remote dial error: %v", res.Error) + } + + return channel.NetConn(), nil +} + func (c *Conn) Close() error { _ = c.Negotiator.DRPCConn().Close() return c.Conn.Close() diff --git a/agent/dial.go b/agent/dial.go deleted file mode 100644 index f457a2c77c179..0000000000000 --- a/agent/dial.go +++ /dev/null @@ -1,135 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "io" - "net" - "net/url" - "strings" - - "github.com/google/uuid" - "golang.org/x/xerrors" - - "github.com/coder/coder/peer" -) - -// DialResponse is written to datachannels with protocol "dial" by the agent as -// the first packet to signify whether the dial succeeded or failed. -type DialResponse struct { - Error string `json:"error,omitempty"` -} - -// Dial dials an arbitrary protocol+address from inside the workspace and -// proxies it through the provided net.Conn. -func (c *Conn) Dial(network string, addr string) (net.Conn, error) { - // Force unique URL by including a random UUID. - id, err := uuid.NewRandom() - if err != nil { - return nil, xerrors.Errorf("generate random UUID: %w", err) - } - - host := "" - path := "" - if strings.HasPrefix(network, "unix") { - path = addr - } else { - host = addr - } - - label := (&url.URL{ - Scheme: network, - Host: host, - Path: path, - RawQuery: (url.Values{ - "id": []string{id.String()}, - }).Encode(), - }).String() - - channel, err := c.OpenChannel(context.Background(), label, &peer.ChannelOptions{ - Protocol: "dial", - }) - if err != nil { - return nil, xerrors.Errorf("pty: %w", err) - } - - // The first message written from the other side is a JSON payload - // containing the dial error. - dec := json.NewDecoder(channel) - var res DialResponse - err = dec.Decode(&res) - if err != nil { - return nil, xerrors.Errorf("failed to decode initial packet: %w", err) - } - if res.Error != "" { - _ = channel.Close() - return nil, xerrors.Errorf("remote dial error: %v", res.Error) - } - - return channel.NetConn(), nil -} - -func (*agent) handleDial(ctx context.Context, label string, conn net.Conn) { - defer conn.Close() - - writeError := func(responseError error) error { - msg := "" - if responseError != nil { - msg = responseError.Error() - } - b, err := json.Marshal(DialResponse{ - Error: msg, - }) - if err != nil { - return xerrors.Errorf("marshal agent webrtc dial response: %w", err) - } - - _, err = conn.Write(b) - return err - } - - u, err := url.Parse(label) - if err != nil { - _ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) - return - } - - network := u.Scheme - addr := u.Host + u.Path - nconn, err := net.Dial(network, addr) - if err != nil { - _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) - return - } - - err = writeError(nil) - if err != nil { - return - } - - bicopy(ctx, conn, nconn) -} - -// bicopy copies all of the data between the two connections -// and will close them after one or both of them are done writing. -// If the context is canceled, both of the connections will be -// closed. -// -// NOTE: This function will block until the copying is done or the -// context is canceled. -func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { - defer c1.Close() - defer c2.Close() - - ctx, cancel := context.WithCancel(ctx) - - copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer cancel() - _, _ = io.Copy(dst, src) - } - - go copyFunc(c1, c2) - go copyFunc(c2, c1) - - <-ctx.Done() -} diff --git a/peer/conn.go b/peer/conn.go index c08c9f6b63203..949468bfa7064 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -469,8 +469,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) { return newChannel(c, dataChannel, &ChannelOptions{}), nil } -// OpenChannel creates a new DataChannel. -func (c *Conn) OpenChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { +// CreateChannel creates a new DataChannel. +func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { if opts == nil { opts = &ChannelOptions{} } diff --git a/peer/conn_test.go b/peer/conn_test.go index e8a6f56f735f5..46bcea980e5f3 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -90,7 +90,7 @@ func TestConn(t *testing.T) { _, err := server.Ping() require.NoError(t, err) // Create a channel that closes on disconnect. - channel, err := server.OpenChannel(context.Background(), "wow", nil) + channel, err := server.CreateChannel(context.Background(), "wow", nil) assert.NoError(t, err) err = wan.Stop() require.NoError(t, err) @@ -108,7 +108,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) @@ -124,7 +124,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, wan := createPair(t) exchange(t, client, server) - cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.OpenChannel(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -196,7 +196,7 @@ func TestConn(t *testing.T) { defaultTransport := http.DefaultTransport.(*http.Transport).Clone() var cch *peer.Channel defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.OpenChannel(ctx, "hello", &peer.ChannelOptions{}) + cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func TestConn(t *testing.T) { require.NoError(t, err) expectedErr := xerrors.New("wow") _ = conn.CloseWithError(expectedErr) - _, err = conn.OpenChannel(context.Background(), "", nil) + _, err = conn.CreateChannel(context.Background(), "", nil) require.ErrorIs(t, err, expectedErr) }) @@ -274,7 +274,7 @@ func TestConn(t *testing.T) { client, server, _ := createPair(t) exchange(t, client, server) go func() { - channel, err := client.OpenChannel(context.Background(), "test", nil) + channel, err := client.CreateChannel(context.Background(), "test", nil) require.NoError(t, err) _, err = channel.Write([]byte{1, 2}) require.NoError(t, err) From e3eb83968777c3afdc0ffcf28d3d359ad8841864 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 01:53:46 +0000 Subject: [PATCH 03/10] feat: add port-forward subcommand --- agent/agent.go | 31 +++- agent/conn.go | 3 +- cli/portforward.go | 342 +++++++++++++++++++++++++++++++++++++++++++++ cli/root.go | 2 +- cli/ssh.go | 89 ++++++------ cli/templates.go | 9 +- cli/tunnel.go | 12 -- go.mod | 2 +- peer/channel.go | 4 +- 9 files changed, 431 insertions(+), 63 deletions(-) create mode 100644 cli/portforward.go delete mode 100644 cli/tunnel.go diff --git a/agent/agent.go b/agent/agent.go index 7745c96d6e8bb..86c8b5cc28de3 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -657,6 +657,14 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { network := u.Scheme addr := u.Host + u.Path + if strings.HasPrefix(network, "unix") { + addr, err = ExpandPath(addr) + if err != nil { + _ = writeError(xerrors.Errorf("expand path %q: %w", addr, err)) + return + } + } + nconn, err := net.Dial(network, addr) if err != nil { _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) @@ -668,7 +676,7 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { return } - bicopy(ctx, conn, nconn) + Bicopy(ctx, conn, nconn) } // isClosed returns whether the API is closed or not. @@ -717,10 +725,10 @@ func (r *reconnectingPTY) Close() { r.timeout.Stop() } -// bicopy copies all of the data between the two connections and will close them +// Bicopy copies all of the data between the two connections and will close them // after one or both of them are done writing. If the context is canceled, both // of the connections will be closed. -func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { +func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { defer c1.Close() defer c2.Close() @@ -736,3 +744,20 @@ func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { <-ctx.Done() } + +// ExpandPath expands the tilde at the beggining of a path to the current user's +// home directory and returns a full absolute path. +func ExpandPath(in string) (string, error) { + usr, err := user.Current() + if err != nil { + return "", xerrors.Errorf("get current user details: %w", err) + } + + if in == "~" { + in = usr.HomeDir + } else if strings.HasPrefix(in, "~/") { + in = filepath.Join(usr.HomeDir, in[2:]) + } + + return filepath.Abs(in) +} diff --git a/agent/conn.go b/agent/conn.go index 747b34ddc7206..56d3d42ea1784 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -87,7 +87,8 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne } channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ - Protocol: "dial", + Protocol: "dial", + Unordered: strings.HasPrefix(network, "udp"), }) if err != nil { return nil, xerrors.Errorf("create datachannel: %w", err) diff --git a/cli/portforward.go b/cli/portforward.go new file mode 100644 index 0000000000000..0de9206a90629 --- /dev/null +++ b/cli/portforward.go @@ -0,0 +1,342 @@ +package cli + +import ( + "context" + "fmt" + "net" + "runtime" + "strconv" + "strings" + "sync" + + "github.com/pion/udp" + "github.com/spf13/cobra" + "golang.org/x/xerrors" + + "github.com/coder/coder/agent" + coderagent "github.com/coder/coder/agent" + "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" +) + +func portForward() *cobra.Command { + var ( + tcpForwards []string // : + udpForwards []string // : + unixForwards []string // : OR : + ) + cmd := &cobra.Command{ + Use: "port-forward ", + Aliases: []string{"tunnel"}, + Args: cobra.ExactArgs(1), + Example: ` + - Port forward a single TCP port from 1234 in the workspace to port 5678 on + your local machine + + ` + cliui.Styles.Code.Render("$ coder port-forward --tcp 5678:1234") + ` + + - Port forward a single UDP port from port 9000 to port 9000 on your local + machine + + ` + cliui.Styles.Code.Render("$ coder port-forward --udp 9000") + ` + + - Forward a Unix socket in the workspace to a local Unix socket + + ` + cliui.Styles.Code.Render("$ coder port-forward --unix ./local.sock:~/remote.sock") + ` + + - Forward a Unix socket in the workspace to a local TCP port + + ` + cliui.Styles.Code.Render("$ coder port-forward --unix 8080:~/remote.sock") + ` + + - Port forward multiple TCP ports and a UDP port + + ` + cliui.Styles.Code.Render("$ coder port-forward --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"), + RunE: func(cmd *cobra.Command, args []string) error { + // TODO: remove parsing debug + fmt.Println("TCP:") + for _, tcp := range tcpForwards { + fmt.Println("\t", tcp) + } + fmt.Println("UDP:") + for _, udp := range udpForwards { + fmt.Println("\t", udp) + } + fmt.Println("Unix:") + for _, unix := range unixForwards { + fmt.Println("\t", unix) + } + fmt.Println() + + specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) + if err != nil { + return xerrors.Errorf("parse port-forward specs: %w", err) + } + + fmt.Println("SPECS:") + for _, spec := range specs { + fmt.Printf("\t%+v\n", spec) + } + + client, err := createClient(cmd) + if err != nil { + return err + } + organization, err := currentOrganization(cmd, client) + if err != nil { + return err + } + + workspace, agent, err := getWorkspaceAndAgent(cmd.Context(), client, organization.ID, codersdk.Me, args[0]) + if err != nil { + return err + } + if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { + return xerrors.New("workspace must be in start transition to port-forward") + } + if workspace.LatestBuild.Job.CompletedAt == nil { + err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + if err != nil { + return err + } + } + + err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + WorkspaceName: workspace.Name, + Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { + return client.WorkspaceAgent(ctx, agent.ID) + }, + }) + if err != nil { + return xerrors.Errorf("await agent: %w", err) + } + + conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) + if err != nil { + return xerrors.Errorf("dial workspace agent: %w", err) + } + defer conn.Close() + + // Start all listeners + var wg sync.WaitGroup + for _, spec := range specs { + var ( + l net.Listener + err error + ) + fmt.Printf("Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) + switch spec.listenNetwork { + case "tcp": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + case "udp": + host, port, err := net.SplitHostPort(spec.listenAddress) + if err != nil { + return xerrors.Errorf("split %q: %w", spec.listenAddress, err) + } + portInt, err := strconv.Atoi(port) + if err != nil { + return xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) + } + + l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ + IP: net.ParseIP(host), + Port: portInt, + }) + case "unix": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + default: + return xerrors.Errorf("unknown listen network %q", spec.listenNetwork) + } + if err != nil { + return xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) + } + + wg.Add(1) + go func(spec portForwardSpec) { + defer wg.Done() + for { + netConn, err := l.Accept() + if err != nil { + fmt.Printf("Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) + fmt.Println("Killing listener") + return + } + + go func(netConn net.Conn) { + defer netConn.Close() + remoteConn, err := conn.DialContext(cmd.Context(), spec.dialNetwork, spec.dialAddress) + if err != nil { + fmt.Printf("Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + return + } + defer remoteConn.Close() + + coderagent.Bicopy(cmd.Context(), netConn, remoteConn) + }(netConn) + } + }(spec) + } + + wg.Wait() + return nil + }, + } + + cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") + cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") + cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port") + + return cmd +} + +type portForwardSpec struct { + listenNetwork string // tcp, udp, unix + listenAddress string // : or path + + dialNetwork string // tcp, udp, unix + dialAddress string // : or path +} + +func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { + specs := []portForwardSpec{} + + for _, spec := range tcp { + local, remote, err := parsePortPort(spec) + if err != nil { + return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec) + } + + specs = append(specs, portForwardSpec{ + listenNetwork: "tcp", + listenAddress: fmt.Sprintf("127.0.0.1:%v", local), + dialNetwork: "tcp", + dialAddress: fmt.Sprintf("127.0.0.1:%v", remote), + }) + } + + for _, spec := range udp { + local, remote, err := parsePortPort(spec) + if err != nil { + return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec) + } + + specs = append(specs, portForwardSpec{ + listenNetwork: "udp", + listenAddress: fmt.Sprintf("127.0.0.1:%v", local), + dialNetwork: "udp", + dialAddress: fmt.Sprintf("127.0.0.1:%v", remote), + }) + } + + for _, specStr := range unix { + localPath, localTCP, remotePath, err := parseUnixUnix(specStr) + if err != nil { + return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr) + } + + spec := portForwardSpec{ + dialNetwork: "unix", + dialAddress: remotePath, + } + if localPath == "" { + spec.listenNetwork = "tcp" + spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP) + } else { + if runtime.GOOS == "windows" { + return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows") + } + spec.listenNetwork = "unix" + spec.listenAddress = localPath + } + specs = append(specs, spec) + } + + // Check for duplicate entries. + locals := map[string]struct{}{} + for _, spec := range specs { + localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress) + if _, ok := locals[localStr]; ok { + return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress) + } + locals[localStr] = struct{}{} + } + + return specs, nil +} + +func parsePort(in string) (uint16, error) { + port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16) + if err != nil { + return 0, xerrors.Errorf("parse port %q: %w", in, err) + } + if port == 0 { + return 0, xerrors.New("port cannot be 0") + } + + return uint16(port), nil +} + +func parseUnixPath(in string) (string, error) { + path, err := agent.ExpandPath(strings.TrimSpace(in)) + if err != nil { + return "", xerrors.Errorf("tidy path %q: %w", in, err) + } + + return path, nil +} + +func parsePortPort(in string) (uint16, uint16, error) { + parts := strings.Split(in, ":") + if len(parts) > 2 { + return 0, 0, xerrors.Errorf("invalid port specification %q", in) + } + if len(parts) == 1 { + // Duplicate the single part + parts = append(parts, parts[0]) + } + + local, err := parsePort(parts[0]) + if err != nil { + return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err) + } + remote, err := parsePort(parts[1]) + if err != nil { + return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err) + } + + return uint16(local), uint16(remote), nil +} + +func parsePortOrUnixPath(in string) (string, uint16, error) { + port, err := parsePort(in) + if err == nil { + return "", port, nil + } + + path, err := parseUnixPath(in) + if err != nil { + return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err) + } + + return path, 0, nil +} + +func parseUnixUnix(in string) (string, uint16, string, error) { + parts := strings.Split(in, ":") + if len(parts) > 2 { + return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in) + } + if len(parts) == 1 { + // Duplicate the single part + parts = append(parts, parts[0]) + } + + localPath, localPort, err := parsePortOrUnixPath(parts[0]) + if err != nil { + return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err) + } + + // We don't really touch the remote path at all since it gets cleaned + // up/expanded on the remote. + return localPath, localPort, parts[1], nil +} diff --git a/cli/root.go b/cli/root.go index 4a38d5895b079..ff64aa1baff7b 100644 --- a/cli/root.go +++ b/cli/root.go @@ -86,7 +86,7 @@ func Root() *cobra.Command { templates(), update(), users(), - tunnel(), + portForward(), workspaceAgent(), ) diff --git a/cli/ssh.go b/cli/ssh.go index 559514396810e..dd91ab4b9716b 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -38,16 +38,13 @@ func ssh() *cobra.Command { return err } - workspaceParts := strings.Split(args[0], ".") - workspace, err := client.WorkspaceByOwnerAndName(cmd.Context(), organization.ID, codersdk.Me, workspaceParts[0]) + workspace, agent, err := getWorkspaceAndAgent(cmd.Context(), client, organization.ID, codersdk.Me, args[0]) if err != nil { return err } - if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { return xerrors.New("workspace must be in start transition to ssh") } - if workspace.LatestBuild.Job.CompletedAt == nil { err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) if err != nil { @@ -55,41 +52,6 @@ func ssh() *cobra.Command { } } - if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete { - return xerrors.New("workspace is deleting...") - } - - resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID) - if err != nil { - return err - } - - agents := make([]codersdk.WorkspaceAgent, 0) - for _, resource := range resources { - agents = append(agents, resource.Agents...) - } - if len(agents) == 0 { - return xerrors.New("workspace has no agents") - } - var agent codersdk.WorkspaceAgent - if len(workspaceParts) >= 2 { - for _, otherAgent := range agents { - if otherAgent.Name != workspaceParts[1] { - continue - } - agent = otherAgent - break - } - if agent.ID == uuid.Nil { - return xerrors.Errorf("agent not found by name %q", workspaceParts[1]) - } - } - if agent.ID == uuid.Nil { - if len(agents) > 1 { - return xerrors.New("you must specify the name of an agent") - } - agent = agents[0] - } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ @@ -180,6 +142,55 @@ func ssh() *cobra.Command { return cmd } +func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uuid.UUID, userID uuid.UUID, in string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { + workspaceParts := strings.Split(in, ".") + workspace, err := client.WorkspaceByOwnerAndName(ctx, orgID, userID, workspaceParts[0]) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("get workspace %q: %w", workspaceParts[0], err) + } + + if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name) + } + + resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err) + } + + agents := make([]codersdk.WorkspaceAgent, 0) + for _, resource := range resources { + agents = append(agents, resource.Agents...) + } + if len(agents) == 0 { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) + } + + var agent *codersdk.WorkspaceAgent + if len(workspaceParts) >= 2 { + for _, otherAgent := range agents { + if otherAgent.Name != workspaceParts[1] { + continue + } + agent = &otherAgent + break + } + + if agent == nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1]) + } + } + + if agent == nil { + if len(agents) > 1 { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") + } + agent = &agents[0] + } + + return workspace, *agent, nil +} + type stdioConn struct { io.Reader io.Writer diff --git a/cli/templates.go b/cli/templates.go index 9ec6dbef10db0..4863ff372f606 100644 --- a/cli/templates.go +++ b/cli/templates.go @@ -1,8 +1,9 @@ package cli import ( - "github.com/fatih/color" "github.com/spf13/cobra" + + "github.com/coder/coder/cli/cliui" ) func templates() *cobra.Command { @@ -12,15 +13,15 @@ func templates() *cobra.Command { Example: ` - Create a template for developers to create workspaces - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates create") + ` + ` + cliui.Styles.Code.Render("$ coder templates create") + ` - Make changes to your template, and plan the changes - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates plan ") + ` + ` + cliui.Styles.Code.Render("$ coder templates plan ") + ` - Update the template. Your developers can update their workspaces - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates update "), + ` + cliui.Styles.Code.Render("$ coder templates update "), } cmd.AddCommand( templateCreate(), diff --git a/cli/tunnel.go b/cli/tunnel.go deleted file mode 100644 index 7f50abcd7d582..0000000000000 --- a/cli/tunnel.go +++ /dev/null @@ -1,12 +0,0 @@ -package cli - -import "github.com/spf13/cobra" - -func tunnel() *cobra.Command { - return &cobra.Command{ - Use: "tunnel", - RunE: func(cmd *cobra.Command, args []string) error { - return nil - }, - } -} diff --git a/go.mod b/go.mod index 128a705b78d79..27f7d71565bf8 100644 --- a/go.mod +++ b/go.mod @@ -87,6 +87,7 @@ require ( github.com/pion/logging v0.2.2 github.com/pion/transport v0.13.0 github.com/pion/turn/v2 v2.0.8 + github.com/pion/udp v0.1.1 github.com/pion/webrtc/v3 v3.1.34 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/sftp v1.13.4 @@ -211,7 +212,6 @@ require ( github.com/pion/sdp/v3 v3.0.4 // indirect github.com/pion/srtp/v2 v2.0.5 // indirect github.com/pion/stun v0.3.5 // indirect - github.com/pion/udp v0.1.1 // indirect github.com/pires/go-proxyproto v0.6.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/peer/channel.go b/peer/channel.go index 5a4424f8bf91c..7db76d984f815 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -53,8 +53,8 @@ type ChannelOptions struct { // Arbitrary string that can be parsed on `Accept`. Protocol string - // Ordered determines whether the channel acts like - // a TCP connection. Defaults to false. + // Unordered determines whether the channel acts like + // a UDP connection. Defaults to false. Unordered bool // Whether the channel will be left open on disconnect or not. From 46092cc615fc41637974d19363423af995edace6 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 13:48:04 +0000 Subject: [PATCH 04/10] wip --- agent/agent.go | 5 +- agent/agent_test.go | 44 +++--- agent/conn.go | 4 +- cli/portforward.go | 64 ++++++-- cli/portforward_test.go | 321 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 403 insertions(+), 35 deletions(-) create mode 100644 cli/portforward_test.go diff --git a/agent/agent.go b/agent/agent.go index 86c8b5cc28de3..0357aa3c2b4a2 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -665,7 +665,10 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { } } - nconn, err := net.Dial(network, addr) + a.logger.Warn(ctx, "yeah", slog.F("network", network), slog.F("addr", addr)) + + d := net.Dialer{Timeout: 3 * time.Second} + nconn, err := d.DialContext(ctx, network, addr) if err != nil { _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) return diff --git a/agent/agent_test.go b/agent/agent_test.go index cbf0f4a66ab23..2a8a2224abec6 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -298,22 +298,20 @@ func TestAgent(t *testing.T) { return } - testAccept(t, c) + go testAccept(t, c) } }() - // Try to dial the listener over WebRTC + // Dial the listener over WebRTC twice and test out of order conn := setupAgent(t, agent.Metadata{}, 0) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn1.Close() - testDial(t, conn1) - - // Dial again using the same WebRTC client conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn2.Close() testDial(t, conn2) + testDial(t, conn1) }) } }) @@ -418,30 +416,28 @@ var dialTestPayload = []byte("dean-was-here123") func testDial(t *testing.T, c net.Conn) { t.Helper() - // Dials will write and then expect echo back - n, err := c.Write(dialTestPayload) - require.NoError(t, err, "write test payload") - require.Equal(t, len(dialTestPayload), n, "test payload length does not match") - - b := make([]byte, len(dialTestPayload)+16) - n, err = c.Read(b) - require.NoError(t, err, "read test payload") - require.Equal(t, len(dialTestPayload), n, "read payload length does not match") - require.Equal(t, dialTestPayload, b[:n]) + assertWritePayload(t, c, dialTestPayload) + assertReadPayload(t, c, dialTestPayload) } func testAccept(t *testing.T, c net.Conn) { t.Helper() defer c.Close() - // Accepts will read then echo - b := make([]byte, len(dialTestPayload)+16) - n, err := c.Read(b) - require.NoError(t, err, "read test payload") - require.Equal(t, len(dialTestPayload), n, "read payload length does not match") - require.Equal(t, dialTestPayload, b[:n]) + assertReadPayload(t, c, dialTestPayload) + assertWritePayload(t, c, dialTestPayload) +} + +func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { + b := make([]byte, len(payload)+16) + n, err := r.Read(b) + require.NoError(t, err, "read payload") + require.Equal(t, len(payload), n, "read payload length does not match") + require.Equal(t, payload, b[:n]) +} - n, err = c.Write(dialTestPayload) - require.NoError(t, err, "write test payload") - require.Equal(t, len(dialTestPayload), n, "test payload length does not match") +func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { + n, err := w.Write(payload) + require.NoError(t, err, "write payload") + require.Equal(t, len(payload), n, "payload length does not match") } diff --git a/agent/conn.go b/agent/conn.go index 56d3d42ea1784..7fd285540c35d 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -13,6 +13,7 @@ import ( "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker/proto" + "github.com/google/uuid" ) // ReconnectingPTYRequest is sent from the client to the server @@ -78,7 +79,8 @@ func (c *Conn) SSHClient() (*ssh.Client, error) { // proxies it through the provided net.Conn. func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { u := &url.URL{ - Scheme: network, + Scheme: network, + RawQuery: "test=" + uuid.Must(uuid.NewRandom()).String(), } if strings.HasPrefix(network, "unix") { u.Path = addr diff --git a/cli/portforward.go b/cli/portforward.go index 0de9206a90629..1cc6a18662609 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "net" + "os" + "os/signal" "runtime" "strconv" "strings" "sync" + "syscall" "github.com/pion/udp" "github.com/spf13/cobra" @@ -72,6 +75,13 @@ func portForward() *cobra.Command { if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) } + if len(specs) == 0 { + err = cmd.Help() + if err != nil { + return xerrors.Errorf("generate help output: %w", err) + } + return xerrors.New("no port-forwards requested") + } fmt.Println("SPECS:") for _, spec := range specs { @@ -118,13 +128,26 @@ func portForward() *cobra.Command { defer conn.Close() // Start all listeners - var wg sync.WaitGroup - for _, spec := range specs { + var ( + ctx, cancel = context.WithCancel(cmd.Context()) + wg sync.WaitGroup + listeners = make([]net.Listener, len(specs)) + closeAllListeners = func() { + for _, l := range listeners { + if l == nil { + continue + } + _ = l.Close() + } + } + ) + defer cancel() + for i, spec := range specs { var ( l net.Listener err error ) - fmt.Printf("Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) + fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) switch spec.listenNetwork { case "tcp": l, err = net.Listen(spec.listenNetwork, spec.listenAddress) @@ -145,11 +168,14 @@ func portForward() *cobra.Command { case "unix": l, err = net.Listen(spec.listenNetwork, spec.listenAddress) default: + closeAllListeners() return xerrors.Errorf("unknown listen network %q", spec.listenNetwork) } if err != nil { + closeAllListeners() return xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) } + listeners[i] = l wg.Add(1) go func(spec portForwardSpec) { @@ -157,28 +183,48 @@ func portForward() *cobra.Command { for { netConn, err := l.Accept() if err != nil { - fmt.Printf("Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) - fmt.Println("Killing listener") + fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) + fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") return } go func(netConn net.Conn) { defer netConn.Close() - remoteConn, err := conn.DialContext(cmd.Context(), spec.dialNetwork, spec.dialAddress) + remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) if err != nil { - fmt.Printf("Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) return } defer remoteConn.Close() - coderagent.Bicopy(cmd.Context(), netConn, remoteConn) + coderagent.Bicopy(ctx, netConn, remoteConn) }(netConn) } }(spec) } + // Wait for the context to be canceled or for a signal and close + // all listeners. + var closeErr error + go func() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-ctx.Done(): + closeErr = ctx.Err() + case <-sigs: + fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections") + closeErr = xerrors.New("signal received") + } + + cancel() + closeAllListeners() + }() + + fmt.Fprintln(cmd.OutOrStderr(), "Ready!") wg.Wait() - return nil + return closeErr }, } diff --git a/cli/portforward_test.go b/cli/portforward_test.go new file mode 100644 index 0000000000000..a1d3c49b06c95 --- /dev/null +++ b/cli/portforward_test.go @@ -0,0 +1,321 @@ +package cli_test + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/pion/udp" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisioner/echo" + "github.com/coder/coder/provisionersdk/proto" +) + +func TestPortForward(t *testing.T) { + t.Parallel() + + t.Run("None", func(t *testing.T) { + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + + cmd, root := clitest.New(t, "port-forward", "blah") + clitest.SetupConfig(t, client, root) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + + err := cmd.Execute() + require.Error(t, err) + require.ErrorContains(t, err, "no port-forwards") + + // Check that the help was printed. + require.Contains(t, buf.String(), "port-forward ") + }) + + cases := []struct { + name string + network string + // The flag to pass to `coder port-forward X` to port-forward this type + // of connection. Has two format args (both strings), the first is the + // local address and the second is the remote address. + flag string + // setupRemote creates a "remote" listener to emulate a service in the + // workspace. + setupRemote func(t *testing.T) net.Listener + // setupLocal returns an available port or Unix socket path that the + // port-forward command will listen on "locally". Returns the address + // you pass to net.Dial, and the port/path you pass to `coder + // port-forward`. + setupLocal func(t *testing.T) (string, string) + }{ + { + name: "TCP", + network: "tcp", + flag: "--tcp=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener to generate random port") + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoErrorf(t, err, "split TCP address %q", l.Addr().String()) + return l.Addr().String(), port + }, + }, + { + name: "UDP", + network: "udp", + flag: "--udp=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener to generate random port") + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoErrorf(t, err, "split UDP address %q", l.Addr().String()) + return l.Addr().String(), port + }, + }, + { + name: "Unix", + network: "unix", + flag: "--unix=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + if runtime.GOOS == "windows" { + t.Skip("Unix socket forwarding isn't supported on Windows") + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) + require.NoError(t, err, "create UDP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + path := filepath.Join(tmpDir, "test.sock") + return path, path + }, + }, + } + + for _, c := range cases { + if c.name != "Unix" { + continue + } + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + t.Run("One", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + l1, p1 = setupTestListener(t, c.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + }) + + // Create a flag that forwards from local to listener 1. + localAddress, localFlag := c.setupLocal(t) + flag := fmt.Sprintf(c.flag, localFlag, p1) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listener. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) + clitest.SetupConfig(t, client, root) + buf := new(bytes.Buffer) + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open two connections simultaneously and test them out of + // sync. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, c.network, localAddress) + require.NoError(t, err, "open connection 1 to 'local' listener") + defer c1.Close() + c2, err := d.DialContext(ctx, c.network, localAddress) + require.NoError(t, err, "open connection 2 to 'local' listener") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) + }) + } +} + +// runAgent creates a fake workspace and starts an agent locally for that +// workspace. The agent will be cleaned up on test completion. +func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) { + ctx := context.Background() + user, err := client.User(ctx, userID) + require.NoError(t, err, "specified user does not exist") + require.Greater(t, len(user.OrganizationIDs), 0, "user has no organizations") + orgID := user.OrganizationIDs[0] + + // Setup echo provisioner + agentToken := uuid.NewString() + coderdtest.NewProvisionerDaemon(t, client) + version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "somename", + Type: "someinstance", + Agents: []*proto.Agent{{ + Auth: &proto.Agent_Token{ + Token: agentToken, + }, + }}, + }}, + }, + }, + }}, + }) + + // Create template and workspace + template := coderdtest.CreateTemplate(t, client, orgID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + // Start workspace agent in a goroutine + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) + agentClient := &*client + clitest.SetupConfig(t, agentClient, root) + agentCtx, agentCancel := context.WithCancel(ctx) + t.Cleanup(agentCancel) + go func() { + err := cmd.ExecuteContext(agentCtx) + require.NoError(t, err) + }() + + coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID) + require.NoError(t, err) + + return resources, workspace +} + +// setupTestListener starts accepting connections and echoing a single packet. +// Returns the listener and the listen port or Unix path. +func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { + t.Cleanup(func() { + _ = l.Close() + }) + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + + go testAccept(t, c) + } + }() + + addr := l.Addr().String() + if !strings.HasPrefix(l.Addr().Network(), "unix") { + _, port, err := net.SplitHostPort(addr) + require.NoErrorf(t, err, "split non-Unix listen path %q", addr) + addr = port + } + + return l, addr +} + +var dialTestPayload = []byte("dean-was-here123") + +func testDial(t *testing.T, c net.Conn) { + t.Helper() + + assertWritePayload(t, c, dialTestPayload) + assertReadPayload(t, c, dialTestPayload) +} + +func testAccept(t *testing.T, c net.Conn) { + t.Helper() + defer c.Close() + + assertReadPayload(t, c, dialTestPayload) + assertWritePayload(t, c, dialTestPayload) +} + +func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { + b := make([]byte, len(payload)+16) + n, err := r.Read(b) + require.NoError(t, err, "read payload") + require.Equal(t, len(payload), n, "read payload length does not match") + require.Equal(t, payload, b[:n]) +} + +func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { + n, err := w.Write(payload) + require.NoError(t, err, "write payload") + require.Equal(t, len(payload), n, "payload length does not match") +} + +func waitForPortForwardReady(t *testing.T, output *bytes.Buffer) { + for i := 0; i < 100; i++ { + time.Sleep(250 * time.Millisecond) + + data := output.String() + if strings.Contains(data, "Ready!") { + return + } + } + + t.Fatal("port-forward command did not become ready in time") +} From f861ea8e3d1acea69c187e63e95c810c1407fe68 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 15:14:53 +0000 Subject: [PATCH 05/10] fix: avoid dropped datachannels in quick succession --- agent/conn.go | 4 +- cli/portforward.go | 147 +++++++++++++++------------------ cli/portforward_test.go | 177 +++++++++++++++++++++++++++++++++++++++- peer/conn.go | 15 ++-- 4 files changed, 250 insertions(+), 93 deletions(-) diff --git a/agent/conn.go b/agent/conn.go index 7fd285540c35d..56d3d42ea1784 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -13,7 +13,6 @@ import ( "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker/proto" - "github.com/google/uuid" ) // ReconnectingPTYRequest is sent from the client to the server @@ -79,8 +78,7 @@ func (c *Conn) SSHClient() (*ssh.Client, error) { // proxies it through the provided net.Conn. func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { u := &url.URL{ - Scheme: network, - RawQuery: "test=" + uuid.Must(uuid.NewRandom()).String(), + Scheme: network, } if strings.HasPrefix(network, "unix") { u.Path = addr diff --git a/cli/portforward.go b/cli/portforward.go index 1cc6a18662609..faaea7da1a171 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -56,21 +56,6 @@ func portForward() *cobra.Command { ` + cliui.Styles.Code.Render("$ coder port-forward --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"), RunE: func(cmd *cobra.Command, args []string) error { - // TODO: remove parsing debug - fmt.Println("TCP:") - for _, tcp := range tcpForwards { - fmt.Println("\t", tcp) - } - fmt.Println("UDP:") - for _, udp := range udpForwards { - fmt.Println("\t", udp) - } - fmt.Println("Unix:") - for _, unix := range unixForwards { - fmt.Println("\t", unix) - } - fmt.Println() - specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) @@ -83,11 +68,6 @@ func portForward() *cobra.Command { return xerrors.New("no port-forwards requested") } - fmt.Println("SPECS:") - for _, spec := range specs { - fmt.Printf("\t%+v\n", spec) - } - client, err := createClient(cmd) if err != nil { return err @@ -127,10 +107,10 @@ func portForward() *cobra.Command { } defer conn.Close() - // Start all listeners + // Start all listeners. var ( ctx, cancel = context.WithCancel(cmd.Context()) - wg sync.WaitGroup + wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { for _, l := range listeners { @@ -143,64 +123,12 @@ func portForward() *cobra.Command { ) defer cancel() for i, spec := range specs { - var ( - l net.Listener - err error - ) - fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) - switch spec.listenNetwork { - case "tcp": - l, err = net.Listen(spec.listenNetwork, spec.listenAddress) - case "udp": - host, port, err := net.SplitHostPort(spec.listenAddress) - if err != nil { - return xerrors.Errorf("split %q: %w", spec.listenAddress, err) - } - portInt, err := strconv.Atoi(port) - if err != nil { - return xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) - } - - l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ - IP: net.ParseIP(host), - Port: portInt, - }) - case "unix": - l, err = net.Listen(spec.listenNetwork, spec.listenAddress) - default: - closeAllListeners() - return xerrors.Errorf("unknown listen network %q", spec.listenNetwork) - } + l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) if err != nil { closeAllListeners() - return xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) + return err } listeners[i] = l - - wg.Add(1) - go func(spec portForwardSpec) { - defer wg.Done() - for { - netConn, err := l.Accept() - if err != nil { - fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) - fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") - return - } - - go func(netConn net.Conn) { - defer netConn.Close() - remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) - if err != nil { - fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) - return - } - defer remoteConn.Close() - - coderagent.Bicopy(ctx, netConn, remoteConn) - }(netConn) - } - }(spec) } // Wait for the context to be canceled or for a signal and close @@ -235,6 +163,67 @@ func portForward() *cobra.Command { return cmd } +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { + fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) + + var ( + l net.Listener + err error + ) + switch spec.listenNetwork { + case "tcp": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + case "udp": + host, port, err := net.SplitHostPort(spec.listenAddress) + if err != nil { + return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err) + } + portInt, err := strconv.Atoi(port) + if err != nil { + return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) + } + + l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ + IP: net.ParseIP(host), + Port: portInt, + }) + case "unix": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + default: + return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork) + } + if err != nil { + return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) + } + + wg.Add(1) + go func(spec portForwardSpec) { + defer wg.Done() + for { + netConn, err := l.Accept() + if err != nil { + fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) + fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") + return + } + + go func(netConn net.Conn) { + defer netConn.Close() + remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) + if err != nil { + fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + return + } + defer remoteConn.Close() + + coderagent.Bicopy(ctx, netConn, remoteConn) + }(netConn) + } + }(spec) + + return l, nil +} + type portForwardSpec struct { listenNetwork string // tcp, udp, unix listenAddress string // : or path @@ -249,7 +238,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { for _, spec := range tcp { local, remote, err := parsePortPort(spec) if err != nil { - return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec) + return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) } specs = append(specs, portForwardSpec{ @@ -263,7 +252,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { for _, spec := range udp { local, remote, err := parsePortPort(spec) if err != nil { - return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec) + return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) } specs = append(specs, portForwardSpec{ @@ -277,7 +266,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { for _, specStr := range unix { localPath, localTCP, remotePath, err := parseUnixUnix(specStr) if err != nil { - return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr) + return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err) } spec := portForwardSpec{ diff --git a/cli/portforward_test.go b/cli/portforward_test.go index a1d3c49b06c95..8b5fe6ed4083e 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -139,14 +139,11 @@ func TestPortForward(t *testing.T) { } for _, c := range cases { - if c.name != "Unix" { - continue - } c := c t.Run(c.name, func(t *testing.T) { t.Parallel() - t.Run("One", func(t *testing.T) { + t.Run("OnePort", func(t *testing.T) { t.Parallel() var ( client = coderdtest.New(t, nil) @@ -189,8 +186,175 @@ func TestPortForward(t *testing.T) { testDial(t, c2) testDial(t, c1) }) + + t.Run("TwoPorts", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + l1, p1 = setupTestListener(t, c.setupRemote(t)) + l2, p2 = setupTestListener(t, c.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + _ = l2.Close() + }) + + // Create a flags for listener 1 and listener 2. + localAddress1, localFlag1 := c.setupLocal(t) + localAddress2, localFlag2 := c.setupLocal(t) + flag1 := fmt.Sprintf(c.flag, localFlag1, p1) + flag2 := fmt.Sprintf(c.flag, localFlag2, p2) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listeners. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2) + clitest.SetupConfig(t, client, root) + buf := new(bytes.Buffer) + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open a connection to both listener 1 and 2 simultaneously and + // then test them out of order. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, c.network, localAddress1) + require.NoError(t, err, "open connection 1 to 'local' listener 1") + defer c1.Close() + c2, err := d.DialContext(ctx, c.network, localAddress2) + require.NoError(t, err, "open connection 2 to 'local' listener 2") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) }) } + + // Test doing a TCP -> Unix forward. + t.Run("TCP2Unix", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + + // Find the TCP and Unix cases so we can use their setupLocal and + // setupRemote methods respectively. + tcpCase = cases[0] + unixCase = cases[2] + + // Setup remote Unix listener. + l1, p1 = setupTestListener(t, unixCase.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + }) + + // Create a flag that forwards from local TCP to Unix listener 1. + // Notably this is a --unix flag. + localAddress, localFlag := tcpCase.setupLocal(t) + flag := fmt.Sprintf(unixCase.flag, localFlag, p1) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listener. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) + clitest.SetupConfig(t, client, root) + buf := new(bytes.Buffer) + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open two connections simultaneously and test them out of + // sync. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, tcpCase.network, localAddress) + require.NoError(t, err, "open connection 1 to 'local' listener") + defer c1.Close() + c2, err := d.DialContext(ctx, tcpCase.network, localAddress) + require.NoError(t, err, "open connection 2 to 'local' listener") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) + + // Test doing TCP, UDP and Unix at the same time. + t.Run("All", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + // These aren't fixed size because we exclude Unix on Windows. + dials = []addr{} + flags = []string{} + ) + + // Start listeners and populate arrays with the cases. + for _, c := range cases { + if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" { + // Unix isn't supported on Windows, but we can still + // test other protocols together. + continue + } + + l, p := setupTestListener(t, c.setupRemote(t)) + t.Cleanup(func() { + _ = l.Close() + }) + + localAddress, localFlag := c.setupLocal(t) + dials = append(dials, addr{ + network: c.network, + addr: localAddress, + }) + flags = append(flags, fmt.Sprintf(c.flag, localFlag, p)) + } + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listeners. + cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...) + clitest.SetupConfig(t, client, root) + buf := new(bytes.Buffer) + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open connections to all items in the "dial" array. + var ( + d = net.Dialer{Timeout: 3 * time.Second} + conns = make([]net.Conn, len(dials)) + ) + for i, a := range dials { + c, err := d.DialContext(ctx, a.network, a.addr) + require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1) + defer c.Close() + conns[i] = c + } + + // Test each connection in reverse order. + for i := len(conns) - 1; i >= 0; i-- { + testDial(t, conns[i]) + } + }) } // runAgent creates a fake workspace and starts an agent locally for that @@ -319,3 +483,8 @@ func waitForPortForwardReady(t *testing.T, output *bytes.Buffer) { t.Fatal("port-forward command did not become ready in time") } + +type addr struct { + network string + addr string +} diff --git a/peer/conn.go b/peer/conn.go index 949468bfa7064..c81b29d0bbd38 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -68,7 +68,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp closed: make(chan struct{}), closedRTC: make(chan struct{}), closedICE: make(chan struct{}), - dcOpenChannel: make(chan *webrtc.DataChannel), + dcOpenChannel: make(chan *webrtc.DataChannel, 8), dcDisconnectChannel: make(chan struct{}), dcFailedChannel: make(chan struct{}), localCandidateChannel: make(chan webrtc.ICECandidateInit), @@ -264,12 +264,13 @@ func (c *Conn) init() error { }() }) c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) { - select { - case <-c.closed: - return - case c.dcOpenChannel <- dc: - default: - } + go func() { + select { + case <-c.closed: + return + case c.dcOpenChannel <- dc: + } + }() }) _, err := c.pingChannel() if err != nil { From 6dfd2f689e9e073375037be404175b2f925555de Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 15:33:21 +0000 Subject: [PATCH 06/10] chore: fix lint errors --- agent/agent.go | 20 +++++++++++++++----- cli/portforward.go | 35 +++++++++++++++++------------------ cli/portforward_test.go | 8 ++++++-- cli/ssh.go | 18 ++++++++++++------ 4 files changed, 50 insertions(+), 31 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 0357aa3c2b4a2..826886b9823f4 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -735,20 +735,30 @@ func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { defer c1.Close() defer c2.Close() - ctx, cancel := context.WithCancel(ctx) - + var wg sync.WaitGroup copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer cancel() + defer wg.Done() _, _ = io.Copy(dst, src) } + wg.Add(2) go copyFunc(c1, c2) go copyFunc(c2, c1) - <-ctx.Done() + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } } -// ExpandPath expands the tilde at the beggining of a path to the current user's +// ExpandPath expands the tilde at the beginning of a path to the current user's // home directory and returns a full absolute path. func ExpandPath(in string) (string, error) { usr, err := user.Current() diff --git a/cli/portforward.go b/cli/portforward.go index faaea7da1a171..449c334532c54 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -16,7 +16,6 @@ import ( "github.com/spf13/cobra" "golang.org/x/xerrors" - "github.com/coder/coder/agent" coderagent "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" @@ -142,7 +141,7 @@ func portForward() *cobra.Command { case <-ctx.Done(): closeErr = ctx.Err() case <-sigs: - fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections") + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections") closeErr = xerrors.New("signal received") } @@ -150,7 +149,7 @@ func portForward() *cobra.Command { closeAllListeners() }() - fmt.Fprintln(cmd.OutOrStderr(), "Ready!") + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") wg.Wait() return closeErr }, @@ -163,8 +162,8 @@ func portForward() *cobra.Command { return cmd } -func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { - fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) var ( l net.Listener @@ -183,7 +182,7 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.C return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) } - l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ + l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ //nolint:ineffassign IP: net.ParseIP(host), Port: portInt, }) @@ -202,8 +201,8 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.C for { netConn, err := l.Accept() if err != nil { - fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) - fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") return } @@ -211,7 +210,7 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.C defer netConn.Close() remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) if err != nil { - fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) return } defer remoteConn.Close() @@ -232,10 +231,10 @@ type portForwardSpec struct { dialAddress string // : or path } -func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { +func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) { specs := []portForwardSpec{} - for _, spec := range tcp { + for _, spec := range tcpSpecs { local, remote, err := parsePortPort(spec) if err != nil { return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) @@ -249,7 +248,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { }) } - for _, spec := range udp { + for _, spec := range udpSpecs { local, remote, err := parsePortPort(spec) if err != nil { return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) @@ -263,7 +262,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) { }) } - for _, specStr := range unix { + for _, specStr := range unixSpecs { localPath, localTCP, remotePath, err := parseUnixUnix(specStr) if err != nil { return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err) @@ -312,7 +311,7 @@ func parsePort(in string) (uint16, error) { } func parseUnixPath(in string) (string, error) { - path, err := agent.ExpandPath(strings.TrimSpace(in)) + path, err := coderagent.ExpandPath(strings.TrimSpace(in)) if err != nil { return "", xerrors.Errorf("tidy path %q: %w", in, err) } @@ -320,7 +319,7 @@ func parseUnixPath(in string) (string, error) { return path, nil } -func parsePortPort(in string) (uint16, uint16, error) { +func parsePortPort(in string) (local uint16, remote uint16, err error) { parts := strings.Split(in, ":") if len(parts) > 2 { return 0, 0, xerrors.Errorf("invalid port specification %q", in) @@ -330,16 +329,16 @@ func parsePortPort(in string) (uint16, uint16, error) { parts = append(parts, parts[0]) } - local, err := parsePort(parts[0]) + local, err = parsePort(parts[0]) if err != nil { return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err) } - remote, err := parsePort(parts[1]) + remote, err = parsePort(parts[1]) if err != nil { return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err) } - return uint16(local), uint16(remote), nil + return local, remote, nil } func parsePortOrUnixPath(in string) (string, uint16, error) { diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 8b5fe6ed4083e..5dde53ac5a833 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -28,6 +28,8 @@ func TestPortForward(t *testing.T) { t.Parallel() t.Run("None", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -138,7 +140,7 @@ func TestPortForward(t *testing.T) { }, } - for _, c := range cases { + for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter c := c t.Run(c.name, func(t *testing.T) { t.Parallel() @@ -346,7 +348,9 @@ func TestPortForward(t *testing.T) { for i, a := range dials { c, err := d.DialContext(ctx, a.network, a.addr) require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1) - defer c.Close() + t.Cleanup(func() { + _ = c.Close() + }) conns[i] = c } diff --git a/cli/ssh.go b/cli/ssh.go index dd91ab4b9716b..cc5b08d2fc3f6 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -166,29 +166,35 @@ func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uu return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) } - var agent *codersdk.WorkspaceAgent + var ( + // We can't use a pointer because linters are mad about using pointers + // from loop variables + agent codersdk.WorkspaceAgent + agentOK bool + ) if len(workspaceParts) >= 2 { for _, otherAgent := range agents { if otherAgent.Name != workspaceParts[1] { continue } - agent = &otherAgent + agent = otherAgent + agentOK = true break } - if agent == nil { + if !agentOK { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1]) } } - if agent == nil { + if !agentOK { if len(agents) > 1 { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") } - agent = &agents[0] + agent = agents[0] } - return workspace, *agent, nil + return workspace, agent, nil } type stdioConn struct { From fd7e32c9e2d07b2ca7786460c1b611de87578367 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 15:39:59 +0000 Subject: [PATCH 07/10] chore: block unix forwarding on windows --- agent/agent.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 826886b9823f4..2494e43fe041a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -658,6 +658,10 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { network := u.Scheme addr := u.Host + u.Path if strings.HasPrefix(network, "unix") { + if runtime.GOOS == "windows" { + _ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces")) + return + } addr, err = ExpandPath(addr) if err != nil { _ = writeError(xerrors.Errorf("expand path %q: %w", addr, err)) @@ -665,8 +669,6 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { } } - a.logger.Warn(ctx, "yeah", slog.F("network", network), slog.F("addr", addr)) - d := net.Dialer{Timeout: 3 * time.Second} nconn, err := d.DialContext(ctx, network, addr) if err != nil { From 52142ba6153e069fd3130a401038c7e5ca308527 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 16:39:58 +0000 Subject: [PATCH 08/10] fix problems --- cli/portforward.go | 11 ++++++---- cli/portforward_test.go | 3 +-- cli/ssh.go | 48 ++++++++++++++--------------------------- 3 files changed, 24 insertions(+), 38 deletions(-) diff --git a/cli/portforward.go b/cli/portforward.go index 449c334532c54..6c23577f1601a 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -76,7 +76,7 @@ func portForward() *cobra.Command { return err } - workspace, agent, err := getWorkspaceAndAgent(cmd.Context(), client, organization.ID, codersdk.Me, args[0]) + workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0]) if err != nil { return err } @@ -173,16 +173,19 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderag case "tcp": l, err = net.Listen(spec.listenNetwork, spec.listenAddress) case "udp": - host, port, err := net.SplitHostPort(spec.listenAddress) + var host, port string + host, port, err = net.SplitHostPort(spec.listenAddress) if err != nil { return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err) } - portInt, err := strconv.Atoi(port) + + var portInt int + portInt, err = strconv.Atoi(port) if err != nil { return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) } - l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ //nolint:ineffassign + l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ IP: net.ParseIP(host), Port: portInt, }) diff --git a/cli/portforward_test.go b/cli/portforward_test.go index cc5d9c36a653f..472b336de4af8 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -401,8 +401,7 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders // Start workspace agent in a goroutine cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) - agentClient := &*client - clitest.SetupConfig(t, agentClient, root) + clitest.SetupConfig(t, client, root) agentCtx, agentCancel := context.WithCancel(ctx) t.Cleanup(agentCancel) go func() { diff --git a/cli/ssh.go b/cli/ssh.go index 74338994add99..6e9d91f197fc9 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -48,19 +48,10 @@ func ssh() *cobra.Command { return err } - workspace, agent, err := getWorkspaceAndAgent(cmd.Context(), client, organization.ID, codersdk.Me, args[0]) + workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0]) if err != nil { return err } - if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { - return xerrors.New("workspace must be in start transition to ssh") - } - if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) - if err != nil { - return err - } - } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. @@ -155,13 +146,24 @@ func ssh() *cobra.Command { return cmd } -func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uuid.UUID, userID string, in string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { +func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, orgID uuid.UUID, userID string, in string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { + ctx := cmd.Context() + workspaceParts := strings.Split(in, ".") workspace, err := client.WorkspaceByOwnerAndName(ctx, orgID, userID, workspaceParts[0]) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("get workspace %q: %w", workspaceParts[0], err) } + if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh") + } + if workspace.LatestBuild.Job.CompletedAt == nil { + err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + } if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name) } @@ -178,29 +180,20 @@ func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uu if len(agents) == 0 { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) } - - var ( - // We can't use a pointer because linters are mad about using pointers - // from loop variables - agent codersdk.WorkspaceAgent - agentOK bool - ) + var agent codersdk.WorkspaceAgent if len(workspaceParts) >= 2 { for _, otherAgent := range agents { if otherAgent.Name != workspaceParts[1] { continue } agent = otherAgent - agentOK = true break } - - if !agentOK { + if agent.ID == uuid.Nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1]) } } - - if !agentOK { + if agent.ID == uuid.Nil { if len(agents) > 1 { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") } @@ -210,15 +203,6 @@ func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uu return workspace, agent, nil } -type stdioConn struct { - io.Reader - io.Writer -} - -func (*stdioConn) Close() (err error) { - return nil -} - // Attempt to poll workspace autostop. We write a per-workspace lockfile to // avoid spamming the user with notifications in case of multiple instances // of the CLI running simultaneously. From 6d9ed0c280468ca1fc941db7d200815c62c8a725 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 17 May 2022 17:00:02 +0000 Subject: [PATCH 09/10] chore: fix data race in port-forward test --- cli/portforward_test.go | 51 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 472b336de4af8..0c0d3ddc5fa08 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "testing" "time" @@ -35,7 +36,7 @@ func TestPortForward(t *testing.T) { cmd, root := clitest.New(t, "port-forward", "blah") clitest.SetupConfig(t, client, root) - buf := new(bytes.Buffer) + buf := newThreadSafeBuffer() cmd.SetOut(buf) err := cmd.Execute() @@ -165,7 +166,7 @@ func TestPortForward(t *testing.T) { // the "local" listener. cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) clitest.SetupConfig(t, client, root) - buf := new(bytes.Buffer) + buf := newThreadSafeBuffer() cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -213,7 +214,7 @@ func TestPortForward(t *testing.T) { // the "local" listeners. cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2) clitest.SetupConfig(t, client, root) - buf := new(bytes.Buffer) + buf := newThreadSafeBuffer() cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -268,7 +269,7 @@ func TestPortForward(t *testing.T) { // the "local" listener. cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) clitest.SetupConfig(t, client, root) - buf := new(bytes.Buffer) + buf := newThreadSafeBuffer() cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -329,7 +330,7 @@ func TestPortForward(t *testing.T) { // the "local" listeners. cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...) clitest.SetupConfig(t, client, root) - buf := new(bytes.Buffer) + buf := newThreadSafeBuffer() cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -474,7 +475,7 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { require.Equal(t, len(payload), n, "payload length does not match") } -func waitForPortForwardReady(t *testing.T, output *bytes.Buffer) { +func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) { for i := 0; i < 100; i++ { time.Sleep(250 * time.Millisecond) @@ -491,3 +492,41 @@ type addr struct { network string addr string } + +type threadSafeBuffer struct { + b *bytes.Buffer + mut *sync.RWMutex +} + +func newThreadSafeBuffer() *threadSafeBuffer { + return &threadSafeBuffer{ + b: bytes.NewBuffer(nil), + mut: new(sync.RWMutex), + } +} + +var _ io.Reader = &threadSafeBuffer{} +var _ io.Writer = &threadSafeBuffer{} + +// Read implements io.Reader. +func (b *threadSafeBuffer) Read(p []byte) (int, error) { + b.mut.RLock() + defer b.mut.RUnlock() + + return b.b.Read(p) +} + +// Write implements io.Writer. +func (b *threadSafeBuffer) Write(p []byte) (int, error) { + b.mut.Lock() + defer b.mut.Unlock() + + return b.b.Write(p) +} + +func (b *threadSafeBuffer) String() string { + b.mut.RLock() + defer b.mut.RUnlock() + + return b.b.String() +} From f12ded0cdfc5bd9a8c120824a6fa0d13cd77e34f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 18 May 2022 13:01:12 +0000 Subject: [PATCH 10/10] chore: rename ExpandPath to ExpandRelativeHomePath --- agent/agent.go | 8 ++++---- cli/portforward.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 2494e43fe041a..75787b4cfc5e1 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -662,7 +662,7 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { _ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces")) return } - addr, err = ExpandPath(addr) + addr, err = ExpandRelativeHomePath(addr) if err != nil { _ = writeError(xerrors.Errorf("expand path %q: %w", addr, err)) return @@ -760,9 +760,9 @@ func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { } } -// ExpandPath expands the tilde at the beginning of a path to the current user's -// home directory and returns a full absolute path. -func ExpandPath(in string) (string, error) { +// ExpandRelativeHomePath expands the tilde at the beginning of a path to the +// current user's home directory and returns a full absolute path. +func ExpandRelativeHomePath(in string) (string, error) { usr, err := user.Current() if err != nil { return "", xerrors.Errorf("get current user details: %w", err) diff --git a/cli/portforward.go b/cli/portforward.go index 6c23577f1601a..776e873cdaa5b 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -314,7 +314,7 @@ func parsePort(in string) (uint16, error) { } func parseUnixPath(in string) (string, error) { - path, err := coderagent.ExpandPath(strings.TrimSpace(in)) + path, err := coderagent.ExpandRelativeHomePath(strings.TrimSpace(in)) if err != nil { return "", xerrors.Errorf("tidy path %q: %w", in, err) }