Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/scripts/files_changed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cd "$(git rev-parse --show-toplevel)"

if [[ $(git ls-files --other --modified --exclude-standard) ]]; then
echo "Files have changed:"
git ls-files --other --modified --exclude-standard
git -c color.ui=never status
exit 1
fi
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/rjeczalik/notify v0.9.2
github.com/spf13/cobra v1.2.1
golang.org/x/net v0.0.0-20210614182718-04defd469f4e
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx
}
}

listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token))
listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token), wsnet.TURNProxyWebSocket(u, token))
if err != nil {
return xerrors.Errorf("listen: %w", err)
}
Expand Down
32 changes: 7 additions & 25 deletions internal/cmd/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -12,7 +11,6 @@ import (

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -104,30 +102,14 @@ type tunnneler struct {
}

func (c *tunnneler) start(ctx context.Context) error {
username, password, err := wsnet.TURNCredentials(c.token)
if err != nil {
return xerrors.Errorf("failed to parse credentials from token")
}
server := webrtc.ICEServer{
URLs: []string{wsnet.TURNEndpoint(c.brokerAddr)},
Username: username,
Credential: password,
CredentialType: webrtc.ICECredentialTypePassword,
}

err = wsnet.DialICE(server, nil)
if errors.Is(err, wsnet.ErrInvalidCredentials) {
return xerrors.Errorf("failed to authenticate your user for this workspace")
}
if errors.Is(err, wsnet.ErrMismatchedProtocol) {
return xerrors.Errorf("your TURN server is configured incorrectly. check TLS settings")
}
if err != nil {
return xerrors.Errorf("dial ice: %w", err)
}

c.log.Debug(ctx, "Connecting to workspace...")
wd, err := wsnet.DialWebsocket(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server})
wd, err := wsnet.DialWebsocket(
ctx,
wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token),
&wsnet.DialOptions{
TURNProxy: wsnet.TURNProxyWebSocket(c.brokerAddr, c.token),
},
)
if err != nil {
return xerrors.Errorf("creating workspace dialer: %w", err)
}
Expand Down
22 changes: 0 additions & 22 deletions wsnet/auth.go

This file was deleted.

109 changes: 89 additions & 20 deletions wsnet/conn.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package wsnet

import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"

"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"nhooyr.io/websocket"

"cdr.dev/coder-cli/coder-sdk"
)

const (
Expand All @@ -22,16 +28,6 @@ const (
maxMessageLength = 32 * 1024 // 32 KB
)

// TURNEndpoint returns the TURN address for a Coder baseURL.
func TURNEndpoint(baseURL *url.URL) string {
turnScheme := "turns"
if baseURL.Scheme == httpScheme {
turnScheme = "turn"
}

return fmt.Sprintf("%s:%s:5349?transport=tcp", turnScheme, baseURL.Hostname())
}

// ListenEndpoint returns the Coder endpoint to listen for workspace connections.
func ListenEndpoint(baseURL *url.URL, token string) string {
wsScheme := "wss"
Expand All @@ -50,7 +46,80 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string {
return fmt.Sprintf("%s://%s%s%s%s%s", wsScheme, baseURL.Host, "/api/private/envagent/", workspace, "/connect?session_token=", token)
}

type conn struct {
// TURNWebSocketICECandidate returns a valid relay ICEServer that can be used to
// trigger a TURNWebSocketDialer.
func TURNProxyICECandidate() webrtc.ICEServer {
return webrtc.ICEServer{
URLs: []string{"turn:127.0.0.1:3478?transport=tcp"},
Username: "~magicalusername~",
Credential: "~magicalpassword~",
CredentialType: webrtc.ICECredentialTypePassword,
}
}

// TURNWebSocketDialer proxies all TURN traffic through a WebSocket.
func TURNProxyWebSocket(baseURL *url.URL, token string) proxy.Dialer {
return &turnProxyDialer{
baseURL: baseURL,
token: token,
}
}

// Proxies all TURN ICEServer traffic through this dialer.
// References Coder APIs with a specific token.
type turnProxyDialer struct {
baseURL *url.URL
token string
}

func (t *turnProxyDialer) Dial(network, addr string) (c net.Conn, err error) {
headers := http.Header{}
headers.Set("Session-Token", t.token)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()

// Copy the baseURL so we can adjust path.
url := *t.baseURL
url.Scheme = "wss"
if url.Scheme == httpScheme {
url.Scheme = "ws"
}
url.Path = "/api/private/turn"
conn, resp, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{
HTTPHeader: headers,
})
if err != nil {
if resp != nil {
defer resp.Body.Close()
return nil, coder.NewHTTPError(resp)
}
return nil, fmt.Errorf("dial: %w", err)
}

return &turnProxyConn{
websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
}, nil
}

// turnProxyConn is a net.Conn wrapper that returns a TCPAddr for the
// LocalAddr function. pion/ice unsafely checks the types. See:
// https://github.com/pion/ice/blob/e78f26fb435987420546c70369ade5d713beca39/gather.go#L448
type turnProxyConn struct {
net.Conn
}

// The LocalAddr specified here doesn't really matter,
// it just has to be of type "TCPAddr".
func (*turnProxyConn) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 0,
}
}

