Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/pion/datachannel v1.4.21
github.com/pion/dtls/v2 v2.0.9
github.com/pion/ice/v2 v2.1.7
github.com/pion/logging v0.2.2
github.com/pion/turn/v2 v2.0.5
github.com/pion/webrtc/v3 v3.0.29
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
Expand Down
55 changes: 54 additions & 1 deletion wsnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@ import (
"fmt"
"net"
"net/url"
"sync"
"time"

"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
)

const (
httpScheme = "http"

bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB
maxBufferedAmount uint64 = 1024 * 1024 // 1 MB
// For some reason messages larger just don't work...
// This shouldn't be a huge deal for real-world usage.
// See: https://github.com/pion/datachannel/issues/59
maxMessageLength = 32 * 1024 // 32 KB
)

// TURNEndpoint returns the TURN address for a Coder baseURL.
Expand Down Expand Up @@ -43,19 +52,63 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string {

type conn struct {
addr *net.UnixAddr
dc *webrtc.DataChannel
rw datachannel.ReadWriteCloser

sendMore chan struct{}
closedMutex sync.RWMutex
closed bool

writeMutex sync.Mutex
}

func (c *conn) init() {
c.sendMore = make(chan struct{}, 1)
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
c.dc.OnBufferedAmountLow(func() {
c.closedMutex.RLock()
defer c.closedMutex.RUnlock()
if c.closed {
return
}
select {
case c.sendMore <- struct{}{}:
default:
}
})
}

func (c *conn) Read(b []byte) (n int, err error) {
return c.rw.Read(b)
}

func (c *conn) Write(b []byte) (n int, err error) {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
if len(b) > maxMessageLength {
return 0, fmt.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
}
if c.dc.BufferedAmount()+uint64(len(b)) >= maxBufferedAmount {
<-c.sendMore
}
// TODO (@kyle): There's an obvious race-condition here.
// This is an edge-case, as most-frequently data won't
// be pooled so synchronously, but is definitely possible.
//
// See: https://github.com/pion/sctp/issues/181
time.Sleep(time.Microsecond)

return c.rw.Write(b)
}

func (c *conn) Close() error {
return c.rw.Close()
c.closedMutex.Lock()
defer c.closedMutex.Unlock()
if !c.closed {
c.closed = true
close(c.sendMore)
}
return c.dc.Close()
}

func (c *conn) LocalAddr() net.Addr {
Expand Down
7 changes: 5 additions & 2 deletions wsnet/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,14 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
return nil, ctx.Err()
}

return &conn{
c := &conn{
addr: &net.UnixAddr{
Name: address,
Net: network,
},
dc: dc,
rw: rw,
}, nil
}
c.init()
return c, nil
}
69 changes: 69 additions & 0 deletions wsnet/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package wsnet
import (
"bytes"
"context"
"crypto/rand"
"errors"
"io"
"net"
"strconv"
"testing"

"github.com/pion/webrtc/v3"
Expand Down Expand Up @@ -160,3 +162,70 @@ func TestDial(t *testing.T) {
}
})
}

func BenchmarkThroughput(b *testing.B) {
sizes := []int64{
4,
16,
128,
256,
1024,
4096,
16384,
32768,
}

listener, err := net.Listen("tcp", "0.0.0.0:0")
if err != nil {
b.Error(err)
return
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
b.Error(err)
return
}
go func() {
_, _ = io.Copy(io.Discard, conn)
}()
}
}()
connectAddr, listenAddr := createDumbBroker(b)
_, err = Listen(context.Background(), listenAddr)
if err != nil {
b.Error(err)
return
}

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
if err != nil {
b.Error(err)
return
}
for _, size := range sizes {
size := size
bytes := make([]byte, size)
_, _ = rand.Read(bytes)
b.Run("Rand"+strconv.Itoa(int(size)), func(b *testing.B) {
b.SetBytes(size)
b.ReportAllocs()

conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
if err != nil {
b.Error(err)
return
}
defer conn.Close()

for i := 0; i < b.N; i++ {
_, err := conn.Write(bytes)
if err != nil {
b.Error(err)
break
}
}
})
}
}
18 changes: 13 additions & 5 deletions wsnet/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) {
return
}

conn, err := net.Dial(network, addr)
nc, err := net.Dial(network, addr)
if err != nil {
init.Code = CodeDialErr
init.Err = err.Error()
Expand All @@ -324,13 +324,21 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) {
if init.Err != "" {
return
}
defer conn.Close()
defer dc.Close()
// Must wrap the data channel inside this connection
// for buffering from the dialed endpoint to the client.
co := &conn{
addr: nil,
dc: dc,
rw: rw,
}
co.init()
defer co.Close()
defer nc.Close()

go func() {
_, _ = io.Copy(rw, conn)
_, _ = io.Copy(co, nc)
}()
_, _ = io.Copy(conn, rw)
_, _ = io.Copy(nc, co)
})
}
}
Expand Down
4 changes: 4 additions & 0 deletions wsnet/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/pion/dtls/v2"
"github.com/pion/ice/v2"
"github.com/pion/logging"
"github.com/pion/turn/v2"
"github.com/pion/webrtc/v3"
)
Expand Down Expand Up @@ -159,6 +160,9 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro
se.SetSrflxAcceptanceMinWait(0)
se.DetachDataChannels()
se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2)
lf := logging.NewDefaultLoggerFactory()
lf.DefaultLogLevel = logging.LogLevelDisabled
se.LoggerFactory = lf

// If one server is provided and we know it's TURN, we can set the
// relay acceptable so the connection starts immediately.
Expand Down
6 changes: 5 additions & 1 deletion wsnet/wsnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ import (
"cdr.dev/slog/sloggers/slogtest/assert"
"github.com/hashicorp/yamux"
"github.com/pion/ice/v2"
"github.com/pion/logging"
"github.com/pion/turn/v2"
"nhooyr.io/websocket"
)

// createDumbBroker proxies sockets between /listen and /connect
// to emulate an authenticated WebSocket pair.
func createDumbBroker(t *testing.T) (connectAddr string, listenAddr string) {
func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Error(err)
Expand Down Expand Up @@ -128,13 +129,16 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string {
}}
}

lf := logging.NewDefaultLoggerFactory()
lf.DefaultLogLevel = logging.LogLevelDisabled
srv, err := turn.NewServer(turn.ServerConfig{
PacketConnConfigs: pcListeners,
ListenerConfigs: listeners,
Realm: "coder",
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
return turn.GenerateAuthKey(username, realm, pass), true
},
LoggerFactory: lf,
})
if err != nil {
t.Error(err)
Expand Down