diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index fb83c3a9..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: nhooyr diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..fb0a4558 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,24 @@ +version: 2 +updates: + # Track in case we ever add dependencies. + - package-ecosystem: 'gomod' + directory: '/' + schedule: + interval: 'weekly' + commit-message: + prefix: 'chore' + + # Keep example and test/benchmark deps up-to-date. + - package-ecosystem: 'gomod' + directories: + - '/internal/examples' + - '/internal/thirdparty' + schedule: + interval: 'monthly' + commit-message: + prefix: 'chore' + labels: [] + groups: + internal-deps: + patterns: + - '*' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 13ddbf3e..9f7aed46 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,11 @@ name: ci -on: [push, pull_request] +on: + push: + branches: + - master + pull_request: + branches: + - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true @@ -20,17 +26,25 @@ jobs: - uses: actions/checkout@v4 - run: go version - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod - run: ./ci/lint.sh test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index 2ba9ce34..0eac94cc 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -19,12 +19,18 @@ jobs: test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html @@ -41,6 +47,12 @@ jobs: test-dev: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 with: ref: dev @@ -48,7 +60,7 @@ jobs: with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage-dev.html path: ./ci/out/coverage.html diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml new file mode 100644 index 00000000..6ea76ab6 --- /dev/null +++ b/.github/workflows/static.yml @@ -0,0 +1,52 @@ +name: static + +on: + push: + branches: ['master'] + workflow_dispatch: + +# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages. +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: pages + cancel-in-progress: true + +jobs: + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - name: Generate coverage and badge + run: | + ./ci/test.sh + mkdir -p ./ci/out/static + cp ./ci/out/coverage.html ./ci/out/static/coverage.html + percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%') + wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success" + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: ./ci/out/static/ + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/README.md b/README.md index c74b79dd..80d2b3cc 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # websocket [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket) -[![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://github.com/coder/websocket/coverage.html) +[![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. @@ -63,7 +63,9 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { } defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + // Set the context as needed. Use of r.Context() is not recommended + // to avoid surprising behavior (see http.Hijacker). + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() var v interface{} diff --git a/accept.go b/accept.go index f672a730..f45fdd0b 100644 --- a/accept.go +++ b/accept.go @@ -5,6 +5,7 @@ package websocket import ( "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" @@ -14,7 +15,7 @@ import ( "net/http" "net/textproto" "net/url" - "path/filepath" + "path" "strings" "github.com/coder/websocket/internal/errd" @@ -41,8 +42,8 @@ type AcceptOptions struct { // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host - // with filepath.Match. - // See https://golang.org/pkg/path/filepath/#Match + // with path.Match. + // See https://golang.org/pkg/path/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. @@ -62,6 +63,22 @@ type AcceptOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { @@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. +// +// Note that using the http.Request Context after Accept returns may lead to +// unexpected behavior (see http.Hijacker). func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } @@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { - if errors.Is(err, filepath.ErrBadPattern) { + if errors.Is(err, path.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } @@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) + hj, ok := hijacker(w) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) @@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con client: false, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: brw.Reader, bw: brw.Writer, @@ -221,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) + return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err) } if matched { return nil @@ -234,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { } func match(pattern, s string) (bool, error) { - return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) + return path.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { diff --git a/accept_test.go b/accept_test.go index 4f799126..3b45ac5c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -143,6 +143,33 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + + t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + _, err := Accept(w, r, nil) + assert.Contains(t, err, "failed to hijack connection") + }) + t.Run("closeRace", func(t *testing.T) { t.Parallel() @@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } + +type mockUnwrapper struct { + http.ResponseWriter + unwrap func() http.ResponseWriter +} + +var _ rwUnwrapper = mockUnwrapper{} + +func (mu mockUnwrapper) Unwrap() http.ResponseWriter { + return mu.unwrap() +} diff --git a/autobahn_test.go b/autobahn_test.go index b1b3a7e9..cd0cc9bb 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -92,7 +92,7 @@ func TestAutobahn(t *testing.T) { } }) - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") diff --git a/ci/fmt.sh b/ci/fmt.sh index 31d0c15d..e319a1e4 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -2,22 +2,24 @@ set -eu cd -- "$(dirname "$0")/.." +# Pin golang.org/x/tools, the go.mod of v0.25.0 is incompatible with Go 1.19. +X_TOOLS_VERSION=v0.24.0 + go mod tidy (cd ./internal/thirdparty && go mod tidy) (cd ./internal/examples && go mod tidy) gofmt -w -s . -go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . +go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" . -npx prettier@3.0.3 \ - --write \ +git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \ + --check \ --log-level=warn \ --print-width=90 \ --no-semi \ --single-quote \ - --arrow-parens=avoid \ - $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") + --arrow-parens=avoid -go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go +go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go if [ "${CI-}" ]; then git diff --exit-code diff --git a/ci/lint.sh b/ci/lint.sh index 3cf8eee4..cf9d1abd 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,11 +1,12 @@ #!/bin/sh +set -x set -eu cd -- "$(dirname "$0")/.." go vet ./... GOOS=js GOARCH=wasm go vet ./... -go install honnef.co/go/tools/cmd/staticcheck@latest +go install honnef.co/go/tools/cmd/staticcheck@v0.4.7 staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./... @@ -15,7 +16,7 @@ govulncheck() { cat "$tmpf" fi } -go install golang.org/x/vuln/cmd/govulncheck@latest +go install golang.org/x/vuln/cmd/govulncheck@v1.1.1 govulncheck ./... GOOS=js GOARCH=wasm govulncheck ./... diff --git a/ci/test.sh b/ci/test.sh index a3007614..cc3c22d7 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -24,7 +24,7 @@ cd -- "$(dirname "$0")/.." ) -go install github.com/agnivade/wasmbrowsertest@latest +go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/close.go b/close.go index ff2e878a..f94951dc 100644 --- a/close.go +++ b/close.go @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) { } func (c *Conn) casClosing() bool { - c.closeMu.Lock() - defer c.closeMu.Unlock() - if !c.closing { - c.closing = true - return true - } - return false + return c.closing.Swap(true) } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index 8690fb3b..42fe89fe 100644 --- a/conn.go +++ b/conn.go @@ -69,17 +69,25 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error + + // CloseRead state. closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} + closing atomic.Bool + closeMu sync.Mutex // Protects following. closed chan struct{} - closeMu sync.Mutex - closing bool - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) } type connConfig struct { @@ -88,6 +96,8 @@ type connConfig struct { client bool copts *compressionOptions flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) br *bufio.Reader bw *bufio.Writer @@ -108,8 +118,10 @@ func newConn(cfg connConfig) *Conn { writeTimeout: make(chan context.Context), timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, } c.readMu = newMu(c) @@ -200,9 +212,9 @@ func (c *Conn) flate() bool { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) + p := c.pingCounter.Add(1) - err := c.ping(ctx, strconv.Itoa(int(p))) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } diff --git a/conn_test.go b/conn_test.go index be7c9983..45bb75be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "os" @@ -96,6 +97,85 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingReceivedPongReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Success(t, err) + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2) + assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1)) + }) + + t.Run("pingReceivedPongNotReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1)) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) @@ -364,7 +444,7 @@ func TestWasm(t *testing.T) { defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) + cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { @@ -372,6 +452,18 @@ func TestWasm(t *testing.T) { } } +func cleanEnv(env []string) (out []string) { + for _, e := range env { + // Filter out GITHUB envs and anything with token in it, + // especially GITHUB_TOKEN in CI as it breaks TestWasm. + if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") { + continue + } + out = append(out, e) + } + return out +} + func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) @@ -448,7 +540,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } func BenchmarkConn(b *testing.B) { - var benchCases = []struct { + benchCases := []struct { name string mode websocket.CompressionMode }{ @@ -613,3 +705,149 @@ func TestConcurrentClosePing(t *testing.T) { }() } } + +func TestConnClosePropagation(t *testing.T) { + t.Parallel() + + want := []byte("hello") + keepWriting := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + err := c.Write(context.Background(), websocket.MessageText, want) + if err != nil { + return err + } + } + }) + } + keepReading := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + _, got, err := c.Read(context.Background()) + if err != nil { + return err + } + if !bytes.Equal(want, got) { + return fmt.Errorf("unexpected message: want %q, got %q", want, got) + } + } + }) + } + checkReadErr := func(t *testing.T, err error) { + // Check read error (output depends on when read is called in relation to connection closure). + var ce websocket.CloseError + if errors.As(err, &ce) { + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) + } else { + assert.ErrorIs(t, net.ErrClosed, err) + } + } + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { + for _, c := range conn { + // Check write error. + err := c.Write(context.Background(), websocket.MessageText, want) + assert.ErrorIs(t, net.ErrClosed, err) + + _, _, err = c.Read(context.Background()) + checkReadErr(t, err) + } + } + + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + + _, got, err := other.Read(tt.ctx) + assert.Success(t, err) + assert.Equal(t, "msg", want, got) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + otherReadErr := keepReading(other) + + err := this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = other.CloseRead(tt.ctx) + errs := keepReading(this) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-errs: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + thisReadErr := keepReading(this) + otherReadErr := keepReading(other) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) +} diff --git a/dial.go b/dial.go index ad61a35d..0b11ecbb 100644 --- a/dial.go +++ b/dial.go @@ -48,6 +48,22 @@ type DialOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( client: true, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..9cce45ca --- /dev/null +++ b/hijack.go @@ -0,0 +1,33 @@ +//go:build !js + +package websocket + +import ( + "net/http" +) + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +// hijacker returns the Hijacker interface of the http.ResponseWriter. +// It follows the Unwrap method of the http.ResponseWriter if available, +// matching the behavior of http.ResponseController. If the Hijacker +// interface is not found, it returns false. +// +// Since the http.ResponseController is not available in Go 1.19, and +// does not support checking the presence of the Hijacker interface, +// this function is used to provide a consistent way to check for the +// Hijacker interface across Go versions. +func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t, true + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, false + } + } +} diff --git a/hijack_go120_test.go b/hijack_go120_test.go new file mode 100644 index 00000000..0f0673a9 --- /dev/null +++ b/hijack_go120_test.go @@ -0,0 +1,38 @@ +//go:build !js && go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + _, _, err := http.NewResponseController(w).Hijack() + assert.Contains(t, err, "haha") + hj, ok := hijacker(w) + assert.Equal(t, "hijacker found", ok, true) + _, _, err = hj.Hijack() + assert.Contains(t, err, "haha") +} diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index aa826fba..12cf577a 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -5,15 +5,16 @@ import ( "sync" ) -var bpool sync.Pool +var bpool = sync.Pool{ + New: func() any { + return &bytes.Buffer{} + }, +} // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() - if b == nil { - return &bytes.Buffer{} - } return b.(*bytes.Buffer) } diff --git a/internal/examples/chat/chat.go b/internal/examples/chat/chat.go index 3cb1e021..29f304b7 100644 --- a/internal/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -70,7 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - err := cs.subscribe(r.Context(), w, r) + err := cs.subscribe(w, r) if errors.Is(err, context.Canceled) { return } @@ -111,7 +111,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error { var mu sync.Mutex var c *websocket.Conn var closed bool @@ -142,7 +142,7 @@ func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *h mu.Unlock() defer c.CloseNow() - ctx = c.CloseRead(ctx) + ctx := c.CloseRead(context.Background()) for { select { diff --git a/internal/examples/chat/chat_test.go b/internal/examples/chat/chat_test.go index 8eb72051..dcada0b2 100644 --- a/internal/examples/chat/chat_test.go +++ b/internal/examples/chat/chat_test.go @@ -52,7 +52,7 @@ func Test_chatServer(t *testing.T) { // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // - // The test verifies that every message is seen by ever client + // The test verifies that every message is seen by every client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() diff --git a/internal/examples/echo/server.go b/internal/examples/echo/server.go index a44d20b5..37e2f2c4 100644 --- a/internal/examples/echo/server.go +++ b/internal/examples/echo/server.go @@ -37,7 +37,7 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) for { - err = echo(r.Context(), c, l) + err = echo(c, l) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { return } @@ -51,8 +51,8 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. -func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*10) +func echo(c *websocket.Conn, l *rate.Limiter) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := l.Wait(ctx) diff --git a/internal/examples/go.mod b/internal/examples/go.mod index 4f7a8a70..2aa1ee02 100644 --- a/internal/examples/go.mod +++ b/internal/examples/go.mod @@ -6,5 +6,5 @@ replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 - golang.org/x/time v0.3.0 + golang.org/x/time v0.7.0 ) diff --git a/internal/examples/go.sum b/internal/examples/go.sum index f8a07e82..60aa8f9a 100644 --- a/internal/examples/go.sum +++ b/internal/examples/go.sum @@ -1,2 +1,2 @@ -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod index d946ffae..e060ce67 100644 --- a/internal/thirdparty/go.mod +++ b/internal/thirdparty/go.mod @@ -6,38 +6,40 @@ replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 - github.com/gin-gonic/gin v1.9.1 - github.com/gobwas/ws v1.3.0 - github.com/gorilla/websocket v1.5.0 - github.com/lesismal/nbio v1.3.18 + github.com/gin-gonic/gin v1.10.0 + github.com/gobwas/ws v1.4.0 + github.com/gorilla/websocket v1.5.3 + github.com/lesismal/nbio v1.5.12 ) require ( - github.com/bytedance/sonic v1.9.1 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect - github.com/lesismal/llib v1.1.12 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/lesismal/llib v1.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.9.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.17.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum index 1f542103..2352ac75 100644 --- a/internal/thirdparty/go.sum +++ b/internal/thirdparty/go.sum @@ -1,129 +1,107 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= -github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= -github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/lesismal/llib v1.1.12 h1:KJFB8bL02V+QGIvILEw/w7s6bKj9Ps9Px97MZP2EOk0= -github.com/lesismal/llib v1.1.12/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= -github.com/lesismal/nbio v1.3.18 h1:kmJZlxjQpVfuCPYcXdv0Biv9LHVViJZet5K99Xs3RAs= -github.com/lesismal/nbio v1.3.18/go.mod h1:KWlouFT5cgDdW5sMX8RsHASUMGniea9X0XIellZ0B38= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lesismal/llib v1.1.13 h1:+w1+t0PykXpj2dXQck0+p6vdC9/mnbEXHgUy/HXDGfE= +github.com/lesismal/llib v1.1.13/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= +github.com/lesismal/nbio v1.5.12 h1:YcUjjmOvmKEANs6Oo175JogXvHy8CuE7i6ccjM2/tv4= +github.com/lesismal/nbio v1.5.12/go.mod h1:QsxE0fKFe1PioyjuHVDn2y8ktYK7xv9MFbpkoRFj8vI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go deleted file mode 100644 index a0c40204..00000000 --- a/internal/xsync/int64.go +++ /dev/null @@ -1,23 +0,0 @@ -package xsync - -import ( - "sync/atomic" -) - -// Int64 represents an atomic int64. -type Int64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -// Load loads the int64. -func (v *Int64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -// Store stores the int64. -func (v *Int64) Store(i int64) { - v.i.Store(i) -} diff --git a/netconn.go b/netconn.go index 86f7dadb..b118e4d3 100644 --- a/netconn.go +++ b/netconn.go @@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. - atomic.StoreInt64(&nc.writeExpired, 1) + nc.writeExpired.Store(1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C @@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. - atomic.StoreInt64(&nc.readExpired, 1) + nc.readExpired.Store(1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C @@ -94,25 +94,22 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. - // https://github.com/nhooyr/websocket/pull/438 - readExpired int64 - writeExpired int64 - c *Conn msgType MessageType - writeTimer *time.Timer - writeMu *mu - writeCtx context.Context - writeCancel context.CancelFunc - - readTimer *time.Timer - readMu *mu - readCtx context.Context - readCancel context.CancelFunc - readEOFed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeExpired atomic.Int64 + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readExpired atomic.Int64 + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} @@ -129,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() - if atomic.LoadInt64(&nc.writeExpired) == 1 { + if nc.writeExpired.Load() == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } @@ -157,7 +154,7 @@ func (nc *netConn) Read(p []byte) (int, error) { } func (nc *netConn) read(p []byte) (int, error) { - if atomic.LoadInt64(&nc.readExpired) == 1 { + if nc.readExpired.Load() == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } @@ -209,7 +206,7 @@ func (nc *netConn) SetDeadline(t time.Time) error { } func (nc *netConn) SetWriteDeadline(t time.Time) error { - atomic.StoreInt64(&nc.writeExpired, 0) + nc.writeExpired.Store(0) if t.IsZero() { nc.writeTimer.Stop() } else { @@ -223,7 +220,7 @@ func (nc *netConn) SetWriteDeadline(t time.Time) error { } func (nc *netConn) SetReadDeadline(t time.Time) error { - atomic.StoreInt64(&nc.readExpired, 0) + nc.readExpired.Store(0) if t.IsZero() { nc.readTimer.Stop() } else { diff --git a/read.go b/read.go index 1b9404b8..2db22435 100644 --- a/read.go +++ b/read.go @@ -11,11 +11,11 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" - "github.com/coder/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket @@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { +// prepareRead sets the readTimeout context and returns a done function +// to be called after the read is done. It also returns an error if the +// connection is closed. The reference to the error is used to assign +// an error depending on if the connection closed or the context timed +// out during use. Typically the referenced error is a named return +// variable of the function calling this method. +func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: - return header{}, net.ErrClosed + return nil, net.ErrClosed case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) - if err != nil { + done := func() { select { case <-c.closed: - return header{}, net.ErrClosed - case <-ctx.Done(): - return header{}, ctx.Err() - default: - return header{}, err + if *err != nil { + *err = net.ErrClosed + } + case c.readTimeout <- context.Background(): + } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() } } - select { - case <-c.closed: - return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + defer done() + return nil, closeReceivedErr } - return h, nil + return done, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - select { - case <-c.closed: - return 0, net.ErrClosed - case c.readTimeout <- ctx: +func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return header{}, err } + defer readDone() - n, err := io.ReadFull(c.br, p) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { - select { - case <-c.closed: - return n, net.ErrClosed - case <-ctx.Done(): - return n, ctx.Err() - default: - return n, fmt.Errorf("failed to read frame payload: %w", err) - } + return header{}, err } - select { - case <-c.closed: - return n, net.ErrClosed - case c.readTimeout <- context.Background(): + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return 0, err + } + defer readDone() + + n, err := io.ReadFull(c.br, p) + if err != nil { + return n, fmt.Errorf("failed to read frame payload: %w", err) } return n, err @@ -301,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + if c.onPingReceived != nil { + if !c.onPingReceived(ctx, b) { + return nil + } + } return c.writeControl(ctx, opPong, b) case opPong: + if c.onPongReceived != nil { + c.onPongReceived(ctx, b) + } c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() @@ -325,9 +344,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.writeClose(ce.Code, ce.Reason) - c.readMu.unlock() - c.close() + c.closeStateMu.Lock() + c.closeReceivedErr = err + closeSent := c.closeSentErr != nil + c.closeStateMu.Unlock() + + // Only unlock readMu if this connection is being closed becaue + // c.close will try to acquire the readMu lock. We unlock for + // writeClose as well because it may also call c.close. + if !closeSent { + c.readMu.unlock() + _ = c.writeClose(ce.Code, ce.Reason) + } + if !c.casClosing() { + c.readMu.unlock() + _ = c.close() + } return err } @@ -465,7 +497,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit xsync.Int64 + limit atomic.Int64 n int64 } diff --git a/write.go b/write.go index e294a680..7324de74 100644 --- a/write.go +++ b/write.go @@ -5,6 +5,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -14,8 +15,6 @@ import ( "net" "time" - "compress/flate" - "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) @@ -249,22 +248,36 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() + defer func() { + if c.isClosed() && opcode == opClose { + err = nil + } + if err != nil { + if ctx.Err() != nil { + err = ctx.Err() + } else if c.isClosed() { + err = net.ErrClosed + } + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.closeStateMu.Lock() + closeSentErr := c.closeSentErr + c.closeStateMu.Unlock() + if closeSentErr != nil { + return 0, net.ErrClosed + } + select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } - defer func() { - if err != nil { - select { - case <-c.closed: - err = net.ErrClosed - case <-ctx.Done(): - err = ctx.Err() - default: - } - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + case c.writeTimeout <- context.Background(): } }() @@ -303,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - select { - case <-c.closed: - if opcode == opClose { - return n, nil + if opcode == opClose { + c.closeStateMu.Lock() + c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) + closeReceived := c.closeReceivedErr != nil + c.closeStateMu.Unlock() + + if closeReceived && !c.casClosing() { + c.writeFrameMu.unlock() + _ = c.close() } - return n, net.ErrClosed - case c.writeTimeout <- context.Background(): } return n, nil diff --git a/ws_js.go b/ws_js.go index a8de0c63..5e324c47 100644 --- a/ws_js.go +++ b/ws_js.go @@ -12,11 +12,11 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall/js" "github.com/coder/websocket/internal/bpool" "github.com/coder/websocket/internal/wsjs" - "github.com/coder/websocket/internal/xsync" ) // opcode represents a WebSocket opcode. @@ -45,7 +45,7 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit xsync.Int64 + msgReadLimit atomic.Int64 closeReadMu sync.Mutex closeReadCtx context.Context