// Properly buffers data for data channel connections.
type dataChannelConn struct {
addr *net.UnixAddr
dc *webrtc.DataChannel
rw datachannel.ReadWriteCloser
Expand All @@ -62,7 +131,7 @@ type conn struct {
writeMutex sync.Mutex
}

func (c *conn) init() {
func (c *dataChannelConn) init() {
c.sendMore = make(chan struct{}, 1)
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
c.dc.OnBufferedAmountLow(func() {
Expand All @@ -78,11 +147,11 @@ func (c *conn) init() {
})
}

func (c *conn) Read(b []byte) (n int, err error) {
func (c *dataChannelConn) Read(b []byte) (n int, err error) {
return c.rw.Read(b)
}

func (c *conn) Write(b []byte) (n int, err error) {
func (c *dataChannelConn) Write(b []byte) (n int, err error) {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
if len(b) > maxMessageLength {
Expand All @@ -101,7 +170,7 @@ func (c *conn) Write(b []byte) (n int, err error) {
return c.rw.Write(b)
}

func (c *conn) Close() error {
func (c *dataChannelConn) Close() error {
c.closedMutex.Lock()
defer c.closedMutex.Unlock()
if !c.closed {
Expand All @@ -111,22 +180,22 @@ func (c *conn) Close() error {
return c.dc.Close()
}

func (c *conn) LocalAddr() net.Addr {
func (c *dataChannelConn) LocalAddr() net.Addr {
return c.addr
}

func (c *conn) RemoteAddr() net.Addr {
func (c *dataChannelConn) RemoteAddr() net.Addr {
return c.addr
}

func (c *conn) SetDeadline(t time.Time) error {
func (c *dataChannelConn) SetDeadline(t time.Time) error {
return nil
}

func (c *conn) SetReadDeadline(t time.Time) error {
func (c *dataChannelConn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *conn) SetWriteDeadline(t time.Time) error {
func (c *dataChannelConn) SetWriteDeadline(t time.Time) error {
return nil
}
37 changes: 29 additions & 8 deletions wsnet/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ import (

"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"nhooyr.io/websocket"

"cdr.dev/coder-cli/coder-sdk"
)

// DialOptions are configurable options for a wsnet connection.
type DialOptions struct {
// ICEServers is an array of STUN or TURN servers to use for negotiation purposes.
// See: https://developer.mozilla.org/en-US/docs/Web/API/RTCConfiguration/iceServers
ICEServers []webrtc.ICEServer

// TURNProxy is a function used to proxy all TURN traffic.
// If specified without ICEServers, `TURNProxyICECandidate`
// will be used.
TURNProxy proxy.Dialer
}

// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) {
func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) {
conn, resp, err := websocket.Dial(ctx, broker, nil)
if err != nil {
if resp != nil {
Expand All @@ -35,16 +48,24 @@ func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICESe
// We should close the socket intentionally.
_ = conn.Close(websocket.StatusInternalError, "an error occurred")
}()
return Dial(nconn, iceServers)
return Dial(nconn, options)
}

// Dial negotiates a connection to a listener.
func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
if iceServers == nil {
iceServers = []webrtc.ICEServer{}
func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) {
if options == nil {
options = &DialOptions{}
}
if options.ICEServers == nil {
options.ICEServers = []webrtc.ICEServer{}
}
// If the TURNProxy is specified and ICEServers aren't,
// it's safe to assume we can inject the default proxy candidate.
if len(options.ICEServers) == 0 && options.TURNProxy != nil {
options.ICEServers = []webrtc.ICEServer{TURNProxyICECandidate()}
}

rtc, err := newPeerConnection(iceServers)
rtc, err := newPeerConnection(options.ICEServers, options.TURNProxy)
if err != nil {
return nil, fmt.Errorf("create peer connection: %w", err)
}
Expand All @@ -70,7 +91,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {

offerMessage, err := json.Marshal(&BrokerMessage{
Offer: &offer,
Servers: iceServers,
Servers: options.ICEServers,
})
if err != nil {
return nil, fmt.Errorf("marshal offer message: %w", err)
Expand Down Expand Up @@ -287,7 +308,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
return nil, ctx.Err()
}

c := &conn{
c := &dataChannelConn{
addr: &net.UnixAddr{
Name: address,
Net: network,
Expand Down
Loading