Skip to content

feat: add support for X11 forwarding #7205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ type agent struct {
}

func (a *agent) init(ctx context.Context) {
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout)
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "")
if err != nil {
panic(err)
}
Expand Down
32 changes: 25 additions & 7 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
"github.com/spf13/afero"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -48,6 +49,7 @@ const (

type Server struct {
mu sync.RWMutex // Protects following.
fs afero.Fs
listeners map[net.Listener]struct{}
conns map[net.Conn]struct{}
sessions map[ssh.Session]struct{}
Expand All @@ -56,8 +58,9 @@ type Server struct {
// a lock on mu but protected by closing.
wg sync.WaitGroup

logger slog.Logger
srv *ssh.Server
logger slog.Logger
srv *ssh.Server
x11SocketDir string

Env map[string]string
AgentToken func() string
Expand All @@ -68,7 +71,7 @@ type Server struct {
connCountSSHSession atomic.Int64
}

func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
Expand All @@ -80,15 +83,20 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
if err != nil {
return nil, err
}
if x11SocketDir == "" {
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
}

forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := &forwardedUnixHandler{log: logger}

s := &Server{
listeners: make(map[net.Listener]struct{}),
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
x11SocketDir: x11SocketDir,
}

s.srv = &ssh.Server{
Expand Down Expand Up @@ -125,6 +133,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
},
X11Callback: s.x11Callback,
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
NoClientAuth: true,
Expand Down Expand Up @@ -163,6 +172,15 @@ func (s *Server) sessionHandler(session ssh.Session) {

ctx := session.Context()

x11, hasX11 := session.X11()
if hasX11 {
handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
return
}
}

switch ss := session.Subsystem(); ss {
case "":
case "sftp":
Expand Down
6 changes: 4 additions & 2 deletions agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"testing"

"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand All @@ -32,7 +33,7 @@ func TestNewServer_ServeClient(t *testing.T) {

ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, 0)
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
require.NoError(t, err)

// The assumption is that these are set before serving SSH connections.
Expand All @@ -50,6 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) {
}()

c := sshClient(t, ln.Addr().String())

var b bytes.Buffer
sess, err := c.NewSession()
sess.Stdout = &b
Expand All @@ -72,7 +74,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {

ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, 0)
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
require.NoError(t, err)

// The assumption is that these are set before serving SSH connections.
Expand Down
190 changes: 190 additions & 0 deletions agent/agentssh/x11.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package agentssh

import (
"context"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"

"github.com/gliderlabs/ssh"
"github.com/gofrs/flock"
"github.com/spf13/afero"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"cdr.dev/slog"
)

// x11Callback is called when the client requests X11 forwarding.
// It adds an Xauthority entry to the Xauthority file.
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
hostname, err := os.Hostname()
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
return false
}

err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
if err != nil {
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
return false
}

err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
return false
}
return true
}

// x11Handler is called when a session has requested X11 forwarding.
// It listens for X11 connections and forwards them to the client.
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !valid {
s.logger.Warn(ctx, "failed to get server connection")
return false
}
listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)))
if err != nil {
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
return false
}
s.trackListener(listener, true)

go func() {
defer listener.Close()
defer s.trackListener(listener, false)
handledFirstConnection := false

for {
conn, err := listener.Accept()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x11.SingleConnection is not honored

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
return
}
if x11.SingleConnection && handledFirstConnection {
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
_ = conn.Close()
continue
}
handledFirstConnection = true

unixConn, ok := conn.(*net.UnixConn)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
return
}
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
return
}

channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
OriginatorAddress string
OriginatorPort uint32
}{
OriginatorAddress: unixAddr.Name,
OriginatorPort: 0,
}))
if err != nil {
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
return
}
go gossh.DiscardRequests(reqs)
go Bicopy(ctx, conn, channel)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to track that conn and channel are actually closed by the time (*Server).Close completes. Or essentially, that bicopy has exited, but I think we can keep it like this unless we start seeing goleak errors in the tests.

}
}()
return true
}

// addXauthEntry adds an Xauthority entry to the Xauthority file.
// The Xauthority file is located at ~/.Xauthority.
func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
// Get the Xauthority file path
homeDir, err := os.UserHomeDir()
if err != nil {
return xerrors.Errorf("failed to get user home directory: %w", err)
}

xauthPath := filepath.Join(homeDir, ".Xauthority")

lock := flock.New(xauthPath)
defer lock.Close()
ok, err := lock.TryLockContext(ctx, 100*time.Millisecond)
if !ok {
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
}
if err != nil {
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
}

// Open or create the Xauthority file
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o600)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider flock-ing the file so we're the only one writing to it? Not sure if xauth respects that though.

How about multiple SSH connections with X11 forwarding?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this works either actually... I suppose we should flock?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a flock!

if err != nil {
return xerrors.Errorf("failed to open Xauthority file: %w", err)
}
defer file.Close()

// Convert the authCookie from hex string to byte slice
authCookieBytes, err := hex.DecodeString(authCookie)
if err != nil {
return xerrors.Errorf("failed to decode auth cookie: %w", err)
}

// Write Xauthority entry
family := uint16(0x0100) // FamilyLocal
err = binary.Write(file, binary.BigEndian, family)
if err != nil {
return xerrors.Errorf("failed to write family: %w", err)
}

err = binary.Write(file, binary.BigEndian, uint16(len(host)))
if err != nil {
return xerrors.Errorf("failed to write host length: %w", err)
}
_, err = file.WriteString(host)
if err != nil {
return xerrors.Errorf("failed to write host: %w", err)
}

err = binary.Write(file, binary.BigEndian, uint16(len(display)))
if err != nil {
return xerrors.Errorf("failed to write display length: %w", err)
}
_, err = file.WriteString(display)
if err != nil {
return xerrors.Errorf("failed to write display: %w", err)
}

err = binary.Write(file, binary.BigEndian, uint16(len(authProtocol)))
if err != nil {
return xerrors.Errorf("failed to write auth protocol length: %w", err)
}
_, err = file.WriteString(authProtocol)
if err != nil {
return xerrors.Errorf("failed to write auth protocol: %w", err)
}

err = binary.Write(file, binary.BigEndian, uint16(len(authCookieBytes)))
if err != nil {
return xerrors.Errorf("failed to write auth cookie length: %w", err)
}
_, err = file.Write(authCookieBytes)
if err != nil {
return xerrors.Errorf("failed to write auth cookie: %w", err)
}

return nil
}
Loading