Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 0af7789

Browse files
committed
Move DialFunc to Dial
1 parent 857a743 commit 0af7789

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

wsnet/cache.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,15 @@ import (
99
)
1010

1111
// dialerFunc is used to reference a dialer returned for caching.
12-
type dialerFunc func(ctx context.Context, key string) (*Dialer, error)
12+
type dialerFunc func() (*Dialer, error)
1313

1414
// DialCache constructs a new DialerCache.
1515
// The cache clears connections that:
1616
// 1. Are older than the TTL and have no active user-created connections.
1717
// 2. Have been closed.
18-
func DialCache(ttl time.Duration, dialer dialerFunc) *DialerCache {
18+
func DialCache(ttl time.Duration) *DialerCache {
1919
dc := &DialerCache{
2020
ttl: ttl,
21-
dialerFunc: dialer,
2221
closed: make(chan struct{}),
2322
flightGroup: &singleflight.Group{},
2423
mut: sync.RWMutex{},
@@ -30,7 +29,6 @@ func DialCache(ttl time.Duration, dialer dialerFunc) *DialerCache {
3029
}
3130

3231
type DialerCache struct {
33-
dialerFunc dialerFunc
3432
ttl time.Duration
3533
flightGroup *singleflight.Group
3634

@@ -98,7 +96,7 @@ func (d *DialerCache) evict() {
9896

9997
// Dial returns a Dialer from the cache if one exists with the key provided,
10098
// or dials a new connection using the dialerFunc.
101-
func (d *DialerCache) Dial(ctx context.Context, key string) (*Dialer, bool, error) {
99+
func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) {
102100
d.mut.RLock()
103101
if dialer, ok := d.dialers[key]; ok {
104102
closed := false
@@ -119,7 +117,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string) (*Dialer, bool, erro
119117
d.mut.RUnlock()
120118

121119
dialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) {
122-
dialer, err := d.dialerFunc(ctx, key)
120+
dialer, err := dialerFunc()
123121
if err != nil {
124122
return nil, err
125123
}

wsnet/cache_test.go

+15-15
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,23 @@ import (
1010
)
1111

1212
func TestCache(t *testing.T) {
13+
dialFunc := func(connectAddr string) func() (*Dialer, error) {
14+
return func() (*Dialer, error) {
15+
return DialWebsocket(context.Background(), connectAddr, nil)
16+
}
17+
}
18+
1319
t.Run("Caches", func(t *testing.T) {
1420
connectAddr, listenAddr := createDumbBroker(t)
1521
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
1622
require.NoError(t, err)
1723
defer l.Close()
1824

19-
cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) {
20-
return DialWebsocket(ctx, connectAddr, nil)
21-
})
22-
_, cached, err := cache.Dial(context.Background(), "example")
25+
cache := DialCache(time.Hour)
26+
_, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
2327
require.NoError(t, err)
2428
require.Equal(t, cached, false)
25-
_, cached, err = cache.Dial(context.Background(), "example")
29+
_, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr))
2630
require.NoError(t, err)
2731
require.Equal(t, cached, true)
2832
})
@@ -33,15 +37,13 @@ func TestCache(t *testing.T) {
3337
require.NoError(t, err)
3438
defer l.Close()
3539

36-
cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) {
37-
return DialWebsocket(ctx, connectAddr, nil)
38-
})
40+
cache := DialCache(time.Hour)
3941

40-
conn, cached, err := cache.Dial(context.Background(), "example")
42+
conn, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
4143
require.NoError(t, err)
4244
require.Equal(t, cached, false)
4345
require.NoError(t, conn.Close())
44-
_, cached, err = cache.Dial(context.Background(), "example")
46+
_, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr))
4547
require.NoError(t, err)
4648
require.Equal(t, cached, false)
4749
})
@@ -52,15 +54,13 @@ func TestCache(t *testing.T) {
5254
require.NoError(t, err)
5355
defer l.Close()
5456

55-
cache := DialCache(0, func(ctx context.Context, key string) (*Dialer, error) {
56-
return DialWebsocket(ctx, connectAddr, nil)
57-
})
57+
cache := DialCache(0)
5858

59-
_, cached, err := cache.Dial(context.Background(), "example")
59+
_, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
6060
require.NoError(t, err)
6161
require.Equal(t, cached, false)
6262
cache.evict()
63-
_, cached, err = cache.Dial(context.Background(), "example")
63+
_, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr))
6464
require.NoError(t, err)
6565
require.Equal(t, cached, false)
6666
})

0 commit comments

Comments
 (0)