diff --git a/go.mod b/go.mod index 3823f149..3891770e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/pion/datachannel v1.4.21 github.com/pion/dtls/v2 v2.0.9 github.com/pion/ice/v2 v2.1.7 + github.com/pion/logging v0.2.2 github.com/pion/turn/v2 v2.0.5 github.com/pion/webrtc/v3 v3.0.29 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 diff --git a/wsnet/conn.go b/wsnet/conn.go index 7e18723b..b5dea0a5 100644 --- a/wsnet/conn.go +++ b/wsnet/conn.go @@ -4,13 +4,22 @@ import ( "fmt" "net" "net/url" + "sync" "time" "github.com/pion/datachannel" + "github.com/pion/webrtc/v3" ) const ( httpScheme = "http" + + bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB + maxBufferedAmount uint64 = 1024 * 1024 // 1 MB + // For some reason messages larger just don't work... + // This shouldn't be a huge deal for real-world usage. + // See: https://github.com/pion/datachannel/issues/59 + maxMessageLength = 32 * 1024 // 32 KB ) // TURNEndpoint returns the TURN address for a Coder baseURL. @@ -43,7 +52,30 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string { type conn struct { addr *net.UnixAddr + dc *webrtc.DataChannel rw datachannel.ReadWriteCloser + + sendMore chan struct{} + closedMutex sync.RWMutex + closed bool + + writeMutex sync.Mutex +} + +func (c *conn) init() { + c.sendMore = make(chan struct{}, 1) + c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) + c.dc.OnBufferedAmountLow(func() { + c.closedMutex.RLock() + defer c.closedMutex.RUnlock() + if c.closed { + return + } + select { + case c.sendMore <- struct{}{}: + default: + } + }) } func (c *conn) Read(b []byte) (n int, err error) { @@ -51,11 +83,32 @@ func (c *conn) Read(b []byte) (n int, err error) { } func (c *conn) Write(b []byte) (n int, err error) { + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + if len(b) > maxMessageLength { + return 0, fmt.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength) + } + if c.dc.BufferedAmount()+uint64(len(b)) >= maxBufferedAmount { + <-c.sendMore + } + // TODO (@kyle): There's an obvious race-condition here. + // This is an edge-case, as most-frequently data won't + // be pooled so synchronously, but is definitely possible. + // + // See: https://github.com/pion/sctp/issues/181 + time.Sleep(time.Microsecond) + return c.rw.Write(b) } func (c *conn) Close() error { - return c.rw.Close() + c.closedMutex.Lock() + defer c.closedMutex.Unlock() + if !c.closed { + c.closed = true + close(c.sendMore) + } + return c.dc.Close() } func (c *conn) LocalAddr() net.Addr { diff --git a/wsnet/dial.go b/wsnet/dial.go index 23581eaf..9debf47a 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -249,11 +249,14 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. return nil, ctx.Err() } - return &conn{ + c := &conn{ addr: &net.UnixAddr{ Name: address, Net: network, }, + dc: dc, rw: rw, - }, nil + } + c.init() + return c, nil } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 584c37d5..50bdd938 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -3,9 +3,11 @@ package wsnet import ( "bytes" "context" + "crypto/rand" "errors" "io" "net" + "strconv" "testing" "github.com/pion/webrtc/v3" @@ -160,3 +162,70 @@ func TestDial(t *testing.T) { } }) } + +func BenchmarkThroughput(b *testing.B) { + sizes := []int64{ + 4, + 16, + 128, + 256, + 1024, + 4096, + 16384, + 32768, + } + + listener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + b.Error(err) + return + } + go func() { + for { + conn, err := listener.Accept() + if err != nil { + b.Error(err) + return + } + go func() { + _, _ = io.Copy(io.Discard, conn) + }() + } + }() + connectAddr, listenAddr := createDumbBroker(b) + _, err = Listen(context.Background(), listenAddr) + if err != nil { + b.Error(err) + return + } + + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + if err != nil { + b.Error(err) + return + } + for _, size := range sizes { + size := size + bytes := make([]byte, size) + _, _ = rand.Read(bytes) + b.Run("Rand"+strconv.Itoa(int(size)), func(b *testing.B) { + b.SetBytes(size) + b.ReportAllocs() + + conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + if err != nil { + b.Error(err) + return + } + defer conn.Close() + + for i := 0; i < b.N; i++ { + _, err := conn.Write(bytes) + if err != nil { + b.Error(err) + break + } + } + }) + } +} diff --git a/wsnet/listen.go b/wsnet/listen.go index 1496e19c..3a6735f0 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -311,7 +311,7 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { return } - conn, err := net.Dial(network, addr) + nc, err := net.Dial(network, addr) if err != nil { init.Code = CodeDialErr init.Err = err.Error() @@ -324,13 +324,21 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { if init.Err != "" { return } - defer conn.Close() - defer dc.Close() + // Must wrap the data channel inside this connection + // for buffering from the dialed endpoint to the client. + co := &conn{ + addr: nil, + dc: dc, + rw: rw, + } + co.init() + defer co.Close() + defer nc.Close() go func() { - _, _ = io.Copy(rw, conn) + _, _ = io.Copy(co, nc) }() - _, _ = io.Copy(conn, rw) + _, _ = io.Copy(nc, co) }) } } diff --git a/wsnet/rtc.go b/wsnet/rtc.go index f5c7c5f3..0605fa29 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -14,6 +14,7 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/ice/v2" + "github.com/pion/logging" "github.com/pion/turn/v2" "github.com/pion/webrtc/v3" ) @@ -159,6 +160,9 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro se.SetSrflxAcceptanceMinWait(0) se.DetachDataChannels() se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2) + lf := logging.NewDefaultLoggerFactory() + lf.DefaultLogLevel = logging.LogLevelDisabled + se.LoggerFactory = lf // If one server is provided and we know it's TURN, we can set the // relay acceptable so the connection starts immediately. diff --git a/wsnet/wsnet_test.go b/wsnet/wsnet_test.go index 8452015d..fc14cd3c 100644 --- a/wsnet/wsnet_test.go +++ b/wsnet/wsnet_test.go @@ -20,13 +20,14 @@ import ( "cdr.dev/slog/sloggers/slogtest/assert" "github.com/hashicorp/yamux" "github.com/pion/ice/v2" + "github.com/pion/logging" "github.com/pion/turn/v2" "nhooyr.io/websocket" ) // createDumbBroker proxies sockets between /listen and /connect // to emulate an authenticated WebSocket pair. -func createDumbBroker(t *testing.T) (connectAddr string, listenAddr string) { +func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { listener, err := net.Listen("tcp4", "127.0.0.1:0") if err != nil { t.Error(err) @@ -128,6 +129,8 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { }} } + lf := logging.NewDefaultLoggerFactory() + lf.DefaultLogLevel = logging.LogLevelDisabled srv, err := turn.NewServer(turn.ServerConfig{ PacketConnConfigs: pcListeners, ListenerConfigs: listeners, @@ -135,6 +138,7 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { return turn.GenerateAuthKey(username, realm, pass), true }, + LoggerFactory: lf, }) if err != nil { t.Error(err)