diff --git a/ci/scripts/files_changed.sh b/ci/scripts/files_changed.sh index 759c68d3..490cb5ad 100755 --- a/ci/scripts/files_changed.sh +++ b/ci/scripts/files_changed.sh @@ -6,6 +6,7 @@ cd "$(git rev-parse --show-toplevel)" if [[ $(git ls-files --other --modified --exclude-standard) ]]; then echo "Files have changed:" + git ls-files --other --modified --exclude-standard git -c color.ui=never status exit 1 fi diff --git a/go.mod b/go.mod index 960091bb..46a9d7ce 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/rjeczalik/notify v0.9.2 github.com/spf13/cobra v1.2.1 + golang.org/x/net v0.0.0-20210614182718-04defd469f4e golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 diff --git a/internal/cmd/agent.go b/internal/cmd/agent.go index 38853dd1..c19bdfee 100644 --- a/internal/cmd/agent.go +++ b/internal/cmd/agent.go @@ -73,7 +73,7 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx } } - listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token)) + listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token), wsnet.TURNProxyWebSocket(u, token)) if err != nil { return xerrors.Errorf("listen: %w", err) } diff --git a/internal/cmd/tunnel.go b/internal/cmd/tunnel.go index ae7fe14c..7b14cf33 100644 --- a/internal/cmd/tunnel.go +++ b/internal/cmd/tunnel.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "errors" "fmt" "io" "net" @@ -12,7 +11,6 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/pion/webrtc/v3" "github.com/spf13/cobra" "golang.org/x/xerrors" @@ -104,30 +102,14 @@ type tunnneler struct { } func (c *tunnneler) start(ctx context.Context) error { - username, password, err := wsnet.TURNCredentials(c.token) - if err != nil { - return xerrors.Errorf("failed to parse credentials from token") - } - server := webrtc.ICEServer{ - URLs: []string{wsnet.TURNEndpoint(c.brokerAddr)}, - Username: username, - Credential: password, - CredentialType: webrtc.ICECredentialTypePassword, - } - - err = wsnet.DialICE(server, nil) - if errors.Is(err, wsnet.ErrInvalidCredentials) { - return xerrors.Errorf("failed to authenticate your user for this workspace") - } - if errors.Is(err, wsnet.ErrMismatchedProtocol) { - return xerrors.Errorf("your TURN server is configured incorrectly. check TLS settings") - } - if err != nil { - return xerrors.Errorf("dial ice: %w", err) - } - c.log.Debug(ctx, "Connecting to workspace...") - wd, err := wsnet.DialWebsocket(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server}) + wd, err := wsnet.DialWebsocket( + ctx, + wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), + &wsnet.DialOptions{ + TURNProxy: wsnet.TURNProxyWebSocket(c.brokerAddr, c.token), + }, + ) if err != nil { return xerrors.Errorf("creating workspace dialer: %w", err) } diff --git a/wsnet/auth.go b/wsnet/auth.go deleted file mode 100644 index a5daf45e..00000000 --- a/wsnet/auth.go +++ /dev/null @@ -1,22 +0,0 @@ -package wsnet - -import ( - "crypto/sha256" - "encoding/base64" - "errors" - "strings" -) - -// TURNCredentials returns a username and password pair -// for a Coder token. -func TURNCredentials(token string) (username, password string, err error) { - str := strings.SplitN(token, "-", 2) - if len(str) != 2 { - err = errors.New("invalid token format") - return - } - username = str[0] - hash := sha256.Sum256([]byte(str[1])) - password = base64.StdEncoding.EncodeToString(hash[:]) - return -} diff --git a/wsnet/conn.go b/wsnet/conn.go index b5dea0a5..608c5c70 100644 --- a/wsnet/conn.go +++ b/wsnet/conn.go @@ -1,14 +1,20 @@ package wsnet import ( + "context" "fmt" "net" + "net/http" "net/url" "sync" "time" "github.com/pion/datachannel" "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" + "nhooyr.io/websocket" + + "cdr.dev/coder-cli/coder-sdk" ) const ( @@ -22,16 +28,6 @@ const ( maxMessageLength = 32 * 1024 // 32 KB ) -// TURNEndpoint returns the TURN address for a Coder baseURL. -func TURNEndpoint(baseURL *url.URL) string { - turnScheme := "turns" - if baseURL.Scheme == httpScheme { - turnScheme = "turn" - } - - return fmt.Sprintf("%s:%s:5349?transport=tcp", turnScheme, baseURL.Hostname()) -} - // ListenEndpoint returns the Coder endpoint to listen for workspace connections. func ListenEndpoint(baseURL *url.URL, token string) string { wsScheme := "wss" @@ -50,7 +46,80 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string { return fmt.Sprintf("%s://%s%s%s%s%s", wsScheme, baseURL.Host, "/api/private/envagent/", workspace, "/connect?session_token=", token) } -type conn struct { +// TURNWebSocketICECandidate returns a valid relay ICEServer that can be used to +// trigger a TURNWebSocketDialer. +func TURNProxyICECandidate() webrtc.ICEServer { + return webrtc.ICEServer{ + URLs: []string{"turn:127.0.0.1:3478?transport=tcp"}, + Username: "~magicalusername~", + Credential: "~magicalpassword~", + CredentialType: webrtc.ICECredentialTypePassword, + } +} + +// TURNWebSocketDialer proxies all TURN traffic through a WebSocket. +func TURNProxyWebSocket(baseURL *url.URL, token string) proxy.Dialer { + return &turnProxyDialer{ + baseURL: baseURL, + token: token, + } +} + +// Proxies all TURN ICEServer traffic through this dialer. +// References Coder APIs with a specific token. +type turnProxyDialer struct { + baseURL *url.URL + token string +} + +func (t *turnProxyDialer) Dial(network, addr string) (c net.Conn, err error) { + headers := http.Header{} + headers.Set("Session-Token", t.token) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + // Copy the baseURL so we can adjust path. + url := *t.baseURL + url.Scheme = "wss" + if url.Scheme == httpScheme { + url.Scheme = "ws" + } + url.Path = "/api/private/turn" + conn, resp, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ + HTTPHeader: headers, + }) + if err != nil { + if resp != nil { + defer resp.Body.Close() + return nil, coder.NewHTTPError(resp) + } + return nil, fmt.Errorf("dial: %w", err) + } + + return &turnProxyConn{ + websocket.NetConn(context.Background(), conn, websocket.MessageBinary), + }, nil +} + +// turnProxyConn is a net.Conn wrapper that returns a TCPAddr for the +// LocalAddr function. pion/ice unsafely checks the types. See: +// https://github.com/pion/ice/blob/e78f26fb435987420546c70369ade5d713beca39/gather.go#L448 +type turnProxyConn struct { + net.Conn +} + +// The LocalAddr specified here doesn't really matter, +// it just has to be of type "TCPAddr". +func (*turnProxyConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + } +} + +// Properly buffers data for data channel connections. +type dataChannelConn struct { addr *net.UnixAddr dc *webrtc.DataChannel rw datachannel.ReadWriteCloser @@ -62,7 +131,7 @@ type conn struct { writeMutex sync.Mutex } -func (c *conn) init() { +func (c *dataChannelConn) init() { c.sendMore = make(chan struct{}, 1) c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) c.dc.OnBufferedAmountLow(func() { @@ -78,11 +147,11 @@ func (c *conn) init() { }) } -func (c *conn) Read(b []byte) (n int, err error) { +func (c *dataChannelConn) Read(b []byte) (n int, err error) { return c.rw.Read(b) } -func (c *conn) Write(b []byte) (n int, err error) { +func (c *dataChannelConn) Write(b []byte) (n int, err error) { c.writeMutex.Lock() defer c.writeMutex.Unlock() if len(b) > maxMessageLength { @@ -101,7 +170,7 @@ func (c *conn) Write(b []byte) (n int, err error) { return c.rw.Write(b) } -func (c *conn) Close() error { +func (c *dataChannelConn) Close() error { c.closedMutex.Lock() defer c.closedMutex.Unlock() if !c.closed { @@ -111,22 +180,22 @@ func (c *conn) Close() error { return c.dc.Close() } -func (c *conn) LocalAddr() net.Addr { +func (c *dataChannelConn) LocalAddr() net.Addr { return c.addr } -func (c *conn) RemoteAddr() net.Addr { +func (c *dataChannelConn) RemoteAddr() net.Addr { return c.addr } -func (c *conn) SetDeadline(t time.Time) error { +func (c *dataChannelConn) SetDeadline(t time.Time) error { return nil } -func (c *conn) SetReadDeadline(t time.Time) error { +func (c *dataChannelConn) SetReadDeadline(t time.Time) error { return nil } -func (c *conn) SetWriteDeadline(t time.Time) error { +func (c *dataChannelConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/wsnet/dial.go b/wsnet/dial.go index 0beb2232..362bbab9 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -12,13 +12,26 @@ import ( "github.com/pion/datachannel" "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" "nhooyr.io/websocket" "cdr.dev/coder-cli/coder-sdk" ) +// DialOptions are configurable options for a wsnet connection. +type DialOptions struct { + // ICEServers is an array of STUN or TURN servers to use for negotiation purposes. + // See: https://developer.mozilla.org/en-US/docs/Web/API/RTCConfiguration/iceServers + ICEServers []webrtc.ICEServer + + // TURNProxy is a function used to proxy all TURN traffic. + // If specified without ICEServers, `TURNProxyICECandidate` + // will be used. + TURNProxy proxy.Dialer +} + // DialWebsocket dials the broker with a WebSocket and negotiates a connection. -func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) { +func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) { conn, resp, err := websocket.Dial(ctx, broker, nil) if err != nil { if resp != nil { @@ -35,16 +48,24 @@ func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICESe // We should close the socket intentionally. _ = conn.Close(websocket.StatusInternalError, "an error occurred") }() - return Dial(nconn, iceServers) + return Dial(nconn, options) } // Dial negotiates a connection to a listener. -func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { - if iceServers == nil { - iceServers = []webrtc.ICEServer{} +func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { + if options == nil { + options = &DialOptions{} + } + if options.ICEServers == nil { + options.ICEServers = []webrtc.ICEServer{} + } + // If the TURNProxy is specified and ICEServers aren't, + // it's safe to assume we can inject the default proxy candidate. + if len(options.ICEServers) == 0 && options.TURNProxy != nil { + options.ICEServers = []webrtc.ICEServer{TURNProxyICECandidate()} } - rtc, err := newPeerConnection(iceServers) + rtc, err := newPeerConnection(options.ICEServers, options.TURNProxy) if err != nil { return nil, fmt.Errorf("create peer connection: %w", err) } @@ -70,7 +91,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { offerMessage, err := json.Marshal(&BrokerMessage{ Offer: &offer, - Servers: iceServers, + Servers: options.ICEServers, }) if err != nil { return nil, fmt.Errorf("marshal offer message: %w", err) @@ -287,7 +308,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. return nil, ctx.Err() } - c := &conn{ + c := &dataChannelConn{ addr: &net.UnixAddr{ Name: address, Net: network, diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 5d2e3884..6ad27866 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -35,7 +35,9 @@ func ExampleDial_basic() { } } - dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers) + dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialOptions{ + ICEServers: servers, + }) if err != nil { // Do something... } @@ -53,7 +55,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr) + _, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -73,18 +75,20 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr) + _, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return } turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) - dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ - URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, - Username: "example", - Credential: testPass, - CredentialType: webrtc.ICECredentialTypePassword, - }}) + dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ + ICEServers: []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: testPass, + CredentialType: webrtc.ICECredentialTypePassword, + }}, + }) if err != nil { t.Error(err) return @@ -102,7 +106,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr) + _, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -141,7 +145,7 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr) + _, err = Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -180,7 +184,7 @@ func TestDial(t *testing.T) { _, _ = listener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) - srv, err := Listen(context.Background(), listenAddr) + srv, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -207,7 +211,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr) + _, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -241,18 +245,20 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr) + _, err = Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return } turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) - dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ - URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, - Username: "example", - Credential: testPass, - CredentialType: webrtc.ICECredentialTypePassword, - }}) + dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ + ICEServers: []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: testPass, + CredentialType: webrtc.ICECredentialTypePassword, + }}, + }) if err != nil { t.Error(err) return @@ -276,7 +282,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr) + _, err := Listen(context.Background(), listenAddr, nil) if err != nil { t.Error(err) return @@ -327,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) { } }() connectAddr, listenAddr := createDumbBroker(b) - _, err = Listen(context.Background(), listenAddr) + _, err = Listen(context.Background(), listenAddr, nil) if err != nil { b.Error(err) return diff --git a/wsnet/listen.go b/wsnet/listen.go index 3c7c3b3e..b29bbdb3 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/yamux" "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" "nhooyr.io/websocket" "cdr.dev/coder-cli/coder-sdk" @@ -39,10 +40,11 @@ type DialChannelResponse struct { // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. -func Listen(ctx context.Context, broker string) (io.Closer, error) { +func Listen(ctx context.Context, broker string, tcpProxy proxy.Dialer) (io.Closer, error) { l := &listener{ broker: broker, connClosers: make([]io.Closer, 0), + tcpProxy: tcpProxy, } // We do a one-off dial outside of the loop to ensure the initial // connection is successful. If not, there's likely an error the @@ -83,7 +85,8 @@ func Listen(ctx context.Context, broker string) (io.Closer, error) { } type listener struct { - broker string + broker string + tcpProxy proxy.Dialer acceptError error ws *websocket.Conn @@ -186,13 +189,18 @@ func (l *listener) negotiate(conn net.Conn) { return } for _, server := range msg.Servers { + if server.Username == TURNProxyICECandidate().Username { + // This candidate is only used when proxying, + // so it will not validate. + continue + } err = DialICE(server, nil) if err != nil { closeError(fmt.Errorf("dial server %+v: %w", server.URLs, err)) return } } - rtc, err = newPeerConnection(msg.Servers) + rtc, err = newPeerConnection(msg.Servers, l.tcpProxy) if err != nil { closeError(err) return @@ -326,7 +334,7 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { } // Must wrap the data channel inside this connection // for buffering from the dialed endpoint to the client. - co := &conn{ + co := &dataChannelConn{ addr: nil, dc: dc, rw: rw, diff --git a/wsnet/listen_test.go b/wsnet/listen_test.go index 45519b92..47b856c3 100644 --- a/wsnet/listen_test.go +++ b/wsnet/listen_test.go @@ -45,7 +45,7 @@ func TestListen(t *testing.T) { addr := listener.Addr() broker := fmt.Sprintf("http://%s/", addr.String()) - _, err = Listen(context.Background(), broker) + _, err = Listen(context.Background(), broker, nil) if err != nil { t.Error(err) return diff --git a/wsnet/rtc.go b/wsnet/rtc.go index e8b5eab3..05c04f1b 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -17,6 +17,7 @@ import ( "github.com/pion/logging" "github.com/pion/turn/v2" "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" ) var ( @@ -154,7 +155,7 @@ func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) } // Generalizes creating a new peer connection with consistent options. -func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, error) { +func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc.PeerConnection, error) { se := webrtc.SettingEngine{} se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4}) se.SetSrflxAcceptanceMinWait(0) @@ -164,15 +165,21 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro lf.DefaultLogLevel = logging.LogLevelDisabled se.LoggerFactory = lf + // Enables tunneling of TURN traffic through an arbitrary proxy. + // We proxy TURN over a WebSocket to reduce deployment complexity. + if dialer != nil { + se.SetICEProxyDialer(dialer) + } + transportPolicy := webrtc.ICETransportPolicyAll // If one server is provided and we know it's TURN, we can set the // relay acceptable so the connection starts immediately. if len(servers) == 1 { server := servers[0] - if server.Credential != nil && len(server.URLs) == 1 { + if len(server.URLs) == 1 { url, err := ice.ParseURL(server.URLs[0]) - if err == nil && url.Proto == ice.ProtoTypeTCP { + if err == nil && server.Credential != nil && url.Proto == ice.ProtoTypeTCP { se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) se.SetRelayAcceptanceMinWait(0) }