diff --git a/internal/cmd/tunnel.go b/internal/cmd/tunnel.go index 9c12dd37..956e9fd2 100644 --- a/internal/cmd/tunnel.go +++ b/internal/cmd/tunnel.go @@ -112,6 +112,7 @@ func (c *tunnneler) start(ctx context.Context) error { TURNProxyURL: c.brokerAddr, ICEServers: []webrtc.ICEServer{wsnet.TURNProxyICECandidate()}, }, + nil, ) if err != nil { return xerrors.Errorf("creating workspace dialer: %w", err) diff --git a/wsnet/cache.go b/wsnet/cache.go new file mode 100644 index 00000000..e62aa0a9 --- /dev/null +++ b/wsnet/cache.go @@ -0,0 +1,171 @@ +package wsnet + +import ( + "context" + "errors" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +// DialCache constructs a new DialerCache. +// The cache clears connections that: +// 1. Are older than the TTL and have no active user-created connections. +// 2. Have been closed. +func DialCache(ttl time.Duration) *DialerCache { + dc := &DialerCache{ + ttl: ttl, + closed: make(chan struct{}), + flightGroup: &singleflight.Group{}, + mut: sync.RWMutex{}, + dialers: make(map[string]*Dialer), + atime: make(map[string]time.Time), + } + go dc.init() + return dc +} + +type DialerCache struct { + ttl time.Duration + flightGroup *singleflight.Group + closed chan struct{} + mut sync.RWMutex + + // Key is the "key" of a dialer, which is usually the workspace ID. + dialers map[string]*Dialer + atime map[string]time.Time +} + +// init starts the ticker for evicting connections. +func (d *DialerCache) init() { + ticker := time.NewTicker(time.Second * 30) + defer ticker.Stop() + for { + select { + case <-d.closed: + return + case <-ticker.C: + d.evict() + } + } +} + +// evict removes lost/broken/expired connections from the cache. +func (d *DialerCache) evict() { + var wg sync.WaitGroup + d.mut.RLock() + for key, dialer := range d.dialers { + wg.Add(1) + key := key + dialer := dialer + go func() { + defer wg.Done() + + evict := false + select { + case <-dialer.Closed(): + evict = true + default: + } + if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl { + evict = true + } + // If we're already evicting there's no point in trying to ping. + if !evict { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + err := dialer.Ping(ctx) + if err != nil { + evict = true + } + } + + if !evict { + return + } + + _ = dialer.Close() + // Ensure after Ping and potential delays that we're still testing against + // the proper dialer. + if dialer != d.dialers[key] { + return + } + + d.mut.Lock() + defer d.mut.Unlock() + delete(d.atime, key) + delete(d.dialers, key) + }() + } + d.mut.RUnlock() + wg.Wait() +} + +// Dial returns a Dialer from the cache if one exists with the key provided, +// or dials a new connection using the dialerFunc. +// The bool returns whether the connection was found in the cache or not. +func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) { + select { + case <-d.closed: + return nil, false, errors.New("cache closed") + default: + } + + d.mut.RLock() + dialer, ok := d.dialers[key] + d.mut.RUnlock() + if ok { + closed := false + select { + case <-dialer.Closed(): + closed = true + default: + } + if !closed { + d.mut.Lock() + d.atime[key] = time.Now() + d.mut.Unlock() + + return dialer, true, nil + } + } + + rawDialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { + dialer, err := dialerFunc() + if err != nil { + return nil, err + } + d.mut.Lock() + d.dialers[key] = dialer + d.atime[key] = time.Now() + d.mut.Unlock() + + return dialer, nil + }) + if err != nil { + return nil, false, err + } + select { + case <-d.closed: + return nil, false, errors.New("cache closed") + default: + } + + return rawDialer.(*Dialer), false, nil +} + +// Close closes all cached dialers. +func (d *DialerCache) Close() error { + d.mut.Lock() + defer d.mut.Unlock() + + for _, dialer := range d.dialers { + err := dialer.Close() + if err != nil { + return err + } + } + close(d.closed) + return nil +} diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go new file mode 100644 index 00000000..44edb608 --- /dev/null +++ b/wsnet/cache_test.go @@ -0,0 +1,71 @@ +package wsnet + +import ( + "context" + "testing" + "time" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + dialFunc := func(connectAddr string) func() (*Dialer, error) { + return func() (*Dialer, error) { + return DialWebsocket(context.Background(), connectAddr, nil, nil) + } + } + + t.Run("Caches", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(time.Hour) + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, false) + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, true) + assert.Same(t, c1, c2) + }) + + t.Run("Create If Closed", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(time.Hour) + + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, false) + require.NoError(t, c1.Close()) + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, false) + assert.NotSame(t, c1, c2) + }) + + t.Run("Evict No Connections", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(0) + + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, false) + cache.evict() + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, err) + require.Equal(t, cached, false) + assert.NotSame(t, c1, c2) + }) +} diff --git a/wsnet/dial.go b/wsnet/dial.go index 050bc574..af4b422c 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -35,8 +35,8 @@ type DialOptions struct { } // DialWebsocket dials the broker with a WebSocket and negotiates a connection. -func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) { - conn, resp, err := websocket.Dial(ctx, broker, nil) +func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsOpts *websocket.DialOptions) (*Dialer, error) { + conn, resp, err := websocket.Dial(ctx, broker, wsOpts) if err != nil { if resp != nil { defer func() { @@ -52,7 +52,7 @@ func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*D // We should close the socket intentionally. _ = conn.Close(websocket.StatusInternalError, "an error occurred") }() - return Dial(nconn, options) + return Dial(nconn, netOpts) } // Dial negotiates a connection to a listener. @@ -246,6 +246,12 @@ func (d *Dialer) ActiveConnections() int { // Close closes the RTC connection. // All data channels dialed will be closed. func (d *Dialer) Close() error { + select { + case <-d.closedChan: + return nil + default: + } + close(d.closedChan) return d.rtc.Close() } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index a5d33b96..8a6486ba 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -39,7 +39,7 @@ func ExampleDial_basic() { dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialOptions{ ICEServers: servers, - }) + }, nil) if err != nil { // Do something... } @@ -60,7 +60,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) err = dialer.Ping(context.Background()) @@ -83,7 +83,7 @@ func TestDial(t *testing.T) { Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }}, - }) + }, nil) require.NoError(t, err) _ = dialer.Ping(context.Background()) @@ -100,7 +100,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) _, err = dialer.DialContext(context.Background(), "tcp", "localhost:100") @@ -130,7 +130,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) @@ -158,7 +158,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) @@ -178,7 +178,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) err = dialer.Close() @@ -210,7 +210,7 @@ func TestDial(t *testing.T) { Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }}, - }) + }, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) @@ -231,7 +231,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) go func() { _ = dialer.Close() @@ -261,7 +261,7 @@ func TestDial(t *testing.T) { t.Error(err) return } - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) if err != nil { t.Error(err) } @@ -314,7 +314,7 @@ func BenchmarkThroughput(b *testing.B) { } defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) if err != nil { b.Error(err) return