From 8d5db5d9adc6c63db5679f45e69e655666019219 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 7 Jul 2021 21:54:25 +0000 Subject: [PATCH 1/4] fix: Remove active connections when RTC connection is lost --- wsnet/dial.go | 28 +++++++++++++++++++++++++--- wsnet/dial_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ wsnet/rtc.go | 10 +++++++++- wsnet/rtc_test.go | 6 +++--- wsnet/wsnet_test.go | 9 +++++---- 5 files changed, 84 insertions(+), 11 deletions(-) diff --git a/wsnet/dial.go b/wsnet/dial.go index eaab8938..fe327efe 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/pion/datachannel" @@ -81,9 +82,10 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { flushCandidates() dialer := &Dialer{ - conn: conn, - ctrl: ctrl, - rtc: rtc, + conn: conn, + ctrl: ctrl, + rtc: rtc, + connClosers: make([]io.Closer, 0), } return dialer, dialer.negotiate() @@ -97,6 +99,9 @@ type Dialer struct { ctrl *webrtc.DataChannel ctrlrw datachannel.ReadWriteCloser rtc *webrtc.PeerConnection + + connClosers []io.Closer + connClosersMut sync.Mutex } func (d *Dialer) negotiate() (err error) { @@ -117,6 +122,19 @@ func (d *Dialer) negotiate() (err error) { errCh <- err return } + d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + if pcs == webrtc.PeerConnectionStateConnected { + return + } + + // Close connections opened while the RTC was alive. + d.connClosersMut.Lock() + defer d.connClosersMut.Unlock() + for _, connCloser := range d.connClosers { + _ = connCloser.Close() + } + d.connClosers = make([]io.Closer, 0) + }) go func() { // Closing this connection took 30ms+. _ = d.conn.Close() @@ -252,6 +270,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. return nil, ctx.Err() } + d.connClosersMut.Lock() + defer d.connClosersMut.Unlock() + d.connClosers = append(d.connClosers, rw) + c := &conn{ addr: &net.UnixAddr{ Name: address, diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 3e52f17b..d888f811 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -5,11 +5,13 @@ import ( "context" "crypto/rand" "errors" + "fmt" "io" "net" "strconv" "testing" + "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" ) @@ -184,6 +186,46 @@ func TestDial(t *testing.T) { t.Error(err) } }) + + t.Run("Disconnect DialContext", func(t *testing.T) { + tcpListener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + t.Error(err) + return + } + go tcpListener.Accept() + + connectAddr, listenAddr := createDumbBroker(t) + _, err = Listen(context.Background(), listenAddr) + if err != nil { + t.Error(err) + return + } + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN, "test") + dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: "test", + CredentialType: webrtc.ICECredentialTypePassword, + }}) + if err != nil { + t.Error(err) + return + } + conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) + if err != nil { + t.Error(err) + return + } + // Close the TURN server before reading... + // WebRTC connections take a few seconds to timeout. + closeTurn() + _, err = conn.Read(make([]byte, 16)) + if err != io.EOF { + t.Error(err) + return + } + }) } func BenchmarkThroughput(b *testing.B) { diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 0605fa29..4d454311 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -164,6 +164,8 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro lf.DefaultLogLevel = logging.LogLevelDisabled se.LoggerFactory = lf + transportPolicy := webrtc.ICETransportPolicyAll + // If one server is provided and we know it's TURN, we can set the // relay acceptable so the connection starts immediately. if len(servers) == 1 { @@ -174,12 +176,18 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) se.SetRelayAcceptanceMinWait(0) } + if err == nil && (url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS) { + // Local peers will connect if they discover they live on the same host. + // For testing purposes, it's simpler if they cannot peer on the same host. + transportPolicy = webrtc.ICETransportPolicyRelay + } } } api := webrtc.NewAPI(webrtc.WithSettingEngine(se)) return api.NewPeerConnection(webrtc.Configuration{ - ICEServers: servers, + ICEServers: servers, + ICETransportPolicy: transportPolicy, }) } diff --git a/wsnet/rtc_test.go b/wsnet/rtc_test.go index 14bdd846..52fd3a3a 100644 --- a/wsnet/rtc_test.go +++ b/wsnet/rtc_test.go @@ -16,7 +16,7 @@ func TestDialICE(t *testing.T) { t.Run("TURN with TLS", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", @@ -34,7 +34,7 @@ func TestDialICE(t *testing.T) { t.Run("Protocol mismatch", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turn:%s", addr)}, Username: "example", @@ -52,7 +52,7 @@ func TestDialICE(t *testing.T) { t.Run("Invalid auth", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", diff --git a/wsnet/wsnet_test.go b/wsnet/wsnet_test.go index fc14cd3c..122b91b1 100644 --- a/wsnet/wsnet_test.go +++ b/wsnet/wsnet_test.go @@ -86,7 +86,7 @@ func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { } // createTURNServer allocates a TURN server and returns the address. -func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { +func createTURNServer(t *testing.T, server ice.SchemeType, pass string) (string, func()) { var ( listeners []turn.ListenerConfig pcListeners []turn.PacketConnConfig @@ -143,7 +143,7 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { if err != nil { t.Error(err) } - t.Cleanup(func() { + closeFunc := func() { for _, l := range listeners { l.Listener.Close() } @@ -151,9 +151,10 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { l.PacketConn.Close() } srv.Close() - }) + } + t.Cleanup(closeFunc) - return listenAddr.String() + return listenAddr.String(), closeFunc } func generateTLSConfig(t testing.TB) *tls.Config { From e9d49710150794c56d84bd1768d4b6f300dc5d0c Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 8 Jul 2021 13:56:48 +0000 Subject: [PATCH 2/4] Move close injection --- wsnet/dial.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/wsnet/dial.go b/wsnet/dial.go index fe327efe..637bc5fd 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -116,9 +116,11 @@ func (d *Dialer) negotiate() (err error) { go func() { defer close(errCh) + defer func() { + _ = d.conn.Close() + }() err := waitForConnectionOpen(context.Background(), d.rtc) if err != nil { - _ = d.conn.Close() errCh <- err return } @@ -135,10 +137,6 @@ func (d *Dialer) negotiate() (err error) { } d.connClosers = make([]io.Closer, 0) }) - go func() { - // Closing this connection took 30ms+. - _ = d.conn.Close() - }() }() for { @@ -228,6 +226,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, fmt.Errorf("create data channel: %w", err) } + d.connClosersMut.Lock() + d.connClosers = append(d.connClosers, dc) + d.connClosersMut.Unlock() + err = waitForDataChannelOpen(ctx, dc) if err != nil { return nil, fmt.Errorf("wait for open: %w", err) @@ -270,10 +272,6 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. return nil, ctx.Err() } - d.connClosersMut.Lock() - defer d.connClosersMut.Unlock() - d.connClosers = append(d.connClosers, rw) - c := &conn{ addr: &net.UnixAddr{ Name: address, From 48d6dcc636c9b79a6b6caa3598231de86c04e38e Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 8 Jul 2021 14:00:43 +0000 Subject: [PATCH 3/4] Fix linting --- wsnet/dial_test.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index d888f811..8af31060 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -46,6 +46,7 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } +// nolint:gocognit,nestif func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { connectAddr, listenAddr := createDumbBroker(t) @@ -193,7 +194,9 @@ func TestDial(t *testing.T) { t.Error(err) return } - go tcpListener.Accept() + go func() { + _, _ = tcpListener.Accept() + }() connectAddr, listenAddr := createDumbBroker(t) _, err = Listen(context.Background(), listenAddr) @@ -201,11 +204,12 @@ func TestDial(t *testing.T) { t.Error(err) return } - turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN, "test") + pass := "test" + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN, pass) dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, Username: "example", - Credential: "test", + Credential: pass, CredentialType: webrtc.ICECredentialTypePassword, }}) if err != nil { From 9771524ae70066f5b8e71fd34395e0454622a7bf Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 8 Jul 2021 14:03:49 +0000 Subject: [PATCH 4/4] Fix linting --- wsnet/dial_test.go | 7 +++---- wsnet/rtc_test.go | 10 +++++----- wsnet/wsnet_test.go | 9 +++++++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 8af31060..71fdc8c7 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -46,7 +46,7 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } -// nolint:gocognit,nestif +// nolint:gocognit func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { connectAddr, listenAddr := createDumbBroker(t) @@ -204,12 +204,11 @@ func TestDial(t *testing.T) { t.Error(err) return } - pass := "test" - turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN, pass) + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, Username: "example", - Credential: pass, + Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }}) if err != nil { diff --git a/wsnet/rtc_test.go b/wsnet/rtc_test.go index 52fd3a3a..73d1af2f 100644 --- a/wsnet/rtc_test.go +++ b/wsnet/rtc_test.go @@ -16,11 +16,11 @@ func TestDialICE(t *testing.T) { t.Run("TURN with TLS", func(t *testing.T) { t.Parallel() - addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", - Credential: "test", + Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }, &DialICEOptions{ Timeout: time.Millisecond, @@ -34,11 +34,11 @@ func TestDialICE(t *testing.T) { t.Run("Protocol mismatch", func(t *testing.T) { t.Parallel() - addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turn:%s", addr)}, Username: "example", - Credential: "test", + Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }, &DialICEOptions{ Timeout: time.Millisecond, @@ -52,7 +52,7 @@ func TestDialICE(t *testing.T) { t.Run("Invalid auth", func(t *testing.T) { t.Parallel() - addr, _ := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", diff --git a/wsnet/wsnet_test.go b/wsnet/wsnet_test.go index 122b91b1..ad9ac381 100644 --- a/wsnet/wsnet_test.go +++ b/wsnet/wsnet_test.go @@ -25,6 +25,11 @@ import ( "nhooyr.io/websocket" ) +const ( + // Password used connecting to the test TURN server. + testPass = "test" +) + // createDumbBroker proxies sockets between /listen and /connect // to emulate an authenticated WebSocket pair. func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { @@ -86,7 +91,7 @@ func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { } // createTURNServer allocates a TURN server and returns the address. -func createTURNServer(t *testing.T, server ice.SchemeType, pass string) (string, func()) { +func createTURNServer(t *testing.T, server ice.SchemeType) (string, func()) { var ( listeners []turn.ListenerConfig pcListeners []turn.PacketConnConfig @@ -136,7 +141,7 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) (string, ListenerConfigs: listeners, Realm: "coder", AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - return turn.GenerateAuthKey(username, realm, pass), true + return turn.GenerateAuthKey(username, realm, testPass), true }, LoggerFactory: lf, })