diff --git a/wsnet/cache.go b/wsnet/cache.go index e62aa0a9..b16950ca 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/pion/webrtc/v3" "golang.org/x/sync/singleflight" ) @@ -39,7 +40,7 @@ type DialerCache struct { // init starts the ticker for evicting connections. func (d *DialerCache) init() { - ticker := time.NewTicker(time.Second * 30) + ticker := time.NewTicker(time.Second * 5) defer ticker.Stop() for { select { @@ -62,17 +63,11 @@ func (d *DialerCache) evict() { go func() { defer wg.Done() - evict := false - select { - case <-dialer.Closed(): + // If we're no longer signaling, the connection is pending close. + evict := dialer.rtc.SignalingState() == webrtc.SignalingStateClosed + if dialer.activeConnections() == 0 && time.Since(d.atime[key]) >= d.ttl { evict = true - default: - } - if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl { - evict = true - } - // If we're already evicting there's no point in trying to ping. - if !evict { + } else { ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() err := dialer.Ping(ctx) @@ -116,17 +111,12 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (* dialer, ok := d.dialers[key] d.mut.RUnlock() if ok { - closed := false - select { - case <-dialer.Closed(): - closed = true - default: - } - if !closed { - d.mut.Lock() - d.atime[key] = time.Now() - d.mut.Unlock() + d.mut.Lock() + d.atime[key] = time.Now() + d.mut.Unlock() + // The connection is pending close here... + if dialer.rtc.SignalingState() != webrtc.SignalingStateClosed { return dialer, true, nil } } diff --git a/wsnet/conn.go b/wsnet/conn.go index 5b863f04..de67c3c4 100644 --- a/wsnet/conn.go +++ b/wsnet/conn.go @@ -2,6 +2,7 @@ package wsnet import ( "context" + "errors" "fmt" "net" "net/http" @@ -73,9 +74,13 @@ func (t *turnProxyDialer) Dial(network, addr string) (c net.Conn, err error) { // Copy the baseURL so we can adjust path. url := *t.baseURL - url.Scheme = "wss" - if url.Scheme == httpScheme { + switch url.Scheme { + case "http": url.Scheme = "ws" + case "https": + url.Scheme = "wss" + default: + return nil, errors.New("invalid turn url addr scheme provided") } url.Path = "/api/private/turn" conn, resp, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ diff --git a/wsnet/dial.go b/wsnet/dial.go index 283bedf4..394e6a9f 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -118,7 +118,6 @@ func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { conn: conn, ctrl: ctrl, rtc: rtc, - closedChan: make(chan struct{}), connClosers: []io.Closer{ctrl}, } @@ -134,7 +133,6 @@ type Dialer struct { ctrlrw datachannel.ReadWriteCloser rtc *webrtc.PeerConnection - closedChan chan struct{} connClosers []io.Closer connClosersMut sync.Mutex pingMut sync.Mutex @@ -161,25 +159,17 @@ func (d *Dialer) negotiate() (err error) { return } d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { - if pcs != webrtc.PeerConnectionStateDisconnected { + if pcs == webrtc.PeerConnectionStateConnected { return } - // Close connections opened while the RTC was alive. + // Close connections opened when RTC was alive. d.connClosersMut.Lock() defer d.connClosersMut.Unlock() for _, connCloser := range d.connClosers { _ = connCloser.Close() } d.connClosers = make([]io.Closer, 0) - - select { - case <-d.closedChan: - return - default: - } - close(d.closedChan) - _ = d.rtc.Close() }) }() @@ -228,15 +218,9 @@ func (d *Dialer) negotiate() (err error) { return <-errCh } -// Closed returns a channel that closes when -// the connection is closed. -func (d *Dialer) Closed() <-chan struct{} { - return d.closedChan -} - // ActiveConnections returns the amount of active connections. // DialContext opens a connection, and close will end it. -func (d *Dialer) ActiveConnections() int { +func (d *Dialer) activeConnections() int { stats, ok := d.rtc.GetStats().GetConnectionStats(d.rtc) if !ok { return -1 @@ -248,12 +232,6 @@ func (d *Dialer) ActiveConnections() int { // Close closes the RTC connection. // All data channels dialed will be closed. func (d *Dialer) Close() error { - select { - case <-d.closedChan: - return nil - default: - } - close(d.closedChan) return d.rtc.Close() } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 8a6486ba..5dd11b58 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -9,7 +9,6 @@ import ( "net" "strconv" "testing" - "time" "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" @@ -223,27 +222,6 @@ func TestDial(t *testing.T) { assert.ErrorIs(t, err, io.EOF) }) - t.Run("Closed", func(t *testing.T) { - t.Parallel() - - connectAddr, listenAddr := createDumbBroker(t) - l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - require.NoError(t, err) - defer l.Close() - - dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) - require.NoError(t, err) - go func() { - _ = dialer.Close() - }() - - select { - case <-dialer.Closed(): - case <-time.NewTimer(time.Second).C: - t.Error("didn't close in time") - } - }) - t.Run("Active Connections", func(t *testing.T) { t.Parallel() @@ -266,14 +244,14 @@ func TestDial(t *testing.T) { t.Error(err) } conn, _ := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) - assert.Equal(t, 1, dialer.ActiveConnections()) + assert.Equal(t, 1, dialer.activeConnections()) _ = conn.Close() - assert.Equal(t, 0, dialer.ActiveConnections()) + assert.Equal(t, 0, dialer.activeConnections()) _, _ = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) conn, _ = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) - assert.Equal(t, 2, dialer.ActiveConnections()) + assert.Equal(t, 2, dialer.activeConnections()) _ = conn.Close() - assert.Equal(t, 1, dialer.ActiveConnections()) + assert.Equal(t, 1, dialer.activeConnections()) }) }