diff --git a/wsnet/listen.go b/wsnet/listen.go index 5a159e52..859ba17f 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -18,7 +18,7 @@ import ( "cdr.dev/coder-cli/coder-sdk" ) -var keepAliveInterval = 5 * time.Second +var connectionRetryInterval = time.Second // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. @@ -41,8 +41,19 @@ func Listen(ctx context.Context, broker string) (io.Closer, error) { // If we hit an EOF, then the connection to the broker // was interrupted. We'll take a short break then dial // again. - time.Sleep(time.Second) - ch, err = l.dial(ctx) + ticker := time.NewTicker(connectionRetryInterval) + for { + select { + case <-ticker.C: + ch, err = l.dial(ctx) + case <-ctx.Done(): + err = ctx.Err() + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + break + } + } + ticker.Stop() } if err != nil { l.acceptError = err @@ -79,7 +90,6 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { l.ws = conn nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary) config := yamux.DefaultConfig() - config.KeepAliveInterval = keepAliveInterval config.LogOutput = io.Discard session, err := yamux.Server(nconn, config) if err != nil { diff --git a/wsnet/listen_test.go b/wsnet/listen_test.go index d228bd09..45519b92 100644 --- a/wsnet/listen_test.go +++ b/wsnet/listen_test.go @@ -11,24 +11,27 @@ import ( "nhooyr.io/websocket" ) +func init() { + // We override this value to make tests faster. + connectionRetryInterval = 10 * time.Millisecond +} + func TestListen(t *testing.T) { t.Run("Reconnect", func(t *testing.T) { - keepAliveInterval = 50 * time.Millisecond - var ( - connCh = make(chan interface{}) + connCh = make(chan *websocket.Conn) mux = http.NewServeMux() srv = http.Server{ Handler: mux, } ) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - _, err := websocket.Accept(w, r, nil) + ws, err := websocket.Accept(w, r, nil) if err != nil { t.Error(err) return } - connCh <- struct{}{} + connCh <- ws }) listener, err := net.Listen("tcp4", "127.0.0.1:0") @@ -47,8 +50,15 @@ func TestListen(t *testing.T) { t.Error(err) return } - <-connCh + conn := <-connCh _ = listener.Close() + // We need to close the connection too... closing a TCP + // listener does not close active local connections. + _ = conn.Close(websocket.StatusGoingAway, "") + + // At least a few retry attempts should be had... + time.Sleep(connectionRetryInterval * 5) + listener, err = net.Listen("tcp4", addr.String()) if err != nil { t.Error(err)