Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 0d2f06b

Browse files
authored
fix: Remove active connections when RTC connection is lost (#379)
* fix: Remove active connections when RTC connection is lost * Move close injection * Fix linting * Fix linting
1 parent 8be51b6 commit 0d2f06b

File tree

5 files changed

+98
-19
lines changed

5 files changed

+98
-19
lines changed

wsnet/dial.go

+28-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net"
10+
"sync"
1011
"time"
1112

1213
"github.com/pion/datachannel"
@@ -81,9 +82,10 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
8182
flushCandidates()
8283

8384
dialer := &Dialer{
84-
conn: conn,
85-
ctrl: ctrl,
86-
rtc: rtc,
85+
conn: conn,
86+
ctrl: ctrl,
87+
rtc: rtc,
88+
connClosers: make([]io.Closer, 0),
8789
}
8890

8991
return dialer, dialer.negotiate()
@@ -97,6 +99,9 @@ type Dialer struct {
9799
ctrl *webrtc.DataChannel
98100
ctrlrw datachannel.ReadWriteCloser
99101
rtc *webrtc.PeerConnection
102+
103+
connClosers []io.Closer
104+
connClosersMut sync.Mutex
100105
}
101106

102107
func (d *Dialer) negotiate() (err error) {
@@ -111,16 +116,27 @@ func (d *Dialer) negotiate() (err error) {
111116

112117
go func() {
113118
defer close(errCh)
119+
defer func() {
120+
_ = d.conn.Close()
121+
}()
114122
err := waitForConnectionOpen(context.Background(), d.rtc)
115123
if err != nil {
116-
_ = d.conn.Close()
117124
errCh <- err
118125
return
119126
}
120-
go func() {
121-
// Closing this connection took 30ms+.
122-
_ = d.conn.Close()
123-
}()
127+
d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
128+
if pcs == webrtc.PeerConnectionStateConnected {
129+
return
130+
}
131+
132+
// Close connections opened while the RTC was alive.
133+
d.connClosersMut.Lock()
134+
defer d.connClosersMut.Unlock()
135+
for _, connCloser := range d.connClosers {
136+
_ = connCloser.Close()
137+
}
138+
d.connClosers = make([]io.Closer, 0)
139+
})
124140
}()
125141

126142
for {
@@ -210,6 +226,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
210226
if err != nil {
211227
return nil, fmt.Errorf("create data channel: %w", err)
212228
}
229+
d.connClosersMut.Lock()
230+
d.connClosers = append(d.connClosers, dc)
231+
d.connClosersMut.Unlock()
232+
213233
err = waitForDataChannelOpen(ctx, dc)
214234
if err != nil {
215235
return nil, fmt.Errorf("wait for open: %w", err)

wsnet/dial_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"context"
66
"crypto/rand"
77
"errors"
8+
"fmt"
89
"io"
910
"net"
1011
"strconv"
1112
"testing"
1213

14+
"github.com/pion/ice/v2"
1315
"github.com/pion/webrtc/v3"
1416
)
1517

@@ -44,6 +46,7 @@ func ExampleDial_basic() {
4446
// You now have access to the proxied remote port in `conn`.
4547
}
4648

49+
// nolint:gocognit
4750
func TestDial(t *testing.T) {
4851
t.Run("Ping", func(t *testing.T) {
4952
connectAddr, listenAddr := createDumbBroker(t)
@@ -184,6 +187,48 @@ func TestDial(t *testing.T) {
184187
t.Error(err)
185188
}
186189
})
190+
191+
t.Run("Disconnect DialContext", func(t *testing.T) {
192+
tcpListener, err := net.Listen("tcp", "0.0.0.0:0")
193+
if err != nil {
194+
t.Error(err)
195+
return
196+
}
197+
go func() {
198+
_, _ = tcpListener.Accept()
199+
}()
200+
201+
connectAddr, listenAddr := createDumbBroker(t)
202+
_, err = Listen(context.Background(), listenAddr)
203+
if err != nil {
204+
t.Error(err)
205+
return
206+
}
207+
turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN)
208+
dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{
209+
URLs: []string{fmt.Sprintf("turn:%s", turnAddr)},
210+
Username: "example",
211+
Credential: testPass,
212+
CredentialType: webrtc.ICECredentialTypePassword,
213+
}})
214+
if err != nil {
215+
t.Error(err)
216+
return
217+
}
218+
conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String())
219+
if err != nil {
220+
t.Error(err)
221+
return
222+
}
223+
// Close the TURN server before reading...
224+
// WebRTC connections take a few seconds to timeout.
225+
closeTurn()
226+
_, err = conn.Read(make([]byte, 16))
227+
if err != io.EOF {
228+
t.Error(err)
229+
return
230+
}
231+
})
187232
}
188233

