Skip to content

Commit 65f378c

Browse files
committed
Reduce connection to a single negotiation channel
1 parent e155e31 commit 65f378c

File tree

6 files changed

+272
-318
lines changed

6 files changed

+272
-318
lines changed

peer/conn.go

Lines changed: 49 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,9 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
7878
dcFailedChannel: make(chan struct{}),
7979
// This channel needs to be bufferred otherwise slow consumers
8080
// of this will cause a connection failure.
81-
localCandidateChannel: make(chan webrtc.ICECandidateInit, 16),
82-
pendingCandidatesToSend: make([]webrtc.ICECandidateInit, 0),
83-
pendingCandidatesToAccept: make([]webrtc.ICECandidateInit, 0),
84-
localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
85-
remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
81+
localNegotiator: make(chan Negotiation, 8),
82+
remoteSessionDescription: make(chan webrtc.SessionDescription, 1),
83+
pendingCandidatesToSend: make([]webrtc.ICECandidateInit, 0),
8684
}
8785
if client {
8886
// If we're the client, we want to flip the echo and
@@ -126,15 +124,12 @@ type Conn struct {
126124
dcFailedListeners atomic.Uint32
127125
dcClosedWaitGroup sync.WaitGroup
128126

129-
localCandidateChannel chan webrtc.ICECandidateInit
130-
localSessionDescriptionChannel chan webrtc.SessionDescription
131-
remoteSessionDescriptionChannel chan webrtc.SessionDescription
127+
localNegotiator chan Negotiation
128+
remoteSessionDescription chan webrtc.SessionDescription
132129

133130
pendingCandidatesToSend []webrtc.ICECandidateInit
134131
pendingCandidatesToSendMutex sync.Mutex
135-
136-
pendingCandidatesToAccept []webrtc.ICECandidateInit
137-
pendingCandidatesToAcceptMutex sync.Mutex
132+
pendingCandidatesFlushed bool
138133

139134
pingChannelID uint16
140135
pingEchoChannelID uint16
@@ -148,6 +143,12 @@ type Conn struct {
148143
pingError error
149144
}
150145

146+
// Negotiation represents a handshake message between peer connections.
147+
type Negotiation struct {
148+
SessionDescription *webrtc.SessionDescription
149+
ICECandidates []webrtc.ICECandidateInit
150+
}
151+
151152
func (c *Conn) init() error {
152153
c.rtc.OnNegotiationNeeded(c.negotiate)
153154
c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) {
@@ -181,7 +182,7 @@ func (c *Conn) init() error {
181182
slog.F("hash", c.hashCandidate(json)),
182183
slog.F("length", len(json.Candidate)),
183184
}
184-
if c.rtc.RemoteDescription() == nil {
185+
if !c.pendingCandidatesFlushed {
185186
c.pendingCandidatesToSend = append(c.pendingCandidatesToSend, json)
186187
c.opts.Logger.Debug(context.Background(), "buffering local candidate to send", fields...)
187188
return
@@ -190,7 +191,7 @@ func (c *Conn) init() error {
190191
select {
191192
case <-c.closed:
192193
break
193-
case c.localCandidateChannel <- json:
194+
case c.localNegotiator <- Negotiation{nil, []webrtc.ICECandidateInit{json}}:
194195
}
195196
})
196197
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
@@ -341,20 +342,20 @@ func (c *Conn) negotiate() {
341342
select {
342343
case <-c.closed:
343344
return
344-
case c.localSessionDescriptionChannel <- offer:
345+
case c.localNegotiator <- Negotiation{&offer, nil}:
345346
}
346347
}
347348

348-
var remoteDescription webrtc.SessionDescription
349+
var sessionDescription webrtc.SessionDescription
349350
select {
350351
case <-c.closed:
351352
return
352-
case remoteDescription = <-c.remoteSessionDescriptionChannel:
353+
case sessionDescription = <-c.remoteSessionDescription:
353354
}
354355

355356
c.opts.Logger.Debug(context.Background(), "setting remote description")
356357
c.closeMutex.Lock()
357-
err := c.rtc.SetRemoteDescription(remoteDescription)
358+
err := c.rtc.SetRemoteDescription(sessionDescription)
358359
c.closeMutex.Unlock()
359360
if err != nil {
360361
_ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
@@ -378,80 +379,58 @@ func (c *Conn) negotiate() {
378379
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
379380
return
380381
}
381-
382+
c.opts.Logger.Debug(context.Background(), "sending answer")
382383
select {
383384
case <-c.closed:
384385
return
385-
case c.localSessionDescriptionChannel <- answer:
386+
case c.localNegotiator <- Negotiation{&answer, nil}:
386387
}
387388
}
388389

389390
c.pendingCandidatesToSendMutex.Lock()
390-
for _, pendingCandidate := range c.pendingCandidatesToSend {
391-
c.opts.Logger.Debug(context.Background(), "sending buffered local candidate",
392-
slog.F("hash", c.hashCandidate(pendingCandidate)),
393-
slog.F("length", len(pendingCandidate.Candidate)),
394-
)
391+
defer c.pendingCandidatesToSendMutex.Unlock()
392+
if len(c.pendingCandidatesToSend) > 0 {
395393
select {
396394
case <-c.closed:
397395
return
398-
case c.localCandidateChannel <- pendingCandidate:
396+
case c.localNegotiator <- Negotiation{nil, c.pendingCandidatesToSend}:
399397
}
400398
}
401399
c.opts.Logger.Debug(context.Background(), "flushed buffered local candidates",
402400
slog.F("count", len(c.pendingCandidatesToSend)),
403401
)
404402
c.pendingCandidatesToSend = make([]webrtc.ICECandidateInit, 0)
405-
c.pendingCandidatesToSendMutex.Unlock()
406-
407-
c.pendingCandidatesToAcceptMutex.Lock()
408-
defer c.pendingCandidatesToAcceptMutex.Unlock()
409-
for _, pendingCandidate := range c.pendingCandidatesToAccept {
410-
c.opts.Logger.Debug(context.Background(), "adding buffered remote candidate",
411-
slog.F("hash", c.hashCandidate(pendingCandidate)),
412-
slog.F("length", len(pendingCandidate.Candidate)),
413-
)
414-
err = c.rtc.AddICECandidate(pendingCandidate)
415-
if err != nil {
416-
_ = c.CloseWithError(xerrors.Errorf("accept buffered remote candidate: %w", err))
417-
return
418-
}
419-
}
420-
c.opts.Logger.Debug(context.Background(), "flushed buffered remote candidates",
421-
slog.F("count", len(c.pendingCandidatesToAccept)),
422-
)
423-
c.pendingCandidatesToAccept = make([]webrtc.ICECandidateInit, 0)
403+
c.pendingCandidatesFlushed = true
424404
}
425405

426-
// LocalCandidate returns a channel that emits when a local candidate
427-
// needs to be exchanged with a remote connection.
428-
func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit {
429-
return c.localCandidateChannel
406+
func (c *Conn) LocalNegotiation() <-chan Negotiation {
407+
return c.localNegotiator
430408
}
431409

432-
// AddRemoteCandidate adds a remote candidate to the RTC connection.
433-
func (c *Conn) AddRemoteCandidate(iceCandidate webrtc.ICECandidateInit) error {
434-
c.pendingCandidatesToAcceptMutex.Lock()
435-
defer c.pendingCandidatesToAcceptMutex.Unlock()
436-
fields := []slog.Field{
437-
slog.F("hash", c.hashCandidate(iceCandidate)),
438-
slog.F("length", len(iceCandidate.Candidate)),
439-
}
440-
// The consumer doesn't need to set the session description before
441-
// adding remote candidates. This buffers it so an error doesn't occur.
442-
if c.rtc.RemoteDescription() == nil {
443-
c.opts.Logger.Debug(context.Background(), "buffering remote candidate to accept", fields...)
444-
c.pendingCandidatesToAccept = append(c.pendingCandidatesToAccept, iceCandidate)
445-
return nil
446-
}
447-
c.opts.Logger.Debug(context.Background(), "adding remote candidate", fields...)
448-
return c.rtc.AddICECandidate(iceCandidate)
449-
}
410+
func (c *Conn) AddRemoteNegotiation(negotiation Negotiation) error {
411+
if negotiation.SessionDescription != nil {
412+
c.opts.Logger.Debug(context.Background(), "adding remote negotiation with session description")
413+
select {
414+
case <-c.closed:
415+
return nil
416+
case c.remoteSessionDescription <- *negotiation.SessionDescription:
417+
}
418+
}
419+
420+
if len(negotiation.ICECandidates) > 0 {
421+
c.opts.Logger.Debug(context.Background(), "adding remote negotiation with ice candidates",
422+
slog.F("count", len(negotiation.ICECandidates)))
423+
c.closeMutex.Lock()
424+
defer c.closeMutex.Unlock()
425+
for _, iceCandidate := range negotiation.ICECandidates {
426+
err := c.rtc.AddICECandidate(iceCandidate)
427+
if err != nil {
428+
return err
429+
}
430+
}
431+
}
450432

451-
// LocalSessionDescription returns a channel that emits a session description
452-
// when one is required to be exchanged.
453-
func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription {
454-
return c.localSessionDescriptionChannel
433+
return nil
455434
}
456435

457436
// SetConfiguration applies options to the WebRTC connection.
@@ -460,19 +439,6 @@ func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error {
460439
return c.rtc.SetConfiguration(configuration)
461440
}
462441

463-
// SetRemoteSessionDescription sets the remote description for the WebRTC connection.
464-
func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) {
465-
if c.isClosed() {
466-
return
467-
}
468-
c.closeMutex.Lock()
469-
defer c.closeMutex.Unlock()
470-
select {
471-
case <-c.closed:
472-
case c.remoteSessionDescriptionChannel <- sessionDescription:
473-
}
474-
}
475-
476442
// Accept blocks waiting for a channel to be opened.
477443
func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
478444
var dataChannel *webrtc.DataChannel

peer/conn_test.go

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ var (
3535
// In CI resources are frequently contended, so increasing this value
3636
// results in less flakes.
3737
if os.Getenv("CI") == "true" {
38-
return 3 * time.Second
38+
return time.Second
3939
}
4040
return 100 * time.Millisecond
4141
}()
@@ -64,7 +64,6 @@ func TestConn(t *testing.T) {
6464
t.Run("Ping", func(t *testing.T) {
6565
t.Parallel()
6666
client, server, _ := createPair(t)
67-
exchange(client, server)
6867
_, err := client.Ping()
6968
require.NoError(t, err)
7069
_, err = server.Ping()
@@ -73,8 +72,7 @@ func TestConn(t *testing.T) {
7372

7473
t.Run("PingNetworkOffline", func(t *testing.T) {
7574
t.Parallel()
76-
client, server, wan := createPair(t)
77-
exchange(client, server)
75+
_, server, wan := createPair(t)
7876
_, err := server.Ping()
7977
require.NoError(t, err)
8078
err = wan.Stop()
@@ -85,8 +83,7 @@ func TestConn(t *testing.T) {
8583

8684
t.Run("PingReconnect", func(t *testing.T) {
8785
t.Parallel()
88-
client, server, wan := createPair(t)
89-
exchange(client, server)
86+
_, server, wan := createPair(t)
9087
_, err := server.Ping()
9188
require.NoError(t, err)
9289
// Create a channel that closes on disconnect.
@@ -107,7 +104,6 @@ func TestConn(t *testing.T) {
107104
t.Run("Accept", func(t *testing.T) {
108105
t.Parallel()
109106
client, server, _ := createPair(t)
110-
exchange(client, server)
111107
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
112108
require.NoError(t, err)
113109

@@ -123,7 +119,6 @@ func TestConn(t *testing.T) {
123119
t.Run("AcceptNetworkOffline", func(t *testing.T) {
124120
t.Parallel()
125121
client, server, wan := createPair(t)
126-
exchange(client, server)
127122
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
128123
require.NoError(t, err)
129124
sch, err := server.Accept(context.Background())
@@ -140,21 +135,22 @@ func TestConn(t *testing.T) {
140135
t.Run("Buffering", func(t *testing.T) {
141136
t.Parallel()
142137
client, server, _ := createPair(t)
143-
exchange(client, server)
144138
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
145139
require.NoError(t, err)
146140
sch, err := server.Accept(context.Background())
147141
require.NoError(t, err)
148142
defer sch.Close()
149143
go func() {
144+
bytes := make([]byte, 4096)
150145
for i := 0; i < 1024; i++ {
151-
_, err := cch.Write(make([]byte, 4096))
146+
_, err := cch.Write(bytes)
152147
require.NoError(t, err)
153148
}
154149
_ = cch.Close()
155150
}()
151+
bytes := make([]byte, 4096)
156152
for {
157-
_, err = sch.Read(make([]byte, 4096))
153+
_, err = sch.Read(bytes)
158154
if err != nil {
159155
require.ErrorIs(t, err, peer.ErrClosed)
160156
break
@@ -165,7 +161,6 @@ func TestConn(t *testing.T) {
165161
t.Run("NetConn", func(t *testing.T) {
166162
t.Parallel()
167163
client, server, _ := createPair(t)
168-
exchange(client, server)
169164
srv, err := net.Listen("tcp", "127.0.0.1:0")
170165
require.NoError(t, err)
171166
defer srv.Close()
@@ -218,7 +213,6 @@ func TestConn(t *testing.T) {
218213
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
219214
t.Parallel()
220215
client, server, _ := createPair(t)
221-
exchange(client, server)
222216
err := client.Close()
223217
require.NoError(t, err)
224218
err = server.Close()
@@ -238,7 +232,6 @@ func TestConn(t *testing.T) {
238232
t.Run("PingConcurrent", func(t *testing.T) {
239233
t.Parallel()
240234
client, server, _ := createPair(t)
241-
exchange(client, server)
242235
var wg sync.WaitGroup
243236
wg.Add(2)
244237
go func() {
@@ -253,19 +246,6 @@ func TestConn(t *testing.T) {
253246
}()
254247
wg.Wait()
255248
})
256-
257-
t.Run("NegotiateOutOfOrder", func(t *testing.T) {
258-
t.Parallel()
259-
client, server, _ := createPair(t)
260-
server.SetRemoteSessionDescription(<-client.LocalSessionDescription())
261-
err := client.AddRemoteCandidate(<-server.LocalCandidate())
262-
require.NoError(t, err)
263-
client.SetRemoteSessionDescription(<-server.LocalSessionDescription())
264-
err = server.AddRemoteCandidate(<-client.LocalCandidate())
265-
require.NoError(t, err)
266-
_, err = client.Ping()
267-
require.NoError(t, err)
268-
})
269249
}
270250

271251
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
@@ -324,18 +304,12 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
324304
_ = wan.Stop()
325305
})
326306

327-
return channel1, channel2, wan
328-
}
329-
330-
func exchange(client *peer.Conn, server *peer.Conn) {
331307
go func() {
332308
for {
333309
select {
334-
case c := <-server.LocalCandidate():
335-
_ = client.AddRemoteCandidate(c)
336-
case c := <-server.LocalSessionDescription():
337-
client.SetRemoteSessionDescription(c)
338-
case <-server.Closed():
310+
case c := <-channel2.LocalNegotiation():
311+
_ = channel1.AddRemoteNegotiation(c)
312+
case <-channel2.Closed():
339313
return
340314
}
341315
}
@@ -344,13 +318,13 @@ func exchange(client *peer.Conn, server *peer.Conn) {
344318
go func() {
345319
for {
346320
select {
347-
case c := <-client.LocalCandidate():
348-
_ = server.AddRemoteCandidate(c)
349-
case c := <-client.LocalSessionDescription():
350-
server.SetRemoteSessionDescription(c)
351-
case <-client.Closed():
321+
case c := <-channel1.LocalNegotiation():
322+
_ = channel2.AddRemoteNegotiation(c)
323+
case <-channel1.Closed():
352324
return
353325
}
354326
}
355327
}()
328+
329+
return channel1, channel2, wan
356330
}

0 commit comments

Comments
 (0)