diff --git a/.vscode/settings.json b/.vscode/settings.json index 34ed9fbae2c42..d9b2b88f1798c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -57,6 +57,7 @@ "tfexec", "tfstate", "unconvert", + "webrtc", "xerrors", "yamux" ] diff --git a/peer/channel.go b/peer/channel.go index b01154bcfaa25..d1f4930fe31f7 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -1,6 +1,7 @@ package peer import ( + "bufio" "context" "io" "net" @@ -78,7 +79,8 @@ type Channel struct { dc *webrtc.DataChannel // This field can be nil. It becomes set after the DataChannel // has been opened and is detached. - rwc datachannel.ReadWriteCloser + rwc datachannel.ReadWriteCloser + reader io.Reader closed chan struct{} closeMutex sync.Mutex @@ -130,6 +132,21 @@ func (c *Channel) init() { _ = c.closeWithError(xerrors.Errorf("detach: %w", err)) return } + // pion/webrtc will return an io.ErrShortBuffer when a read + // is triggerred with a buffer size less than the chunks written. + // + // This makes sense when considering UDP connections, because + // bufferring of data that has no transmit guarantees is likely + // to cause unexpected behavior. + // + // When ordered, this adds a bufio.Reader. This ensures additional + // data on TCP-like connections can be read in parts, while still + // being bufferred. + if c.opts.Unordered { + c.reader = c.rwc + } else { + c.reader = bufio.NewReader(c.rwc) + } close(c.opened) }) @@ -181,7 +198,7 @@ func (c *Channel) Read(bytes []byte) (int, error) { } } - bytesRead, err := c.rwc.Read(bytes) + bytesRead, err := c.reader.Read(bytes) if err != nil { if c.isClosed() { return 0, c.closeError diff --git a/peer/conn_test.go b/peer/conn_test.go index 519e5f3b743db..644390ba2ea68 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -267,6 +267,27 @@ func TestConn(t *testing.T) { _, err := client.Ping() require.NoError(t, err) }) + + t.Run("ShortBuffer", func(t *testing.T) { + t.Parallel() + client, server, _ := createPair(t) + exchange(client, server) + go func() { + channel, err := client.Dial(context.Background(), "test", nil) + require.NoError(t, err) + _, err = channel.Write([]byte{1, 2}) + require.NoError(t, err) + }() + channel, err := server.Accept(context.Background()) + require.NoError(t, err) + data := make([]byte, 1) + _, err = channel.Read(data) + require.NoError(t, err) + require.Equal(t, uint8(0x1), data[0]) + _, err = channel.Read(data) + require.NoError(t, err) + require.Equal(t, uint8(0x2), data[0]) + }) } func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {