-
Notifications
You must be signed in to change notification settings - Fork 899
feat: add port-forward subcommand #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f104e02
1972629
e3eb839
46092cc
f861ea8
6dfd2f6
fd7e32c
8eac40a
52142ba
6d9ed0c
f12ded0
ee25623
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import ( | |
"fmt" | ||
"io" | ||
"net" | ||
"net/url" | ||
"os" | ||
"os/exec" | ||
"os/user" | ||
|
@@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { | |
go a.sshServer.HandleConn(channel.NetConn()) | ||
case "reconnecting-pty": | ||
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn()) | ||
case "dial": | ||
go a.handleDial(ctx, channel.Label(), channel.NetConn()) | ||
default: | ||
a.logger.Warn(ctx, "unhandled protocol from channel", | ||
slog.F("protocol", channel.Protocol()), | ||
|
@@ -617,6 +620,57 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne | |
} | ||
} | ||
|
||
// dialResponse is written to datachannels with protocol "dial" by the agent as | ||
// the first packet to signify whether the dial succeeded or failed. | ||
type dialResponse struct { | ||
Error string `json:"error,omitempty"` | ||
} | ||
|
||
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { | ||
defer conn.Close() | ||
|
||
writeError := func(responseError error) error { | ||
msg := "" | ||
if responseError != nil { | ||
msg = responseError.Error() | ||
if !xerrors.Is(responseError, io.EOF) { | ||
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError)) | ||
} | ||
} | ||
b, err := json.Marshal(dialResponse{ | ||
Error: msg, | ||
}) | ||
if err != nil { | ||
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err)) | ||
return xerrors.Errorf("marshal agent webrtc dial response: %w", err) | ||
} | ||
|
||
_, err = conn.Write(b) | ||
return err | ||
} | ||
|
||
u, err := url.Parse(label) | ||
if err != nil { | ||
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) | ||
return | ||
} | ||
|
||
network := u.Scheme | ||
addr := u.Host + u.Path | ||
nconn, err := net.Dial(network, addr) | ||
if err != nil { | ||
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) | ||
return | ||
} | ||
|
||
err = writeError(nil) | ||
if err != nil { | ||
return | ||
} | ||
|
||
bicopy(ctx, conn, nconn) | ||
} | ||
|
||
// isClosed returns whether the API is closed or not. | ||
func (a *agent) isClosed() bool { | ||
select { | ||
|
@@ -662,3 +716,23 @@ func (r *reconnectingPTY) Close() { | |
r.circularBuffer.Reset() | ||
r.timeout.Stop() | ||
} | ||
|
||
// bicopy copies all of the data between the two connections and will close them | ||
// after one or both of them are done writing. If the context is canceled, both | ||
// of the connections will be closed. | ||
func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { | ||
defer c1.Close() | ||
defer c2.Close() | ||
|
||
ctx, cancel := context.WithCancel(ctx) | ||
|
||
copyFunc := func(dst io.WriteCloser, src io.Reader) { | ||
defer cancel() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it lead to problems if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I switched this to a WaitGroup to avoid this case... not sure if it could've happened in practice but better to be safe than sorry |
||
_, _ = io.Copy(dst, src) | ||
} | ||
|
||
go copyFunc(c1, c2) | ||
go copyFunc(c2, c1) | ||
|
||
<-ctx.Done() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ import ( | |
"time" | ||
|
||
"github.com/google/uuid" | ||
"github.com/pion/udp" | ||
"github.com/pion/webrtc/v3" | ||
"github.com/pkg/sftp" | ||
"github.com/stretchr/testify/require" | ||
|
@@ -234,6 +235,114 @@ func TestAgent(t *testing.T) { | |
findEcho() | ||
findEcho() | ||
}) | ||
|
||
t.Run("Dial", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
cases := []struct { | ||
name string | ||
setup func(t *testing.T) net.Listener | ||
}{ | ||
Comment on lines
+242
to
+245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are done beautifully! 😍 |
||
{ | ||
name: "TCP", | ||
setup: func(t *testing.T) net.Listener { | ||
l, err := net.Listen("tcp", "127.0.0.1:0") | ||
require.NoError(t, err, "create TCP listener") | ||
return l | ||
}, | ||
}, | ||
{ | ||
name: "UDP", | ||
setup: func(t *testing.T) net.Listener { | ||
addr := net.UDPAddr{ | ||
IP: net.ParseIP("127.0.0.1"), | ||
Port: 0, | ||
} | ||
l, err := udp.Listen("udp", &addr) | ||
require.NoError(t, err, "create UDP listener") | ||
return l | ||
}, | ||
}, | ||
{ | ||
name: "Unix", | ||
setup: func(t *testing.T) net.Listener { | ||
if runtime.GOOS == "windows" { | ||
t.Skip("Unix socket forwarding isn't supported on Windows") | ||
} | ||
|
||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") | ||
require.NoError(t, err, "create temp dir for unix listener") | ||
t.Cleanup(func() { | ||
_ = os.RemoveAll(tmpDir) | ||
}) | ||
|
||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) | ||
require.NoError(t, err, "create UDP listener") | ||
return l | ||
}, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
c := c | ||
t.Run(c.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Setup listener | ||
l := c.setup(t) | ||
defer l.Close() | ||
go func() { | ||
for { | ||
c, err := l.Accept() | ||
if err != nil { | ||
return | ||
} | ||
|
||
testAccept(t, c) | ||
} | ||
}() | ||
|
||
// Try to dial the listener over WebRTC | ||
conn := setupAgent(t, agent.Metadata{}, 0) | ||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) | ||
require.NoError(t, err) | ||
defer conn1.Close() | ||
testDial(t, conn1) | ||
|
||
// Dial again using the same WebRTC client | ||
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) | ||
require.NoError(t, err) | ||
defer conn2.Close() | ||
testDial(t, conn2) | ||
}) | ||
} | ||
}) | ||
|
||
t.Run("DialError", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
if runtime.GOOS == "windows" { | ||
// This test uses Unix listeners so we can very easily ensure that | ||
// no other tests decide to listen on the same random port we | ||
// picked. | ||
t.Skip("this test is unsupported on Windows") | ||
return | ||
} | ||
|
||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") | ||
require.NoError(t, err, "create temp dir") | ||
t.Cleanup(func() { | ||
_ = os.RemoveAll(tmpDir) | ||
}) | ||
|
||
// Try to dial the non-existent Unix socket over WebRTC | ||
conn := setupAgent(t, agent.Metadata{}, 0) | ||
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock")) | ||
require.Error(t, err) | ||
require.ErrorContains(t, err, "remote dial error") | ||
require.ErrorContains(t, err, "no such file") | ||
require.Nil(t, netConn) | ||
}) | ||
} | ||
|
||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { | ||
|
@@ -303,3 +412,36 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) | |
Conn: conn, | ||
} | ||
} | ||
|
||
var dialTestPayload = []byte("dean-was-here123") | ||
|
||
func testDial(t *testing.T, c net.Conn) { | ||
t.Helper() | ||
|
||
// Dials will write and then expect echo back | ||
n, err := c.Write(dialTestPayload) | ||
require.NoError(t, err, "write test payload") | ||
require.Equal(t, len(dialTestPayload), n, "test payload length does not match") | ||
|
||
b := make([]byte, len(dialTestPayload)+16) | ||
n, err = c.Read(b) | ||
require.NoError(t, err, "read test payload") | ||
require.Equal(t, len(dialTestPayload), n, "read payload length does not match") | ||
require.Equal(t, dialTestPayload, b[:n]) | ||
} | ||
|
||
func testAccept(t *testing.T, c net.Conn) { | ||
t.Helper() | ||
defer c.Close() | ||
|
||
// Accepts will read then echo | ||
b := make([]byte, len(dialTestPayload)+16) | ||
n, err := c.Read(b) | ||
require.NoError(t, err, "read test payload") | ||
require.Equal(t, len(dialTestPayload), n, "read payload length does not match") | ||
require.Equal(t, dialTestPayload, b[:n]) | ||
|
||
n, err = c.Write(dialTestPayload) | ||
require.NoError(t, err, "write test payload") | ||
require.Equal(t, len(dialTestPayload), n, "test payload length does not match") | ||
} |
Uh oh!
There was an error while loading. Please reload this page.