Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 08fe800

Browse files
authored
feat: Add DialCache for key-based connection caching (#391)
* feat: Add DialCache for key-based connection caching * Remove DialOptions * Move DialFunc to Dial * Add WS options to dial * Requested changes * Add comment * Fixup
1 parent f337897 commit 08fe800

File tree

5 files changed

+263
-14
lines changed

5 files changed

+263
-14
lines changed

internal/cmd/tunnel.go

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func (c *tunnneler) start(ctx context.Context) error {
112112
TURNProxyURL: c.brokerAddr,
113113
ICEServers: []webrtc.ICEServer{wsnet.TURNProxyICECandidate()},
114114
},
115+
nil,
115116
)
116117
if err != nil {
117118
return xerrors.Errorf("creating workspace dialer: %w", err)

wsnet/cache.go

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package wsnet
2+
3+
import (
4+
"context"
5+
"errors"
6+
"sync"
7+
"time"
8+
9+
"golang.org/x/sync/singleflight"
10+
)
11+
12+
// DialCache constructs a new DialerCache.
13+
// The cache clears connections that:
14+
// 1. Are older than the TTL and have no active user-created connections.
15+
// 2. Have been closed.
16+
func DialCache(ttl time.Duration) *DialerCache {
17+
dc := &DialerCache{
18+
ttl: ttl,
19+
closed: make(chan struct{}),
20+
flightGroup: &singleflight.Group{},
21+
mut: sync.RWMutex{},
22+
dialers: make(map[string]*Dialer),
23+
atime: make(map[string]time.Time),
24+
}
25+
go dc.init()
26+
return dc
27+
}
28+
29+
type DialerCache struct {
30+
ttl time.Duration
31+
flightGroup *singleflight.Group
32+
closed chan struct{}
33+
mut sync.RWMutex
34+
35+
// Key is the "key" of a dialer, which is usually the workspace ID.
36+
dialers map[string]*Dialer
37+
atime map[string]time.Time
38+
}
39+
40+
// init starts the ticker for evicting connections.
41+
func (d *DialerCache) init() {
42+
ticker := time.NewTicker(time.Second * 30)
43+
defer ticker.Stop()
44+
for {
45+
select {
46+
case <-d.closed:
47+
return
48+
case <-ticker.C:
49+
d.evict()
50+
}
51+
}
52+
}
53+
54+
// evict removes lost/broken/expired connections from the cache.
55+
func (d *DialerCache) evict() {
56+
var wg sync.WaitGroup
57+
d.mut.RLock()
58+
for key, dialer := range d.dialers {
59+
wg.Add(1)
60+
key := key
61+
dialer := dialer
62+
go func() {
63+
defer wg.Done()
64+
65+
evict := false
66+
select {
67+
case <-dialer.Closed():
68+
evict = true
69+
default:
70+
}
71+
if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl {
72+
evict = true
73+
}
74+
// If we're already evicting there's no point in trying to ping.
75+
if !evict {
76+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
77+
defer cancel()
78+
err := dialer.Ping(ctx)
79+
if err != nil {
80+
evict = true
81+
}
82+
}
83+
84+
if !evict {
85+
return
86+
}
87+
88+
_ = dialer.Close()
89+
// Ensure after Ping and potential delays that we're still testing against
90+
// the proper dialer.
91+
if dialer != d.dialers[key] {
92+
return
93+
}
94+
95+
d.mut.Lock()
96+
defer d.mut.Unlock()
97+
delete(d.atime, key)
98+
delete(d.dialers, key)
99+
}()
100+
}
101+
d.mut.RUnlock()
102+
wg.Wait()
103+
}
104+
105+
// Dial returns a Dialer from the cache if one exists with the key provided,
106+
// or dials a new connection using the dialerFunc.
107+
// The bool returns whether the connection was found in the cache or not.
108+
func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) {
109+
select {
110+
case <-d.closed:
111+
return nil, false, errors.New("cache closed")
112+
default:
113+
}
114+
115+
d.mut.RLock()
116+
dialer, ok := d.dialers[key]
117+
d.mut.RUnlock()
118+
if ok {
119+
closed := false
120+
select {
121+
case <-dialer.Closed():
122+
closed = true
123+
default:
124+
}
125+
if !closed {
126+
d.mut.Lock()
127+
d.atime[key] = time.Now()
128+
d.mut.Unlock()
129+
130+
return dialer, true, nil
131+
}
132+
}
133+
134+
rawDialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) {
135+
dialer, err := dialerFunc()
136+
if err != nil {
137+
return nil, err
138+
}
139+
d.mut.Lock()
140+
d.dialers[key] = dialer
141+
d.atime[key] = time.Now()
142+
d.mut.Unlock()
143+
144+
return dialer, nil
145+
})
146+
if err != nil {
147+
return nil, false, err
148+
}
149+
select {
150+
case <-d.closed:
151+
return nil, false, errors.New("cache closed")
152+
default:
153+
}
154+
155+
return rawDialer.(*Dialer), false, nil
156+
}
157+
158+
// Close closes all cached dialers.
159+
func (d *DialerCache) Close() error {
160+
d.mut.Lock()
161+
defer d.mut.Unlock()
162+
163+
for _, dialer := range d.dialers {
164+
err := dialer.Close()
165+
if err != nil {
166+
return err
167+
}
168+
}
169+
close(d.closed)
170+
return nil
171+
}

