Skip to content

Commit eac7d46

Browse files
Support client_max_window_bits and server_max_window_bits compression opts
1 parent 91013c1 commit eac7d46

File tree

6 files changed

+75
-13
lines changed

6 files changed

+75
-13
lines changed

accept.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net/textproto"
1616
"net/url"
1717
"path"
18+
"strconv"
1819
"strings"
1920

2021
"github.com/coder/websocket/internal/errd"
@@ -298,15 +299,34 @@ func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOp
298299
case "server_no_context_takeover":
299300
copts.serverNoContextTakeover = true
300301
continue
301-
case "client_max_window_bits",
302-
"server_max_window_bits=15":
302+
case "client_max_window_bits":
303+
copts.clientMaxWindowBits = 15 // default
304+
continue
305+
case "server_max_window_bits":
306+
copts.serverMaxWindowBits = 15 // default
303307
continue
304308
}
305309

306310
if strings.HasPrefix(p, "client_max_window_bits=") {
307-
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
311+
// We don't need to change decoder settings; larger window decoder can read smaller windows.
312+
if v, err := strconv.Atoi(strings.TrimPrefix(p, "client_max_window_bits=")); err == nil {
313+
if v >= 8 && v <= 15 {
314+
copts.clientMaxWindowBits = v
315+
}
316+
}
308317
continue
309318
}
319+
320+
if strings.HasPrefix(p, "server_max_window_bits=") {
321+
vstr := strings.TrimPrefix(p, "server_max_window_bits=")
322+
v, err := strconv.Atoi(vstr)
323+
if err != nil || v < 8 || v > 15 {
324+
return nil, false // invalid per RFC
325+
}
326+
copts.serverMaxWindowBits = v
327+
continue
328+
}
329+
310330
return nil, false
311331
}
312332
return copts, true

accept_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,22 @@ func Test_selectDeflate(t *testing.T) {
515515
expCopts: &compressionOptions{
516516
clientNoContextTakeover: true,
517517
serverNoContextTakeover: true,
518+
519+
clientMaxWindowBits: 15,
520+
serverMaxWindowBits: 0,
521+
},
522+
expOK: true,
523+
},
524+
{
525+
name: "permessage-deflate/custom-client-window-bits",
526+
mode: CompressionNoContextTakeover,
527+
header: "permessage-deflate; client_max_window_bits=12",
528+
expCopts: &compressionOptions{
529+
clientNoContextTakeover: true,
530+
serverNoContextTakeover: true,
531+
532+
clientMaxWindowBits: 12,
533+
serverMaxWindowBits: 0,
518534
},
519535
expOK: true,
520536
},
@@ -531,6 +547,9 @@ func Test_selectDeflate(t *testing.T) {
531547
expCopts: &compressionOptions{
532548
clientNoContextTakeover: true,
533549
serverNoContextTakeover: true,
550+
551+
clientMaxWindowBits: 15,
552+
serverMaxWindowBits: 0,
534553
},
535554
expOK: true,
536555
},

compress.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
package websocket
44

55
import (
6-
"compress/flate"
6+
"strconv"
7+
8+
"github.com/klauspost/compress/flate"
9+
710
"io"
811
"sync"
912
)
@@ -53,12 +56,18 @@ func (m CompressionMode) opts() *compressionOptions {
5356
return &compressionOptions{
5457
clientNoContextTakeover: m == CompressionNoContextTakeover,
5558
serverNoContextTakeover: m == CompressionNoContextTakeover,
59+
60+
serverMaxWindowBits: 0,
61+
clientMaxWindowBits: 0,
5662
}
5763
}
5864

5965
type compressionOptions struct {
6066
clientNoContextTakeover bool
6167
serverNoContextTakeover bool
68+
69+
serverMaxWindowBits int
70+
clientMaxWindowBits int
6271
}
6372

6473
func (copts *compressionOptions) String() string {
@@ -69,6 +78,11 @@ func (copts *compressionOptions) String() string {
6978
if copts.serverNoContextTakeover {
7079
s += "; server_no_context_takeover"
7180
}
81+
82+
if copts.clientMaxWindowBits != 0 {
83+
s += "; client_max_window_bits=" + strconv.Itoa(copts.clientMaxWindowBits)
84+
}
85+
7286
return s
7387
}
7488

@@ -147,20 +161,24 @@ func putFlateReader(fr io.Reader) {
147161
flateReaderPool.Put(fr)
148162
}
149163

150-
var flateWriterPool sync.Pool
164+
var flateWriterPool [16]sync.Pool
151165

152-
func getFlateWriter(w io.Writer) *flate.Writer {
153-
fw, ok := flateWriterPool.Get().(*flate.Writer)
166+
func getFlateWriter(w io.Writer, bits int) *flate.Writer {
167+
fw, ok := flateWriterPool[bits].Get().(*flate.Writer)
154168
if !ok {
155-
fw, _ = flate.NewWriter(w, flate.BestSpeed)
169+
if bits == 0 {
170+
fw, _ = flate.NewWriter(w, flate.BestCompression)
171+
} else {
172+
fw, _ = flate.NewWriterWindow(w, 1<<bits)
173+
}
156174
return fw
157175
}
158176
fw.Reset(w)
159177
return fw
160178
}
161179

162-
func putFlateWriter(w *flate.Writer) {
163-
flateWriterPool.Put(w)
180+
func putFlateWriter(w *flate.Writer, bits int) {
181+
flateWriterPool[bits].Put(w)
164182
}
165183

166184
type slidingWindow struct {

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
module github.com/coder/websocket
22

33
go 1.23
4+
5+
require github.com/klauspost/compress v1.18.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
2+
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=

write.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package websocket
44

55
import (
66
"bufio"
7-
"compress/flate"
87
"context"
98
"crypto/rand"
109
"encoding/binary"
@@ -14,6 +13,8 @@ import (
1413
"net"
1514
"time"
1615

16+
"github.com/klauspost/compress/flate"
17+
1718
"github.com/coder/websocket/internal/errd"
1819
"github.com/coder/websocket/internal/util"
1920
)
@@ -79,7 +80,7 @@ func (mw *msgWriter) ensureFlate() {
7980
}
8081

8182
if mw.flateWriter == nil {
82-
mw.flateWriter = getFlateWriter(mw.trimWriter)
83+
mw.flateWriter = getFlateWriter(mw.trimWriter, mw.c.copts.serverMaxWindowBits)
8384
}
8485
mw.flate = true
8586
}
@@ -137,7 +138,7 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
137138

138139
func (mw *msgWriter) putFlateWriter() {
139140
if mw.flateWriter != nil {
140-
putFlateWriter(mw.flateWriter)
141+
putFlateWriter(mw.flateWriter, mw.c.copts.serverMaxWindowBits)
141142
mw.flateWriter = nil
142143
}
143144
}

0 commit comments

Comments
 (0)