Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remote forward
  • Loading branch information
mtojek committed Jul 18, 2023
commit 9f5e5003be7994a756286a792cbf381b5af96094
104 changes: 104 additions & 0 deletions cli/remoteforward.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package cli

import (
"context"
"fmt"
"io"
"net"
"regexp"
"strconv"

gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"github.com/coder/coder/agent/agentssh"
)

// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
// cookie which is written to the connection before forwarding.
type cookieAddr struct {
net.Addr
cookie []byte
}

var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
Copy link
Member

@johnstcn johnstcn Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment to clarify what this should match

This could also be done with a simple strings.Split()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment to clarify what this should match

👍

This could also be done with a simple strings.Split()

Right, but with regexp we can check if ports are numbers :)


func validateRemoteForward(flag string) bool {
return remoteForwardRegex.MatchString(flag)
}

func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
matches := remoteForwardRegex.FindStringSubmatch(flag)

// Format:
// remote_port:local_address:local_port
remotePort, err := strconv.Atoi(matches[1])
if err != nil {
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
}
localAddress, err := net.ResolveIPAddr("ip", matches[2])
if err != nil {
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
}
localPort, err := strconv.Atoi(matches[3])
if err != nil {
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
}

localAddr := &net.TCPAddr{
IP: localAddress.IP,
Port: localPort,
}

remoteAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: remotePort,
}
return localAddr, remoteAddr, nil
}

// sshRemoteForward starts forwarding connections from a remote listener to a
// local address via SSH in a goroutine.
//
// Accepts a `cookieAddr` as the local address.
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for reviewers: moved from ssh.go

listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
if err != nil {
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
}

go func() {
for {
remoteConn, err := listener.Accept()
if err != nil {
if ctx.Err() == nil {
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
}
return
}

go func() {
defer remoteConn.Close()

localConn, err := net.Dial(localAddr.Network(), localAddr.String())
if err != nil {
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
return
}
defer localConn.Close()

if c, ok := localAddr.(cookieAddr); ok {
_, err = localConn.Write(c.cookie)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
return
}
}

agentssh.Bicopy(ctx, localConn, remoteConn)
}()
}
}()

return listener, nil
}
87 changes: 3 additions & 84 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ import (
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
Expand All @@ -29,7 +26,6 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"

"github.com/coder/coder/agent/agentssh"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
Expand All @@ -42,8 +38,6 @@ import (
var (
workspacePollInterval = time.Minute
autostopNotifyCountdown = []time.Duration{30 * time.Minute}

remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
)

//nolint:gocyclo
Expand Down Expand Up @@ -128,7 +122,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
}

if remoteForward != "" {
isValid := remoteForwardRegex.MatchString(remoteForward)
isValid := validateRemoteForward(remoteForward)
if !isValid {
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
}
Expand Down Expand Up @@ -317,31 +311,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
}

if remoteForward != "" {
matches := remoteForwardRegex.FindStringSubmatch(remoteForward)

// Format:
// remote_port:local_address:local_port
remotePort, err := strconv.Atoi(matches[1])
if err != nil {
return xerrors.Errorf("remote port is invalid: %w", err)
}
localAddress, err := net.ResolveIPAddr("ip", matches[2])
if err != nil {
return xerrors.Errorf("local address is invalid: %w", err)
}
localPort, err := strconv.Atoi(matches[3])
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
if err != nil {
return xerrors.Errorf("local port is invalid: %w", err)
}

localAddr := &net.TCPAddr{
IP: localAddress.IP,
Port: localPort,
}

remoteAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: remotePort,
return err
}

closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
Expand Down Expand Up @@ -817,56 +789,3 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {

return string(bytes.TrimSpace(remoteSocket)), nil
}

// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
// cookie which is written to the connection before forwarding.
type cookieAddr struct {
net.Addr
cookie []byte
}

// sshRemoteForward starts forwarding connections from a remote listener to a
// local address via SSH in a goroutine.
//
// Accepts a `cookieAddr` as the local address.
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
if err != nil {
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
}

go func() {
for {
remoteConn, err := listener.Accept()
if err != nil {
if ctx.Err() == nil {
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
}
return
}

go func() {
defer remoteConn.Close()

localConn, err := net.Dial(localAddr.Network(), localAddr.String())
if err != nil {
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
return
}
defer localConn.Close()

if c, ok := localAddr.(cookieAddr); ok {
_, err = localConn.Write(c.cookie)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
return
}
}

agentssh.Bicopy(ctx, localConn, remoteConn)
}()
}
}()

return listener, nil
}