wsnet/cache_test.go

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package wsnet
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"cdr.dev/slog/sloggers/slogtest"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestCache(t *testing.T) {
14+
dialFunc := func(connectAddr string) func() (*Dialer, error) {
15+
return func() (*Dialer, error) {
16+
return DialWebsocket(context.Background(), connectAddr, nil, nil)
17+
}
18+
}
19+
20+
t.Run("Caches", func(t *testing.T) {
21+
connectAddr, listenAddr := createDumbBroker(t)
22+
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
23+
require.NoError(t, err)
24+
defer l.Close()
25+
26+
cache := DialCache(time.Hour)
27+
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
28+
require.NoError(t, err)
29+
require.Equal(t, cached, false)
30+
c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
31+
require.NoError(t, err)
32+
require.Equal(t, cached, true)
33+
assert.Same(t, c1, c2)
34+
})
35+
36+
t.Run("Create If Closed", func(t *testing.T) {
37+
connectAddr, listenAddr := createDumbBroker(t)
38+
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
39+
require.NoError(t, err)
40+
defer l.Close()
41+
42+
cache := DialCache(time.Hour)
43+
44+
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
45+
require.NoError(t, err)
46+
require.Equal(t, cached, false)
47+
require.NoError(t, c1.Close())
48+
c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
49+
require.NoError(t, err)
50+
require.Equal(t, cached, false)
51+
assert.NotSame(t, c1, c2)
52+
})
53+
54+
t.Run("Evict No Connections", func(t *testing.T) {
55+
connectAddr, listenAddr := createDumbBroker(t)
56+
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
57+
require.NoError(t, err)
58+
defer l.Close()
59+
60+
cache := DialCache(0)
61+
62+
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
63+
require.NoError(t, err)
64+
require.Equal(t, cached, false)
65+
cache.evict()
66+
c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
67+
require.NoError(t, err)
68+
require.Equal(t, cached, false)
69+
assert.NotSame(t, c1, c2)
70+
})
71+
}

wsnet/dial.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ type DialOptions struct {
3535
}
3636

3737
// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
38-
func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) {
39-
conn, resp, err := websocket.Dial(ctx, broker, nil)
38+
func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsOpts *websocket.DialOptions) (*Dialer, error) {
39+
conn, resp, err := websocket.Dial(ctx, broker, wsOpts)
4040
if err != nil {
4141
if resp != nil {
4242
defer func() {
@@ -52,7 +52,7 @@ func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*D
5252
// We should close the socket intentionally.
5353
_ = conn.Close(websocket.StatusInternalError, "an error occurred")
5454
}()
55-
return Dial(nconn, options)
55+
return Dial(nconn, netOpts)
5656
}
5757

5858
// Dial negotiates a connection to a listener.
@@ -246,6 +246,12 @@ func (d *Dialer) ActiveConnections() int {
246246
// Close closes the RTC connection.
247247
// All data channels dialed will be closed.
248248
func (d *Dialer) Close() error {
249+
select {
250+
case <-d.closedChan:
251+
return nil
252+
default:
253+
}
254+
close(d.closedChan)
249255
return d.rtc.Close()
250256
}
251257

wsnet/dial_test.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func ExampleDial_basic() {
3939

4040
dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialOptions{
4141
ICEServers: servers,
42-
})
42+
}, nil)
4343
if err != nil {
4444
// Do something...
4545
}
@@ -60,7 +60,7 @@ func TestDial(t *testing.T) {
6060
require.NoError(t, err)
6161
defer l.Close()
6262

63-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
63+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
6464
require.NoError(t, err)
6565

6666
err = dialer.Ping(context.Background())
@@ -83,7 +83,7 @@ func TestDial(t *testing.T) {
8383
Credential: testPass,
8484
CredentialType: webrtc.ICECredentialTypePassword,
8585
}},
86-
})
86+
}, nil)
8787
require.NoError(t, err)
8888

8989
_ = dialer.Ping(context.Background())
@@ -100,7 +100,7 @@ func TestDial(t *testing.T) {
100100
require.NoError(t, err)
101101
defer l.Close()
102102

103-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
103+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
104104
require.NoError(t, err)
105105

106106
_, err = dialer.DialContext(context.Background(), "tcp", "localhost:100")
@@ -130,7 +130,7 @@ func TestDial(t *testing.T) {
130130
require.NoError(t, err)
131131
defer l.Close()
132132

133-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
133+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
134134
require.NoError(t, err)
135135

136136
conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
@@ -158,7 +158,7 @@ func TestDial(t *testing.T) {
158158
require.NoError(t, err)
159159
defer l.Close()
160160

161-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
161+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
162162
require.NoError(t, err)
163163

164164
conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
@@ -178,7 +178,7 @@ func TestDial(t *testing.T) {
178178
require.NoError(t, err)
179179
defer l.Close()
180180

181-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
181+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
182182
require.NoError(t, err)
183183

184184
err = dialer.Close()
@@ -210,7 +210,7 @@ func TestDial(t *testing.T) {
210210
Credential: testPass,
211211
CredentialType: webrtc.ICECredentialTypePassword,
212212
}},
213-
})
213+
}, nil)
214214
require.NoError(t, err)
215215

216216
conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String())
@@ -231,7 +231,7 @@ func TestDial(t *testing.T) {
231231
require.NoError(t, err)
232232
defer l.Close()
233233

234-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
234+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
235235
require.NoError(t, err)
236236
go func() {
237237
_ = dialer.Close()
@@ -261,7 +261,7 @@ func TestDial(t *testing.T) {
261261
t.Error(err)
262262
return
263263
}
264-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
264+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
265265
if err != nil {
266266
t.Error(err)
267267
}
@@ -314,7 +314,7 @@ func BenchmarkThroughput(b *testing.B) {
314314
}
315315
defer l.Close()
316316

317-
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
317+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil)
318318
if err != nil {
319319
b.Error(err)
320320
return

0 commit comments

Comments
 (0)