189234
func BenchmarkThroughput(b *testing.B) {

wsnet/rtc.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro
164164
lf.DefaultLogLevel = logging.LogLevelDisabled
165165
se.LoggerFactory = lf
166166

167+
transportPolicy := webrtc.ICETransportPolicyAll
168+
167169
// If one server is provided and we know it's TURN, we can set the
168170
// relay acceptable so the connection starts immediately.
169171
if len(servers) == 1 {
@@ -174,12 +176,18 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro
174176
se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
175177
se.SetRelayAcceptanceMinWait(0)
176178
}
179+
if err == nil && (url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS) {
180+
// Local peers will connect if they discover they live on the same host.
181+
// For testing purposes, it's simpler if they cannot peer on the same host.
182+
transportPolicy = webrtc.ICETransportPolicyRelay
183+
}
177184
}
178185
}
179186
api := webrtc.NewAPI(webrtc.WithSettingEngine(se))
180187

181188
return api.NewPeerConnection(webrtc.Configuration{
182-
ICEServers: servers,
189+
ICEServers: servers,
190+
ICETransportPolicy: transportPolicy,
183191
})
184192
}
185193

wsnet/rtc_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ func TestDialICE(t *testing.T) {
1616
t.Run("TURN with TLS", func(t *testing.T) {
1717
t.Parallel()
1818

19-
addr := createTURNServer(t, ice.SchemeTypeTURNS, "test")
19+
addr, _ := createTURNServer(t, ice.SchemeTypeTURNS)
2020
err := DialICE(webrtc.ICEServer{
2121
URLs: []string{fmt.Sprintf("turns:%s", addr)},
2222
Username: "example",
23-
Credential: "test",
23+
Credential: testPass,
2424
CredentialType: webrtc.ICECredentialTypePassword,
2525
}, &DialICEOptions{
2626
Timeout: time.Millisecond,
@@ -34,11 +34,11 @@ func TestDialICE(t *testing.T) {
3434
t.Run("Protocol mismatch", func(t *testing.T) {
3535
t.Parallel()
3636

37-
addr := createTURNServer(t, ice.SchemeTypeTURNS, "test")
37+
addr, _ := createTURNServer(t, ice.SchemeTypeTURNS)
3838
err := DialICE(webrtc.ICEServer{
3939
URLs: []string{fmt.Sprintf("turn:%s", addr)},
4040
Username: "example",
41-
Credential: "test",
41+
Credential: testPass,
4242
CredentialType: webrtc.ICECredentialTypePassword,
4343
}, &DialICEOptions{
4444
Timeout: time.Millisecond,
@@ -52,7 +52,7 @@ func TestDialICE(t *testing.T) {
5252
t.Run("Invalid auth", func(t *testing.T) {
5353
t.Parallel()
5454

55-
addr := createTURNServer(t, ice.SchemeTypeTURNS, "test")
55+
addr, _ := createTURNServer(t, ice.SchemeTypeTURNS)
5656
err := DialICE(webrtc.ICEServer{
5757
URLs: []string{fmt.Sprintf("turns:%s", addr)},
5858
Username: "example",

wsnet/wsnet_test.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ import (
2525
"nhooyr.io/websocket"
2626
)
2727

28+
const (
29+
// Password used connecting to the test TURN server.
30+
testPass = "test"
31+
)
32+
2833
// createDumbBroker proxies sockets between /listen and /connect
2934
// to emulate an authenticated WebSocket pair.
3035
func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) {
@@ -86,7 +91,7 @@ func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) {
8691
}
8792

8893
// createTURNServer allocates a TURN server and returns the address.
89-
func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string {
94+
func createTURNServer(t *testing.T, server ice.SchemeType) (string, func()) {
9095
var (
9196
listeners []turn.ListenerConfig
9297
pcListeners []turn.PacketConnConfig
@@ -136,24 +141,25 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string {
136141
ListenerConfigs: listeners,
137142
Realm: "coder",
138143
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
139-
return turn.GenerateAuthKey(username, realm, pass), true
144+
return turn.GenerateAuthKey(username, realm, testPass), true
140145
},
141146
LoggerFactory: lf,
142147
})
143148
if err != nil {
144149
t.Error(err)
145150
}
146-
t.Cleanup(func() {
151+
closeFunc := func() {
147152
for _, l := range listeners {
148153
l.Listener.Close()
149154
}
150155
for _, l := range pcListeners {
151156
l.PacketConn.Close()
152157
}
153158
srv.Close()
154-
})
159+
}
160+
t.Cleanup(closeFunc)
155161

156-
return listenAddr.String()
162+
return listenAddr.String(), closeFunc
157163
}
158164

159165
func generateTLSConfig(t testing.TB) *tls.Config {

0 commit comments

Comments
 (0)