diff --git a/wsnet/dial.go b/wsnet/dial.go index eaab8938..637bc5fd 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/pion/datachannel" @@ -81,9 +82,10 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { flushCandidates() dialer := &Dialer{ - conn: conn, - ctrl: ctrl, - rtc: rtc, + conn: conn, + ctrl: ctrl, + rtc: rtc, + connClosers: make([]io.Closer, 0), } return dialer, dialer.negotiate() @@ -97,6 +99,9 @@ type Dialer struct { ctrl *webrtc.DataChannel ctrlrw datachannel.ReadWriteCloser rtc *webrtc.PeerConnection + + connClosers []io.Closer + connClosersMut sync.Mutex } func (d *Dialer) negotiate() (err error) { @@ -111,16 +116,27 @@ func (d *Dialer) negotiate() (err error) { go func() { defer close(errCh) + defer func() { + _ = d.conn.Close() + }() err := waitForConnectionOpen(context.Background(), d.rtc) if err != nil { - _ = d.conn.Close() errCh <- err return } - go func() { - // Closing this connection took 30ms+. - _ = d.conn.Close() - }() + d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + if pcs == webrtc.PeerConnectionStateConnected { + return + } + + // Close connections opened while the RTC was alive. + d.connClosersMut.Lock() + defer d.connClosersMut.Unlock() + for _, connCloser := range d.connClosers { + _ = connCloser.Close() + } + d.connClosers = make([]io.Closer, 0) + }) }() for { @@ -210,6 +226,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, fmt.Errorf("create data channel: %w", err) } + d.connClosersMut.Lock() + d.connClosers = append(d.connClosers, dc) + d.connClosersMut.Unlock() + err = waitForDataChannelOpen(ctx, dc) if err != nil { return nil, fmt.Errorf("wait for open: %w", err) diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 3e52f17b..71fdc8c7 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -5,11 +5,13 @@ import ( "context" "crypto/rand" "errors" + "fmt" "io" "net" "strconv" "testing" + "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" ) @@ -44,6 +46,7 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } +// nolint:gocognit func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { connectAddr, listenAddr := createDumbBroker(t) @@ -184,6 +187,48 @@ func TestDial(t *testing.T) { t.Error(err) } }) + + t.Run("Disconnect DialContext", func(t *testing.T) { + tcpListener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + t.Error(err) + return + } + go func() { + _, _ = tcpListener.Accept() + }() + + connectAddr, listenAddr := createDumbBroker(t) + _, err = Listen(context.Background(), listenAddr) + if err != nil { + t.Error(err) + return + } + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) + dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: testPass, + CredentialType: webrtc.ICECredentialTypePassword, + }}) + if err != nil { + t.Error(err) + return + } + conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) + if err != nil { + t.Error(err) + return + } + // Close the TURN server before reading... + // WebRTC connections take a few seconds to timeout. + closeTurn() + _, err = conn.Read(make([]byte, 16)) + if err != io.EOF { + t.Error(err) + return + } + }) } func BenchmarkThroughput(b *testing.B) { diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 0605fa29..4d454311 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -164,6 +164,8 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro lf.DefaultLogLevel = logging.LogLevelDisabled se.LoggerFactory = lf + transportPolicy := webrtc.ICETransportPolicyAll + // If one server is provided and we know it's TURN, we can set the // relay acceptable so the connection starts immediately. if len(servers) == 1 { @@ -174,12 +176,18 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) se.SetRelayAcceptanceMinWait(0) } + if err == nil && (url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS) { + // Local peers will connect if they discover they live on the same host. + // For testing purposes, it's simpler if they cannot peer on the same host. + transportPolicy = webrtc.ICETransportPolicyRelay + } } } api := webrtc.NewAPI(webrtc.WithSettingEngine(se)) return api.NewPeerConnection(webrtc.Configuration{ - ICEServers: servers, + ICEServers: servers, + ICETransportPolicy: transportPolicy, }) } diff --git a/wsnet/rtc_test.go b/wsnet/rtc_test.go index 14bdd846..73d1af2f 100644 --- a/wsnet/rtc_test.go +++ b/wsnet/rtc_test.go @@ -16,11 +16,11 @@ func TestDialICE(t *testing.T) { t.Run("TURN with TLS", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", - Credential: "test", + Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }, &DialICEOptions{ Timeout: time.Millisecond, @@ -34,11 +34,11 @@ func TestDialICE(t *testing.T) { t.Run("Protocol mismatch", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turn:%s", addr)}, Username: "example", - Credential: "test", + Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }, &DialICEOptions{ Timeout: time.Millisecond, @@ -52,7 +52,7 @@ func TestDialICE(t *testing.T) { t.Run("Invalid auth", func(t *testing.T) { t.Parallel() - addr := createTURNServer(t, ice.SchemeTypeTURNS, "test") + addr, _ := createTURNServer(t, ice.SchemeTypeTURNS) err := DialICE(webrtc.ICEServer{ URLs: []string{fmt.Sprintf("turns:%s", addr)}, Username: "example", diff --git a/wsnet/wsnet_test.go b/wsnet/wsnet_test.go index fc14cd3c..ad9ac381 100644 --- a/wsnet/wsnet_test.go +++ b/wsnet/wsnet_test.go @@ -25,6 +25,11 @@ import ( "nhooyr.io/websocket" ) +const ( + // Password used connecting to the test TURN server. + testPass = "test" +) + // createDumbBroker proxies sockets between /listen and /connect // to emulate an authenticated WebSocket pair. func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { @@ -86,7 +91,7 @@ func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { } // createTURNServer allocates a TURN server and returns the address. -func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { +func createTURNServer(t *testing.T, server ice.SchemeType) (string, func()) { var ( listeners []turn.ListenerConfig pcListeners []turn.PacketConnConfig @@ -136,14 +141,14 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { ListenerConfigs: listeners, Realm: "coder", AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - return turn.GenerateAuthKey(username, realm, pass), true + return turn.GenerateAuthKey(username, realm, testPass), true }, LoggerFactory: lf, }) if err != nil { t.Error(err) } - t.Cleanup(func() { + closeFunc := func() { for _, l := range listeners { l.Listener.Close() } @@ -151,9 +156,10 @@ func createTURNServer(t *testing.T, server ice.SchemeType, pass string) string { l.PacketConn.Close() } srv.Close() - }) + } + t.Cleanup(closeFunc) - return listenAddr.String() + return listenAddr.String(), closeFunc } func generateTLSConfig(t testing.TB) *tls.Config {