Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

feat: Add DialCache for key-based connection caching #391

Merged
merged 7 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/cmd/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func (c *tunnneler) start(ctx context.Context) error {
TURNProxyURL: c.brokerAddr,
ICEServers: []webrtc.ICEServer{wsnet.TURNProxyICECandidate()},
},
nil,
)
if err != nil {
return xerrors.Errorf("creating workspace dialer: %w", err)
Expand Down
163 changes: 163 additions & 0 deletions wsnet/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package wsnet

import (
"context"
"errors"
"sync"
"time"

"golang.org/x/sync/singleflight"
)

// DialCache constructs a new DialerCache.
// The cache clears connections that:
// 1. Are older than the TTL and have no active user-created connections.
// 2. Have been closed.
func DialCache(ttl time.Duration) *DialerCache {
dc := &DialerCache{
ttl: ttl,
closed: make(chan struct{}),
flightGroup: &singleflight.Group{},
mut: sync.RWMutex{},
dialers: make(map[string]*Dialer),
atime: make(map[string]time.Time),
}
go dc.init()
return dc
}

type DialerCache struct {
ttl time.Duration
flightGroup *singleflight.Group
closed chan struct{}
mut sync.RWMutex

// Key is the "key" of a dialer.
dialers map[string]*Dialer
atime map[string]time.Time
}

// init starts the ticker for evicting connections.
func (d *DialerCache) init() {
ticker := time.NewTicker(time.Second * 30)
defer ticker.Stop()
for {
select {
case <-d.closed:
return
case <-ticker.C:
d.evict()
}
}
}

// evict removes lost/broken/expired connections from the cache.
func (d *DialerCache) evict() {
var wg sync.WaitGroup
d.mut.RLock()
for key, dialer := range d.dialers {
wg.Add(1)
key := key
dialer := dialer
go func() {
defer wg.Done()

evict := false
select {
case <-dialer.Closed():
evict = true
default:
}
if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl {
evict = true
}
// If we're already evicting there's no point in trying to ping.
if !evict {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
err := dialer.Ping(ctx)
if err != nil {
evict = true
}
}

if evict {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use else instead of two opposite ifs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh the other one is to ping only if can't evict.

_ = dialer.Close()
d.mut.Lock()
delete(d.atime, key)
delete(d.dialers, key)
d.mut.Unlock()
}
}()
}
d.mut.RUnlock()
wg.Wait()
}

// Dial returns a Dialer from the cache if one exists with the key provided,
// or dials a new connection using the dialerFunc.
// The bool returns whether the connection was found in the cache or not.
func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) {
select {
case <-d.closed:
return nil, false, errors.New("cache closed")
default:
}

d.mut.RLock()
dialer, ok := d.dialers[key]
d.mut.RUnlock()
if ok {
closed := false
select {
case <-dialer.Closed():
closed = true
default:
}
if !closed {
d.mut.Lock()
d.atime[key] = time.Now()
d.mut.Unlock()

return dialer, true, nil
}
}

rawDialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) {
dialer, err := dialerFunc()
if err != nil {
return nil, err
}
d.mut.Lock()
d.dialers[key] = dialer
d.atime[key] = time.Now()
d.mut.Unlock()

return dialer, nil
})
if err != nil {
return nil, false, err
}
select {
case <-d.closed:
return nil, false, errors.New("cache closed")
default:
}

return rawDialer.(*Dialer), false, nil
}

// Close closes all cached dialers.
func (d *DialerCache) Close() error {
d.mut.Lock()
defer d.mut.Unlock()

for _, dialer := range d.dialers {
err := dialer.Close()
if err != nil {
return err
}
}
close(d.closed)
return nil
}
70 changes: 70 additions & 0 deletions wsnet/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package wsnet

import (
"context"
"testing"
"time"

"cdr.dev/slog/sloggers/slogtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCache(t *testing.T) {
dialFunc := func(connectAddr string) func() (*Dialer, error) {
return func() (*Dialer, error) {
return DialWebsocket(context.Background(), connectAddr, nil, nil)
}
}

t.Run("Caches", func(t *testing.T) {
connectAddr, listenAddr := createDumbBroker(t)
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
require.NoError(t, err)
defer l.Close()

cache := DialCache(time.Hour)
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, false)
c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, true)
assert.Same(t, c1, c2)
})

t.Run("Create If Closed", func(t *testing.T) {
connectAddr, listenAddr := createDumbBroker(t)
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
require.NoError(t, err)
defer l.Close()

cache := DialCache(time.Hour)

c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, false)
require.NoError(t, c1.Close())
c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, false)
assert.NotSame(t, c1, c2)
})

t.Run("Evict No Connections", func(t *testing.T) {
connectAddr, listenAddr := createDumbBroker(t)
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
require.NoError(t, err)
defer l.Close()

cache := DialCache(0)

_, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, false)
cache.evict()
_, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr))
require.NoError(t, err)
require.Equal(t, cached, false)
})
}
12 changes: 9 additions & 3 deletions wsnet/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ type DialOptions struct {
}

// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) {
conn, resp, err := websocket.Dial(ctx, broker, nil)
func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsOpts *websocket.DialOptions) (*Dialer, error) {
conn, resp, err := websocket.Dial(ctx, broker, wsOpts)
if err != nil {
if resp != nil {
defer func() {
Expand All @@ -52,7 +52,7 @@ func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*D
// We should close the socket intentionally.
_ = conn.Close(websocket.StatusInternalError, "an error occurred")
}()
return Dial(nconn, options)
return Dial(nconn, netOpts)
}

// Dial negotiates a connection to a listener.
Expand Down Expand Up @@ -246,6 +246,12 @@ func (d *Dialer) ActiveConnections() int {
// Close closes the RTC connection.
// All data channels dialed will be closed.
func (d *Dialer) Close() error {
select {
case <-d.closedChan:
return nil
default:
}
close(d.closedChan)
return d.rtc.Close()
}

Expand Down
22 changes: 11 additions & 11 deletions wsnet/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func ExampleDial_basic() {

dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialOptions{
ICEServers: servers,
})
}, nil)
if err != nil {
// Do something...
}
Expand All @@ -60,7 +60,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)

err = dialer.Ping(context.Background())
Expand All @@ -83,7 +83,7 @@ func TestDial(t *testing.T) {
Credential: testPass,
CredentialType: webrtc.ICECredentialTypePassword,
}},
})
}, nil)
require.NoError(t, err)

_ = dialer.Ping(context.Background())
Expand All @@ -100,7 +100,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)

_, err = dialer.DialContext(context.Background(), "tcp", "localhost:100")
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)

conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)

conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
Expand All @@ -178,7 +178,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)

err = dialer.Close()
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestDial(t *testing.T) {
Credential: testPass,
CredentialType: webrtc.ICECredentialTypePassword,
}},
})
}, nil)
require.NoError(t, err)

conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String())
Expand All @@ -231,7 +231,7 @@ func TestDial(t *testing.T) {
require.NoError(t, err)
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
require.NoError(t, err)
go func() {
_ = dialer.Close()
Expand Down Expand Up @@ -261,7 +261,7 @@ func TestDial(t *testing.T) {
t.Error(err)
return
}
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -314,7 +314,7 @@ func BenchmarkThroughput(b *testing.B) {
}
defer l.Close()

dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
if err != nil {
b.Error(err)
return
Expand Down