From b87b4c6b3ab68be02c2364cef5a15d532a132681 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 16 Jul 2021 19:51:53 +0000 Subject: [PATCH 1/4] Add logger to wsnet listener --- internal/cmd/agent.go | 24 ++++++++++++++++-------- wsnet/dial_test.go | 19 ++++++++++--------- wsnet/listen.go | 6 +++++- wsnet/listen_test.go | 3 ++- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/internal/cmd/agent.go b/internal/cmd/agent.go index 7563ee4b..29f65b21 100644 --- a/internal/cmd/agent.go +++ b/internal/cmd/agent.go @@ -1,13 +1,15 @@ package cmd import ( - "context" - "log" "net/url" "os" "os/signal" "syscall" + // We use slog here since agent runs in the background and we can benefit + // from structured logging. + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" "github.com/spf13/cobra" "golang.org/x/xerrors" @@ -46,7 +48,10 @@ coder agent start coder agent start --coder-url https://my-coder.com --token xxxx-xxxx `, RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() + var ( + ctx = cmd.Context() + log = slog.Make(sloghuman.Sink(os.Stderr)).Leveled(slog.LevelDebug) + ) if coderURL == "" { var ok bool coderURL, ok = os.LookupEnv("CODER_URL") @@ -73,20 +78,23 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx } } - listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token), token) + log.Info(ctx, "starting wsnet listener", slog.F("coder_access_url", u.String())) + listener, err := wsnet.Listen(ctx, log, wsnet.ListenEndpoint(u, token), token) if err != nil { return xerrors.Errorf("listen: %w", err) } + defer func() { + err := listener.Close() + if err != nil { + log.Error(ctx, "close listener", slog.Error(err)) + } + }() // Block until user sends SIGINT or SIGTERM sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = listener.Close(); err != nil { - log.Panic(err) - } - return nil }, } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 91b7d0a2..be52685c 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" ) @@ -55,7 +56,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, "") + _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -75,7 +76,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, "") + _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -106,7 +107,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, "") + _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -145,7 +146,7 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr, "") + _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -184,7 +185,7 @@ func TestDial(t *testing.T) { _, _ = listener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) - srv, err := Listen(context.Background(), listenAddr, "") + srv, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -211,7 +212,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, "") + _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -245,7 +246,7 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr, "") + _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -282,7 +283,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, "") + _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return @@ -333,7 +334,7 @@ func BenchmarkThroughput(b *testing.B) { } }() connectAddr, listenAddr := createDumbBroker(b) - _, err = Listen(context.Background(), listenAddr, "") + _, err = Listen(context.Background(), slogtest.Make(b, nil), listenAddr, "") if err != nil { b.Error(err) return diff --git a/wsnet/listen.go b/wsnet/listen.go index e159b6e3..ea72cecf 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -17,6 +17,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/coder-cli/coder-sdk" + "cdr.dev/slog" ) // Codes for DialChannelResponse. @@ -41,12 +42,14 @@ type DialChannelResponse struct { // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. -func Listen(ctx context.Context, broker string, turnProxyAuthToken string) (io.Closer, error) { +func Listen(ctx context.Context, log slog.Logger, broker string, turnProxyAuthToken string) (io.Closer, error) { l := &listener{ + log: log, broker: broker, connClosers: make([]io.Closer, 0), turnProxyAuthToken: turnProxyAuthToken, } + // We do a one-off dial outside of the loop to ensure the initial // connection is successful. If not, there's likely an error the // user needs to act on. @@ -89,6 +92,7 @@ type listener struct { broker string turnProxyAuthToken string + log slog.Logger acceptError error ws *websocket.Conn connClosers []io.Closer diff --git a/wsnet/listen_test.go b/wsnet/listen_test.go index 2c5ba35f..55efb0bc 100644 --- a/wsnet/listen_test.go +++ b/wsnet/listen_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "cdr.dev/slog/sloggers/slogtest" "nhooyr.io/websocket" ) @@ -45,7 +46,7 @@ func TestListen(t *testing.T) { addr := listener.Addr() broker := fmt.Sprintf("http://%s/", addr.String()) - _, err = Listen(context.Background(), broker, "") + _, err = Listen(context.Background(), slogtest.Make(t, nil), broker, "") if err != nil { t.Error(err) return From 6af71930785cb7f55414c6061570db1a1cebad46 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 20 Jul 2021 01:41:30 +0000 Subject: [PATCH 2/4] Add more logging to wsnet listener --- wsnet/listen.go | 68 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/wsnet/listen.go b/wsnet/listen.go index ea72cecf..b2404c17 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -9,6 +9,7 @@ import ( "net" "net/url" "sync" + "sync/atomic" "time" "github.com/hashicorp/yamux" @@ -61,6 +62,8 @@ func Listen(ctx context.Context, log slog.Logger, broker string, turnProxyAuthTo for { err := <-ch if errors.Is(err, io.EOF) || errors.Is(err, yamux.ErrKeepAliveTimeout) { + l.log.Warn(ctx, "disconnected from broker", slog.Error(err)) + // If we hit an EOF, then the connection to the broker // was interrupted. We'll take a short break then dial // again. @@ -97,12 +100,16 @@ type listener struct { ws *websocket.Conn connClosers []io.Closer connClosersMut sync.Mutex + + nextConnNumber int64 } func (l *listener) dial(ctx context.Context) (<-chan error, error) { + l.log.Info(ctx, "connecting to broker", slog.F("broker_url", l.broker)) if l.ws != nil { _ = l.ws.Close(websocket.StatusNormalClosure, "new connection inbound") } + conn, resp, err := websocket.Dial(ctx, l.broker, nil) if err != nil { if resp != nil { @@ -111,6 +118,7 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { return nil, err } l.ws = conn + nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary) config := yamux.DefaultConfig() config.LogOutput = io.Discard @@ -118,6 +126,8 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { if err != nil { return nil, fmt.Errorf("create multiplex: %w", err) } + + l.log.Info(ctx, "broker connection established") errCh := make(chan error) go func() { defer close(errCh) @@ -127,9 +137,10 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { errCh <- err break } - go l.negotiate(conn) + go l.negotiate(ctx, conn) } }() + return errCh, nil } @@ -137,9 +148,10 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { // This functions control-flow is important to readability, // so the cognitive overload linter has been disabled. // nolint:gocognit,nestif -func (l *listener) negotiate(conn net.Conn) { +func (l *listener) negotiate(ctx context.Context, conn net.Conn) { var ( err error + id = atomic.AddInt64(&l.nextConnNumber, 1) decoder = json.NewDecoder(conn) rtc *webrtc.PeerConnection // If candidates are sent before an offer, we place them here. @@ -149,6 +161,8 @@ func (l *listener) negotiate(conn net.Conn) { // Sends the error provided then closes the connection. // If RTC isn't connected, we'll close it. closeError = func(err error) { + l.log.Warn(ctx, "negotiation error, closing connection", slog.Error(err)) + d, _ := json.Marshal(&BrokerMessage{ Error: err.Error(), }) @@ -163,6 +177,9 @@ func (l *listener) negotiate(conn net.Conn) { } ) + ctx = slog.With(ctx, slog.F("conn_id", id)) + l.log.Info(ctx, "accepted new session from broker connection, negotiating") + for { var msg BrokerMessage err = decoder.Decode(&msg) @@ -170,6 +187,7 @@ func (l *listener) negotiate(conn net.Conn) { closeError(err) return } + l.log.Debug(ctx, "received broker message", slog.F("msg", msg)) if msg.Candidate != "" { c := webrtc.ICECandidateInit{ @@ -181,6 +199,7 @@ func (l *listener) negotiate(conn net.Conn) { continue } + l.log.Debug(ctx, "adding ICE candidate", slog.F("c", c)) err = rtc.AddICECandidate(c) if err != nil { closeError(fmt.Errorf("accept ice candidate: %w", err)) @@ -199,12 +218,15 @@ func (l *listener) negotiate(conn net.Conn) { // so it will not validate. continue } + + l.log.Debug(ctx, "validating ICE server", slog.F("s", server)) err = DialICE(server, nil) if err != nil { closeError(fmt.Errorf("dial server %+v: %w", server.URLs, err)) return } } + var turnProxy proxy.Dialer if msg.TURNProxyURL != "" { u, err := url.Parse(msg.TURNProxyURL) @@ -223,26 +245,33 @@ func (l *listener) negotiate(conn net.Conn) { return } rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + l.log.Debug(ctx, "connection state change", slog.F("state", pcs.String())) if pcs == webrtc.PeerConnectionStateConnecting { return } _ = conn.Close() }) + flushCandidates := proxyICECandidates(rtc, conn) l.connClosersMut.Lock() l.connClosers = append(l.connClosers, rtc) l.connClosersMut.Unlock() - rtc.OnDataChannel(l.handle(msg)) + rtc.OnDataChannel(l.handle(ctx, msg)) + + l.log.Debug(ctx, "set remote description", slog.F("offer", *msg.Offer)) err = rtc.SetRemoteDescription(*msg.Offer) if err != nil { closeError(fmt.Errorf("apply offer: %w", err)) return } + answer, err := rtc.CreateAnswer(nil) if err != nil { closeError(fmt.Errorf("create answer: %w", err)) return } + + l.log.Debug(ctx, "set local description", slog.F("answer", answer)) err = rtc.SetLocalDescription(answer) if err != nil { closeError(fmt.Errorf("set local answer: %w", err)) @@ -250,13 +279,16 @@ func (l *listener) negotiate(conn net.Conn) { } flushCandidates() - data, err := json.Marshal(&BrokerMessage{ + bmsg := &BrokerMessage{ Answer: rtc.LocalDescription(), - }) + } + data, err := json.Marshal(bmsg) if err != nil { closeError(fmt.Errorf("marshal: %w", err)) return } + + l.log.Debug(ctx, "writing message", slog.F("msg", bmsg)) _, err = conn.Write(data) if err != nil { closeError(fmt.Errorf("write: %w", err)) @@ -264,6 +296,7 @@ func (l *listener) negotiate(conn net.Conn) { } for _, candidate := range pendingCandidates { + l.log.Debug(ctx, "adding pending ICE candidate", slog.F("c", candidate)) err = rtc.AddICECandidate(candidate) if err != nil { closeError(fmt.Errorf("add pending candidate: %w", err)) @@ -275,11 +308,13 @@ func (l *listener) negotiate(conn net.Conn) { } } -func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { +// nolint:gocognit +func (l *listener) handle(ctx context.Context, msg BrokerMessage) func(dc *webrtc.DataChannel) { return func(dc *webrtc.DataChannel) { if dc.Protocol() == controlChannel { // The control channel handles pings. dc.OnOpen(func() { + l.log.Debug(ctx, "control channel open") rw, err := dc.Detach() if err != nil { return @@ -287,7 +322,11 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { // We'll read and write back a single byte for ping/pongin'. d := make([]byte, 1) for { + l.log.Debug(ctx, "sending ping") _, err = rw.Read(d) + if err != nil { + l.log.Debug(ctx, "reading ping response failed", slog.Error(err)) + } if errors.Is(err, io.EOF) { return } @@ -300,7 +339,14 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { return } + ctx = slog.With(ctx, + slog.F("dc_id", dc.ID()), + slog.F("dc_label", dc.Label()), + slog.F("dc_proto", dc.Protocol()), + ) + dc.OnOpen(func() { + l.log.Info(ctx, "data channel opened") rw, err := dc.Detach() if err != nil { return @@ -308,17 +354,21 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { var init DialChannelResponse sendInitMessage := func() { + l.log.Debug(ctx, "sending dc init message", slog.F("msg", init)) initData, err := json.Marshal(&init) if err != nil { + l.log.Debug(ctx, "failed to marshal dc init message", slog.Error(err)) rw.Close() return } _, err = rw.Write(initData) if err != nil { + l.log.Debug(ctx, "failed to write dc init message", slog.Error(err)) return } if init.Err != "" { // If an error occurred, we're safe to close the connection. + l.log.Debug(ctx, "closing data channel due to error", slog.F("msg", init.Err)) dc.Close() return } @@ -336,8 +386,10 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { return } + l.log.Debug(ctx, "dialing remote address", slog.F("network", network), slog.F("addr", addr)) nc, err := net.Dial(network, addr) if err != nil { + l.log.Debug(ctx, "failed to dial remote address") init.Code = CodeDialErr init.Err = err.Error() if op, ok := err.(*net.OpError); ok { @@ -349,8 +401,10 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { if init.Err != "" { return } + // Must wrap the data channel inside this connection // for buffering from the dialed endpoint to the client. + l.log.Debug(ctx, "data channel initialized, tunnelling") co := &dataChannelConn{ addr: nil, dc: dc, @@ -370,6 +424,8 @@ func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { // Close closes the broker socket and all created RTC connections. func (l *listener) Close() error { + l.log.Info(context.Background(), "listener closed") + l.connClosersMut.Lock() for _, connCloser := range l.connClosers { // We can ignore the error here... it doesn't From 86e1c79b3fa536cc4730dc1c194061783a9340cd Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 20 Jul 2021 01:56:30 +0000 Subject: [PATCH 3/4] fixup! Add more logging to wsnet listener --- wsnet/listen.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wsnet/listen.go b/wsnet/listen.go index b2404c17..071c5b9c 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -17,8 +17,9 @@ import ( "golang.org/x/net/proxy" "nhooyr.io/websocket" - "cdr.dev/coder-cli/coder-sdk" "cdr.dev/slog" + + "cdr.dev/coder-cli/coder-sdk" ) // Codes for DialChannelResponse. From eb2200c456c00d7884ffd8025a451f8ae6d6b31a Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 20 Jul 2021 02:31:44 +0000 Subject: [PATCH 4/4] Fix tests --- go.mod | 1 + wsnet/dial_test.go | 220 ++++++++++++++++--------------------------- wsnet/listen.go | 25 ++++- wsnet/listen_test.go | 44 +++------ 4 files changed, 115 insertions(+), 175 deletions(-) diff --git a/go.mod b/go.mod index 46a9d7ce..33895cc2 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/rjeczalik/notify v0.9.2 github.com/spf13/cobra v1.2.1 + github.com/stretchr/testify v1.7.0 golang.org/x/net v0.0.0-20210614182718-04defd469f4e golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index be52685c..fcd00ac7 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -1,7 +1,6 @@ package wsnet import ( - "bytes" "context" "crypto/rand" "errors" @@ -15,6 +14,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func ExampleDial_basic() { @@ -50,37 +51,30 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } -// nolint:gocognit,gocyclo func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + err = dialer.Ping(context.Background()) - if err != nil { - t.Error(err) - } + require.NoError(t, err) }) t.Run("Ping Close", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ ICEServers: []webrtc.ICEServer{{ @@ -90,86 +84,63 @@ func TestDial(t *testing.T) { CredentialType: webrtc.ICECredentialTypePassword, }}, }) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + _ = dialer.Ping(context.Background()) closeTurn() err = dialer.Ping(context.Background()) - if err != io.EOF { - t.Error(err) - return - } + assert.ErrorIs(t, err, io.EOF) }) t.Run("OPError", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - } + require.NoError(t, err) + _, err = dialer.DialContext(context.Background(), "tcp", "localhost:100") - if err == nil { - t.Error("should have gotten err") - return - } - _, ok := err.(*net.OpError) - if !ok { - t.Error("invalid error type returned") - return - } + assert.Error(t, err) + + // Double pointer intended. + netErr := &net.OpError{} + assert.ErrorAs(t, err, &netErr) }) t.Run("Proxy", func(t *testing.T) { t.Parallel() listener, err := net.Listen("tcp", "0.0.0.0:0") - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + msg := []byte("Hello!") go func() { conn, err := listener.Accept() - if err != nil { - t.Error(err) - } + require.NoError(t, err) + _, _ = conn.Write(msg) }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + rec := make([]byte, len(msg)) _, err = conn.Read(rec) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(msg, rec) { - t.Error("bytes were different", string(msg), string(rec)) - } + require.NoError(t, err) + + assert.Equal(t, msg, rec) }) // Expect that we'd get an EOF on the server closing. @@ -177,80 +148,60 @@ func TestDial(t *testing.T) { t.Parallel() listener, err := net.Listen("tcp", "0.0.0.0:0") - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) go func() { _, _ = listener.Accept() }() + connectAddr, listenAddr := createDumbBroker(t) - srv, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - } + require.NoError(t, err) + conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) - if err != nil { - t.Error(err) - return - } - go srv.Close() + require.NoError(t, err) + + go l.Close() rec := make([]byte, 16) _, err = conn.Read(rec) - if !errors.Is(err, io.EOF) { - t.Error(err) - return - } + assert.ErrorIs(t, err, io.EOF) }) t.Run("Disconnect", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + err = dialer.Close() - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + err = dialer.Ping(context.Background()) - if err != webrtc.ErrConnectionClosed { - t.Error(err) - } + assert.ErrorIs(t, err, webrtc.ErrConnectionClosed) }) t.Run("Disconnect DialContext", func(t *testing.T) { t.Parallel() tcpListener, err := net.Listen("tcp", "0.0.0.0:0") - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) go func() { _, _ = tcpListener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ ICEServers: []webrtc.ICEServer{{ @@ -260,42 +211,32 @@ func TestDial(t *testing.T) { CredentialType: webrtc.ICECredentialTypePassword, }}, }) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + // Close the TURN server before reading... // WebRTC connections take a few seconds to timeout. closeTurn() _, err = conn.Read(make([]byte, 16)) - if err != io.EOF { - t.Error(err) - return - } + assert.ErrorIs(t, err, io.EOF) }) t.Run("Closed", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) go func() { _ = dialer.Close() }() + select { case <-dialer.Closed(): case <-time.NewTimer(time.Second).C: @@ -334,11 +275,12 @@ func BenchmarkThroughput(b *testing.B) { } }() connectAddr, listenAddr := createDumbBroker(b) - _, err = Listen(context.Background(), slogtest.Make(b, nil), listenAddr, "") + l, err := Listen(context.Background(), slogtest.Make(b, nil), listenAddr, "") if err != nil { b.Error(err) return } + defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { diff --git a/wsnet/listen.go b/wsnet/listen.go index 071c5b9c..02f13f41 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -49,6 +49,7 @@ func Listen(ctx context.Context, log slog.Logger, broker string, turnProxyAuthTo log: log, broker: broker, connClosers: make([]io.Closer, 0), + closed: make(chan struct{}, 1), turnProxyAuthToken: turnProxyAuthToken, } @@ -62,6 +63,14 @@ func Listen(ctx context.Context, log slog.Logger, broker string, turnProxyAuthTo go func() { for { err := <-ch + select { + case _, ok := <-l.closed: + if !ok { + return + } + default: + } + if errors.Is(err, io.EOF) || errors.Is(err, yamux.ErrKeepAliveTimeout) { l.log.Warn(ctx, "disconnected from broker", slog.Error(err)) @@ -101,7 +110,7 @@ type listener struct { ws *websocket.Conn connClosers []io.Closer connClosersMut sync.Mutex - + closed chan struct{} nextConnNumber int64 } @@ -340,7 +349,7 @@ func (l *listener) handle(ctx context.Context, msg BrokerMessage) func(dc *webrt return } - ctx = slog.With(ctx, + ctx := slog.With(ctx, slog.F("dc_id", dc.ID()), slog.F("dc_label", dc.Label()), slog.F("dc_proto", dc.Protocol()), @@ -428,12 +437,22 @@ func (l *listener) Close() error { l.log.Info(context.Background(), "listener closed") l.connClosersMut.Lock() + defer l.connClosersMut.Unlock() + + select { + case _, ok := <-l.closed: + if !ok { + return errors.New("already closed") + } + default: + } + close(l.closed) + for _, connCloser := range l.connClosers { // We can ignore the error here... it doesn't // really matter if these fail to close. _ = connCloser.Close() } - l.connClosersMut.Unlock() return l.ws.Close(websocket.StatusNormalClosure, "") } diff --git a/wsnet/listen_test.go b/wsnet/listen_test.go index 55efb0bc..78b56691 100644 --- a/wsnet/listen_test.go +++ b/wsnet/listen_test.go @@ -2,13 +2,13 @@ package wsnet import ( "context" - "fmt" - "net" "net/http" + "net/http/httptest" "testing" "time" "cdr.dev/slog/sloggers/slogtest" + "github.com/stretchr/testify/require" "nhooyr.io/websocket" ) @@ -22,9 +22,6 @@ func TestListen(t *testing.T) { var ( connCh = make(chan *websocket.Conn) mux = http.NewServeMux() - srv = http.Server{ - Handler: mux, - } ) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Accept(w, r, nil) @@ -35,39 +32,20 @@ func TestListen(t *testing.T) { connCh <- ws }) - listener, err := net.Listen("tcp4", "127.0.0.1:0") - if err != nil { - t.Error(err) - return - } - go func() { - _ = srv.Serve(listener) - }() - addr := listener.Addr() - broker := fmt.Sprintf("http://%s/", addr.String()) + s := httptest.NewServer(mux) + defer s.Close() - _, err = Listen(context.Background(), slogtest.Make(t, nil), broker, "") - if err != nil { - t.Error(err) - return - } + l, err := Listen(context.Background(), slogtest.Make(t, nil), s.URL, "") + require.NoError(t, err) + defer l.Close() conn := <-connCh - _ = listener.Close() - // We need to close the connection too... closing a TCP - // listener does not close active local connections. - _ = conn.Close(websocket.StatusGoingAway, "") + + // Kill the server connection. + err = conn.Close(websocket.StatusGoingAway, "") + require.NoError(t, err) // At least a few retry attempts should be had... time.Sleep(connectionRetryInterval * 5) - - listener, err = net.Listen("tcp4", addr.String()) - if err != nil { - t.Error(err) - return - } - go func() { - _ = srv.Serve(listener) - }() <-connCh }) }