",
}
for _, src := range srcs {
// The next line shouldn't infinite-loop.
@@ -477,7 +477,7 @@ func TestParseFragmentForeignContentTemplates(t *testing.T) {
}
func BenchmarkParser(b *testing.B) {
- buf, err := ioutil.ReadFile("testdata/go1.html")
+ buf, err := os.ReadFile("testdata/go1.html")
if err != nil {
b.Fatalf("could not read testdata/go1.html: %v", err)
}
diff --git a/html/token.go b/html/token.go
index de67f938a1..3c57880d69 100644
--- a/html/token.go
+++ b/html/token.go
@@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() {
return
}
switch c {
- case ' ', '\n', '\r', '\t', '\f', '/':
- z.pendingAttr[0].end = z.raw.end - 1
- return
case '=':
if z.pendingAttr[0].start+1 == z.raw.end {
// WHATWG 13.2.5.32, if we see an equals sign before the attribute name
@@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() {
continue
}
fallthrough
- case '>':
+ case ' ', '\n', '\r', '\t', '\f', '/', '>':
+ // WHATWG 13.2.5.33 Attribute name state
+ // We need to reconsume the char in the after attribute name state to support the / character
z.raw.end--
z.pendingAttr[0].end = z.raw.end
return
@@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() {
if z.err != nil {
return
}
+ if c == '/' {
+ // WHATWG 13.2.5.34 After attribute name state
+ // U+002F SOLIDUS (/) - Switch to the self-closing start tag state.
+ return
+ }
if c != '=' {
z.raw.end--
return
diff --git a/html/token_test.go b/html/token_test.go
index b2383a951c..a36d112d74 100644
--- a/html/token_test.go
+++ b/html/token_test.go
@@ -7,7 +7,7 @@ package html
import (
"bytes"
"io"
- "io/ioutil"
+ "os"
"reflect"
"runtime"
"strings"
@@ -601,6 +601,21 @@ var tokenTests = []tokenTest{
``,
`
`,
},
+ {
+ "forward slash before attribute name",
+ `
`,
+ `
`,
+ },
+ {
+ "forward slash before attribute name with spaces around",
+ `
`,
+ `
`,
+ },
+ {
+ "forward slash after attribute name followed by a character",
+ `
`,
+ `
`,
+ },
}
func TestTokenizer(t *testing.T) {
@@ -665,7 +680,7 @@ tests:
}
}
// Anything tokenized along with untokenized input or data left in the reader.
- assembled, err := ioutil.ReadAll(io.MultiReader(&tokenized, bytes.NewReader(z.Buffered()), r))
+ assembled, err := io.ReadAll(io.MultiReader(&tokenized, bytes.NewReader(z.Buffered()), r))
if err != nil {
t.Errorf("%s: ReadAll: %v", test.desc, err)
continue tests
@@ -851,7 +866,7 @@ const (
)
func benchmarkTokenizer(b *testing.B, level int) {
- buf, err := ioutil.ReadFile("testdata/go1.html")
+ buf, err := os.ReadFile("testdata/go1.html")
if err != nil {
b.Fatalf("could not read testdata/go1.html: %v", err)
}
diff --git a/http/httpguts/httplex.go b/http/httpguts/httplex.go
index 6e071e8524..9b4de94019 100644
--- a/http/httpguts/httplex.go
+++ b/http/httpguts/httplex.go
@@ -12,7 +12,7 @@ import (
"golang.org/x/net/idna"
)
-var isTokenTable = [127]bool{
+var isTokenTable = [256]bool{
'!': true,
'#': true,
'$': true,
@@ -93,12 +93,7 @@ var isTokenTable = [127]bool{
}
func IsTokenRune(r rune) bool {
- i := int(r)
- return i < len(isTokenTable) && isTokenTable[i]
-}
-
-func isNotToken(r rune) bool {
- return !IsTokenRune(r)
+ return r < utf8.RuneSelf && isTokenTable[byte(r)]
}
// HeaderValuesContainsToken reports whether any string in values
@@ -202,8 +197,8 @@ func ValidHeaderFieldName(v string) bool {
if len(v) == 0 {
return false
}
- for _, r := range v {
- if !IsTokenRune(r) {
+ for i := 0; i < len(v); i++ {
+ if !isTokenTable[v[i]] {
return false
}
}
diff --git a/http/httpguts/httplex_test.go b/http/httpguts/httplex_test.go
index a2c57f3927..791440b1a7 100644
--- a/http/httpguts/httplex_test.go
+++ b/http/httpguts/httplex_test.go
@@ -20,7 +20,7 @@ func isSeparator(c rune) bool {
return false
}
-func TestIsToken(t *testing.T) {
+func TestIsTokenRune(t *testing.T) {
for i := 0; i <= 130; i++ {
r := rune(i)
expected := isChar(r) && !isCtl(r) && !isSeparator(r)
@@ -30,6 +30,15 @@ func TestIsToken(t *testing.T) {
}
}
+func BenchmarkIsTokenRune(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var r rune
+ for ; r < 1024; r++ {
+ IsTokenRune(r)
+ }
+ }
+}
+
func TestHeaderValuesContainsToken(t *testing.T) {
tests := []struct {
vals []string
@@ -100,6 +109,44 @@ func TestHeaderValuesContainsToken(t *testing.T) {
}
}
+func TestValidHeaderFieldName(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"", false},
+ {"Accept Charset", false},
+ {"Accept-Charset", true},
+ {"AccepT-EncodinG", true},
+ {"CONNECTION", true},
+ {"résumé", false},
+ }
+ for _, tt := range tests {
+ got := ValidHeaderFieldName(tt.in)
+ if tt.want != got {
+ t.Errorf("ValidHeaderFieldName(%q) = %t; want %t", tt.in, got, tt.want)
+ }
+ }
+}
+
+func BenchmarkValidHeaderFieldName(b *testing.B) {
+ names := []string{
+ "",
+ "Accept Charset",
+ "Accept-Charset",
+ "AccepT-EncodinG",
+ "CONNECTION",
+ "résumé",
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for _, name := range names {
+ ValidHeaderFieldName(name)
+ }
+ }
+}
+
func TestPunycodeHostPort(t *testing.T) {
tests := []struct {
in, want string
diff --git a/http/httpproxy/go19_test.go b/http/httpproxy/go19_test.go
index 5f6e3d7ff1..5fca5ac454 100644
--- a/http/httpproxy/go19_test.go
+++ b/http/httpproxy/go19_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build go1.9
-// +build go1.9
package httpproxy_test
diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go
index c3bd9a1eeb..d89c257ae7 100644
--- a/http/httpproxy/proxy.go
+++ b/http/httpproxy/proxy.go
@@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"net"
+ "net/netip"
"net/url"
"os"
"strings"
@@ -149,10 +150,7 @@ func parseProxy(proxy string) (*url.URL, error) {
}
proxyURL, err := url.Parse(proxy)
- if err != nil ||
- (proxyURL.Scheme != "http" &&
- proxyURL.Scheme != "https" &&
- proxyURL.Scheme != "socks5") {
+ if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" {
// proxy was bogus. Try prepending "http://" to it and
// see if that parses correctly. If not, we fall
// through and complain about the original one.
@@ -180,8 +178,10 @@ func (cfg *config) useProxy(addr string) bool {
if host == "localhost" {
return false
}
- ip := net.ParseIP(host)
- if ip != nil {
+ nip, err := netip.ParseAddr(host)
+ var ip net.IP
+ if err == nil {
+ ip = net.IP(nip.AsSlice())
if ip.IsLoopback() {
return false
}
@@ -363,6 +363,9 @@ type domainMatch struct {
}
func (m domainMatch) match(host, port string, ip net.IP) bool {
+ if ip != nil {
+ return false
+ }
if strings.HasSuffix(host, m.host) || (m.matchHost && host == m.host[1:]) {
return m.port == "" || m.port == port
}
diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go
index d763732950..a1dd2e83fd 100644
--- a/http/httpproxy/proxy_test.go
+++ b/http/httpproxy/proxy_test.go
@@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{
HTTPProxy: "cache.corp.example.com",
},
want: "http://cache.corp.example.com",
+}, {
+ // single label domain is recognized as scheme by url.Parse
+ cfg: httpproxy.Config{
+ HTTPProxy: "localhost",
+ },
+ want: "http://localhost",
}, {
cfg: httpproxy.Config{
HTTPProxy: "https://cache.corp.example.com",
@@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{
HTTPProxy: "socks5://127.0.0.1",
},
want: "socks5://127.0.0.1",
+}, {
+ // Preserve unknown schemes.
+ cfg: httpproxy.Config{
+ HTTPProxy: "foo://host",
+ },
+ want: "foo://host",
}, {
// Don't use secure for http
cfg: httpproxy.Config{
@@ -199,6 +211,13 @@ var proxyForURLTests = []proxyForURLTest{{
},
req: "http://www.xn--fsq092h.com",
want: "",
+}, {
+ cfg: httpproxy.Config{
+ NoProxy: "example.com",
+ HTTPProxy: "proxy",
+ },
+ req: "http://[1000::%25.example.com]:123",
+ want: "http://proxy",
},
}
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 780968d6c1..e81b73e6a7 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -8,8 +8,8 @@ package http2
import (
"context"
- "crypto/tls"
"errors"
+ "net"
"net/http"
"sync"
)
@@ -158,7 +158,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) {
// This code decides which ones live or die.
// The return value used is whether c was used.
// c is never closed.
-func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
+func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
p.mu.Lock()
for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() {
@@ -194,8 +194,8 @@ type addConnCall struct {
err error
}
-func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
- cc, err := t.NewClientConn(tc)
+func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
+ cc, err := t.NewClientConn(nc)
p := c.p
p.mu.Lock()
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
new file mode 100644
index 0000000000..f9e9a2fdaa
--- /dev/null
+++ b/http2/clientconn_test.go
@@ -0,0 +1,595 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Infrastructure for testing ClientConn.RoundTrip.
+// Put actual tests in transport_test.go.
+
+package http2
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net/http"
+ "reflect"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/gate"
+)
+
+// TestTestClientConn demonstrates usage of testClientConn.
+func TestTestClientConn(t *testing.T) {
+ // newTestClientConn creates a *ClientConn and surrounding test infrastructure.
+ tc := newTestClientConn(t)
+
+ // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
+ // and sends a SETTINGS frame to the client.
+ //
+ // Additional settings may be provided as optional parameters to greet.
+ tc.greet()
+
+ // Request bodies must either be constant (bytes.Buffer, strings.Reader)
+ // or created with newRequestBody.
+ body := tc.newRequestBody()
+ body.writeBytes(10) // 10 arbitrary bytes...
+ body.closeWithError(io.EOF) // ...followed by EOF.
+
+ // tc.roundTrip calls RoundTrip, but does not wait for it to return.
+ // It returns a testRoundTrip.
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ // tc has a number of methods to check for expected frames sent.
+ // Here, we look for headers and the request body.
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: false,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"PUT"},
+ ":path": []string{"/"},
+ },
+ })
+ // Expect 10 bytes of request body in DATA frames.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: 10,
+ multiple: true,
+ })
+
+ // tc.writeHeaders sends a HEADERS frame back to the client.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // Now that we've received headers, RoundTrip has finished.
+ // testRoundTrip has various methods to examine the response,
+ // or to fetch the response and/or error returned by RoundTrip
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+}
+
+// A testClientConn allows testing ClientConn.RoundTrip against a fake server.
+//
+// A test using testClientConn consists of:
+// - actions on the client (calling RoundTrip, making data available to Request.Body);
+// - validation of frames sent by the client to the server; and
+// - providing frames from the server to the client.
+//
+// testClientConn manages synchronization, so tests can generally be written as
+// a linear sequence of actions and validations without additional synchronization.
+type testClientConn struct {
+ t *testing.T
+
+ tr *Transport
+ fr *Framer
+ cc *ClientConn
+ group *synctestGroup
+ testConnFramer
+
+ encbuf bytes.Buffer
+ enc *hpack.Encoder
+
+ roundtrips []*testRoundTrip
+
+ netconn *synctestNetConn
+}
+
+func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
+ tc := &testClientConn{
+ t: t,
+ tr: cc.t,
+ cc: cc,
+ group: cc.t.transportTestHooks.group.(*synctestGroup),
+ }
+
+ // srv is the side controlled by the test.
+ var srv *synctestNetConn
+ if cc.tconn == nil {
+ // If cc.tconn is nil, we're being called with a new conn created by the
+ // Transport's client pool. This path skips dialing the server, and we
+ // create a test connection pair here.
+ cc.tconn, srv = synctestNetPipe(tc.group)
+ } else {
+ // If cc.tconn is non-nil, we're in a test which provides a conn to the
+ // Transport via a TLSNextProto hook. Extract the test connection pair.
+ if tc, ok := cc.tconn.(*tls.Conn); ok {
+ // Unwrap any *tls.Conn to the underlying net.Conn,
+ // to avoid dealing with encryption in tests.
+ cc.tconn = tc.NetConn()
+ }
+ srv = cc.tconn.(*synctestNetConn).peer
+ }
+
+ srv.SetReadDeadline(tc.group.Now())
+ srv.autoWait = true
+ tc.netconn = srv
+ tc.enc = hpack.NewEncoder(&tc.encbuf)
+ tc.fr = NewFramer(srv, srv)
+ tc.testConnFramer = testConnFramer{
+ t: t,
+ fr: tc.fr,
+ dec: hpack.NewDecoder(initialHeaderTableSize, nil),
+ }
+ tc.fr.SetMaxReadFrameSize(10 << 20)
+ t.Cleanup(func() {
+ tc.closeWrite()
+ })
+
+ return tc
+}
+
+func (tc *testClientConn) readClientPreface() {
+ tc.t.Helper()
+ // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
+ buf := make([]byte, len(clientPreface))
+ if _, err := io.ReadFull(tc.netconn, buf); err != nil {
+ tc.t.Fatalf("reading preface: %v", err)
+ }
+ if !bytes.Equal(buf, clientPreface) {
+ tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
+ }
+}
+
+func newTestClientConn(t *testing.T, opts ...any) *testClientConn {
+ t.Helper()
+
+ tt := newTestTransport(t, opts...)
+ const singleUse = false
+ _, err := tt.tr.newClientConn(nil, singleUse)
+ if err != nil {
+ t.Fatalf("newClientConn: %v", err)
+ }
+
+ return tt.getConn()
+}
+
+// sync waits for the ClientConn under test to reach a stable state,
+// with all goroutines blocked on some input.
+func (tc *testClientConn) sync() {
+ tc.group.Wait()
+}
+
+// advance advances synthetic time by a duration.
+func (tc *testClientConn) advance(d time.Duration) {
+ tc.group.AdvanceTime(d)
+ tc.sync()
+}
+
+// hasFrame reports whether a frame is available to be read.
+func (tc *testClientConn) hasFrame() bool {
+ return len(tc.netconn.Peek()) > 0
+}
+
+// isClosed reports whether the peer has closed the connection.
+func (tc *testClientConn) isClosed() bool {
+ return tc.netconn.IsClosedByPeer()
+}
+
+// closeWrite causes the net.Conn used by the ClientConn to return a error
+// from Read calls.
+func (tc *testClientConn) closeWrite() {
+ tc.netconn.Close()
+}
+
+// testRequestBody is a Request.Body for use in tests.
+type testRequestBody struct {
+ tc *testClientConn
+ gate gate.Gate
+
+ // At most one of buf or bytes can be set at any given time:
+ buf bytes.Buffer // specific bytes to read from the body
+ bytes int // body contains this many arbitrary bytes
+
+ err error // read error (comes after any available bytes)
+}
+
+func (tc *testClientConn) newRequestBody() *testRequestBody {
+ b := &testRequestBody{
+ tc: tc,
+ gate: gate.New(false),
+ }
+ return b
+}
+
+func (b *testRequestBody) unlock() {
+ b.gate.Unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
+}
+
+// Read is called by the ClientConn to read from a request body.
+func (b *testRequestBody) Read(p []byte) (n int, _ error) {
+ if err := b.gate.WaitAndLock(context.Background()); err != nil {
+ return 0, err
+ }
+ defer b.unlock()
+ switch {
+ case b.buf.Len() > 0:
+ return b.buf.Read(p)
+ case b.bytes > 0:
+ if len(p) > b.bytes {
+ p = p[:b.bytes]
+ }
+ b.bytes -= len(p)
+ for i := range p {
+ p[i] = 'A'
+ }
+ return len(p), nil
+ default:
+ return 0, b.err
+ }
+}
+
+// Close is called by the ClientConn when it is done reading from a request body.
+func (b *testRequestBody) Close() error {
+ return nil
+}
+
+// writeBytes adds n arbitrary bytes to the body.
+func (b *testRequestBody) writeBytes(n int) {
+ defer b.tc.sync()
+ b.gate.Lock()
+ defer b.unlock()
+ b.bytes += n
+ b.checkWrite()
+ b.tc.sync()
+}
+
+// Write adds bytes to the body.
+func (b *testRequestBody) Write(p []byte) (int, error) {
+ defer b.tc.sync()
+ b.gate.Lock()
+ defer b.unlock()
+ n, err := b.buf.Write(p)
+ b.checkWrite()
+ return n, err
+}
+
+func (b *testRequestBody) checkWrite() {
+ if b.bytes > 0 && b.buf.Len() > 0 {
+ b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
+ }
+ if b.err != nil {
+ b.tc.t.Fatalf("can't write to request body after closeWithError")
+ }
+}
+
+// closeWithError sets an error which will be returned by Read.
+func (b *testRequestBody) closeWithError(err error) {
+ defer b.tc.sync()
+ b.gate.Lock()
+ defer b.unlock()
+ b.err = err
+}
+
+// roundTrip starts a RoundTrip call.
+//
+// (Note that the RoundTrip won't complete until response headers are received,
+// the request times out, or some other terminal condition is reached.)
+func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{
+ t: tc.t,
+ donec: make(chan struct{}),
+ }
+ tc.roundtrips = append(tc.roundtrips, rt)
+ go func() {
+ tc.group.Join()
+ defer close(rt.donec)
+ rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) {
+ rt.id.Store(cs.ID)
+ })
+ }()
+ tc.sync()
+
+ tc.t.Cleanup(func() {
+ if !rt.done() {
+ return
+ }
+ res, _ := rt.result()
+ if res != nil {
+ res.Body.Close()
+ }
+ })
+
+ return rt
+}
+
+func (tc *testClientConn) greet(settings ...Setting) {
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.writeSettings(settings...)
+ tc.writeSettingsAck()
+ tc.wantFrameType(FrameSettings) // acknowledgement
+}
+
+// makeHeaderBlockFragment encodes headers in a form suitable for inclusion
+// in a HEADERS or CONTINUATION frame.
+//
+// It takes a list of alernating names and values.
+func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
+ if len(s)%2 != 0 {
+ tc.t.Fatalf("uneven list of header name/value pairs")
+ }
+ tc.encbuf.Reset()
+ for i := 0; i < len(s); i += 2 {
+ tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
+ }
+ return tc.encbuf.Bytes()
+}
+
+// inflowWindow returns the amount of inbound flow control available for a stream,
+// or for the connection if streamID is 0.
+func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
+ tc.cc.mu.Lock()
+ defer tc.cc.mu.Unlock()
+ if streamID == 0 {
+ return tc.cc.inflow.avail + tc.cc.inflow.unsent
+ }
+ cs := tc.cc.streams[streamID]
+ if cs == nil {
+ tc.t.Errorf("no stream with id %v", streamID)
+ return -1
+ }
+ return cs.inflow.avail + cs.inflow.unsent
+}
+
+// testRoundTrip manages a RoundTrip in progress.
+type testRoundTrip struct {
+ t *testing.T
+ resp *http.Response
+ respErr error
+ donec chan struct{}
+ id atomic.Uint32
+}
+
+// streamID returns the HTTP/2 stream ID of the request.
+func (rt *testRoundTrip) streamID() uint32 {
+ id := rt.id.Load()
+ if id == 0 {
+ panic("stream ID unknown")
+ }
+ return id
+}
+
+// done reports whether RoundTrip has returned.
+func (rt *testRoundTrip) done() bool {
+ select {
+ case <-rt.donec:
+ return true
+ default:
+ return false
+ }
+}
+
+// result returns the result of the RoundTrip.
+func (rt *testRoundTrip) result() (*http.Response, error) {
+ t := rt.t
+ t.Helper()
+ select {
+ case <-rt.donec:
+ default:
+ t.Fatalf("RoundTrip is not done; want it to be")
+ }
+ return rt.resp, rt.respErr
+}
+
+// response returns the response of a successful RoundTrip.
+// If the RoundTrip unexpectedly failed, it calls t.Fatal.
+func (rt *testRoundTrip) response() *http.Response {
+ t := rt.t
+ t.Helper()
+ resp, err := rt.result()
+ if err != nil {
+ t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
+ }
+ if resp == nil {
+ t.Fatalf("RoundTrip returned nil *Response and nil error")
+ }
+ return resp
+}
+
+// err returns the (possibly nil) error result of RoundTrip.
+func (rt *testRoundTrip) err() error {
+ t := rt.t
+ t.Helper()
+ _, err := rt.result()
+ return err
+}
+
+// wantStatus indicates the expected response StatusCode.
+func (rt *testRoundTrip) wantStatus(want int) {
+ t := rt.t
+ t.Helper()
+ if got := rt.response().StatusCode; got != want {
+ t.Fatalf("got response status %v, want %v", got, want)
+ }
+}
+
+// body reads the contents of the response body.
+func (rt *testRoundTrip) readBody() ([]byte, error) {
+ t := rt.t
+ t.Helper()
+ return io.ReadAll(rt.response().Body)
+}
+
+// wantBody indicates the expected response body.
+// (Note that this consumes the body.)
+func (rt *testRoundTrip) wantBody(want []byte) {
+ t := rt.t
+ t.Helper()
+ got, err := rt.readBody()
+ if err != nil {
+ t.Fatalf("unexpected error reading response body: %v", err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want)
+ }
+}
+
+// wantHeaders indicates the expected response headers.
+func (rt *testRoundTrip) wantHeaders(want http.Header) {
+ t := rt.t
+ t.Helper()
+ res := rt.response()
+ if diff := diffHeaders(res.Header, want); diff != "" {
+ t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+}
+
+// wantTrailers indicates the expected response trailers.
+func (rt *testRoundTrip) wantTrailers(want http.Header) {
+ t := rt.t
+ t.Helper()
+ res := rt.response()
+ if diff := diffHeaders(res.Trailer, want); diff != "" {
+ t.Fatalf("unexpected response trailers:\n%v", diff)
+ }
+}
+
+func diffHeaders(got, want http.Header) string {
+ // nil and 0-length non-nil are equal.
+ if len(got) == 0 && len(want) == 0 {
+ return ""
+ }
+ // We could do a more sophisticated diff here.
+ // DeepEqual is good enough for now.
+ if reflect.DeepEqual(got, want) {
+ return ""
+ }
+ return fmt.Sprintf("got: %v\nwant: %v", got, want)
+}
+
+// A testTransport allows testing Transport.RoundTrip against fake servers.
+// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
+// should use testClientConn instead.
+type testTransport struct {
+ t *testing.T
+ tr *Transport
+ group *synctestGroup
+
+ ccs []*testClientConn
+}
+
+func newTestTransport(t *testing.T, opts ...any) *testTransport {
+ tt := &testTransport{
+ t: t,
+ group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)),
+ }
+ tt.group.Join()
+
+ tr := &Transport{}
+ for _, o := range opts {
+ switch o := o.(type) {
+ case func(*http.Transport):
+ if tr.t1 == nil {
+ tr.t1 = &http.Transport{}
+ }
+ o(tr.t1)
+ case func(*Transport):
+ o(tr)
+ case *Transport:
+ tr = o
+ }
+ }
+ tt.tr = tr
+
+ tr.transportTestHooks = &transportTestHooks{
+ group: tt.group,
+ newclientconn: func(cc *ClientConn) {
+ tc := newTestClientConnFromClientConn(t, cc)
+ tt.ccs = append(tt.ccs, tc)
+ },
+ }
+
+ t.Cleanup(func() {
+ tt.sync()
+ if len(tt.ccs) > 0 {
+ t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
+ }
+ tt.group.Close(t)
+ })
+
+ return tt
+}
+
+func (tt *testTransport) sync() {
+ tt.group.Wait()
+}
+
+func (tt *testTransport) advance(d time.Duration) {
+ tt.group.AdvanceTime(d)
+ tt.sync()
+}
+
+func (tt *testTransport) hasConn() bool {
+ return len(tt.ccs) > 0
+}
+
+func (tt *testTransport) getConn() *testClientConn {
+ tt.t.Helper()
+ if len(tt.ccs) == 0 {
+ tt.t.Fatalf("no new ClientConns created; wanted one")
+ }
+ tc := tt.ccs[0]
+ tt.ccs = tt.ccs[1:]
+ tc.sync()
+ tc.readClientPreface()
+ tc.sync()
+ return tc
+}
+
+func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{
+ t: tt.t,
+ donec: make(chan struct{}),
+ }
+ go func() {
+ tt.group.Join()
+ defer close(rt.donec)
+ rt.resp, rt.respErr = tt.tr.RoundTrip(req)
+ }()
+ tt.sync()
+
+ tt.t.Cleanup(func() {
+ if !rt.done() {
+ return
+ }
+ res, _ := rt.result()
+ if res != nil {
+ res.Body.Close()
+ }
+ })
+
+ return rt
+}
diff --git a/http2/config.go b/http2/config.go
new file mode 100644
index 0000000000..ca645d9a1a
--- /dev/null
+++ b/http2/config.go
@@ -0,0 +1,122 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "math"
+ "net/http"
+ "time"
+)
+
+// http2Config is a package-internal version of net/http.HTTP2Config.
+//
+// http.HTTP2Config was added in Go 1.24.
+// When running with a version of net/http that includes HTTP2Config,
+// we merge the configuration with the fields in Transport or Server
+// to produce an http2Config.
+//
+// Zero valued fields in http2Config are interpreted as in the
+// net/http.HTTPConfig documentation.
+//
+// Precedence order for reconciling configurations is:
+//
+// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
+// - Otherwise use the http2.{Server.Transport} value.
+// - If the resulting value is zero or out of range, use a default.
+type http2Config struct {
+ MaxConcurrentStreams uint32
+ MaxDecoderHeaderTableSize uint32
+ MaxEncoderHeaderTableSize uint32
+ MaxReadFrameSize uint32
+ MaxUploadBufferPerConnection int32
+ MaxUploadBufferPerStream int32
+ SendPingTimeout time.Duration
+ PingTimeout time.Duration
+ WriteByteTimeout time.Duration
+ PermitProhibitedCipherSuites bool
+ CountError func(errType string)
+}
+
+// configFromServer merges configuration settings from
+// net/http.Server.HTTP2Config and http2.Server.
+func configFromServer(h1 *http.Server, h2 *Server) http2Config {
+ conf := http2Config{
+ MaxConcurrentStreams: h2.MaxConcurrentStreams,
+ MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
+ MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
+ MaxReadFrameSize: h2.MaxReadFrameSize,
+ MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
+ MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
+ SendPingTimeout: h2.ReadIdleTimeout,
+ PingTimeout: h2.PingTimeout,
+ WriteByteTimeout: h2.WriteByteTimeout,
+ PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
+ CountError: h2.CountError,
+ }
+ fillNetHTTPServerConfig(&conf, h1)
+ setConfigDefaults(&conf, true)
+ return conf
+}
+
+// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2
+// (the net/http Transport).
+func configFromTransport(h2 *Transport) http2Config {
+ conf := http2Config{
+ MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
+ MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
+ MaxReadFrameSize: h2.MaxReadFrameSize,
+ SendPingTimeout: h2.ReadIdleTimeout,
+ PingTimeout: h2.PingTimeout,
+ WriteByteTimeout: h2.WriteByteTimeout,
+ }
+
+ // Unlike most config fields, where out-of-range values revert to the default,
+ // Transport.MaxReadFrameSize clips.
+ if conf.MaxReadFrameSize < minMaxFrameSize {
+ conf.MaxReadFrameSize = minMaxFrameSize
+ } else if conf.MaxReadFrameSize > maxFrameSize {
+ conf.MaxReadFrameSize = maxFrameSize
+ }
+
+ if h2.t1 != nil {
+ fillNetHTTPTransportConfig(&conf, h2.t1)
+ }
+ setConfigDefaults(&conf, false)
+ return conf
+}
+
+func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
+ if *v < minval || *v > maxval {
+ *v = defval
+ }
+}
+
+func setConfigDefaults(conf *http2Config, server bool) {
+ setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
+ setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
+ setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
+ if server {
+ setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
+ } else {
+ setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
+ }
+ if server {
+ setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
+ } else {
+ setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
+ }
+ setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
+ setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
+}
+
+// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
+// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
+func adjustHTTP1MaxHeaderSize(n int64) int64 {
+ // http2's count is in a slightly different unit and includes 32 bytes per pair.
+ // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
+ const perFieldOverhead = 32 // per http2 spec
+ const typicalHeaders = 10 // conservative
+ return n + typicalHeaders*perFieldOverhead
+}
diff --git a/http2/config_go124.go b/http2/config_go124.go
new file mode 100644
index 0000000000..5b516c55ff
--- /dev/null
+++ b/http2/config_go124.go
@@ -0,0 +1,61 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http2
+
+import "net/http"
+
+// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2.
+func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {
+ fillNetHTTPConfig(conf, srv.HTTP2)
+}
+
+// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2.
+func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {
+ fillNetHTTPConfig(conf, tr.HTTP2)
+}
+
+func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
+ if h2 == nil {
+ return
+ }
+ if h2.MaxConcurrentStreams != 0 {
+ conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
+ }
+ if h2.MaxEncoderHeaderTableSize != 0 {
+ conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
+ }
+ if h2.MaxDecoderHeaderTableSize != 0 {
+ conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
+ }
+ if h2.MaxConcurrentStreams != 0 {
+ conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
+ }
+ if h2.MaxReadFrameSize != 0 {
+ conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
+ }
+ if h2.MaxReceiveBufferPerConnection != 0 {
+ conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
+ }
+ if h2.MaxReceiveBufferPerStream != 0 {
+ conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
+ }
+ if h2.SendPingTimeout != 0 {
+ conf.SendPingTimeout = h2.SendPingTimeout
+ }
+ if h2.PingTimeout != 0 {
+ conf.PingTimeout = h2.PingTimeout
+ }
+ if h2.WriteByteTimeout != 0 {
+ conf.WriteByteTimeout = h2.WriteByteTimeout
+ }
+ if h2.PermitProhibitedCipherSuites {
+ conf.PermitProhibitedCipherSuites = true
+ }
+ if h2.CountError != nil {
+ conf.CountError = h2.CountError
+ }
+}
diff --git a/http2/config_pre_go124.go b/http2/config_pre_go124.go
new file mode 100644
index 0000000000..060fd6c64c
--- /dev/null
+++ b/http2/config_pre_go124.go
@@ -0,0 +1,16 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !go1.24
+
+package http2
+
+import "net/http"
+
+// Pre-Go 1.24 fallback.
+// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24.
+
+func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {}
+
+func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {}
diff --git a/http2/config_test.go b/http2/config_test.go
new file mode 100644
index 0000000000..b8e7a7b043
--- /dev/null
+++ b/http2/config_test.go
@@ -0,0 +1,95 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http2
+
+import (
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestConfigServerSettings(t *testing.T) {
+ config := &http.HTTP2Config{
+ MaxConcurrentStreams: 1,
+ MaxDecoderHeaderTableSize: 1<<20 + 2,
+ MaxEncoderHeaderTableSize: 1<<20 + 3,
+ MaxReadFrameSize: 1<<20 + 4,
+ MaxReceiveBufferPerConnection: 64<<10 + 5,
+ MaxReceiveBufferPerStream: 64<<10 + 6,
+ }
+ const maxHeaderBytes = 4096 + 7
+ st := newServerTester(t, nil, func(s *http.Server) {
+ s.MaxHeaderBytes = maxHeaderBytes
+ s.HTTP2 = config
+ })
+ st.writePreface()
+ st.writeSettings()
+ st.wantSettings(map[SettingID]uint32{
+ SettingMaxConcurrentStreams: uint32(config.MaxConcurrentStreams),
+ SettingHeaderTableSize: uint32(config.MaxDecoderHeaderTableSize),
+ SettingInitialWindowSize: uint32(config.MaxReceiveBufferPerStream),
+ SettingMaxFrameSize: uint32(config.MaxReadFrameSize),
+ SettingMaxHeaderListSize: maxHeaderBytes + (32 * 10),
+ })
+}
+
+func TestConfigTransportSettings(t *testing.T) {
+ config := &http.HTTP2Config{
+ MaxConcurrentStreams: 1, // ignored by Transport
+ MaxDecoderHeaderTableSize: 1<<20 + 2,
+ MaxEncoderHeaderTableSize: 1<<20 + 3,
+ MaxReadFrameSize: 1<<20 + 4,
+ MaxReceiveBufferPerConnection: 64<<10 + 5,
+ MaxReceiveBufferPerStream: 64<<10 + 6,
+ }
+ const maxHeaderBytes = 4096 + 7
+ tc := newTestClientConn(t, func(tr *http.Transport) {
+ tr.HTTP2 = config
+ tr.MaxResponseHeaderBytes = maxHeaderBytes
+ })
+ tc.wantSettings(map[SettingID]uint32{
+ SettingHeaderTableSize: uint32(config.MaxDecoderHeaderTableSize),
+ SettingInitialWindowSize: uint32(config.MaxReceiveBufferPerStream),
+ SettingMaxFrameSize: uint32(config.MaxReadFrameSize),
+ SettingMaxHeaderListSize: maxHeaderBytes + (32 * 10),
+ })
+ tc.wantWindowUpdate(0, uint32(config.MaxReceiveBufferPerConnection))
+}
+
+func TestConfigPingTimeoutServer(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = 2 * time.Second
+ s.PingTimeout = 3 * time.Second
+ })
+ st.greet()
+
+ st.advance(2 * time.Second)
+ _ = readFrame[*PingFrame](t, st)
+ st.advance(3 * time.Second)
+ st.wantClosed()
+}
+
+func TestConfigPingTimeoutTransport(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 2 * time.Second
+ tr.PingTimeout = 3 * time.Second
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ tc.advance(2 * time.Second)
+ tc.wantFrameType(FramePing)
+ tc.advance(3 * time.Second)
+ err := rt.err()
+ if err == nil {
+ t.Fatalf("expected connection to close")
+ }
+}
diff --git a/http2/connframes_test.go b/http2/connframes_test.go
new file mode 100644
index 0000000000..2c4532571a
--- /dev/null
+++ b/http2/connframes_test.go
@@ -0,0 +1,431 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "os"
+ "reflect"
+ "slices"
+ "testing"
+
+ "golang.org/x/net/http2/hpack"
+)
+
+type testConnFramer struct {
+ t testing.TB
+ fr *Framer
+ dec *hpack.Decoder
+}
+
+// readFrame reads the next frame.
+// It returns nil if the conn is closed or no frames are available.
+func (tf *testConnFramer) readFrame() Frame {
+ tf.t.Helper()
+ fr, err := tf.fr.ReadFrame()
+ if err == io.EOF || err == os.ErrDeadlineExceeded {
+ return nil
+ }
+ if err != nil {
+ tf.t.Fatalf("ReadFrame: %v", err)
+ }
+ return fr
+}
+
+type readFramer interface {
+ readFrame() Frame
+}
+
+// readFrame reads a frame of a specific type.
+func readFrame[T any](t testing.TB, framer readFramer) T {
+ t.Helper()
+ var v T
+ fr := framer.readFrame()
+ if fr == nil {
+ t.Fatalf("got no frame, want frame %T", v)
+ }
+ v, ok := fr.(T)
+ if !ok {
+ t.Fatalf("got frame %T, want %T", fr, v)
+ }
+ return v
+}
+
+// wantFrameType reads the next frame.
+// It produces an error if the frame type is not the expected value.
+func (tf *testConnFramer) wantFrameType(want FrameType) {
+ tf.t.Helper()
+ fr := tf.readFrame()
+ if fr == nil {
+ tf.t.Fatalf("got no frame, want frame %v", want)
+ }
+ if got := fr.Header().Type; got != want {
+ tf.t.Fatalf("got frame %v, want %v", got, want)
+ }
+}
+
+// wantUnorderedFrames reads frames until every condition in want has been satisfied.
+//
+// want is a list of func(*SomeFrame) bool.
+// wantUnorderedFrames will call each func with frames of the appropriate type
+// until the func returns true.
+// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
+// or all funcs with that type have returned true), or if the framer runs out of frames
+// with unsatisfied funcs.
+//
+// Example:
+//
+// // Read a SETTINGS frame, and any number of DATA frames for a stream.
+// // The SETTINGS frame may appear anywhere in the sequence.
+// // The last DATA frame must indicate the end of the stream.
+// tf.wantUnorderedFrames(
+// func(f *SettingsFrame) bool {
+// return true
+// },
+// func(f *DataFrame) bool {
+// return f.StreamEnded()
+// },
+// )
+func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
+ tf.t.Helper()
+ want = slices.Clone(want)
+ seen := 0
+frame:
+ for seen < len(want) && !tf.t.Failed() {
+ fr := tf.readFrame()
+ if fr == nil {
+ break
+ }
+ for i, f := range want {
+ if f == nil {
+ continue
+ }
+ typ := reflect.TypeOf(f)
+ if typ.Kind() != reflect.Func ||
+ typ.NumIn() != 1 ||
+ typ.NumOut() != 1 ||
+ typ.Out(0) != reflect.TypeOf(true) {
+ tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
+ }
+ if typ.In(0) == reflect.TypeOf(fr) {
+ out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
+ if out[0].Bool() {
+ want[i] = nil
+ seen++
+ }
+ continue frame
+ }
+ }
+ tf.t.Errorf("got unexpected frame type %T", fr)
+ }
+ if seen < len(want) {
+ for _, f := range want {
+ if f == nil {
+ continue
+ }
+ tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
+ }
+ tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
+ }
+}
+
+type wantHeader struct {
+ streamID uint32
+ endStream bool
+ header http.Header
+}
+
+// wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
+// and asserts that they contain the expected headers.
+func (tf *testConnFramer) wantHeaders(want wantHeader) {
+ tf.t.Helper()
+
+ hf := readFrame[*HeadersFrame](tf.t, tf)
+ if got, want := hf.StreamID, want.streamID; got != want {
+ tf.t.Fatalf("got stream ID %v, want %v", got, want)
+ }
+ if got, want := hf.StreamEnded(), want.endStream; got != want {
+ tf.t.Fatalf("got stream ended %v, want %v", got, want)
+ }
+
+ gotHeader := make(http.Header)
+ tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
+ gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
+ })
+ defer tf.dec.SetEmitFunc(nil)
+ if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
+ tf.t.Fatalf("decoding HEADERS frame: %v", err)
+ }
+ headersEnded := hf.HeadersEnded()
+ for !headersEnded {
+ cf := readFrame[*ContinuationFrame](tf.t, tf)
+ if cf == nil {
+ tf.t.Fatalf("got end of frames, want CONTINUATION")
+ }
+ if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
+ tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
+ }
+ headersEnded = cf.HeadersEnded()
+ }
+ if err := tf.dec.Close(); err != nil {
+ tf.t.Fatalf("hpack decoding error: %v", err)
+ }
+
+ for k, v := range want.header {
+ if !reflect.DeepEqual(v, gotHeader[k]) {
+ tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
+ }
+ }
+}
+
+// decodeHeader supports some older server tests.
+// TODO: rewrite those tests to use newer, more convenient test APIs.
+func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
+ tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
+ if hf.Name == "date" {
+ return
+ }
+ pairs = append(pairs, [2]string{hf.Name, hf.Value})
+ })
+ defer tf.dec.SetEmitFunc(nil)
+ if _, err := tf.dec.Write(headerBlock); err != nil {
+ tf.t.Fatalf("hpack decoding error: %v", err)
+ }
+ if err := tf.dec.Close(); err != nil {
+ tf.t.Fatalf("hpack decoding error: %v", err)
+ }
+ return pairs
+}
+
+type wantData struct {
+ streamID uint32
+ endStream bool
+ size int
+ data []byte
+ multiple bool // data may be spread across multiple DATA frames
+}
+
+// wantData reads zero or more DATA frames, and asserts that they match the expectation.
+func (tf *testConnFramer) wantData(want wantData) {
+ tf.t.Helper()
+ gotSize := 0
+ gotEndStream := false
+ if want.data != nil {
+ want.size = len(want.data)
+ }
+ var gotData []byte
+ for {
+ fr := tf.readFrame()
+ if fr == nil {
+ break
+ }
+ data, ok := fr.(*DataFrame)
+ if !ok {
+ tf.t.Fatalf("got frame %T, want DataFrame", fr)
+ }
+ if want.data != nil {
+ gotData = append(gotData, data.Data()...)
+ }
+ gotSize += len(data.Data())
+ if data.StreamEnded() {
+ gotEndStream = true
+ break
+ }
+ if !want.endStream && gotSize >= want.size {
+ break
+ }
+ if !want.multiple {
+ break
+ }
+ }
+ if gotSize != want.size {
+ tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
+ }
+ if gotEndStream != want.endStream {
+ tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
+ }
+ if want.data != nil && !bytes.Equal(gotData, want.data) {
+ tf.t.Fatalf("got data %q, want %q", gotData, want.data)
+ }
+}
+
+func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
+ tf.t.Helper()
+ fr := readFrame[*RSTStreamFrame](tf.t, tf)
+ if fr.StreamID != streamID || fr.ErrCode != code {
+ tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code)
+ }
+}
+
+func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
+ fr := readFrame[*SettingsFrame](tf.t, tf)
+ if fr.Header().Flags.Has(FlagSettingsAck) {
+ tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
+ }
+ for wantID, wantVal := range want {
+ gotVal, ok := fr.Value(wantID)
+ if !ok {
+ tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
+ } else if gotVal != wantVal {
+ tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
+ }
+ }
+ if tf.t.Failed() {
+ tf.t.Fatalf("%v", fr)
+ }
+}
+
+func (tf *testConnFramer) wantSettingsAck() {
+ tf.t.Helper()
+ fr := readFrame[*SettingsFrame](tf.t, tf)
+ if !fr.Header().Flags.Has(FlagSettingsAck) {
+ tf.t.Fatal("Settings Frame didn't have ACK set")
+ }
+}
+
+func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
+ tf.t.Helper()
+ fr := readFrame[*GoAwayFrame](tf.t, tf)
+ if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
+ tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code)
+ }
+}
+
+func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
+ tf.t.Helper()
+ wu := readFrame[*WindowUpdateFrame](tf.t, tf)
+ if wu.FrameHeader.StreamID != streamID {
+ tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
+ }
+ if wu.Increment != incr {
+ tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
+ }
+}
+
+func (tf *testConnFramer) wantClosed() {
+ tf.t.Helper()
+ fr, err := tf.fr.ReadFrame()
+ if err == nil {
+ tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
+ }
+ if err == os.ErrDeadlineExceeded {
+ tf.t.Fatalf("connection is not closed; want it to be")
+ }
+}
+
+func (tf *testConnFramer) wantIdle() {
+ tf.t.Helper()
+ fr, err := tf.fr.ReadFrame()
+ if err == nil {
+ tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
+ }
+ if err != os.ErrDeadlineExceeded {
+ tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
+ }
+}
+
+func (tf *testConnFramer) writeSettings(settings ...Setting) {
+ tf.t.Helper()
+ if err := tf.fr.WriteSettings(settings...); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeSettingsAck() {
+ tf.t.Helper()
+ if err := tf.fr.WriteSettingsAck(); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
+ tf.t.Helper()
+ if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
+ tf.t.Helper()
+ if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
+ tf.t.Helper()
+ if err := tf.fr.WriteHeaders(p); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+// writeHeadersMode writes header frames, as modified by mode:
+//
+// - noHeader: Don't write the header.
+// - oneHeader: Write a single HEADERS frame.
+// - splitHeader: Write a HEADERS frame and CONTINUATION frame.
+func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
+ tf.t.Helper()
+ switch mode {
+ case noHeader:
+ case oneHeader:
+ tf.writeHeaders(p)
+ case splitHeader:
+ if len(p.BlockFragment) < 2 {
+ panic("too small")
+ }
+ contData := p.BlockFragment[1:]
+ contEnd := p.EndHeaders
+ p.BlockFragment = p.BlockFragment[:1]
+ p.EndHeaders = false
+ tf.writeHeaders(p)
+ tf.writeContinuation(p.StreamID, contEnd, contData)
+ default:
+ panic("bogus mode")
+ }
+}
+
+func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
+ tf.t.Helper()
+ if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
+ if err := tf.fr.WritePriority(id, p); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
+ tf.t.Helper()
+ if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
+ tf.t.Helper()
+ if err := tf.fr.WritePing(ack, data); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
+ tf.t.Helper()
+ if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
+ tf.t.Fatal(err)
+ }
+}
+
+func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
+ tf.t.Helper()
+ if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
+ tf.t.Fatal(err)
+ }
+}
diff --git a/http2/databuffer.go b/http2/databuffer.go
index a3067f8de7..e6f55cbd16 100644
--- a/http2/databuffer.go
+++ b/http2/databuffer.go
@@ -20,41 +20,44 @@ import (
// TODO: Benchmark to determine if the pools are necessary. The GC may have
// improved enough that we can instead allocate chunks like this:
// make([]byte, max(16<<10, expectedBytesRemaining))
-var (
- dataChunkSizeClasses = []int{
- 1 << 10,
- 2 << 10,
- 4 << 10,
- 8 << 10,
- 16 << 10,
- }
- dataChunkPools = [...]sync.Pool{
- {New: func() interface{} { return make([]byte, 1<<10) }},
- {New: func() interface{} { return make([]byte, 2<<10) }},
- {New: func() interface{} { return make([]byte, 4<<10) }},
- {New: func() interface{} { return make([]byte, 8<<10) }},
- {New: func() interface{} { return make([]byte, 16<<10) }},
- }
-)
+var dataChunkPools = [...]sync.Pool{
+ {New: func() interface{} { return new([1 << 10]byte) }},
+ {New: func() interface{} { return new([2 << 10]byte) }},
+ {New: func() interface{} { return new([4 << 10]byte) }},
+ {New: func() interface{} { return new([8 << 10]byte) }},
+ {New: func() interface{} { return new([16 << 10]byte) }},
+}
func getDataBufferChunk(size int64) []byte {
- i := 0
- for ; i < len(dataChunkSizeClasses)-1; i++ {
- if size <= int64(dataChunkSizeClasses[i]) {
- break
- }
+ switch {
+ case size <= 1<<10:
+ return dataChunkPools[0].Get().(*[1 << 10]byte)[:]
+ case size <= 2<<10:
+ return dataChunkPools[1].Get().(*[2 << 10]byte)[:]
+ case size <= 4<<10:
+ return dataChunkPools[2].Get().(*[4 << 10]byte)[:]
+ case size <= 8<<10:
+ return dataChunkPools[3].Get().(*[8 << 10]byte)[:]
+ default:
+ return dataChunkPools[4].Get().(*[16 << 10]byte)[:]
}
- return dataChunkPools[i].Get().([]byte)
}
func putDataBufferChunk(p []byte) {
- for i, n := range dataChunkSizeClasses {
- if len(p) == n {
- dataChunkPools[i].Put(p)
- return
- }
+ switch len(p) {
+ case 1 << 10:
+ dataChunkPools[0].Put((*[1 << 10]byte)(p))
+ case 2 << 10:
+ dataChunkPools[1].Put((*[2 << 10]byte)(p))
+ case 4 << 10:
+ dataChunkPools[2].Put((*[4 << 10]byte)(p))
+ case 8 << 10:
+ dataChunkPools[3].Put((*[8 << 10]byte)(p))
+ case 16 << 10:
+ dataChunkPools[4].Put((*[16 << 10]byte)(p))
+ default:
+ panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
- panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
// dataBuffer is an io.ReadWriter backed by a list of data chunks.
diff --git a/http2/frame.go b/http2/frame.go
index c1f6b90dc3..81faec7e75 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -490,6 +490,9 @@ func terminalReadFrameError(err error) bool {
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
+//
+// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
+// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
fr.errDetail = nil
if fr.lastFrame != nil {
@@ -1487,7 +1490,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
- case ":method", ":path", ":scheme", ":authority":
+ case ":method", ":path", ":scheme", ":authority", ":protocol":
isRequest = true
case ":status":
isResponse = true
@@ -1495,7 +1498,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
return pseudoHeaderError(hf.Name)
}
// Check for duplicates.
- // This would be a bad algorithm, but N is 4.
+ // This would be a bad algorithm, but N is 5.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {
@@ -1510,19 +1513,18 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
}
func (fr *Framer) maxHeaderStringLen() int {
- v := fr.maxHeaderListSize()
- if uint32(int(v)) == v {
- return int(v)
+ v := int(fr.maxHeaderListSize())
+ if v < 0 {
+ // If maxHeaderListSize overflows an int, use no limit (0).
+ return 0
}
- // They had a crazy big number for MaxHeaderBytes anyway,
- // so give them unlimited header lengths:
- return 0
+ return v
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
// merge them into the provided hf and returns a MetaHeadersFrame
// with the decoded hpack values.
-func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
+func (fr *Framer) readMetaFrame(hf *HeadersFrame) (Frame, error) {
if fr.AllowIllegalReads {
return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
}
@@ -1565,6 +1567,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
+ remainSize = 0
return
}
remainSize -= size
@@ -1577,8 +1580,38 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
+
+ // Avoid parsing large amounts of headers that we will then discard.
+ // If the sender exceeds the max header list size by too much,
+ // skip parsing the fragment and close the connection.
+ //
+ // "Too much" is either any CONTINUATION frame after we've already
+ // exceeded the max header list size (in which case remainSize is 0),
+ // or a frame whose encoded size is more than twice the remaining
+ // header list bytes we're willing to accept.
+ if int64(len(frag)) > int64(2*remainSize) {
+ if VerboseLogs {
+ log.Printf("http2: header list too large")
+ }
+ // It would be nice to send a RST_STREAM before sending the GOAWAY,
+ // but the structure of the server's frame writer makes this difficult.
+ return mh, ConnectionError(ErrCodeProtocol)
+ }
+
+ // Also close the connection after any CONTINUATION frame following an
+ // invalid header, since we stop tracking the size of the headers after
+ // an invalid one.
+ if invalid != nil {
+ if VerboseLogs {
+ log.Printf("http2: invalid header: %v", invalid)
+ }
+ // It would be nice to send a RST_STREAM before sending the GOAWAY,
+ // but the structure of the server's frame writer makes this difficult.
+ return mh, ConnectionError(ErrCodeProtocol)
+ }
+
if _, err := hdec.Write(frag); err != nil {
- return nil, ConnectionError(ErrCodeCompression)
+ return mh, ConnectionError(ErrCodeCompression)
}
if hc.HeadersEnded() {
@@ -1595,7 +1628,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
mh.HeadersFrame.invalidate()
if err := hdec.Close(); err != nil {
- return nil, ConnectionError(ErrCodeCompression)
+ return mh, ConnectionError(ErrCodeCompression)
}
if invalid != nil {
fr.errDetail = invalid
diff --git a/http2/go111.go b/http2/go111.go
deleted file mode 100644
index 5bf62b032e..0000000000
--- a/http2/go111.go
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2018 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.11
-// +build go1.11
-
-package http2
-
-import (
- "net/http/httptrace"
- "net/textproto"
-)
-
-func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
- return trace != nil && trace.WroteHeaderField != nil
-}
-
-func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
- if trace != nil && trace.WroteHeaderField != nil {
- trace.WroteHeaderField(k, []string{v})
- }
-}
-
-func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
- if trace != nil {
- return trace.Got1xxResponse
- }
- return nil
-}
diff --git a/http2/go115.go b/http2/go115.go
deleted file mode 100644
index 908af1ab93..0000000000
--- a/http2/go115.go
+++ /dev/null
@@ -1,27 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.15
-// +build go1.15
-
-package http2
-
-import (
- "context"
- "crypto/tls"
-)
-
-// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
-// connection.
-func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
- dialer := &tls.Dialer{
- Config: cfg,
- }
- cn, err := dialer.DialContext(ctx, network, addr)
- if err != nil {
- return nil, err
- }
- tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
- return tlsCn, nil
-}
diff --git a/http2/go118.go b/http2/go118.go
deleted file mode 100644
index aca4b2b31a..0000000000
--- a/http2/go118.go
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.18
-// +build go1.18
-
-package http2
-
-import (
- "crypto/tls"
- "net"
-)
-
-func tlsUnderlyingConn(tc *tls.Conn) net.Conn {
- return tc.NetConn()
-}
diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go
index 038cbc3649..3e78f29135 100644
--- a/http2/h2c/h2c_test.go
+++ b/http2/h2c/h2c_test.go
@@ -9,7 +9,6 @@ import (
"crypto/tls"
"fmt"
"io"
- "io/ioutil"
"log"
"net"
"net/http"
@@ -68,7 +67,7 @@ func TestContext(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- _, err = ioutil.ReadAll(resp.Body)
+ _, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
@@ -162,7 +161,7 @@ func TestMaxBytesHandler(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
- _, err = ioutil.ReadAll(resp.Body)
+ _, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
diff --git a/http2/h2i/h2i.go b/http2/h2i/h2i.go
index 901f6ca79a..ee7020dd9b 100644
--- a/http2/h2i/h2i.go
+++ b/http2/h2i/h2i.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
/*
The h2i command is an interactive HTTP/2 console.
diff --git a/http2/hpack/gen.go b/http2/hpack/gen.go
index de14ab0ec0..0efa8e558c 100644
--- a/http2/hpack/gen.go
+++ b/http2/hpack/gen.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package main
@@ -11,7 +10,6 @@ import (
"bytes"
"fmt"
"go/format"
- "io/ioutil"
"os"
"sort"
@@ -177,7 +175,7 @@ func genFile(name string, buf *bytes.Buffer) {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
- if err := ioutil.WriteFile(name, b, 0644); err != nil {
+ if err := os.WriteFile(name, b, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
diff --git a/http2/http2.go b/http2/http2.go
index 6f2df28187..6c18ea230b 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -17,15 +17,18 @@ package http2 // import "golang.org/x/net/http2"
import (
"bufio"
+ "context"
"crypto/tls"
+ "errors"
"fmt"
- "io"
+ "net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
+ "time"
"golang.org/x/net/http/httpguts"
)
@@ -35,6 +38,15 @@ var (
logFrameWrites bool
logFrameReads bool
inTests bool
+
+ // Enabling extended CONNECT by causes browsers to attempt to use
+ // WebSockets-over-HTTP/2. This results in problems when the server's websocket
+ // package doesn't support extended CONNECT.
+ //
+ // Disable extended CONNECT by default for now.
+ //
+ // Issue #71128.
+ disableExtendedConnectProtocol = true
)
func init() {
@@ -47,6 +59,9 @@ func init() {
logFrameWrites = true
logFrameReads = true
}
+ if strings.Contains(e, "http2xconnect=1") {
+ disableExtendedConnectProtocol = false
+ }
}
const (
@@ -138,6 +153,10 @@ func (s Setting) Valid() error {
if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol)
}
+ case SettingEnableConnectProtocol:
+ if s.Val != 1 && s.Val != 0 {
+ return ConnectionError(ErrCodeProtocol)
+ }
}
return nil
}
@@ -147,21 +166,23 @@ func (s Setting) Valid() error {
type SettingID uint16
const (
- SettingHeaderTableSize SettingID = 0x1
- SettingEnablePush SettingID = 0x2
- SettingMaxConcurrentStreams SettingID = 0x3
- SettingInitialWindowSize SettingID = 0x4
- SettingMaxFrameSize SettingID = 0x5
- SettingMaxHeaderListSize SettingID = 0x6
+ SettingHeaderTableSize SettingID = 0x1
+ SettingEnablePush SettingID = 0x2
+ SettingMaxConcurrentStreams SettingID = 0x3
+ SettingInitialWindowSize SettingID = 0x4
+ SettingMaxFrameSize SettingID = 0x5
+ SettingMaxHeaderListSize SettingID = 0x6
+ SettingEnableConnectProtocol SettingID = 0x8
)
var settingName = map[SettingID]string{
- SettingHeaderTableSize: "HEADER_TABLE_SIZE",
- SettingEnablePush: "ENABLE_PUSH",
- SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
- SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
- SettingMaxFrameSize: "MAX_FRAME_SIZE",
- SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+ SettingHeaderTableSize: "HEADER_TABLE_SIZE",
+ SettingEnablePush: "ENABLE_PUSH",
+ SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
+ SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
+ SettingMaxFrameSize: "MAX_FRAME_SIZE",
+ SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+ SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
}
func (s SettingID) String() string {
@@ -210,12 +231,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error)
}
-// A gate lets two goroutines coordinate their activities.
-type gate chan struct{}
-
-func (g gate) Done() { g <- struct{}{} }
-func (g gate) Wait() { <-g }
-
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{}
@@ -241,13 +256,19 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
- _ incomparable
- w io.Writer // immutable
- bw *bufio.Writer // non-nil when data is buffered
+ _ incomparable
+ group synctestGroupInterface // immutable
+ conn net.Conn // immutable
+ bw *bufio.Writer // non-nil when data is buffered
+ byteTimeout time.Duration // immutable, WriteByteTimeout
}
-func newBufferedWriter(w io.Writer) *bufferedWriter {
- return &bufferedWriter{w: w}
+func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
+ return &bufferedWriter{
+ group: group,
+ conn: conn,
+ byteTimeout: timeout,
+ }
}
// bufWriterPoolBufferSize is the size of bufio.Writer's
@@ -274,7 +295,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
- bw.Reset(w.w)
+ bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw
}
return w.bw.Write(p)
@@ -292,6 +313,38 @@ func (w *bufferedWriter) Flush() error {
return err
}
+type bufferedWriterTimeoutWriter bufferedWriter
+
+func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
+ return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
+}
+
+// writeWithByteTimeout writes to conn.
+// If more than timeout passes without any bytes being written to the connection,
+// the write fails.
+func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
+ if timeout <= 0 {
+ return conn.Write(p)
+ }
+ for {
+ var now time.Time
+ if group == nil {
+ now = time.Now()
+ } else {
+ now = group.Now()
+ }
+ conn.SetWriteDeadline(now.Add(timeout))
+ nn, err := conn.Write(p[n:])
+ n += nn
+ if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
+ // Either we finished the write, made no progress, or hit the deadline.
+ // Whichever it is, we're done now.
+ conn.SetWriteDeadline(time.Time{})
+ return n, err
+ }
+ }
+}
+
func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
@@ -362,24 +415,18 @@ func (s *sorter) SortStrings(ss []string) {
s.v = save
}
-// validPseudoPath reports whether v is a valid :path pseudo-header
-// value. It must be either:
-//
-// - a non-empty string starting with '/'
-// - the string '*', for OPTIONS requests.
-//
-// For now this is only used a quick check for deciding when to clean
-// up Opaque URLs before sending requests from the Transport.
-// See golang.org/issue/16847
-//
-// We used to enforce that the path also didn't start with "//", but
-// Google's GFE accepts such paths and Chrome sends them, so ignore
-// that part of the spec. See golang.org/issue/19103.
-func validPseudoPath(v string) bool {
- return (len(v) > 0 && v[0] == '/') || v == "*"
-}
-
// incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type incomparable [0]func()
+
+// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
+// It's defined as an interface here to let us keep synctestGroup entirely test-only
+// and not a part of non-test builds.
+type synctestGroupInterface interface {
+ Join()
+ Now() time.Time
+ NewTimer(d time.Duration) timer
+ AfterFunc(d time.Duration, f func()) timer
+ ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
+}
diff --git a/http2/http2_test.go b/http2/http2_test.go
index a16774b7ff..c7774133a7 100644
--- a/http2/http2_test.go
+++ b/http2/http2_test.go
@@ -8,7 +8,6 @@ import (
"bytes"
"flag"
"fmt"
- "io/ioutil"
"net/http"
"os"
"path/filepath"
@@ -266,7 +265,7 @@ func TestNoUnicodeStrings(t *testing.T) {
return nil
}
- contents, err := ioutil.ReadFile(path)
+ contents, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
@@ -284,3 +283,20 @@ func TestNoUnicodeStrings(t *testing.T) {
t.Fatal(err)
}
}
+
+// setForTest sets *p = v, and restores its original value in t.Cleanup.
+func setForTest[T any](t *testing.T, p *T, v T) {
+ orig := *p
+ t.Cleanup(func() {
+ *p = orig
+ })
+ *p = v
+}
+
+// must returns v if err is nil, or panics otherwise.
+func must[T any](v T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/http2/netconn_test.go b/http2/netconn_test.go
new file mode 100644
index 0000000000..5a1759579e
--- /dev/null
+++ b/http2/netconn_test.go
@@ -0,0 +1,356 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "math"
+ "net"
+ "net/netip"
+ "os"
+ "sync"
+ "time"
+)
+
+// synctestNetPipe creates an in-memory, full duplex network connection.
+// Read and write timeouts are managed by the synctest group.
+//
+// Unlike net.Pipe, the connection is not synchronous.
+// Writes are made to a buffer, and return immediately.
+// By default, the buffer size is unlimited.
+func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) {
+ s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
+ s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+ s1 := newSynctestNetConnHalf(s1addr)
+ s2 := newSynctestNetConnHalf(s2addr)
+ r = &synctestNetConn{group: group, loc: s1, rem: s2}
+ w = &synctestNetConn{group: group, loc: s2, rem: s1}
+ r.peer = w
+ w.peer = r
+ return r, w
+}
+
+// A synctestNetConn is one endpoint of the connection created by synctestNetPipe.
+type synctestNetConn struct {
+ group *synctestGroup
+
+ // local and remote connection halves.
+ // Each half contains a buffer.
+ // Reads pull from the local buffer, and writes push to the remote buffer.
+ loc, rem *synctestNetConnHalf
+
+ // When set, group.Wait is automatically called before reads and after writes.
+ autoWait bool
+
+ // peer is the other endpoint.
+ peer *synctestNetConn
+}
+
+// Read reads data from the connection.
+func (c *synctestNetConn) Read(b []byte) (n int, err error) {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.read(b)
+}
+
+// Peek returns the available unread read buffer,
+// without consuming its contents.
+func (c *synctestNetConn) Peek() []byte {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.peek()
+}
+
+// Write writes data to the connection.
+func (c *synctestNetConn) Write(b []byte) (n int, err error) {
+ if c.autoWait {
+ defer c.group.Wait()
+ }
+ return c.rem.write(b)
+}
+
+// IsClosedByPeer reports whether the peer has closed its end of the connection.
+func (c *synctestNetConn) IsClosedByPeer() bool {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.isClosedByPeer()
+}
+
+// Close closes the connection.
+func (c *synctestNetConn) Close() error {
+ c.loc.setWriteError(errors.New("connection closed by peer"))
+ c.rem.setReadError(io.EOF)
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return nil
+}
+
+// LocalAddr returns the (fake) local network address.
+func (c *synctestNetConn) LocalAddr() net.Addr {
+ return c.loc.addr
+}
+
+// LocalAddr returns the (fake) remote network address.
+func (c *synctestNetConn) RemoteAddr() net.Addr {
+ return c.rem.addr
+}
+
+// SetDeadline sets the read and write deadlines for the connection.
+func (c *synctestNetConn) SetDeadline(t time.Time) error {
+ c.SetReadDeadline(t)
+ c.SetWriteDeadline(t)
+ return nil
+}
+
+// SetReadDeadline sets the read deadline for the connection.
+func (c *synctestNetConn) SetReadDeadline(t time.Time) error {
+ c.loc.rctx.setDeadline(c.group, t)
+ return nil
+}
+
+// SetWriteDeadline sets the write deadline for the connection.
+func (c *synctestNetConn) SetWriteDeadline(t time.Time) error {
+ c.rem.wctx.setDeadline(c.group, t)
+ return nil
+}
+
+// SetReadBufferSize sets the read buffer limit for the connection.
+// Writes by the peer will block so long as the buffer is full.
+func (c *synctestNetConn) SetReadBufferSize(size int) {
+ c.loc.setReadBufferSize(size)
+}
+
+// synctestNetConnHalf is one data flow in the connection created by synctestNetPipe.
+// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it.
+type synctestNetConnHalf struct {
+ addr net.Addr
+
+ // Read and write timeouts.
+ rctx, wctx deadlineContext
+
+ // A half can be readable and/or writable.
+ //
+ // These four channels act as a lock,
+ // and allow waiting for readability/writability.
+ // When the half is unlocked, exactly one channel contains a value.
+ // When the half is locked, all channels are empty.
+ lockr chan struct{} // readable
+ lockw chan struct{} // writable
+ lockrw chan struct{} // readable and writable
+ lockc chan struct{} // neither readable nor writable
+
+ bufMax int // maximum buffer size
+ buf bytes.Buffer
+ readErr error // error returned by reads
+ writeErr error // error returned by writes
+}
+
+func newSynctestNetConnHalf(addr net.Addr) *synctestNetConnHalf {
+ h := &synctestNetConnHalf{
+ addr: addr,
+ lockw: make(chan struct{}, 1),
+ lockr: make(chan struct{}, 1),
+ lockrw: make(chan struct{}, 1),
+ lockc: make(chan struct{}, 1),
+ bufMax: math.MaxInt, // unlimited
+ }
+ h.unlock()
+ return h
+}
+
+func (h *synctestNetConnHalf) lock() {
+ select {
+ case <-h.lockw:
+ case <-h.lockr:
+ case <-h.lockrw:
+ case <-h.lockc:
+ }
+}
+
+func (h *synctestNetConnHalf) unlock() {
+ canRead := h.readErr != nil || h.buf.Len() > 0
+ canWrite := h.writeErr != nil || h.bufMax > h.buf.Len()
+ switch {
+ case canRead && canWrite:
+ h.lockrw <- struct{}{}
+ case canRead:
+ h.lockr <- struct{}{}
+ case canWrite:
+ h.lockw <- struct{}{}
+ default:
+ h.lockc <- struct{}{}
+ }
+}
+
+func (h *synctestNetConnHalf) readWaitAndLock() error {
+ select {
+ case <-h.lockr:
+ return nil
+ case <-h.lockrw:
+ return nil
+ default:
+ }
+ ctx := h.rctx.context()
+ select {
+ case <-h.lockr:
+ return nil
+ case <-h.lockrw:
+ return nil
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+func (h *synctestNetConnHalf) writeWaitAndLock() error {
+ select {
+ case <-h.lockw:
+ return nil
+ case <-h.lockrw:
+ return nil
+ default:
+ }
+ ctx := h.wctx.context()
+ select {
+ case <-h.lockw:
+ return nil
+ case <-h.lockrw:
+ return nil
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+func (h *synctestNetConnHalf) peek() []byte {
+ h.lock()
+ defer h.unlock()
+ return h.buf.Bytes()
+}
+
+func (h *synctestNetConnHalf) isClosedByPeer() bool {
+ h.lock()
+ defer h.unlock()
+ return h.readErr != nil
+}
+
+func (h *synctestNetConnHalf) read(b []byte) (n int, err error) {
+ if err := h.readWaitAndLock(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.buf.Len() == 0 && h.readErr != nil {
+ return 0, h.readErr
+ }
+ return h.buf.Read(b)
+}
+
+func (h *synctestNetConnHalf) setReadBufferSize(size int) {
+ h.lock()
+ defer h.unlock()
+ h.bufMax = size
+}
+
+func (h *synctestNetConnHalf) write(b []byte) (n int, err error) {
+ for n < len(b) {
+ nn, err := h.writePartial(b[n:])
+ n += nn
+ if err != nil {
+ return n, err
+ }
+ }
+ return n, nil
+}
+
+func (h *synctestNetConnHalf) writePartial(b []byte) (n int, err error) {
+ if err := h.writeWaitAndLock(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.writeErr != nil {
+ return 0, h.writeErr
+ }
+ writeMax := h.bufMax - h.buf.Len()
+ if writeMax < len(b) {
+ b = b[:writeMax]
+ }
+ return h.buf.Write(b)
+}
+
+func (h *synctestNetConnHalf) setReadError(err error) {
+ h.lock()
+ defer h.unlock()
+ if h.readErr == nil {
+ h.readErr = err
+ }
+}
+
+func (h *synctestNetConnHalf) setWriteError(err error) {
+ h.lock()
+ defer h.unlock()
+ if h.writeErr == nil {
+ h.writeErr = err
+ }
+}
+
+// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context.
+type deadlineContext struct {
+ mu sync.Mutex
+ ctx context.Context
+ cancel context.CancelCauseFunc
+ timer timer
+}
+
+// context returns a Context which expires when the deadline does.
+func (t *deadlineContext) context() context.Context {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.ctx == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ return t.ctx
+}
+
+// setDeadline sets the current deadline.
+func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled
+ // and we should create a new one.
+ if t.ctx == nil || t.cancel == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ // Stop any existing deadline from expiring.
+ if t.timer != nil {
+ t.timer.Stop()
+ }
+ if deadline.IsZero() {
+ // No deadline.
+ return
+ }
+ if !deadline.After(group.Now()) {
+ // Deadline has already expired.
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ return
+ }
+ if t.timer != nil {
+ // Reuse existing deadline timer.
+ t.timer.Reset(deadline.Sub(group.Now()))
+ return
+ }
+ // Create a new timer to cancel the context at the deadline.
+ t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ })
+}
diff --git a/http2/not_go111.go b/http2/not_go111.go
deleted file mode 100644
index cc0baa8197..0000000000
--- a/http2/not_go111.go
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2018 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.11
-// +build !go1.11
-
-package http2
-
-import (
- "net/http/httptrace"
- "net/textproto"
-)
-
-func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return false }
-
-func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {}
-
-func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
- return nil
-}
diff --git a/http2/not_go115.go b/http2/not_go115.go
deleted file mode 100644
index e6c04cf7ac..0000000000
--- a/http2/not_go115.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.15
-// +build !go1.15
-
-package http2
-
-import (
- "context"
- "crypto/tls"
-)
-
-// dialTLSWithContext opens a TLS connection.
-func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
- cn, err := tls.Dial(network, addr, cfg)
- if err != nil {
- return nil, err
- }
- if err := cn.Handshake(); err != nil {
- return nil, err
- }
- if cfg.InsecureSkipVerify {
- return cn, nil
- }
- if err := cn.VerifyHostname(cfg.ServerName); err != nil {
- return nil, err
- }
- return cn, nil
-}
diff --git a/http2/not_go118.go b/http2/not_go118.go
deleted file mode 100644
index eab532c96b..0000000000
--- a/http2/not_go118.go
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.18
-// +build !go1.18
-
-package http2
-
-import (
- "crypto/tls"
- "net"
-)
-
-func tlsUnderlyingConn(tc *tls.Conn) net.Conn {
- return nil
-}
diff --git a/http2/pipe.go b/http2/pipe.go
index 684d984fd9..3b9f06b962 100644
--- a/http2/pipe.go
+++ b/http2/pipe.go
@@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
}
}
-var errClosedPipeWrite = errors.New("write on closed buffer")
+var (
+ errClosedPipeWrite = errors.New("write on closed buffer")
+ errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
+)
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
@@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite
}
+ // pipe.setBuffer is never invoked, leaving the buffer uninitialized.
+ // We shouldn't try to write to an uninitialized pipe,
+ // but returning an error is better than panicking.
+ if p.b == nil {
+ return 0, errUninitializedPipeWrite
+ }
return p.b.Write(d)
}
diff --git a/http2/pipe_test.go b/http2/pipe_test.go
index 67562a92a1..326b94deb5 100644
--- a/http2/pipe_test.go
+++ b/http2/pipe_test.go
@@ -8,7 +8,6 @@ import (
"bytes"
"errors"
"io"
- "io/ioutil"
"testing"
)
@@ -85,7 +84,7 @@ func TestPipeCloseWithError(t *testing.T) {
io.WriteString(p, body)
a := errors.New("test error")
p.CloseWithError(a)
- all, err := ioutil.ReadAll(p)
+ all, err := io.ReadAll(p)
if string(all) != body {
t.Errorf("read bytes = %q; want %q", all, body)
}
@@ -112,7 +111,7 @@ func TestPipeBreakWithError(t *testing.T) {
io.WriteString(p, "foo")
a := errors.New("test err")
p.BreakWithError(a)
- all, err := ioutil.ReadAll(p)
+ all, err := io.ReadAll(p)
if string(all) != "" {
t.Errorf("read bytes = %q; want empty string", all)
}
diff --git a/http2/server.go b/http2/server.go
index 02c88b6b3e..b640deb0e0 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -29,6 +29,7 @@ import (
"bufio"
"bytes"
"context"
+ "crypto/rand"
"crypto/tls"
"errors"
"fmt"
@@ -49,13 +50,18 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/httpcommon"
)
const (
- prefaceTimeout = 10 * time.Second
- firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
- handlerChunkWriteSize = 4 << 10
- defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+ prefaceTimeout = 10 * time.Second
+ firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
+ handlerChunkWriteSize = 4 << 10
+ defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+
+ // maxQueuedControlFrames is the maximum number of control frames like
+ // SETTINGS, PING and RST_STREAM that will be queued for writing before
+ // the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000
)
@@ -124,8 +130,25 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout.
+ // If zero or negative, there is no timeout.
IdleTimeout time.Duration
+ // ReadIdleTimeout is the timeout after which a health check using a ping
+ // frame will be carried out if no frame is received on the connection.
+ // If zero, no health check is performed.
+ ReadIdleTimeout time.Duration
+
+ // PingTimeout is the timeout after which the connection will be closed
+ // if a response to a ping is not received.
+ // If zero, a default of 15 seconds is used.
+ PingTimeout time.Duration
+
+ // WriteByteTimeout is the timeout after which a connection will be
+ // closed if no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ // If zero or negative, there is no timeout.
+ WriteByteTimeout time.Duration
+
// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
@@ -153,57 +176,39 @@ type Server struct {
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *serverInternalState
-}
-func (s *Server) initialConnRecvWindowSize() int32 {
- if s.MaxUploadBufferPerConnection >= initialWindowSize {
- return s.MaxUploadBufferPerConnection
- }
- return 1 << 20
+ // Synchronization group used for testing.
+ // Outside of tests, this is nil.
+ group synctestGroupInterface
}
-func (s *Server) initialStreamRecvWindowSize() int32 {
- if s.MaxUploadBufferPerStream > 0 {
- return s.MaxUploadBufferPerStream
+func (s *Server) markNewGoroutine() {
+ if s.group != nil {
+ s.group.Join()
}
- return 1 << 20
}
-func (s *Server) maxReadFrameSize() uint32 {
- if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
- return v
+func (s *Server) now() time.Time {
+ if s.group != nil {
+ return s.group.Now()
}
- return defaultMaxReadFrameSize
+ return time.Now()
}
-func (s *Server) maxConcurrentStreams() uint32 {
- if v := s.MaxConcurrentStreams; v > 0 {
- return v
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (s *Server) newTimer(d time.Duration) timer {
+ if s.group != nil {
+ return s.group.NewTimer(d)
}
- return defaultMaxStreams
+ return timeTimer{time.NewTimer(d)}
}
-func (s *Server) maxDecoderHeaderTableSize() uint32 {
- if v := s.MaxDecoderHeaderTableSize; v > 0 {
- return v
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (s *Server) afterFunc(d time.Duration, f func()) timer {
+ if s.group != nil {
+ return s.group.AfterFunc(d, f)
}
- return initialHeaderTableSize
-}
-
-func (s *Server) maxEncoderHeaderTableSize() uint32 {
- if v := s.MaxEncoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
-// maxQueuedControlFrames is the maximum number of control frames like
-// SETTINGS, PING and RST_STREAM that will be queued for writing before
-// the connection is closed to prevent memory exhaustion attacks.
-func (s *Server) maxQueuedControlFrames() int {
- // TODO: if anybody asks, add a Server field, and remember to define the
- // behavior of negative values.
- return maxQueuedControlFrames
+ return timeTimer{time.AfterFunc(d, f)}
}
type serverInternalState struct {
@@ -302,7 +307,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
}
- protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
if testHookOnConn != nil {
testHookOnConn()
}
@@ -319,12 +324,31 @@ func ConfigureServer(s *http.Server, conf *Server) error {
ctx = bc.BaseContext()
}
conf.ServeConn(c, &ServeConnOpts{
- Context: ctx,
- Handler: h,
- BaseConfig: hs,
+ Context: ctx,
+ Handler: h,
+ BaseConfig: hs,
+ SawClientPreface: sawClientPreface,
})
}
- s.TLSNextProto[NextProtoTLS] = protoHandler
+ s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ protoHandler(hs, c, h, false)
+ }
+ // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
+ //
+ // A connection passed in this method has already had the HTTP/2 preface read from it.
+ s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ nc, err := unencryptedNetConnFromTLSConn(c)
+ if err != nil {
+ if lg := hs.ErrorLog; lg != nil {
+ lg.Print(err)
+ } else {
+ log.Print(err)
+ }
+ go c.Close()
+ return
+ }
+ protoHandler(hs, nc, h, true)
+ }
return nil
}
@@ -399,16 +423,22 @@ func (o *ServeConnOpts) handler() http.Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
+ s.serveConn(c, opts, nil)
+}
+
+func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel()
+ http1srv := opts.baseConfig()
+ conf := configFromServer(http1srv, s)
sc := &serverConn{
srv: s,
- hs: opts.baseConfig(),
+ hs: http1srv,
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
- bw: newBufferedWriter(c),
+ bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
@@ -418,13 +448,19 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
- advMaxStreams: s.maxConcurrentStreams(),
+ advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize,
+ initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize,
+ pingTimeout: conf.PingTimeout,
+ countErrorFunc: conf.CountError,
serveG: newGoroutineLock(),
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
+ if newf != nil {
+ newf(sc)
+ }
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
@@ -434,7 +470,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here.
- if sc.hs.WriteTimeout != 0 {
+ if sc.hs.WriteTimeout > 0 {
sc.conn.SetWriteDeadline(time.Time{})
}
@@ -450,15 +486,15 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
- sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
+ sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c)
- if s.CountError != nil {
- fr.countError = s.CountError
+ if conf.CountError != nil {
+ fr.countError = conf.CountError
}
- fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
+ fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
- fr.SetMaxReadFrameSize(s.maxReadFrameSize())
+ fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr
if tc, ok := c.(connectionStater); ok {
@@ -491,7 +527,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// So for now, do nothing here again.
}
- if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
+ if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated."
@@ -528,7 +564,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
opts.UpgradeRequest = nil
}
- sc.serve()
+ sc.serve(conf)
}
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@@ -568,6 +604,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string
writeSched WriteScheduler
+ countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve()
@@ -587,6 +624,7 @@ type serverConn struct {
streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32
+ initialStreamRecvWindowSize int32
maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@@ -597,9 +635,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ pingSent bool
+ sentPingData [8]byte
goAwayCode ErrCode
- shutdownTimer *time.Timer // nil until used
- idleTimer *time.Timer // nil if unused
+ shutdownTimer timer // nil until used
+ idleTimer timer // nil if unused
+ readIdleTimeout time.Duration
+ pingTimeout time.Duration
+ readIdleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -614,11 +657,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 {
n = http.DefaultMaxHeaderBytes
}
- // http2's count is in a slightly different unit and includes 32 bytes per pair.
- // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
- const perFieldOverhead = 32 // per http2 spec
- const typicalHeaders = 10 // conservative
- return uint32(n + typicalHeaders*perFieldOverhead)
+ return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
}
func (sc *serverConn) curOpenStreams() uint32 {
@@ -648,12 +687,12 @@ type stream struct {
flow outflow // limits writing from Handler to client
inflow inflow // what the client is allowed to POST/etc to us
state streamState
- resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
- gotTrailerHeader bool // HEADER frame for trailers was seen
- wroteHeaders bool // whether we wrote headers (not status 100)
- readDeadline *time.Timer // nil if unused
- writeDeadline *time.Timer // nil if unused
- closeErr error // set before cw is closed
+ resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+ wroteHeaders bool // whether we wrote headers (not status 100)
+ readDeadline timer // nil if unused
+ writeDeadline timer // nil if unused
+ closeErr error // set before cw is closed
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
@@ -731,11 +770,7 @@ func isClosedConnError(err error) bool {
return false
}
- // TODO: remove this string search and be more like the Windows
- // case below. That might involve modifying the standard library
- // to return better error types.
- str := err.Error()
- if strings.Contains(str, "use of closed network connection") {
+ if errors.Is(err, net.ErrClosed) {
return true
}
@@ -778,8 +813,7 @@ const maxCachedCanonicalHeadersKeysSize = 2048
func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
- buildCommonHeaderMapsOnce()
- cv, ok := commonCanonHeader[v]
+ cv, ok := httpcommon.CachedCanonicalHeader(v)
if ok {
return cv
}
@@ -814,8 +848,9 @@ type readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
- gate := make(gate)
- gateDone := gate.Done
+ sc.srv.markNewGoroutine()
+ gate := make(chan struct{})
+ gateDone := func() { gate <- struct{}{} }
for {
f, err := sc.framer.ReadFrame()
select {
@@ -846,6 +881,7 @@ type frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
+ sc.srv.markNewGoroutine()
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
@@ -884,7 +920,7 @@ func (sc *serverConn) notePanic() {
}
}
-func (sc *serverConn) serve() {
+func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check()
defer sc.notePanic()
defer sc.conn.Close()
@@ -896,20 +932,24 @@ func (sc *serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
+ settings := writeSettings{
+ {SettingMaxFrameSize, conf.MaxReadFrameSize},
+ {SettingMaxConcurrentStreams, sc.advMaxStreams},
+ {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
+ {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
+ {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
+ }
+ if !disableExtendedConnectProtocol {
+ settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
+ }
sc.writeFrame(FrameWriteRequest{
- write: writeSettings{
- {SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
- {SettingMaxConcurrentStreams, sc.advMaxStreams},
- {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
- {SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
- {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
- },
+ write: settings,
})
sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens.
- if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 {
+ if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
}
@@ -924,16 +964,23 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle)
- if sc.srv.IdleTimeout != 0 {
- sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
+ if sc.srv.IdleTimeout > 0 {
+ sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
+ if conf.SendPingTimeout > 0 {
+ sc.readIdleTimeout = conf.SendPingTimeout
+ sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
+ defer sc.readIdleTimer.Stop()
+ }
+
go sc.readFrames() // closed by defer sc.conn.Close above
- settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
+ settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
+ lastFrameTime := sc.srv.now()
loopNum := 0
for {
loopNum++
@@ -947,6 +994,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
+ lastFrameTime = sc.srv.now()
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
@@ -978,6 +1026,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg:
sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo)
+ case readIdleTimerMsg:
+ sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
@@ -1000,7 +1050,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us
// run out of memory.
- if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() {
+ if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection")
return
}
@@ -1016,12 +1066,39 @@ func (sc *serverConn) serve() {
}
}
+func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
+ if sc.pingSent {
+ sc.vlogf("timeout waiting for PING response")
+ sc.conn.Close()
+ return
+ }
+
+ pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
+ now := sc.srv.now()
+ if pingAt.After(now) {
+ // We received frames since arming the ping timer.
+ // Reset it for the next possible timeout.
+ sc.readIdleTimer.Reset(pingAt.Sub(now))
+ return
+ }
+
+ sc.pingSent = true
+ // Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
+ // is we send a PING frame containing 0s.
+ _, _ = rand.Read(sc.sentPingData[:])
+ sc.writeFrame(FrameWriteRequest{
+ write: &writePing{data: sc.sentPingData},
+ })
+ sc.readIdleTimer.Reset(sc.pingTimeout)
+}
+
type serverMessage int
// Message values sent to serveMsgCh.
var (
settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage)
+ readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
@@ -1029,6 +1106,7 @@ var (
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
+func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) {
@@ -1060,10 +1138,10 @@ func (sc *serverConn) readPreface() error {
errc <- nil
}
}()
- timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server?
+ timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
- case <-timer.C:
+ case <-timer.C():
return errPrefaceTimeout
case err := <-errc:
if err == nil {
@@ -1281,6 +1359,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false
sc.writingFrameAsync = false
+ if res.err != nil {
+ sc.conn.Close()
+ }
+
wr := res.wr
if writeEndsStream(wr.write) {
@@ -1428,7 +1510,7 @@ func (sc *serverConn) goAway(code ErrCode) {
func (sc *serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
- sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
+ sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
}
func (sc *serverConn) resetStream(se StreamError) {
@@ -1481,6 +1563,11 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
sc.goAway(ErrCodeFlowControl)
return true
case ConnectionError:
+ if res.f != nil {
+ if id := res.f.Header().StreamID; id > sc.maxClientStreamID {
+ sc.maxClientStreamID = id
+ }
+ }
sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
sc.goAway(ErrCode(ev))
return true // goAway will handle shutdown
@@ -1550,6 +1637,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check()
if f.IsAck() {
+ if sc.pingSent && sc.sentPingData == f.Data {
+ // This is a response to a PING we sent.
+ sc.pingSent = false
+ sc.readIdleTimer.Reset(sc.readIdleTimeout)
+ }
// 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag."
return nil
@@ -1637,7 +1729,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
- if sc.srv.IdleTimeout != 0 {
+ if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@@ -1659,6 +1751,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
}
}
st.closeErr = err
+ st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id)
}
@@ -1712,6 +1805,9 @@ func (sc *serverConn) processSetting(s Setting) error {
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val
+ case SettingEnableConnectProtocol:
+ // Receipt of this parameter by a server does not
+ // have any impact
default:
// Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST
@@ -2017,9 +2113,9 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway.
- if sc.hs.ReadTimeout != 0 {
+ if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
- st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
+ st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -2038,7 +2134,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package
// prior to the upgrade.
- if sc.hs.ReadTimeout != 0 {
+ if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
}
@@ -2115,9 +2211,9 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
- st.inflow.init(sc.srv.initialStreamRecvWindowSize())
- if sc.hs.WriteTimeout != 0 {
- st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
+ st.inflow.init(sc.initialStreamRecvWindowSize)
+ if sc.hs.WriteTimeout > 0 {
+ st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
@@ -2137,19 +2233,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check()
- rp := requestParam{
- method: f.PseudoValue("method"),
- scheme: f.PseudoValue("scheme"),
- authority: f.PseudoValue("authority"),
- path: f.PseudoValue("path"),
+ rp := httpcommon.ServerRequestParam{
+ Method: f.PseudoValue("method"),
+ Scheme: f.PseudoValue("scheme"),
+ Authority: f.PseudoValue("authority"),
+ Path: f.PseudoValue("path"),
+ Protocol: f.PseudoValue("protocol"),
+ }
+
+ // extended connect is disabled, so we should not see :protocol
+ if disableExtendedConnectProtocol && rp.Protocol != "" {
+ return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
- isConnect := rp.method == "CONNECT"
+ isConnect := rp.Method == "CONNECT"
if isConnect {
- if rp.path != "" || rp.scheme != "" || rp.authority == "" {
+ if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
- } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
+ } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
@@ -2163,12 +2265,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
}
- rp.header = make(http.Header)
+ header := make(http.Header)
+ rp.Header = header
for _, hf := range f.RegularFields() {
- rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ }
+ if rp.Authority == "" {
+ rp.Authority = header.Get("Host")
}
- if rp.authority == "" {
- rp.authority = rp.header.Get("Host")
+ if rp.Protocol != "" {
+ header.Set(":protocol", rp.Protocol)
}
rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
@@ -2177,7 +2283,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
}
bodyOpen := !f.StreamEnded()
if bodyOpen {
- if vv, ok := rp.header["Content-Length"]; ok {
+ if vv, ok := rp.Header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
req.ContentLength = int64(cl)
} else {
@@ -2193,83 +2299,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return rw, req, nil
}
-type requestParam struct {
- method string
- scheme, authority, path string
- header http.Header
-}
-
-func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
+func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) {
sc.serveG.check()
var tlsState *tls.ConnectionState // nil if not scheme https
- if rp.scheme == "https" {
+ if rp.Scheme == "https" {
tlsState = sc.tlsState
}
- needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
- if needsContinue {
- rp.header.Del("Expect")
- }
- // Merge Cookie headers into one "; "-delimited value.
- if cookies := rp.header["Cookie"]; len(cookies) > 1 {
- rp.header.Set("Cookie", strings.Join(cookies, "; "))
- }
-
- // Setup Trailers
- var trailer http.Header
- for _, v := range rp.header["Trailer"] {
- for _, key := range strings.Split(v, ",") {
- key = http.CanonicalHeaderKey(textproto.TrimString(key))
- switch key {
- case "Transfer-Encoding", "Trailer", "Content-Length":
- // Bogus. (copy of http1 rules)
- // Ignore.
- default:
- if trailer == nil {
- trailer = make(http.Header)
- }
- trailer[key] = nil
- }
- }
- }
- delete(rp.header, "Trailer")
-
- var url_ *url.URL
- var requestURI string
- if rp.method == "CONNECT" {
- url_ = &url.URL{Host: rp.authority}
- requestURI = rp.authority // mimic HTTP/1 server behavior
- } else {
- var err error
- url_, err = url.ParseRequestURI(rp.path)
- if err != nil {
- return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
- }
- requestURI = rp.path
+ res := httpcommon.NewServerRequest(rp)
+ if res.InvalidReason != "" {
+ return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
}
body := &requestBody{
conn: sc,
stream: st,
- needsContinue: needsContinue,
+ needsContinue: res.NeedsContinue,
}
- req := &http.Request{
- Method: rp.method,
- URL: url_,
+ req := (&http.Request{
+ Method: rp.Method,
+ URL: res.URL,
RemoteAddr: sc.remoteAddrStr,
- Header: rp.header,
- RequestURI: requestURI,
+ Header: rp.Header,
+ RequestURI: res.RequestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
- Host: rp.authority,
+ Host: rp.Authority,
Body: body,
- Trailer: trailer,
- }
- req = req.WithContext(st.ctx)
-
+ Trailer: res.Trailer,
+ }).WithContext(st.ctx)
rw := sc.newResponseWriter(st, req)
return rw, req, nil
}
@@ -2341,6 +2402,7 @@ func (sc *serverConn) handlerDone() {
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
+ sc.srv.markNewGoroutine()
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true
defer func() {
@@ -2549,7 +2611,6 @@ type responseWriterState struct {
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
- dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64
@@ -2638,7 +2699,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
- date = time.Now().UTC().Format(http.TimeFormat)
+ date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
@@ -2669,7 +2730,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
date: date,
})
if err != nil {
- rws.dirty = true
return 0, err
}
if endStream {
@@ -2690,7 +2750,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
- rws.dirty = true
return 0, err
}
}
@@ -2702,9 +2761,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
trailers: rws.trailers,
endStream: true,
})
- if err != nil {
- rws.dirty = true
- }
return len(p), err
}
return len(p), nil
@@ -2765,7 +2821,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
- if !deadline.IsZero() && deadline.Before(time.Now()) {
+ if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
@@ -2781,9 +2837,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
- st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
+ st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
} else {
- st.readDeadline.Reset(deadline.Sub(time.Now()))
+ st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil
@@ -2791,7 +2847,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
- if !deadline.IsZero() && deadline.Before(time.Now()) {
+ if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
@@ -2807,14 +2863,19 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
- st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
+ st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
} else {
- st.writeDeadline.Reset(deadline.Sub(time.Now()))
+ st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil
}
+func (w *responseWriter) EnableFullDuplex() error {
+ // We always support full duplex responses, so this is a no-op.
+ return nil
+}
+
func (w *responseWriter) Flush() {
w.FlushError()
}
@@ -2920,14 +2981,12 @@ func (rws *responseWriterState) writeHeader(code int) {
h.Del("Transfer-Encoding")
}
- if rws.conn.writeHeaders(rws.stream, &writeResHeaders{
+ rws.conn.writeHeaders(rws.stream, &writeResHeaders{
streamID: rws.stream.id,
httpResCode: code,
h: h,
endStream: rws.handlerDone && !rws.hasTrailers(),
- }) != nil {
- rws.dirty = true
- }
+ })
return
}
@@ -2992,19 +3051,10 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
func (w *responseWriter) handlerDone() {
rws := w.rws
- dirty := rws.dirty
rws.handlerDone = true
w.Flush()
w.rws = nil
- if !dirty {
- // Only recycle the pool if all prior Write calls to
- // the serverConn goroutine completed successfully. If
- // they returned earlier due to resets from the peer
- // there might still be write goroutines outstanding
- // from the serverConn referencing the rws memory. See
- // issue 20704.
- responseWriterStatePool.Put(rws)
- }
+ responseWriterStatePool.Put(rws)
}
// Push errors.
@@ -3175,18 +3225,19 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
// we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
- rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
- method: msg.method,
- scheme: msg.url.Scheme,
- authority: msg.url.Host,
- path: msg.url.RequestURI(),
- header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
+ rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
+ Method: msg.method,
+ Scheme: msg.url.Scheme,
+ Authority: msg.url.Host,
+ Path: msg.url.RequestURI(),
+ Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
})
if err != nil {
// Should not happen, since we've already validated msg.url.
panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
}
+ sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP)
return promisedID, nil
}
@@ -3271,7 +3322,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil {
return err
}
- f := sc.srv.CountError
+ f := sc.countErrorFunc
if f == nil {
return err
}
diff --git a/http2/server_push_test.go b/http2/server_push_test.go
index 6e57de0b7c..69e4c3b12d 100644
--- a/http2/server_push_test.go
+++ b/http2/server_push_test.go
@@ -8,9 +8,9 @@ import (
"errors"
"fmt"
"io"
- "io/ioutil"
"net/http"
"reflect"
+ "runtime"
"strconv"
"sync"
"testing"
@@ -39,7 +39,7 @@ func TestServer_Push_Success(t *testing.T) {
if r.Body == nil {
return fmt.Errorf("nil Body")
}
- if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
+ if buf, err := io.ReadAll(r.Body); err != nil || len(buf) != 0 {
return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
}
return nil
@@ -105,7 +105,7 @@ func TestServer_Push_Success(t *testing.T) {
errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
}
})
- stURL = st.ts.URL
+ stURL = "https://" + st.authority()
// Send one request, which should push two responses.
st.greet()
@@ -169,7 +169,7 @@ func TestServer_Push_Success(t *testing.T) {
return checkPushPromise(f, 2, [][2]string{
{":method", "GET"},
{":scheme", "https"},
- {":authority", st.ts.Listener.Addr().String()},
+ {":authority", st.authority()},
{":path", "/pushed?get"},
{"user-agent", userAgent},
})
@@ -178,7 +178,7 @@ func TestServer_Push_Success(t *testing.T) {
return checkPushPromise(f, 4, [][2]string{
{":method", "HEAD"},
{":scheme", "https"},
- {":authority", st.ts.Listener.Addr().String()},
+ {":authority", st.authority()},
{":path", "/pushed?head"},
{"cookie", cookie},
{"user-agent", userAgent},
@@ -218,12 +218,12 @@ func TestServer_Push_Success(t *testing.T) {
consumed := map[uint32]int{}
for k := 0; len(expected) > 0; k++ {
- f, err := st.readFrame()
- if err != nil {
+ f := st.readFrame()
+ if f == nil {
for id, left := range expected {
t.Errorf("stream %d: missing %d frames", id, len(left))
}
- t.Fatalf("readFrame %d: %v", k, err)
+ break
}
id := f.Header().StreamID
label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
@@ -339,10 +339,10 @@ func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher,
t.Error(err)
}
// Should not get a PUSH_PROMISE frame.
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Error("stream should end after headers")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
}
func TestServer_Push_RejectIfDisabled(t *testing.T) {
@@ -459,7 +459,7 @@ func TestServer_Push_StateTransitions(t *testing.T) {
}
getSlash(st)
// After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
- st.wantPushPromise()
+ _ = readFrame[*PushPromiseFrame](t, st)
if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
t.Fatalf("streamState(2)=%v, want %v", got, want)
}
@@ -468,10 +468,10 @@ func TestServer_Push_StateTransitions(t *testing.T) {
// the stream before we check st.streamState(2) -- should that happen, we'll
// see stateClosed and fail the above check.
close(gotPromise)
- st.wantHeaders()
- if df := st.wantData(); !df.StreamEnded() {
- t.Fatal("expected END_STREAM flag on DATA")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 2,
+ endStream: false,
+ })
if got, want := st.streamState(2), stateClosed; got != want {
t.Fatalf("streamState(2)=%v, want %v", got, want)
}
@@ -483,11 +483,7 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) {
ready := make(chan struct{})
errc := make(chan error, 2)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- select {
- case <-ready:
- case <-time.After(5 * time.Second):
- errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
- }
+ <-ready
if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
errc <- fmt.Errorf("Push()=%v, want %v", got, want)
}
@@ -505,6 +501,10 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) {
case <-ready:
return
default:
+ if runtime.GOARCH == "wasm" {
+ // Work around https://go.dev/issue/65178 to avoid goroutine starvation.
+ runtime.Gosched()
+ }
}
st.sc.serveMsgCh <- func(loopNum int) {
if !st.sc.pushEnabled {
@@ -517,3 +517,55 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) {
t.Error(err)
}
}
+
+func TestServer_Push_Underflow(t *testing.T) {
+ // Test for #63511: Send several requests which generate PUSH_PROMISE responses,
+ // verify they all complete successfully.
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.RequestURI() {
+ case "/":
+ opt := &http.PushOptions{
+ Header: http.Header{"User-Agent": {"testagent"}},
+ }
+ if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
+ t.Errorf("error pushing: %v", err)
+ }
+ w.WriteHeader(200)
+ case "/pushed":
+ r.Header.Set("User-Agent", "newagent")
+ r.Header.Set("Cookie", "cookie")
+ w.WriteHeader(200)
+ default:
+ t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
+ }
+ })
+ // Send several requests.
+ st.greet()
+ const numRequests = 4
+ for i := 0; i < numRequests; i++ {
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: uint32(1 + i*2), // clients send odd numbers
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ }
+ // Each request should result in one PUSH_PROMISE and two responses.
+ numPushPromises := 0
+ numHeaders := 0
+ for numHeaders < numRequests*2 || numPushPromises < numRequests {
+ f := st.readFrame()
+ if f == nil {
+ st.t.Fatal("conn is idle, want frame")
+ }
+ switch f := f.(type) {
+ case *HeadersFrame:
+ if !f.Flags.Has(FlagHeadersEndStream) {
+ t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
+ }
+ numHeaders++
+ case *PushPromiseFrame:
+ numPushPromises++
+ }
+ }
+}
diff --git a/http2/server_test.go b/http2/server_test.go
index 22657cbfe4..b27a127a5e 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -14,8 +14,8 @@ import (
"flag"
"fmt"
"io"
- "io/ioutil"
"log"
+ "math"
"net"
"net/http"
"net/http/httptest"
@@ -38,7 +38,7 @@ func stderrv() io.Writer {
return os.Stderr
}
- return ioutil.Discard
+ return io.Discard
}
type safeBuffer struct {
@@ -65,16 +65,16 @@ func (sb *safeBuffer) Len() int {
}
type serverTester struct {
- cc net.Conn // client conn
- t testing.TB
- ts *httptest.Server
- fr *Framer
- serverLogBuf safeBuffer // logger for httptest.Server
- logFilter []string // substrings to filter out
- scMu sync.Mutex // guards sc
- sc *serverConn
- hpackDec *hpack.Decoder
- decodedHeaders [][2]string
+ cc net.Conn // client conn
+ t testing.TB
+ group *synctestGroup
+ h1server *http.Server
+ h2server *Server
+ serverLogBuf safeBuffer // logger for httptest.Server
+ logFilter []string // substrings to filter out
+ scMu sync.Mutex // guards sc
+ sc *serverConn
+ testConnFramer
// If http2debug!=2, then we capture Frame debug logs that will be written
// to t.Log after a test fails. The read and write logs use separate locks
@@ -101,23 +101,153 @@ func resetHooks() {
testHookOnPanicMu.Unlock()
}
+func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *httptest.Server {
+ ts := httptest.NewUnstartedServer(handler)
+ ts.EnableHTTP2 = true
+ ts.Config.ErrorLog = log.New(twriter{t: t}, "", log.LstdFlags)
+ h2server := new(Server)
+ for _, opt := range opts {
+ switch v := opt.(type) {
+ case func(*httptest.Server):
+ v(ts)
+ case func(*http.Server):
+ v(ts.Config)
+ case func(*Server):
+ v(h2server)
+ default:
+ t.Fatalf("unknown newTestServer option type %T", v)
+ }
+ }
+ ConfigureServer(ts.Config, h2server)
+
+ // ConfigureServer populates ts.Config.TLSConfig.
+ // Copy it to ts.TLS as well.
+ ts.TLS = ts.Config.TLSConfig
+
+ // Go 1.22 changes the default minimum TLS version to TLS 1.2,
+ // in order to properly test cases where we want to reject low
+ // TLS versions, we need to explicitly configure the minimum
+ // version here.
+ ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
+
+ ts.StartTLS()
+ t.Cleanup(func() {
+ ts.CloseClientConnections()
+ ts.Close()
+ })
+
+ return ts
+}
+
type serverTesterOpt string
-var optOnlyServer = serverTesterOpt("only_server")
-var optQuiet = serverTesterOpt("quiet_logging")
var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
+var optQuiet = func(server *http.Server) {
+ server.ErrorLog = log.New(io.Discard, "", 0)
+}
+
func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
+ t.Helper()
+ g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))
+ t.Cleanup(func() {
+ g.Close(t)
+ })
+
+ h1server := &http.Server{}
+ h2server := &Server{
+ group: g,
+ }
+ tlsState := tls.ConnectionState{
+ Version: tls.VersionTLS13,
+ ServerName: "go.dev",
+ CipherSuite: tls.TLS_AES_128_GCM_SHA256,
+ }
+ for _, opt := range opts {
+ switch v := opt.(type) {
+ case func(*Server):
+ v(h2server)
+ case func(*http.Server):
+ v(h1server)
+ case func(*tls.ConnectionState):
+ v(&tlsState)
+ default:
+ t.Fatalf("unknown newServerTester option type %T", v)
+ }
+ }
+ ConfigureServer(h1server, h2server)
+
+ cli, srv := synctestNetPipe(g)
+ cli.SetReadDeadline(g.Now())
+ cli.autoWait = true
+
+ st := &serverTester{
+ t: t,
+ cc: cli,
+ group: g,
+ h1server: h1server,
+ h2server: h2server,
+ }
+ st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
+ if h1server.ErrorLog == nil {
+ h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
+ }
+
+ t.Cleanup(func() {
+ st.Close()
+ g.AdvanceTime(goAwayTimeout) // give server time to shut down
+ })
+
+ connc := make(chan *serverConn)
+ go func() {
+ g.Join()
+ h2server.serveConn(&netConnWithConnectionState{
+ Conn: srv,
+ state: tlsState,
+ }, &ServeConnOpts{
+ Handler: handler,
+ BaseConfig: h1server,
+ }, func(sc *serverConn) {
+ connc <- sc
+ })
+ }()
+ st.sc = <-connc
+
+ st.fr = NewFramer(st.cc, st.cc)
+ st.testConnFramer = testConnFramer{
+ t: t,
+ fr: NewFramer(st.cc, st.cc),
+ dec: hpack.NewDecoder(initialHeaderTableSize, nil),
+ }
+ g.Wait()
+ return st
+}
+
+type netConnWithConnectionState struct {
+ net.Conn
+ state tls.ConnectionState
+}
+
+func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState {
+ return c.state
+}
+
+// newServerTesterWithRealConn creates a test server listening on a localhost port.
+// Mostly superseded by newServerTester, which creates a test server using a fake
+// net.Conn and synthetic time. This function is still around because some benchmarks
+// rely on it; new tests should use newServerTester.
+func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
resetHooks()
ts := httptest.NewUnstartedServer(handler)
+ t.Cleanup(ts.Close)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{NextProtoTLS},
}
- var onlyServer, quiet, framerReuseFrames bool
+ var framerReuseFrames bool
h2server := new(Server)
for _, opt := range opts {
switch v := opt.(type) {
@@ -125,14 +255,12 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
v(tlsConfig)
case func(*httptest.Server):
v(ts)
+ case func(*http.Server):
+ v(ts.Config)
case func(*Server):
v(h2server)
case serverTesterOpt:
switch v {
- case optOnlyServer:
- onlyServer = true
- case optQuiet:
- quiet = true
case optFramerReuseFrames:
framerReuseFrames = true
}
@@ -145,17 +273,19 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
ConfigureServer(ts.Config, h2server)
+ // Go 1.22 changes the default minimum TLS version to TLS 1.2,
+ // in order to properly test cases where we want to reject low
+ // TLS versions, we need to explicitly configure the minimum
+ // version here.
+ ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
+
st := &serverTester{
- t: t,
- ts: ts,
+ t: t,
}
st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
- st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
- if quiet {
- ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
- } else {
+ if ts.Config.ErrorLog == nil {
ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
}
ts.StartTLS()
@@ -169,36 +299,54 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
st.sc = v
}
log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
- if !onlyServer {
- cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
- if err != nil {
- t.Fatal(err)
+ cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ st.cc = cc
+ st.testConnFramer = testConnFramer{
+ t: t,
+ fr: NewFramer(st.cc, st.cc),
+ dec: hpack.NewDecoder(initialHeaderTableSize, nil),
+ }
+ if framerReuseFrames {
+ st.fr.SetReuseFrames()
+ }
+ if !logFrameReads && !logFrameWrites {
+ st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
+ m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
+ st.frameReadLogMu.Lock()
+ fmt.Fprintf(&st.frameReadLogBuf, m, v...)
+ st.frameReadLogMu.Unlock()
}
- st.cc = cc
- st.fr = NewFramer(cc, cc)
- if framerReuseFrames {
- st.fr.SetReuseFrames()
- }
- if !logFrameReads && !logFrameWrites {
- st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
- m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
- st.frameReadLogMu.Lock()
- fmt.Fprintf(&st.frameReadLogBuf, m, v...)
- st.frameReadLogMu.Unlock()
- }
- st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
- m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
- st.frameWriteLogMu.Lock()
- fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
- st.frameWriteLogMu.Unlock()
- }
- st.fr.logReads = true
- st.fr.logWrites = true
+ st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
+ m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
+ st.frameWriteLogMu.Lock()
+ fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
+ st.frameWriteLogMu.Unlock()
}
+ st.fr.logReads = true
+ st.fr.logWrites = true
}
return st
}
+// sync waits for all goroutines to idle.
+func (st *serverTester) sync() {
+ if st.group != nil {
+ st.group.Wait()
+ }
+}
+
+// advance advances synthetic time by a duration.
+func (st *serverTester) advance(d time.Duration) {
+ st.group.AdvanceTime(d)
+}
+
+func (st *serverTester) authority() string {
+ return "dummy.tld"
+}
+
func (st *serverTester) closeConn() {
st.scMu.Lock()
defer st.scMu.Unlock()
@@ -274,7 +422,6 @@ func (st *serverTester) Close() {
st.cc.Close()
}
}
- st.ts.Close()
if st.cc != nil {
st.cc.Close()
}
@@ -284,13 +431,16 @@ func (st *serverTester) Close() {
// greet initiates the client's HTTP/2 connection into a state where
// frames may be sent.
func (st *serverTester) greet() {
+ st.t.Helper()
st.greetAndCheckSettings(func(Setting) error { return nil })
}
func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
+ st.t.Helper()
st.writePreface()
- st.writeInitialSettings()
- st.wantSettings().ForeachSetting(checkSetting)
+ st.writeSettings()
+ st.sync()
+ readFrame[*SettingsFrame](st.t, st).ForeachSetting(checkSetting)
st.writeSettingsAck()
// The initial WINDOW_UPDATE and SETTINGS ACK can come in any order.
@@ -298,9 +448,9 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error
var gotWindowUpdate bool
for i := 0; i < 2; i++ {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatal(err)
+ f := st.readFrame()
+ if f == nil {
+ st.t.Fatal("wanted a settings ACK and window update, got none")
}
switch f := f.(type) {
case *SettingsFrame:
@@ -313,7 +463,8 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error
if f.FrameHeader.StreamID != 0 {
st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
}
- incr := uint32(st.sc.srv.initialConnRecvWindowSize() - initialWindowSize)
+ conf := configFromServer(st.sc.hs, st.sc.srv)
+ incr := uint32(conf.MaxUploadBufferPerConnection - initialWindowSize)
if f.Increment != incr {
st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
}
@@ -342,34 +493,6 @@ func (st *serverTester) writePreface() {
}
}
-func (st *serverTester) writeInitialSettings() {
- if err := st.fr.WriteSettings(); err != nil {
- if runtime.GOOS == "openbsd" && strings.HasSuffix(err.Error(), "write: broken pipe") {
- st.t.Logf("Error writing initial SETTINGS frame from client to server: %v", err)
- st.t.Skipf("Skipping test with known OpenBSD failure mode. (See https://go.dev/issue/52208.)")
- }
- st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
- }
-}
-
-func (st *serverTester) writeSettingsAck() {
- if err := st.fr.WriteSettingsAck(); err != nil {
- st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
- }
-}
-
-func (st *serverTester) writeHeaders(p HeadersFrameParam) {
- if err := st.fr.WriteHeaders(p); err != nil {
- st.t.Fatalf("Error writing HEADERS: %v", err)
- }
-}
-
-func (st *serverTester) writePriority(id uint32, p PriorityParam) {
- if err := st.fr.WritePriority(id, p); err != nil {
- st.t.Fatalf("Error writing PRIORITY: %v", err)
- }
-}
-
func (st *serverTester) encodeHeaderField(k, v string) {
err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
if err != nil {
@@ -403,7 +526,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte {
}
st.headerBuf.Reset()
- defaultAuthority := st.ts.Listener.Addr().String()
+ defaultAuthority := st.authority()
if len(headers) == 0 {
// Fast path, mostly for benchmarks, so test code doesn't pollute
@@ -468,150 +591,13 @@ func (st *serverTester) bodylessReq1(headers ...string) {
})
}
-func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
- if err := st.fr.WriteData(streamID, endStream, data); err != nil {
- st.t.Fatalf("Error writing DATA: %v", err)
- }
-}
-
-func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
- if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
- st.t.Fatalf("Error writing DATA: %v", err)
- }
-}
-
-// writeReadPing sends a PING and immediately reads the PING ACK.
-// It will fail if any other unread data was pending on the connection.
-func (st *serverTester) writeReadPing() {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- if err := st.fr.WritePing(false, data); err != nil {
- st.t.Fatalf("Error writing PING: %v", err)
- }
- p := st.wantPing()
- if p.Flags&FlagPingAck == 0 {
- st.t.Fatalf("got a PING, want a PING ACK")
- }
- if p.Data != data {
- st.t.Fatalf("got PING data = %x, want %x", p.Data, data)
- }
-}
-
-func (st *serverTester) readFrame() (Frame, error) {
- return st.fr.ReadFrame()
-}
-
-func (st *serverTester) wantHeaders() *HeadersFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *HeadersFrame", f)
- }
- return hf
-}
-
-func (st *serverTester) wantContinuation() *ContinuationFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
- }
- cf, ok := f.(*ContinuationFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *ContinuationFrame", f)
- }
- return cf
-}
-
-func (st *serverTester) wantData() *DataFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a DATA frame: %v", err)
- }
- df, ok := f.(*DataFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *DataFrame", f)
- }
- return df
-}
-
-func (st *serverTester) wantSettings() *SettingsFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *SettingsFrame", f)
- }
- return sf
-}
-
-func (st *serverTester) wantPing() *PingFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a PING frame: %v", err)
- }
- pf, ok := f.(*PingFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *PingFrame", f)
- }
- return pf
-}
-
-func (st *serverTester) wantGoAway() *GoAwayFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
- }
- gf, ok := f.(*GoAwayFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *GoAwayFrame", f)
- }
- return gf
-}
-
-func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
- }
- rs, ok := f.(*RSTStreamFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
- }
- if rs.FrameHeader.StreamID != streamID {
- st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
- }
- if rs.ErrCode != errCode {
- st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
- }
-}
-
-func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
- }
- wu, ok := f.(*WindowUpdateFrame)
- if !ok {
- st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
- }
- if wu.FrameHeader.StreamID != streamID {
- st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
- }
- if wu.Increment != incr {
- st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
- }
-}
-
func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
+ conf := configFromServer(st.sc.hs, st.sc.srv)
var initial int32
if streamID == 0 {
- initial = st.sc.srv.initialConnRecvWindowSize()
+ initial = conf.MaxUploadBufferPerConnection
} else {
- initial = st.sc.srv.initialStreamRecvWindowSize()
+ initial = conf.MaxUploadBufferPerStream
}
donec := make(chan struct{})
st.sc.sendServeMsg(func(sc *serverConn) {
@@ -628,32 +614,6 @@ func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
<-donec
}
-func (st *serverTester) wantSettingsAck() {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatal(err)
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- st.t.Fatalf("Wanting a settings ACK, received a %T", f)
- }
- if !sf.Header().Flags.Has(FlagSettingsAck) {
- st.t.Fatal("Settings Frame didn't have ACK set")
- }
-}
-
-func (st *serverTester) wantPushPromise() *PushPromiseFrame {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatal(err)
- }
- ppf, ok := f.(*PushPromiseFrame)
- if !ok {
- st.t.Fatalf("Wanted PushPromise, received %T", ppf)
- }
- return ppf
-}
-
func TestServer(t *testing.T) {
gotReq := make(chan bool, 1)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
@@ -662,12 +622,6 @@ func TestServer(t *testing.T) {
})
defer st.Close()
- covers("3.5", `
- The server connection preface consists of a potentially empty
- SETTINGS frame ([SETTINGS]) that MUST be the first frame the
- server sends in the HTTP/2 connection.
- `)
-
st.greet()
st.writeHeaders(HeadersFrameParam{
StreamID: 1, // clients send odd numbers
@@ -860,7 +814,7 @@ func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, wr
if r.ContentLength != wantContentLength {
t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
}
- all, err := ioutil.ReadAll(r.Body)
+ all, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
@@ -881,7 +835,7 @@ func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError s
if r.ContentLength != wantContentLength {
t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
}
- all, err := ioutil.ReadAll(r.Body)
+ all, err := io.ReadAll(r.Body)
if err == nil {
t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
wantReadError, all)
@@ -1078,6 +1032,26 @@ func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
})
}
+func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) {
+ // "':authority' MUST NOT include the deprecated userinfo subcomponent
+ // for "http" or "https" schemed URIs."
+ // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
+ testRejectRequest(t, func(st *serverTester) {
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "userinfo@example.tld"})
+ enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
+ enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
+ enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1, // clients send odd numbers
+ BlockFragment: buf.Bytes(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ })
+}
+
func testRejectRequest(t *testing.T, send func(*serverTester)) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
t.Error("server request made it to handler; should've been rejected")
@@ -1089,37 +1063,32 @@ func testRejectRequest(t *testing.T, send func(*serverTester)) {
st.wantRSTStream(1, ErrCodeProtocol)
}
-func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
+func newServerTesterForError(t *testing.T) *serverTester {
+ t.Helper()
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
t.Error("server request made it to handler; should've been rejected")
}, optQuiet)
- defer st.Close()
-
st.greet()
- send(st)
- gf := st.wantGoAway()
- if gf.ErrCode != ErrCodeProtocol {
- t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
- }
+ return st
}
// Section 5.1, on idle connections: "Receiving any frame other than
// HEADERS or PRIORITY on a stream in this state MUST be treated as a
// connection error (Section 5.4.1) of type PROTOCOL_ERROR."
func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
- testRejectRequestWithProtocolError(t, func(st *serverTester) {
- st.fr.WriteWindowUpdate(123, 456)
- })
+ st := newServerTesterForError(t)
+ st.fr.WriteWindowUpdate(123, 456)
+ st.wantGoAway(123, ErrCodeProtocol)
}
func TestRejectFrameOnIdle_Data(t *testing.T) {
- testRejectRequestWithProtocolError(t, func(st *serverTester) {
- st.fr.WriteData(123, true, nil)
- })
+ st := newServerTesterForError(t)
+ st.fr.WriteData(123, true, nil)
+ st.wantGoAway(123, ErrCodeProtocol)
}
func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
- testRejectRequestWithProtocolError(t, func(st *serverTester) {
- st.fr.WriteRSTStream(123, ErrCodeCancel)
- })
+ st := newServerTesterForError(t)
+ st.fr.WriteRSTStream(123, ErrCodeCancel)
+ st.wantGoAway(123, ErrCodeProtocol)
}
func TestServer_Request_Connect(t *testing.T) {
@@ -1193,7 +1162,7 @@ func TestServer_Ping(t *testing.T) {
t.Fatal(err)
}
- pf := st.wantPing()
+ pf := readFrame[*PingFrame](t, st)
if !pf.Flags.Has(FlagPingAck) {
t.Error("response ping doesn't have ACK set")
}
@@ -1216,38 +1185,36 @@ func (l *filterListener) Accept() (net.Conn, error) {
}
func TestServer_MaxQueuedControlFrames(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
+ // Goroutine debugging makes this test very slow.
+ disableGoroutineTracking(t)
- st := newServerTester(t, nil, func(ts *httptest.Server) {
- // TCP buffer sizes on test systems aren't under our control and can be large.
- // Create a conn that blocks after 10000 bytes written.
- ts.Listener = &filterListener{
- Listener: ts.Listener,
- accept: func(conn net.Conn) (net.Conn, error) {
- return newBlockingWriteConn(conn, 10000), nil
- },
- }
- })
- defer st.Close()
+ st := newServerTester(t, nil)
st.greet()
- const extraPings = 500000 // enough to fill the TCP buffers
+ st.cc.(*synctestNetConn).SetReadBufferSize(0) // all writes block
+ st.cc.(*synctestNetConn).autoWait = false // don't sync after every write
+ // Send maxQueuedControlFrames pings, plus a few extra
+ // to account for ones that enter the server's write buffer.
+ const extraPings = 2
for i := 0; i < maxQueuedControlFrames+extraPings; i++ {
pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- if err := st.fr.WritePing(false, pingData); err != nil {
- if i == 0 {
- t.Fatal(err)
- }
- // We expect the connection to get closed by the server when the TCP
- // buffer fills up and the write queue reaches MaxQueuedControlFrames.
- t.Logf("sent %d PING frames", i)
- return
+ st.fr.WritePing(false, pingData)
+ }
+ st.group.Wait()
+
+ // Unblock the server.
+ // It should have closed the connection after exceeding the control frame limit.
+ st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt)
+
+ st.advance(goAwayTimeout)
+ // Some frames may have persisted in the server's buffers.
+ for i := 0; i < 10; i++ {
+ if st.readFrame() == nil {
+ break
}
}
- t.Errorf("unexpected success sending all PING frames")
+ st.wantClosed()
}
func TestServer_RejectsLargeFrames(t *testing.T) {
@@ -1263,15 +1230,9 @@ func TestServer_RejectsLargeFrames(t *testing.T) {
// will only read the first 9 bytes (the headre) and then disconnect.
st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
- gf := st.wantGoAway()
- if gf.ErrCode != ErrCodeFrameSize {
- t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
- }
- if st.serverLogBuf.Len() != 0 {
- // Previously we spun here for a bit until the GOAWAY disconnect
- // timer fired, logging while we fired.
- t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
- }
+ st.wantGoAway(0, ErrCodeFrameSize)
+ st.advance(goAwayTimeout)
+ st.wantClosed()
}
func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
@@ -1297,7 +1258,6 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
EndStream: false, // data coming
EndHeaders: true,
})
- st.writeReadPing()
// Write less than half the max window of data and consume it.
// The server doesn't return flow control yet, buffering the 1024 bytes to
@@ -1305,20 +1265,17 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
data := make([]byte, windowSize)
st.writeData(1, false, data[:1024])
puppet.do(readBodyHandler(t, string(data[:1024])))
- st.writeReadPing()
// Write up to the window limit.
// The server returns the buffered credit.
st.writeData(1, false, data[1024:])
st.wantWindowUpdate(0, 1024)
st.wantWindowUpdate(1, 1024)
- st.writeReadPing()
// The handler consumes the data and the server returns credit.
puppet.do(readBodyHandler(t, string(data[1024:])))
st.wantWindowUpdate(0, windowSize-1024)
st.wantWindowUpdate(1, windowSize-1024)
- st.writeReadPing()
}
// the version of the TestServer_Handler_Sends_WindowUpdate with padding.
@@ -1342,7 +1299,6 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
EndStream: false,
EndHeaders: true,
})
- st.writeReadPing()
// Write half a window of data, with some padding.
// The server doesn't return the padding yet, buffering the 5 bytes to combine
@@ -1350,7 +1306,6 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
data := make([]byte, windowSize/2)
pad := make([]byte, 4)
st.writeDataPadded(1, false, data, pad)
- st.writeReadPing()
// The handler consumes the body.
// The server returns flow control for the body and padding
@@ -1367,13 +1322,7 @@ func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
t.Fatal(err)
}
- gf := st.wantGoAway()
- if gf.ErrCode != ErrCodeFlowControl {
- t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
- }
- if gf.LastStreamID != 0 {
- t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
- }
+ st.wantGoAway(0, ErrCodeFlowControl)
}
func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
@@ -1580,10 +1529,10 @@ func TestServer_StateTransitions(t *testing.T) {
st.writeData(1, true, nil)
leaveHandler <- true
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Fatal("expected END_STREAM flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
if got, want := st.streamState(1), stateClosed; got != want {
t.Errorf("at end, state is %v; want %v", got, want)
@@ -1595,97 +1544,101 @@ func TestServer_StateTransitions(t *testing.T) {
// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.writeHeaders(HeadersFrameParam{
- StreamID: 1,
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: false,
- })
- st.writeHeaders(HeadersFrameParam{ // Not a continuation.
- StreamID: 3, // different stream.
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: true,
- })
+ st := newServerTesterForError(t)
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: false,
+ })
+ st.writeHeaders(HeadersFrameParam{ // Not a continuation.
+ StreamID: 3, // different stream.
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
})
+ st.wantGoAway(0, ErrCodeProtocol)
}
// test HEADERS w/o EndHeaders + PING (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.writeHeaders(HeadersFrameParam{
- StreamID: 1,
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: false,
- })
- if err := st.fr.WritePing(false, [8]byte{}); err != nil {
- t.Fatal(err)
- }
+ st := newServerTesterForError(t)
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: false,
})
+ if err := st.fr.WritePing(false, [8]byte{}); err != nil {
+ t.Fatal(err)
+ }
+ st.wantGoAway(0, ErrCodeProtocol)
}
// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.writeHeaders(HeadersFrameParam{
- StreamID: 1,
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: true,
- })
- st.wantHeaders()
- if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
- t.Fatal(err)
- }
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optQuiet)
+ st.greet()
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
})
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
+ t.Fatal(err)
+ }
+ st.wantGoAway(1, ErrCodeProtocol)
}
// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.writeHeaders(HeadersFrameParam{
- StreamID: 1,
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: false,
- })
- if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
- t.Fatal(err)
- }
+ st := newServerTesterForError(t)
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: false,
})
+ if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
+ t.Fatal(err)
+ }
+ st.wantGoAway(0, ErrCodeProtocol)
}
// No HEADERS on stream 0.
func TestServer_Rejects_Headers0(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.fr.AllowIllegalWrites = true
- st.writeHeaders(HeadersFrameParam{
- StreamID: 0,
- BlockFragment: st.encodeHeader(),
- EndStream: true,
- EndHeaders: true,
- })
+ st := newServerTesterForError(t)
+ st.fr.AllowIllegalWrites = true
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 0,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
})
+ st.wantGoAway(0, ErrCodeProtocol)
}
// No CONTINUATION on stream 0.
func TestServer_Rejects_Continuation0(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.fr.AllowIllegalWrites = true
- if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
- t.Fatal(err)
- }
- })
+ st := newServerTesterForError(t)
+ st.fr.AllowIllegalWrites = true
+ if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
+ t.Fatal(err)
+ }
+ st.wantGoAway(0, ErrCodeProtocol)
}
// No PRIORITY on stream 0.
func TestServer_Rejects_Priority0(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- st.fr.AllowIllegalWrites = true
- st.writePriority(0, PriorityParam{StreamDep: 1})
- })
+ st := newServerTesterForError(t)
+ st.fr.AllowIllegalWrites = true
+ st.writePriority(0, PriorityParam{StreamDep: 1})
+ st.wantGoAway(0, ErrCodeProtocol)
}
// No HEADERS frame with a self-dependence.
@@ -1711,36 +1664,15 @@ func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
}
func TestServer_Rejects_PushPromise(t *testing.T) {
- testServerRejectsConn(t, func(st *serverTester) {
- pp := PushPromiseParam{
- StreamID: 1,
- PromiseID: 3,
- }
- if err := st.fr.WritePushPromise(pp); err != nil {
- t.Fatal(err)
- }
- })
-}
-
-// testServerRejectsConn tests that the server hangs up with a GOAWAY
-// frame and a server close after the client does something
-// deserving a CONNECTION_ERROR.
-func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
- st.addLogFilter("connection error: PROTOCOL_ERROR")
- defer st.Close()
- st.greet()
- writeReq(st)
-
- st.wantGoAway()
-
- fr, err := st.fr.ReadFrame()
- if err == nil {
- t.Errorf("ReadFrame got frame of type %T; want io.EOF", fr)
+ st := newServerTesterForError(t)
+ pp := PushPromiseParam{
+ StreamID: 1,
+ PromiseID: 3,
}
- if err != io.EOF {
- t.Errorf("ReadFrame = %v; want io.EOF", err)
+ if err := st.fr.WritePushPromise(pp); err != nil {
+ t.Fatal(err)
}
+ st.wantGoAway(1, ErrCodeProtocol)
}
// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
@@ -1780,13 +1712,10 @@ func TestServer_Response_NoData(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Fatal("want END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
})
}
@@ -1796,22 +1725,15 @@ func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Fatal("want END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"foo-bar", "some-value"},
- {"content-length", "0"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":status": []string{"200"},
+ "foo-bar": []string{"some-value"},
+ "content-length": []string{"0"},
+ },
+ })
})
}
@@ -1856,15 +1778,14 @@ func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-length", tt.wantCL},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("For case %q, value %q, got = %q; want %q", tt.name, tt.cl, goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{tt.wantCL},
+ },
+ })
})
}
}
@@ -1934,29 +1855,20 @@ func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("don't want END_STREAM, expecting data")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "foo/bar"},
- {"content-length", strconv.Itoa(len(msg))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
- df := st.wantData()
- if !df.StreamEnded() {
- t.Error("expected DATA to have END_STREAM flag")
- }
- if got := string(df.Data()); got != msg {
- t.Errorf("got DATA %q; want %q", got, msg)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"foo/bar"},
+ "content-length": []string{strconv.Itoa(len(msg))},
+ },
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ data: []byte(msg),
+ })
})
}
@@ -1968,16 +1880,15 @@ func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", strconv.Itoa(len(msg))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "content-length": []string{strconv.Itoa(len(msg))},
+ },
+ })
})
}
@@ -1990,22 +1901,15 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "text/html; charset=utf-8"},
- {"content-length", strconv.Itoa(len(msg))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/html; charset=utf-8"},
+ "content-length": []string{strconv.Itoa(len(msg))},
+ },
+ })
})
}
@@ -2019,23 +1923,16 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"foo", "proper value"},
- {"content-type", "text/html; charset=utf-8"},
- {"content-length", strconv.Itoa(len(msg))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "foo": []string{"proper value"},
+ "content-type": []string{"text/html; charset=utf-8"},
+ "content-length": []string{strconv.Itoa(len(msg))},
+ },
+ })
})
}
@@ -2046,29 +1943,20 @@ func TestServer_Response_Data_SniffLenType(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("don't want END_STREAM, expecting data")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "text/html; charset=utf-8"},
- {"content-length", strconv.Itoa(len(msg))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
- df := st.wantData()
- if !df.StreamEnded() {
- t.Error("expected DATA to have END_STREAM flag")
- }
- if got := string(df.Data()); got != msg {
- t.Errorf("got DATA %q; want %q", got, msg)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/html; charset=utf-8"},
+ "content-length": []string{strconv.Itoa(len(msg))},
+ },
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ data: []byte(msg),
+ })
})
}
@@ -2082,40 +1970,25 @@ func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "text/html; charset=utf-8"}, // sniffed
- // and no content-length
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
- {
- df := st.wantData()
- if df.StreamEnded() {
- t.Error("unexpected END_STREAM flag")
- }
- if got := string(df.Data()); got != msg {
- t.Errorf("got DATA %q; want %q", got, msg)
- }
- }
- {
- df := st.wantData()
- if !df.StreamEnded() {
- t.Error("wanted END_STREAM flag on last data chunk")
- }
- if got := string(df.Data()); got != msg2 {
- t.Errorf("got DATA %q; want %q", got, msg2)
- }
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/html; charset=utf-8"}, // sniffed
+ // and no content-length
+ },
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: false,
+ data: []byte(msg),
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ data: []byte(msg2),
+ })
})
}
@@ -2151,25 +2024,18 @@ func TestServer_Response_LargeWrite(t *testing.T) {
if err := st.fr.WriteWindowUpdate(0, size); err != nil {
t.Fatal(err)
}
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"content-type", "text/plain; charset=utf-8"}, // sniffed
- // and no content-length
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/plain; charset=utf-8"}, // sniffed
+ // and no content-length
+ },
+ })
var bytes, frames int
for {
- df := st.wantData()
+ df := readFrame[*DataFrame](t, st)
bytes += len(df.Data())
frames++
for _, b := range df.Data() {
@@ -2220,27 +2086,26 @@ func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
getSlash(st) // make the single request
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
- df := st.wantData()
- if got := len(df.Data()); got != reads[0] {
- t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: false,
+ size: reads[0],
+ })
- for _, quota := range reads[1:] {
+ for i, quota := range reads[1:] {
if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
t.Fatal(err)
}
- df := st.wantData()
- if int(quota) != len(df.Data()) {
- t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: i == len(reads[1:])-1,
+ size: quota,
+ })
}
})
}
@@ -2267,13 +2132,10 @@ func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
getSlash(st) // make the single request
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
t.Fatal(err)
@@ -2295,21 +2157,16 @@ func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
getSlash(st) // make the single request
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
- df := st.wantData()
- if got := len(df.Data()); got != 0 {
- t.Fatalf("unexpected %d DATA bytes; want 0", got)
- }
- if !df.StreamEnded() {
- t.Fatal("DATA didn't have END_STREAM")
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ size: 0,
+ })
})
}
@@ -2334,49 +2191,33 @@ func TestServer_Response_Automatic100Continue(t *testing.T) {
EndStream: false,
EndHeaders: true,
})
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "100"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Fatalf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"100"},
+ },
+ })
// Okay, they sent status 100, so we can send our
// gigantic and/or sensitive "foo" payload now.
st.writeData(1, true, []byte(msg))
- hf = st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("expected data to follow")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- goth = st.decodeHeader(hf.HeaderBlockFragment())
- wanth = [][2]string{
- {":status", "200"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", strconv.Itoa(len(reply))},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "content-length": []string{strconv.Itoa(len(reply))},
+ },
+ })
- df := st.wantData()
- if string(df.Data()) != reply {
- t.Errorf("Client read %q; want %q", df.Data(), reply)
- }
- if !df.StreamEnded() {
- t.Errorf("expect data stream end")
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ data: []byte(reply),
+ })
})
}
@@ -2398,13 +2239,10 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
EndStream: false,
EndHeaders: true,
})
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("unexpected END_STREAM flag")
- }
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
// Close the connection and wait for the handler to (hopefully) notice.
st.cc.Close()
_ = <-errc
@@ -2425,6 +2263,11 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
<-leaveHandler
})
defer st.Close()
+
+ // Automatically syncing after every write / before every read
+ // slows this test down substantially.
+ st.cc.(*synctestNetConn).autoWait = false
+
st.greet()
nextStreamID := uint32(1)
streamID := func() uint32 {
@@ -2464,11 +2307,16 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
t.Fatal(err)
}
+ st.sync()
st.wantRSTStream(rejectID, ErrCodeProtocol)
// But let a handler finish:
leaveHandler <- true
- st.wantHeaders()
+ st.sync()
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
// And now another stream should be able to start:
goodID := streamID()
@@ -2488,14 +2336,14 @@ func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
+ hf := readFrame[*HeadersFrame](t, st)
if hf.HeadersEnded() {
t.Fatal("got unwanted END_HEADERS flag")
}
n := 0
for {
n++
- cf := st.wantContinuation()
+ cf := readFrame[*ContinuationFrame](t, st)
if cf.HeadersEnded() {
break
}
@@ -2524,10 +2372,10 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
EndStream: false, // DATA is coming
EndHeaders: true,
})
- hf := st.wantHeaders()
- if !hf.HeadersEnded() || !hf.StreamEnded() {
- t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
// Sent when the a Handler closes while a client has
// indicated it's still sending DATA:
@@ -2582,79 +2430,51 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
-func testRejectTLS(t *testing.T, max uint16) {
- st := newServerTester(t, nil, func(c *tls.Config) {
+func testRejectTLS(t *testing.T, version uint16) {
+ st := newServerTester(t, nil, func(state *tls.ConnectionState) {
// As of 1.18 the default minimum Go TLS version is
// 1.2. In order to test rejection of lower versions,
- // manually set the minimum version to 1.0
- c.MinVersion = tls.VersionTLS10
- c.MaxVersion = max
+ // manually set the version to 1.0
+ state.Version = version
})
defer st.Close()
- gf := st.wantGoAway()
- if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
- t.Errorf("Got error code %v; want %v", got, want)
- }
+ st.wantGoAway(0, ErrCodeInadequateSecurity)
}
func TestServer_Rejects_TLSBadCipher(t *testing.T) {
- st := newServerTester(t, nil, func(c *tls.Config) {
- // All TLS 1.3 ciphers are good. Test with TLS 1.2.
- c.MaxVersion = tls.VersionTLS12
- // Only list bad ones:
- c.CipherSuites = []uint16{
- tls.TLS_RSA_WITH_RC4_128_SHA,
- tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
- tls.TLS_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
- tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
- }
+ st := newServerTester(t, nil, func(state *tls.ConnectionState) {
+ state.Version = tls.VersionTLS12
+ state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
})
defer st.Close()
- gf := st.wantGoAway()
- if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
- t.Errorf("Got error code %v; want %v", got, want)
- }
+ st.wantGoAway(0, ErrCodeInadequateSecurity)
}
func TestServer_Advertises_Common_Cipher(t *testing.T) {
- const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
- st := newServerTester(t, nil, func(c *tls.Config) {
- // Have the client only support the one required by the spec.
- c.CipherSuites = []uint16{requiredSuite}
- }, func(ts *httptest.Server) {
- var srv *http.Server = ts.Config
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(srv *http.Server) {
// Have the server configured with no specific cipher suites.
// This tests that Go's defaults include the required one.
srv.TLSConfig = nil
})
- defer st.Close()
- st.greet()
-}
-func (st *serverTester) onHeaderField(f hpack.HeaderField) {
- if f.Name == "date" {
- return
- }
- st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
-}
+ // Have the client only support the one required by the spec.
+ const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
+ tlsConfig := tlsConfigInsecure.Clone()
+ tlsConfig.MaxVersion = tls.VersionTLS12
+ tlsConfig.CipherSuites = []uint16{requiredSuite}
+ tr := &Transport{TLSClientConfig: tlsConfig}
+ defer tr.CloseIdleConnections()
-func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
- st.decodedHeaders = nil
- if _, err := st.hpackDec.Write(headerBlock); err != nil {
- st.t.Fatalf("hpack decoding error: %v", err)
+ req, err := http.NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
}
- if err := st.hpackDec.Close(); err != nil {
- st.t.Fatalf("hpack decoding error: %v", err)
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
}
- return st.decodedHeaders
+ res.Body.Close()
}
// testServerResponse sets up an idle HTTP/2 connection. The client function should
@@ -2798,19 +2618,15 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) {
st.fr.WriteContinuation(1, len(b) == 0, chunk)
}
- h := st.wantHeaders()
- if !h.HeadersEnded() {
- t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
- }
- headers := st.decodeHeader(h.HeaderBlockFragment())
- want := [][2]string{
- {":status", "431"},
- {"content-type", "text/html; charset=utf-8"},
- {"content-length", "63"},
- }
- if !reflect.DeepEqual(headers, want) {
- t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"431"},
+ "content-type": []string{"text/html; charset=utf-8"},
+ "content-length": []string{"63"},
+ },
+ })
}
func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
@@ -2819,17 +2635,15 @@ func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if !hf.HeadersEnded() {
- t.Fatal("want END_HEADERS flag")
- }
- df := st.wantData()
- if len(df.data) != 0 {
- t.Fatal("did not want data")
- }
- if !df.StreamEnded() {
- t.Fatal("want END_STREAM flag")
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ size: 0,
+ })
})
}
@@ -2838,8 +2652,8 @@ func TestCompressionErrorOnWrite(t *testing.T) {
var serverConfig *http.Server
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
// No response body.
- }, func(ts *httptest.Server) {
- serverConfig = ts.Config
+ }, func(s *http.Server) {
+ serverConfig = s
serverConfig.MaxHeaderBytes = maxStrLen
})
st.addLogFilter("connection error: COMPRESSION_ERROR")
@@ -2867,20 +2681,16 @@ func TestCompressionErrorOnWrite(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- h := st.wantHeaders()
- if !h.HeadersEnded() {
- t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
- }
- headers := st.decodeHeader(h.HeaderBlockFragment())
- want := [][2]string{
- {":status", "431"},
- {"content-type", "text/html; charset=utf-8"},
- {"content-length", "63"},
- }
- if !reflect.DeepEqual(headers, want) {
- t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
- }
- df := st.wantData()
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"431"},
+ "content-type": []string{"text/html; charset=utf-8"},
+ "content-length": []string{"63"},
+ },
+ })
+ df := readFrame[*DataFrame](t, st)
if !strings.Contains(string(df.Data()), "HTTP Error 431") {
t.Errorf("Unexpected data body: %q", df.Data())
}
@@ -2896,10 +2706,7 @@ func TestCompressionErrorOnWrite(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeCompression {
- t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
- }
+ st.wantGoAway(3, ErrCodeCompression)
}
func TestCompressionErrorOnClose(t *testing.T) {
@@ -2918,10 +2725,7 @@ func TestCompressionErrorOnClose(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeCompression {
- t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
- }
+ st.wantGoAway(1, ErrCodeCompression)
}
// test that a server handler can read trailers from a client
@@ -2956,7 +2760,7 @@ func TestServerReadsTrailers(t *testing.T) {
if !reflect.DeepEqual(r.Trailer, wantTrailer) {
t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
}
- slurp, err := ioutil.ReadAll(r.Body)
+ slurp, err := io.ReadAll(r.Body)
if string(slurp) != testBody {
t.Errorf("read body %q; want %q", slurp, testBody)
}
@@ -3010,67 +2814,54 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
return nil
}, func(st *serverTester) {
- getSlash(st)
- hf := st.wantHeaders()
- if hf.StreamEnded() {
- t.Fatal("response HEADERS had END_STREAM")
- }
- if !hf.HeadersEnded() {
- t.Fatal("response HEADERS didn't have END_HEADERS")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"foo", "Bar"},
- {"trailer", "Server-Trailer-A, Server-Trailer-B"},
- {"trailer", "Server-Trailer-C"},
- {"trailer", "Transfer-Encoding, Content-Length, Trailer"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", "5"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
- }
- df := st.wantData()
- if string(df.Data()) != "Hello" {
- t.Fatalf("Client read %q; want Hello", df.Data())
- }
- if df.StreamEnded() {
- t.Fatalf("data frame had STREAM_ENDED")
- }
- tf := st.wantHeaders() // for the trailers
- if !tf.StreamEnded() {
- t.Fatalf("trailers HEADERS lacked END_STREAM")
- }
- if !tf.HeadersEnded() {
- t.Fatalf("trailers HEADERS lacked END_HEADERS")
- }
- wanth = [][2]string{
- {"post-header-trailer", "hi1"},
- {"post-header-trailer2", "hi2"},
- {"server-trailer-a", "valuea"},
- {"server-trailer-c", "valuec"},
- }
- goth = st.decodeHeader(tf.HeaderBlockFragment())
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
- }
+ // Ignore errors from writing invalid trailers.
+ st.h1server.ErrorLog = log.New(io.Discard, "", 0)
+ getSlash(st)
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "foo": []string{"Bar"},
+ "trailer": []string{
+ "Server-Trailer-A, Server-Trailer-B",
+ "Server-Trailer-C",
+ "Transfer-Encoding, Content-Length, Trailer",
+ },
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "content-length": []string{"5"},
+ },
+ })
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: false,
+ data: []byte("Hello"),
+ })
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ "post-header-trailer": []string{"hi1"},
+ "post-header-trailer2": []string{"hi2"},
+ "server-trailer-a": []string{"valuea"},
+ "server-trailer-c": []string{"valuec"},
+ },
+ })
})
}
func TestServerWritesUndeclaredTrailers(t *testing.T) {
const trailer = "Trailer-Header"
const value = "hi1"
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(http.TrailerPrefix+trailer, value)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cl := &http.Client{Transport: tr}
- resp, err := cl.Get(st.ts.URL)
+ resp, err := cl.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -3093,31 +2884,24 @@ func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Error("response HEADERS lacked END_STREAM")
- }
- if !hf.HeadersEnded() {
- t.Fatal("response HEADERS didn't have END_HEADERS")
- }
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"ok1", "x"},
- {"content-length", "0"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":status": []string{"200"},
+ "ok1": []string{"x"},
+ "content-length": []string{"0"},
+ },
+ })
})
}
func BenchmarkServerGets(b *testing.B) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world"
- st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
+ st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, msg)
})
defer st.Close()
@@ -3136,24 +2920,23 @@ func BenchmarkServerGets(b *testing.B) {
EndStream: true,
EndHeaders: true,
})
- st.wantHeaders()
- df := st.wantData()
- if !df.StreamEnded() {
+ st.wantFrameType(FrameHeaders)
+ if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
b.Fatalf("DATA didn't have END_STREAM; got %v", df)
}
}
}
func BenchmarkServerPosts(b *testing.B) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world"
- st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
+ st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
// Consume the (empty) body from th peer before replying, otherwise
// the server will sometimes (depending on scheduling) send the peer a
// a RST_STREAM with the CANCEL error code.
- if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
+ if n, err := io.Copy(io.Discard, r.Body); n != 0 || err != nil {
b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
}
io.WriteString(w, msg)
@@ -3175,9 +2958,8 @@ func BenchmarkServerPosts(b *testing.B) {
EndHeaders: true,
})
st.writeData(id, true, nil)
- st.wantHeaders()
- df := st.wantData()
- if !df.StreamEnded() {
+ st.wantFrameType(FrameHeaders)
+ if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
b.Fatalf("DATA didn't have END_STREAM; got %v", df)
}
}
@@ -3197,7 +2979,7 @@ func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
}
func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
const msgLen = 1
// default window size
@@ -3213,11 +2995,11 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
return msg
}
- st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
+ st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
// Consume the (empty) body from th peer before replying, otherwise
// the server will sometimes (depending on scheduling) send the peer a
// a RST_STREAM with the CANCEL error code.
- if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
+ if n, err := io.Copy(io.Discard, r.Body); n != 0 || err != nil {
b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
}
for i := 0; i < b.N; i += 1 {
@@ -3238,18 +3020,22 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
})
st.writeData(id, true, nil)
- st.wantHeaders()
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
var pendingWindowUpdate = uint32(0)
for i := 0; i < b.N; i += 1 {
expected := nextMsg(i)
- df := st.wantData()
- if bytes.Compare(expected, df.data) != 0 {
- b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: false,
+ data: expected,
+ })
// try to send infrequent but large window updates so they don't overwhelm the test
- pendingWindowUpdate += uint32(len(df.data))
+ pendingWindowUpdate += uint32(len(expected))
if pendingWindowUpdate >= windowSize/2 {
if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
b.Fatal(err)
@@ -3260,10 +3046,10 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
pendingWindowUpdate = 0
}
}
- df := st.wantData()
- if !df.StreamEnded() {
- b.Fatalf("DATA didn't have END_STREAM; got %v", df)
- }
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ })
}
// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53
@@ -3427,14 +3213,13 @@ func TestServerNoAutoContentLengthOnHead(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- h := st.wantHeaders()
- headers := st.decodeHeader(h.HeaderBlockFragment())
- want := [][2]string{
- {":status", "200"},
- }
- if !reflect.DeepEqual(headers, want) {
- t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":status": []string{"200"},
+ },
+ })
}
// golang.org/issue/13495
@@ -3451,16 +3236,15 @@ func TestServerNoDuplicateContentType(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- h := st.wantHeaders()
- headers := st.decodeHeader(h.HeaderBlockFragment())
- want := [][2]string{
- {":status", "200"},
- {"content-type", ""},
- {"content-length", "41"},
- }
- if !reflect.DeepEqual(headers, want) {
- t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{""},
+ "content-length": []string{"41"},
+ },
+ })
}
func TestServerContentLengthCanBeDisabled(t *testing.T) {
@@ -3476,29 +3260,28 @@ func TestServerContentLengthCanBeDisabled(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- h := st.wantHeaders()
- headers := st.decodeHeader(h.HeaderBlockFragment())
- want := [][2]string{
- {":status", "200"},
- {"content-type", "text/plain; charset=utf-8"},
- }
- if !reflect.DeepEqual(headers, want) {
- t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/plain; charset=utf-8"},
+ },
+ })
}
-func disableGoroutineTracking() (restore func()) {
+func disableGoroutineTracking(t testing.TB) {
old := DebugGoroutines
DebugGoroutines = false
- return func() { DebugGoroutines = old }
+ t.Cleanup(func() { DebugGoroutines = old })
}
func BenchmarkServer_GetRequest(b *testing.B) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world."
- st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
- n, err := io.Copy(ioutil.Discard, r.Body)
+ st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
+ n, err := io.Copy(io.Discard, r.Body)
if err != nil || n > 0 {
b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
}
@@ -3520,17 +3303,17 @@ func BenchmarkServer_GetRequest(b *testing.B) {
EndStream: true,
EndHeaders: true,
})
- st.wantHeaders()
- st.wantData()
+ st.wantFrameType(FrameHeaders)
+ st.wantFrameType(FrameData)
}
}
func BenchmarkServer_PostRequest(b *testing.B) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world."
- st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
- n, err := io.Copy(ioutil.Discard, r.Body)
+ st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
+ n, err := io.Copy(io.Discard, r.Body)
if err != nil || n > 0 {
b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
}
@@ -3552,8 +3335,8 @@ func BenchmarkServer_PostRequest(b *testing.B) {
EndHeaders: true,
})
st.writeData(streamID, true, nil)
- st.wantHeaders()
- st.wantData()
+ st.wantFrameType(FrameHeaders)
+ st.wantFrameType(FrameData)
}
}
@@ -3604,7 +3387,7 @@ func TestServerHandleCustomConn(t *testing.T) {
EndStream: true,
EndHeaders: true,
})
- go io.Copy(ioutil.Discard, c2)
+ go io.Copy(io.Discard, c2)
<-handlerDone
}()
const testString = "my custom ConnectionState"
@@ -3638,17 +3421,16 @@ func TestServer_Rejects_ConnHeaders(t *testing.T) {
defer st.Close()
st.greet()
st.bodylessReq1("connection", "foo")
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "400"},
- {"content-type", "text/plain; charset=utf-8"},
- {"x-content-type-options", "nosniff"},
- {"content-length", "51"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"400"},
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "x-content-type-options": []string{"nosniff"},
+ "content-length": []string{"51"},
+ },
+ })
}
type hpackEncoder struct {
@@ -3725,7 +3507,7 @@ func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
doRead := make(chan bool, 1)
defer close(doRead) // fallback cleanup
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, msg)
w.(http.Flusher).Flush()
@@ -3734,14 +3516,12 @@ func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
r.Body.Read(make([]byte, 10))
io.WriteString(w, msg2)
-
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
+ req, _ := http.NewRequest("POST", ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
req.Header.Set("Expect", "100-continue")
res, err := tr.RoundTrip(req)
@@ -3802,14 +3582,13 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) {
unblock := make(chan bool, 1)
defer close(unblock)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
// Don't read the 16KB request body. Wait until the client's
// done sending it and then return. This should cause the Server
// to then return those 16KB of flow control to the client.
tt.reqFn(r)
<-unblock
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -3827,7 +3606,7 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) {
return 0, io.EOF
}),
)
- req, _ := http.NewRequest("POST", st.ts.URL, body)
+ req, _ := http.NewRequest("POST", ts.URL, body)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(tt.name, err)
@@ -3856,12 +3635,18 @@ func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) {
BlockFragment: st.encodeHeader(),
EndHeaders: true,
})
- st.wantHeaders()
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ })
const size = inflowMinRefresh // enough to trigger flow control return
st.writeData(1, false, make([]byte, size))
st.wantWindowUpdate(0, size) // conn-level flow control is returned
unblockHandler <- struct{}{}
- st.wantData()
+ st.wantData(wantData{
+ streamID: 1,
+ endStream: true,
+ })
}
func TestServerIdleTimeout(t *testing.T) {
@@ -3876,22 +3661,24 @@ func TestServerIdleTimeout(t *testing.T) {
defer st.Close()
st.greet()
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeNo {
- t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
- }
+ st.advance(500 * time.Millisecond)
+ st.wantGoAway(0, ErrCodeNo)
}
func TestServerIdleTimeout_AfterRequest(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- const timeout = 250 * time.Millisecond
+ const (
+ requestTimeout = 2 * time.Second
+ idleTimeout = 1 * time.Second
+ )
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- time.Sleep(timeout * 2)
+ var st *serverTester
+ st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ st.group.Sleep(requestTimeout)
}, func(h2s *Server) {
- h2s.IdleTimeout = timeout
+ h2s.IdleTimeout = idleTimeout
})
defer st.Close()
@@ -3900,14 +3687,16 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) {
// Send a request which takes twice the timeout. Verifies the
// idle timeout doesn't fire while we're in a request:
st.bodylessReq1()
- st.wantHeaders()
+ st.advance(requestTimeout)
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
// But the idle timeout should be rearmed after the request
// is done:
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeNo {
- t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
- }
+ st.advance(idleTimeout)
+ st.wantGoAway(1, ErrCodeNo)
}
// grpc-go closes the Request.Body currently with a Read.
@@ -3943,22 +3732,21 @@ func TestIssue20704Race(t *testing.T) {
itemCount = 100
)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
for i := 0; i < itemCount; i++ {
_, err := w.Write(make([]byte, itemSize))
if err != nil {
return
}
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cl := &http.Client{Transport: tr}
for i := 0; i < 1000; i++ {
- resp, err := cl.Get(st.ts.URL)
+ resp, err := cl.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -3970,7 +3758,7 @@ func TestIssue20704Race(t *testing.T) {
func TestServer_Rejects_TooSmall(t *testing.T) {
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
- ioutil.ReadAll(r.Body)
+ io.ReadAll(r.Body)
return nil
}, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
@@ -4010,13 +3798,10 @@ func TestServerHandlerConnectionClose(t *testing.T) {
var sawRes bool
var sawWindowUpdate bool
for {
- f, err := st.readFrame()
- if err == io.EOF {
+ f := st.readFrame()
+ if f == nil {
break
}
- if err != nil {
- t.Fatal(err)
- }
switch f := f.(type) {
case *GoAwayFrame:
sawGoAway = true
@@ -4068,6 +3853,8 @@ func TestServerHandlerConnectionClose(t *testing.T) {
}
sawWindowUpdate = true
unblockHandler <- true
+ st.sync()
+ st.advance(goAwayTimeout)
default:
t.Logf("unexpected frame: %v", summarizeFrame(f))
}
@@ -4133,20 +3920,9 @@ func TestServer_Headers_HalfCloseRemote(t *testing.T) {
}
func TestServerGracefulShutdown(t *testing.T) {
- var st *serverTester
handlerDone := make(chan struct{})
- st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- defer close(handlerDone)
- go st.ts.Config.Shutdown(context.Background())
-
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeNo {
- t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
- }
- if ga.LastStreamID != 1 {
- t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
- }
-
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ <-handlerDone
w.Header().Set("x-foo", "bar")
})
defer st.Close()
@@ -4154,17 +3930,23 @@ func TestServerGracefulShutdown(t *testing.T) {
st.greet()
st.bodylessReq1()
- <-handlerDone
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"x-foo", "bar"},
- {"content-length", "0"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
+ st.sync()
+ st.h1server.Shutdown(context.Background())
+
+ st.wantGoAway(1, ErrCodeNo)
+
+ close(handlerDone)
+ st.sync()
+
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":status": []string{"200"},
+ "x-foo": []string{"bar"},
+ "content-length": []string{"0"},
+ },
+ })
n, err := st.cc.Read([]byte{0})
if n != 0 || err == nil {
@@ -4235,26 +4017,25 @@ func TestContentEncodingNoSniffing(t *testing.T) {
for _, tt := range resps {
t.Run(tt.name, func(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if tt.contentEncoding != nil {
w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
}
w.Write(tt.body)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
res, err := tr.RoundTrip(req)
if err != nil {
- t.Fatalf("GET %s: %v", st.ts.URL, err)
+ t.Fatalf("GET %s: %v", ts.URL, err)
}
defer res.Body.Close()
g := res.Header.Get("Content-Encoding")
- t.Logf("%s: Content-Encoding: %s", st.ts.URL, g)
+ t.Logf("%s: Content-Encoding: %s", ts.URL, g)
if w := tt.contentEncoding; g != w {
if w != nil { // The case where contentEncoding was set explicitly.
@@ -4268,7 +4049,7 @@ func TestContentEncodingNoSniffing(t *testing.T) {
if w := tt.wantContentType; g != w {
t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
}
- t.Logf("%s: Content-Type: %s", st.ts.URL, g)
+ t.Logf("%s: Content-Type: %s", ts.URL, g)
})
}
}
@@ -4316,13 +4097,10 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) {
// Wait for flow control credit for the portion of the request written so far.
increments := windowSize / 2
for {
- f, err := st.readFrame()
- if err == io.EOF {
+ f := st.readFrame()
+ if f == nil {
break
}
- if err != nil {
- t.Fatal(err)
- }
if wu, ok := f.(*WindowUpdateFrame); ok && wu.StreamID == 0 {
increments -= int(wu.Increment)
if increments == 0 {
@@ -4356,24 +4134,16 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) {
EndStream: false,
EndHeaders: true,
})
- st.wantHeaders()
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
st.sc.startGracefulShutdown()
- for {
- f, err := st.readFrame()
- if err == io.EOF {
- st.t.Fatal("got a EOF; want *GoAwayFrame")
- }
- if err != nil {
- t.Fatal(err)
- }
- if gf, ok := f.(*GoAwayFrame); ok && gf.StreamID == 0 {
- break
- }
- }
+ st.wantRSTStream(1, ErrCodeNo)
+ st.wantGoAway(1, ErrCodeNo)
st.writeData(1, true, []byte(content))
- time.Sleep(200 * time.Millisecond)
st.Close()
if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) {
@@ -4389,27 +4159,22 @@ func TestServerSendsProcessing(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "102"},
- }
-
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got = %q; want %q", goth, wanth)
- }
-
- hf = st.wantHeaders()
- goth = st.decodeHeader(hf.HeaderBlockFragment())
- wanth = [][2]string{
- {":status", "200"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", "5"},
- }
-
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got = %q; want %q", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"102"},
+ },
+ })
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "content-length": []string{"5"},
+ },
+ })
})
}
@@ -4429,45 +4194,43 @@ func TestServerSendsEarlyHints(t *testing.T) {
return nil
}, func(st *serverTester) {
getSlash(st)
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "103"},
- {"link", "; rel=preload; as=style"},
- {"link", "; rel=preload; as=script"},
- }
-
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got = %q; want %q", goth, wanth)
- }
-
- hf = st.wantHeaders()
- goth = st.decodeHeader(hf.HeaderBlockFragment())
- wanth = [][2]string{
- {":status", "103"},
- {"link", "; rel=preload; as=style"},
- {"link", "; rel=preload; as=script"},
- {"link", "; rel=preload; as=script"},
- }
-
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got = %q; want %q", goth, wanth)
- }
-
- hf = st.wantHeaders()
- goth = st.decodeHeader(hf.HeaderBlockFragment())
- wanth = [][2]string{
- {":status", "200"},
- {"link", "; rel=preload; as=style"},
- {"link", "; rel=preload; as=script"},
- {"link", "; rel=preload; as=script"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", "123"},
- }
-
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got = %q; want %q", goth, wanth)
- }
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"103"},
+ "link": []string{
+ "; rel=preload; as=style",
+ "; rel=preload; as=script",
+ },
+ },
+ })
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"103"},
+ "link": []string{
+ "; rel=preload; as=style",
+ "; rel=preload; as=script",
+ "; rel=preload; as=script",
+ },
+ },
+ })
+ st.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: false,
+ header: http.Header{
+ ":status": []string{"200"},
+ "link": []string{
+ "; rel=preload; as=style",
+ "; rel=preload; as=script",
+ "; rel=preload; as=script",
+ },
+ "content-type": []string{"text/plain; charset=utf-8"},
+ "content-length": []string{"123"},
+ },
+ })
})
}
@@ -4489,7 +4252,6 @@ func TestProtocolErrorAfterGoAway(t *testing.T) {
EndHeaders: true,
})
st.writeData(1, false, []byte(content[:5]))
- st.writeReadPing()
// Send a GOAWAY with ErrCodeNo, followed by a bogus window update.
// The server should close the connection.
@@ -4500,14 +4262,9 @@ func TestProtocolErrorAfterGoAway(t *testing.T) {
t.Fatal(err)
}
- for {
- if _, err := st.readFrame(); err != nil {
- if err != io.EOF {
- t.Errorf("unexpected readFrame error: %v", err)
- }
- break
- }
- }
+ st.advance(goAwayTimeout)
+ st.wantGoAway(1, ErrCodeNo)
+ st.wantClosed()
}
func TestServerInitialFlowControlWindow(t *testing.T) {
@@ -4528,9 +4285,9 @@ func TestServerInitialFlowControlWindow(t *testing.T) {
}, func(s *Server) {
s.MaxUploadBufferPerConnection = want
})
- defer st.Close()
st.writePreface()
- st.writeInitialSettings()
+ st.writeSettings()
+ _ = readFrame[*SettingsFrame](t, st)
st.writeSettingsAck()
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
@@ -4541,10 +4298,7 @@ func TestServerInitialFlowControlWindow(t *testing.T) {
window := 65535
Frames:
for {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatal(err)
- }
+ f := st.readFrame()
switch f := f.(type) {
case *WindowUpdateFrame:
if f.FrameHeader.StreamID != 0 {
@@ -4554,6 +4308,8 @@ func TestServerInitialFlowControlWindow(t *testing.T) {
window += int(f.Increment)
case *HeadersFrame:
break Frames
+ case nil:
+ break Frames
default:
}
}
@@ -4572,13 +4328,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) {
sc := &serverConn{
serveG: newGoroutineLock(),
}
- const count = 1000
- for i := 0; i < count; i++ {
- h := fmt.Sprintf("%v-%v", base, i)
+ count := 0
+ added := 0
+ for added < 10*maxCachedCanonicalHeadersKeysSize {
+ h := fmt.Sprintf("%v-%v", base, count)
c := sc.canonicalHeader(h)
if len(h) != len(c) {
t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c)
}
+ count++
+ added += len(h)
}
total := 0
for k, v := range sc.canonHeader {
@@ -4590,14 +4349,14 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) {
}
}
-// TestServerWriteDoesNotRetainBufferAfterStreamClose checks for access to
+// TestServerWriteDoesNotRetainBufferAfterReturn checks for access to
// the slice passed to ResponseWriter.Write after Write returns.
//
// Terminating the request stream on the client causes Write to return.
// We should not access the slice after this point.
func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
donec := make(chan struct{})
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
defer close(donec)
buf := make([]byte, 1<<20)
var i byte
@@ -4611,13 +4370,12 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
return
}
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
@@ -4633,7 +4391,7 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
// We should not access the slice after this point.
func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
donec := make(chan struct{}, 1)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
donec <- struct{}{}
defer close(donec)
buf := make([]byte, 1<<20)
@@ -4648,20 +4406,19 @@ func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
return
}
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
<-donec
- st.ts.Config.Close()
+ ts.Config.Close()
<-donec
}
@@ -4688,9 +4445,7 @@ func TestServerMaxHandlerGoroutines(t *testing.T) {
})
defer st.Close()
- st.writePreface()
- st.writeInitialSettings()
- st.writeSettingsAck()
+ st.greet()
// Make maxHandlers concurrent requests.
// Reset them all, but only after the handler goroutines have started.
@@ -4757,23 +4512,244 @@ func TestServerMaxHandlerGoroutines(t *testing.T) {
st.fr.WriteRSTStream(streamID, ErrCodeCancel)
streamID += 2
}
-Frames:
+ fr := readFrame[*GoAwayFrame](t, st)
+ if fr.ErrCode != ErrCodeEnhanceYourCalm {
+ t.Errorf("err code = %v; want %v", fr.ErrCode, ErrCodeEnhanceYourCalm)
+ }
+
+ for _, s := range stops {
+ close(s)
+ }
+}
+
+func TestServerContinuationFlood(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Println(r.Header)
+ }, func(s *http.Server) {
+ s.MaxHeaderBytes = 4096
+ })
+ defer st.Close()
+
+ st.greet()
+
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ })
+ for i := 0; i < 1000; i++ {
+ st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
+ fmt.Sprintf("x-%v", i), "1234567890",
+ ))
+ }
+ st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
+ "x-last-header", "1",
+ ))
+
for {
- f, err := st.readFrame()
- if err != nil {
- st.t.Fatal(err)
+ f := st.readFrame()
+ if f == nil {
+ break
}
switch f := f.(type) {
+ case *HeadersFrame:
+ t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection")
case *GoAwayFrame:
- if f.ErrCode != ErrCodeEnhanceYourCalm {
- t.Errorf("err code = %v; want %v", f.ErrCode, ErrCodeEnhanceYourCalm)
+ // We might not see the GOAWAY (see below), but if we do it should
+ // indicate that the server processed this request so the client doesn't
+ // attempt to retry it.
+ if got, want := f.LastStreamID, uint32(1); got != want {
+ t.Errorf("received GOAWAY with LastStreamId %v, want %v", got, want)
}
- break Frames
- default:
+
+ }
+ }
+ // We expect to have seen a GOAWAY before the connection closes,
+ // but the server will close the connection after one second
+ // whether or not it has finished sending the GOAWAY. On windows-amd64-race
+ // builders, this fairly consistently results in the connection closing without
+ // the GOAWAY being sent.
+ //
+ // Since the server's behavior is inherently racy here and the important thing
+ // is that the connection is closed, don't check for the GOAWAY having been sent.
+}
+
+func TestServerContinuationAfterInvalidHeader(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Println(r.Header)
+ })
+ defer st.Close()
+
+ st.greet()
+
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ })
+ st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
+ "x-invalid-header", "\x00",
+ ))
+ st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
+ "x-valid-header", "1",
+ ))
+
+ var sawGoAway bool
+ for {
+ f := st.readFrame()
+ if f == nil {
+ break
+ }
+ switch f.(type) {
+ case *GoAwayFrame:
+ sawGoAway = true
+ case *HeadersFrame:
+ t.Fatalf("received HEADERS frame; want GOAWAY")
}
}
+ if !sawGoAway {
+ t.Errorf("connection closed with no GOAWAY frame; want one")
+ }
+}
- for _, s := range stops {
- close(s)
+func TestServerUpgradeRequestPrefaceFailure(t *testing.T) {
+ // An h2c upgrade request fails when the client preface is not as expected.
+ s2 := &Server{
+ // Setting IdleTimeout triggers #67168.
+ IdleTimeout: 60 * time.Minute,
+ }
+ c1, c2 := net.Pipe()
+ donec := make(chan struct{})
+ go func() {
+ defer close(donec)
+ s2.ServeConn(c1, &ServeConnOpts{
+ UpgradeRequest: httptest.NewRequest("GET", "/", nil),
+ })
+ }()
+ // The server expects to see the HTTP/2 preface,
+ // but we close the connection instead.
+ c2.Close()
+ <-donec
+}
+
+// Issue 67036: A stream error should result in the handler's request context being canceled.
+func TestServerRequestCancelOnError(t *testing.T) {
+ recvc := make(chan struct{}) // handler has started
+ donec := make(chan struct{}) // handler has finished
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ close(recvc)
+ <-r.Context().Done()
+ close(donec)
+ })
+ defer st.Close()
+
+ st.greet()
+
+ // Client sends request headers, handler starts.
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ <-recvc
+
+ // Client sends an invalid second set of request headers.
+ // The stream is reset.
+ // The handler's context is canceled, and the handler exits.
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ <-donec
+}
+
+func TestServerSetReadWriteDeadlineRace(t *testing.T) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ ctl := http.NewResponseController(w)
+ ctl.SetReadDeadline(time.Now().Add(3600 * time.Second))
+ ctl.SetWriteDeadline(time.Now().Add(3600 * time.Second))
+ })
+ resp, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+}
+
+func TestServerWriteByteTimeout(t *testing.T) {
+ const timeout = 1 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.Write(make([]byte, 100))
+ }, func(s *Server) {
+ s.WriteByteTimeout = timeout
+ })
+ st.greet()
+
+ st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+
+ // Read a few bytes, staying just under WriteByteTimeout.
+ for i := 0; i < 10; i++ {
+ st.advance(timeout - 1)
+ if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
+ t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
+ }
}
+
+ // Wait for WriteByteTimeout.
+ // The connection should close.
+ st.advance(1 * time.Second) // timeout after writing one byte
+ st.advance(1 * time.Second) // timeout after failing to write any more bytes
+ st.wantClosed()
+}
+
+func TestServerPingSent(t *testing.T) {
+ const readIdleTimeout = 15 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = readIdleTimeout
+ })
+ st.greet()
+
+ st.wantIdle()
+
+ st.advance(readIdleTimeout)
+ _ = readFrame[*PingFrame](t, st)
+ st.wantIdle()
+
+ st.advance(14 * time.Second)
+ st.wantIdle()
+ st.advance(1 * time.Second)
+ st.wantClosed()
+}
+
+func TestServerPingResponded(t *testing.T) {
+ const readIdleTimeout = 15 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = readIdleTimeout
+ })
+ st.greet()
+
+ st.wantIdle()
+
+ st.advance(readIdleTimeout)
+ pf := readFrame[*PingFrame](t, st)
+ st.wantIdle()
+
+ st.advance(14 * time.Second)
+ st.wantIdle()
+
+ st.writePing(true, pf.Data)
+
+ st.advance(2 * time.Second)
+ st.wantIdle()
}
diff --git a/http2/sync_test.go b/http2/sync_test.go
new file mode 100644
index 0000000000..6687202d2c
--- /dev/null
+++ b/http2/sync_test.go
@@ -0,0 +1,329 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "context"
+ "fmt"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// A synctestGroup synchronizes between a set of cooperating goroutines.
+type synctestGroup struct {
+ mu sync.Mutex
+ gids map[int]bool
+ now time.Time
+ timers map[*fakeTimer]struct{}
+}
+
+type goroutine struct {
+ id int
+ parent int
+ state string
+ syscall bool
+}
+
+// newSynctest creates a new group with the synthetic clock set the provided time.
+func newSynctest(now time.Time) *synctestGroup {
+ return &synctestGroup{
+ gids: map[int]bool{
+ currentGoroutine(): true,
+ },
+ now: now,
+ }
+}
+
+// Join adds the current goroutine to the group.
+func (g *synctestGroup) Join() {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ g.gids[currentGoroutine()] = true
+}
+
+// Count returns the number of goroutines in the group.
+func (g *synctestGroup) Count() int {
+ gs := stacks(true)
+ count := 0
+ for _, gr := range gs {
+ if !g.gids[gr.id] && !g.gids[gr.parent] {
+ continue
+ }
+ count++
+ }
+ return count
+}
+
+// Close calls t.Fatal if the group contains any running goroutines.
+func (g *synctestGroup) Close(t testing.TB) {
+ if count := g.Count(); count != 1 {
+ buf := make([]byte, 16*1024)
+ n := runtime.Stack(buf, true)
+ t.Logf("stacks:\n%s", buf[:n])
+ t.Fatalf("%v goroutines still running after test completed, expect 1", count)
+ }
+}
+
+// Wait blocks until every goroutine in the group and their direct children are idle.
+func (g *synctestGroup) Wait() {
+ for i := 0; ; i++ {
+ if g.idle() {
+ return
+ }
+ runtime.Gosched()
+ if runtime.GOOS == "js" {
+ // When GOOS=js, we appear to need to time.Sleep to make progress
+ // on some syscalls. In particular, without this sleep
+ // writing to stdout (including via t.Log) can block forever.
+ for range 10 {
+ time.Sleep(1)
+ }
+ }
+ }
+}
+
+func (g *synctestGroup) idle() bool {
+ gs := stacks(true)
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ for _, gr := range gs[1:] {
+ if !g.gids[gr.id] && !g.gids[gr.parent] {
+ continue
+ }
+ if gr.syscall {
+ return false
+ }
+ // From runtime/runtime2.go.
+ switch gr.state {
+ case "IO wait":
+ case "chan receive (nil chan)":
+ case "chan send (nil chan)":
+ case "select":
+ case "select (no cases)":
+ case "chan receive":
+ case "chan send":
+ case "sync.Cond.Wait":
+ default:
+ return false
+ }
+ }
+ return true
+}
+
+func currentGoroutine() int {
+ s := stacks(false)
+ return s[0].id
+}
+
+func stacks(all bool) []goroutine {
+ buf := make([]byte, 16*1024)
+ for {
+ n := runtime.Stack(buf, all)
+ if n < len(buf) {
+ buf = buf[:n]
+ break
+ }
+ buf = make([]byte, len(buf)*2)
+ }
+
+ var goroutines []goroutine
+ for _, gs := range strings.Split(string(buf), "\n\n") {
+ skip, rest, ok := strings.Cut(gs, "goroutine ")
+ if skip != "" || !ok {
+ panic(fmt.Errorf("1 unparsable goroutine stack:\n%s", gs))
+ }
+ ids, rest, ok := strings.Cut(rest, " [")
+ if !ok {
+ panic(fmt.Errorf("2 unparsable goroutine stack:\n%s", gs))
+ }
+ id, err := strconv.Atoi(ids)
+ if err != nil {
+ panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs))
+ }
+ state, rest, ok := strings.Cut(rest, "]")
+ isSyscall := false
+ if strings.Contains(rest, "\nsyscall.") {
+ isSyscall = true
+ }
+ var parent int
+ _, rest, ok = strings.Cut(rest, "\ncreated by ")
+ if ok && strings.Contains(rest, " in goroutine ") {
+ _, rest, ok := strings.Cut(rest, " in goroutine ")
+ if !ok {
+ panic(fmt.Errorf("4 unparsable goroutine stack:\n%s", gs))
+ }
+ parents, rest, ok := strings.Cut(rest, "\n")
+ if !ok {
+ panic(fmt.Errorf("5 unparsable goroutine stack:\n%s", gs))
+ }
+ parent, err = strconv.Atoi(parents)
+ if err != nil {
+ panic(fmt.Errorf("6 unparsable goroutine stack:\n%s", gs))
+ }
+ }
+ goroutines = append(goroutines, goroutine{
+ id: id,
+ parent: parent,
+ state: state,
+ syscall: isSyscall,
+ })
+ }
+ return goroutines
+}
+
+// AdvanceTime advances the synthetic clock by d.
+func (g *synctestGroup) AdvanceTime(d time.Duration) {
+ defer g.Wait()
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ g.now = g.now.Add(d)
+ for tm := range g.timers {
+ if tm.when.After(g.now) {
+ continue
+ }
+ tm.run()
+ delete(g.timers, tm)
+ }
+}
+
+// Now returns the current synthetic time.
+func (g *synctestGroup) Now() time.Time {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ return g.now
+}
+
+// TimeUntilEvent returns the amount of time until the next scheduled timer.
+func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ for tm := range g.timers {
+ if dd := tm.when.Sub(g.now); !scheduled || dd < d {
+ d = dd
+ scheduled = true
+ }
+ }
+ return d, scheduled
+}
+
+// Sleep is time.Sleep, but using synthetic time.
+func (g *synctestGroup) Sleep(d time.Duration) {
+ tm := g.NewTimer(d)
+ <-tm.C()
+}
+
+// NewTimer is time.NewTimer, but using synthetic time.
+func (g *synctestGroup) NewTimer(d time.Duration) Timer {
+ return g.addTimer(d, &fakeTimer{
+ ch: make(chan time.Time),
+ })
+}
+
+// AfterFunc is time.AfterFunc, but using synthetic time.
+func (g *synctestGroup) AfterFunc(d time.Duration, f func()) Timer {
+ return g.addTimer(d, &fakeTimer{
+ f: f,
+ })
+}
+
+// ContextWithTimeout is context.WithTimeout, but using synthetic time.
+func (g *synctestGroup) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(ctx)
+ tm := g.AfterFunc(d, cancel)
+ return ctx, func() {
+ tm.Stop()
+ cancel()
+ }
+}
+
+func (g *synctestGroup) addTimer(d time.Duration, tm *fakeTimer) *fakeTimer {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ tm.g = g
+ tm.when = g.now.Add(d)
+ if g.timers == nil {
+ g.timers = make(map[*fakeTimer]struct{})
+ }
+ if tm.when.After(g.now) {
+ g.timers[tm] = struct{}{}
+ } else {
+ tm.run()
+ }
+ return tm
+}
+
+type Timer = interface {
+ C() <-chan time.Time
+ Reset(d time.Duration) bool
+ Stop() bool
+}
+
+type fakeTimer struct {
+ g *synctestGroup
+ when time.Time
+ ch chan time.Time
+ f func()
+}
+
+func (tm *fakeTimer) run() {
+ if tm.ch != nil {
+ tm.ch <- tm.g.now
+ } else {
+ go func() {
+ tm.g.Join()
+ tm.f()
+ }()
+ }
+}
+
+func (tm *fakeTimer) C() <-chan time.Time { return tm.ch }
+
+func (tm *fakeTimer) Reset(d time.Duration) bool {
+ tm.g.mu.Lock()
+ defer tm.g.mu.Unlock()
+ _, stopped := tm.g.timers[tm]
+ if d <= 0 {
+ delete(tm.g.timers, tm)
+ tm.run()
+ } else {
+ tm.when = tm.g.now.Add(d)
+ tm.g.timers[tm] = struct{}{}
+ }
+ return stopped
+}
+
+func (tm *fakeTimer) Stop() bool {
+ tm.g.mu.Lock()
+ defer tm.g.mu.Unlock()
+ _, stopped := tm.g.timers[tm]
+ delete(tm.g.timers, tm)
+ return stopped
+}
+
+// TestSynctestLogs verifies that t.Log works,
+// in particular that the GOOS=js workaround in synctestGroup.Wait is working.
+// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops
+// calling runtime.Gosched; see Wait for the workaround.)
+func TestSynctestLogs(t *testing.T) {
+ g := newSynctest(time.Now())
+ donec := make(chan struct{})
+ go func() {
+ g.Join()
+ for range 100 {
+ t.Logf("logging a long line")
+ }
+ close(donec)
+ }()
+ g.Wait()
+ select {
+ case <-donec:
+ default:
+ panic("done")
+ }
+}
diff --git a/http2/testdata/draft-ietf-httpbis-http2.xml b/http2/testdata/draft-ietf-httpbis-http2.xml
deleted file mode 100644
index 39d756de7a..0000000000
--- a/http2/testdata/draft-ietf-httpbis-http2.xml
+++ /dev/null
@@ -1,5021 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Hypertext Transfer Protocol version 2
-
-
- Twist
-
- mbelshe@chromium.org
-
-
-
-
- Google, Inc
-
- fenix@google.com
-
-
-
-
- Mozilla
-
-
- 331 E Evelyn Street
- Mountain View
- CA
- 94041
- US
-
- martin.thomson@gmail.com
-
-
-
-
- Applications
- HTTPbis
- HTTP
- SPDY
- Web
-
-
-
- This specification describes an optimized expression of the semantics of the Hypertext
- Transfer Protocol (HTTP). HTTP/2 enables a more efficient use of network resources and a
- reduced perception of latency by introducing header field compression and allowing multiple
- concurrent messages on the same connection. It also introduces unsolicited push of
- representations from servers to clients.
-
-
- This specification is an alternative to, but does not obsolete, the HTTP/1.1 message syntax.
- HTTP's existing semantics remain unchanged.
-
-
-
-
-
- Discussion of this draft takes place on the HTTPBIS working group mailing list
- (ietf-http-wg@w3.org), which is archived at .
-
-
- Working Group information can be found at ; that specific to HTTP/2 are at .
-
-
- The changes in this draft are summarized in .
-
-
-
-
-
-
-
-
-
- The Hypertext Transfer Protocol (HTTP) is a wildly successful protocol. However, the
- HTTP/1.1 message format ( ) has
- several characteristics that have a negative overall effect on application performance
- today.
-
-
- In particular, HTTP/1.0 allowed only one request to be outstanding at a time on a given
- TCP connection. HTTP/1.1 added request pipelining, but this only partially addressed
- request concurrency and still suffers from head-of-line blocking. Therefore, HTTP/1.1
- clients that need to make many requests typically use multiple connections to a server in
- order to achieve concurrency and thereby reduce latency.
-
-
- Furthermore, HTTP header fields are often repetitive and verbose, causing unnecessary
- network traffic, as well as causing the initial TCP congestion
- window to quickly fill. This can result in excessive latency when multiple requests are
- made on a new TCP connection.
-
-
- HTTP/2 addresses these issues by defining an optimized mapping of HTTP's semantics to an
- underlying connection. Specifically, it allows interleaving of request and response
- messages on the same connection and uses an efficient coding for HTTP header fields. It
- also allows prioritization of requests, letting more important requests complete more
- quickly, further improving performance.
-
-
- The resulting protocol is more friendly to the network, because fewer TCP connections can
- be used in comparison to HTTP/1.x. This means less competition with other flows, and
- longer-lived connections, which in turn leads to better utilization of available network
- capacity.
-
-
- Finally, HTTP/2 also enables more efficient processing of messages through use of binary
- message framing.
-
-
-
-
-
- HTTP/2 provides an optimized transport for HTTP semantics. HTTP/2 supports all of the core
- features of HTTP/1.1, but aims to be more efficient in several ways.
-
-
- The basic protocol unit in HTTP/2 is a frame . Each frame
- type serves a different purpose. For example, HEADERS and
- DATA frames form the basis of HTTP requests and
- responses ; other frame types like SETTINGS ,
- WINDOW_UPDATE , and PUSH_PROMISE are used in support of other
- HTTP/2 features.
-
-
- Multiplexing of requests is achieved by having each HTTP request-response exchange
- associated with its own stream . Streams are largely
- independent of each other, so a blocked or stalled request or response does not prevent
- progress on other streams.
-
-
- Flow control and prioritization ensure that it is possible to efficiently use multiplexed
- streams. Flow control helps to ensure that only data that
- can be used by a receiver is transmitted. Prioritization ensures that limited resources can be directed
- to the most important streams first.
-
-
- HTTP/2 adds a new interaction mode, whereby a server can push
- responses to a client . Server push allows a server to speculatively send a client
- data that the server anticipates the client will need, trading off some network usage
- against a potential latency gain. The server does this by synthesizing a request, which it
- sends as a PUSH_PROMISE frame. The server is then able to send a response to
- the synthetic request on a separate stream.
-
-
- Frames that contain HTTP header fields are compressed .
- HTTP requests can be highly redundant, so compression can reduce the size of requests and
- responses significantly.
-
-
-
-
- The HTTP/2 specification is split into four parts:
-
-
- Starting HTTP/2 covers how an HTTP/2 connection is
- initiated.
-
-
- The framing and streams layers describe the way HTTP/2 frames are
- structured and formed into multiplexed streams.
-
-
- Frame and error
- definitions include details of the frame and error types used in HTTP/2.
-
-
- HTTP mappings and additional
- requirements describe how HTTP semantics are expressed using frames and
- streams.
-
-
-
-
- While some of the frame and stream layer concepts are isolated from HTTP, this
- specification does not define a completely generic framing layer. The framing and streams
- layers are tailored to the needs of the HTTP protocol and server push.
-
-
-
-
-
- The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD
- NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as
- described in RFC 2119 .
-
-
- All numeric values are in network byte order. Values are unsigned unless otherwise
- indicated. Literal values are provided in decimal or hexadecimal as appropriate.
- Hexadecimal literals are prefixed with 0x to distinguish them
- from decimal literals.
-
-
- The following terms are used:
-
-
- The endpoint initiating the HTTP/2 connection.
-
-
- A transport-layer connection between two endpoints.
-
-
- An error that affects the entire HTTP/2 connection.
-
-
- Either the client or server of the connection.
-
-
- The smallest unit of communication within an HTTP/2 connection, consisting of a header
- and a variable-length sequence of octets structured according to the frame type.
-
-
- An endpoint. When discussing a particular endpoint, "peer" refers to the endpoint
- that is remote to the primary subject of discussion.
-
-
- An endpoint that is receiving frames.
-
-
- An endpoint that is transmitting frames.
-
-
- The endpoint which did not initiate the HTTP/2 connection.
-
-
- A bi-directional flow of frames across a virtual channel within the HTTP/2 connection.
-
-
- An error on the individual HTTP/2 stream.
-
-
-
-
- Finally, the terms "gateway", "intermediary", "proxy", and "tunnel" are defined
- in .
-
-
-
-
-
-
- An HTTP/2 connection is an application layer protocol running on top of a TCP connection
- ( ). The client is the TCP connection initiator.
-
-
- HTTP/2 uses the same "http" and "https" URI schemes used by HTTP/1.1. HTTP/2 shares the same
- default port numbers: 80 for "http" URIs and 443 for "https" URIs. As a result,
- implementations processing requests for target resource URIs like http://example.org/foo or https://example.com/bar are required to first discover whether the
- upstream server (the immediate peer to which the client wishes to establish a connection)
- supports HTTP/2.
-
-
-
- The means by which support for HTTP/2 is determined is different for "http" and "https"
- URIs. Discovery for "http" URIs is described in . Discovery
- for "https" URIs is described in .
-
-
-
-
- The protocol defined in this document has two identifiers.
-
-
-
- The string "h2" identifies the protocol where HTTP/2 uses TLS . This identifier is used in the TLS application layer protocol negotiation extension (ALPN)
- field and any place that HTTP/2 over TLS is identified.
-
-
- The "h2" string is serialized into an ALPN protocol identifier as the two octet
- sequence: 0x68, 0x32.
-
-
-
-
- The string "h2c" identifies the protocol where HTTP/2 is run over cleartext TCP.
- This identifier is used in the HTTP/1.1 Upgrade header field and any place that
- HTTP/2 over TCP is identified.
-
-
-
-
-
- Negotiating "h2" or "h2c" implies the use of the transport, security, framing and message
- semantics described in this document.
-
-
- RFC Editor's Note: please remove the remainder of this section prior to the
- publication of a final version of this document.
-
-
- Only implementations of the final, published RFC can identify themselves as "h2" or "h2c".
- Until such an RFC exists, implementations MUST NOT identify themselves using these
- strings.
-
-
- Examples and text throughout the rest of this document use "h2" as a matter of
- editorial convenience only. Implementations of draft versions MUST NOT identify using
- this string.
-
-
- Implementations of draft versions of the protocol MUST add the string "-" and the
- corresponding draft number to the identifier. For example, draft-ietf-httpbis-http2-11
- over TLS is identified using the string "h2-11".
-
-
- Non-compatible experiments that are based on these draft versions MUST append the string
- "-" and an experiment name to the identifier. For example, an experimental implementation
- of packet mood-based encoding based on draft-ietf-httpbis-http2-09 might identify itself
- as "h2-09-emo". Note that any label MUST conform to the "token" syntax defined in
- . Experimenters are
- encouraged to coordinate their experiments on the ietf-http-wg@w3.org mailing list.
-
-
-
-
-
- A client that makes a request for an "http" URI without prior knowledge about support for
- HTTP/2 uses the HTTP Upgrade mechanism ( ). The client makes an HTTP/1.1 request that includes an Upgrade
- header field identifying HTTP/2 with the "h2c" token. The HTTP/1.1 request MUST include
- exactly one HTTP2-Settings header field.
-
-
- For example:
-
-
-]]>
-
-
- Requests that contain an entity body MUST be sent in their entirety before the client can
- send HTTP/2 frames. This means that a large request entity can block the use of the
- connection until it is completely sent.
-
-
- If concurrency of an initial request with subsequent requests is important, an OPTIONS
- request can be used to perform the upgrade to HTTP/2, at the cost of an additional
- round-trip.
-
-
- A server that does not support HTTP/2 can respond to the request as though the Upgrade
- header field were absent:
-
-
-
-HTTP/1.1 200 OK
-Content-Length: 243
-Content-Type: text/html
-
-...
-
-
-
- A server MUST ignore a "h2" token in an Upgrade header field. Presence of a token with
- "h2" implies HTTP/2 over TLS, which is instead negotiated as described in .
-
-
- A server that supports HTTP/2 can accept the upgrade with a 101 (Switching Protocols)
- response. After the empty line that terminates the 101 response, the server can begin
- sending HTTP/2 frames. These frames MUST include a response to the request that initiated
- the Upgrade.
-
-
-
-
- For example:
-
-
-HTTP/1.1 101 Switching Protocols
-Connection: Upgrade
-Upgrade: h2c
-
-[ HTTP/2 connection ...
-
-
-
- The first HTTP/2 frame sent by the server is a SETTINGS frame ( ) as the server connection preface ( ). Upon receiving the 101 response, the client sends a connection preface , which includes a
- SETTINGS frame.
-
-
- The HTTP/1.1 request that is sent prior to upgrade is assigned stream identifier 1 and is
- assigned default priority values . Stream 1 is
- implicitly half closed from the client toward the server, since the request is completed
- as an HTTP/1.1 request. After commencing the HTTP/2 connection, stream 1 is used for the
- response.
-
-
-
-
- A request that upgrades from HTTP/1.1 to HTTP/2 MUST include exactly one HTTP2-Settings header field. The HTTP2-Settings header field is a connection-specific header field
- that includes parameters that govern the HTTP/2 connection, provided in anticipation of
- the server accepting the request to upgrade.
-
-
-
-
-
- A server MUST NOT upgrade the connection to HTTP/2 if this header field is not present,
- or if more than one is present. A server MUST NOT send this header field.
-
-
-
- The content of the HTTP2-Settings header field is the
- payload of a SETTINGS frame ( ), encoded as a
- base64url string (that is, the URL- and filename-safe Base64 encoding described in , with any trailing '=' characters omitted). The
- ABNF production for token68 is
- defined in .
-
-
- Since the upgrade is only intended to apply to the immediate connection, a client
- sending HTTP2-Settings MUST also send HTTP2-Settings as a connection option in the Connection header field to prevent it from being forwarded
- downstream.
-
-
- A server decodes and interprets these values as it would any other
- SETTINGS frame. Acknowledgement of the
- SETTINGS parameters is not necessary, since a 101 response serves as implicit
- acknowledgment. Providing these values in the Upgrade request gives a client an
- opportunity to provide parameters prior to receiving any frames from the server.
-
-
-
-
-
-
- A client that makes a request to an "https" URI uses TLS
- with the application layer protocol negotiation extension .
-
-
- HTTP/2 over TLS uses the "h2" application token. The "h2c" token MUST NOT be sent by a
- client or selected by a server.
-
-
- Once TLS negotiation is complete, both the client and the server send a connection preface .
-
-
-
-
-
- A client can learn that a particular server supports HTTP/2 by other means. For example,
- describes a mechanism for advertising this capability.
-
-
- A client MAY immediately send HTTP/2 frames to a server that is known to support HTTP/2,
- after the connection preface ; a server can
- identify such a connection by the presence of the connection preface. This only affects
- the establishment of HTTP/2 connections over cleartext TCP; implementations that support
- HTTP/2 over TLS MUST use protocol negotiation in TLS .
-
-
- Without additional information, prior support for HTTP/2 is not a strong signal that a
- given server will support HTTP/2 for future connections. For example, it is possible for
- server configurations to change, for configurations to differ between instances in
- clustered servers, or for network conditions to change.
-
-
-
-
-
- Upon establishment of a TCP connection and determination that HTTP/2 will be used by both
- peers, each endpoint MUST send a connection preface as a final confirmation and to
- establish the initial SETTINGS parameters for the HTTP/2 connection. The client and
- server each send a different connection preface.
-
-
- The client connection preface starts with a sequence of 24 octets, which in hex notation
- are:
-
-
-
-
-
- (the string PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n ). This sequence
- is followed by a SETTINGS frame ( ). The
- SETTINGS frame MAY be empty. The client sends the client connection
- preface immediately upon receipt of a 101 Switching Protocols response (indicating a
- successful upgrade), or as the first application data octets of a TLS connection. If
- starting an HTTP/2 connection with prior knowledge of server support for the protocol, the
- client connection preface is sent upon connection establishment.
-
-
-
-
- The client connection preface is selected so that a large proportion of HTTP/1.1 or
- HTTP/1.0 servers and intermediaries do not attempt to process further frames. Note
- that this does not address the concerns raised in .
-
-
-
-
- The server connection preface consists of a potentially empty SETTINGS
- frame ( ) that MUST be the first frame the server sends in the
- HTTP/2 connection.
-
-
- The SETTINGS frames received from a peer as part of the connection preface
- MUST be acknowledged (see ) after sending the connection
- preface.
-
-
- To avoid unnecessary latency, clients are permitted to send additional frames to the
- server immediately after sending the client connection preface, without waiting to receive
- the server connection preface. It is important to note, however, that the server
- connection preface SETTINGS frame might include parameters that necessarily
- alter how a client is expected to communicate with the server. Upon receiving the
- SETTINGS frame, the client is expected to honor any parameters established.
- In some configurations, it is possible for the server to transmit SETTINGS
- before the client sends additional frames, providing an opportunity to avoid this issue.
-
-
- Clients and servers MUST treat an invalid connection preface as a connection error of type
- PROTOCOL_ERROR . A GOAWAY frame ( )
- MAY be omitted in this case, since an invalid preface indicates that the peer is not using
- HTTP/2.
-
-
-
-
-
-
- Once the HTTP/2 connection is established, endpoints can begin exchanging frames.
-
-
-
-
- All frames begin with a fixed 9-octet header followed by a variable-length payload.
-
-
-
-
-
- The fields of the frame header are defined as:
-
-
-
- The length of the frame payload expressed as an unsigned 24-bit integer. Values
- greater than 214 (16,384) MUST NOT be sent unless the receiver has
- set a larger value for SETTINGS_MAX_FRAME_SIZE .
-
-
- The 9 octets of the frame header are not included in this value.
-
-
-
-
- The 8-bit type of the frame. The frame type determines the format and semantics of
- the frame. Implementations MUST ignore and discard any frame that has a type that
- is unknown.
-
-
-
-
- An 8-bit field reserved for frame-type specific boolean flags.
-
-
- Flags are assigned semantics specific to the indicated frame type. Flags that have
- no defined semantics for a particular frame type MUST be ignored, and MUST be left
- unset (0) when sending.
-
-
-
-
- A reserved 1-bit field. The semantics of this bit are undefined and the bit MUST
- remain unset (0) when sending and MUST be ignored when receiving.
-
-
-
-
- A 31-bit stream identifier (see ). The value 0 is
- reserved for frames that are associated with the connection as a whole as opposed to
- an individual stream.
-
-
-
-
-
- The structure and content of the frame payload is dependent entirely on the frame type.
-
-
-
-
-
- The size of a frame payload is limited by the maximum size that a receiver advertises in
- the SETTINGS_MAX_FRAME_SIZE setting. This setting can have any value
- between 214 (16,384) and 224 -1 (16,777,215) octets,
- inclusive.
-
-
- All implementations MUST be capable of receiving and minimally processing frames up to
- 214 octets in length, plus the 9 octet frame
- header . The size of the frame header is not included when describing frame sizes.
-
-
- Certain frame types, such as PING , impose additional limits
- on the amount of payload data allowed.
-
-
-
-
- If a frame size exceeds any defined limit, or is too small to contain mandatory frame
- data, the endpoint MUST send a FRAME_SIZE_ERROR error. A frame size error
- in a frame that could alter the state of the entire connection MUST be treated as a connection error ; this includes any frame carrying
- a header block (that is, HEADERS ,
- PUSH_PROMISE , and CONTINUATION ), SETTINGS ,
- and any WINDOW_UPDATE frame with a stream identifier of 0.
-
-
- Endpoints are not obligated to use all available space in a frame. Responsiveness can be
- improved by using frames that are smaller than the permitted maximum size. Sending large
- frames can result in delays in sending time-sensitive frames (such
- RST_STREAM , WINDOW_UPDATE , or PRIORITY )
- which if blocked by the transmission of a large frame, could affect performance.
-
-
-
-
-
- Just as in HTTP/1, a header field in HTTP/2 is a name with one or more associated values.
- They are used within HTTP request and response messages as well as server push operations
- (see ).
-
-
- Header lists are collections of zero or more header fields. When transmitted over a
- connection, a header list is serialized into a header block using HTTP Header Compression . The serialized header block is then
- divided into one or more octet sequences, called header block fragments, and transmitted
- within the payload of HEADERS , PUSH_PROMISE or CONTINUATION frames.
-
-
- The Cookie header field is treated specially by the HTTP
- mapping (see ).
-
-
- A receiving endpoint reassembles the header block by concatenating its fragments, then
- decompresses the block to reconstruct the header list.
-
-
- A complete header block consists of either:
-
-
- a single HEADERS or PUSH_PROMISE frame,
- with the END_HEADERS flag set, or
-
-
- a HEADERS or PUSH_PROMISE frame with the END_HEADERS
- flag cleared and one or more CONTINUATION frames,
- where the last CONTINUATION frame has the END_HEADERS flag set.
-
-
-
-
- Header compression is stateful. One compression context and one decompression context is
- used for the entire connection. Each header block is processed as a discrete unit.
- Header blocks MUST be transmitted as a contiguous sequence of frames, with no interleaved
- frames of any other type or from any other stream. The last frame in a sequence of
- HEADERS or CONTINUATION frames MUST have the END_HEADERS
- flag set. The last frame in a sequence of PUSH_PROMISE or
- CONTINUATION frames MUST have the END_HEADERS flag set. This allows a
- header block to be logically equivalent to a single frame.
-
-
- Header block fragments can only be sent as the payload of HEADERS ,
- PUSH_PROMISE or CONTINUATION frames, because these frames
- carry data that can modify the compression context maintained by a receiver. An endpoint
- receiving HEADERS , PUSH_PROMISE or
- CONTINUATION frames MUST reassemble header blocks and perform decompression
- even if the frames are to be discarded. A receiver MUST terminate the connection with a
- connection error of type
- COMPRESSION_ERROR if it does not decompress a header block.
-
-
-
-
-
-
- A "stream" is an independent, bi-directional sequence of frames exchanged between the client
- and server within an HTTP/2 connection. Streams have several important characteristics:
-
-
- A single HTTP/2 connection can contain multiple concurrently open streams, with either
- endpoint interleaving frames from multiple streams.
-
-
- Streams can be established and used unilaterally or shared by either the client or
- server.
-
-
- Streams can be closed by either endpoint.
-
-
- The order in which frames are sent on a stream is significant. Recipients process frames
- in the order they are received. In particular, the order of HEADERS ,
- and DATA frames is semantically significant.
-
-
- Streams are identified by an integer. Stream identifiers are assigned to streams by the
- endpoint initiating the stream.
-
-
-
-
-
-
- The lifecycle of a stream is shown in .
-
-
-
-
- | |<-----------' |
- | R | closed | R |
- `-------------------->| |<--------------------'
- +--------+
-
- H: HEADERS frame (with implied CONTINUATIONs)
- PP: PUSH_PROMISE frame (with implied CONTINUATIONs)
- ES: END_STREAM flag
- R: RST_STREAM frame
-]]>
-
-
-
-
- Note that this diagram shows stream state transitions and the frames and flags that affect
- those transitions only. In this regard, CONTINUATION frames do not result
- in state transitions; they are effectively part of the HEADERS or
- PUSH_PROMISE that they follow. For this purpose, the END_STREAM flag is
- processed as a separate event to the frame that bears it; a HEADERS frame
- with the END_STREAM flag set can cause two state transitions.
-
-
- Both endpoints have a subjective view of the state of a stream that could be different
- when frames are in transit. Endpoints do not coordinate the creation of streams; they are
- created unilaterally by either endpoint. The negative consequences of a mismatch in
- states are limited to the "closed" state after sending RST_STREAM , where
- frames might be received for some time after closing.
-
-
- Streams have the following states:
-
-
-
-
-
- All streams start in the "idle" state. In this state, no frames have been
- exchanged.
-
-
- The following transitions are valid from this state:
-
-
- Sending or receiving a HEADERS frame causes the stream to become
- "open". The stream identifier is selected as described in . The same HEADERS frame can also
- cause a stream to immediately become "half closed".
-
-
- Sending a PUSH_PROMISE frame marks the associated stream for
- later use. The stream state for the reserved stream transitions to "reserved
- (local)".
-
-
- Receiving a PUSH_PROMISE frame marks the associated stream as
- reserved by the remote peer. The state of the stream becomes "reserved
- (remote)".
-
-
-
-
- Receiving any frames other than HEADERS or
- PUSH_PROMISE on a stream in this state MUST be treated as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
-
-
- A stream in the "reserved (local)" state is one that has been promised by sending a
- PUSH_PROMISE frame. A PUSH_PROMISE frame reserves an
- idle stream by associating the stream with an open stream that was initiated by the
- remote peer (see ).
-
-
- In this state, only the following transitions are possible:
-
-
- The endpoint can send a HEADERS frame. This causes the stream to
- open in a "half closed (remote)" state.
-
-
- Either endpoint can send a RST_STREAM frame to cause the stream
- to become "closed". This releases the stream reservation.
-
-
-
-
- An endpoint MUST NOT send any type of frame other than HEADERS or
- RST_STREAM in this state.
-
-
- A PRIORITY frame MAY be received in this state. Receiving any type
- of frame other than RST_STREAM or PRIORITY on a stream
- in this state MUST be treated as a connection
- error of type PROTOCOL_ERROR .
-
-
-
-
-
-
- A stream in the "reserved (remote)" state has been reserved by a remote peer.
-
-
- In this state, only the following transitions are possible:
-
-
- Receiving a HEADERS frame causes the stream to transition to
- "half closed (local)".
-
-
- Either endpoint can send a RST_STREAM frame to cause the stream
- to become "closed". This releases the stream reservation.
-
-
-
-
- An endpoint MAY send a PRIORITY frame in this state to reprioritize
- the reserved stream. An endpoint MUST NOT send any type of frame other than
- RST_STREAM , WINDOW_UPDATE , or PRIORITY
- in this state.
-
-
- Receiving any type of frame other than HEADERS or
- RST_STREAM on a stream in this state MUST be treated as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
-
-
- A stream in the "open" state may be used by both peers to send frames of any type.
- In this state, sending peers observe advertised stream
- level flow control limits .
-
-
- From this state either endpoint can send a frame with an END_STREAM flag set, which
- causes the stream to transition into one of the "half closed" states: an endpoint
- sending an END_STREAM flag causes the stream state to become "half closed (local)";
- an endpoint receiving an END_STREAM flag causes the stream state to become "half
- closed (remote)".
-
-
- Either endpoint can send a RST_STREAM frame from this state, causing
- it to transition immediately to "closed".
-
-
-
-
-
-
- A stream that is in the "half closed (local)" state cannot be used for sending
- frames. Only WINDOW_UPDATE , PRIORITY and
- RST_STREAM frames can be sent in this state.
-
-
- A stream transitions from this state to "closed" when a frame that contains an
- END_STREAM flag is received, or when either peer sends a RST_STREAM
- frame.
-
-
- A receiver can ignore WINDOW_UPDATE frames in this state, which might
- arrive for a short period after a frame bearing the END_STREAM flag is sent.
-
-
- PRIORITY frames received in this state are used to reprioritize
- streams that depend on the current stream.
-
-
-
-
-
-
- A stream that is "half closed (remote)" is no longer being used by the peer to send
- frames. In this state, an endpoint is no longer obligated to maintain a receiver
- flow control window if it performs flow control.
-
-
- If an endpoint receives additional frames for a stream that is in this state, other
- than WINDOW_UPDATE , PRIORITY or
- RST_STREAM , it MUST respond with a stream error of type
- STREAM_CLOSED .
-
-
- A stream that is "half closed (remote)" can be used by the endpoint to send frames
- of any type. In this state, the endpoint continues to observe advertised stream level flow control limits .
-
-
- A stream can transition from this state to "closed" by sending a frame that contains
- an END_STREAM flag, or when either peer sends a RST_STREAM frame.
-
-
-
-
-
-
- The "closed" state is the terminal state.
-
-
- An endpoint MUST NOT send frames other than PRIORITY on a closed
- stream. An endpoint that receives any frame other than PRIORITY
- after receiving a RST_STREAM MUST treat that as a stream error of type
- STREAM_CLOSED . Similarly, an endpoint that receives any frames after
- receiving a frame with the END_STREAM flag set MUST treat that as a connection error of type
- STREAM_CLOSED , unless the frame is permitted as described below.
-
-
- WINDOW_UPDATE or RST_STREAM frames can be received in
- this state for a short period after a DATA or HEADERS
- frame containing an END_STREAM flag is sent. Until the remote peer receives and
- processes RST_STREAM or the frame bearing the END_STREAM flag, it
- might send frames of these types. Endpoints MUST ignore
- WINDOW_UPDATE or RST_STREAM frames received in this
- state, though endpoints MAY choose to treat frames that arrive a significant time
- after sending END_STREAM as a connection
- error of type PROTOCOL_ERROR .
-
-
- PRIORITY frames can be sent on closed streams to prioritize streams
- that are dependent on the closed stream. Endpoints SHOULD process
- PRIORITY frame, though they can be ignored if the stream has been
- removed from the dependency tree (see ).
-
-
- If this state is reached as a result of sending a RST_STREAM frame,
- the peer that receives the RST_STREAM might have already sent - or
- enqueued for sending - frames on the stream that cannot be withdrawn. An endpoint
- MUST ignore frames that it receives on closed streams after it has sent a
- RST_STREAM frame. An endpoint MAY choose to limit the period over
- which it ignores frames and treat frames that arrive after this time as being in
- error.
-
-
- Flow controlled frames (i.e., DATA ) received after sending
- RST_STREAM are counted toward the connection flow control window.
- Even though these frames might be ignored, because they are sent before the sender
- receives the RST_STREAM , the sender will consider the frames to count
- against the flow control window.
-
-
- An endpoint might receive a PUSH_PROMISE frame after it sends
- RST_STREAM . PUSH_PROMISE causes a stream to become
- "reserved" even if the associated stream has been reset. Therefore, a
- RST_STREAM is needed to close an unwanted promised stream.
-
-
-
-
-
- In the absence of more specific guidance elsewhere in this document, implementations
- SHOULD treat the receipt of a frame that is not expressly permitted in the description of
- a state as a connection error of type
- PROTOCOL_ERROR . Frame of unknown types are ignored.
-
-
- An example of the state transitions for an HTTP request/response exchange can be found in
- . An example of the state transitions for server push can be
- found in and .
-
-
-
-
- Streams are identified with an unsigned 31-bit integer. Streams initiated by a client
- MUST use odd-numbered stream identifiers; those initiated by the server MUST use
- even-numbered stream identifiers. A stream identifier of zero (0x0) is used for
- connection control messages; the stream identifier zero cannot be used to establish a
- new stream.
-
-
- HTTP/1.1 requests that are upgraded to HTTP/2 (see ) are
- responded to with a stream identifier of one (0x1). After the upgrade
- completes, stream 0x1 is "half closed (local)" to the client. Therefore, stream 0x1
- cannot be selected as a new stream identifier by a client that upgrades from HTTP/1.1.
-
-
- The identifier of a newly established stream MUST be numerically greater than all
- streams that the initiating endpoint has opened or reserved. This governs streams that
- are opened using a HEADERS frame and streams that are reserved using
- PUSH_PROMISE . An endpoint that receives an unexpected stream identifier
- MUST respond with a connection error of
- type PROTOCOL_ERROR .
-
-
- The first use of a new stream identifier implicitly closes all streams in the "idle"
- state that might have been initiated by that peer with a lower-valued stream identifier.
- For example, if a client sends a HEADERS frame on stream 7 without ever
- sending a frame on stream 5, then stream 5 transitions to the "closed" state when the
- first frame for stream 7 is sent or received.
-
-
- Stream identifiers cannot be reused. Long-lived connections can result in an endpoint
- exhausting the available range of stream identifiers. A client that is unable to
- establish a new stream identifier can establish a new connection for new streams. A
- server that is unable to establish a new stream identifier can send a
- GOAWAY frame so that the client is forced to open a new connection for
- new streams.
-
-
-
-
-
- A peer can limit the number of concurrently active streams using the
- SETTINGS_MAX_CONCURRENT_STREAMS parameter (see ) within a SETTINGS frame. The maximum concurrent
- streams setting is specific to each endpoint and applies only to the peer that receives
- the setting. That is, clients specify the maximum number of concurrent streams the
- server can initiate, and servers specify the maximum number of concurrent streams the
- client can initiate.
-
-
- Streams that are in the "open" state, or either of the "half closed" states count toward
- the maximum number of streams that an endpoint is permitted to open. Streams in any of
- these three states count toward the limit advertised in the
- SETTINGS_MAX_CONCURRENT_STREAMS setting. Streams in either of the
- "reserved" states do not count toward the stream limit.
-
-
- Endpoints MUST NOT exceed the limit set by their peer. An endpoint that receives a
- HEADERS frame that causes their advertised concurrent stream limit to be
- exceeded MUST treat this as a stream error . An
- endpoint that wishes to reduce the value of
- SETTINGS_MAX_CONCURRENT_STREAMS to a value that is below the current
- number of open streams can either close streams that exceed the new value or allow
- streams to complete.
-
-
-
-
-
-
- Using streams for multiplexing introduces contention over use of the TCP connection,
- resulting in blocked streams. A flow control scheme ensures that streams on the same
- connection do not destructively interfere with each other. Flow control is used for both
- individual streams and for the connection as a whole.
-
-
- HTTP/2 provides for flow control through use of the WINDOW_UPDATE frame .
-
-
-
-
- HTTP/2 stream flow control aims to allow a variety of flow control algorithms to be
- used without requiring protocol changes. Flow control in HTTP/2 has the following
- characteristics:
-
-
- Flow control is specific to a connection; i.e., it is "hop-by-hop", not
- "end-to-end".
-
-
- Flow control is based on window update frames. Receivers advertise how many octets
- they are prepared to receive on a stream and for the entire connection. This is a
- credit-based scheme.
-
-
- Flow control is directional with overall control provided by the receiver. A
- receiver MAY choose to set any window size that it desires for each stream and for
- the entire connection. A sender MUST respect flow control limits imposed by a
- receiver. Clients, servers and intermediaries all independently advertise their
- flow control window as a receiver and abide by the flow control limits set by
- their peer when sending.
-
-
- The initial value for the flow control window is 65,535 octets for both new streams
- and the overall connection.
-
-
- The frame type determines whether flow control applies to a frame. Of the frames
- specified in this document, only DATA frames are subject to flow
- control; all other frame types do not consume space in the advertised flow control
- window. This ensures that important control frames are not blocked by flow control.
-
-
- Flow control cannot be disabled.
-
-
- HTTP/2 defines only the format and semantics of the WINDOW_UPDATE
- frame ( ). This document does not stipulate how a
- receiver decides when to send this frame or the value that it sends, nor does it
- specify how a sender chooses to send packets. Implementations are able to select
- any algorithm that suits their needs.
-
-
-
-
- Implementations are also responsible for managing how requests and responses are sent
- based on priority; choosing how to avoid head of line blocking for requests; and
- managing the creation of new streams. Algorithm choices for these could interact with
- any flow control algorithm.
-
-
-
-
-
- Flow control is defined to protect endpoints that are operating under resource
- constraints. For example, a proxy needs to share memory between many connections, and
- also might have a slow upstream connection and a fast downstream one. Flow control
- addresses cases where the receiver is unable process data on one stream, yet wants to
- continue to process other streams in the same connection.
-
-
- Deployments that do not require this capability can advertise a flow control window of
- the maximum size, incrementing the available space when new data is received. This
- effectively disables flow control for that receiver. Conversely, a sender is always
- subject to the flow control window advertised by the receiver.
-
-
- Deployments with constrained resources (for example, memory) can employ flow control to
- limit the amount of memory a peer can consume. Note, however, that this can lead to
- suboptimal use of available network resources if flow control is enabled without
- knowledge of the bandwidth-delay product (see ).
-
-
- Even with full awareness of the current bandwidth-delay product, implementation of flow
- control can be difficult. When using flow control, the receiver MUST read from the TCP
- receive buffer in a timely fashion. Failure to do so could lead to a deadlock when
- critical frames, such as WINDOW_UPDATE , are not read and acted upon.
-
-
-
-
-
-
- A client can assign a priority for a new stream by including prioritization information in
- the HEADERS frame that opens the stream. For an existing
- stream, the PRIORITY frame can be used to change the
- priority.
-
-
- The purpose of prioritization is to allow an endpoint to express how it would prefer its
- peer allocate resources when managing concurrent streams. Most importantly, priority can
- be used to select streams for transmitting frames when there is limited capacity for
- sending.
-
-
- Streams can be prioritized by marking them as dependent on the completion of other streams
- ( ). Each dependency is assigned a relative weight, a number
- that is used to determine the relative proportion of available resources that are assigned
- to streams dependent on the same stream.
-
-
-
- Explicitly setting the priority for a stream is input to a prioritization process. It
- does not guarantee any particular processing or transmission order for the stream relative
- to any other stream. An endpoint cannot force a peer to process concurrent streams in a
- particular order using priority. Expressing priority is therefore only ever a suggestion.
-
-
- Providing prioritization information is optional, so default values are used if no
- explicit indicator is provided ( ).
-
-
-
-
- Each stream can be given an explicit dependency on another stream. Including a
- dependency expresses a preference to allocate resources to the identified stream rather
- than to the dependent stream.
-
-
- A stream that is not dependent on any other stream is given a stream dependency of 0x0.
- In other words, the non-existent stream 0 forms the root of the tree.
-
-
- A stream that depends on another stream is a dependent stream. The stream upon which a
- stream is dependent is a parent stream. A dependency on a stream that is not currently
- in the tree - such as a stream in the "idle" state - results in that stream being given
- a default priority .
-
-
- When assigning a dependency on another stream, the stream is added as a new dependency
- of the parent stream. Dependent streams that share the same parent are not ordered with
- respect to each other. For example, if streams B and C are dependent on stream A, and
- if stream D is created with a dependency on stream A, this results in a dependency order
- of A followed by B, C, and D in any order.
-
-
- /|\
- B C B D C
-]]>
-
-
- An exclusive flag allows for the insertion of a new level of dependencies. The
- exclusive flag causes the stream to become the sole dependency of its parent stream,
- causing other dependencies to become dependent on the exclusive stream. In the
- previous example, if stream D is created with an exclusive dependency on stream A, this
- results in D becoming the dependency parent of B and C.
-
-
- D
- B C / \
- B C
-]]>
-
-
- Inside the dependency tree, a dependent stream SHOULD only be allocated resources if all
- of the streams that it depends on (the chain of parent streams up to 0x0) are either
- closed, or it is not possible to make progress on them.
-
-
- A stream cannot depend on itself. An endpoint MUST treat this as a stream error of type PROTOCOL_ERROR .
-
-
-
-
-
- All dependent streams are allocated an integer weight between 1 and 256 (inclusive).
-
-
- Streams with the same parent SHOULD be allocated resources proportionally based on their
- weight. Thus, if stream B depends on stream A with weight 4, and C depends on stream A
- with weight 12, and if no progress can be made on A, stream B ideally receives one third
- of the resources allocated to stream C.
-
-
-
-
-
- Stream priorities are changed using the PRIORITY frame. Setting a
- dependency causes a stream to become dependent on the identified parent stream.
-
-
- Dependent streams move with their parent stream if the parent is reprioritized. Setting
- a dependency with the exclusive flag for a reprioritized stream moves all the
- dependencies of the new parent stream to become dependent on the reprioritized stream.
-
-
- If a stream is made dependent on one of its own dependencies, the formerly dependent
- stream is first moved to be dependent on the reprioritized stream's previous parent.
- The moved dependency retains its weight.
-
-
-
- For example, consider an original dependency tree where B and C depend on A, D and E
- depend on C, and F depends on D. If A is made dependent on D, then D takes the place
- of A. All other dependency relationships stay the same, except for F, which becomes
- dependent on A if the reprioritization is exclusive.
-
- F B C ==> F A OR A
- / \ | / \ /|\
- D E E B C B C F
- | | |
- F E E
- (intermediate) (non-exclusive) (exclusive)
-]]>
-
-
-
-
-
- When a stream is removed from the dependency tree, its dependencies can be moved to
- become dependent on the parent of the closed stream. The weights of new dependencies
- are recalculated by distributing the weight of the dependency of the closed stream
- proportionally based on the weights of its dependencies.
-
-
- Streams that are removed from the dependency tree cause some prioritization information
- to be lost. Resources are shared between streams with the same parent stream, which
- means that if a stream in that set closes or becomes blocked, any spare capacity
- allocated to a stream is distributed to the immediate neighbors of the stream. However,
- if the common dependency is removed from the tree, those streams share resources with
- streams at the next highest level.
-
-
- For example, assume streams A and B share a parent, and streams C and D both depend on
- stream A. Prior to the removal of stream A, if streams A and D are unable to proceed,
- then stream C receives all the resources dedicated to stream A. If stream A is removed
- from the tree, the weight of stream A is divided between streams C and D. If stream D
- is still unable to proceed, this results in stream C receiving a reduced proportion of
- resources. For equal starting weights, C receives one third, rather than one half, of
- available resources.
-
-
- It is possible for a stream to become closed while prioritization information that
- creates a dependency on that stream is in transit. If a stream identified in a
- dependency has no associated priority information, then the dependent stream is instead
- assigned a default priority . This potentially creates
- suboptimal prioritization, since the stream could be given a priority that is different
- to what is intended.
-
-
- To avoid these problems, an endpoint SHOULD retain stream prioritization state for a
- period after streams become closed. The longer state is retained, the lower the chance
- that streams are assigned incorrect or default priority values.
-
-
- This could create a large state burden for an endpoint, so this state MAY be limited.
- An endpoint MAY apply a fixed upper limit on the number of closed streams for which
- prioritization state is tracked to limit state exposure. The amount of additional state
- an endpoint maintains could be dependent on load; under high load, prioritization state
- can be discarded to limit resource commitments. In extreme cases, an endpoint could
- even discard prioritization state for active or reserved streams. If a fixed limit is
- applied, endpoints SHOULD maintain state for at least as many streams as allowed by
- their setting for SETTINGS_MAX_CONCURRENT_STREAMS .
-
-
- An endpoint receiving a PRIORITY frame that changes the priority of a
- closed stream SHOULD alter the dependencies of the streams that depend on it, if it has
- retained enough state to do so.
-
-
-
-
-
- Providing priority information is optional. Streams are assigned a non-exclusive
- dependency on stream 0x0 by default. Pushed streams
- initially depend on their associated stream. In both cases, streams are assigned a
- default weight of 16.
-
-
-
-
-
-
- HTTP/2 framing permits two classes of error:
-
-
- An error condition that renders the entire connection unusable is a connection error.
-
-
- An error in an individual stream is a stream error.
-
-
-
-
- A list of error codes is included in .
-
-
-
-
- A connection error is any error which prevents further processing of the framing layer,
- or which corrupts any connection state.
-
-
- An endpoint that encounters a connection error SHOULD first send a GOAWAY
- frame ( ) with the stream identifier of the last stream that it
- successfully received from its peer. The GOAWAY frame includes an error
- code that indicates why the connection is terminating. After sending the
- GOAWAY frame, the endpoint MUST close the TCP connection.
-
-
- It is possible that the GOAWAY will not be reliably received by the
- receiving endpoint (see ). In the event of a connection error,
- GOAWAY only provides a best effort attempt to communicate with the peer
- about why the connection is being terminated.
-
-
- An endpoint can end a connection at any time. In particular, an endpoint MAY choose to
- treat a stream error as a connection error. Endpoints SHOULD send a
- GOAWAY frame when ending a connection, providing that circumstances
- permit it.
-
-
-
-
-
- A stream error is an error related to a specific stream that does not affect processing
- of other streams.
-
-
- An endpoint that detects a stream error sends a RST_STREAM frame ( ) that contains the stream identifier of the stream where the error
- occurred. The RST_STREAM frame includes an error code that indicates the
- type of error.
-
-
- A RST_STREAM is the last frame that an endpoint can send on a stream.
- The peer that sends the RST_STREAM frame MUST be prepared to receive any
- frames that were sent or enqueued for sending by the remote peer. These frames can be
- ignored, except where they modify connection state (such as the state maintained for
- header compression , or flow control).
-
-
- Normally, an endpoint SHOULD NOT send more than one RST_STREAM frame for
- any stream. However, an endpoint MAY send additional RST_STREAM frames if
- it receives frames on a closed stream after more than a round-trip time. This behavior
- is permitted to deal with misbehaving implementations.
-
-
- An endpoint MUST NOT send a RST_STREAM in response to an
- RST_STREAM frame, to avoid looping.
-
-
-
-
-
- If the TCP connection is closed or reset while streams remain in open or half closed
- states, then the endpoint MUST assume that those streams were abnormally interrupted and
- could be incomplete.
-
-
-
-
-
-
- HTTP/2 permits extension of the protocol. Protocol extensions can be used to provide
- additional services or alter any aspect of the protocol, within the limitations described
- in this section. Extensions are effective only within the scope of a single HTTP/2
- connection.
-
-
- Extensions are permitted to use new frame types , new
- settings , or new error
- codes . Registries are established for managing these extension points: frame types , settings and
- error codes .
-
-
- Implementations MUST ignore unknown or unsupported values in all extensible protocol
- elements. Implementations MUST discard frames that have unknown or unsupported types.
- This means that any of these extension points can be safely used by extensions without
- prior arrangement or negotiation. However, extension frames that appear in the middle of
- a header block are not permitted; these MUST be treated
- as a connection error of type
- PROTOCOL_ERROR .
-
-
- However, extensions that could change the semantics of existing protocol components MUST
- be negotiated before being used. For example, an extension that changes the layout of the
- HEADERS frame cannot be used until the peer has given a positive signal
- that this is acceptable. In this case, it could also be necessary to coordinate when the
- revised layout comes into effect. Note that treating any frame other than
- DATA frames as flow controlled is such a change in semantics, and can only
- be done through negotiation.
-
-
- This document doesn't mandate a specific method for negotiating the use of an extension,
- but notes that a setting could be used for that
- purpose. If both peers set a value that indicates willingness to use the extension, then
- the extension can be used. If a setting is used for extension negotiation, the initial
- value MUST be defined so that the extension is initially disabled.
-
-
-
-
-
-
- This specification defines a number of frame types, each identified by a unique 8-bit type
- code. Each frame type serves a distinct purpose either in the establishment and management
- of the connection as a whole, or of individual streams.
-
-
- The transmission of specific frame types can alter the state of a connection. If endpoints
- fail to maintain a synchronized view of the connection state, successful communication
- within the connection will no longer be possible. Therefore, it is important that endpoints
- have a shared comprehension of how the state is affected by the use any given frame.
-
-
-
-
- DATA frames (type=0x0) convey arbitrary, variable-length sequences of octets associated
- with a stream. One or more DATA frames are used, for instance, to carry HTTP request or
- response payloads.
-
-
- DATA frames MAY also contain arbitrary padding. Padding can be added to DATA frames to
- obscure the size of messages.
-
-
-
-
-
- The DATA frame contains the following fields:
-
-
- An 8-bit field containing the length of the frame padding in units of octets. This
- field is optional and is only present if the PADDED flag is set.
-
-
- Application data. The amount of data is the remainder of the frame payload after
- subtracting the length of the other fields that are present.
-
-
- Padding octets that contain no application semantic value. Padding octets MUST be set
- to zero when sending and ignored when receiving.
-
-
-
-
-
- The DATA frame defines the following flags:
-
-
- Bit 1 being set indicates that this frame is the last that the endpoint will send for
- the identified stream. Setting this flag causes the stream to enter one of the "half closed" states or the "closed" state .
-
-
- Bit 4 being set indicates that the Pad Length field and any padding that it describes
- is present.
-
-
-
-
- DATA frames MUST be associated with a stream. If a DATA frame is received whose stream
- identifier field is 0x0, the recipient MUST respond with a connection error of type
- PROTOCOL_ERROR .
-
-
- DATA frames are subject to flow control and can only be sent when a stream is in the
- "open" or "half closed (remote)" states. The entire DATA frame payload is included in flow
- control, including Pad Length and Padding fields if present. If a DATA frame is received
- whose stream is not in "open" or "half closed (local)" state, the recipient MUST respond
- with a stream error of type
- STREAM_CLOSED .
-
-
- The total number of padding octets is determined by the value of the Pad Length field. If
- the length of the padding is greater than the length of the frame payload, the recipient
- MUST treat this as a connection error of
- type PROTOCOL_ERROR .
-
-
- A frame can be increased in size by one octet by including a Pad Length field with a
- value of zero.
-
-
-
-
- Padding is a security feature; see .
-
-
-
-
-
- The HEADERS frame (type=0x1) is used to open a stream ,
- and additionally carries a header block fragment. HEADERS frames can be sent on a stream
- in the "open" or "half closed (remote)" states.
-
-
-
-
-
- The HEADERS frame payload has the following fields:
-
-
- An 8-bit field containing the length of the frame padding in units of octets. This
- field is only present if the PADDED flag is set.
-
-
- A single bit flag indicates that the stream dependency is exclusive, see . This field is only present if the PRIORITY flag is set.
-
-
- A 31-bit stream identifier for the stream that this stream depends on, see . This field is only present if the PRIORITY flag is set.
-
-
- An 8-bit weight for the stream, see . Add one to the
- value to obtain a weight between 1 and 256. This field is only present if the
- PRIORITY flag is set.
-
-
- A header block fragment .
-
-
- Padding octets that contain no application semantic value. Padding octets MUST be set
- to zero when sending and ignored when receiving.
-
-
-
-
-
- The HEADERS frame defines the following flags:
-
-
-
- Bit 1 being set indicates that the header block is
- the last that the endpoint will send for the identified stream. Setting this flag
- causes the stream to enter one of "half closed"
- states .
-
-
- A HEADERS frame carries the END_STREAM flag that signals the end of a stream.
- However, a HEADERS frame with the END_STREAM flag set can be followed by
- CONTINUATION frames on the same stream. Logically, the
- CONTINUATION frames are part of the HEADERS frame.
-
-
-
-
- Bit 3 being set indicates that this frame contains an entire header block and is not followed by any
- CONTINUATION frames.
-
-
- A HEADERS frame without the END_HEADERS flag set MUST be followed by a
- CONTINUATION frame for the same stream. A receiver MUST treat the
- receipt of any other type of frame or a frame on a different stream as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
- Bit 4 being set indicates that the Pad Length field and any padding that it
- describes is present.
-
-
-
-
- Bit 6 being set indicates that the Exclusive Flag (E), Stream Dependency, and Weight
- fields are present; see .
-
-
-
-
-
-
- The payload of a HEADERS frame contains a header block
- fragment . A header block that does not fit within a HEADERS frame is continued in
- a CONTINUATION frame .
-
-
-
- HEADERS frames MUST be associated with a stream. If a HEADERS frame is received whose
- stream identifier field is 0x0, the recipient MUST respond with a connection error of type
- PROTOCOL_ERROR .
-
-
-
- The HEADERS frame changes the connection state as described in .
-
-
-
- The HEADERS frame includes optional padding. Padding fields and flags are identical to
- those defined for DATA frames .
-
-
- Prioritization information in a HEADERS frame is logically equivalent to a separate
- PRIORITY frame, but inclusion in HEADERS avoids the potential for churn in
- stream prioritization when new streams are created. Priorization fields in HEADERS frames
- subsequent to the first on a stream reprioritize the
- stream .
-
-
-
-
-
- The PRIORITY frame (type=0x2) specifies the sender-advised
- priority of a stream . It can be sent at any time for an existing stream, including
- closed streams. This enables reprioritization of existing streams.
-
-
-
-
-
- The payload of a PRIORITY frame contains the following fields:
-
-
- A single bit flag indicates that the stream dependency is exclusive, see .
-
-
- A 31-bit stream identifier for the stream that this stream depends on, see .
-
-
- An 8-bit weight for the identified stream dependency, see . Add one to the value to obtain a weight between 1 and 256.
-
-
-
-
-
- The PRIORITY frame does not define any flags.
-
-
-
- The PRIORITY frame is associated with an existing stream. If a PRIORITY frame is received
- with a stream identifier of 0x0, the recipient MUST respond with a connection error of type
- PROTOCOL_ERROR .
-
-
- The PRIORITY frame can be sent on a stream in any of the "reserved (remote)", "open",
- "half closed (local)", "half closed (remote)", or "closed" states, though it cannot be
- sent between consecutive frames that comprise a single header
- block . Note that this frame could arrive after processing or frame sending has
- completed, which would cause it to have no effect on the current stream. For a stream
- that is in the "half closed (remote)" or "closed" - state, this frame can only affect
- processing of the current stream and not frame transmission.
-
-
- The PRIORITY frame is the only frame that can be sent for a stream in the "closed" state.
- This allows for the reprioritization of a group of dependent streams by altering the
- priority of a parent stream, which might be closed. However, a PRIORITY frame sent on a
- closed stream risks being ignored due to the peer having discarded priority state
- information for that stream.
-
-
-
-
-
- The RST_STREAM frame (type=0x3) allows for abnormal termination of a stream. When sent by
- the initiator of a stream, it indicates that they wish to cancel the stream or that an
- error condition has occurred. When sent by the receiver of a stream, it indicates that
- either the receiver is rejecting the stream, requesting that the stream be cancelled, or
- that an error condition has occurred.
-
-
-
-
-
-
- The RST_STREAM frame contains a single unsigned, 32-bit integer identifying the error code . The error code indicates why the stream is being
- terminated.
-
-
-
- The RST_STREAM frame does not define any flags.
-
-
-
- The RST_STREAM frame fully terminates the referenced stream and causes it to enter the
- closed state. After receiving a RST_STREAM on a stream, the receiver MUST NOT send
- additional frames for that stream, with the exception of PRIORITY . However,
- after sending the RST_STREAM, the sending endpoint MUST be prepared to receive and process
- additional frames sent on the stream that might have been sent by the peer prior to the
- arrival of the RST_STREAM.
-
-
-
- RST_STREAM frames MUST be associated with a stream. If a RST_STREAM frame is received
- with a stream identifier of 0x0, the recipient MUST treat this as a connection error of type
- PROTOCOL_ERROR .
-
-
-
- RST_STREAM frames MUST NOT be sent for a stream in the "idle" state. If a RST_STREAM
- frame identifying an idle stream is received, the recipient MUST treat this as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
-
-
- The SETTINGS frame (type=0x4) conveys configuration parameters that affect how endpoints
- communicate, such as preferences and constraints on peer behavior. The SETTINGS frame is
- also used to acknowledge the receipt of those parameters. Individually, a SETTINGS
- parameter can also be referred to as a "setting".
-
-
- SETTINGS parameters are not negotiated; they describe characteristics of the sending peer,
- which are used by the receiving peer. Different values for the same parameter can be
- advertised by each peer. For example, a client might set a high initial flow control
- window, whereas a server might set a lower value to conserve resources.
-
-
-
- A SETTINGS frame MUST be sent by both endpoints at the start of a connection, and MAY be
- sent at any other time by either endpoint over the lifetime of the connection.
- Implementations MUST support all of the parameters defined by this specification.
-
-
-
- Each parameter in a SETTINGS frame replaces any existing value for that parameter.
- Parameters are processed in the order in which they appear, and a receiver of a SETTINGS
- frame does not need to maintain any state other than the current value of its
- parameters. Therefore, the value of a SETTINGS parameter is the last value that is seen by
- a receiver.
-
-
- SETTINGS parameters are acknowledged by the receiving peer. To enable this, the SETTINGS
- frame defines the following flag:
-
-
- Bit 1 being set indicates that this frame acknowledges receipt and application of the
- peer's SETTINGS frame. When this bit is set, the payload of the SETTINGS frame MUST
- be empty. Receipt of a SETTINGS frame with the ACK flag set and a length field value
- other than 0 MUST be treated as a connection
- error of type FRAME_SIZE_ERROR . For more info, see Settings Synchronization .
-
-
-
-
- SETTINGS frames always apply to a connection, never a single stream. The stream
- identifier for a SETTINGS frame MUST be zero (0x0). If an endpoint receives a SETTINGS
- frame whose stream identifier field is anything other than 0x0, the endpoint MUST respond
- with a connection error of type
- PROTOCOL_ERROR .
-
-
- The SETTINGS frame affects connection state. A badly formed or incomplete SETTINGS frame
- MUST be treated as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
- The payload of a SETTINGS frame consists of zero or more parameters, each consisting of
- an unsigned 16-bit setting identifier and an unsigned 32-bit value.
-
-
-
-
-
-
-
-
-
- The following parameters are defined:
-
-
-
- Allows the sender to inform the remote endpoint of the maximum size of the header
- compression table used to decode header blocks, in octets. The encoder can select
- any size equal to or less than this value by using signaling specific to the
- header compression format inside a header block. The initial value is 4,096
- octets.
-
-
-
-
- This setting can be use to disable server
- push . An endpoint MUST NOT send a PUSH_PROMISE frame if it
- receives this parameter set to a value of 0. An endpoint that has both set this
- parameter to 0 and had it acknowledged MUST treat the receipt of a
- PUSH_PROMISE frame as a connection error of type
- PROTOCOL_ERROR .
-
-
- The initial value is 1, which indicates that server push is permitted. Any value
- other than 0 or 1 MUST be treated as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
- Indicates the maximum number of concurrent streams that the sender will allow.
- This limit is directional: it applies to the number of streams that the sender
- permits the receiver to create. Initially there is no limit to this value. It is
- recommended that this value be no smaller than 100, so as to not unnecessarily
- limit parallelism.
-
-
- A value of 0 for SETTINGS_MAX_CONCURRENT_STREAMS SHOULD NOT be treated as special
- by endpoints. A zero value does prevent the creation of new streams, however this
- can also happen for any limit that is exhausted with active streams. Servers
- SHOULD only set a zero value for short durations; if a server does not wish to
- accept requests, closing the connection could be preferable.
-
-
-
-
- Indicates the sender's initial window size (in octets) for stream level flow
- control. The initial value is 216 -1 (65,535) octets.
-
-
- This setting affects the window size of all streams, including existing streams,
- see .
-
-
- Values above the maximum flow control window size of 231 -1 MUST
- be treated as a connection error of
- type FLOW_CONTROL_ERROR .
-
-
-
-
- Indicates the size of the largest frame payload that the sender is willing to
- receive, in octets.
-
-
- The initial value is 214 (16,384) octets. The value advertised by
- an endpoint MUST be between this initial value and the maximum allowed frame size
- (224 -1 or 16,777,215 octets), inclusive. Values outside this range
- MUST be treated as a connection error
- of type PROTOCOL_ERROR .
-
-
-
-
- This advisory setting informs a peer of the maximum size of header list that the
- sender is prepared to accept, in octets. The value is based on the uncompressed
- size of header fields, including the length of the name and value in octets plus
- an overhead of 32 octets for each header field.
-
-
- For any given request, a lower limit than what is advertised MAY be enforced. The
- initial value of this setting is unlimited.
-
-
-
-
-
- An endpoint that receives a SETTINGS frame with any unknown or unsupported identifier
- MUST ignore that setting.
-
-
-
-
-
- Most values in SETTINGS benefit from or require an understanding of when the peer has
- received and applied the changed parameter values. In order to provide
- such synchronization timepoints, the recipient of a SETTINGS frame in which the ACK flag
- is not set MUST apply the updated parameters as soon as possible upon receipt.
-
-
- The values in the SETTINGS frame MUST be processed in the order they appear, with no
- other frame processing between values. Unsupported parameters MUST be ignored. Once
- all values have been processed, the recipient MUST immediately emit a SETTINGS frame
- with the ACK flag set. Upon receiving a SETTINGS frame with the ACK flag set, the sender
- of the altered parameters can rely on the setting having been applied.
-
-
- If the sender of a SETTINGS frame does not receive an acknowledgement within a
- reasonable amount of time, it MAY issue a connection error of type
- SETTINGS_TIMEOUT .
-
-
-
-
-
-
- The PUSH_PROMISE frame (type=0x5) is used to notify the peer endpoint in advance of
- streams the sender intends to initiate. The PUSH_PROMISE frame includes the unsigned
- 31-bit identifier of the stream the endpoint plans to create along with a set of headers
- that provide additional context for the stream. contains a
- thorough description of the use of PUSH_PROMISE frames.
-
-
-
-
-
-
- The PUSH_PROMISE frame payload has the following fields:
-
-
- An 8-bit field containing the length of the frame padding in units of octets. This
- field is only present if the PADDED flag is set.
-
-
- A single reserved bit.
-
-
- An unsigned 31-bit integer that identifies the stream that is reserved by the
- PUSH_PROMISE. The promised stream identifier MUST be a valid choice for the next
- stream sent by the sender (see new stream
- identifier ).
-
-
- A header block fragment containing request header
- fields.
-
-
- Padding octets.
-
-
-
-
-
- The PUSH_PROMISE frame defines the following flags:
-
-
-
- Bit 3 being set indicates that this frame contains an entire header block and is not followed by any
- CONTINUATION frames.
-
-
- A PUSH_PROMISE frame without the END_HEADERS flag set MUST be followed by a
- CONTINUATION frame for the same stream. A receiver MUST treat the receipt of any
- other type of frame or a frame on a different stream as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
- Bit 4 being set indicates that the Pad Length field and any padding that it
- describes is present.
-
-
-
-
-
-
- PUSH_PROMISE frames MUST be associated with an existing, peer-initiated stream. The stream
- identifier of a PUSH_PROMISE frame indicates the stream it is associated with. If the
- stream identifier field specifies the value 0x0, a recipient MUST respond with a connection error of type
- PROTOCOL_ERROR .
-
-
-
- Promised streams are not required to be used in the order they are promised. The
- PUSH_PROMISE only reserves stream identifiers for later use.
-
-
-
- PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH setting of the
- peer endpoint is set to 0. An endpoint that has set this setting and has received
- acknowledgement MUST treat the receipt of a PUSH_PROMISE frame as a connection error of type
- PROTOCOL_ERROR .
-
-
- Recipients of PUSH_PROMISE frames can choose to reject promised streams by returning a
- RST_STREAM referencing the promised stream identifier back to the sender of
- the PUSH_PROMISE.
-
-
-
- A PUSH_PROMISE frame modifies the connection state in two ways. The inclusion of a header block potentially modifies the state maintained for
- header compression. PUSH_PROMISE also reserves a stream for later use, causing the
- promised stream to enter the "reserved" state. A sender MUST NOT send a PUSH_PROMISE on a
- stream unless that stream is either "open" or "half closed (remote)"; the sender MUST
- ensure that the promised stream is a valid choice for a new stream identifier (that is, the promised stream MUST
- be in the "idle" state).
-
-
- Since PUSH_PROMISE reserves a stream, ignoring a PUSH_PROMISE frame causes the stream
- state to become indeterminate. A receiver MUST treat the receipt of a PUSH_PROMISE on a
- stream that is neither "open" nor "half closed (local)" as a connection error of type
- PROTOCOL_ERROR . However, an endpoint that has sent
- RST_STREAM on the associated stream MUST handle PUSH_PROMISE frames that
- might have been created before the RST_STREAM frame is received and
- processed.
-
-
- A receiver MUST treat the receipt of a PUSH_PROMISE that promises an illegal stream identifier (that is, an identifier for a
- stream that is not currently in the "idle" state) as a connection error of type
- PROTOCOL_ERROR .
-
-
-
- The PUSH_PROMISE frame includes optional padding. Padding fields and flags are identical
- to those defined for DATA frames .
-
-
-
-
-
- The PING frame (type=0x6) is a mechanism for measuring a minimal round trip time from the
- sender, as well as determining whether an idle connection is still functional. PING
- frames can be sent from any endpoint.
-
-
-
-
-
-
- In addition to the frame header, PING frames MUST contain 8 octets of data in the payload.
- A sender can include any value it chooses and use those bytes in any fashion.
-
-
- Receivers of a PING frame that does not include an ACK flag MUST send a PING frame with
- the ACK flag set in response, with an identical payload. PING responses SHOULD be given
- higher priority than any other frame.
-
-
-
- The PING frame defines the following flags:
-
-
- Bit 1 being set indicates that this PING frame is a PING response. An endpoint MUST
- set this flag in PING responses. An endpoint MUST NOT respond to PING frames
- containing this flag.
-
-
-
-
- PING frames are not associated with any individual stream. If a PING frame is received
- with a stream identifier field value other than 0x0, the recipient MUST respond with a
- connection error of type
- PROTOCOL_ERROR .
-
-
- Receipt of a PING frame with a length field value other than 8 MUST be treated as a connection error of type
- FRAME_SIZE_ERROR .
-
-
-
-
-
-
- The GOAWAY frame (type=0x7) informs the remote peer to stop creating streams on this
- connection. GOAWAY can be sent by either the client or the server. Once sent, the sender
- will ignore frames sent on any new streams with identifiers higher than the included last
- stream identifier. Receivers of a GOAWAY frame MUST NOT open additional streams on the
- connection, although a new connection can be established for new streams.
-
-
- The purpose of this frame is to allow an endpoint to gracefully stop accepting new
- streams, while still finishing processing of previously established streams. This enables
- administrative actions, like server maintenance.
-
-
- There is an inherent race condition between an endpoint starting new streams and the
- remote sending a GOAWAY frame. To deal with this case, the GOAWAY contains the stream
- identifier of the last peer-initiated stream which was or might be processed on the
- sending endpoint in this connection. For instance, if the server sends a GOAWAY frame,
- the identified stream is the highest numbered stream initiated by the client.
-
-
- If the receiver of the GOAWAY has sent data on streams with a higher stream identifier
- than what is indicated in the GOAWAY frame, those streams are not or will not be
- processed. The receiver of the GOAWAY frame can treat the streams as though they had
- never been created at all, thereby allowing those streams to be retried later on a new
- connection.
-
-
- Endpoints SHOULD always send a GOAWAY frame before closing a connection so that the remote
- can know whether a stream has been partially processed or not. For example, if an HTTP
- client sends a POST at the same time that a server closes a connection, the client cannot
- know if the server started to process that POST request if the server does not send a
- GOAWAY frame to indicate what streams it might have acted on.
-
-
- An endpoint might choose to close a connection without sending GOAWAY for misbehaving
- peers.
-
-
-
-
-
-
- The GOAWAY frame does not define any flags.
-
-
- The GOAWAY frame applies to the connection, not a specific stream. An endpoint MUST treat
- a GOAWAY frame with a stream identifier other than 0x0 as a connection error of type
- PROTOCOL_ERROR .
-
-
- The last stream identifier in the GOAWAY frame contains the highest numbered stream
- identifier for which the sender of the GOAWAY frame might have taken some action on, or
- might yet take action on. All streams up to and including the identified stream might
- have been processed in some way. The last stream identifier can be set to 0 if no streams
- were processed.
-
-
- In this context, "processed" means that some data from the stream was passed to some
- higher layer of software that might have taken some action as a result.
-
-
- If a connection terminates without a GOAWAY frame, the last stream identifier is
- effectively the highest possible stream identifier.
-
-
- On streams with lower or equal numbered identifiers that were not closed completely prior
- to the connection being closed, re-attempting requests, transactions, or any protocol
- activity is not possible, with the exception of idempotent actions like HTTP GET, PUT, or
- DELETE. Any protocol activity that uses higher numbered streams can be safely retried
- using a new connection.
-
-
- Activity on streams numbered lower or equal to the last stream identifier might still
- complete successfully. The sender of a GOAWAY frame might gracefully shut down a
- connection by sending a GOAWAY frame, maintaining the connection in an open state until
- all in-progress streams complete.
-
-
- An endpoint MAY send multiple GOAWAY frames if circumstances change. For instance, an
- endpoint that sends GOAWAY with NO_ERROR during graceful shutdown could
- subsequently encounter an condition that requires immediate termination of the connection.
- The last stream identifier from the last GOAWAY frame received indicates which streams
- could have been acted upon. Endpoints MUST NOT increase the value they send in the last
- stream identifier, since the peers might already have retried unprocessed requests on
- another connection.
-
-
- A client that is unable to retry requests loses all requests that are in flight when the
- server closes the connection. This is especially true for intermediaries that might
- not be serving clients using HTTP/2. A server that is attempting to gracefully shut down
- a connection SHOULD send an initial GOAWAY frame with the last stream identifier set to
- 231 -1 and a NO_ERROR code. This signals to the client that
- a shutdown is imminent and that no further requests can be initiated. After waiting at
- least one round trip time, the server can send another GOAWAY frame with an updated last
- stream identifier. This ensures that a connection can be cleanly shut down without losing
- requests.
-
-
-
- After sending a GOAWAY frame, the sender can discard frames for streams with identifiers
- higher than the identified last stream. However, any frames that alter connection state
- cannot be completely ignored. For instance, HEADERS ,
- PUSH_PROMISE and CONTINUATION frames MUST be minimally
- processed to ensure the state maintained for header compression is consistent (see ); similarly DATA frames MUST be counted toward the connection flow
- control window. Failure to process these frames can cause flow control or header
- compression state to become unsynchronized.
-
-
-
- The GOAWAY frame also contains a 32-bit error code that
- contains the reason for closing the connection.
-
-
- Endpoints MAY append opaque data to the payload of any GOAWAY frame. Additional debug
- data is intended for diagnostic purposes only and carries no semantic value. Debug
- information could contain security- or privacy-sensitive data. Logged or otherwise
- persistently stored debug data MUST have adequate safeguards to prevent unauthorized
- access.
-
-
-
-
-
- The WINDOW_UPDATE frame (type=0x8) is used to implement flow control; see for an overview.
-
-
- Flow control operates at two levels: on each individual stream and on the entire
- connection.
-
-
- Both types of flow control are hop-by-hop; that is, only between the two endpoints.
- Intermediaries do not forward WINDOW_UPDATE frames between dependent connections.
- However, throttling of data transfer by any receiver can indirectly cause the propagation
- of flow control information toward the original sender.
-
-
- Flow control only applies to frames that are identified as being subject to flow control.
- Of the frame types defined in this document, this includes only DATA frames.
- Frames that are exempt from flow control MUST be accepted and processed, unless the
- receiver is unable to assign resources to handling the frame. A receiver MAY respond with
- a stream error or connection error of type
- FLOW_CONTROL_ERROR if it is unable to accept a frame.
-
-
-
-
-
- The payload of a WINDOW_UPDATE frame is one reserved bit, plus an unsigned 31-bit integer
- indicating the number of octets that the sender can transmit in addition to the existing
- flow control window. The legal range for the increment to the flow control window is 1 to
- 231 -1 (0x7fffffff) octets.
-
-
- The WINDOW_UPDATE frame does not define any flags.
-
-
- The WINDOW_UPDATE frame can be specific to a stream or to the entire connection. In the
- former case, the frame's stream identifier indicates the affected stream; in the latter,
- the value "0" indicates that the entire connection is the subject of the frame.
-
-
- A receiver MUST treat the receipt of a WINDOW_UPDATE frame with an flow control window
- increment of 0 as a stream error of type
- PROTOCOL_ERROR ; errors on the connection flow control window MUST be
- treated as a connection error .
-
-
- WINDOW_UPDATE can be sent by a peer that has sent a frame bearing the END_STREAM flag.
- This means that a receiver could receive a WINDOW_UPDATE frame on a "half closed (remote)"
- or "closed" stream. A receiver MUST NOT treat this as an error, see .
-
-
- A receiver that receives a flow controlled frame MUST always account for its contribution
- against the connection flow control window, unless the receiver treats this as a connection error . This is necessary even if the
- frame is in error. Since the sender counts the frame toward the flow control window, if
- the receiver does not, the flow control window at sender and receiver can become
- different.
-
-
-
-
- Flow control in HTTP/2 is implemented using a window kept by each sender on every
- stream. The flow control window is a simple integer value that indicates how many octets
- of data the sender is permitted to transmit; as such, its size is a measure of the
- buffering capacity of the receiver.
-
-
- Two flow control windows are applicable: the stream flow control window and the
- connection flow control window. The sender MUST NOT send a flow controlled frame with a
- length that exceeds the space available in either of the flow control windows advertised
- by the receiver. Frames with zero length with the END_STREAM flag set (that is, an
- empty DATA frame) MAY be sent if there is no available space in either
- flow control window.
-
-
- For flow control calculations, the 9 octet frame header is not counted.
-
-
- After sending a flow controlled frame, the sender reduces the space available in both
- windows by the length of the transmitted frame.
-
-
- The receiver of a frame sends a WINDOW_UPDATE frame as it consumes data and frees up
- space in flow control windows. Separate WINDOW_UPDATE frames are sent for the stream
- and connection level flow control windows.
-
-
- A sender that receives a WINDOW_UPDATE frame updates the corresponding window by the
- amount specified in the frame.
-
-
- A sender MUST NOT allow a flow control window to exceed 231 -1 octets.
- If a sender receives a WINDOW_UPDATE that causes a flow control window to exceed this
- maximum it MUST terminate either the stream or the connection, as appropriate. For
- streams, the sender sends a RST_STREAM with the error code of
- FLOW_CONTROL_ERROR code; for the connection, a GOAWAY
- frame with a FLOW_CONTROL_ERROR code.
-
-
- Flow controlled frames from the sender and WINDOW_UPDATE frames from the receiver are
- completely asynchronous with respect to each other. This property allows a receiver to
- aggressively update the window size kept by the sender to prevent streams from stalling.
-
-
-
-
-
- When an HTTP/2 connection is first established, new streams are created with an initial
- flow control window size of 65,535 octets. The connection flow control window is 65,535
- octets. Both endpoints can adjust the initial window size for new streams by including
- a value for SETTINGS_INITIAL_WINDOW_SIZE in the SETTINGS
- frame that forms part of the connection preface. The connection flow control window can
- only be changed using WINDOW_UPDATE frames.
-
-
- Prior to receiving a SETTINGS frame that sets a value for
- SETTINGS_INITIAL_WINDOW_SIZE , an endpoint can only use the default
- initial window size when sending flow controlled frames. Similarly, the connection flow
- control window is set to the default initial window size until a WINDOW_UPDATE frame is
- received.
-
-
- A SETTINGS frame can alter the initial flow control window size for all
- current streams. When the value of SETTINGS_INITIAL_WINDOW_SIZE changes,
- a receiver MUST adjust the size of all stream flow control windows that it maintains by
- the difference between the new value and the old value.
-
-
- A change to SETTINGS_INITIAL_WINDOW_SIZE can cause the available space in
- a flow control window to become negative. A sender MUST track the negative flow control
- window, and MUST NOT send new flow controlled frames until it receives WINDOW_UPDATE
- frames that cause the flow control window to become positive.
-
-
- For example, if the client sends 60KB immediately on connection establishment, and the
- server sets the initial window size to be 16KB, the client will recalculate the
- available flow control window to be -44KB on receipt of the SETTINGS
- frame. The client retains a negative flow control window until WINDOW_UPDATE frames
- restore the window to being positive, after which the client can resume sending.
-
-
- A SETTINGS frame cannot alter the connection flow control window.
-
-
- An endpoint MUST treat a change to SETTINGS_INITIAL_WINDOW_SIZE that
- causes any flow control window to exceed the maximum size as a connection error of type
- FLOW_CONTROL_ERROR .
-
-
-
-
-
- A receiver that wishes to use a smaller flow control window than the current size can
- send a new SETTINGS frame. However, the receiver MUST be prepared to
- receive data that exceeds this window size, since the sender might send data that
- exceeds the lower limit prior to processing the SETTINGS frame.
-
-
- After sending a SETTINGS frame that reduces the initial flow control window size, a
- receiver has two options for handling streams that exceed flow control limits:
-
-
- The receiver can immediately send RST_STREAM with
- FLOW_CONTROL_ERROR error code for the affected streams.
-
-
- The receiver can accept the streams and tolerate the resulting head of line
- blocking, sending WINDOW_UPDATE frames as it consumes data.
-
-
-
-
-
-
-
-
- The CONTINUATION frame (type=0x9) is used to continue a sequence of header block fragments . Any number of CONTINUATION frames can
- be sent on an existing stream, as long as the preceding frame is on the same stream and is
- a HEADERS , PUSH_PROMISE or CONTINUATION frame without the
- END_HEADERS flag set.
-
-
-
-
-
-
- The CONTINUATION frame payload contains a header block
- fragment .
-
-
-
- The CONTINUATION frame defines the following flag:
-
-
-
- Bit 3 being set indicates that this frame ends a header
- block .
-
-
- If the END_HEADERS bit is not set, this frame MUST be followed by another
- CONTINUATION frame. A receiver MUST treat the receipt of any other type of frame or
- a frame on a different stream as a connection
- error of type PROTOCOL_ERROR .
-
-
-
-
-
-
- The CONTINUATION frame changes the connection state as defined in .
-
-
-
- CONTINUATION frames MUST be associated with a stream. If a CONTINUATION frame is received
- whose stream identifier field is 0x0, the recipient MUST respond with a connection error of type PROTOCOL_ERROR.
-
-
-
- A CONTINUATION frame MUST be preceded by a HEADERS ,
- PUSH_PROMISE or CONTINUATION frame without the END_HEADERS flag set. A
- recipient that observes violation of this rule MUST respond with a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
-
-
- Error codes are 32-bit fields that are used in RST_STREAM and
- GOAWAY frames to convey the reasons for the stream or connection error.
-
-
-
- Error codes share a common code space. Some error codes apply only to either streams or the
- entire connection and have no defined semantics in the other context.
-
-
-
- The following error codes are defined:
-
-
- The associated condition is not as a result of an error. For example, a
- GOAWAY might include this code to indicate graceful shutdown of a
- connection.
-
-
- The endpoint detected an unspecific protocol error. This error is for use when a more
- specific error code is not available.
-
-
- The endpoint encountered an unexpected internal error.
-
-
- The endpoint detected that its peer violated the flow control protocol.
-
-
- The endpoint sent a SETTINGS frame, but did not receive a response in a
- timely manner. See Settings Synchronization .
-
-
- The endpoint received a frame after a stream was half closed.
-
-
- The endpoint received a frame with an invalid size.
-
-
- The endpoint refuses the stream prior to performing any application processing, see
- for details.
-
-
- Used by the endpoint to indicate that the stream is no longer needed.
-
-
- The endpoint is unable to maintain the header compression context for the connection.
-
-
- The connection established in response to a CONNECT
- request was reset or abnormally closed.
-
-
- The endpoint detected that its peer is exhibiting a behavior that might be generating
- excessive load.
-
-
- The underlying transport has properties that do not meet minimum security
- requirements (see ).
-
-
-
-
- Unknown or unsupported error codes MUST NOT trigger any special behavior. These MAY be
- treated by an implementation as being equivalent to INTERNAL_ERROR .
-
-
-
-
-
- HTTP/2 is intended to be as compatible as possible with current uses of HTTP. This means
- that, from the application perspective, the features of the protocol are largely
- unchanged. To achieve this, all request and response semantics are preserved, although the
- syntax of conveying those semantics has changed.
-
-
- Thus, the specification and requirements of HTTP/1.1 Semantics and Content , Conditional Requests , Range Requests , Caching and Authentication are applicable to HTTP/2. Selected portions of HTTP/1.1 Message Syntax
- and Routing , such as the HTTP and HTTPS URI schemes, are also
- applicable in HTTP/2, but the expression of those semantics for this protocol are defined
- in the sections below.
-
-
-
-
- A client sends an HTTP request on a new stream, using a previously unused stream identifier . A server sends an HTTP response on
- the same stream as the request.
-
-
- An HTTP message (request or response) consists of:
-
-
- for a response only, zero or more HEADERS frames (each followed by zero
- or more CONTINUATION frames) containing the message headers of
- informational (1xx) HTTP responses (see and ),
- and
-
-
- one HEADERS frame (followed by zero or more CONTINUATION
- frames) containing the message headers (see ), and
-
-
- zero or more DATA frames containing the message payload (see ), and
-
-
- optionally, one HEADERS frame, followed by zero or more
- CONTINUATION frames containing the trailer-part, if present (see ).
-
-
- The last frame in the sequence bears an END_STREAM flag, noting that a
- HEADERS frame bearing the END_STREAM flag can be followed by
- CONTINUATION frames that carry any remaining portions of the header block.
-
-
- Other frames (from any stream) MUST NOT occur between either HEADERS frame
- and any CONTINUATION frames that might follow.
-
-
-
- Trailing header fields are carried in a header block that also terminates the stream.
- That is, a sequence starting with a HEADERS frame, followed by zero or more
- CONTINUATION frames, where the HEADERS frame bears an
- END_STREAM flag. Header blocks after the first that do not terminate the stream are not
- part of an HTTP request or response.
-
-
- A HEADERS frame (and associated CONTINUATION frames) can
- only appear at the start or end of a stream. An endpoint that receives a
- HEADERS frame without the END_STREAM flag set after receiving a final
- (non-informational) status code MUST treat the corresponding request or response as malformed .
-
-
-
- An HTTP request/response exchange fully consumes a single stream. A request starts with
- the HEADERS frame that puts the stream into an "open" state. The request
- ends with a frame bearing END_STREAM, which causes the stream to become "half closed
- (local)" for the client and "half closed (remote)" for the server. A response starts with
- a HEADERS frame and ends with a frame bearing END_STREAM, which places the
- stream in the "closed" state.
-
-
-
-
-
- HTTP/2 removes support for the 101 (Switching Protocols) informational status code
- ( ).
-
-
- The semantics of 101 (Switching Protocols) aren't applicable to a multiplexed protocol.
- Alternative protocols are able to use the same mechanisms that HTTP/2 uses to negotiate
- their use (see ).
-
-
-
-
-
- HTTP header fields carry information as a series of key-value pairs. For a listing of
- registered HTTP headers, see the Message Header Field Registry maintained at .
-
-
-
-
- While HTTP/1.x used the message start-line (see ) to convey the target URI and method of the request, and the
- status code for the response, HTTP/2 uses special pseudo-header fields beginning with
- ':' character (ASCII 0x3a) for this purpose.
-
-
- Pseudo-header fields are not HTTP header fields. Endpoints MUST NOT generate
- pseudo-header fields other than those defined in this document.
-
-
- Pseudo-header fields are only valid in the context in which they are defined.
- Pseudo-header fields defined for requests MUST NOT appear in responses; pseudo-header
- fields defined for responses MUST NOT appear in requests. Pseudo-header fields MUST
- NOT appear in trailers. Endpoints MUST treat a request or response that contains
- undefined or invalid pseudo-header fields as malformed .
-
-
- Just as in HTTP/1.x, header field names are strings of ASCII characters that are
- compared in a case-insensitive fashion. However, header field names MUST be converted
- to lowercase prior to their encoding in HTTP/2. A request or response containing
- uppercase header field names MUST be treated as malformed .
-
-
- All pseudo-header fields MUST appear in the header block before regular header fields.
- Any request or response that contains a pseudo-header field that appears in a header
- block after a regular header field MUST be treated as malformed .
-
-
-
-
-
- HTTP/2 does not use the Connection header field to
- indicate connection-specific header fields; in this protocol, connection-specific
- metadata is conveyed by other means. An endpoint MUST NOT generate a HTTP/2 message
- containing connection-specific header fields; any message containing
- connection-specific header fields MUST be treated as malformed .
-
-
- This means that an intermediary transforming an HTTP/1.x message to HTTP/2 will need
- to remove any header fields nominated by the Connection header field, along with the
- Connection header field itself. Such intermediaries SHOULD also remove other
- connection-specific header fields, such as Keep-Alive, Proxy-Connection,
- Transfer-Encoding and Upgrade, even if they are not nominated by Connection.
-
-
- One exception to this is the TE header field, which MAY be present in an HTTP/2
- request, but when it is MUST NOT contain any value other than "trailers".
-
-
-
-
- HTTP/2 purposefully does not support upgrade to another protocol. The handshake
- methods described in are believed sufficient to
- negotiate the use of alternative protocols.
-
-
-
-
-
-
-
- The following pseudo-header fields are defined for HTTP/2 requests:
-
-
-
- The :method pseudo-header field includes the HTTP
- method ( ).
-
-
-
-
- The :scheme pseudo-header field includes the scheme
- portion of the target URI ( ).
-
-
- :scheme is not restricted to http and https schemed URIs. A
- proxy or gateway can translate requests for non-HTTP schemes, enabling the use
- of HTTP to interact with non-HTTP services.
-
-
-
-
- The :authority pseudo-header field includes the
- authority portion of the target URI ( ). The authority MUST NOT include the deprecated userinfo subcomponent for http
- or https schemed URIs.
-
-
- To ensure that the HTTP/1.1 request line can be reproduced accurately, this
- pseudo-header field MUST be omitted when translating from an HTTP/1.1 request
- that has a request target in origin or asterisk form (see ). Clients that generate
- HTTP/2 requests directly SHOULD use the :authority pseudo-header
- field instead of the Host header field. An
- intermediary that converts an HTTP/2 request to HTTP/1.1 MUST create a Host header field if one is not present in a request by
- copying the value of the :authority pseudo-header
- field.
-
-
-
-
- The :path pseudo-header field includes the path and
- query parts of the target URI (the path-absolute
- production from and optionally a '?' character
- followed by the query production, see and ). A request in asterisk form includes the value '*' for the
- :path pseudo-header field.
-
-
- This pseudo-header field MUST NOT be empty for http
- or https URIs; http or
- https URIs that do not contain a path component
- MUST include a value of '/'. The exception to this rule is an OPTIONS request
- for an http or https
- URI that does not include a path component; these MUST include a :path pseudo-header field with a value of '*' (see ).
-
-
-
-
-
- All HTTP/2 requests MUST include exactly one valid value for the :method , :scheme , and :path pseudo-header fields, unless it is a CONNECT request . An HTTP request that omits mandatory
- pseudo-header fields is malformed .
-
-
- HTTP/2 does not define a way to carry the version identifier that is included in the
- HTTP/1.1 request line.
-
-
-
-
-
- For HTTP/2 responses, a single :status pseudo-header
- field is defined that carries the HTTP status code field (see ). This pseudo-header field MUST be included in all
- responses, otherwise the response is malformed .
-
-
- HTTP/2 does not define a way to carry the version or reason phrase that is included in
- an HTTP/1.1 status line.
-
-
-
-
-
- The Cookie header field can carry a significant amount of
- redundant data.
-
-
- The Cookie header field uses a semi-colon (";") to delimit cookie-pairs (or "crumbs").
- This header field doesn't follow the list construction rules in HTTP (see ), which prevents cookie-pairs from
- being separated into different name-value pairs. This can significantly reduce
- compression efficiency as individual cookie-pairs are updated.
-
-
- To allow for better compression efficiency, the Cookie header field MAY be split into
- separate header fields, each with one or more cookie-pairs. If there are multiple
- Cookie header fields after decompression, these MUST be concatenated into a single
- octet string using the two octet delimiter of 0x3B, 0x20 (the ASCII string "; ")
- before being passed into a non-HTTP/2 context, such as an HTTP/1.1 connection, or a
- generic HTTP server application.
-
-
-
- Therefore, the following two lists of Cookie header fields are semantically
- equivalent.
-
-
-
-
-
-
-
- A malformed request or response is one that is an otherwise valid sequence of HTTP/2
- frames, but is otherwise invalid due to the presence of extraneous frames, prohibited
- header fields, the absence of mandatory header fields, or the inclusion of uppercase
- header field names.
-
-
- A request or response that includes an entity body can include a content-length header field. A request or response is also
- malformed if the value of a content-length header field
- does not equal the sum of the DATA frame payload lengths that form the
- body. A response that is defined to have no payload, as described in , can have a non-zero
- content-length header field, even though no content is
- included in DATA frames.
-
-
- Intermediaries that process HTTP requests or responses (i.e., any intermediary not
- acting as a tunnel) MUST NOT forward a malformed request or response. Malformed
- requests or responses that are detected MUST be treated as a stream error of type PROTOCOL_ERROR .
-
-
- For malformed requests, a server MAY send an HTTP response prior to closing or
- resetting the stream. Clients MUST NOT accept a malformed response. Note that these
- requirements are intended to protect against several types of common attacks against
- HTTP; they are deliberately strict, because being permissive can expose
- implementations to these vulnerabilities.
-
-
-
-
-
-
- This section shows HTTP/1.1 requests and responses, with illustrations of equivalent
- HTTP/2 requests and responses.
-
-
- An HTTP GET request includes request header fields and no body and is therefore
- transmitted as a single HEADERS frame, followed by zero or more
- CONTINUATION frames containing the serialized block of request header
- fields. The HEADERS frame in the following has both the END_HEADERS and
- END_STREAM flags set; no CONTINUATION frames are sent:
-
-
-
- + END_STREAM
- Accept: image/jpeg + END_HEADERS
- :method = GET
- :scheme = https
- :path = /resource
- host = example.org
- accept = image/jpeg
-]]>
-
-
-
- Similarly, a response that includes only response header fields is transmitted as a
- HEADERS frame (again, followed by zero or more
- CONTINUATION frames) containing the serialized block of response header
- fields.
-
-
-
- + END_STREAM
- Expires: Thu, 23 Jan ... + END_HEADERS
- :status = 304
- etag = "xyzzy"
- expires = Thu, 23 Jan ...
-]]>
-
-
-
- An HTTP POST request that includes request header fields and payload data is transmitted
- as one HEADERS frame, followed by zero or more
- CONTINUATION frames containing the request header fields, followed by one
- or more DATA frames, with the last CONTINUATION (or
- HEADERS ) frame having the END_HEADERS flag set and the final
- DATA frame having the END_STREAM flag set:
-
-
-
- - END_STREAM
- Content-Type: image/jpeg - END_HEADERS
- Content-Length: 123 :method = POST
- :path = /resource
- {binary data} :scheme = https
-
- CONTINUATION
- + END_HEADERS
- content-type = image/jpeg
- host = example.org
- content-length = 123
-
- DATA
- + END_STREAM
- {binary data}
-]]>
-
- Note that data contributing to any given header field could be spread between header
- block fragments. The allocation of header fields to frames in this example is
- illustrative only.
-
-
-
-
- A response that includes header fields and payload data is transmitted as a
- HEADERS frame, followed by zero or more CONTINUATION
- frames, followed by one or more DATA frames, with the last
- DATA frame in the sequence having the END_STREAM flag set:
-
-
-
- - END_STREAM
- Content-Length: 123 + END_HEADERS
- :status = 200
- {binary data} content-type = image/jpeg
- content-length = 123
-
- DATA
- + END_STREAM
- {binary data}
-]]>
-
-
-
- Trailing header fields are sent as a header block after both the request or response
- header block and all the DATA frames have been sent. The
- HEADERS frame starting the trailers header block has the END_STREAM flag
- set.
-
-
-
- - END_STREAM
- Transfer-Encoding: chunked + END_HEADERS
- Trailer: Foo :status = 200
- content-length = 123
- 123 content-type = image/jpeg
- {binary data} trailer = Foo
- 0
- Foo: bar DATA
- - END_STREAM
- {binary data}
-
- HEADERS
- + END_STREAM
- + END_HEADERS
- foo = bar
-]]>
-
-
-
-
-
- An informational response using a 1xx status code other than 101 is transmitted as a
- HEADERS frame, followed by zero or more CONTINUATION
- frames:
-
- - END_STREAM
- + END_HEADERS
- :status = 103
- extension-field = bar
-]]>
-
-
-
-
-
- In HTTP/1.1, an HTTP client is unable to retry a non-idempotent request when an error
- occurs, because there is no means to determine the nature of the error. It is possible
- that some server processing occurred prior to the error, which could result in
- undesirable effects if the request were reattempted.
-
-
- HTTP/2 provides two mechanisms for providing a guarantee to a client that a request has
- not been processed:
-
-
- The GOAWAY frame indicates the highest stream number that might have
- been processed. Requests on streams with higher numbers are therefore guaranteed to
- be safe to retry.
-
-
- The REFUSED_STREAM error code can be included in a
- RST_STREAM frame to indicate that the stream is being closed prior to
- any processing having occurred. Any request that was sent on the reset stream can
- be safely retried.
-
-
-
-
- Requests that have not been processed have not failed; clients MAY automatically retry
- them, even those with non-idempotent methods.
-
-
- A server MUST NOT indicate that a stream has not been processed unless it can guarantee
- that fact. If frames that are on a stream are passed to the application layer for any
- stream, then REFUSED_STREAM MUST NOT be used for that stream, and a
- GOAWAY frame MUST include a stream identifier that is greater than or
- equal to the given stream identifier.
-
-
- In addition to these mechanisms, the PING frame provides a way for a
- client to easily test a connection. Connections that remain idle can become broken as
- some middleboxes (for instance, network address translators, or load balancers) silently
- discard connection bindings. The PING frame allows a client to safely
- test whether a connection is still active without sending a request.
-
-
-
-
-
-
- HTTP/2 allows a server to pre-emptively send (or "push") responses (along with
- corresponding "promised" requests) to a client in association with a previous
- client-initiated request. This can be useful when the server knows the client will need
- to have those responses available in order to fully process the response to the original
- request.
-
-
-
- Pushing additional message exchanges in this fashion is optional, and is negotiated
- between individual endpoints. The SETTINGS_ENABLE_PUSH setting can be set
- to 0 to indicate that server push is disabled.
-
-
- Promised requests MUST be cacheable (see ), MUST be safe (see ) and MUST NOT include a request body. Clients that receive a
- promised request that is not cacheable, unsafe or that includes a request body MUST
- reset the stream with a stream error of type
- PROTOCOL_ERROR .
-
-
- Pushed responses that are cacheable (see ) can be stored by the client, if it implements a HTTP
- cache. Pushed responses are considered successfully validated on the origin server (e.g.,
- if the "no-cache" cache response directive is present) while the stream identified by the
- promised stream ID is still open.
-
-
- Pushed responses that are not cacheable MUST NOT be stored by any HTTP cache. They MAY
- be made available to the application separately.
-
-
- An intermediary can receive pushes from the server and choose not to forward them on to
- the client. In other words, how to make use of the pushed information is up to that
- intermediary. Equally, the intermediary might choose to make additional pushes to the
- client, without any action taken by the server.
-
-
- A client cannot push. Thus, servers MUST treat the receipt of a
- PUSH_PROMISE frame as a connection
- error of type PROTOCOL_ERROR . Clients MUST reject any attempt to
- change the SETTINGS_ENABLE_PUSH setting to a value other than 0 by treating
- the message as a connection error of type
- PROTOCOL_ERROR .
-
-
-
-
- Server push is semantically equivalent to a server responding to a request; however, in
- this case that request is also sent by the server, as a PUSH_PROMISE
- frame.
-
-
- The PUSH_PROMISE frame includes a header block that contains a complete
- set of request header fields that the server attributes to the request. It is not
- possible to push a response to a request that includes a request body.
-
-
-
- Pushed responses are always associated with an explicit request from the client. The
- PUSH_PROMISE frames sent by the server are sent on that explicit
- request's stream. The PUSH_PROMISE frame also includes a promised stream
- identifier, chosen from the stream identifiers available to the server (see ).
-
-
-
- The header fields in PUSH_PROMISE and any subsequent
- CONTINUATION frames MUST be a valid and complete set of request header fields . The server MUST include a method in
- the :method header field that is safe and cacheable. If a
- client receives a PUSH_PROMISE that does not include a complete and valid
- set of header fields, or the :method header field identifies
- a method that is not safe, it MUST respond with a stream error of type PROTOCOL_ERROR .
-
-
-
- The server SHOULD send PUSH_PROMISE ( )
- frames prior to sending any frames that reference the promised responses. This avoids a
- race where clients issue requests prior to receiving any PUSH_PROMISE
- frames.
-
-
- For example, if the server receives a request for a document containing embedded links
- to multiple image files, and the server chooses to push those additional images to the
- client, sending push promises before the DATA frames that contain the
- image links ensures that the client is able to see the promises before discovering
- embedded links. Similarly, if the server pushes responses referenced by the header block
- (for instance, in Link header fields), sending the push promises before sending the
- header block ensures that clients do not request them.
-
-
-
- PUSH_PROMISE frames MUST NOT be sent by the client.
-
-
- PUSH_PROMISE frames can be sent by the server in response to any
- client-initiated stream, but the stream MUST be in either the "open" or "half closed
- (remote)" state with respect to the server. PUSH_PROMISE frames are
- interspersed with the frames that comprise a response, though they cannot be
- interspersed with HEADERS and CONTINUATION frames that
- comprise a single header block.
-
-
- Sending a PUSH_PROMISE frame creates a new stream and puts the stream
- into the “reserved (local)” state for the server and the “reserved (remote)” state for
- the client.
-
-
-
-
-
- After sending the PUSH_PROMISE frame, the server can begin delivering the
- pushed response as a response on a server-initiated
- stream that uses the promised stream identifier. The server uses this stream to
- transmit an HTTP response, using the same sequence of frames as defined in . This stream becomes "half closed"
- to the client after the initial HEADERS frame is sent.
-
-
-
- Once a client receives a PUSH_PROMISE frame and chooses to accept the
- pushed response, the client SHOULD NOT issue any requests for the promised response
- until after the promised stream has closed.
-
-
-
- If the client determines, for any reason, that it does not wish to receive the pushed
- response from the server, or if the server takes too long to begin sending the promised
- response, the client can send an RST_STREAM frame, using either the
- CANCEL or REFUSED_STREAM codes, and referencing the pushed
- stream's identifier.
-
-
- A client can use the SETTINGS_MAX_CONCURRENT_STREAMS setting to limit the
- number of responses that can be concurrently pushed by a server. Advertising a
- SETTINGS_MAX_CONCURRENT_STREAMS value of zero disables server push by
- preventing the server from creating the necessary streams. This does not prohibit a
- server from sending PUSH_PROMISE frames; clients need to reset any
- promised streams that are not wanted.
-
-
-
- Clients receiving a pushed response MUST validate that either the server is
- authoritative (see ), or the proxy that provided the pushed
- response is configured for the corresponding request. For example, a server that offers
- a certificate for only the example.com DNS-ID or Common Name
- is not permitted to push a response for https://www.example.org/doc .
-
-
- The response for a PUSH_PROMISE stream begins with a
- HEADERS frame, which immediately puts the stream into the “half closed
- (remote)” state for the server and “half closed (local)” state for the client, and ends
- with a frame bearing END_STREAM, which places the stream in the "closed" state.
-
-
- The client never sends a frame with the END_STREAM flag for a server push.
-
-
-
-
-
-
-
-
-
- In HTTP/1.x, the pseudo-method CONNECT ( ) is used to convert an HTTP connection into a tunnel to a remote host.
- CONNECT is primarily used with HTTP proxies to establish a TLS session with an origin
- server for the purposes of interacting with https resources.
-
-
- In HTTP/2, the CONNECT method is used to establish a tunnel over a single HTTP/2 stream to
- a remote host, for similar purposes. The HTTP header field mapping works as defined in
- Request Header Fields , with a few
- differences. Specifically:
-
-
- The :method header field is set to CONNECT .
-
-
- The :scheme and :path header
- fields MUST be omitted.
-
-
- The :authority header field contains the host and port to
- connect to (equivalent to the authority-form of the request-target of CONNECT
- requests, see ).
-
-
-
-
- A proxy that supports CONNECT establishes a TCP connection to
- the server identified in the :authority header field. Once
- this connection is successfully established, the proxy sends a HEADERS
- frame containing a 2xx series status code to the client, as defined in .
-
-
- After the initial HEADERS frame sent by each peer, all subsequent
- DATA frames correspond to data sent on the TCP connection. The payload of
- any DATA frames sent by the client is transmitted by the proxy to the TCP
- server; data received from the TCP server is assembled into DATA frames by
- the proxy. Frame types other than DATA or stream management frames
- (RST_STREAM , WINDOW_UPDATE , and PRIORITY )
- MUST NOT be sent on a connected stream, and MUST be treated as a stream error if received.
-
-
- The TCP connection can be closed by either peer. The END_STREAM flag on a
- DATA frame is treated as being equivalent to the TCP FIN bit. A client is
- expected to send a DATA frame with the END_STREAM flag set after receiving
- a frame bearing the END_STREAM flag. A proxy that receives a DATA frame
- with the END_STREAM flag set sends the attached data with the FIN bit set on the last TCP
- segment. A proxy that receives a TCP segment with the FIN bit set sends a
- DATA frame with the END_STREAM flag set. Note that the final TCP segment
- or DATA frame could be empty.
-
-
- A TCP connection error is signaled with RST_STREAM . A proxy treats any
- error in the TCP connection, which includes receiving a TCP segment with the RST bit set,
- as a stream error of type
- CONNECT_ERROR . Correspondingly, a proxy MUST send a TCP segment with the
- RST bit set if it detects an error with the stream or the HTTP/2 connection.
-
-
-
-
-
-
- This section outlines attributes of the HTTP protocol that improve interoperability, reduce
- exposure to known security vulnerabilities, or reduce the potential for implementation
- variation.
-
-
-
-
- HTTP/2 connections are persistent. For best performance, it is expected clients will not
- close connections until it is determined that no further communication with a server is
- necessary (for example, when a user navigates away from a particular web page), or until
- the server closes the connection.
-
-
- Clients SHOULD NOT open more than one HTTP/2 connection to a given host and port pair,
- where host is derived from a URI, a selected alternative
- service , or a configured proxy.
-
-
- A client can create additional connections as replacements, either to replace connections
- that are near to exhausting the available stream
- identifier space , to refresh the keying material for a TLS connection, or to
- replace connections that have encountered errors .
-
-
- A client MAY open multiple connections to the same IP address and TCP port using different
- Server Name Indication values or to provide different TLS
- client certificates, but SHOULD avoid creating multiple connections with the same
- configuration.
-
-
- Servers are encouraged to maintain open connections for as long as possible, but are
- permitted to terminate idle connections if necessary. When either endpoint chooses to
- close the transport-layer TCP connection, the terminating endpoint SHOULD first send a
- GOAWAY ( ) frame so that both endpoints can reliably
- determine whether previously sent frames have been processed and gracefully complete or
- terminate any necessary remaining tasks.
-
-
-
-
- Connections that are made to an origin servers, either directly or through a tunnel
- created using the CONNECT method MAY be reused for
- requests with multiple different URI authority components. A connection can be reused
- as long as the origin server is authoritative . For
- http resources, this depends on the host having resolved to
- the same IP address.
-
-
- For https resources, connection reuse additionally depends
- on having a certificate that is valid for the host in the URI. An origin server might
- offer a certificate with multiple subjectAltName attributes,
- or names with wildcards, one of which is valid for the authority in the URI. For
- example, a certificate with a subjectAltName of *.example.com might permit the use of the same connection for
- requests to URIs starting with https://a.example.com/ and
- https://b.example.com/ .
-
-
- In some deployments, reusing a connection for multiple origins can result in requests
- being directed to the wrong origin server. For example, TLS termination might be
- performed by a middlebox that uses the TLS Server Name Indication
- (SNI) extension to select an origin server. This means that it is possible
- for clients to send confidential information to servers that might not be the intended
- target for the request, even though the server is otherwise authoritative.
-
-
- A server that does not wish clients to reuse connections can indicate that it is not
- authoritative for a request by sending a 421 (Misdirected Request) status code in response
- to the request (see ).
-
-
- A client that is configured to use a proxy over HTTP/2 directs requests to that proxy
- through a single connection. That is, all requests sent via a proxy reuse the
- connection to the proxy.
-
-
-
-
-
- The 421 (Misdirected Request) status code indicates that the request was directed at a
- server that is not able to produce a response. This can be sent by a server that is not
- configured to produce responses for the combination of scheme and authority that are
- included in the request URI.
-
-
- Clients receiving a 421 (Misdirected Request) response from a server MAY retry the
- request - whether the request method is idempotent or not - over a different connection.
- This is possible if a connection is reused ( ) or if an alternative
- service is selected ( ).
-
-
- This status code MUST NOT be generated by proxies.
-
-
- A 421 response is cacheable by default; i.e., unless otherwise indicated by the method
- definition or explicit cache controls (see ).
-
-
-
-
-
-
- Implementations of HTTP/2 MUST support TLS 1.2 for HTTP/2 over
- TLS. The general TLS usage guidance in SHOULD be followed, with
- some additional restrictions that are specific to HTTP/2.
-
-
-
- An implementation of HTTP/2 over TLS MUST use TLS 1.2 or higher with the restrictions on
- feature set and cipher suite described in this section. Due to implementation
- limitations, it might not be possible to fail TLS negotiation. An endpoint MUST
- immediately terminate an HTTP/2 connection that does not meet these minimum requirements
- with a connection error of type
- INADEQUATE_SECURITY .
-
-
-
-
- The TLS implementation MUST support the Server Name Indication
- (SNI) extension to TLS. HTTP/2 clients MUST indicate the target domain name when
- negotiating TLS.
-
-
- The TLS implementation MUST disable compression. TLS compression can lead to the
- exposure of information that would not otherwise be revealed .
- Generic compression is unnecessary since HTTP/2 provides compression features that are
- more aware of context and therefore likely to be more appropriate for use for
- performance, security or other reasons.
-
-
- The TLS implementation MUST disable renegotiation. An endpoint MUST treat a TLS
- renegotiation as a connection error of type
- PROTOCOL_ERROR . Note that disabling renegotiation can result in
- long-lived connections becoming unusable due to limits on the number of messages the
- underlying cipher suite can encipher.
-
-
- A client MAY use renegotiation to provide confidentiality protection for client
- credentials offered in the handshake, but any renegotiation MUST occur prior to sending
- the connection preface. A server SHOULD request a client certificate if it sees a
- renegotiation request immediately after establishing a connection.
-
-
- This effectively prevents the use of renegotiation in response to a request for a
- specific protected resource. A future specification might provide a way to support this
- use case.
-
-
-
-
-
- The set of TLS cipher suites that are permitted in HTTP/2 is restricted. HTTP/2 MUST
- only be used with cipher suites that have ephemeral key exchange, such as the ephemeral Diffie-Hellman (DHE) or the elliptic curve variant (ECDHE) . Ephemeral key exchange MUST
- have a minimum size of 2048 bits for DHE or security level of 128 bits for ECDHE.
- Clients MUST accept DHE sizes of up to 4096 bits. HTTP MUST NOT be used with cipher
- suites that use stream or block ciphers. Authenticated Encryption with Additional Data
- (AEAD) modes, such as the Galois Counter Model (GCM) mode for
- AES are acceptable.
-
-
- The effect of these restrictions is that TLS 1.2 implementations could have
- non-intersecting sets of available cipher suites, since these prevent the use of the
- cipher suite that TLS 1.2 makes mandatory. To avoid this problem, implementations of
- HTTP/2 that use TLS 1.2 MUST support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 with P256 .
-
-
- Clients MAY advertise support of cipher suites that are prohibited by the above
- restrictions in order to allow for connection to servers that do not support HTTP/2.
- This enables a fallback to protocols without these constraints without the additional
- latency imposed by using a separate connection for fallback.
-
-
-
-
-
-
-
-
- HTTP/2 relies on the HTTP/1.1 definition of authority for determining whether a server is
- authoritative in providing a given response, see . This relies on local name resolution for the "http"
- URI scheme, and the authenticated server identity for the "https" scheme (see ).
-
-
-
-
-
- In a cross-protocol attack, an attacker causes a client to initiate a transaction in one
- protocol toward a server that understands a different protocol. An attacker might be able
- to cause the transaction to appear as valid transaction in the second protocol. In
- combination with the capabilities of the web context, this can be used to interact with
- poorly protected servers in private networks.
-
-
- Completing a TLS handshake with an ALPN identifier for HTTP/2 can be considered sufficient
- protection against cross protocol attacks. ALPN provides a positive indication that a
- server is willing to proceed with HTTP/2, which prevents attacks on other TLS-based
- protocols.
-
-
- The encryption in TLS makes it difficult for attackers to control the data which could be
- used in a cross-protocol attack on a cleartext protocol.
-
-
- The cleartext version of HTTP/2 has minimal protection against cross-protocol attacks.
- The connection preface contains a string that is
- designed to confuse HTTP/1.1 servers, but no special protection is offered for other
- protocols. A server that is willing to ignore parts of an HTTP/1.1 request containing an
- Upgrade header field in addition to the client connection preface could be exposed to a
- cross-protocol attack.
-
-
-
-
-
- HTTP/2 header field names and values are encoded as sequences of octets with a length
- prefix. This enables HTTP/2 to carry any string of octets as the name or value of a
- header field. An intermediary that translates HTTP/2 requests or responses into HTTP/1.1
- directly could permit the creation of corrupted HTTP/1.1 messages. An attacker might
- exploit this behavior to cause the intermediary to create HTTP/1.1 messages with illegal
- header fields, extra header fields, or even new messages that are entirely falsified.
-
-
- Header field names or values that contain characters not permitted by HTTP/1.1, including
- carriage return (ASCII 0xd) or line feed (ASCII 0xa) MUST NOT be translated verbatim by an
- intermediary, as stipulated in .
-
-
- Translation from HTTP/1.x to HTTP/2 does not produce the same opportunity to an attacker.
- Intermediaries that perform translation to HTTP/2 MUST remove any instances of the obs-fold production from header field values.
-
-
-
-
-
- Pushed responses do not have an explicit request from the client; the request
- is provided by the server in the PUSH_PROMISE frame.
-
-
- Caching responses that are pushed is possible based on the guidance provided by the origin
- server in the Cache-Control header field. However, this can cause issues if a single
- server hosts more than one tenant. For example, a server might offer multiple users each
- a small portion of its URI space.
-
-
- Where multiple tenants share space on the same server, that server MUST ensure that
- tenants are not able to push representations of resources that they do not have authority
- over. Failure to enforce this would allow a tenant to provide a representation that would
- be served out of cache, overriding the actual representation that the authoritative tenant
- provides.
-
-
- Pushed responses for which an origin server is not authoritative (see
- ) are never cached or used.
-
-
-
-
-
- An HTTP/2 connection can demand a greater commitment of resources to operate than a
- HTTP/1.1 connection. The use of header compression and flow control depend on a
- commitment of resources for storing a greater amount of state. Settings for these
- features ensure that memory commitments for these features are strictly bounded.
-
-
- The number of PUSH_PROMISE frames is not constrained in the same fashion.
- A client that accepts server push SHOULD limit the number of streams it allows to be in
- the "reserved (remote)" state. Excessive number of server push streams can be treated as
- a stream error of type
- ENHANCE_YOUR_CALM .
-
-
- Processing capacity cannot be guarded as effectively as state capacity.
-
-
- The SETTINGS frame can be abused to cause a peer to expend additional
- processing time. This might be done by pointlessly changing SETTINGS parameters, setting
- multiple undefined parameters, or changing the same setting multiple times in the same
- frame. WINDOW_UPDATE or PRIORITY frames can be abused to
- cause an unnecessary waste of resources.
-
-
- Large numbers of small or empty frames can be abused to cause a peer to expend time
- processing frame headers. Note however that some uses are entirely legitimate, such as
- the sending of an empty DATA frame to end a stream.
-
-
- Header compression also offers some opportunities to waste processing resources; see for more details on potential abuses.
-
-
- Limits in SETTINGS parameters cannot be reduced instantaneously, which
- leaves an endpoint exposed to behavior from a peer that could exceed the new limits. In
- particular, immediately after establishing a connection, limits set by a server are not
- known to clients and could be exceeded without being an obvious protocol violation.
-
-
- All these features - i.e., SETTINGS changes, small frames, header
- compression - have legitimate uses. These features become a burden only when they are
- used unnecessarily or to excess.
-
-
- An endpoint that doesn't monitor this behavior exposes itself to a risk of denial of
- service attack. Implementations SHOULD track the use of these features and set limits on
- their use. An endpoint MAY treat activity that is suspicious as a connection error of type
- ENHANCE_YOUR_CALM .
-
-
-
-
- A large header block can cause an implementation to
- commit a large amount of state. Header fields that are critical for routing can appear
- toward the end of a header block, which prevents streaming of header fields to their
- ultimate destination. For this an other reasons, such as ensuring cache correctness,
- means that an endpoint might need to buffer the entire header block. Since there is no
- hard limit to the size of a header block, some endpoints could be forced commit a large
- amount of available memory for header fields.
-
-
- An endpoint can use the SETTINGS_MAX_HEADER_LIST_SIZE to advise peers of
- limits that might apply on the size of header blocks. This setting is only advisory, so
- endpoints MAY choose to send header blocks that exceed this limit and risk having the
- request or response being treated as malformed. This setting specific to a connection,
- so any request or response could encounter a hop with a lower, unknown limit. An
- intermediary can attempt to avoid this problem by passing on values presented by
- different peers, but they are not obligated to do so.
-
-
- A server that receives a larger header block than it is willing to handle can send an
- HTTP 431 (Request Header Fields Too Large) status code . A
- client can discard responses that it cannot process. The header block MUST be processed
- to ensure a consistent connection state, unless the connection is closed.
-
-
-
-
-
-
- HTTP/2 enables greater use of compression for both header fields ( ) and entity bodies. Compression can allow an attacker to recover
- secret data when it is compressed in the same context as data under attacker control.
-
-
- There are demonstrable attacks on compression that exploit the characteristics of the web
- (e.g., ). The attacker induces multiple requests containing
- varying plaintext, observing the length of the resulting ciphertext in each, which
- reveals a shorter length when a guess about the secret is correct.
-
-
- Implementations communicating on a secure channel MUST NOT compress content that includes
- both confidential and attacker-controlled data unless separate compression dictionaries
- are used for each source of data. Compression MUST NOT be used if the source of data
- cannot be reliably determined. Generic stream compression, such as that provided by TLS
- MUST NOT be used with HTTP/2 ( ).
-
-
- Further considerations regarding the compression of header fields are described in .
-
-
-
-
-
- Padding within HTTP/2 is not intended as a replacement for general purpose padding, such
- as might be provided by TLS . Redundant padding could even be
- counterproductive. Correct application can depend on having specific knowledge of the
- data that is being padded.
-
-
- To mitigate attacks that rely on compression, disabling or limiting compression might be
- preferable to padding as a countermeasure.
-
-
- Padding can be used to obscure the exact size of frame content, and is provided to
- mitigate specific attacks within HTTP. For example, attacks where compressed content
- includes both attacker-controlled plaintext and secret data (see for example, ).
-
-
- Use of padding can result in less protection than might seem immediately obvious. At
- best, padding only makes it more difficult for an attacker to infer length information by
- increasing the number of frames an attacker has to observe. Incorrectly implemented
- padding schemes can be easily defeated. In particular, randomized padding with a
- predictable distribution provides very little protection; similarly, padding payloads to a
- fixed size exposes information as payload sizes cross the fixed size boundary, which could
- be possible if an attacker can control plaintext.
-
-
- Intermediaries SHOULD retain padding for DATA frames, but MAY drop padding
- for HEADERS and PUSH_PROMISE frames. A valid reason for an
- intermediary to change the amount of padding of frames is to improve the protections that
- padding provides.
-
-
-
-
-
- Several characteristics of HTTP/2 provide an observer an opportunity to correlate actions
- of a single client or server over time. This includes the value of settings, the manner
- in which flow control windows are managed, the way priorities are allocated to streams,
- timing of reactions to stimulus, and handling of any optional features.
-
-
- As far as this creates observable differences in behavior, they could be used as a basis
- for fingerprinting a specific client, as defined in .
-
-
-
-
-
-
- A string for identifying HTTP/2 is entered into the "Application Layer Protocol Negotiation
- (ALPN) Protocol IDs" registry established in .
-
-
- This document establishes a registry for frame types, settings, and error codes. These new
- registries are entered into a new "Hypertext Transfer Protocol (HTTP) 2 Parameters" section.
-
-
- This document registers the HTTP2-Settings header field for
- use in HTTP; and the 421 (Misdirected Request) status code.
-
-
- This document registers the PRI method for use in HTTP, to avoid
- collisions with the connection preface .
-
-
-
-
- This document creates two registrations for the identification of HTTP/2 in the
- "Application Layer Protocol Negotiation (ALPN) Protocol IDs" registry established in .
-
-
- The "h2" string identifies HTTP/2 when used over TLS:
-
- HTTP/2 over TLS
- 0x68 0x32 ("h2")
- This document
-
-
-
- The "h2c" string identifies HTTP/2 when used over cleartext TCP:
-
- HTTP/2 over TCP
- 0x68 0x32 0x63 ("h2c")
- This document
-
-
-
-
-
-
- This document establishes a registry for HTTP/2 frame type codes. The "HTTP/2 Frame
- Type" registry manages an 8-bit space. The "HTTP/2 Frame Type" registry operates under
- either of the "IETF Review" or "IESG Approval" policies for
- values between 0x00 and 0xef, with values between 0xf0 and 0xff being reserved for
- experimental use.
-
-
- New entries in this registry require the following information:
-
-
- A name or label for the frame type.
-
-
- The 8-bit code assigned to the frame type.
-
-
- A reference to a specification that includes a description of the frame layout,
- it's semantics and flags that the frame type uses, including any parts of the frame
- that are conditionally present based on the value of flags.
-
-
-
-
- The entries in the following table are registered by this document.
-
-
- Frame Type
- Code
- Section
- DATA 0x0
- HEADERS 0x1
- PRIORITY 0x2
- RST_STREAM 0x3
- SETTINGS 0x4
- PUSH_PROMISE 0x5
- PING 0x6
- GOAWAY 0x7
- WINDOW_UPDATE 0x8
- CONTINUATION 0x9
-
-
-
-
-
- This document establishes a registry for HTTP/2 settings. The "HTTP/2 Settings" registry
- manages a 16-bit space. The "HTTP/2 Settings" registry operates under the "Expert Review" policy for values in the range from 0x0000 to
- 0xefff, with values between and 0xf000 and 0xffff being reserved for experimental use.
-
-
- New registrations are advised to provide the following information:
-
-
- A symbolic name for the setting. Specifying a setting name is optional.
-
-
- The 16-bit code assigned to the setting.
-
-
- An initial value for the setting.
-
-
- An optional reference to a specification that describes the use of the setting.
-
-
-
-
- An initial set of setting registrations can be found in .
-
-
- Name
- Code
- Initial Value
- Specification
- HEADER_TABLE_SIZE
- 0x1 4096
- ENABLE_PUSH
- 0x2 1
- MAX_CONCURRENT_STREAMS
- 0x3 (infinite)
- INITIAL_WINDOW_SIZE
- 0x4 65535
- MAX_FRAME_SIZE
- 0x5 16384
- MAX_HEADER_LIST_SIZE
- 0x6 (infinite)
-
-
-
-
-
-
- This document establishes a registry for HTTP/2 error codes. The "HTTP/2 Error Code"
- registry manages a 32-bit space. The "HTTP/2 Error Code" registry operates under the
- "Expert Review" policy .
-
-
- Registrations for error codes are required to include a description of the error code. An
- expert reviewer is advised to examine new registrations for possible duplication with
- existing error codes. Use of existing registrations is to be encouraged, but not
- mandated.
-
-
- New registrations are advised to provide the following information:
-
-
- A name for the error code. Specifying an error code name is optional.
-
-
- The 32-bit error code value.
-
-
- A brief description of the error code semantics, longer if no detailed specification
- is provided.
-
-
- An optional reference for a specification that defines the error code.
-
-
-
-
- The entries in the following table are registered by this document.
-
-
- Name
- Code
- Description
- Specification
- NO_ERROR 0x0
- Graceful shutdown
-
- PROTOCOL_ERROR 0x1
- Protocol error detected
-
- INTERNAL_ERROR 0x2
- Implementation fault
-
- FLOW_CONTROL_ERROR 0x3
- Flow control limits exceeded
-
- SETTINGS_TIMEOUT 0x4
- Settings not acknowledged
-
- STREAM_CLOSED 0x5
- Frame received for closed stream
-
- FRAME_SIZE_ERROR 0x6
- Frame size incorrect
-
- REFUSED_STREAM 0x7
- Stream not processed
-
- CANCEL 0x8
- Stream cancelled
-
- COMPRESSION_ERROR 0x9
- Compression state not updated
-
- CONNECT_ERROR 0xa
- TCP connection error for CONNECT method
-
- ENHANCE_YOUR_CALM 0xb
- Processing capacity exceeded
-
- INADEQUATE_SECURITY 0xc
- Negotiated TLS parameters not acceptable
-
-
-
-
-
-
-
- This section registers the HTTP2-Settings header field in the
- Permanent Message Header Field Registry .
-
-
- HTTP2-Settings
-
-
- http
-
-
- standard
-
-
- IETF
-
-
- of this document
-
-
- This header field is only used by an HTTP/2 client for Upgrade-based negotiation.
-
-
-
-
-
-
-
- This section registers the PRI method in the HTTP Method
- Registry ( ).
-
-
- PRI
-
-
- No
-
-
- No
-
-
- of this document
-
-
- This method is never used by an actual client. This method will appear to be used
- when an HTTP/1.1 server or intermediary attempts to parse an HTTP/2 connection
- preface.
-
-
-
-
-
-
-
- This document registers the 421 (Misdirected Request) HTTP Status code in the Hypertext
- Transfer Protocol (HTTP) Status Code Registry ( ).
-
-
-
-
- 421
-
-
- Misdirected Request
-
-
- of this document
-
-
-
-
-
-
-
-
-
- This document includes substantial input from the following individuals:
-
-
- Adam Langley, Wan-Teh Chang, Jim Morrison, Mark Nottingham, Alyssa Wilk, Costin
- Manolache, William Chan, Vitaliy Lvin, Joe Chan, Adam Barth, Ryan Hamilton, Gavin
- Peters, Kent Alstad, Kevin Lindsay, Paul Amer, Fan Yang, Jonathan Leighton (SPDY
- contributors).
-
-
- Gabriel Montenegro and Willy Tarreau (Upgrade mechanism).
-
-
- William Chan, Salvatore Loreto, Osama Mazahir, Gabriel Montenegro, Jitu Padhye, Roberto
- Peon, Rob Trace (Flow control).
-
-
- Mike Bishop (Extensibility).
-
-
- Mark Nottingham, Julian Reschke, James Snell, Jeff Pinner, Mike Bishop, Herve Ruellan
- (Substantial editorial contributions).
-
-
- Kari Hurtta, Tatsuhiro Tsujikawa, Greg Wilkins, Poul-Henning Kamp.
-
-
- Alexey Melnikov was an editor of this document during 2013.
-
-
- A substantial proportion of Martin's contribution was supported by Microsoft during his
- employment there.
-
-
-
-
-
-
-
-
-
-
- HPACK - Header Compression for HTTP/2
-
-
-
-
-
-
-
-
-
-
-
- Transmission Control Protocol
-
-
- University of Southern California (USC)/Information Sciences
- Institute
-
-
-
-
-
-
-
-
-
-
- Key words for use in RFCs to Indicate Requirement Levels
-
-
- Harvard University
- sob@harvard.edu
-
-
-
-
-
-
-
-
-
-
- HTTP Over TLS
-
-
-
-
-
-
-
-
-
- Uniform Resource Identifier (URI): Generic
- Syntax
-
-
-
-
-
-
-
-
-
-
-
- The Base16, Base32, and Base64 Data Encodings
-
-
-
-
-
-
-
-
- Guidelines for Writing an IANA Considerations Section in RFCs
-
-
-
-
-
-
-
-
-
-
- Augmented BNF for Syntax Specifications: ABNF
-
-
-
-
-
-
-
-
-
-
- The Transport Layer Security (TLS) Protocol Version 1.2
-
-
-
-
-
-
-
-
-
-
- Transport Layer Security (TLS) Extensions: Extension Definitions
-
-
-
-
-
-
-
-
-
- Transport Layer Security (TLS) Application-Layer Protocol Negotiation Extension
-
-
-
-
-
-
-
-
-
-
-
-
- TLS Elliptic Curve Cipher Suites with SHA-256/384 and AES Galois
- Counter Mode (GCM)
-
-
-
-
-
-
-
-
-
-
- Digital Signature Standard (DSS)
-
- NIST
-
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Message Syntax and Routing
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Semantics and Content
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Conditional Requests
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Range Requests
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- World Wide Web Consortium
- ylafon@w3.org
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Caching
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- Akamai
- mnot@mnot.net
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
-
- Hypertext Transfer Protocol (HTTP/1.1): Authentication
-
- Adobe Systems Incorporated
- fielding@gbiv.com
-
-
- greenbytes GmbH
- julian.reschke@greenbytes.de
-
-
-
-
-
-
-
-
-
- HTTP State Management Mechanism
-
-
-
-
-
-
-
-
-
-
-
- TCP Extensions for High Performance
-
-
-
-
-
-
-
-
-
-
-
- Transport Layer Security Protocol Compression Methods
-
-
-
-
-
-
-
-
- Additional HTTP Status Codes
-
-
-
-
-
-
-
-
-
-
- Elliptic Curve Cryptography (ECC) Cipher Suites for Transport Layer Security (TLS)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- AES Galois Counter Mode (GCM) Cipher Suites for TLS
-
-
-
-
-
-
-
-
-
-
-
- HTML5
-
-
-
-
-
-
-
-
-
-
- Latest version available at
- .
-
-
-
-
-
-
- Talking to Yourself for Fun and Profit
-
-
-
-
-
-
-
-
-
-
-
-
-
- BREACH: Reviving the CRIME Attack
-
-
-
-
-
-
-
-
-
-
- Registration Procedures for Message Header Fields
-
- Nine by Nine
- GK-IETF@ninebynine.org
-
-
- BEA Systems
- mnot@pobox.com
-
-
- HP Labs
- JeffMogul@acm.org
-
-
-
-
-
-
-
-
-
- Recommendations for Secure Use of TLS and DTLS
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- HTTP Alternative Services
-
-
- Akamai
-
-
- Mozilla
-
-
- greenbytes
-
-
-
-
-
-
-
-
-
-
- This section is to be removed by RFC Editor before publication.
-
-
-
-
- Renamed Not Authoritative status code to Misdirected Request.
-
-
-
-
-
- Pseudo-header fields are now required to appear strictly before regular ones.
-
-
- Restored 1xx series status codes, except 101.
-
-
- Changed frame length field 24-bits. Expanded frame header to 9 octets. Added a setting
- to limit the damage.
-
-
- Added a setting to advise peers of header set size limits.
-
-
- Removed segments.
-
-
- Made non-semantic-bearing HEADERS frames illegal in the HTTP mapping.
-
-
-
-
-
- Restored extensibility options.
-
-
- Restricting TLS cipher suites to AEAD only.
-
-
- Removing Content-Encoding requirements.
-
-
- Permitting the use of PRIORITY after stream close.
-
-
- Removed ALTSVC frame.
-
-
- Removed BLOCKED frame.
-
-
- Reducing the maximum padding size to 256 octets; removing padding from
- CONTINUATION frames.
-
-
- Removed per-frame GZIP compression.
-
-
-
-
-
- Added BLOCKED frame (at risk).
-
-
- Simplified priority scheme.
-
-
- Added DATA per-frame GZIP compression.
-
-
-
-
-
- Changed "connection header" to "connection preface" to avoid confusion.
-
-
- Added dependency-based stream prioritization.
-
-
- Added "h2c" identifier to distinguish between cleartext and secured HTTP/2.
-
-
- Adding missing padding to PUSH_PROMISE .
-
-
- Integrate ALTSVC frame and supporting text.
-
-
- Dropping requirement on "deflate" Content-Encoding.
-
-
- Improving security considerations around use of compression.
-
-
-
-
-
- Adding padding for data frames.
-
-
- Renumbering frame types, error codes, and settings.
-
-
- Adding INADEQUATE_SECURITY error code.
-
-
- Updating TLS usage requirements to 1.2; forbidding TLS compression.
-
-
- Removing extensibility for frames and settings.
-
-
- Changing setting identifier size.
-
-
- Removing the ability to disable flow control.
-
-
- Changing the protocol identification token to "h2".
-
-
- Changing the use of :authority to make it optional and to allow userinfo in non-HTTP
- cases.
-
-
- Allowing split on 0x0 for Cookie.
-
-
- Reserved PRI method in HTTP/1.1 to avoid possible future collisions.
-
-
-
-
-
- Added cookie crumbling for more efficient header compression.
-
-
- Added header field ordering with the value-concatenation mechanism.
-
-
-
-
-
- Marked draft for implementation.
-
-
-
-
-
- Adding definition for CONNECT method.
-
-
- Constraining the use of push to safe, cacheable methods with no request body.
-
-
- Changing from :host to :authority to remove any potential confusion.
-
-
- Adding setting for header compression table size.
-
-
- Adding settings acknowledgement.
-
-
- Removing unnecessary and potentially problematic flags from CONTINUATION.
-
-
- Added denial of service considerations.
-
-
-
-
- Marking the draft ready for implementation.
-
-
- Renumbering END_PUSH_PROMISE flag.
-
-
- Editorial clarifications and changes.
-
-
-
-
-
- Added CONTINUATION frame for HEADERS and PUSH_PROMISE.
-
-
- PUSH_PROMISE is no longer implicitly prohibited if SETTINGS_MAX_CONCURRENT_STREAMS is
- zero.
-
-
- Push expanded to allow all safe methods without a request body.
-
-
- Clarified the use of HTTP header fields in requests and responses. Prohibited HTTP/1.1
- hop-by-hop header fields.
-
-
- Requiring that intermediaries not forward requests with missing or illegal routing
- :-headers.
-
-
- Clarified requirements around handling different frames after stream close, stream reset
- and GOAWAY .
-
-
- Added more specific prohibitions for sending of different frame types in various stream
- states.
-
-
- Making the last received setting value the effective value.
-
-
- Clarified requirements on TLS version, extension and ciphers.
-
-
-
-
-
- Committed major restructuring atrocities.
-
-
- Added reference to first header compression draft.
-
-
- Added more formal description of frame lifecycle.
-
-
- Moved END_STREAM (renamed from FINAL) back to HEADERS /DATA .
-
-
- Removed HEADERS+PRIORITY, added optional priority to HEADERS frame.
-
-
- Added PRIORITY frame.
-
-
-
-
-
- Added continuations to frames carrying header blocks.
-
-
- Replaced use of "session" with "connection" to avoid confusion with other HTTP stateful
- concepts, like cookies.
-
-
- Removed "message".
-
-
- Switched to TLS ALPN from NPN.
-
-
- Editorial changes.
-
-
-
-
-
- Added IANA considerations section for frame types, error codes and settings.
-
-
- Removed data frame compression.
-
-
- Added PUSH_PROMISE .
-
-
- Added globally applicable flags to framing.
-
-
- Removed zlib-based header compression mechanism.
-
-
- Updated references.
-
-
- Clarified stream identifier reuse.
-
-
- Removed CREDENTIALS frame and associated mechanisms.
-
-
- Added advice against naive implementation of flow control.
-
-
- Added session header section.
-
-
- Restructured frame header. Removed distinction between data and control frames.
-
-
- Altered flow control properties to include session-level limits.
-
-
- Added note on cacheability of pushed resources and multiple tenant servers.
-
-
- Changed protocol label form based on discussions.
-
-
-
-
-
- Changed title throughout.
-
-
- Removed section on Incompatibilities with SPDY draft#2.
-
-
- Changed INTERNAL_ERROR on GOAWAY to have a value of 2 .
-
-
- Replaced abstract and introduction.
-
-
- Added section on starting HTTP/2.0, including upgrade mechanism.
-
-
- Removed unused references.
-
-
- Added flow control principles based on .
-
-
-
-
-
- Adopted as base for draft-ietf-httpbis-http2.
-
-
- Updated authors/editors list.
-
-
- Added status note.
-
-
-
-
-
-
-
diff --git a/http2/timer.go b/http2/timer.go
new file mode 100644
index 0000000000..0b1c17b812
--- /dev/null
+++ b/http2/timer.go
@@ -0,0 +1,20 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+package http2
+
+import "time"
+
+// A timer is a time.Timer, as an interface which can be replaced in tests.
+type timer = interface {
+ C() <-chan time.Time
+ Reset(d time.Duration) bool
+ Stop() bool
+}
+
+// timeTimer adapts a time.Timer to the timer interface.
+type timeTimer struct {
+ *time.Timer
+}
+
+func (t timeTimer) C() <-chan time.Time { return t.Timer.C }
diff --git a/http2/transport.go b/http2/transport.go
index 4515b22c4a..f26356b9cd 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -25,8 +25,6 @@ import (
"net/http"
"net/http/httptrace"
"net/textproto"
- "os"
- "sort"
"strconv"
"strings"
"sync"
@@ -36,6 +34,7 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
+ "golang.org/x/net/internal/httpcommon"
)
const (
@@ -147,6 +146,12 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool
+ // IdleConnTimeout is the maximum amount of time an idle
+ // (keep-alive) connection will remain idle before closing
+ // itself.
+ // Zero means no limit.
+ IdleConnTimeout time.Duration
+
// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
@@ -178,41 +183,81 @@ type Transport struct {
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
+
+ *transportTestHooks
}
-func (t *Transport) maxHeaderListSize() uint32 {
- if t.MaxHeaderListSize == 0 {
- return 10 << 20
+// Hook points used for testing.
+// Outside of tests, t.transportTestHooks is nil and these all have minimal implementations.
+// Inside tests, see the testSyncHooks function docs.
+
+type transportTestHooks struct {
+ newclientconn func(*ClientConn)
+ group synctestGroupInterface
+}
+
+func (t *Transport) markNewGoroutine() {
+ if t != nil && t.transportTestHooks != nil {
+ t.transportTestHooks.group.Join()
}
- if t.MaxHeaderListSize == 0xffffffff {
- return 0
+}
+
+func (t *Transport) now() time.Time {
+ if t != nil && t.transportTestHooks != nil {
+ return t.transportTestHooks.group.Now()
}
- return t.MaxHeaderListSize
+ return time.Now()
}
-func (t *Transport) maxFrameReadSize() uint32 {
- if t.MaxReadFrameSize == 0 {
- return 0 // use the default provided by the peer
+func (t *Transport) timeSince(when time.Time) time.Duration {
+ if t != nil && t.transportTestHooks != nil {
+ return t.now().Sub(when)
}
- if t.MaxReadFrameSize < minMaxFrameSize {
- return minMaxFrameSize
+ return time.Since(when)
+}
+
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (t *Transport) newTimer(d time.Duration) timer {
+ if t.transportTestHooks != nil {
+ return t.transportTestHooks.group.NewTimer(d)
}
- if t.MaxReadFrameSize > maxFrameSize {
- return maxFrameSize
+ return timeTimer{time.NewTimer(d)}
+}
+
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (t *Transport) afterFunc(d time.Duration, f func()) timer {
+ if t.transportTestHooks != nil {
+ return t.transportTestHooks.group.AfterFunc(d, f)
}
- return t.MaxReadFrameSize
+ return timeTimer{time.AfterFunc(d, f)}
}
-func (t *Transport) disableCompression() bool {
- return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
+func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ if t.transportTestHooks != nil {
+ return t.transportTestHooks.group.ContextWithTimeout(ctx, d)
+ }
+ return context.WithTimeout(ctx, d)
}
-func (t *Transport) pingTimeout() time.Duration {
- if t.PingTimeout == 0 {
- return 15 * time.Second
+func (t *Transport) maxHeaderListSize() uint32 {
+ n := int64(t.MaxHeaderListSize)
+ if t.t1 != nil && t.t1.MaxResponseHeaderBytes != 0 {
+ n = t.t1.MaxResponseHeaderBytes
+ if n > 0 {
+ n = adjustHTTP1MaxHeaderSize(n)
+ }
+ }
+ if n <= 0 {
+ return 10 << 20
+ }
+ if n >= 0xffffffff {
+ return 0
}
- return t.PingTimeout
+ return uint32(n)
+}
+func (t *Transport) disableCompression() bool {
+ return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
@@ -250,8 +295,8 @@ func configureTransports(t1 *http.Transport) (*Transport, error) {
if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
}
- upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
- addr := authorityAddr("https", authority)
+ upgradeFn := func(scheme, authority string, c net.Conn) http.RoundTripper {
+ addr := authorityAddr(scheme, authority)
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
go c.Close()
return erringRoundTripper{err}
@@ -262,18 +307,37 @@ func configureTransports(t1 *http.Transport) (*Transport, error) {
// was unknown)
go c.Close()
}
+ if scheme == "http" {
+ return (*unencryptedTransport)(t2)
+ }
return t2
}
- if m := t1.TLSNextProto; len(m) == 0 {
- t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
- "h2": upgradeFn,
+ if t1.TLSNextProto == nil {
+ t1.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
+ }
+ t1.TLSNextProto[NextProtoTLS] = func(authority string, c *tls.Conn) http.RoundTripper {
+ return upgradeFn("https", authority, c)
+ }
+ // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
+ t1.TLSNextProto[nextProtoUnencryptedHTTP2] = func(authority string, c *tls.Conn) http.RoundTripper {
+ nc, err := unencryptedNetConnFromTLSConn(c)
+ if err != nil {
+ go c.Close()
+ return erringRoundTripper{err}
}
- } else {
- m["h2"] = upgradeFn
+ return upgradeFn("http", authority, nc)
}
return t2, nil
}
+// unencryptedTransport is a Transport with a RoundTrip method that
+// always permits http:// URLs.
+type unencryptedTransport Transport
+
+func (t *unencryptedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ return (*Transport)(t).RoundTripOpt(req, RoundTripOpt{allowHTTP: true})
+}
+
func (t *Transport) connPool() ClientConnPool {
t.connPoolOnce.Do(t.initConnPool)
return t.connPoolOrDef
@@ -293,7 +357,7 @@ type ClientConn struct {
t *Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState // nil only for specialized impls
- reused uint32 // whether conn is being reused; atomic
+ atomicReused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
getConnCalled bool // used by clientConnPool
@@ -302,33 +366,57 @@ type ClientConn struct {
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
- idleTimer *time.Timer
-
- mu sync.Mutex // guards following
- cond *sync.Cond // hold mu; broadcast on flow/closed changes
- flow outflow // our conn-level flow control quota (cs.outflow is per stream)
- inflow inflow // peer's conn-level flow control
- doNotReuse bool // whether conn is marked to not be reused for any future requests
- closing bool
- closed bool
- seenSettings bool // true if we've seen a settings frame, false otherwise
- wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
- goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
- goAwayDebug string // goAway frame's debug data, retained as a string
- streams map[uint32]*clientStream // client-initiated
- streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
- nextStreamID uint32
- pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
- pings map[[8]byte]chan struct{} // in flight ping data to notification channel
- br *bufio.Reader
- lastActive time.Time
- lastIdle time.Time // time last idle
+ idleTimer timer
+
+ mu sync.Mutex // guards following
+ cond *sync.Cond // hold mu; broadcast on flow/closed changes
+ flow outflow // our conn-level flow control quota (cs.outflow is per stream)
+ inflow inflow // peer's conn-level flow control
+ doNotReuse bool // whether conn is marked to not be reused for any future requests
+ closing bool
+ closed bool
+ closedOnIdle bool // true if conn was closed for idleness
+ seenSettings bool // true if we've seen a settings frame, false otherwise
+ seenSettingsChan chan struct{} // closed when seenSettings is true or frame reading fails
+ wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
+ goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
+ goAwayDebug string // goAway frame's debug data, retained as a string
+ streams map[uint32]*clientStream // client-initiated
+ streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
+ nextStreamID uint32
+ pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
+ pings map[[8]byte]chan struct{} // in flight ping data to notification channel
+ br *bufio.Reader
+ lastActive time.Time
+ lastIdle time.Time // time last idle
// Settings from peer: (also guarded by wmu)
- maxFrameSize uint32
- maxConcurrentStreams uint32
- peerMaxHeaderListSize uint64
- peerMaxHeaderTableSize uint32
- initialWindowSize uint32
+ maxFrameSize uint32
+ maxConcurrentStreams uint32
+ peerMaxHeaderListSize uint64
+ peerMaxHeaderTableSize uint32
+ initialWindowSize uint32
+ initialStreamRecvWindowSize int32
+ readIdleTimeout time.Duration
+ pingTimeout time.Duration
+ extendedConnectAllowed bool
+
+ // rstStreamPingsBlocked works around an unfortunate gRPC behavior.
+ // gRPC strictly limits the number of PING frames that it will receive.
+ // The default is two pings per two hours, but the limit resets every time
+ // the gRPC endpoint sends a HEADERS or DATA frame. See golang/go#70575.
+ //
+ // rstStreamPingsBlocked is set after receiving a response to a PING frame
+ // bundled with an RST_STREAM (see pendingResets below), and cleared after
+ // receiving a HEADERS or DATA frame.
+ rstStreamPingsBlocked bool
+
+ // pendingResets is the number of RST_STREAM frames we have sent to the peer,
+ // without confirming that the peer has received them. When we send a RST_STREAM,
+ // we bundle it with a PING frame, unless a PING is already in flight. We count
+ // the reset stream against the connection's concurrency limit until we get
+ // a PING response. This limits the number of requests we'll try to send to a
+ // completely unresponsive connection.
+ pendingResets int
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock.
@@ -386,12 +474,12 @@ type clientStream struct {
sentHeaders bool
// owned by clientConnReadLoop:
- firstByte bool // got the first response byte
- pastHeaders bool // got first MetaHeadersFrame (actual headers)
- pastTrailers bool // got optional second MetaHeadersFrame (trailers)
- num1xx uint8 // number of 1xx responses seen
- readClosed bool // peer sent an END_STREAM flag
- readAborted bool // read loop reset the stream
+ firstByte bool // got the first response byte
+ pastHeaders bool // got first MetaHeadersFrame (actual headers)
+ pastTrailers bool // got optional second MetaHeadersFrame (trailers)
+ readClosed bool // peer sent an END_STREAM flag
+ readAborted bool // read loop reset the stream
+ totalHeaderSize int64 // total size of 1xx headers seen
trailer http.Header // accumulated trailers
resTrailer *http.Header // client's Response.Trailer
@@ -446,12 +534,14 @@ func (cs *clientStream) closeReqBodyLocked() {
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
+ cs.cc.t.markNewGoroutine()
cs.reqBody.Close()
close(reqBodyClosed)
}()
}
type stickyErrWriter struct {
+ group synctestGroupInterface
conn net.Conn
timeout time.Duration
err *error
@@ -461,22 +551,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
- for {
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
- }
- nn, err := sew.conn.Write(p[n:])
- n += nn
- if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
- // Keep extending the deadline so long as we're making progress.
- continue
- }
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Time{})
- }
- *sew.err = err
- return n, err
- }
+ n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
+ *sew.err = err
+ return n, err
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -507,6 +584,8 @@ type RoundTripOpt struct {
// no cached connection is available, RoundTripOpt
// will return ErrNoCachedConn.
OnlyCachedConn bool
+
+ allowHTTP bool // allow http:// URLs
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -537,18 +616,16 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port)
}
-var retryBackoffHook func(time.Duration) *time.Timer
-
-func backoffNewTimer(d time.Duration) *time.Timer {
- if retryBackoffHook != nil {
- return retryBackoffHook(d)
- }
- return time.NewTimer(d)
-}
-
// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
- if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
+ switch req.URL.Scheme {
+ case "https":
+ // Always okay.
+ case "http":
+ if !t.AllowHTTP && !opt.allowHTTP {
+ return nil, errors.New("http2: unencrypted HTTP/2 not enabled")
+ }
+ default:
return nil, errors.New("http2: unsupported scheme")
}
@@ -559,7 +636,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err
}
- reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+ reused := !atomic.CompareAndSwapUint32(&cc.atomicReused, 0, 1)
traceGotConn(req, cc, reused)
res, err := cc.RoundTrip(req)
if err != nil && retry <= 6 {
@@ -573,17 +650,33 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
- timer := backoffNewTimer(d)
+ tm := t.newTimer(d)
select {
- case <-timer.C:
+ case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
- timer.Stop()
+ tm.Stop()
err = req.Context().Err()
}
}
}
+ if err == errClientConnNotEstablished {
+ // This ClientConn was created recently,
+ // this is the first request to use it,
+ // and the connection is closed and not usable.
+ //
+ // In this state, cc.idleTimer will remove the conn from the pool
+ // when it fires. Stop the timer and remove it here so future requests
+ // won't try to use this connection.
+ //
+ // If the timer has already fired and we're racing it, the redundant
+ // call to MarkDead is harmless.
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+ t.connPool().MarkDead(cc)
+ }
if err != nil {
t.vlogf("RoundTrip failure: %v", err)
return nil, err
@@ -602,9 +695,10 @@ func (t *Transport) CloseIdleConnections() {
}
var (
- errClientConnClosed = errors.New("http2: client conn is closed")
- errClientConnUnusable = errors.New("http2: client conn not usable")
- errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
+ errClientConnClosed = errors.New("http2: client conn is closed")
+ errClientConnUnusable = errors.New("http2: client conn not usable")
+ errClientConnNotEstablished = errors.New("http2: client conn could not be established")
+ errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
)
// shouldRetryRequest is called by RoundTrip when a request fails to get
@@ -658,6 +752,9 @@ func canRetryError(err error) bool {
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
+ if t.transportTestHooks != nil {
+ return t.newClientConn(nil, singleUse)
+ }
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@@ -717,43 +814,38 @@ func (t *Transport) expectContinueTimeout() time.Duration {
return t.t1.ExpectContinueTimeout
}
-func (t *Transport) maxDecoderHeaderTableSize() uint32 {
- if v := t.MaxDecoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
-func (t *Transport) maxEncoderHeaderTableSize() uint32 {
- if v := t.MaxEncoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives())
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
+ conf := configFromTransport(t)
cc := &ClientConn{
- t: t,
- tconn: c,
- readerDone: make(chan struct{}),
- nextStreamID: 1,
- maxFrameSize: 16 << 10, // spec default
- initialWindowSize: 65535, // spec default
- maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
- peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
- streams: make(map[uint32]*clientStream),
- singleUse: singleUse,
- wantSettingsAck: true,
- pings: make(map[[8]byte]chan struct{}),
- reqHeaderMu: make(chan struct{}, 1),
- }
- if d := t.idleConnTimeout(); d != 0 {
- cc.idleTimeout = d
- cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
+ t: t,
+ tconn: c,
+ readerDone: make(chan struct{}),
+ nextStreamID: 1,
+ maxFrameSize: 16 << 10, // spec default
+ initialWindowSize: 65535, // spec default
+ initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
+ maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
+ peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
+ streams: make(map[uint32]*clientStream),
+ singleUse: singleUse,
+ seenSettingsChan: make(chan struct{}),
+ wantSettingsAck: true,
+ readIdleTimeout: conf.SendPingTimeout,
+ pingTimeout: conf.PingTimeout,
+ pings: make(map[[8]byte]chan struct{}),
+ reqHeaderMu: make(chan struct{}, 1),
+ lastActive: t.now(),
+ }
+ var group synctestGroupInterface
+ if t.transportTestHooks != nil {
+ t.markNewGoroutine()
+ t.transportTestHooks.newclientconn(cc)
+ c = cc.tconn
+ group = t.group
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -765,30 +857,25 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{
+ group: group,
conn: c,
- timeout: t.WriteByteTimeout,
+ timeout: conf.WriteByteTimeout,
err: &cc.werr,
})
cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br)
- if t.maxFrameReadSize() != 0 {
- cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
- }
+ cc.fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
if t.CountError != nil {
cc.fr.countError = t.CountError
}
- maxHeaderTableSize := t.maxDecoderHeaderTableSize()
+ maxHeaderTableSize := conf.MaxDecoderHeaderTableSize
cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize())
+ cc.henc.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
cc.peerMaxHeaderTableSize = initialHeaderTableSize
- if t.AllowHTTP {
- cc.nextStreamID = 3
- }
-
if cs, ok := c.(connectionStater); ok {
state := cs.ConnectionState()
cc.tlsState = &state
@@ -796,11 +883,9 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
initialSettings := []Setting{
{ID: SettingEnablePush, Val: 0},
- {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
- }
- if max := t.maxFrameReadSize(); max != 0 {
- initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
+ {ID: SettingInitialWindowSize, Val: uint32(cc.initialStreamRecvWindowSize)},
}
+ initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: conf.MaxReadFrameSize})
if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
}
@@ -810,23 +895,29 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
cc.bw.Write(clientPreface)
cc.fr.WriteSettings(initialSettings...)
- cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow)
- cc.inflow.init(transportDefaultConnFlow + initialWindowSize)
+ cc.fr.WriteWindowUpdate(0, uint32(conf.MaxUploadBufferPerConnection))
+ cc.inflow.init(conf.MaxUploadBufferPerConnection + initialWindowSize)
cc.bw.Flush()
if cc.werr != nil {
cc.Close()
return nil, cc.werr
}
+ // Start the idle timer after the connection is fully initialized.
+ if d := t.idleConnTimeout(); d != 0 {
+ cc.idleTimeout = d
+ cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout)
+ }
+
go cc.readLoop()
return cc, nil
}
func (cc *ClientConn) healthCheck() {
- pingTimeout := cc.t.pingTimeout()
+ pingTimeout := cc.pingTimeout
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
- ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -861,7 +952,20 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
}
last := f.LastStreamID
for streamID, cs := range cc.streams {
- if streamID > last {
+ if streamID <= last {
+ // The server's GOAWAY indicates that it received this stream.
+ // It will either finish processing it, or close the connection
+ // without doing so. Either way, leave the stream alone for now.
+ continue
+ }
+ if streamID == 1 && cc.goAway.ErrCode != ErrCodeNo {
+ // Don't retry the first stream on a connection if we get a non-NO error.
+ // If the server is sending an error on a new connection,
+ // retrying the request on a new one probably isn't going to work.
+ cs.abortStreamLocked(fmt.Errorf("http2: Transport received GOAWAY from server ErrCode:%v", cc.goAway.ErrCode))
+ } else {
+ // Aborting the stream with errClentConnGotGoAway indicates that
+ // the request should be retried on a new connection.
cs.abortStreamLocked(errClientConnGotGoAway)
}
}
@@ -938,7 +1042,7 @@ func (cc *ClientConn) State() ClientConnState {
return ClientConnState{
Closed: cc.closed,
Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil,
- StreamsActive: len(cc.streams),
+ StreamsActive: len(cc.streams) + cc.pendingResets,
StreamsReserved: cc.streamsReserved,
StreamsPending: cc.pendingRequests,
LastIdle: cc.lastIdle,
@@ -970,16 +1074,40 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
// writing it.
maxConcurrentOkay = true
} else {
- maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
+ // We can take a new request if the total of
+ // - active streams;
+ // - reservation slots for new streams; and
+ // - streams for which we have sent a RST_STREAM and a PING,
+ // but received no subsequent frame
+ // is less than the concurrency limit.
+ maxConcurrentOkay = cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
!cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
+
+ // If this connection has never been used for a request and is closed,
+ // then let it take a request (which will fail).
+ // If the conn was closed for idleness, we're racing the idle timer;
+ // don't try to use the conn. (Issue #70515.)
+ //
+ // This avoids a situation where an error early in a connection's lifetime
+ // goes unreported.
+ if cc.nextStreamID == 1 && cc.streamsReserved == 0 && cc.closed && !cc.closedOnIdle {
+ st.canTakeNewRequest = true
+ }
+
return
}
+// currentRequestCountLocked reports the number of concurrency slots currently in use,
+// including active streams, reserved slots, and reset streams waiting for acknowledgement.
+func (cc *ClientConn) currentRequestCountLocked() int {
+ return len(cc.streams) + cc.streamsReserved + cc.pendingResets
+}
+
func (cc *ClientConn) canTakeNewRequestLocked() bool {
st := cc.idleStateLocked()
return st.canTakeNewRequest
@@ -992,7 +1120,7 @@ func (cc *ClientConn) tooIdleLocked() bool {
// times are compared based on their wall time. We don't want
// to reuse a connection that's been sitting idle during
// VM/laptop suspend if monotonic time was also frozen.
- return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout
+ return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && cc.t.timeSince(cc.lastIdle.Round(0)) > cc.idleTimeout
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -1018,7 +1146,7 @@ func (cc *ClientConn) forceCloseConn() {
if !ok {
return
}
- if nc := tlsUnderlyingConn(tc); nc != nil {
+ if nc := tc.NetConn(); nc != nil {
nc.Close()
}
}
@@ -1030,6 +1158,7 @@ func (cc *ClientConn) closeIfIdle() {
return
}
cc.closed = true
+ cc.closedOnIdle = true
nextID := cc.nextStreamID
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
@@ -1057,6 +1186,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
done := make(chan struct{})
cancelled := false // guarded by cc.mu
go func() {
+ cc.t.markNewGoroutine()
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -1145,23 +1275,6 @@ func (cc *ClientConn) closeForLostPing() {
// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
var errRequestCanceled = errors.New("net/http: request canceled")
-func commaSeparatedTrailers(req *http.Request) (string, error) {
- keys := make([]string, 0, len(req.Trailer))
- for k := range req.Trailer {
- k = canonicalHeader(k)
- switch k {
- case "Transfer-Encoding", "Trailer", "Content-Length":
- return "", fmt.Errorf("invalid Trailer key %q", k)
- }
- keys = append(keys, k)
- }
- if len(keys) > 0 {
- sort.Strings(keys)
- return strings.Join(keys, ","), nil
- }
- return "", nil
-}
-
func (cc *ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
@@ -1173,22 +1286,6 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration {
return 0
}
-// checkConnHeaders checks whether req has any invalid connection-level headers.
-// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
-// Certain headers are special-cased as okay but not transmitted later.
-func checkConnHeaders(req *http.Request) error {
- if v := req.Header.Get("Upgrade"); v != "" {
- return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"])
- }
- if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
- return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
- }
- if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
- return fmt.Errorf("http2: invalid Connection request header: %q", vv)
- }
- return nil
-}
-
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
@@ -1215,6 +1312,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
+ return cc.roundTrip(req, nil)
+}
+
+func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context()
cs := &clientStream{
cc: cc,
@@ -1229,7 +1330,10 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
- go cs.doRequest(req)
+
+ cs.requestedGzip = httpcommon.IsRequestGzip(req.Method, req.Header, cc.t.disableCompression())
+
+ go cs.doRequest(req, streamf)
waitDone := func() error {
select {
@@ -1322,11 +1426,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
// doRequest runs for the duration of the request lifetime.
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
-func (cs *clientStream) doRequest(req *http.Request) {
- err := cs.writeRequest(req)
+func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) {
+ cs.cc.t.markNewGoroutine()
+ err := cs.writeRequest(req, streamf)
cs.cleanupWriteRequest(err)
}
+var errExtendedConnectNotSupported = errors.New("net/http: extended connect not supported by peer")
+
// writeRequest sends a request.
//
// It returns nil after the request is written, the response read,
@@ -1334,12 +1441,15 @@ func (cs *clientStream) doRequest(req *http.Request) {
//
// It returns non-nil if the request ends otherwise.
// If the returned error is StreamError, the error Code may be used in resetting the stream.
-func (cs *clientStream) writeRequest(req *http.Request) (err error) {
+func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStream)) (err error) {
cc := cs.cc
ctx := cs.ctx
- if err := checkConnHeaders(req); err != nil {
- return err
+ // wait for setting frames to be received, a server can change this value later,
+ // but we just wait for the first settings frame
+ var isExtendedConnect bool
+ if req.Method == "CONNECT" && req.Header.Get(":protocol") != "" {
+ isExtendedConnect = true
}
// Acquire the new-request lock by writing to reqHeaderMu.
@@ -1348,6 +1458,18 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
+ if isExtendedConnect {
+ select {
+ case <-cs.reqCancel:
+ return errRequestCanceled
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cc.seenSettingsChan:
+ if !cc.extendedConnectAllowed {
+ return errExtendedConnectNotSupported
+ }
+ }
+ }
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1372,24 +1494,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
- // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
- if !cc.t.disableCompression() &&
- req.Header.Get("Accept-Encoding") == "" &&
- req.Header.Get("Range") == "" &&
- !cs.isHead {
- // Request gzip only, not deflate. Deflate is ambiguous and
- // not as universally supported anyway.
- // See: https://zlib.net/zlib_faq.html#faq39
- //
- // Note that we don't request this for HEAD requests,
- // due to a bug in nginx:
- // http://trac.nginx.org/nginx/ticket/358
- // https://golang.org/issue/5522
- //
- // We don't request gzip if the request is for a range, since
- // auto-decoding a portion of a gzipped document will just fail
- // anyway. See https://golang.org/issue/8923
- cs.requestedGzip = true
+ if streamf != nil {
+ streamf(cs)
}
continueTimeout := cc.t.expectContinueTimeout()
@@ -1452,9 +1558,9 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
- timer := time.NewTimer(d)
+ timer := cc.t.newTimer(d)
defer timer.Stop()
- respHeaderTimer = timer.C
+ respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
@@ -1502,26 +1608,39 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error {
// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
// sent by writeRequestBody below, along with any Trailers,
// again in form HEADERS{1}, CONTINUATION{0,})
- trailers, err := commaSeparatedTrailers(req)
+ cc.hbuf.Reset()
+ res, err := encodeRequestHeaders(req, cs.requestedGzip, cc.peerMaxHeaderListSize, func(name, value string) {
+ cc.writeHeader(name, value)
+ })
if err != nil {
- return err
- }
- hasTrailers := trailers != ""
- contentLen := actualContentLength(req)
- hasBody := contentLen != 0
- hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
- if err != nil {
- return err
+ return fmt.Errorf("http2: %w", err)
}
+ hdrs := cc.hbuf.Bytes()
// Write the request.
- endStream := !hasBody && !hasTrailers
+ endStream := !res.HasBody && !res.HasTrailers
cs.sentHeaders = true
err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
traceWroteHeaders(cs.trace)
return err
}
+func encodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) {
+ return httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
+ Request: httpcommon.Request{
+ Header: req.Header,
+ Trailer: req.Trailer,
+ URL: req.URL,
+ Host: req.Host,
+ Method: req.Method,
+ ActualContentLength: actualContentLength(req),
+ },
+ AddGzipHeader: addGzipHeader,
+ PeerMaxHeaderListSize: peerMaxHeaderListSize,
+ DefaultUserAgent: defaultUserAgent,
+ }, headerf)
+}
+
// cleanupWriteRequest performs post-request tasks.
//
// If err (the result of writeRequest) is non-nil and the stream is not closed,
@@ -1545,6 +1664,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
cs.reqBodyClosed = make(chan struct{})
}
bodyClosed := cs.reqBodyClosed
+ closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
cc.mu.Unlock()
if mustCloseBody {
cs.reqBody.Close()
@@ -1569,16 +1689,44 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
if cs.sentHeaders {
if se, ok := err.(StreamError); ok {
if se.Cause != errFromPeer {
- cc.writeStreamReset(cs.ID, se.Code, err)
+ cc.writeStreamReset(cs.ID, se.Code, false, err)
}
} else {
- cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
+ // We're cancelling an in-flight request.
+ //
+ // This could be due to the server becoming unresponsive.
+ // To avoid sending too many requests on a dead connection,
+ // we let the request continue to consume a concurrency slot
+ // until we can confirm the server is still responding.
+ // We do this by sending a PING frame along with the RST_STREAM
+ // (unless a ping is already in flight).
+ //
+ // For simplicity, we don't bother tracking the PING payload:
+ // We reset cc.pendingResets any time we receive a PING ACK.
+ //
+ // We skip this if the conn is going to be closed on idle,
+ // because it's short lived and will probably be closed before
+ // we get the ping response.
+ ping := false
+ if !closeOnIdle {
+ cc.mu.Lock()
+ // rstStreamPingsBlocked works around a gRPC behavior:
+ // see comment on the field for details.
+ if !cc.rstStreamPingsBlocked {
+ if cc.pendingResets == 0 {
+ ping = true
+ }
+ cc.pendingResets++
+ }
+ cc.mu.Unlock()
+ }
+ cc.writeStreamReset(cs.ID, ErrCodeCancel, ping, err)
}
}
cs.bufPipe.CloseWithError(err) // no-op if already closed
} else {
if cs.sentHeaders && !cs.sentEndStream {
- cc.writeStreamReset(cs.ID, ErrCodeNo, nil)
+ cc.writeStreamReset(cs.ID, ErrCodeNo, false, nil)
}
cs.bufPipe.CloseWithError(errRequestCanceled)
}
@@ -1600,12 +1748,17 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
// Must hold cc.mu.
func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
for {
- cc.lastActive = time.Now()
+ if cc.closed && cc.nextStreamID == 1 && cc.streamsReserved == 0 {
+ // This is the very first request sent to this connection.
+ // Return a fatal error which aborts the retry loop.
+ return errClientConnNotEstablished
+ }
+ cc.lastActive = cc.t.now()
if cc.closed || !cc.canTakeNewRequestLocked() {
return errClientConnUnusable
}
cc.lastIdle = time.Time{}
- if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) {
+ if cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams) {
return nil
}
cc.pendingRequests++
@@ -1875,203 +2028,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
}
}
-var errNilRequestURL = errors.New("http2: Request.URI is nil")
-
-// requires cc.wmu be held.
-func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
- cc.hbuf.Reset()
- if req.URL == nil {
- return nil, errNilRequestURL
- }
-
- host := req.Host
- if host == "" {
- host = req.URL.Host
- }
- host, err := httpguts.PunycodeHostPort(host)
- if err != nil {
- return nil, err
- }
- if !httpguts.ValidHostHeader(host) {
- return nil, errors.New("http2: invalid Host header")
- }
-
- var path string
- if req.Method != "CONNECT" {
- path = req.URL.RequestURI()
- if !validPseudoPath(path) {
- orig := path
- path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
- if !validPseudoPath(path) {
- if req.URL.Opaque != "" {
- return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
- } else {
- return nil, fmt.Errorf("invalid request :path %q", orig)
- }
- }
- }
- }
-
- // Check for any invalid headers and return an error before we
- // potentially pollute our hpack state. (We want to be able to
- // continue to reuse the hpack encoder for future requests)
- for k, vv := range req.Header {
- if !httpguts.ValidHeaderFieldName(k) {
- return nil, fmt.Errorf("invalid HTTP header name %q", k)
- }
- for _, v := range vv {
- if !httpguts.ValidHeaderFieldValue(v) {
- // Don't include the value in the error, because it may be sensitive.
- return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
- }
- }
- }
-
- enumerateHeaders := func(f func(name, value string)) {
- // 8.1.2.3 Request Pseudo-Header Fields
- // The :path pseudo-header field includes the path and query parts of the
- // target URI (the path-absolute production and optionally a '?' character
- // followed by the query production, see Sections 3.3 and 3.4 of
- // [RFC3986]).
- f(":authority", host)
- m := req.Method
- if m == "" {
- m = http.MethodGet
- }
- f(":method", m)
- if req.Method != "CONNECT" {
- f(":path", path)
- f(":scheme", req.URL.Scheme)
- }
- if trailers != "" {
- f("trailer", trailers)
- }
-
- var didUA bool
- for k, vv := range req.Header {
- if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
- // Host is :authority, already sent.
- // Content-Length is automatic, set below.
- continue
- } else if asciiEqualFold(k, "connection") ||
- asciiEqualFold(k, "proxy-connection") ||
- asciiEqualFold(k, "transfer-encoding") ||
- asciiEqualFold(k, "upgrade") ||
- asciiEqualFold(k, "keep-alive") {
- // Per 8.1.2.2 Connection-Specific Header
- // Fields, don't send connection-specific
- // fields. We have already checked if any
- // are error-worthy so just ignore the rest.
- continue
- } else if asciiEqualFold(k, "user-agent") {
- // Match Go's http1 behavior: at most one
- // User-Agent. If set to nil or empty string,
- // then omit it. Otherwise if not mentioned,
- // include the default (below).
- didUA = true
- if len(vv) < 1 {
- continue
- }
- vv = vv[:1]
- if vv[0] == "" {
- continue
- }
- } else if asciiEqualFold(k, "cookie") {
- // Per 8.1.2.5 To allow for better compression efficiency, the
- // Cookie header field MAY be split into separate header fields,
- // each with one or more cookie-pairs.
- for _, v := range vv {
- for {
- p := strings.IndexByte(v, ';')
- if p < 0 {
- break
- }
- f("cookie", v[:p])
- p++
- // strip space after semicolon if any.
- for p+1 <= len(v) && v[p] == ' ' {
- p++
- }
- v = v[p:]
- }
- if len(v) > 0 {
- f("cookie", v)
- }
- }
- continue
- }
-
- for _, v := range vv {
- f(k, v)
- }
- }
- if shouldSendReqContentLength(req.Method, contentLength) {
- f("content-length", strconv.FormatInt(contentLength, 10))
- }
- if addGzipHeader {
- f("accept-encoding", "gzip")
- }
- if !didUA {
- f("user-agent", defaultUserAgent)
- }
- }
-
- // Do a first pass over the headers counting bytes to ensure
- // we don't exceed cc.peerMaxHeaderListSize. This is done as a
- // separate pass before encoding the headers to prevent
- // modifying the hpack state.
- hlSize := uint64(0)
- enumerateHeaders(func(name, value string) {
- hf := hpack.HeaderField{Name: name, Value: value}
- hlSize += uint64(hf.Size())
- })
-
- if hlSize > cc.peerMaxHeaderListSize {
- return nil, errRequestHeaderListSize
- }
-
- trace := httptrace.ContextClientTrace(req.Context())
- traceHeaders := traceHasWroteHeaderField(trace)
-
- // Header list size is ok. Write the headers.
- enumerateHeaders(func(name, value string) {
- name, ascii := lowerHeader(name)
- if !ascii {
- // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
- // field names have to be ASCII characters (just as in HTTP/1.x).
- return
- }
- cc.writeHeader(name, value)
- if traceHeaders {
- traceWroteHeaderField(trace, name, value)
- }
- })
-
- return cc.hbuf.Bytes(), nil
-}
-
-// shouldSendReqContentLength reports whether the http2.Transport should send
-// a "content-length" request header. This logic is basically a copy of the net/http
-// transferWriter.shouldSendContentLength.
-// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
-// -1 means unknown.
-func shouldSendReqContentLength(method string, contentLength int64) bool {
- if contentLength > 0 {
- return true
- }
- if contentLength < 0 {
- return false
- }
- // For zero bodies, whether we send a content-length depends on the method.
- // It also kinda doesn't matter for http2 either way, with END_STREAM.
- switch method {
- case "POST", "PUT", "PATCH":
- return true
- default:
- return false
- }
-}
-
// requires cc.wmu be held.
func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
cc.hbuf.Reset()
@@ -2088,7 +2044,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
}
for k, vv := range trailer {
- lowKey, ascii := lowerHeader(k)
+ lowKey, ascii := httpcommon.LowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
@@ -2120,7 +2076,7 @@ type resAndError struct {
func (cc *ClientConn) addStreamLocked(cs *clientStream) {
cs.flow.add(int32(cc.initialWindowSize))
cs.flow.setConnFlow(&cc.flow)
- cs.inflow.init(transportDefaultStreamFlow)
+ cs.inflow.init(cc.initialStreamRecvWindowSize)
cs.ID = cc.nextStreamID
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
@@ -2136,10 +2092,10 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
if len(cc.streams) != slen-1 {
panic("forgetting unknown stream id")
}
- cc.lastActive = time.Now()
+ cc.lastActive = cc.t.now()
if len(cc.streams) == 0 && cc.idleTimer != nil {
cc.idleTimer.Reset(cc.idleTimeout)
- cc.lastIdle = time.Now()
+ cc.lastIdle = cc.t.now()
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
@@ -2165,6 +2121,7 @@ type clientConnReadLoop struct {
// readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *ClientConn) readLoop() {
+ cc.t.markNewGoroutine()
rl := &clientConnReadLoop{cc: cc}
defer rl.cleanup()
cc.readerErr = rl.run()
@@ -2198,7 +2155,6 @@ func isEOFOrNetReadError(err error) bool {
func (rl *clientConnReadLoop) cleanup() {
cc := rl.cc
- cc.t.connPool().MarkDead(cc)
defer cc.closeConn()
defer close(cc.readerDone)
@@ -2222,6 +2178,27 @@ func (rl *clientConnReadLoop) cleanup() {
}
cc.closed = true
+ // If the connection has never been used, and has been open for only a short time,
+ // leave it in the connection pool for a little while.
+ //
+ // This avoids a situation where new connections are constantly created,
+ // added to the pool, fail, and are removed from the pool, without any error
+ // being surfaced to the user.
+ unusedWaitTime := 5 * time.Second
+ if cc.idleTimeout > 0 && unusedWaitTime > cc.idleTimeout {
+ unusedWaitTime = cc.idleTimeout
+ }
+ idleTime := cc.t.now().Sub(cc.lastActive)
+ if atomic.LoadUint32(&cc.atomicReused) == 0 && idleTime < unusedWaitTime && !cc.closedOnIdle {
+ cc.idleTimer = cc.t.afterFunc(unusedWaitTime-idleTime, func() {
+ cc.t.connPool().MarkDead(cc)
+ })
+ } else {
+ cc.mu.Unlock() // avoid any deadlocks in MarkDead
+ cc.t.connPool().MarkDead(cc)
+ cc.mu.Lock()
+ }
+
for _, cs := range cc.streams {
select {
case <-cs.peerClosed:
@@ -2233,6 +2210,13 @@ func (rl *clientConnReadLoop) cleanup() {
}
cc.cond.Broadcast()
cc.mu.Unlock()
+
+ if !cc.seenSettings {
+ // If we have a pending request that wants extended CONNECT,
+ // let it continue and fail with the connection error.
+ cc.extendedConnectAllowed = true
+ close(cc.seenSettingsChan)
+ }
}
// countReadFrameError calls Transport.CountError with a string
@@ -2265,11 +2249,10 @@ func (cc *ClientConn) countReadFrameError(err error) {
func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
- readIdleTimeout := cc.t.ReadIdleTimeout
- var t *time.Timer
+ readIdleTimeout := cc.readIdleTimeout
+ var t timer
if readIdleTimeout != 0 {
- t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
- defer t.Stop()
+ t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -2280,7 +2263,7 @@ func (rl *clientConnReadLoop) run() error {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
if se, ok := err.(StreamError); ok {
- if cs := rl.streamByID(se.StreamID); cs != nil {
+ if cs := rl.streamByID(se.StreamID, notHeaderOrDataFrame); cs != nil {
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
@@ -2332,7 +2315,7 @@ func (rl *clientConnReadLoop) run() error {
}
func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, headerOrDataFrame)
if cs == nil {
// We'd get here if we canceled a request while the
// server had its response still in flight. So if this
@@ -2420,7 +2403,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range regularFields {
- key := canonicalHeader(hf.Name)
+ key := httpcommon.CanonicalHeader(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
@@ -2428,7 +2411,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
- t[canonicalHeader(v)] = nil
+ t[httpcommon.CanonicalHeader(v)] = nil
})
} else {
vv := header[key]
@@ -2450,15 +2433,34 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
if f.StreamEnded() {
return nil, errors.New("1xx informational response with END_STREAM flag")
}
- cs.num1xx++
- const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
- if cs.num1xx > max1xxResponses {
- return nil, errors.New("http2: too many 1xx informational responses")
- }
if fn := cs.get1xxTraceFunc(); fn != nil {
+ // If the 1xx response is being delivered to the user,
+ // then they're responsible for limiting the number
+ // of responses.
if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
return nil, err
}
+ } else {
+ // If the user didn't examine the 1xx response, then we
+ // limit the size of all 1xx headers.
+ //
+ // This differs a bit from the HTTP/1 implementation, which
+ // limits the size of all 1xx headers plus the final response.
+ // Use the larger limit of MaxHeaderListSize and
+ // net/http.Transport.MaxResponseHeaderBytes.
+ limit := int64(cs.cc.t.maxHeaderListSize())
+ if t1 := cs.cc.t.t1; t1 != nil && t1.MaxResponseHeaderBytes > limit {
+ limit = t1.MaxResponseHeaderBytes
+ }
+ for _, h := range f.Fields {
+ cs.totalHeaderSize += int64(h.Size())
+ }
+ if cs.totalHeaderSize > limit {
+ if VerboseLogs {
+ log.Printf("http2: 1xx informational responses too large")
+ }
+ return nil, errors.New("header list too large")
+ }
}
if statusCode == 100 {
traceGot100Continue(cs.trace)
@@ -2533,7 +2535,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr
trailer := make(http.Header)
for _, hf := range f.RegularFields() {
- key := canonicalHeader(hf.Name)
+ key := httpcommon.CanonicalHeader(hf.Name)
trailer[key] = append(trailer[key], hf.Value)
}
cs.trailer = trailer
@@ -2642,7 +2644,7 @@ func (b transportResponseBody) Close() error {
func (rl *clientConnReadLoop) processData(f *DataFrame) error {
cc := rl.cc
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, headerOrDataFrame)
data := f.Data()
if cs == nil {
cc.mu.Lock()
@@ -2684,7 +2686,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
})
return nil
}
- if !cs.firstByte {
+ if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
@@ -2777,9 +2779,22 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
cs.abortStream(err)
}
-func (rl *clientConnReadLoop) streamByID(id uint32) *clientStream {
+// Constants passed to streamByID for documentation purposes.
+const (
+ headerOrDataFrame = true
+ notHeaderOrDataFrame = false
+)
+
+// streamByID returns the stream with the given id, or nil if no stream has that id.
+// If headerOrData is true, it clears rst.StreamPingsBlocked.
+func (rl *clientConnReadLoop) streamByID(id uint32, headerOrData bool) *clientStream {
rl.cc.mu.Lock()
defer rl.cc.mu.Unlock()
+ if headerOrData {
+ // Work around an unfortunate gRPC behavior.
+ // See comment on ClientConn.rstStreamPingsBlocked for details.
+ rl.cc.rstStreamPingsBlocked = false
+ }
cs := rl.cc.streams[id]
if cs != nil && !cs.readAborted {
return cs
@@ -2873,6 +2888,21 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
case SettingHeaderTableSize:
cc.henc.SetMaxDynamicTableSize(s.Val)
cc.peerMaxHeaderTableSize = s.Val
+ case SettingEnableConnectProtocol:
+ if err := s.Valid(); err != nil {
+ return err
+ }
+ // If the peer wants to send us SETTINGS_ENABLE_CONNECT_PROTOCOL,
+ // we require that it do so in the first SETTINGS frame.
+ //
+ // When we attempt to use extended CONNECT, we wait for the first
+ // SETTINGS frame to see if the server supports it. If we let the
+ // server enable the feature with a later SETTINGS frame, then
+ // users will see inconsistent results depending on whether we've
+ // seen that frame or not.
+ if !cc.seenSettings {
+ cc.extendedConnectAllowed = s.Val == 1
+ }
default:
cc.vlogf("Unhandled Setting: %v", s)
}
@@ -2890,6 +2920,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
// connection can establish to our default.
cc.maxConcurrentStreams = defaultMaxConcurrentStreams
}
+ close(cc.seenSettingsChan)
cc.seenSettings = true
}
@@ -2898,7 +2929,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
cc := rl.cc
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, notHeaderOrDataFrame)
if f.StreamID != 0 && cs == nil {
return nil
}
@@ -2911,6 +2942,15 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
fl = &cs.flow
}
if !fl.add(int32(f.Increment)) {
+ // For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR
+ if cs != nil {
+ rl.endStreamError(cs, StreamError{
+ StreamID: f.StreamID,
+ Code: ErrCodeFlowControl,
+ })
+ return nil
+ }
+
return ConnectionError(ErrCodeFlowControl)
}
cc.cond.Broadcast()
@@ -2918,7 +2958,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
}
func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, notHeaderOrDataFrame)
if cs == nil {
// TODO: return error if server tries to RST_STREAM an idle stream
return nil
@@ -2955,24 +2995,26 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
- errc := make(chan error, 1)
+ var pingError error
+ errc := make(chan struct{})
go func() {
+ cc.t.markNewGoroutine()
cc.wmu.Lock()
defer cc.wmu.Unlock()
- if err := cc.fr.WritePing(false, p); err != nil {
- errc <- err
+ if pingError = cc.fr.WritePing(false, p); pingError != nil {
+ close(errc)
return
}
- if err := cc.bw.Flush(); err != nil {
- errc <- err
+ if pingError = cc.bw.Flush(); pingError != nil {
+ close(errc)
return
}
}()
select {
case <-c:
return nil
- case err := <-errc:
- return err
+ case <-errc:
+ return pingError
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
@@ -2991,6 +3033,12 @@ func (rl *clientConnReadLoop) processPing(f *PingFrame) error {
close(c)
delete(cc.pings, f.Data)
}
+ if cc.pendingResets > 0 {
+ // See clientStream.cleanupWriteRequest.
+ cc.pendingResets = 0
+ cc.rstStreamPingsBlocked = true
+ cc.cond.Broadcast()
+ }
return nil
}
cc := rl.cc
@@ -3013,20 +3061,27 @@ func (rl *clientConnReadLoop) processPushPromise(f *PushPromiseFrame) error {
return ConnectionError(ErrCodeProtocol)
}
-func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
+// writeStreamReset sends a RST_STREAM frame.
+// When ping is true, it also sends a PING frame with a random payload.
+func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, ping bool, err error) {
// TODO: map err to more interesting error codes, once the
// HTTP community comes up with some. But currently for
// RST_STREAM there's no equivalent to GOAWAY frame's debug
// data, and the error codes are all pretty vague ("cancel").
cc.wmu.Lock()
cc.fr.WriteRSTStream(streamID, code)
+ if ping {
+ var payload [8]byte
+ rand.Read(payload[:])
+ cc.fr.WritePing(false, payload)
+ }
cc.bw.Flush()
cc.wmu.Unlock()
}
var (
errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
- errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit")
+ errRequestHeaderListSize = httpcommon.ErrRequestHeaderListSize
)
func (cc *ClientConn) logf(format string, args ...interface{}) {
@@ -3141,9 +3196,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}
func (t *Transport) idleConnTimeout() time.Duration {
+ // to keep things backwards compatible, we use non-zero values of
+ // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
+ // http1 transport, followed by 0
+ if t.IdleConnTimeout != 0 {
+ return t.IdleConnTimeout
+ }
+
if t.t1 != nil {
return t.t1.IdleConnTimeout
}
+
return 0
}
@@ -3165,7 +3228,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
cc.mu.Lock()
ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
- ci.IdleTime = time.Since(cc.lastActive)
+ ci.IdleTime = cc.t.timeSince(cc.lastActive)
}
cc.mu.Unlock()
@@ -3201,3 +3264,24 @@ func traceFirstResponseByte(trace *httptrace.ClientTrace) {
trace.GotFirstResponseByte()
}
}
+
+func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
+ if trace != nil {
+ return trace.Got1xxResponse
+ }
+ return nil
+}
+
+// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
+// connection.
+func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
+ dialer := &tls.Dialer{
+ Config: cfg,
+ }
+ cn, err := dialer.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
+ return tlsCn, nil
+}
diff --git a/http2/transport_go117_test.go b/http2/transport_go117_test.go
deleted file mode 100644
index f5d4e0c1a6..0000000000
--- a/http2/transport_go117_test.go
+++ /dev/null
@@ -1,169 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.17
-// +build go1.17
-
-package http2
-
-import (
- "context"
- "crypto/tls"
- "errors"
- "net/http"
- "net/http/httptest"
-
- "testing"
-)
-
-func TestTransportDialTLSContext(t *testing.T) {
- blockCh := make(chan struct{})
- serverTLSConfigFunc := func(ts *httptest.Server) {
- ts.Config.TLSConfig = &tls.Config{
- // Triggers the server to request the clients certificate
- // during TLS handshake.
- ClientAuth: tls.RequestClientCert,
- }
- }
- ts := newServerTester(t,
- func(w http.ResponseWriter, r *http.Request) {},
- optOnlyServer,
- serverTLSConfigFunc,
- )
- defer ts.Close()
- tr := &Transport{
- TLSClientConfig: &tls.Config{
- GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- // Tests that the context provided to `req` is
- // passed into this function.
- close(blockCh)
- <-cri.Context().Done()
- return nil, cri.Context().Err()
- },
- InsecureSkipVerify: true,
- },
- }
- defer tr.CloseIdleConnections()
- req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
- if err != nil {
- t.Fatal(err)
- }
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- req = req.WithContext(ctx)
- errCh := make(chan error)
- go func() {
- defer close(errCh)
- res, err := tr.RoundTrip(req)
- if err != nil {
- errCh <- err
- return
- }
- res.Body.Close()
- }()
- // Wait for GetClientCertificate handler to be called
- <-blockCh
- // Cancel the context
- cancel()
- // Expect the cancellation error here
- err = <-errCh
- if err == nil {
- t.Fatal("cancelling context during client certificate fetch did not error as expected")
- return
- }
- if !errors.Is(err, context.Canceled) {
- t.Fatalf("unexpected error returned after cancellation: %v", err)
- }
-}
-
-// TestDialRaceResumesDial tests that, given two concurrent requests
-// to the same address, when the first Dial is interrupted because
-// the first request's context is cancelled, the second request
-// resumes the dial automatically.
-func TestDialRaceResumesDial(t *testing.T) {
- blockCh := make(chan struct{})
- serverTLSConfigFunc := func(ts *httptest.Server) {
- ts.Config.TLSConfig = &tls.Config{
- // Triggers the server to request the clients certificate
- // during TLS handshake.
- ClientAuth: tls.RequestClientCert,
- }
- }
- ts := newServerTester(t,
- func(w http.ResponseWriter, r *http.Request) {},
- optOnlyServer,
- serverTLSConfigFunc,
- )
- defer ts.Close()
- tr := &Transport{
- TLSClientConfig: &tls.Config{
- GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- select {
- case <-blockCh:
- // If we already errored, return without error.
- return &tls.Certificate{}, nil
- default:
- }
- close(blockCh)
- <-cri.Context().Done()
- return nil, cri.Context().Err()
- },
- InsecureSkipVerify: true,
- },
- }
- defer tr.CloseIdleConnections()
- req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
- if err != nil {
- t.Fatal(err)
- }
- // Create two requests with independent cancellation.
- ctx1, cancel1 := context.WithCancel(context.Background())
- defer cancel1()
- req1 := req.WithContext(ctx1)
- ctx2, cancel2 := context.WithCancel(context.Background())
- defer cancel2()
- req2 := req.WithContext(ctx2)
- errCh := make(chan error)
- go func() {
- res, err := tr.RoundTrip(req1)
- if err != nil {
- errCh <- err
- return
- }
- res.Body.Close()
- }()
- successCh := make(chan struct{})
- go func() {
- // Don't start request until first request
- // has initiated the handshake.
- <-blockCh
- res, err := tr.RoundTrip(req2)
- if err != nil {
- errCh <- err
- return
- }
- res.Body.Close()
- // Close successCh to indicate that the second request
- // made it to the server successfully.
- close(successCh)
- }()
- // Wait for GetClientCertificate handler to be called
- <-blockCh
- // Cancel the context first
- cancel1()
- // Expect the cancellation error here
- err = <-errCh
- if err == nil {
- t.Fatal("cancelling context during client certificate fetch did not error as expected")
- return
- }
- if !errors.Is(err, context.Canceled) {
- t.Fatalf("unexpected error returned after cancellation: %v", err)
- }
- select {
- case err := <-errCh:
- t.Fatalf("unexpected second error: %v", err)
- case <-successCh:
- }
-}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 99848485b9..1eeb76e06e 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -16,7 +16,6 @@ import (
"fmt"
"io"
"io/fs"
- "io/ioutil"
"log"
"math/rand"
"net"
@@ -95,6 +94,88 @@ func startH2cServer(t *testing.T) net.Listener {
return l
}
+func TestIdleConnTimeout(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ idleConnTimeout time.Duration
+ wait time.Duration
+ baseTransport *http.Transport
+ wantNewConn bool
+ }{{
+ name: "NoExpiry",
+ idleConnTimeout: 2 * time.Second,
+ wait: 1 * time.Second,
+ baseTransport: nil,
+ wantNewConn: false,
+ }, {
+ name: "H2TransportTimeoutExpires",
+ idleConnTimeout: 1 * time.Second,
+ wait: 2 * time.Second,
+ baseTransport: nil,
+ wantNewConn: true,
+ }, {
+ name: "H1TransportTimeoutExpires",
+ idleConnTimeout: 0 * time.Second,
+ wait: 1 * time.Second,
+ baseTransport: &http.Transport{
+ IdleConnTimeout: 2 * time.Second,
+ },
+ wantNewConn: false,
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ tt := newTestTransport(t, func(tr *Transport) {
+ tr.IdleConnTimeout = test.idleConnTimeout
+ })
+ var tc *testClientConn
+ for i := 0; i < 3; i++ {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // This request happens on a new conn if it's the first request
+ // (and there is no cached conn), or if the test timeout is long
+ // enough that old conns are being closed.
+ wantConn := i == 0 || test.wantNewConn
+ if has := tt.hasConn(); has != wantConn {
+ t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn)
+ }
+ if wantConn {
+ tc = tt.getConn()
+ // Read client's SETTINGS and first WINDOW_UPDATE,
+ // send our SETTINGS.
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.writeSettings()
+ }
+ if tt.hasConn() {
+ t.Fatalf("request %v: Transport has more than one conn", i)
+ }
+
+ // Respond to the client's request.
+ hf := readFrame[*HeadersFrame](t, tc)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+
+ // If this was a newly-accepted conn, read the SETTINGS ACK.
+ if wantConn {
+ tc.wantFrameType(FrameSettings) // ACK to our settings
+ }
+
+ tt.advance(test.wait)
+ if got, want := tc.isClosed(), test.wantNewConn; got != want {
+ t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestTransportH2c(t *testing.T) {
l := startH2cServer(t)
defer l.Close()
@@ -124,7 +205,7 @@ func TestTransportH2c(t *testing.T) {
if res.ProtoMajor != 2 {
t.Fatal("proto not h2c")
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -138,15 +219,14 @@ func TestTransportH2c(t *testing.T) {
func TestTransport(t *testing.T) {
const body = "sup"
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, body)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- u, err := url.Parse(st.ts.URL)
+ u, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -182,7 +262,7 @@ func TestTransport(t *testing.T) {
if res.TLS == nil {
t.Errorf("%d: Response.TLS = nil; want non-nil", i)
}
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("%d: Body read: %v", i, err)
} else if string(slurp) != body {
@@ -193,26 +273,27 @@ func TestTransport(t *testing.T) {
}
func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
- }, optOnlyServer, func(c net.Conn, st http.ConnState) {
- t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
+ }, func(ts *httptest.Server) {
+ ts.Config.ConnState = func(c net.Conn, st http.ConnState) {
+ t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
+ }
})
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
if useClient {
tr.ConnPool = noDialClientConnPool{new(clientConnPool)}
}
defer tr.CloseIdleConnections()
get := func() string {
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
modReq(req)
var res *http.Response
if useClient {
- c := st.ts.Client()
+ c := ts.Client()
ConfigureTransports(c.Transport.(*http.Transport))
res, err = c.Do(req)
} else {
@@ -222,7 +303,7 @@ func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq fun
t.Fatal(err)
}
defer res.Body.Close()
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body read: %v", err)
}
@@ -276,15 +357,12 @@ func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
- }, func(s *httptest.Server) {
- s.EnableHTTP2 = true
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
- client := st.ts.Client()
+ client := ts.Client()
ConfigureTransports(client.Transport.(*http.Transport))
var (
@@ -307,7 +385,7 @@ func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
}
},
}
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -350,9 +428,8 @@ func (c *testNetConn) Close() error {
// Tests that the Transport only keeps one pending dial open per destination address.
// https://golang.org/issue/13397
func TestTransportGroupsPendingDials(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- }, optOnlyServer)
- defer st.Close()
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ })
var (
mu sync.Mutex
dialCount int
@@ -381,7 +458,7 @@ func TestTransportGroupsPendingDials(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Error(err)
return
@@ -404,35 +481,21 @@ func TestTransportGroupsPendingDials(t *testing.T) {
}
}
-func retry(tries int, delay time.Duration, fn func() error) error {
- var err error
- for i := 0; i < tries; i++ {
- err = fn()
- if err == nil {
- return nil
- }
- time.Sleep(delay)
- }
- return err
-}
-
func TestTransportAbortClosesPipes(t *testing.T) {
shutdown := make(chan struct{})
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
<-shutdown
},
- optOnlyServer,
)
- defer st.Close()
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
errCh := make(chan error)
go func() {
defer close(errCh)
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
errCh <- err
return
@@ -443,8 +506,8 @@ func TestTransportAbortClosesPipes(t *testing.T) {
return
}
defer res.Body.Close()
- st.closeConn()
- _, err = ioutil.ReadAll(res.Body)
+ ts.CloseClientConnections()
+ _, err = io.ReadAll(res.Body)
if err == nil {
errCh <- errors.New("expected error from res.Body.Read")
return
@@ -466,13 +529,11 @@ func TestTransportAbortClosesPipes(t *testing.T) {
// could be a table-driven test with extra goodies.
func TestTransportPath(t *testing.T) {
gotc := make(chan *url.URL, 1)
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
gotc <- r.URL
},
- optOnlyServer,
)
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -480,7 +541,7 @@ func TestTransportPath(t *testing.T) {
path = "/testpath"
query = "q=1"
)
- surl := st.ts.URL + path + "?" + query
+ surl := ts.URL + path + "?" + query
req, err := http.NewRequest("POST", surl, nil)
if err != nil {
t.Fatal(err)
@@ -574,18 +635,16 @@ func TestTransportBody(t *testing.T) {
err error
}
gotc := make(chan reqInfo, 1)
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
- slurp, err := ioutil.ReadAll(r.Body)
+ slurp, err := io.ReadAll(r.Body)
if err != nil {
gotc <- reqInfo{err: err}
} else {
gotc <- reqInfo{req: r, slurp: slurp}
}
},
- optOnlyServer,
)
- defer st.Close()
for i, tt := range bodyTests {
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
@@ -595,7 +654,7 @@ func TestTransportBody(t *testing.T) {
if tt.noContentLen {
body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
}
- req, err := http.NewRequest("POST", st.ts.URL, body)
+ req, err := http.NewRequest("POST", ts.URL, body)
if err != nil {
t.Fatalf("#%d: %v", i, err)
}
@@ -635,15 +694,13 @@ func TestTransportDialTLS(t *testing.T) {
var mu sync.Mutex // guards following
var gotReq, didDial bool
- ts := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
gotReq = true
mu.Unlock()
},
- optOnlyServer,
)
- defer ts.Close()
tr := &Transport{
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
mu.Lock()
@@ -659,7 +716,7 @@ func TestTransportDialTLS(t *testing.T) {
}
defer tr.CloseIdleConnections()
client := &http.Client{Transport: tr}
- res, err := client.Get(ts.ts.URL)
+ res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -694,18 +751,17 @@ func TestConfigureTransport(t *testing.T) {
}
// And does it work?
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.Proto)
- }, optOnlyServer)
- defer st.Close()
+ })
t1.TLSClientConfig.InsecureSkipVerify = true
c := &http.Client{Transport: t1}
- res, err := c.Get(st.ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -740,53 +796,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) {
return
}
-type clientTester struct {
- t *testing.T
- tr *Transport
- sc, cc net.Conn // server and client conn
- fr *Framer // server's framer
- settings *SettingsFrame
- client func() error
- server func() error
-}
-
-func newClientTester(t *testing.T) *clientTester {
- var dialOnce struct {
- sync.Mutex
- dialed bool
- }
- ct := &clientTester{
- t: t,
- }
- ct.tr = &Transport{
- TLSClientConfig: tlsConfigInsecure,
- DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- dialOnce.Lock()
- defer dialOnce.Unlock()
- if dialOnce.dialed {
- return nil, errors.New("only one dial allowed in test mode")
- }
- dialOnce.dialed = true
- return ct.cc, nil
- },
- }
-
- ln := newLocalListener(t)
- cc, err := net.Dial("tcp", ln.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- sc, err := ln.Accept()
- if err != nil {
- t.Fatal(err)
- }
- ln.Close()
- ct.cc = cc
- ct.sc = sc
- ct.fr = NewFramer(sc, sc)
- return ct
-}
-
func newLocalListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err == nil {
@@ -799,302 +808,88 @@ func newLocalListener(t *testing.T) net.Listener {
return ln
}
-func (ct *clientTester) greet(settings ...Setting) {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- ct.t.Fatalf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- ct.t.Fatalf("Reading client settings frame: %v", err)
- }
- var ok bool
- if ct.settings, ok = f.(*SettingsFrame); !ok {
- ct.t.Fatalf("Wanted client settings frame; got %v", f)
- }
- if err := ct.fr.WriteSettings(settings...); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
-}
-
-func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return nil, err
- }
- if _, ok := f.(*SettingsFrame); ok {
- continue
- }
- return f, nil
- }
-}
-
-// writeReadPing sends a PING and immediately reads the PING ACK.
-// It will fail if any other unread data was pending on the connection,
-// aside from SETTINGS frames.
-func (ct *clientTester) writeReadPing() error {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- if err := ct.fr.WritePing(false, data); err != nil {
- return fmt.Errorf("Error writing PING: %v", err)
- }
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
- }
- p, ok := f.(*PingFrame)
- if !ok {
- return fmt.Errorf("got a %v, want a PING ACK", f)
- }
- if p.Flags&FlagPingAck == 0 {
- return fmt.Errorf("got a PING, want a PING ACK")
- }
- if p.Data != data {
- return fmt.Errorf("got PING data = %x, want %x", p.Data, data)
- }
- return nil
-}
-
-func (ct *clientTester) inflowWindow(streamID uint32) int32 {
- pool := ct.tr.connPoolOrDef.(*clientConnPool)
- pool.mu.Lock()
- defer pool.mu.Unlock()
- if n := len(pool.keys); n != 1 {
- ct.t.Errorf("clientConnPool contains %v keys, expected 1", n)
- return -1
- }
- for cc := range pool.keys {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- if streamID == 0 {
- return cc.inflow.avail + cc.inflow.unsent
- }
- cs := cc.streams[streamID]
- if cs == nil {
- ct.t.Errorf("no stream with id %v", streamID)
- return -1
- }
- return cs.inflow.avail + cs.inflow.unsent
- }
- return -1
-}
-
-func (ct *clientTester) cleanup() {
- ct.tr.CloseIdleConnections()
-
- // close both connections, ignore the error if its already closed
- ct.sc.Close()
- ct.cc.Close()
-}
-
-func (ct *clientTester) run() {
- var errOnce sync.Once
- var wg sync.WaitGroup
-
- run := func(which string, fn func() error) {
- defer wg.Done()
- if err := fn(); err != nil {
- errOnce.Do(func() {
- ct.t.Errorf("%s: %v", which, err)
- ct.cleanup()
- })
- }
- }
-
- wg.Add(2)
- go run("client", ct.client)
- go run("server", ct.server)
- wg.Wait()
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
- errOnce.Do(ct.cleanup) // clean up if no error
-}
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
+ const bodySize = 1 << 10
+
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ body := tc.newRequestBody()
+ body.writeBytes(bodySize / 2)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: false,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"PUT"},
+ ":path": []string{"/"},
+ },
+ })
-func (ct *clientTester) readFrame() (Frame, error) {
- return ct.fr.ReadFrame()
-}
+ // Provide enough congestion window for the full request body.
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeWindowUpdate(rt.streamID(), bodySize)
-func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
- for {
- f, err := ct.readFrame()
- if err != nil {
- return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- return hf, nil
- }
-}
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: false,
+ size: bodySize / 2,
+ })
-type countingReader struct {
- n *int64
-}
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", strconv.Itoa(status),
+ ),
+ })
-func (r countingReader) Read(p []byte) (n int, err error) {
- for i := range p {
- p[i] = byte(i)
+ res := rt.response()
+ if res.StatusCode != status {
+ t.Fatalf("status code = %v; want %v", res.StatusCode, status)
}
- atomic.AddInt64(r.n, int64(len(p)))
- return len(p), err
-}
-func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
-func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
+ body.writeBytes(bodySize / 2)
+ body.closeWithError(io.EOF)
-func testTransportReqBodyAfterResponse(t *testing.T, status int) {
- const bodySize = 10 << 20
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- recvLen := make(chan int64, 1)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
-
- body := &pipe{b: new(bytes.Buffer)}
- io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- if res.StatusCode != status {
- return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
- }
- io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
- body.CloseWithError(io.EOF)
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("Slurp: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("unexpected body: %q", slurp)
- }
- res.Body.Close()
- if status == 200 {
- if got := <-recvLen; got != bodySize {
- return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
- }
- } else {
- if got := <-recvLen; got == 0 || got >= bodySize {
- return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
- }
- }
- return nil
+ if status == 200 {
+ // After a 200 response, client sends the remaining request body.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: bodySize / 2,
+ multiple: true,
+ })
+ } else {
+ // After a 403 response, client gives up and resets the stream.
+ tc.wantFrameType(FrameRSTStream)
}
- ct.server = func() error {
- ct.greet()
- defer close(recvLen)
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var dataRecv int64
- var closed bool
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- //println(fmt.Sprintf("server got frame: %v", f))
- ended := false
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- if f.StreamEnded() {
- return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
- }
- case *DataFrame:
- dataLen := len(f.Data())
- if dataLen > 0 {
- if dataRecv == 0 {
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
- return err
- }
- if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
- return err
- }
- }
- dataRecv += int64(dataLen)
-
- if !closed && ((status != 200 && dataRecv > 0) ||
- (status == 200 && f.StreamEnded())) {
- closed = true
- if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
- return err
- }
- }
- if f.StreamEnded() {
- ended = true
- }
- case *RSTStreamFrame:
- if status == 200 {
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- ended = true
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- if ended {
- select {
- case recvLen <- dataRecv:
- default:
- }
- }
- }
- }
- ct.run()
+ rt.wantBody(nil)
}
// See golang.org/issue/13444
func TestTransportFullDuplex(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) // redundant but for clarity
w.(http.Flusher).Flush()
io.Copy(flushWriter{w}, capitalizeReader{r.Body})
fmt.Fprintf(w, "bye.\n")
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
pr, pw := io.Pipe()
- req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
+ req, err := http.NewRequest("PUT", ts.URL, io.NopCloser(pr))
if err != nil {
t.Fatal(err)
}
@@ -1132,12 +927,11 @@ func TestTransportFullDuplex(t *testing.T) {
func TestTransportConnectRequest(t *testing.T) {
gotc := make(chan *http.Request, 1)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
gotc <- r
- }, optOnlyServer)
- defer st.Close()
+ })
- u, err := url.Parse(st.ts.URL)
+ u, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -1257,121 +1051,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy
panic("invalid combination")
}
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
- if expect100Continue != noHeader {
- req.Header.Set("Expect", "100-continue")
- }
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("Slurp: %v", err)
- }
- wantBody := resBody
- if !withData {
- wantBody = ""
- }
- if string(slurp) != wantBody {
- return fmt.Errorf("body = %q; want %q", slurp, wantBody)
- }
- if trailers == noHeader {
- if len(res.Trailer) > 0 {
- t.Errorf("Trailer = %v; want none", res.Trailer)
- }
- } else {
- want := http.Header{"Some-Trailer": {"some-value"}}
- if !reflect.DeepEqual(res.Trailer, want) {
- t.Errorf("Trailer = %v; want %v", res.Trailer, want)
- }
- }
- return nil
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
+ if expect100Continue != noHeader {
+ req.Header.Set("Expect", "100-continue")
}
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+ rt := tc.roundTrip(req)
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- endStream := false
- send := func(mode headerType) {
- hbf := buf.Bytes()
- switch mode {
- case oneHeader:
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: true,
- EndStream: endStream,
- BlockFragment: hbf,
- })
- case splitHeader:
- if len(hbf) < 2 {
- panic("too small")
- }
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: false,
- EndStream: endStream,
- BlockFragment: hbf[:1],
- })
- ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
- default:
- panic("bogus mode")
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *DataFrame:
- if !f.StreamEnded() {
- // No need to send flow control tokens. The test request body is tiny.
- continue
- }
- // Response headers (1+ frames; 1 or 2 in this test, but never 0)
- {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
- enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
- if trailers != noHeader {
- enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
- }
- endStream = withData == false && trailers == noHeader
- send(resHeader)
- }
- if withData {
- endStream = trailers == noHeader
- ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
- }
- if trailers != noHeader {
- endStream = true
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
- send(trailers)
- }
- if endStream {
- return nil
- }
- case *HeadersFrame:
- if expect100Continue != noHeader {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
- send(expect100Continue)
- }
- }
- }
+ tc.wantFrameType(FrameHeaders)
+
+ // Possibly 100-continue, or skip when noHeader.
+ tc.writeHeadersMode(expect100Continue, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "100",
+ ),
+ })
+
+ // Client sends request body.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: len(reqBody),
+ })
+
+ hdr := []string{
+ ":status", "200",
+ "x-foo", "blah",
+ "x-bar", "more",
+ }
+ if trailers != noHeader {
+ hdr = append(hdr, "trailer", "some-trailer")
+ }
+ tc.writeHeadersMode(resHeader, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: withData == false && trailers == noHeader,
+ BlockFragment: tc.makeHeaderBlockFragment(hdr...),
+ })
+ if withData {
+ endStream := trailers == noHeader
+ tc.writeData(rt.streamID(), endStream, []byte(resBody))
+ }
+ tc.writeHeadersMode(trailers, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "some-trailer", "some-value",
+ ),
+ })
+
+ rt.wantStatus(200)
+ if !withData {
+ rt.wantBody(nil)
+ } else {
+ rt.wantBody([]byte(resBody))
+ }
+ if trailers == noHeader {
+ rt.wantTrailers(nil)
+ } else {
+ rt.wantTrailers(http.Header{
+ "Some-Trailer": {"some-value"},
+ })
}
- ct.run()
}
// Issue 26189, Issue 17739: ignore unknown 1xx responses
@@ -1383,130 +1130,76 @@ func TestTransportUnknown1xx(t *testing.T) {
return nil
}
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 204 {
- return fmt.Errorf("status code = %v; want 204", res.StatusCode)
- }
- want := `code=110 header=map[Foo-Bar:[110]]
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ for i := 110; i <= 114; i++ {
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", fmt.Sprint(i),
+ "foo-bar", fmt.Sprint(i),
+ ),
+ })
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
+
+ res := rt.response()
+ if res.StatusCode != 204 {
+ t.Fatalf("status code = %v; want 204", res.StatusCode)
+ }
+ want := `code=110 header=map[Foo-Bar:[110]]
code=111 header=map[Foo-Bar:[111]]
code=112 header=map[Foo-Bar:[112]]
code=113 header=map[Foo-Bar:[113]]
code=114 header=map[Foo-Bar:[114]]
`
- if got := buf.String(); got != want {
- t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- for i := 110; i <= 114; i++ {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
- enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if got := buf.String(); got != want {
+ t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
}
- ct.run()
-
}
func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("body = %q; want nothing", slurp)
- }
- if _, ok := res.Trailer["Some-Trailer"]; !ok {
- return fmt.Errorf("expected Some-Trailer")
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
-
- var n int
- var hf *HeadersFrame
- for hf == nil && n < 10 {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- hf, _ = f.(*HeadersFrame)
- n++
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- // send headers without Trailer header
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "some-trailer", "I'm an undeclared Trailer!",
+ ),
+ })
- // send trailers
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- ct.run()
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+ rt.wantTrailers(http.Header{
+ "Some-Trailer": []string{"I'm an undeclared Trailer!"},
+ })
}
func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
@@ -1516,10 +1209,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
testTransportInvalidTrailer_Pseudo(t, splitHeader)
}
func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
- testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- })
+ testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"),
+ ":colon", "foo",
+ "foo", "bar",
+ )
}
func TestTransportInvalidTrailer_Capital1(t *testing.T) {
@@ -1529,102 +1222,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) {
testTransportInvalidTrailer_Capital(t, splitHeader)
}
func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
- testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
- })
+ testInvalidTrailer(t, trailers, headerFieldNameError("Capital"),
+ "foo", "bar",
+ "Capital", "bad",
+ )
}
func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
- testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
- })
+ testInvalidTrailer(t, oneHeader, headerFieldNameError(""),
+ "", "bad",
+ )
}
func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
- testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
- })
+ testInvalidTrailer(t, oneHeader, headerFieldValueError("x"),
+ "x", "has\nnewline",
+ )
}
-func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- se, ok := err.(StreamError)
- if !ok || se.Cause != wantErr {
- return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("body = %q; want nothing", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) {
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var endStream bool
- send := func(mode headerType) {
- hbf := buf.Bytes()
- switch mode {
- case oneHeader:
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: endStream,
- BlockFragment: hbf,
- })
- case splitHeader:
- if len(hbf) < 2 {
- panic("too small")
- }
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: false,
- EndStream: endStream,
- BlockFragment: hbf[:1],
- })
- ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
- default:
- panic("bogus mode")
- }
- }
- // Response headers (1+ frames; 1 or 2 in this test, but never 0)
- {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
- endStream = false
- send(oneHeader)
- }
- // Trailers:
- {
- endStream = true
- buf.Reset()
- writeTrailer(enc)
- send(trailers)
- }
- return nil
- }
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "trailer", "declared",
+ ),
+ })
+ tc.writeHeadersMode(mode, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(trailers...),
+ })
+
+ rt.wantStatus(200)
+ body, err := rt.readBody()
+ se, ok := err.(StreamError)
+ if !ok || se.Cause != wantErr {
+ t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr)
+ }
+ if len(body) > 0 {
+ t.Fatalf("body = %q; want nothing", body)
}
- ct.run()
}
// headerListSize returns the HTTP2 header list size of h.
@@ -1741,24 +1386,22 @@ func TestPadHeaders(t *testing.T) {
}
func TestTransportChecksRequestHeaderListSize(t *testing.T) {
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
// Consume body & force client to send
// trailers before writing response.
- // ioutil.ReadAll returns non-nil err for
+ // io.ReadAll returns non-nil err for
// requests that attempt to send greater than
// maxHeaderListSize bytes of trailers, since
// those requests generate a stream reset.
- ioutil.ReadAll(r.Body)
+ io.ReadAll(r.Body)
r.Body.Close()
},
func(ts *httptest.Server) {
ts.Config.MaxHeaderBytes = 16 << 10
},
- optOnlyServer,
optQuiet,
)
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -1766,7 +1409,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
// Make an arbitrary request to ensure we get the server's
// settings frame and initialize peerMaxHeaderListSize.
- req0, err := http.NewRequest("GET", st.ts.URL, nil)
+ req0, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("newRequest: NewRequest: %v", err)
}
@@ -1777,7 +1420,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
res0.Body.Close()
res, err := tr.RoundTrip(req)
- if err != wantErr {
+ if !errors.Is(err, wantErr) {
if res != nil {
res.Body.Close()
}
@@ -1800,26 +1443,14 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
}
}
headerListSizeForRequest := func(req *http.Request) (size uint64) {
- contentLen := actualContentLength(req)
- trailers, err := commaSeparatedTrailers(req)
- if err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
- cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
- cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.mu.Lock()
- hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
- cc.mu.Unlock()
- if err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
- hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
+ const addGzipHeader = true
+ const peerMaxHeaderListSize = 0xffffffffffffffff
+ _, err := encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) {
+ hf := hpack.HeaderField{Name: name, Value: value}
size += uint64(hf.Size())
})
- if len(hdrs) > 0 {
- if _, err := hpackDec.Write(hdrs); err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
+ if err != nil {
+ t.Fatal(err)
}
return size
}
@@ -1829,13 +1460,29 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
newRequest := func() *http.Request {
// Body must be non-nil to enable writing trailers.
body := strings.NewReader("hello")
- req, err := http.NewRequest("POST", st.ts.URL, body)
+ req, err := http.NewRequest("POST", ts.URL, body)
if err != nil {
t.Fatalf("newRequest: NewRequest: %v", err)
}
return req
}
+ var (
+ scMu sync.Mutex
+ sc *serverConn
+ )
+ testHookGetServerConn = func(v *serverConn) {
+ scMu.Lock()
+ defer scMu.Unlock()
+ if sc != nil {
+ panic("testHookGetServerConn called multiple times")
+ }
+ sc = v
+ }
+ defer func() {
+ testHookGetServerConn = nil
+ }()
+
// Validate peerMaxHeaderListSize.
req := newRequest()
checkRoundTrip(req, nil, "Initial request")
@@ -1847,16 +1494,16 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
cc.mu.Lock()
peerSize := cc.peerMaxHeaderListSize
cc.mu.Unlock()
- st.scMu.Lock()
- wantSize := uint64(st.sc.maxHeaderListSize())
- st.scMu.Unlock()
+ scMu.Lock()
+ wantSize := uint64(sc.maxHeaderListSize())
+ scMu.Unlock()
if peerSize != wantSize {
t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
}
// Sanity check peerSize. (*serverConn) maxHeaderListSize adds
// 320 bytes of padding.
- wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
+ wantHeaderBytes := uint64(ts.Config.MaxHeaderBytes) + 320
if peerSize != wantHeaderBytes {
t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
}
@@ -1900,115 +1547,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
}
func TestTransportChecksResponseHeaderListSize(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if e, ok := err.(StreamError); ok {
- err = e.Cause
- }
- if err != errResponseHeaderListSize {
- size := int64(0)
- if res != nil {
- res.Body.Close()
- for k, vv := range res.Header {
- for _, v := range vv {
- size += int64(len(k)) + int64(len(v)) + 32
- }
- }
- }
- return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ hdr := []string{":status", "200"}
+ large := strings.Repeat("a", 1<<10)
+ for i := 0; i < 5042; i++ {
+ hdr = append(hdr, large, large)
+ }
+ hbf := tc.makeHeaderBlockFragment(hdr...)
+ // Note: this number might change if our hpack implementation changes.
+ // That's fine. This is just a sanity check that our response can fit in a single
+ // header block fragment frame.
+ if size, want := len(hbf), 6329; size != want {
+ t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: hbf,
+ })
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- large := strings.Repeat("a", 1<<10)
- for i := 0; i < 5042; i++ {
- enc.WriteField(hpack.HeaderField{Name: large, Value: large})
- }
- if size, want := buf.Len(), 6329; size != want {
- // Note: this number might change if
- // our hpack implementation
- // changes. That's fine. This is
- // just a sanity check that our
- // response can fit in a single
- // header block fragment frame.
- return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
+ res, err := rt.result()
+ if e, ok := err.(StreamError); ok {
+ err = e.Cause
+ }
+ if err != errResponseHeaderListSize {
+ size := int64(0)
+ if res != nil {
+ res.Body.Close()
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ size += int64(len(k)) + int64(len(v)) + 32
}
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
}
}
+ t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
}
- ct.run()
}
func TestTransportCookieHeaderSplit(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- req.Header.Add("Cookie", "a=b;c=d; e=f;")
- req.Header.Add("Cookie", "e=f;g=h; ")
- req.Header.Add("Cookie", "i=j")
- _, err := ct.tr.RoundTrip(req)
- return err
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- dec := hpack.NewDecoder(initialHeaderTableSize, nil)
- hfs, err := dec.DecodeFull(f.HeaderBlockFragment())
- if err != nil {
- return err
- }
- got := []string{}
- want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}
- for _, hf := range hfs {
- if hf.Name == "cookie" {
- got = append(got, hf.Value)
- }
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Cookies = %#v, want %#v", got, want)
- }
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ req.Header.Add("Cookie", "a=b;c=d; e=f;")
+ req.Header.Add("Cookie", "e=f;g=h; ")
+ req.Header.Add("Cookie", "i=j")
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true,
+ header: http.Header{
+ "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"},
+ },
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if err := rt.err(); err != nil {
+ t.Fatalf("RoundTrip = %v, want success", err)
}
- ct.run()
}
// Test that the Transport returns a typed error from Response.Body.Read calls
@@ -2016,22 +1628,20 @@ func TestTransportCookieHeaderSplit(t *testing.T) {
// a stream error, but others like cancel should be similar)
func TestTransportBodyReadErrorType(t *testing.T) {
doPanic := make(chan bool, 1)
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush() // force headers out
<-doPanic
panic("boom")
},
- optOnlyServer,
optQuiet,
)
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- res, err := c.Get(st.ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -2055,7 +1665,7 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) {
conn net.Conn // to close if set
)
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
defer mu.Unlock()
@@ -2063,9 +1673,7 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) {
conn.Close()
}
},
- optOnlyServer,
)
- defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
@@ -2082,20 +1690,18 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) {
}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- c.Get(st.ts.URL)
+ c.Get(ts.URL)
}
// Test that the http1 Transport.DisableKeepAlives option is respected
// and connections are closed as soon as idle.
// See golang.org/issue/14008
func TestTransportDisableKeepAlives(t *testing.T) {
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "hi")
},
- optOnlyServer,
)
- defer st.Close()
connClosed := make(chan struct{}) // closed on tls.Conn.Close
tr := &Transport{
@@ -2112,11 +1718,11 @@ func TestTransportDisableKeepAlives(t *testing.T) {
},
}
c := &http.Client{Transport: tr}
- res, err := c.Get(st.ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
- if _, err := ioutil.ReadAll(res.Body); err != nil {
+ if _, err := io.ReadAll(res.Body); err != nil {
t.Fatal(err)
}
defer res.Body.Close()
@@ -2133,14 +1739,12 @@ func TestTransportDisableKeepAlives(t *testing.T) {
// but when things are totally idle, it still needs to close.
func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
const D = 25 * time.Millisecond
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {
time.Sleep(D)
io.WriteString(w, "hi")
},
- optOnlyServer,
)
- defer st.Close()
var dials int32
var conns sync.WaitGroup
@@ -2175,12 +1779,12 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
}
go func() {
defer reqs.Done()
- res, err := c.Get(st.ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
return
}
- if _, err := ioutil.ReadAll(res.Body); err != nil {
+ if _, err := io.ReadAll(res.Body); err != nil {
t.Error(err)
return
}
@@ -2224,68 +1828,62 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
}
func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
- ct := newClientTester(t)
- ct.tr.t1 = &http.Transport{
- ResponseHeaderTimeout: 5 * time.Millisecond,
- }
- ct.client = func() error {
- c := &http.Client{Transport: ct.tr}
- var err error
- var n int64
- const bodySize = 4 << 20
- if body {
- _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
- } else {
- _, err = c.Get("https://dummy.tld/")
+ const bodySize = 4 << 20
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.t1 = &http.Transport{
+ ResponseHeaderTimeout: 5 * time.Millisecond,
}
- if !isTimeout(err) {
- t.Errorf("client expected timeout error; got %#v", err)
- }
- if body && n != bodySize {
- t.Errorf("only read %d bytes of body; want %d", n, bodySize)
- }
- return nil
+ })
+ tc.greet()
+
+ var req *http.Request
+ var reqBody *testRequestBody
+ if body {
+ reqBody = tc.newRequestBody()
+ reqBody.writeBytes(bodySize)
+ reqBody.closeWithError(io.EOF)
+ req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody)
+ req.Header.Set("Content-Type", "text/foo")
+ } else {
+ req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
}
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- switch f := f.(type) {
- case *DataFrame:
- dataLen := len(f.Data())
- if dataLen > 0 {
- if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
- return err
- }
- if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
- return err
- }
- }
- case *RSTStreamFrame:
- if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
- return nil
- }
- }
- }
+
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeWindowUpdate(rt.streamID(), bodySize)
+
+ if body {
+ tc.wantData(wantData{
+ endStream: true,
+ size: bodySize,
+ multiple: true,
+ })
+ }
+
+ tc.advance(4 * time.Millisecond)
+ if rt.done() {
+ t.Fatalf("RoundTrip is done after 4ms; want still waiting")
+ }
+ tc.advance(1 * time.Millisecond)
+
+ if err := rt.err(); !isTimeout(err) {
+ t.Fatalf("RoundTrip error: %v; want timeout error", err)
}
- ct.run()
}
func TestTransportDisableCompression(t *testing.T) {
const body = "sup"
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
want := http.Header{
"User-Agent": []string{"Go-http-client/2.0"},
}
if !reflect.DeepEqual(r.Header, want) {
t.Errorf("request headers = %v; want %v", r.Header, want)
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
@@ -2295,7 +1893,7 @@ func TestTransportDisableCompression(t *testing.T) {
}
defer tr.CloseIdleConnections()
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -2308,15 +1906,14 @@ func TestTransportDisableCompression(t *testing.T) {
// RFC 7540 section 8.1.2.2
func TestTransportRejectsConnHeaders(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
var got []string
for k := range r.Header {
got = append(got, k)
}
sort.Strings(got)
w.Header().Set("Got-Header", strings.Join(got, ","))
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -2404,7 +2001,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) {
}
for _, tt := range tests {
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
req.Header[tt.key] = tt.value
res, err := tr.RoundTrip(req)
var got string
@@ -2458,14 +2055,13 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", tt.cl[0])
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("HEAD", st.ts.URL, nil)
+ req, _ := http.NewRequest("HEAD", ts.URL, nil)
res, err := tr.RoundTrip(req)
var got string
@@ -2484,19 +2080,20 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) {
}
// golang.org/issue/14048
-func TestTransportFailsOnInvalidHeaders(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+// golang.org/issue/64766
+func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
var got []string
for k := range r.Header {
got = append(got, k)
}
sort.Strings(got)
w.Header().Set("Got-Header", strings.Join(got, ","))
- }, optOnlyServer)
- defer st.Close()
+ })
tests := [...]struct {
h http.Header
+ t http.Header
wantErr string
}{
0: {
@@ -2515,14 +2112,23 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) {
h: http.Header{"foo": {"foo\x01bar"}},
wantErr: `invalid HTTP header value for header "foo"`,
},
+ 4: {
+ t: http.Header{"foo": {"foo\x01bar"}},
+ wantErr: `invalid HTTP trailer value for header "foo"`,
+ },
+ 5: {
+ t: http.Header{"x-\r\nda": {"foo\x01bar"}},
+ wantErr: `invalid HTTP trailer name "x-\r\nda"`,
+ },
}
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
for i, tt := range tests {
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
req.Header = tt.h
+ req.Trailer = tt.t
res, err := tr.RoundTrip(req)
var bad bool
if tt.wantErr == "" {
@@ -2549,7 +2155,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) {
// the first Read call's gzip.NewReader returning an error.
func TestGzipReader_DoubleReadCrash(t *testing.T) {
gz := &gzipReader{
- body: ioutil.NopCloser(strings.NewReader("0123456789")),
+ body: io.NopCloser(strings.NewReader("0123456789")),
}
var buf [1]byte
n, err1 := gz.Read(buf[:])
@@ -2568,7 +2174,7 @@ func TestGzipReader_ReadAfterClose(t *testing.T) {
w.Write([]byte("012345679"))
w.Close()
gz := &gzipReader{
- body: ioutil.NopCloser(&body),
+ body: io.NopCloser(&body),
}
var buf [1]byte
n, err := gz.Read(buf[:])
@@ -2658,115 +2264,61 @@ func TestTransportNewTLSConfig(t *testing.T) {
// without END_STREAM, followed by a 0-length DATA frame with
// END_STREAM. Make sure we don't get confused by that. (We did.)
func TestTransportReadHeadResponse(t *testing.T) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- if res.ContentLength != 123 {
- return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("ReadAll: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false, // as the GFE does
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(hf.StreamID, true, nil)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false, // as the GFE does
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "123",
+ ),
+ })
+ tc.writeData(rt.streamID(), true, nil)
- <-clientDone
- return nil
- }
+ res := rt.response()
+ if res.ContentLength != 123 {
+ t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
}
- ct.run()
+ rt.wantBody(nil)
}
func TestTransportReadHeadResponseWithBody(t *testing.T) {
- // This test use not valid response format.
- // Discarding logger output to not spam tests output.
- log.SetOutput(ioutil.Discard)
+ // This test uses an invalid response format.
+ // Discard logger output to not spam tests output.
+ log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
response := "redirecting to /elsewhere"
- ct := newClientTester(t)
- clientDone := make(chan struct{})
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- if res.ContentLength != int64(len(response)) {
- return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("ReadAll: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(hf.StreamID, true, []byte(response))
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", strconv.Itoa(len(response)),
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte(response))
- <-clientDone
- return nil
- }
+ res := rt.response()
+ if res.ContentLength != int64(len(response)) {
+ t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response))
}
- ct.run()
+ rt.wantBody(nil)
}
type neverEnding byte
@@ -2784,11 +2336,10 @@ func (b neverEnding) Read(p []byte) (int, error) {
// runs out of flow control tokens)
func TestTransportHandlerBodyClose(t *testing.T) {
const bodySize = 10 << 20
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.Body.Close()
io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -2797,7 +2348,7 @@ func TestTransportHandlerBodyClose(t *testing.T) {
const numReq = 10
for i := 0; i < numReq; i++ {
- req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
+ req, err := http.NewRequest("POST", ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
if err != nil {
t.Fatal(err)
}
@@ -2805,7 +2356,7 @@ func TestTransportHandlerBodyClose(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- n, err := io.Copy(ioutil.Discard, res.Body)
+ n, err := io.Copy(io.Discard, res.Body)
res.Body.Close()
if n != bodySize || err != nil {
t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
@@ -2830,7 +2381,7 @@ func TestTransportFlowControl(t *testing.T) {
}
var wrote int64 // updated atomically
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
b := make([]byte, bufLen)
for wrote < total {
n, err := w.Write(b)
@@ -2841,11 +2392,11 @@ func TestTransportFlowControl(t *testing.T) {
}
w.(http.Flusher).Flush()
}
- }, optOnlyServer)
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal("NewRequest error:", err)
}
@@ -2891,190 +2442,128 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
}
func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
+ tc := newTestClientConn(t)
+ tc.greet()
const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
const goAwayDebugData = "some debug data"
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if failMidBody {
- if err != nil {
- return fmt.Errorf("unexpected client RoundTrip error: %v", err)
- }
- _, err = io.Copy(ioutil.Discard, res.Body)
- res.Body.Close()
- }
- want := GoAwayError{
- LastStreamID: 5,
- ErrCode: goAwayErrCode,
- DebugData: goAwayDebugData,
- }
- if !reflect.DeepEqual(err, want) {
- t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- if failMidBody {
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- // Write two GOAWAY frames, to test that the Transport takes
- // the interesting parts of both.
- ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
- ct.fr.WriteGoAway(5, goAwayErrCode, nil)
- ct.sc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- ct.sc.(*net.TCPConn).Close()
- }
- <-clientDone
- return nil
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ if failMidBody {
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "123",
+ ),
+ })
}
- ct.run()
-}
-func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
- ct := newClientTester(t)
+ // Write two GOAWAY frames, to test that the Transport takes
+ // the interesting parts of both.
+ tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
+ tc.writeGoAway(5, goAwayErrCode, nil)
+ tc.closeWrite()
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
+ res, err := rt.result()
+ whence := "RoundTrip"
+ if failMidBody {
+ whence = "Body.Read"
if err != nil {
- return err
+ t.Fatalf("RoundTrip error = %v, want success", err)
}
+ _, err = res.Body.Read(make([]byte, 1))
+ }
- if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
- return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
- }
- res.Body.Close() // leaving 4999 bytes unread
-
- return nil
+ want := GoAwayError{
+ LastStreamID: 5,
+ ErrCode: goAwayErrCode,
+ DebugData: goAwayDebugData,
}
- ct.server = func() error {
- ct.greet()
+ if !reflect.DeepEqual(err, want) {
+ t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want)
+ }
+}
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
+func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
+ initialInflow := tc.inflowWindow(0)
+
+ // Two cases:
+ // - Send one DATA frame with 5000 bytes.
+ // - Send two DATA frames with 1 and 4999 bytes each.
+ //
+ // In both cases, the client should consume one byte of data,
+ // refund that byte, then refund the following 4999 bytes.
+ //
+ // In the second case, the server waits for the client to reset the
+ // stream before sending the second DATA frame. This tests the case
+ // where the client receives a DATA frame after it has reset the stream.
+ const streamNotEnded = false
+ if oneDataFrame {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
+ } else {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
+ }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialInflow := ct.inflowWindow(0)
-
- // Two cases:
- // - Send one DATA frame with 5000 bytes.
- // - Send two DATA frames with 1 and 4999 bytes each.
- //
- // In both cases, the client should consume one byte of data,
- // refund that byte, then refund the following 4999 bytes.
- //
- // In the second case, the server waits for the client to reset the
- // stream before sending the second DATA frame. This tests the case
- // where the client receives a DATA frame after it has reset the stream.
- if oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
- } else {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
- }
+ res := rt.response()
+ if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
+ t.Fatalf("body read = %v, %v; want 1, nil", n, err)
+ }
+ res.Body.Close() // leaving 4999 bytes unread
+ tc.sync()
- wantRST := true
- wantWUF := true
- if !oneDataFrame {
- wantWUF = false // flow control update is small, and will not be sent
- }
- for wantRST || wantWUF {
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
+ sentAdditionalData := false
+ tc.wantUnorderedFrames(
+ func(f *RSTStreamFrame) bool {
+ if f.ErrCode != ErrCodeCancel {
+ t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
}
- switch f := f.(type) {
- case *RSTStreamFrame:
- if !wantRST {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.ErrCode != ErrCodeCancel {
- return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
- }
- wantRST = false
- case *WindowUpdateFrame:
- if !wantWUF {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.Increment != 5000 {
- return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
- }
- wantWUF = false
- default:
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
+ if !oneDataFrame {
+ // Send the remaining data now.
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
+ sentAdditionalData = true
}
- }
- if !oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
+ return true
+ },
+ func(f *PingFrame) bool {
+ return true
+ },
+ func(f *WindowUpdateFrame) bool {
+ if !oneDataFrame && !sentAdditionalData {
+ t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
}
- wuf, ok := f.(*WindowUpdateFrame)
- if !ok || wuf.Increment != 5000 {
- return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
+ if f.Increment != 5000 {
+ t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
}
- }
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- if got, want := ct.inflowWindow(0), initialInflow; got != want {
- return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
- }
- return nil
+ return true
+ },
+ )
+
+ if got, want := tc.inflowWindow(0), initialInflow; got != want {
+ t.Fatalf("connection flow tokens = %v, want %v", got, want)
}
- ct.run()
}
// See golang.org/issue/16481
@@ -3090,199 +2579,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
// Issue 16612: adjust flow control on open streams when transport
// receives SETTINGS with INITIAL_WINDOW_SIZE from server.
func TestTransportAdjustsFlowControl(t *testing.T) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
-
const bodySize = 1 << 20
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ // Don't write our SETTINGS yet.
- req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
+ body := tc.newRequestBody()
+ body.writeBytes(bodySize)
+ body.closeWithError(io.EOF)
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ gotBytes := int64(0)
+ for {
+ f := readFrame[*DataFrame](t, tc)
+ gotBytes += int64(len(f.Data()))
+ // After we've got half the client's initial flow control window's worth
+ // of request body data, give it just enough flow control to finish.
+ if gotBytes >= initialWindowSize/2 {
+ break
}
- res.Body.Close()
- return nil
}
- ct.server = func() error {
- _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- var gotBytes int64
- var sentSettings bool
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- return nil
- default:
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- }
- switch f := f.(type) {
- case *DataFrame:
- gotBytes += int64(len(f.Data()))
- // After we've got half the client's
- // initial flow control window's worth
- // of request body data, give it just
- // enough flow control to finish.
- if gotBytes >= initialWindowSize/2 && !sentSettings {
- sentSettings = true
-
- ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
- ct.fr.WriteWindowUpdate(0, bodySize)
- ct.fr.WriteSettingsAck()
- }
+ tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeSettingsAck()
- if f.StreamEnded() {
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }
- }
+ tc.wantUnorderedFrames(
+ func(f *SettingsFrame) bool { return true },
+ func(f *DataFrame) bool {
+ gotBytes += int64(len(f.Data()))
+ return f.StreamEnded()
+ },
+ )
+
+ if gotBytes != bodySize {
+ t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
}
- ct.run()
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
}
// See golang.org/issue/16556
func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
- unblockClient := make(chan bool, 1)
+ initialConnWindow := tc.inflowWindow(0)
+ initialStreamWindow := tc.inflowWindow(rt.streamID())
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- <-unblockClient
- return nil
+ pad := make([]byte, 5)
+ tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
+
+ // Padding flow control should have been returned.
+ if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
+ t.Errorf("conn inflow window = %v, want %v", got, want)
+ }
+ if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
+ t.Errorf("stream inflow window = %v, want %v", got, want)
}
- ct.server = func() error {
- ct.greet()
-
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
-
- initialConnWindow := ct.inflowWindow(0)
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialStreamWindow := ct.inflowWindow(hf.StreamID)
- pad := make([]byte, 5)
- ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- // Padding flow control should have been returned.
- if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
- t.Errorf("conn inflow window = %v, want %v", got, want)
- }
- if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
- t.Errorf("stream inflow window = %v, want %v", got, want)
- }
- unblockClient <- true
- return nil
- }
- ct.run()
-}
+}
// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
// StreamError as a result of the response HEADERS
func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ " content-type", "bogus",
+ ),
+ })
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- return errors.New("unexpected successful GET")
- }
- want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
- if !reflect.DeepEqual(want, err) {
- t.Errorf("RoundTrip error = %#v; want %#v", err, want)
- }
- return nil
+ err := rt.err()
+ want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
+ if !reflect.DeepEqual(err, want) {
+ t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
}
- ct.server = func() error {
- ct.greet()
-
- hf, err := ct.firstHeaders()
- if err != nil {
- return err
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
-
- for {
- fr, err := ct.readFrame()
- if err != nil {
- return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
- }
- if _, ok := fr.(*SettingsFrame); ok {
- continue
- }
- if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
- t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
- }
- break
- }
- return nil
+ fr := readFrame[*RSTStreamFrame](t, tc)
+ if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
+ t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
}
- ct.run()
}
// byteAndEOFReader returns is in an io.Reader which reads one byte
@@ -3307,16 +2721,15 @@ func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
// which returns (non-0, io.EOF) and also needs to set the ContentLength
// explicitly.
func TestTransportBodyDoubleEndStream(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
// Nothing.
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
for i := 0; i < 2; i++ {
- req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
+ req, _ := http.NewRequest("POST", ts.URL, byteAndEOFReader('a'))
req.ContentLength = 1
res, err := tr.RoundTrip(req)
if err != nil {
@@ -3428,11 +2841,15 @@ func TestTransportRequestPathPseudo(t *testing.T) {
},
}
for i, tt := range tests {
- cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
- cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.mu.Lock()
- hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
- cc.mu.Unlock()
+ hbuf := &bytes.Buffer{}
+ henc := hpack.NewEncoder(hbuf)
+
+ const addGzipHeader = false
+ const peerMaxHeaderListSize = 0xffffffffffffffff
+ _, err := encodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) {
+ henc.WriteField(hpack.HeaderField{Name: name, Value: value})
+ })
+ hdrs := hbuf.Bytes()
var got result
hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
if f.Name == ":path" {
@@ -3459,16 +2876,17 @@ func TestTransportRequestPathPseudo(t *testing.T) {
// before we've determined that the ClientConn is usable.
func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
const body = "foo"
- req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
+ req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body)))
cc := &ClientConn{
closed: true,
reqHeaderMu: make(chan struct{}, 1),
+ t: &Transport{},
}
_, err := cc.RoundTrip(req)
if err != errClientConnUnusable {
t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
}
- slurp, err := ioutil.ReadAll(req.Body)
+ slurp, err := io.ReadAll(req.Body)
if err != nil {
t.Errorf("ReadAll = %v", err)
}
@@ -3478,12 +2896,11 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
}
func TestClientConnPing(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
- defer st.Close()
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
- cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
+ cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
@@ -3501,7 +2918,7 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
clientGotResponse := make(chan bool, 1)
const msg = "Hello."
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/hello") {
time.Sleep(50 * time.Millisecond)
io.WriteString(w, msg)
@@ -3516,29 +2933,28 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
req.Cancel = cancel
res, err := c.Do(req)
clientGotResponse <- true
if err != nil {
t.Fatal(err)
}
- if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
+ if _, err = io.Copy(io.Discard, res.Body); err == nil {
t.Fatal("unexpected success")
}
- res, err = c.Get(st.ts.URL + "/hello")
+ res, err = c.Get(ts.URL + "/hello")
if err != nil {
t.Fatal(err)
}
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -3550,21 +2966,20 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
// Issue 21316: It should be safe to reuse an http.Request after the
// request has completed.
func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
io.WriteString(w, "body")
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
- if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
+ if _, err = io.Copy(io.Discard, resp.Body); err != nil {
t.Fatalf("error reading response body: %v", err)
}
if err := resp.Body.Close(); err != nil {
@@ -3576,34 +2991,30 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
}
func TestTransportCloseAfterLostPing(t *testing.T) {
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.tr.PingTimeout = 1 * time.Second
- ct.tr.ReadIdleTimeout = 1 * time.Second
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- _, err := ct.tr.RoundTrip(req)
- if err == nil || !strings.Contains(err.Error(), "client connection lost") {
- return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- <-clientDone
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.PingTimeout = 1 * time.Second
+ tr.ReadIdleTimeout = 1 * time.Second
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ tc.advance(1 * time.Second)
+ tc.wantFrameType(FramePing)
+
+ tc.advance(1 * time.Second)
+ err := rt.err()
+ if err == nil || !strings.Contains(err.Error(), "client connection lost") {
+ t.Fatalf("expected to get error about \"connection lost\", got %v", err)
}
- ct.run()
}
func TestTransportPingWriteBlocks(t *testing.T) {
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {},
- optOnlyServer,
)
- defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -3622,424 +3033,327 @@ func TestTransportPingWriteBlocks(t *testing.T) {
}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- _, err := c.Get(st.ts.URL)
+ _, err := c.Get(ts.URL)
if err == nil {
t.Fatalf("Get = nil, want error")
}
}
-func TestTransportPingWhenReading(t *testing.T) {
- testCases := []struct {
- name string
- readIdleTimeout time.Duration
- deadline time.Duration
- expectedPingCount int
- }{
- {
- name: "two pings",
- readIdleTimeout: 100 * time.Millisecond,
- deadline: time.Second,
- expectedPingCount: 2,
- },
- {
- name: "zero ping",
- readIdleTimeout: time.Second,
- deadline: 200 * time.Millisecond,
- expectedPingCount: 0,
- },
- {
- name: "0 readIdleTimeout means no ping",
- readIdleTimeout: 0 * time.Millisecond,
- deadline: 500 * time.Millisecond,
- expectedPingCount: 0,
- },
- }
-
- for _, tc := range testCases {
- tc := tc // capture range variable
- t.Run(tc.name, func(t *testing.T) {
- testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
- })
- }
-}
+func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 1000 * time.Millisecond
+ })
+ tc.greet()
-func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
- var pingCount int
- ct := newClientTester(t)
- ct.tr.ReadIdleTimeout = readIdleTimeout
+ ctx, cancel := context.WithCancel(context.Background())
+ req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- ctx, cancel := context.WithTimeout(context.Background(), deadline)
- defer cancel()
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
- }
- _, err = ioutil.ReadAll(res.Body)
- if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
- return nil
+ for i := 0; i < 5; i++ {
+ // No ping yet...
+ tc.advance(999 * time.Millisecond)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- cancel()
- return err
+ // ...ping now.
+ tc.advance(1 * time.Millisecond)
+ f := readFrame[*PingFrame](t, tc)
+ tc.writePing(true, f.Data)
}
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var streamID uint32
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-ctx.Done():
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- streamID = f.StreamID
- case *PingFrame:
- pingCount++
- if pingCount == expectedPingCount {
- if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
- return err
- }
- }
- if err := ct.fr.WritePing(true, f.Data); err != nil {
- return err
- }
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ // Cancel the request, Transport resets it and returns an error from body reads.
+ cancel()
+ tc.sync()
+
+ tc.wantFrameType(FrameRSTStream)
+ _, err := rt.readBody()
+ if err == nil {
+ t.Fatalf("Response.Body.Read() = %v, want error", err)
}
- ct.run()
}
-func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
- ln := newLocalListener(t)
- defer ln.Close()
+func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 0 // PINGs disabled
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- var (
- mu sync.Mutex
- count int
- conns []net.Conn
- )
- var wg sync.WaitGroup
- tr := &Transport{
- TLSClientConfig: tlsConfigInsecure,
- }
- tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- mu.Lock()
- defer mu.Unlock()
- count++
- cc, err := net.Dial("tcp", ln.Addr().String())
- if err != nil {
- return nil, fmt.Errorf("dial error: %v", err)
- }
- conns = append(conns, cc)
- sc, err := ln.Accept()
- if err != nil {
- return nil, fmt.Errorf("accept error: %v", err)
- }
- conns = append(conns, sc)
- ct := &clientTester{
- t: t,
- tr: tr,
- cc: cc,
- sc: sc,
- fr: NewFramer(sc, sc),
- }
- wg.Add(1)
- go func(count int) {
- defer wg.Done()
- server(count, ct)
- }(count)
- return cc, nil
+ // No PING is sent, even after a long delay.
+ tc.advance(1 * time.Minute)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
+}
- client(tr)
- tr.CloseIdleConnections()
- ln.Close()
- for _, c := range conns {
- c.Close()
+func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) {
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY with an error and
+ // a MaxStreamID less than the request ID.
+ // This probably indicates that there was something wrong with our request,
+ // so we don't retry it.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeGoAway(0 /*max id*/, ErrCodeInternal, nil)
+ if rt.err() == nil {
+ t.Fatalf("after GOAWAY, RoundTrip is not done, want error")
}
- wg.Wait()
}
-func TestTransportRetryAfterGOAWAY(t *testing.T) {
- client := func(tr *Transport) {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := tr.RoundTrip(req)
- if res != nil {
- res.Body.Close()
- if got := res.Header.Get("Foo"); got != "bar" {
- err = fmt.Errorf("foo header = %q; want bar", got)
- }
- }
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- }
- }
+func TestTransportRetryAfterGOAWAYRetry(t *testing.T) {
+ tt := newTestTransport(t)
- server := func(count int, ct *clientTester) {
- switch count {
- case 1:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server1 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server1 got %v", hf)
- if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
- t.Errorf("server1 failed writing GOAWAY: %v", err)
- return
- }
- case 2:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server2 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server2 got %v", hf)
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- err = ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- if err != nil {
- t.Errorf("server2 failed writing response HEADERS: %v", err)
- }
- default:
- t.Errorf("unexpected number of dials")
- return
- }
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY with ErrCodeNo and
+ // a MaxStreamID less than the request ID.
+ // We take the server at its word that nothing has really gone wrong,
+ // and retry the request.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil)
+ if rt.done() {
+ t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
+ }
+
+ // Second attempt succeeds on a new connection.
+ tc = tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ rt.wantStatus(200)
+}
- testClientMultipleDials(t, client, server)
+func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) {
+ tt := newTestTransport(t)
+
+ // First request succeeds.
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tt.roundTrip(req)
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // Settings ACK
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
+
+ // Second request: Server sends a GOAWAY with
+ // a MaxStreamID less than the request ID.
+ // The server says it didn't see this request,
+ // so we retry it on a new connection.
+ req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt2 := tt.roundTrip(req)
+
+ // Second request, first attempt.
+ tc.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeGoAway(1 /*max id*/, ErrCodeProtocol, nil)
+ if rt2.done() {
+ t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
+ }
+
+ // Second request, second attempt.
+ tc = tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt2.wantStatus(200)
}
func TestTransportRetryAfterRefusedStream(t *testing.T) {
- clientDone := make(chan struct{})
- client := func(tr *Transport) {
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := tr.RoundTrip(req)
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- return
- }
- resp.Body.Close()
- if resp.StatusCode != 204 {
- t.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a RST_STREAM.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // settings ACK
+ tc.writeRSTStream(1, ErrCodeRefusedStream)
+ if rt.done() {
+ t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
}
- server := func(_ int, ct *clientTester) {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var count int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- default:
- t.Error(err)
- }
- return
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- t.Errorf("headers should have END_HEADERS be ended: %v", f)
- return
- }
- count++
- if count == 1 {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- } else {
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- default:
- t.Errorf("Unexpected client frame %v", f)
- return
- }
- }
- }
+ // Second attempt succeeds on the same connection.
+ tc.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 3,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
- testClientMultipleDials(t, client, server)
+ rt.wantStatus(204)
}
func TestTransportRetryHasLimit(t *testing.T) {
- // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
- if testing.Short() {
- t.Skip("skipping long test in short mode")
- }
- retryBackoffHook = func(d time.Duration) *time.Timer {
- return time.NewTimer(0) // fires immediately
- }
- defer func() {
- retryBackoffHook = nil
- }()
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := ct.tr.RoundTrip(req)
- if err == nil {
- return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+
+ var totalDelay time.Duration
+ count := 0
+ for streamID := uint32(1); ; streamID += 2 {
+ count++
+ tc.wantHeaders(wantHeader{
+ streamID: streamID,
+ endStream: true,
+ })
+ if streamID == 1 {
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // settings ACK
}
- t.Logf("expected error, got: %v", err)
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
+ tc.writeRSTStream(streamID, ErrCodeRefusedStream)
+
+ d, scheduled := tt.group.TimeUntilEvent()
+ if !scheduled {
+ if streamID == 1 {
+ continue
}
+ break
}
+ totalDelay += d
+ if totalDelay > 5*time.Minute {
+ t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay)
+ }
+ tt.advance(d)
+ }
+ if got, want := count, 5; got < count {
+ t.Errorf("RoundTrip made %v attempts, want at least %v", got, want)
+ }
+ if rt.err() == nil {
+ t.Errorf("RoundTrip succeeded, want error")
}
- ct.run()
}
func TestTransportResponseDataBeforeHeaders(t *testing.T) {
- // This test use not valid response format.
- // Discarding logger output to not spam tests output.
- log.SetOutput(ioutil.Discard)
- defer log.SetOutput(os.Stderr)
+ // Discard log output complaining about protocol error.
+ log.SetOutput(io.Discard)
+ t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done
+
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ // First request is normal to ensure the check is per stream and not per connection.
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt1.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
- // First request is normal to ensure the check is per stream and not per connection.
- _, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip expected no error, got: %v", err)
- }
- // Second request returns a DATA frame with no HEADERS.
- resp, err := ct.tr.RoundTrip(req)
- if err == nil {
- return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
- }
- if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
- return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err == io.EOF {
- return nil
- } else if err != nil {
- return err
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame:
- case *HeadersFrame:
- switch f.StreamID {
- case 1:
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- case 3:
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- }
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ // Second request returns a DATA frame with no HEADERS.
+ rt2 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeData(rt2.streamID(), true, []byte("payload"))
+ if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
+ t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
}
- ct.run()
}
func TestTransportMaxFrameReadSize(t *testing.T) {
@@ -4053,39 +3367,27 @@ func TestTransportMaxFrameReadSize(t *testing.T) {
maxReadFrameSize: 1024,
want: minMaxFrameSize,
}} {
- ct := newClientTester(t)
- ct.tr.MaxReadFrameSize = test.maxReadFrameSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- return nil
- }
- ct.server = func() error {
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
- var got uint32
- ct.settings.ForeachSetting(func(s Setting) error {
- switch s.ID {
- case SettingMaxFrameSize:
- got = s.Val
- }
- return nil
+ t.Run(fmt.Sprint(test.maxReadFrameSize), func(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxReadFrameSize = test.maxReadFrameSize
})
- if got != test.want {
+
+ fr := readFrame[*SettingsFrame](t, tc)
+ got, ok := fr.Value(SettingMaxFrameSize)
+ if !ok {
+ t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
+ } else if got != test.want {
t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
}
- return nil
- }
- ct.run()
+ })
}
}
func TestTransportRequestsLowServerLimit(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- }, optOnlyServer, func(s *Server) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
s.MaxConcurrentStreams = 1
})
- defer st.Close()
var (
connCountMu sync.Mutex
@@ -4104,7 +3406,7 @@ func TestTransportRequestsLowServerLimit(t *testing.T) {
const reqCount = 3
for i := 0; i < reqCount; i++ {
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -4129,324 +3431,115 @@ func TestTransportRequestsLowServerLimit(t *testing.T) {
func TestTransportRequestsStallAtServerLimit(t *testing.T) {
const maxConcurrent = 2
- greet := make(chan struct{}) // server sends initial SETTINGS frame
- gotRequest := make(chan struct{}) // server received a request
- clientDone := make(chan struct{})
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.StrictMaxConcurrentStreams = true
+ })
+ tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
cancelClientRequest := make(chan struct{})
- // Collect errors from goroutines.
- var wg sync.WaitGroup
- errs := make(chan error, 100)
- defer func() {
- wg.Wait()
- close(errs)
- for err := range errs {
- t.Error(err)
+ // Start maxConcurrent+2 requests.
+ // The server does not respond to any of them yet.
+ var rts []*testRoundTrip
+ for k := 0; k < maxConcurrent+2; k++ {
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ req.Cancel = cancelClientRequest
+ }
+ rt := tc.roundTrip(req)
+ rts = append(rts, rt)
+
+ if k < maxConcurrent {
+ // We are under the stream limit, so the client sends the request.
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", k)},
+ },
+ })
+ } else {
+ // We have reached the stream limit,
+ // so the client cannot send the request.
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
+ }
}
- }()
- // We will send maxConcurrent+2 requests. This checker goroutine waits for the
- // following stages:
- // 1. The first maxConcurrent requests are received by the server.
- // 2. The client will cancel the next request
- // 3. The server is unblocked so it can service the first maxConcurrent requests
- // 4. The client will send the final request
- wg.Add(1)
- unblockClient := make(chan struct{})
- clientRequestCancelled := make(chan struct{})
- unblockServer := make(chan struct{})
- go func() {
- defer wg.Done()
- // Stage 1.
- for k := 0; k < maxConcurrent; k++ {
- <-gotRequest
- }
- // Stage 2.
- close(unblockClient)
- <-clientRequestCancelled
- // Stage 3: give some time for the final RoundTrip call to be scheduled and
- // verify that the final request is not sent.
- time.Sleep(50 * time.Millisecond)
- select {
- case <-gotRequest:
- errs <- errors.New("last request did not stall")
- close(unblockServer)
- return
- default:
+ if rt.done() {
+ t.Fatalf("rt %v done", k)
}
- close(unblockServer)
- // Stage 4.
- <-gotRequest
- }()
+ }
- ct := newClientTester(t)
- ct.tr.StrictMaxConcurrentStreams = true
- ct.client = func() error {
- var wg sync.WaitGroup
- defer func() {
- wg.Wait()
- close(clientDone)
- ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- ct.cc.(*net.TCPConn).Close()
- }
- }()
- for k := 0; k < maxConcurrent+2; k++ {
- wg.Add(1)
- go func(k int) {
- defer wg.Done()
- // Don't send the second request until after receiving SETTINGS from the server
- // to avoid a race where we use the default SettingMaxConcurrentStreams, which
- // is much larger than maxConcurrent. We have to send the first request before
- // waiting because the first request triggers the dial and greet.
- if k > 0 {
- <-greet
- }
- // Block until maxConcurrent requests are sent before sending any more.
- if k >= maxConcurrent {
- <-unblockClient
- }
- body := newStaticCloseChecker("")
- req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
- if k == maxConcurrent {
- // This request will be canceled.
- req.Cancel = cancelClientRequest
- close(cancelClientRequest)
- _, err := ct.tr.RoundTrip(req)
- close(clientRequestCancelled)
- if err == nil {
- errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
- return
- }
- } else {
- resp, err := ct.tr.RoundTrip(req)
- if err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- return
- }
- ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- if resp.StatusCode != 204 {
- errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
- }
- if err := body.isClosed(); err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- }
- }(k)
- }
- return nil
+ // Cancel the maxConcurrent'th request.
+ // The request should fail.
+ close(cancelClientRequest)
+ tc.sync()
+ if err := rts[maxConcurrent].err(); err == nil {
+ t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
}
- ct.server = func() error {
- var wg sync.WaitGroup
- defer wg.Wait()
+ // No requests should be complete, except for the canceled one.
+ for i, rt := range rts {
+ if i != maxConcurrent && rt.done() {
+ t.Fatalf("RoundTrip(%d) is done, but should not be", i)
+ }
+ }
- ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
-
- // Server write loop.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- writeResp := make(chan uint32, maxConcurrent+1)
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-unblockServer
- for id := range writeResp {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: id,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }()
-
- // Server read loop.
- var nreq int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it will have reported any errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame:
- case *SettingsFrame:
- // Wait for the client SETTINGS ack until ending the greet.
- close(greet)
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- gotRequest <- struct{}{}
- nreq++
- writeResp <- f.StreamID
- if nreq == maxConcurrent+1 {
- close(writeResp)
- }
- case *DataFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
- }
-
- ct.run()
-}
+ // Server responds to a request, unblocking the last one.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rts[0].streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ tc.wantHeaders(wantHeader{
+ streamID: rts[maxConcurrent+1].streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
+ },
+ })
+ rts[0].wantStatus(200)
+}
func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var reqSize, resSize uint32 = 8192, 16384
- ct.tr.MaxDecoderHeaderTableSize = reqSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.peerMaxHeaderTableSize, resSize; got != want {
- return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want)
- }
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxDecoderHeaderTableSize = reqSize
+ })
+
+ fr := readFrame[*SettingsFrame](t, tc)
+ if v, ok := fr.Value(SettingHeaderTableSize); !ok {
+ t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
+ } else if v != reqSize {
+ t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
}
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- var found bool
- err = sf.ForeachSetting(func(s Setting) error {
- if s.ID == SettingHeaderTableSize {
- found = true
- if got, want := s.Val, reqSize; got != want {
- return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want)
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- if !found {
- return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting")
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
+ tc.cc.mu.Lock()
+ defer tc.cc.mu.Unlock()
+ if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want {
+ t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
}
- ct.run()
}
func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var peerAdvertisedMaxHeaderTableSize uint32 = 16384
- ct.tr.MaxEncoderHeaderTableSize = 8192
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want {
- return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
- }
- return nil
- }
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxEncoderHeaderTableSize = 8192
+ })
+ tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want {
+ t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
}
- ct.run()
}
func TestAuthorityAddr(t *testing.T) {
@@ -4480,7 +3573,7 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
writeErr := make(chan error, 1)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
var sum int64
for i := 0; i < 100; i++ {
@@ -4493,13 +3586,12 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
}
t.Logf("wrote all %d bytes", sum)
writeErr <- nil
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- res, err := c.Get(st.ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -4530,61 +3622,43 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
// Issue 18891: make sure Request.Body == NoBody means no DATA frame
// is ever sent, even if empty.
func TestTransportNoBodyMeansNoDATA(t *testing.T) {
- ct := newClientTester(t)
-
- unblockClient := make(chan bool)
-
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- <-unblockClient
- return nil
- }
- ct.server = func() error {
- defer close(unblockClient)
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f := f.(type) {
- default:
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- case *HeadersFrame:
- if !f.StreamEnded() {
- return fmt.Errorf("got headers frame without END_STREAM")
- }
- return nil
- }
- }
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true, // END_STREAM should be set when body is http.NoBody
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ },
+ })
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("unexpected frame after headers: %v", fr)
}
- ct.run()
}
func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
b.ReportAllocs()
- st := newServerTester(b,
+ ts := newTestServer(b,
func(w http.ResponseWriter, r *http.Request) {
for i := 0; i < nResHeader; i++ {
name := fmt.Sprint("A-", i)
w.Header().Set(name, "*")
}
},
- optOnlyServer,
optQuiet,
)
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
b.Fatal(err)
}
@@ -4620,16 +3694,15 @@ func (r infiniteReader) Read(b []byte) (int, error) {
// Issue 20521: it is not an error to receive a response and end stream
// from the server without the body being consumed.
func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
// The request body needs to be big enough to trigger flow control.
- req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{})
+ req, _ := http.NewRequest("PUT", ts.URL, infiniteReader{})
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
@@ -4642,41 +3715,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
// Verify transport doesn't crash when receiving bogus response lacking a :status header.
// Issue 22880.
func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- _, err := ct.tr.RoundTrip(req)
- const substr = "malformed response from server: missing status pseudo header"
- if !strings.Contains(fmt.Sprint(err), substr) {
- return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false, // we'll send some DATA to try to crash the transport
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- return nil
- }
- }
- }
- ct.run()
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false, // we'll send some DATA to try to crash the transport
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "content-type", "text/html", // no :status header
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte("payload"))
}
func BenchmarkClientRequestHeaders(b *testing.B) {
@@ -4701,10 +3755,10 @@ func BenchmarkDownloadFrameSize(b *testing.B) {
b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
}
func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
- defer disableGoroutineTracking()()
+ disableGoroutineTracking(b)
const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M
b.ReportAllocs()
- st := newServerTester(b,
+ ts := newTestServer(b,
func(w http.ResponseWriter, r *http.Request) {
// test 1GB transfer
w.Header().Set("Content-Length", strconv.Itoa(transferSize))
@@ -4715,12 +3769,11 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
}
}, optQuiet,
)
- defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
defer tr.CloseIdleConnections()
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
b.Fatal(err)
}
@@ -4779,7 +3832,7 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
closeDone := make(chan struct{})
beforeHeader := func() {}
bodyWrite := func(w http.ResponseWriter) {}
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
defer close(handlerDone)
beforeHeader()
w.WriteHeader(http.StatusOK)
@@ -4796,13 +3849,12 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
t.Error("expected connection closed by client")
}
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
- cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
- req, err := http.NewRequest("GET", st.ts.URL, nil)
+ cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
+ req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -4902,7 +3954,7 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
case closeAtHeaders, closeAtBody:
if closeMode == closeAtBody {
go close(sendBody)
- if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
+ if _, err := io.Copy(io.Discard, res.Body); err == nil {
t.Error("expected a Copy error, got nil")
}
}
@@ -4953,7 +4005,7 @@ func TestClientConnShutdownCancel(t *testing.T) {
func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
calls := 0
someBody := func() io.ReadCloser {
- return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))}
+ return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))}
}
req := &http.Request{
Body: someBody(),
@@ -5024,95 +4076,42 @@ func (r *errReader) Read(p []byte) (int, error) {
}
func testTransportBodyReadError(t *testing.T, body []byte) {
- if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
- // So far we've only seen this be flaky on Windows and Plan 9,
- // perhaps due to TCP behavior on shutdowns while
- // unread data is in flight. This test should be
- // fixed, but a skip is better than annoying people
- // for now.
- t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
- }
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyReadError := errors.New("body read error")
+ b := tc.newRequestBody()
+ b.Write(body)
+ b.closeWithError(bodyReadError)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ var receivedBody []byte
+readFrames:
+ for {
+ switch f := tc.readFrame().(type) {
+ case *DataFrame:
+ receivedBody = append(receivedBody, f.Data()...)
+ case *RSTStreamFrame:
+ break readFrames
+ default:
+ t.Fatalf("unexpected frame: %v", f)
+ case nil:
+ t.Fatalf("transport is idle, want RST_STREAM")
}
- defer close(clientDone)
+ }
+ if !bytes.Equal(receivedBody, body) {
+ t.Fatalf("body: %q; expected %q", receivedBody, body)
+ }
- checkNoStreams := func() error {
- cp, ok := ct.tr.connPool().(*clientConnPool)
- if !ok {
- return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
- }
- cp.mu.Lock()
- defer cp.mu.Unlock()
- conns, ok := cp.conns["dummy.tld:443"]
- if !ok {
- return fmt.Errorf("missing connection")
- }
- if len(conns) != 1 {
- return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
- }
- if activeStreams(conns[0]) != 0 {
- return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
- }
- return nil
- }
- bodyReadError := errors.New("body read error")
- body := &errReader{body, bodyReadError}
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != bodyReadError {
- return fmt.Errorf("err = %v; want %v", err, bodyReadError)
- }
- if err = checkNoStreams(); err != nil {
- return err
- }
- return nil
+ if err := rt.err(); err != bodyReadError {
+ t.Fatalf("err = %v; want %v", err, bodyReadError)
}
- ct.server = func() error {
- ct.greet()
- var receivedBody []byte
- var resetCount int
- for {
- f, err := ct.fr.ReadFrame()
- t.Logf("server: ReadFrame = %v, %v", f, err)
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- if bytes.Compare(receivedBody, body) != 0 {
- return fmt.Errorf("body: %q; expected %q", receivedBody, body)
- }
- if resetCount != 1 {
- return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
- }
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- receivedBody = append(receivedBody, f.Data()...)
- case *RSTStreamFrame:
- resetCount++
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+
+ if got := activeStreams(tc.cc); got != 0 {
+ t.Fatalf("active streams count: %v; want 0", got)
}
- ct.run()
}
func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
@@ -5125,59 +4124,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) {
const reqBody = "some request body"
const resBody = "some response body"
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- body := strings.NewReader(reqBody)
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
+ body := strings.NewReader(reqBody)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ tc.roundTrip(req)
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- if !f.StreamEnded() {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- return fmt.Errorf("data frame without END_STREAM %v", f)
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte(resBody))
- return nil
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ tc.wantFrameType(FrameHeaders)
+ f := readFrame[*DataFrame](t, tc)
+ if !f.StreamEnded() {
+ t.Fatalf("data frame without END_STREAM %v", f)
}
- ct.run()
}
type chunkReader struct {
@@ -5217,15 +4175,14 @@ func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
}
func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.Body.Read(make([]byte, 6))
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- req, _ := http.NewRequest("POST", st.ts.URL, body)
+ req, _ := http.NewRequest("POST", ts.URL, body)
req.ContentLength = contentLen
_, err := tr.RoundTrip(req)
if err != errReqBodyTooLong {
@@ -5305,13 +4262,12 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
- defer st.Close()
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
- cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
+ cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
@@ -5338,12 +4294,11 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
// already. If the request body has started to be sent, one must wait until it
// is completed.
func TestTransportBodyRewindRace(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close")
w.WriteHeader(http.StatusOK)
return
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &http.Transport{
TLSClientConfig: tlsConfigInsecure,
@@ -5362,7 +4317,7 @@ func TestTransportBodyRewindRace(t *testing.T) {
var wg sync.WaitGroup
wg.Add(clients)
for i := 0; i < clients; i++ {
- req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef"))
+ req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("abcdef"))
if err != nil {
t.Fatalf("unexpect new request error: %v", err)
}
@@ -5382,11 +4337,10 @@ func TestTransportBodyRewindRace(t *testing.T) {
// Issue 42498: A request with a body will never be sent if the stream is
// reset prior to sending any data.
func TestTransportServerResetStreamAtHeaders(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
return
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &http.Transport{
TLSClientConfig: tlsConfigInsecure,
@@ -5402,7 +4356,7 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
Transport: tr,
}
- req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF})
+ req, err := http.NewRequest("POST", ts.URL, errorReader{io.EOF})
if err != nil {
t.Fatalf("unexpect new request error: %v", err)
}
@@ -5430,15 +4384,14 @@ func (tr *trackingReader) WasRead() bool {
}
func TestTransportExpectContinue(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/reject":
w.WriteHeader(403)
default:
io.Copy(io.Discard, r.Body)
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &http.Transport{
TLSClientConfig: tlsConfigInsecure,
@@ -5481,7 +4434,7 @@ func TestTransportExpectContinue(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
startTime := time.Now()
- req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body)
+ req, err := http.NewRequest("POST", ts.URL+tc.Path, tc.Body)
if err != nil {
t.Fatal(err)
}
@@ -5593,11 +4546,11 @@ func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
func TestTransportFrameBufferReuse(t *testing.T) {
filler := hex.EncodeToString([]byte(randString(2048)))
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Big"), filler; got != want {
t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
}
- b, err := ioutil.ReadAll(r.Body)
+ b, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("error reading request body: %v", err)
}
@@ -5607,8 +4560,7 @@ func TestTransportFrameBufferReuse(t *testing.T) {
if got, want := r.Trailer.Get("Big"), filler; got != want {
t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
}
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -5619,7 +4571,7 @@ func TestTransportFrameBufferReuse(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
- req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler))
+ req, err := http.NewRequest("POST", ts.URL, strings.NewReader(filler))
if err != nil {
t.Error(err)
return
@@ -5685,7 +4637,7 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
}} {
test := test
t.Run(test.name, func(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if v := r.Header.Get("Big"); v != "" && v != filler {
t.Errorf("request header mismatch")
}
@@ -5695,10 +4647,9 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
if v := r.Trailer.Get("Big"); v != "" && v != filler {
t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler)
}
- }, optOnlyServer, func(s *Server) {
+ }, func(s *Server) {
s.MaxConcurrentStreams = 1
})
- defer st.Close()
// This Transport creates connections that block on writes after 1024 bytes.
connc := make(chan *blockingWriteConn, 1)
@@ -5720,7 +4671,7 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
// Request 1: A small request to ensure we read the server MaxConcurrentStreams.
{
- req, err := http.NewRequest("POST", st.ts.URL, nil)
+ req, err := http.NewRequest("POST", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -5740,7 +4691,7 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
reqc := make(chan struct{})
go func() {
defer close(reqc)
- req, err := test.req(st.ts.URL)
+ req, err := test.req(ts.URL)
if err != nil {
t.Error(err)
return
@@ -5756,7 +4707,7 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
// Request 3: A small request that is sent on a new connection, since request 2
// is hogging the only available stream on the previous connection.
{
- req, err := http.NewRequest("POST", st.ts.URL, nil)
+ req, err := http.NewRequest("POST", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -5791,15 +4742,14 @@ func TestTransportBlockingRequestWrite(t *testing.T) {
func TestTransportCloseRequestBody(t *testing.T) {
var statusCode int
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
- cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
+ cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
@@ -5826,185 +4776,113 @@ func TestTransportCloseRequestBody(t *testing.T) {
}
}
-// collectClientsConnPool is a ClientConnPool that wraps lower and
-// collects what calls were made on it.
-type collectClientsConnPool struct {
- lower ClientConnPool
-
- mu sync.Mutex
- getErrs int
- got []*ClientConn
-}
-
-func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
- cc, err := p.lower.GetClientConn(req, addr)
- p.mu.Lock()
- defer p.mu.Unlock()
- if err != nil {
- p.getErrs++
- return nil, err
- }
- p.got = append(p.got, cc)
- return cc, nil
-}
-
-func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
- p.lower.MarkDead(cc)
-}
-
func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
- ct := newClientTester(t)
- pool := &collectClientsConnPool{
- lower: &clientConnPool{t: ct.tr},
- }
- ct.tr.ConnPool = pool
+ // This test verifies that
+ // - receiving a protocol error on a connection does not interfere with
+ // other requests in flight on that connection;
+ // - the connection is not reused for further requests; and
+ // - the failed request is retried on a new connecection.
+ tt := newTestTransport(t)
+
+ // Start two requests. The first is a long request
+ // that will finish after the second. The second one
+ // will result in the protocol error.
+
+ // Request #1: The long request.
+ req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tt.roundTrip(req1)
+ tc1 := tt.getConn()
+ tc1.wantFrameType(FrameSettings)
+ tc1.wantFrameType(FrameWindowUpdate)
+ tc1.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc1.writeSettings()
+ tc1.wantFrameType(FrameSettings) // settings ACK
+
+ // Request #2(a): The short request.
+ req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt2 := tt.roundTrip(req2)
+ tc1.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
- gotProtoError := make(chan bool, 1)
- ct.tr.CountError = func(errType string) {
- if errType == "recv_rststream_PROTOCOL_ERROR" {
- select {
- case gotProtoError <- true:
- default:
- }
- }
+ // Request #2(a) fails with ErrCodeProtocol.
+ tc1.writeRSTStream(3, ErrCodeProtocol)
+ if rt1.done() {
+ t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress")
}
- ct.client = func() error {
- // Start two requests. The first is a long request
- // that will finish after the second. The second one
- // will result in the protocol error. We check that
- // after the first one closes, the connection then
- // shuts down.
-
- // The long, outer request.
- req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil)
- res1, err := ct.tr.RoundTrip(req1)
- if err != nil {
- return err
- }
- if got, want := res1.Header.Get("Is-Long"), "1"; got != want {
- return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want)
- }
-
- req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil)
- res, err := ct.tr.RoundTrip(req)
- const want = "only one dial allowed in test mode"
- if got := fmt.Sprint(err); got != want {
- t.Errorf("didn't dial again: got %#q; want %#q", got, want)
- }
- if res != nil {
- res.Body.Close()
- }
- select {
- case <-gotProtoError:
- default:
- t.Errorf("didn't get stream protocol error")
- }
-
- if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 {
- t.Errorf("unexpected body read %v, %v", n, err)
- }
-
- pool.mu.Lock()
- defer pool.mu.Unlock()
- if pool.getErrs != 1 {
- t.Errorf("pool get errors = %v; want 1", pool.getErrs)
- }
- if len(pool.got) == 2 {
- if pool.got[0] != pool.got[1] {
- t.Errorf("requests went on different connections")
- }
- cc := pool.got[0]
- cc.mu.Lock()
- if !cc.doNotReuse {
- t.Error("ClientConn not marked doNotReuse")
- }
- cc.mu.Unlock()
-
- select {
- case <-cc.readerDone:
- case <-time.After(5 * time.Second):
- t.Errorf("timeout waiting for reader to be done")
- }
- } else {
- t.Errorf("pool get success = %v; want 2", len(pool.got))
- }
- return nil
+ if rt2.done() {
+ t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress")
}
- ct.server = func() error {
- ct.greet()
- var sentErr bool
- var numHeaders int
- var firstStreamID uint32
- var hbuf bytes.Buffer
- enc := hpack.NewEncoder(&hbuf)
-
- for {
- f, err := ct.fr.ReadFrame()
- if err == io.EOF {
- // Client hung up on us, as it should at the end.
- return nil
- }
- if err != nil {
- return nil
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- numHeaders++
- if numHeaders == 1 {
- firstStreamID = f.StreamID
- hbuf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: hbuf.Bytes(),
- })
- continue
- }
- if !sentErr {
- sentErr = true
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
- ct.fr.WriteData(firstStreamID, true, nil)
- continue
- }
- }
- }
- }
- ct.run()
+ // Request #2(b): The short request is retried on a new connection.
+ tc2 := tt.getConn()
+ tc2.wantFrameType(FrameSettings)
+ tc2.wantFrameType(FrameWindowUpdate)
+ tc2.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc2.writeSettings()
+ tc2.wantFrameType(FrameSettings) // settings ACK
+
+ // Request #2(b) succeeds.
+ tc2.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "201",
+ ),
+ })
+ rt2.wantStatus(201)
+
+ // Request #1 succeeds.
+ tc1.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
}
func TestClientConnReservations(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- }, func(s *Server) {
- s.MaxConcurrentStreams = initialMaxConcurrentStreams
- })
- defer st.Close()
-
- tr := &Transport{TLSClientConfig: tlsConfigInsecure}
- defer tr.CloseIdleConnections()
+ tc := newTestClientConn(t)
+ tc.greet(
+ Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams},
+ )
- cc, err := tr.newClientConn(st.cc, false)
- if err != nil {
- t.Fatal(err)
+ doRoundTrip := func() {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
}
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
n := 0
- for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+ for n <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
n++
}
if n != initialMaxConcurrentStreams {
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
}
- if _, err := cc.RoundTrip(req); err != nil {
- t.Fatalf("RoundTrip error = %v", err)
- }
+ doRoundTrip()
n2 := 0
- for n2 <= 5 && cc.ReserveNewRequest() {
+ for n2 <= 5 && tc.cc.ReserveNewRequest() {
n2++
}
if n2 != 1 {
@@ -6013,11 +4891,11 @@ func TestClientConnReservations(t *testing.T) {
// Use up all the reservations
for i := 0; i < n; i++ {
- cc.RoundTrip(req)
+ doRoundTrip()
}
n2 = 0
- for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+ for n2 <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
n2++
}
if n2 != n {
@@ -6026,47 +4904,34 @@ func TestClientConnReservations(t *testing.T) {
}
func TestTransportTimeoutServerHangs(t *testing.T) {
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.greet()
- req, err := http.NewRequest("PUT", "https://dummy.tld/", nil)
- if err != nil {
- return err
- }
+ ctx, cancel := context.WithCancel(context.Background())
+ req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancel()
- req = req.WithContext(ctx)
- req.Header.Add("Big", strings.Repeat("a", 1<<20))
- _, err = ct.tr.RoundTrip(req)
- if err == nil {
- return errors.New("error should not be nil")
- }
- if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
- return fmt.Errorf("error should be a net error timeout: %v", err)
- }
- return nil
+ tc.wantFrameType(FrameHeaders)
+ tc.advance(5 * time.Second)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- ct.server = func() error {
- ct.greet()
- select {
- case <-time.After(5 * time.Second):
- case <-clientDone:
- }
- return nil
+ if rt.done() {
+ t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
+ }
+
+ cancel()
+ tc.sync()
+ if rt.err() != context.Canceled {
+ t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
}
- ct.run()
}
func TestTransportContentLengthWithoutBody(t *testing.T) {
contentLength := ""
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", contentLength)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
@@ -6093,7 +4958,7 @@ func TestTransportContentLengthWithoutBody(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
contentLength = test.contentLength
- req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req, _ := http.NewRequest("GET", ts.URL, nil)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
@@ -6115,18 +4980,17 @@ func TestTransportContentLengthWithoutBody(t *testing.T) {
}
func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.(http.Flusher).Flush()
io.Copy(io.Discard, r.Body)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
pr, pw := net.Pipe()
- req, err := http.NewRequest("GET", st.ts.URL, pr)
+ req, err := http.NewRequest("GET", ts.URL, pr)
if err != nil {
t.Fatal(err)
}
@@ -6142,19 +5006,18 @@ func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
func TestTransport300ResponseBody(t *testing.T) {
reqc := make(chan struct{})
body := []byte("response body")
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(300)
w.(http.Flusher).Flush()
<-reqc
w.Write(body)
- }, optOnlyServer)
- defer st.Close()
+ })
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
pr, pw := net.Pipe()
- req, err := http.NewRequest("GET", st.ts.URL, pr)
+ req, err := http.NewRequest("GET", ts.URL, pr)
if err != nil {
t.Fatal(err)
}
@@ -6175,11 +5038,9 @@ func TestTransport300ResponseBody(t *testing.T) {
}
func TestTransportWriteByteTimeout(t *testing.T) {
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {},
- optOnlyServer,
)
- defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -6191,7 +5052,7 @@ func TestTransportWriteByteTimeout(t *testing.T) {
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
- _, err := c.Get(st.ts.URL)
+ _, err := c.Get(ts.URL)
if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
}
@@ -6219,11 +5080,9 @@ func (c *slowWriteConn) Write(b []byte) (n int, err error) {
}
func TestTransportSlowWrites(t *testing.T) {
- st := newServerTester(t,
+ ts := newTestServer(t,
func(w http.ResponseWriter, r *http.Request) {},
- optOnlyServer,
)
- defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -6237,7 +5096,7 @@ func TestTransportSlowWrites(t *testing.T) {
c := &http.Client{Transport: tr}
const bodySize = 1 << 20
- resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
+ resp, err := c.Post(ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
if err != nil {
t.Fatal(err)
}
@@ -6251,20 +5110,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
testTransportClosesConnAfterGoAway(t, 1)
}
-type closeOnceConn struct {
- net.Conn
- closed uint32
-}
-
-var errClosed = errors.New("Close of closed connection")
-
-func (c *closeOnceConn) Close() error {
- if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
- return c.Conn.Close()
- }
- return errClosed
-}
-
// testTransportClosesConnAfterGoAway verifies that the transport
// closes a connection after reading a GOAWAY from it.
//
@@ -6272,53 +5117,35 @@ func (c *closeOnceConn) Close() error {
// When 0, the transport (unsuccessfully) retries the request (stream 1);
// when 1, the transport reads the response after receiving the GOAWAY.
func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
- ct := newClientTester(t)
- ct.cc = &closeOnceConn{Conn: ct.cc}
-
- var wg sync.WaitGroup
- wg.Add(1)
- ct.client = func() error {
- defer wg.Done()
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- }
- if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
- t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
- }
- if err = ct.cc.Close(); err != errClosed {
- return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err)
- }
- return nil
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeGoAway(lastStream, ErrCodeNo, nil)
+
+ if lastStream > 0 {
+ // Send a valid response to first request.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
}
- ct.server = func() error {
- defer wg.Wait()
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- return fmt.Errorf("server failed reading HEADERS: %v", err)
- }
- if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil {
- return fmt.Errorf("server failed writing GOAWAY: %v", err)
- }
- if lastStream > 0 {
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- return nil
+ tc.closeWrite()
+ err := rt.err()
+ if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
+ t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
+ }
+ if !tc.isClosed() {
+ t.Errorf("ClientConn did not close its net.Conn, expected it to")
}
-
- ct.run()
}
type slowCloser struct {
@@ -6337,11 +5164,10 @@ func (r *slowCloser) Close() error {
}
func TestTransportSlowClose(t *testing.T) {
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- }, optOnlyServer)
- defer st.Close()
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ })
- client := st.ts.Client()
+ client := ts.Client()
body := &slowCloser{
closing: make(chan struct{}),
closed: make(chan struct{}),
@@ -6350,7 +5176,7 @@ func TestTransportSlowClose(t *testing.T) {
reqc := make(chan struct{})
go func() {
defer close(reqc)
- res, err := client.Post(st.ts.URL, "text/plain", body)
+ res, err := client.Post(ts.URL, "text/plain", body)
if err != nil {
t.Error(err)
}
@@ -6363,9 +5189,749 @@ func TestTransportSlowClose(t *testing.T) {
<-body.closing // wait for POST request to call body.Close
// This GET request should not be blocked by the in-progress POST.
- res, err := client.Get(st.ts.URL)
+ res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
+
+func TestTransportDialTLSContext(t *testing.T) {
+ blockCh := make(chan struct{})
+ serverTLSConfigFunc := func(ts *httptest.Server) {
+ ts.Config.TLSConfig = &tls.Config{
+ // Triggers the server to request the clients certificate
+ // during TLS handshake.
+ ClientAuth: tls.RequestClientCert,
+ }
+ }
+ ts := newTestServer(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ serverTLSConfigFunc,
+ )
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ // Tests that the context provided to `req` is
+ // passed into this function.
+ close(blockCh)
+ <-cri.Context().Done()
+ return nil, cri.Context().Err()
+ },
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ req = req.WithContext(ctx)
+ errCh := make(chan error)
+ go func() {
+ defer close(errCh)
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ }()
+ // Wait for GetClientCertificate handler to be called
+ <-blockCh
+ // Cancel the context
+ cancel()
+ // Expect the cancellation error here
+ err = <-errCh
+ if err == nil {
+ t.Fatal("cancelling context during client certificate fetch did not error as expected")
+ return
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("unexpected error returned after cancellation: %v", err)
+ }
+}
+
+// TestDialRaceResumesDial tests that, given two concurrent requests
+// to the same address, when the first Dial is interrupted because
+// the first request's context is cancelled, the second request
+// resumes the dial automatically.
+func TestDialRaceResumesDial(t *testing.T) {
+ blockCh := make(chan struct{})
+ serverTLSConfigFunc := func(ts *httptest.Server) {
+ ts.Config.TLSConfig = &tls.Config{
+ // Triggers the server to request the clients certificate
+ // during TLS handshake.
+ ClientAuth: tls.RequestClientCert,
+ }
+ }
+ ts := newTestServer(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ serverTLSConfigFunc,
+ )
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ select {
+ case <-blockCh:
+ // If we already errored, return without error.
+ return &tls.Certificate{}, nil
+ default:
+ }
+ close(blockCh)
+ <-cri.Context().Done()
+ return nil, cri.Context().Err()
+ },
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Create two requests with independent cancellation.
+ ctx1, cancel1 := context.WithCancel(context.Background())
+ defer cancel1()
+ req1 := req.WithContext(ctx1)
+ ctx2, cancel2 := context.WithCancel(context.Background())
+ defer cancel2()
+ req2 := req.WithContext(ctx2)
+ errCh := make(chan error)
+ go func() {
+ res, err := tr.RoundTrip(req1)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ }()
+ successCh := make(chan struct{})
+ go func() {
+ // Don't start request until first request
+ // has initiated the handshake.
+ <-blockCh
+ res, err := tr.RoundTrip(req2)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ // Close successCh to indicate that the second request
+ // made it to the server successfully.
+ close(successCh)
+ }()
+ // Wait for GetClientCertificate handler to be called
+ <-blockCh
+ // Cancel the context first
+ cancel1()
+ // Expect the cancellation error here
+ err = <-errCh
+ if err == nil {
+ t.Fatal("cancelling context during client certificate fetch did not error as expected")
+ return
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("unexpected error returned after cancellation: %v", err)
+ }
+ select {
+ case err := <-errCh:
+ t.Fatalf("unexpected second error: %v", err)
+ case <-successCh:
+ }
+}
+
+func TestTransportDataAfter1xxHeader(t *testing.T) {
+ // Discard logger output to avoid spamming stderr.
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame.
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "100",
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte{0})
+ err := rt.err()
+ if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
+ t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err)
+ }
+ tc.wantFrameType(FrameRSTStream)
+}
+
+func TestIssue66763Race(t *testing.T) {
+ tr := &Transport{
+ IdleConnTimeout: 1 * time.Nanosecond,
+ AllowHTTP: true, // issue 66763 only occurs when AllowHTTP is true
+ }
+ defer tr.CloseIdleConnections()
+
+ cli, srv := net.Pipe()
+ donec := make(chan struct{})
+ go func() {
+ // Creating the client conn may succeed or fail,
+ // depending on when the idle timeout happens.
+ // Either way, the idle timeout will close the net.Conn.
+ tr.NewClientConn(cli)
+ close(donec)
+ }()
+
+ // The client sends its preface and SETTINGS frame,
+ // and then closes its conn after the idle timeout.
+ io.ReadAll(srv)
+ srv.Close()
+
+ <-donec
+}
+
+// Issue 67671: Sending a Connection: close request on a Transport with AllowHTTP
+// set caused a the transport to wedge.
+func TestIssue67671(t *testing.T) {
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ AllowHTTP: true,
+ }
+ defer tr.CloseIdleConnections()
+ req, _ := http.NewRequest("GET", ts.URL, nil)
+ req.Close = true
+ for i := 0; i < 2; i++ {
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ }
+}
+
+func TestTransport1xxLimits(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ opt any
+ ctxfn func(context.Context) context.Context
+ hcount int
+ limited bool
+ }{{
+ name: "default",
+ hcount: 10,
+ limited: false,
+ }, {
+ name: "MaxHeaderListSize",
+ opt: func(tr *Transport) {
+ tr.MaxHeaderListSize = 10000
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "MaxResponseHeaderBytes",
+ opt: func(tr *http.Transport) {
+ tr.MaxResponseHeaderBytes = 10000
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "limit by client trace",
+ ctxfn: func(ctx context.Context) context.Context {
+ count := 0
+ return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ count++
+ if count >= 10 {
+ return errors.New("too many 1xx")
+ }
+ return nil
+ },
+ })
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "limit disabled by client trace",
+ opt: func(tr *Transport) {
+ tr.MaxHeaderListSize = 10000
+ },
+ ctxfn: func(ctx context.Context) context.Context {
+ return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ return nil
+ },
+ })
+ },
+ hcount: 20,
+ limited: false,
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ tc := newTestClientConn(t, test.opt)
+ tc.greet()
+
+ ctx := context.Background()
+ if test.ctxfn != nil {
+ ctx = test.ctxfn(ctx)
+ }
+ req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ for i := 0; i < test.hcount; i++ {
+ if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
+ t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "103",
+ "x-field", strings.Repeat("a", 1000),
+ ),
+ })
+ }
+ if test.limited {
+ tc.wantFrameType(FrameRSTStream)
+ } else {
+ tc.wantIdle()
+ }
+ })
+ }
+}
+
+func TestTransportSendPingWithReset(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.StrictMaxConcurrentStreams = true
+ })
+
+ const maxConcurrent = 3
+ tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+ // Start several requests.
+ var rts []*testRoundTrip
+ for i := 0; i < maxConcurrent+1; i++ {
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tc.roundTrip(req)
+ if i >= maxConcurrent {
+ tc.wantIdle()
+ continue
+ }
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+ rts = append(rts, rt)
+ }
+
+ // Cancel one request. We send a PING frame along with the RST_STREAM.
+ rts[0].response().Body.Close()
+ tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel)
+ pf := readFrame[*PingFrame](t, tc)
+ tc.wantIdle()
+
+ // Cancel another request. No PING frame, since one is in flight.
+ rts[1].response().Body.Close()
+ tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel)
+ tc.wantIdle()
+
+ // Respond to the PING.
+ // This finalizes the previous resets, and allows the pending request to be sent.
+ tc.writePing(true, pf.Data)
+ tc.wantFrameType(FrameHeaders)
+ tc.wantIdle()
+
+ // Receive a byte of data for the remaining stream, which resets our ability
+ // to send pings (see comment on ClientConn.rstStreamPingsBlocked).
+ tc.writeData(rts[2].streamID(), false, []byte{0})
+
+ // Cancel the last request. We send another PING, since none are in flight.
+ rts[2].response().Body.Close()
+ tc.wantRSTStream(rts[2].streamID(), ErrCodeCancel)
+ tc.wantFrameType(FramePing)
+ tc.wantIdle()
+}
+
+// Issue #70505: gRPC gets upset if we send more than 2 pings per HEADERS/DATA frame
+// sent by the server.
+func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ makeAndResetRequest := func() {
+ t.Helper()
+ ctx, cancel := context.WithCancel(context.Background())
+ req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ cancel()
+ tc.wantRSTStream(rt.streamID(), ErrCodeCancel) // client sends RST_STREAM
+ }
+
+ // Create a request and cancel it.
+ // The client sends a PING frame along with the reset.
+ makeAndResetRequest()
+ pf1 := readFrame[*PingFrame](t, tc) // client sends PING
+
+ // Create another request and cancel it.
+ // We do not send a PING frame along with the reset,
+ // because we haven't received a HEADERS or DATA frame from the server
+ // since the last PING we sent.
+ makeAndResetRequest()
+
+ // Server belatedly responds to request 1.
+ // The server has not responded to our first PING yet.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // Create yet another request and cancel it.
+ // We still do not send a PING frame along with the reset.
+ // We've received a HEADERS frame, but it came before the response to the PING.
+ makeAndResetRequest()
+
+ // The server responds to our PING.
+ tc.writePing(true, pf1.Data)
+
+ // Create yet another request and cancel it.
+ // Still no PING frame; we got a response to the previous one,
+ // but no HEADERS or DATA.
+ makeAndResetRequest()
+
+ // Server belatedly responds to the second request.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 3,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // One more request.
+ // This time we send a PING frame.
+ makeAndResetRequest()
+ tc.wantFrameType(FramePing)
+}
+
+func TestTransportConnBecomesUnresponsive(t *testing.T) {
+ // We send a number of requests in series to an unresponsive connection.
+ // Each request is canceled or times out without a response.
+ // Eventually, we open a new connection rather than trying to use the old one.
+ tt := newTestTransport(t)
+
+ const maxConcurrent = 3
+
+ t.Logf("first request opens a new connection and succeeds")
+ req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt1 := tt.roundTrip(req1)
+ tc1 := tt.getConn()
+ tc1.wantFrameType(FrameSettings)
+ tc1.wantFrameType(FrameWindowUpdate)
+ hf1 := readFrame[*HeadersFrame](t, tc1)
+ tc1.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+ tc1.wantFrameType(FrameSettings) // ack
+ tc1.writeHeaders(HeadersFrameParam{
+ StreamID: hf1.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
+ rt1.response().Body.Close()
+
+ // Send more requests.
+ // None receive a response.
+ // Each is canceled.
+ for i := 0; i < maxConcurrent; i++ {
+ t.Logf("request %v receives no response and is canceled", i)
+ ctx, cancel := context.WithCancel(context.Background())
+ req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
+ tt.roundTrip(req)
+ if tt.hasConn() {
+ t.Fatalf("new connection created; expect existing conn to be reused")
+ }
+ tc1.wantFrameType(FrameHeaders)
+ cancel()
+ tc1.wantFrameType(FrameRSTStream)
+ if i == 0 {
+ tc1.wantFrameType(FramePing)
+ }
+ tc1.wantIdle()
+ }
+
+ // The conn has hit its concurrency limit.
+ // The next request is sent on a new conn.
+ req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt2 := tt.roundTrip(req2)
+ tc2 := tt.getConn()
+ tc2.wantFrameType(FrameSettings)
+ tc2.wantFrameType(FrameWindowUpdate)
+ hf := readFrame[*HeadersFrame](t, tc2)
+ tc2.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+ tc2.wantFrameType(FrameSettings) // ack
+ tc2.writeHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc2.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt2.wantStatus(200)
+ rt2.response().Body.Close()
+}
+
+// Test that the Transport can use a conn provided to it by a TLSNextProto hook.
+func TestTransportTLSNextProtoConnOK(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+ tc.greet()
+
+ // Send a request on the Transport.
+ // It uses the conn we provided.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ },
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+}
+
+// Test the case where a conn provided via a TLSNextProto hook immediately encounters an error.
+func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.closeWrite()
+
+ // Send a request on the Transport.
+ //
+ // It should fail, because we have no usable connections, but not with ErrNoCachedConn.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err)
+ }
+
+ // Send the request again.
+ // This time it should fail with ErrNoCachedConn,
+ // because the dead conn has been removed from the pool.
+ rt = tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip after broken conn is used: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+// Test the case where a conn provided via a TLSNextProto hook is closed for idleness
+// before we use it.
+func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) {
+ t1 := &http.Transport{
+ IdleConnTimeout: 1 * time.Second,
+ }
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.advance(2 * time.Second)
+
+ // Send a request on the Transport.
+ //
+ // It should fail with ErrNoCachedConn.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip with conn closed for idleness: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+// Test the case where a conn provided via a TLSNextProto hook immediately encounters an error,
+// but no requests are sent which would use the bad connection.
+func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.closeWrite()
+
+ // Some time passes.
+ // The dead connection is removed from the pool.
+ tc.advance(10 * time.Second)
+
+ // Send a request on the Transport.
+ //
+ // It should fail with ErrNoCachedConn, because the pool contains no conns.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip after broken conn expires: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+func TestExtendedConnectClientWithServerSupport(t *testing.T) {
+ setForTest(t, &disableExtendedConnectProtocol, false)
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get(":protocol") != "extended-connect" {
+ t.Fatalf("unexpected :protocol header received")
+ }
+ t.Log(io.Copy(w, r.Body))
+ })
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ AllowHTTP: true,
+ }
+ defer tr.CloseIdleConnections()
+ pr, pw := io.Pipe()
+ pwDone := make(chan struct{})
+ req, _ := http.NewRequest("CONNECT", ts.URL, pr)
+ req.Header.Set(":protocol", "extended-connect")
+ req.Header.Set("X-A", "A")
+ req.Header.Set("X-B", "B")
+ req.Header.Set("X-C", "C")
+ go func() {
+ pw.Write([]byte("hello, extended connect"))
+ pw.Close()
+ close(pwDone)
+ }()
+
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, []byte("hello, extended connect")) {
+ t.Fatal("unexpected body received")
+ }
+}
+
+func TestExtendedConnectClientWithoutServerSupport(t *testing.T) {
+ setForTest(t, &disableExtendedConnectProtocol, true)
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(w, r.Body)
+ })
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ AllowHTTP: true,
+ }
+ defer tr.CloseIdleConnections()
+ pr, pw := io.Pipe()
+ pwDone := make(chan struct{})
+ req, _ := http.NewRequest("CONNECT", ts.URL, pr)
+ req.Header.Set(":protocol", "extended-connect")
+ req.Header.Set("X-A", "A")
+ req.Header.Set("X-B", "B")
+ req.Header.Set("X-C", "C")
+ go func() {
+ pw.Write([]byte("hello, extended connect"))
+ pw.Close()
+ close(pwDone)
+ }()
+
+ _, err := tr.RoundTrip(req)
+ if !errors.Is(err, errExtendedConnectNotSupported) {
+ t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err)
+ }
+}
+
+// Issue #70658: Make sure extended CONNECT requests don't get stuck if a
+// connection fails early in its lifetime.
+func TestExtendedConnectReadFrameError(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+
+ req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil)
+ req.Header.Set(":protocol", "extended-connect")
+ rt := tc.roundTrip(req)
+ tc.wantIdle() // waiting for SETTINGS response
+
+ tc.closeWrite() // connection breaks without sending SETTINGS
+ if !rt.done() {
+ t.Fatalf("after connection closed: RoundTrip still running; want done")
+ }
+ if rt.err() == nil {
+ t.Fatalf("after connection closed: RoundTrip succeeded; want error")
+ }
+}
diff --git a/http2/unencrypted.go b/http2/unencrypted.go
new file mode 100644
index 0000000000..b2de211613
--- /dev/null
+++ b/http2/unencrypted.go
@@ -0,0 +1,32 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+)
+
+const nextProtoUnencryptedHTTP2 = "unencrypted_http2"
+
+// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn.
+//
+// TLSNextProto functions accept a *tls.Conn.
+//
+// When passing an unencrypted HTTP/2 connection to a TLSNextProto function,
+// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection.
+// To be extra careful about mistakes (accidentally dropping TLS encryption in a place
+// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method
+// that returns the actual connection we want to use.
+func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) {
+ conner, ok := tc.NetConn().(interface {
+ UnencryptedNetConn() net.Conn
+ })
+ if !ok {
+ return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff")
+ }
+ return conner.UnencryptedNetConn(), nil
+}
diff --git a/http2/write.go b/http2/write.go
index 33f61398a1..fdb35b9477 100644
--- a/http2/write.go
+++ b/http2/write.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/httpcommon"
)
// writeFramer is implemented by any type that is used to write frames.
@@ -131,6 +132,16 @@ func (se StreamError) writeFrame(ctx writeContext) error {
func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max }
+type writePing struct {
+ data [8]byte
+}
+
+func (w writePing) writeFrame(ctx writeContext) error {
+ return ctx.Framer().WritePing(false, w.data)
+}
+
+func (w writePing) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.data) <= max }
+
type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error {
@@ -341,7 +352,7 @@ func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
}
for _, k := range keys {
vv := h[k]
- k, ascii := lowerHeader(k)
+ k, ascii := httpcommon.LowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
diff --git a/http2/writesched_priority.go b/http2/writesched_priority.go
index 0a242c669e..f6783339d1 100644
--- a/http2/writesched_priority.go
+++ b/http2/writesched_priority.go
@@ -443,8 +443,8 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
}
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) {
- for k := n.kids; k != nil; k = k.next {
- k.setParent(n.parent)
+ for n.kids != nil {
+ n.kids.setParent(n.parent)
}
n.setParent(nil)
delete(ws.nodes, n.id)
diff --git a/http2/writesched_priority_test.go b/http2/writesched_priority_test.go
index b579ef9879..5aad057bea 100644
--- a/http2/writesched_priority_test.go
+++ b/http2/writesched_priority_test.go
@@ -562,3 +562,37 @@ func TestPriorityRstStreamOnNonOpenStreams(t *testing.T) {
t.Error(err)
}
}
+
+// https://go.dev/issue/66514
+func TestPriorityIssue66514(t *testing.T) {
+ addDep := func(ws *priorityWriteScheduler, child uint32, parent uint32) {
+ ws.AdjustStream(child, PriorityParam{
+ StreamDep: parent,
+ Exclusive: false,
+ Weight: 16,
+ })
+ }
+
+ validateDepTree := func(ws *priorityWriteScheduler, id uint32, t *testing.T) {
+ for n := ws.nodes[id]; n != nil; n = n.parent {
+ if n.parent == nil {
+ if n.id != uint32(0) {
+ t.Errorf("detected nodes not parented to 0")
+ }
+ }
+ }
+ }
+
+ ws := NewPriorityWriteScheduler(nil).(*priorityWriteScheduler)
+
+ // Root entry
+ addDep(ws, uint32(1), uint32(0))
+ addDep(ws, uint32(3), uint32(1))
+ addDep(ws, uint32(5), uint32(1))
+
+ for id := uint32(7); id < uint32(100); id += uint32(4) {
+ addDep(ws, id, id-uint32(4))
+ addDep(ws, id+uint32(2), id-uint32(4))
+ validateDepTree(ws, id, t)
+ }
+}
diff --git a/http2/z_spec_test.go b/http2/z_spec_test.go
deleted file mode 100644
index 610b2cdbc2..0000000000
--- a/http2/z_spec_test.go
+++ /dev/null
@@ -1,356 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package http2
-
-import (
- "bytes"
- "encoding/xml"
- "flag"
- "fmt"
- "io"
- "os"
- "reflect"
- "regexp"
- "sort"
- "strconv"
- "strings"
- "sync"
- "testing"
-)
-
-var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests")
-
-// The global map of sentence coverage for the http2 spec.
-var defaultSpecCoverage specCoverage
-
-var loadSpecOnce sync.Once
-
-func loadSpec() {
- if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil {
- panic(err)
- } else {
- defaultSpecCoverage = readSpecCov(f)
- f.Close()
- }
-}
-
-// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not
-// "covered" will be included in report outputted by TestSpecCoverage.
-func covers(sec, sentences string) {
- loadSpecOnce.Do(loadSpec)
- defaultSpecCoverage.cover(sec, sentences)
-}
-
-type specPart struct {
- section string
- sentence string
-}
-
-func (ss specPart) Less(oo specPart) bool {
- atoi := func(s string) int {
- n, err := strconv.Atoi(s)
- if err != nil {
- panic(err)
- }
- return n
- }
- a := strings.Split(ss.section, ".")
- b := strings.Split(oo.section, ".")
- for len(a) > 0 {
- if len(b) == 0 {
- return false
- }
- x, y := atoi(a[0]), atoi(b[0])
- if x == y {
- a, b = a[1:], b[1:]
- continue
- }
- return x < y
- }
- if len(b) > 0 {
- return true
- }
- return false
-}
-
-type bySpecSection []specPart
-
-func (a bySpecSection) Len() int { return len(a) }
-func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) }
-func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
-
-type specCoverage struct {
- coverage map[specPart]bool
- d *xml.Decoder
-}
-
-func joinSection(sec []int) string {
- s := fmt.Sprintf("%d", sec[0])
- for _, n := range sec[1:] {
- s = fmt.Sprintf("%s.%d", s, n)
- }
- return s
-}
-
-func (sc specCoverage) readSection(sec []int) {
- var (
- buf = new(bytes.Buffer)
- sub = 0
- )
- for {
- tk, err := sc.d.Token()
- if err != nil {
- if err == io.EOF {
- return
- }
- panic(err)
- }
- switch v := tk.(type) {
- case xml.StartElement:
- if skipElement(v) {
- if err := sc.d.Skip(); err != nil {
- panic(err)
- }
- if v.Name.Local == "section" {
- sub++
- }
- break
- }
- switch v.Name.Local {
- case "section":
- sub++
- sc.readSection(append(sec, sub))
- case "xref":
- buf.Write(sc.readXRef(v))
- }
- case xml.CharData:
- if len(sec) == 0 {
- break
- }
- buf.Write(v)
- case xml.EndElement:
- if v.Name.Local == "section" {
- sc.addSentences(joinSection(sec), buf.String())
- return
- }
- }
- }
-}
-
-func (sc specCoverage) readXRef(se xml.StartElement) []byte {
- var b []byte
- for {
- tk, err := sc.d.Token()
- if err != nil {
- panic(err)
- }
- switch v := tk.(type) {
- case xml.CharData:
- if b != nil {
- panic("unexpected CharData")
- }
- b = []byte(string(v))
- case xml.EndElement:
- if v.Name.Local != "xref" {
- panic("expected ")
- }
- if b != nil {
- return b
- }
- sig := attrSig(se)
- switch sig {
- case "target":
- return []byte(fmt.Sprintf("[%s]", attrValue(se, "target")))
- case "fmt-of,rel,target", "fmt-,,rel,target":
- return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel")))
- case "fmt-of,sec,target", "fmt-,,sec,target":
- return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target")))
- case "fmt-of,rel,sec,target":
- return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel")))
- default:
- panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se)))
- }
- default:
- panic(fmt.Sprintf("unexpected tag %q", v))
- }
- }
-}
-
-var skipAnchor = map[string]bool{
- "intro": true,
- "Overview": true,
-}
-
-var skipTitle = map[string]bool{
- "Acknowledgements": true,
- "Change Log": true,
- "Document Organization": true,
- "Conventions and Terminology": true,
-}
-
-func skipElement(s xml.StartElement) bool {
- switch s.Name.Local {
- case "artwork":
- return true
- case "section":
- for _, attr := range s.Attr {
- switch attr.Name.Local {
- case "anchor":
- if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") {
- return true
- }
- case "title":
- if skipTitle[attr.Value] {
- return true
- }
- }
- }
- }
- return false
-}
-
-func readSpecCov(r io.Reader) specCoverage {
- sc := specCoverage{
- coverage: map[specPart]bool{},
- d: xml.NewDecoder(r)}
- sc.readSection(nil)
- return sc
-}
-
-func (sc specCoverage) addSentences(sec string, sentence string) {
- for _, s := range parseSentences(sentence) {
- sc.coverage[specPart{sec, s}] = false
- }
-}
-
-func (sc specCoverage) cover(sec string, sentence string) {
- for _, s := range parseSentences(sentence) {
- p := specPart{sec, s}
- if _, ok := sc.coverage[p]; !ok {
- panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s))
- }
- sc.coverage[specPart{sec, s}] = true
- }
-
-}
-
-var whitespaceRx = regexp.MustCompile(`\s+`)
-
-func parseSentences(sens string) []string {
- sens = strings.TrimSpace(sens)
- if sens == "" {
- return nil
- }
- ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ")
- for i, s := range ss {
- s = strings.TrimSpace(s)
- if !strings.HasSuffix(s, ".") {
- s += "."
- }
- ss[i] = s
- }
- return ss
-}
-
-func TestSpecParseSentences(t *testing.T) {
- tests := []struct {
- ss string
- want []string
- }{
- {"Sentence 1. Sentence 2.",
- []string{
- "Sentence 1.",
- "Sentence 2.",
- }},
- {"Sentence 1. \nSentence 2.\tSentence 3.",
- []string{
- "Sentence 1.",
- "Sentence 2.",
- "Sentence 3.",
- }},
- }
-
- for i, tt := range tests {
- got := parseSentences(tt.ss)
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("%d: got = %q, want %q", i, got, tt.want)
- }
- }
-}
-
-func TestSpecCoverage(t *testing.T) {
- if !*coverSpec {
- t.Skip()
- }
-
- loadSpecOnce.Do(loadSpec)
-
- var (
- list []specPart
- cv = defaultSpecCoverage.coverage
- total = len(cv)
- complete = 0
- )
-
- for sp, touched := range defaultSpecCoverage.coverage {
- if touched {
- complete++
- } else {
- list = append(list, sp)
- }
- }
- sort.Stable(bySpecSection(list))
-
- if testing.Short() && len(list) > 5 {
- list = list[:5]
- }
-
- for _, p := range list {
- t.Errorf("\tSECTION %s: %s", p.section, p.sentence)
- }
-
- t.Logf("%d/%d (%d%%) sentences covered", complete, total, (complete/total)*100)
-}
-
-func attrSig(se xml.StartElement) string {
- var names []string
- for _, attr := range se.Attr {
- if attr.Name.Local == "fmt" {
- names = append(names, "fmt-"+attr.Value)
- } else {
- names = append(names, attr.Name.Local)
- }
- }
- sort.Strings(names)
- return strings.Join(names, ",")
-}
-
-func attrValue(se xml.StartElement, attr string) string {
- for _, a := range se.Attr {
- if a.Name.Local == attr {
- return a.Value
- }
- }
- panic("unknown attribute " + attr)
-}
-
-func TestSpecPartLess(t *testing.T) {
- tests := []struct {
- sec1, sec2 string
- want bool
- }{
- {"6.2.1", "6.2", false},
- {"6.2", "6.2.1", true},
- {"6.10", "6.10.1", true},
- {"6.10", "6.1.1", false}, // 10, not 1
- {"6.1", "6.1", false}, // equal, so not less
- }
- for _, tt := range tests {
- got := (specPart{tt.sec1, "foo"}).Less(specPart{tt.sec2, "foo"})
- if got != tt.want {
- t.Errorf("Less(%q, %q) = %v; want %v", tt.sec1, tt.sec2, got, tt.want)
- }
- }
-}
diff --git a/icmp/helper_posix.go b/icmp/helper_posix.go
index 6c3ebfaed4..f625483f06 100644
--- a/icmp/helper_posix.go
+++ b/icmp/helper_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
package icmp
diff --git a/icmp/listen_posix.go b/icmp/listen_posix.go
index 6aea804788..b7cb15b7dc 100644
--- a/icmp/listen_posix.go
+++ b/icmp/listen_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
package icmp
diff --git a/icmp/listen_stub.go b/icmp/listen_stub.go
index 1acfb74b60..7b76be1cb3 100644
--- a/icmp/listen_stub.go
+++ b/icmp/listen_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows
package icmp
diff --git a/idna/go118.go b/idna/go118.go
index c5c4338dbe..712f1ad839 100644
--- a/idna/go118.go
+++ b/idna/go118.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build go1.18
-// +build go1.18
package idna
diff --git a/idna/idna10.0.0.go b/idna/idna10.0.0.go
index 64ccf85feb..7b37178847 100644
--- a/idna/idna10.0.0.go
+++ b/idna/idna10.0.0.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build go1.10
-// +build go1.10
// Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to
diff --git a/idna/idna9.0.0.go b/idna/idna9.0.0.go
index ee1698cefb..cc6a892a4a 100644
--- a/idna/idna9.0.0.go
+++ b/idna/idna9.0.0.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build !go1.10
-// +build !go1.10
// Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to
diff --git a/idna/pre_go118.go b/idna/pre_go118.go
index 3aaccab1c5..40e74bb3d2 100644
--- a/idna/pre_go118.go
+++ b/idna/pre_go118.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build !go1.18
-// +build !go1.18
package idna
diff --git a/idna/tables10.0.0.go b/idna/tables10.0.0.go
index d1d62ef459..c6c2bf10a6 100644
--- a/idna/tables10.0.0.go
+++ b/idna/tables10.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.10 && !go1.13
-// +build go1.10,!go1.13
package idna
diff --git a/idna/tables11.0.0.go b/idna/tables11.0.0.go
index 167efba712..76789393cc 100644
--- a/idna/tables11.0.0.go
+++ b/idna/tables11.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.13 && !go1.14
-// +build go1.13,!go1.14
package idna
diff --git a/idna/tables12.0.0.go b/idna/tables12.0.0.go
index ab40f7bcc3..0600cd2ae5 100644
--- a/idna/tables12.0.0.go
+++ b/idna/tables12.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.14 && !go1.16
-// +build go1.14,!go1.16
package idna
diff --git a/idna/tables13.0.0.go b/idna/tables13.0.0.go
index 66701eadfb..2fb768ef6d 100644
--- a/idna/tables13.0.0.go
+++ b/idna/tables13.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.16 && !go1.21
-// +build go1.16,!go1.21
package idna
diff --git a/idna/tables15.0.0.go b/idna/tables15.0.0.go
index 40033778f0..5ff05fe1af 100644
--- a/idna/tables15.0.0.go
+++ b/idna/tables15.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.21
-// +build go1.21
package idna
diff --git a/idna/tables9.0.0.go b/idna/tables9.0.0.go
index 4074b5332e..0f25e84ca2 100644
--- a/idna/tables9.0.0.go
+++ b/idna/tables9.0.0.go
@@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build !go1.10
-// +build !go1.10
package idna
diff --git a/idna/trie12.0.0.go b/idna/trie12.0.0.go
index bb63f904b3..8a75b96673 100644
--- a/idna/trie12.0.0.go
+++ b/idna/trie12.0.0.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build !go1.16
-// +build !go1.16
package idna
diff --git a/idna/trie13.0.0.go b/idna/trie13.0.0.go
index 7d68a8dc13..fa45bb9074 100644
--- a/idna/trie13.0.0.go
+++ b/idna/trie13.0.0.go
@@ -5,7 +5,6 @@
// license that can be found in the LICENSE file.
//go:build go1.16
-// +build go1.16
package idna
diff --git a/internal/gate/gate.go b/internal/gate/gate.go
new file mode 100644
index 0000000000..5c026c002d
--- /dev/null
+++ b/internal/gate/gate.go
@@ -0,0 +1,76 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package gate contains an alternative condition variable.
+package gate
+
+import "context"
+
+// A Gate is a monitor (mutex + condition variable) with one bit of state.
+//
+// The condition may be either set or unset.
+// Lock operations may be unconditional, or wait for the condition to be set.
+// Unlock operations record the new state of the condition.
+type Gate struct {
+ // When unlocked, exactly one of set or unset contains a value.
+ // When locked, neither chan contains a value.
+ set chan struct{}
+ unset chan struct{}
+}
+
+// New returns a new, unlocked gate.
+func New(set bool) Gate {
+ g := Gate{
+ set: make(chan struct{}, 1),
+ unset: make(chan struct{}, 1),
+ }
+ g.Unlock(set)
+ return g
+}
+
+// Lock acquires the gate unconditionally.
+// It reports whether the condition is set.
+func (g *Gate) Lock() (set bool) {
+ select {
+ case <-g.set:
+ return true
+ case <-g.unset:
+ return false
+ }
+}
+
+// WaitAndLock waits until the condition is set before acquiring the gate.
+// If the context expires, WaitAndLock returns an error and does not acquire the gate.
+func (g *Gate) WaitAndLock(ctx context.Context) error {
+ select {
+ case <-g.set:
+ return nil
+ default:
+ }
+ select {
+ case <-g.set:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// LockIfSet acquires the gate if and only if the condition is set.
+func (g *Gate) LockIfSet() (acquired bool) {
+ select {
+ case <-g.set:
+ return true
+ default:
+ return false
+ }
+}
+
+// Unlock sets the condition and releases the gate.
+func (g *Gate) Unlock(set bool) {
+ if set {
+ g.set <- struct{}{}
+ } else {
+ g.unset <- struct{}{}
+ }
+}
diff --git a/internal/gate/gate_test.go b/internal/gate/gate_test.go
new file mode 100644
index 0000000000..87a78b15af
--- /dev/null
+++ b/internal/gate/gate_test.go
@@ -0,0 +1,85 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gate_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "golang.org/x/net/internal/gate"
+)
+
+func TestGateLockAndUnlock(t *testing.T) {
+ g := gate.New(false)
+ if set := g.Lock(); set {
+ t.Errorf("g.Lock of never-locked gate: true, want false")
+ }
+ unlockedc := make(chan struct{})
+ donec := make(chan struct{})
+ go func() {
+ defer close(donec)
+ if set := g.Lock(); !set {
+ t.Errorf("g.Lock of set gate: false, want true")
+ }
+ select {
+ case <-unlockedc:
+ default:
+ t.Errorf("g.Lock succeeded while gate was held")
+ }
+ g.Unlock(false)
+ }()
+ time.Sleep(1 * time.Millisecond)
+ close(unlockedc)
+ g.Unlock(true)
+ <-donec
+ if set := g.Lock(); set {
+ t.Errorf("g.Lock of unset gate: true, want false")
+ }
+}
+
+func TestGateWaitAndLock(t *testing.T) {
+ g := gate.New(false)
+ // WaitAndLock is canceled.
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+ if err := g.WaitAndLock(ctx); err != context.DeadlineExceeded {
+ t.Fatalf("g.WaitAndLock = %v, want context.DeadlineExceeded", err)
+ }
+ // WaitAndLock succeeds.
+ set := false
+ go func() {
+ time.Sleep(1 * time.Millisecond)
+ g.Lock()
+ set = true
+ g.Unlock(true)
+ }()
+ if err := g.WaitAndLock(context.Background()); err != nil {
+ t.Fatalf("g.WaitAndLock = %v, want nil", err)
+ }
+ if !set {
+ t.Fatalf("g.WaitAndLock returned before gate was set")
+ }
+ g.Unlock(true)
+ // WaitAndLock succeeds when the gate is set and the context is canceled.
+ if err := g.WaitAndLock(ctx); err != nil {
+ t.Fatalf("g.WaitAndLock = %v, want nil", err)
+ }
+}
+
+func TestGateLockIfSet(t *testing.T) {
+ g := gate.New(false)
+ if locked := g.LockIfSet(); locked {
+ t.Fatalf("g.LockIfSet of unset gate = %v, want false", locked)
+ }
+ g.Lock()
+ if locked := g.LockIfSet(); locked {
+ t.Fatalf("g.LockIfSet of locked gate = %v, want false", locked)
+ }
+ g.Unlock(true)
+ if locked := g.LockIfSet(); !locked {
+ t.Fatalf("g.LockIfSet of set gate = %v, want true", locked)
+ }
+}
diff --git a/internal/http3/body.go b/internal/http3/body.go
new file mode 100644
index 0000000000..cdde482efb
--- /dev/null
+++ b/internal/http3/body.go
@@ -0,0 +1,142 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "sync"
+)
+
+// A bodyWriter writes a request or response body to a stream
+// as a series of DATA frames.
+type bodyWriter struct {
+ st *stream
+ remain int64 // -1 when content-length is not known
+ flush bool // flush the stream after every write
+ name string // "request" or "response"
+}
+
+func (w *bodyWriter) Write(p []byte) (n int, err error) {
+ if w.remain >= 0 && int64(len(p)) > w.remain {
+ return 0, &streamError{
+ code: errH3InternalError,
+ message: w.name + " body longer than specified content length",
+ }
+ }
+ w.st.writeVarint(int64(frameTypeData))
+ w.st.writeVarint(int64(len(p)))
+ n, err = w.st.Write(p)
+ if w.remain >= 0 {
+ w.remain -= int64(n)
+ }
+ if w.flush && err == nil {
+ err = w.st.Flush()
+ }
+ if err != nil {
+ err = fmt.Errorf("writing %v body: %w", w.name, err)
+ }
+ return n, err
+}
+
+func (w *bodyWriter) Close() error {
+ if w.remain > 0 {
+ return errors.New(w.name + " body shorter than specified content length")
+ }
+ return nil
+}
+
+// A bodyReader reads a request or response body from a stream.
+type bodyReader struct {
+ st *stream
+
+ mu sync.Mutex
+ remain int64
+ err error
+}
+
+func (r *bodyReader) Read(p []byte) (n int, err error) {
+ // The HTTP/1 and HTTP/2 implementations both permit concurrent reads from a body,
+ // in the sense that the race detector won't complain.
+ // Use a mutex here to provide the same behavior.
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.err != nil {
+ return 0, r.err
+ }
+ defer func() {
+ if err != nil {
+ r.err = err
+ }
+ }()
+ if r.st.lim == 0 {
+ // We've finished reading the previous DATA frame, so end it.
+ if err := r.st.endFrame(); err != nil {
+ return 0, err
+ }
+ }
+ // Read the next DATA frame header,
+ // if we aren't already in the middle of one.
+ for r.st.lim < 0 {
+ ftype, err := r.st.readFrameHeader()
+ if err == io.EOF && r.remain > 0 {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body shorter than content-length",
+ }
+ }
+ if err != nil {
+ return 0, err
+ }
+ switch ftype {
+ case frameTypeData:
+ if r.remain >= 0 && r.st.lim > r.remain {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body longer than content-length",
+ }
+ }
+ // Fall out of the loop and process the frame body below.
+ case frameTypeHeaders:
+ // This HEADERS frame contains the message trailers.
+ if r.remain > 0 {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body shorter than content-length",
+ }
+ }
+ // TODO: Fill in Request.Trailer.
+ if err := r.st.discardFrame(); err != nil {
+ return 0, err
+ }
+ return 0, io.EOF
+ default:
+ if err := r.st.discardUnknownFrame(ftype); err != nil {
+ return 0, err
+ }
+ }
+ }
+ // We are now reading the content of a DATA frame.
+ // Fill the read buffer or read to the end of the frame,
+ // whichever comes first.
+ if int64(len(p)) > r.st.lim {
+ p = p[:r.st.lim]
+ }
+ n, err = r.st.Read(p)
+ if r.remain > 0 {
+ r.remain -= int64(n)
+ }
+ return n, err
+}
+
+func (r *bodyReader) Close() error {
+ // Unlike the HTTP/1 and HTTP/2 body readers (at the time of this comment being written),
+ // calling Close concurrently with Read will interrupt the read.
+ r.st.stream.CloseRead()
+ return nil
+}
diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go
new file mode 100644
index 0000000000..599e0df816
--- /dev/null
+++ b/internal/http3/body_test.go
@@ -0,0 +1,276 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+)
+
+// TestReadData tests servers reading request bodies, and clients reading response bodies.
+func TestReadData(t *testing.T) {
+ // These tests consist of a series of steps,
+ // where each step is either something arriving on the stream
+ // or the client/server reading from the body.
+ type (
+ // HEADERS frame arrives (headers).
+ receiveHeaders struct {
+ contentLength int64 // -1 for no content-length
+ }
+ // DATA frame header arrives.
+ receiveDataHeader struct {
+ size int64
+ }
+ // DATA frame content arrives.
+ receiveData struct {
+ size int64
+ }
+ // HEADERS frame arrives (trailers).
+ receiveTrailers struct{}
+ // Some other frame arrives.
+ receiveFrame struct {
+ ftype frameType
+ data []byte
+ }
+ // Stream closed, ending the body.
+ receiveEOF struct{}
+ // Server reads from Request.Body, or client reads from Response.Body.
+ wantBody struct {
+ size int64
+ eof bool
+ }
+ wantError struct{}
+ )
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ steps []any
+ wantError bool
+ }{{
+ name: "no content length",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ receiveEOF{},
+ wantBody{size: 10, eof: true},
+ },
+ }, {
+ name: "valid content length",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ receiveEOF{},
+ wantBody{size: 10, eof: true},
+ },
+ }, {
+ name: "data frame exceeds content length",
+ steps: []any{
+ receiveHeaders{contentLength: 5},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantError{},
+ },
+ }, {
+ name: "data frame after all content read",
+ steps: []any{
+ receiveHeaders{contentLength: 5},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ wantBody{size: 5},
+ receiveDataHeader{size: 1},
+ receiveData{size: 1},
+ wantError{},
+ },
+ }, {
+ name: "content length too long",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ receiveEOF{},
+ wantBody{size: 5},
+ wantError{},
+ },
+ }, {
+ name: "stream ended by trailers",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ receiveTrailers{},
+ wantBody{size: 5, eof: true},
+ },
+ }, {
+ name: "trailers and content length too long",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ wantBody{size: 5},
+ receiveTrailers{},
+ wantError{},
+ },
+ }, {
+ name: "unknown frame before headers",
+ steps: []any{
+ receiveFrame{
+ ftype: 0x1f + 0x21, // reserved frame type
+ data: []byte{1, 2, 3, 4},
+ },
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantBody{size: 10},
+ },
+ }, {
+ name: "unknown frame after headers",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveFrame{
+ ftype: 0x1f + 0x21, // reserved frame type
+ data: []byte{1, 2, 3, 4},
+ },
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantBody{size: 10},
+ },
+ }, {
+ name: "invalid frame",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveFrame{
+ ftype: frameTypeSettings, // not a valid frame on this stream
+ data: []byte{1, 2, 3, 4},
+ },
+ wantError{},
+ },
+ }, {
+ name: "data frame consumed by several reads",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 16},
+ receiveData{size: 16},
+ wantBody{size: 2},
+ wantBody{size: 4},
+ wantBody{size: 8},
+ wantBody{size: 2},
+ },
+ }, {
+ name: "read multiple frames",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 2},
+ receiveData{size: 2},
+ receiveDataHeader{size: 4},
+ receiveData{size: 4},
+ receiveDataHeader{size: 8},
+ receiveData{size: 8},
+ wantBody{size: 2},
+ wantBody{size: 4},
+ wantBody{size: 8},
+ },
+ }} {
+
+ runTest := func(t testing.TB, h http.Header, st *testQUICStream, body func() io.ReadCloser) {
+ var (
+ bytesSent int
+ bytesReceived int
+ )
+ for _, step := range test.steps {
+ switch step := step.(type) {
+ case receiveHeaders:
+ header := h.Clone()
+ if step.contentLength != -1 {
+ header["content-length"] = []string{
+ fmt.Sprint(step.contentLength),
+ }
+ }
+ st.writeHeaders(header)
+ case receiveDataHeader:
+ t.Logf("receive DATA frame header: size=%v", step.size)
+ st.writeVarint(int64(frameTypeData))
+ st.writeVarint(step.size)
+ st.Flush()
+ case receiveData:
+ t.Logf("receive DATA frame content: size=%v", step.size)
+ for range step.size {
+ st.stream.stream.WriteByte(byte(bytesSent))
+ bytesSent++
+ }
+ st.Flush()
+ case receiveTrailers:
+ st.writeHeaders(http.Header{
+ "x-trailer": []string{"trailer"},
+ })
+ case receiveFrame:
+ st.writeVarint(int64(step.ftype))
+ st.writeVarint(int64(len(step.data)))
+ st.Write(step.data)
+ st.Flush()
+ case receiveEOF:
+ t.Logf("receive EOF on request stream")
+ st.stream.stream.CloseWrite()
+ case wantBody:
+ t.Logf("read %v bytes from response body", step.size)
+ want := make([]byte, step.size)
+ for i := range want {
+ want[i] = byte(bytesReceived)
+ bytesReceived++
+ }
+ got := make([]byte, step.size)
+ n, err := body().Read(got)
+ got = got[:n]
+ if !bytes.Equal(got, want) {
+ t.Errorf("resp.Body.Read:")
+ t.Errorf(" got: {%x}", got)
+ t.Fatalf(" want: {%x}", want)
+ }
+ if err != nil {
+ if step.eof && err == io.EOF {
+ continue
+ }
+ t.Fatalf("resp.Body.Read: unexpected error %v", err)
+ }
+ if step.eof {
+ if n, err := body().Read([]byte{0}); n != 0 || err != io.EOF {
+ t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err)
+ }
+ }
+ case wantError:
+ if n, err := body().Read([]byte{0}); n != 0 || err == nil || err == io.EOF {
+ t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err)
+ }
+ default:
+ t.Fatalf("unknown test step %T", step)
+ }
+ }
+
+ }
+
+ runSynctestSubtest(t, test.name+"/client", func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ header := http.Header{
+ ":status": []string{"200"},
+ }
+ runTest(t, header, st, func() io.ReadCloser {
+ return rt.response().Body
+ })
+ })
+ }
+}
diff --git a/internal/http3/conn.go b/internal/http3/conn.go
new file mode 100644
index 0000000000..5eb803115e
--- /dev/null
+++ b/internal/http3/conn.go
@@ -0,0 +1,133 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "io"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+type streamHandler interface {
+ handleControlStream(*stream) error
+ handlePushStream(*stream) error
+ handleEncoderStream(*stream) error
+ handleDecoderStream(*stream) error
+ handleRequestStream(*stream) error
+ abort(error)
+}
+
+type genericConn struct {
+ mu sync.Mutex
+
+ // The peer may create exactly one control, encoder, and decoder stream.
+ // streamsCreated is a bitset of streams created so far.
+ // Bits are 1 << streamType.
+ streamsCreated uint8
+}
+
+func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) {
+ for {
+ // Use context.Background: This blocks until a stream is accepted
+ // or the connection closes.
+ st, err := qconn.AcceptStream(context.Background())
+ if err != nil {
+ return // connection closed
+ }
+ if st.IsReadOnly() {
+ go c.handleUnidirectionalStream(newStream(st), h)
+ } else {
+ go c.handleRequestStream(newStream(st), h)
+ }
+ }
+}
+
+func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) {
+ // Unidirectional stream header: One varint with the stream type.
+ v, err := st.readVarint()
+ if err != nil {
+ h.abort(&connectionError{
+ code: errH3StreamCreationError,
+ message: "error reading unidirectional stream header",
+ })
+ return
+ }
+ stype := streamType(v)
+ if err := c.checkStreamCreation(stype); err != nil {
+ h.abort(err)
+ return
+ }
+ switch stype {
+ case streamTypeControl:
+ err = h.handleControlStream(st)
+ case streamTypePush:
+ err = h.handlePushStream(st)
+ case streamTypeEncoder:
+ err = h.handleEncoderStream(st)
+ case streamTypeDecoder:
+ err = h.handleDecoderStream(st)
+ default:
+ // "Recipients of unknown stream types MUST either abort reading
+ // of the stream or discard incoming data without further processing."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-7
+ //
+ // We should send the H3_STREAM_CREATION_ERROR error code,
+ // but the quic package currently doesn't allow setting error codes
+ // for STOP_SENDING frames.
+ // TODO: Should CloseRead take an error code?
+ err = nil
+ }
+ if err == io.EOF {
+ err = &connectionError{
+ code: errH3ClosedCriticalStream,
+ message: streamType(stype).String() + " stream closed",
+ }
+ }
+ c.handleStreamError(st, h, err)
+}
+
+func (c *genericConn) handleRequestStream(st *stream, h streamHandler) {
+ c.handleStreamError(st, h, h.handleRequestStream(st))
+}
+
+func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) {
+ switch err := err.(type) {
+ case *connectionError:
+ h.abort(err)
+ case nil:
+ st.stream.CloseRead()
+ st.stream.CloseWrite()
+ case *streamError:
+ st.stream.CloseRead()
+ st.stream.Reset(uint64(err.code))
+ default:
+ st.stream.CloseRead()
+ st.stream.Reset(uint64(errH3InternalError))
+ }
+}
+
+func (c *genericConn) checkStreamCreation(stype streamType) error {
+ switch stype {
+ case streamTypeControl, streamTypeEncoder, streamTypeDecoder:
+ // The peer may create exactly one control, encoder, and decoder stream.
+ default:
+ return nil
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ bit := uint8(1) << stype
+ if c.streamsCreated&bit != 0 {
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "multiple " + stype.String() + " streams created",
+ }
+ }
+ c.streamsCreated |= bit
+ return nil
+}
diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go
new file mode 100644
index 0000000000..a9afb1f9e9
--- /dev/null
+++ b/internal/http3/conn_test.go
@@ -0,0 +1,154 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "testing"
+ "testing/synctest"
+)
+
+// Tests which apply to both client and server connections.
+
+func TestConnCreatesControlStream(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.wantStream(streamTypeControl)
+ controlStream.wantFrameHeader(
+ "server sends SETTINGS frame on control stream",
+ frameTypeSettings)
+ controlStream.discardFrame()
+ })
+}
+
+func TestConnUnknownUnidirectionalStream(t *testing.T) {
+ // "Recipients of unknown stream types MUST either abort reading of the stream
+ // or discard incoming data without further processing."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-7
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ st := tc.newStream(0x21) // reserved stream type
+
+ // The endpoint should send a STOP_SENDING for this stream,
+ // but it should not close the connection.
+ synctest.Wait()
+ if _, err := st.Write([]byte("hello")); err == nil {
+ t.Fatalf("write to send-only stream with an unknown type succeeded; want error")
+ }
+ tc.wantNotClosed("after receiving unknown unidirectional stream type")
+ })
+}
+
+func TestConnUnknownSettings(t *testing.T) {
+ // "An implementation MUST ignore any [settings] parameter with
+ // an identifier it does not understand."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-9
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.newStream(streamTypeControl)
+ controlStream.writeSettings(0x1f+0x21, 0) // reserved settings type
+ controlStream.Flush()
+ tc.wantNotClosed("after receiving unknown settings")
+ })
+}
+
+func TestConnInvalidSettings(t *testing.T) {
+ // "These reserved settings MUST NOT be sent, and their receipt MUST
+ // be treated as a connection error of type H3_SETTINGS_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.newStream(streamTypeControl)
+ controlStream.writeSettings(0x02, 0) // HTTP/2 SETTINGS_ENABLE_PUSH
+ controlStream.Flush()
+ tc.wantClosed("invalid setting", errH3SettingsError)
+ })
+}
+
+func TestConnDuplicateStream(t *testing.T) {
+ for _, stype := range []streamType{
+ streamTypeControl,
+ streamTypeEncoder,
+ streamTypeDecoder,
+ } {
+ t.Run(stype.String(), func(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ _ = tc.newStream(stype)
+ tc.wantNotClosed("after creating one " + stype.String() + " stream")
+
+ // Opening a second control, encoder, or decoder stream
+ // is a protocol violation.
+ _ = tc.newStream(stype)
+ tc.wantClosed("duplicate stream", errH3StreamCreationError)
+ })
+ })
+ }
+}
+
+func TestConnUnknownFrames(t *testing.T) {
+ for _, stype := range []streamType{
+ streamTypeControl,
+ } {
+ t.Run(stype.String(), func(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ st := tc.newStream(stype)
+
+ if stype == streamTypeControl {
+ // First frame on the control stream must be settings.
+ st.writeVarint(int64(frameTypeSettings))
+ st.writeVarint(0) // size
+ }
+
+ data := "frame content"
+ st.writeVarint(0x1f + 0x21) // reserved frame type
+ st.writeVarint(int64(len(data))) // size
+ st.Write([]byte(data))
+ st.Flush()
+
+ tc.wantNotClosed("after writing unknown frame")
+ })
+ })
+ }
+}
+
+func TestConnInvalidFrames(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ control := tc.newStream(streamTypeControl)
+
+ // SETTINGS frame.
+ control.writeVarint(int64(frameTypeSettings))
+ control.writeVarint(0) // size
+
+ // DATA frame (invalid on the control stream).
+ control.writeVarint(int64(frameTypeData))
+ control.writeVarint(0) // size
+ control.Flush()
+ tc.wantClosed("after writing DATA frame to control stream", errH3FrameUnexpected)
+ })
+}
+
+func TestConnPeerCreatesBadUnidirectionalStream(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ // Create and close a stream without sending the unidirectional stream header.
+ qs, err := tc.qconn.NewSendOnlyStream(canceledCtx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ st := newTestQUICStream(tc.t, newStream(qs))
+ st.stream.stream.Close()
+
+ tc.wantClosed("after peer creates and closes uni stream", errH3StreamCreationError)
+ })
+}
+
+func runConnTest(t *testing.T, f func(testing.TB, *testQUICConn)) {
+ t.Helper()
+ runSynctestSubtest(t, "client", func(t testing.TB) {
+ tc := newTestClientConn(t)
+ f(t, tc.testQUICConn)
+ })
+ runSynctestSubtest(t, "server", func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ f(t, tc.testQUICConn)
+ })
+}
diff --git a/internal/http3/doc.go b/internal/http3/doc.go
new file mode 100644
index 0000000000..5530113f69
--- /dev/null
+++ b/internal/http3/doc.go
@@ -0,0 +1,10 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package http3 implements the HTTP/3 protocol.
+//
+// This package is a work in progress.
+// It is not ready for production usage.
+// Its API is subject to change without notice.
+package http3
diff --git a/internal/http3/errors.go b/internal/http3/errors.go
new file mode 100644
index 0000000000..db46acfcc8
--- /dev/null
+++ b/internal/http3/errors.go
@@ -0,0 +1,104 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "fmt"
+
+// http3Error is an HTTP/3 error code.
+type http3Error int
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-8.1
+ errH3NoError = http3Error(0x0100)
+ errH3GeneralProtocolError = http3Error(0x0101)
+ errH3InternalError = http3Error(0x0102)
+ errH3StreamCreationError = http3Error(0x0103)
+ errH3ClosedCriticalStream = http3Error(0x0104)
+ errH3FrameUnexpected = http3Error(0x0105)
+ errH3FrameError = http3Error(0x0106)
+ errH3ExcessiveLoad = http3Error(0x0107)
+ errH3IDError = http3Error(0x0108)
+ errH3SettingsError = http3Error(0x0109)
+ errH3MissingSettings = http3Error(0x010a)
+ errH3RequestRejected = http3Error(0x010b)
+ errH3RequestCancelled = http3Error(0x010c)
+ errH3RequestIncomplete = http3Error(0x010d)
+ errH3MessageError = http3Error(0x010e)
+ errH3ConnectError = http3Error(0x010f)
+ errH3VersionFallback = http3Error(0x0110)
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-8.3
+ errQPACKDecompressionFailed = http3Error(0x0200)
+ errQPACKEncoderStreamError = http3Error(0x0201)
+ errQPACKDecoderStreamError = http3Error(0x0202)
+)
+
+func (e http3Error) Error() string {
+ switch e {
+ case errH3NoError:
+ return "H3_NO_ERROR"
+ case errH3GeneralProtocolError:
+ return "H3_GENERAL_PROTOCOL_ERROR"
+ case errH3InternalError:
+ return "H3_INTERNAL_ERROR"
+ case errH3StreamCreationError:
+ return "H3_STREAM_CREATION_ERROR"
+ case errH3ClosedCriticalStream:
+ return "H3_CLOSED_CRITICAL_STREAM"
+ case errH3FrameUnexpected:
+ return "H3_FRAME_UNEXPECTED"
+ case errH3FrameError:
+ return "H3_FRAME_ERROR"
+ case errH3ExcessiveLoad:
+ return "H3_EXCESSIVE_LOAD"
+ case errH3IDError:
+ return "H3_ID_ERROR"
+ case errH3SettingsError:
+ return "H3_SETTINGS_ERROR"
+ case errH3MissingSettings:
+ return "H3_MISSING_SETTINGS"
+ case errH3RequestRejected:
+ return "H3_REQUEST_REJECTED"
+ case errH3RequestCancelled:
+ return "H3_REQUEST_CANCELLED"
+ case errH3RequestIncomplete:
+ return "H3_REQUEST_INCOMPLETE"
+ case errH3MessageError:
+ return "H3_MESSAGE_ERROR"
+ case errH3ConnectError:
+ return "H3_CONNECT_ERROR"
+ case errH3VersionFallback:
+ return "H3_VERSION_FALLBACK"
+ case errQPACKDecompressionFailed:
+ return "QPACK_DECOMPRESSION_FAILED"
+ case errQPACKEncoderStreamError:
+ return "QPACK_ENCODER_STREAM_ERROR"
+ case errQPACKDecoderStreamError:
+ return "QPACK_DECODER_STREAM_ERROR"
+ }
+ return fmt.Sprintf("H3_ERROR_%v", int(e))
+}
+
+// A streamError is an error which terminates a stream, but not the connection.
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-1
+type streamError struct {
+ code http3Error
+ message string
+}
+
+func (e *streamError) Error() string { return e.message }
+func (e *streamError) Unwrap() error { return e.code }
+
+// A connectionError is an error which results in the entire connection closing.
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-2
+type connectionError struct {
+ code http3Error
+ message string
+}
+
+func (e *connectionError) Error() string { return e.message }
+func (e *connectionError) Unwrap() error { return e.code }
diff --git a/internal/http3/files_test.go b/internal/http3/files_test.go
new file mode 100644
index 0000000000..9c97a6ced4
--- /dev/null
+++ b/internal/http3/files_test.go
@@ -0,0 +1,56 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "os"
+ "strings"
+ "testing"
+)
+
+// TestFiles checks that every file in this package has a build constraint on Go 1.24.
+//
+// Package tests rely on testing/synctest, added as an experiment in Go 1.24.
+// When moving internal/http3 to an importable location, we can decide whether
+// to relax the constraint for non-test files.
+//
+// Drop this test when the x/net go.mod depends on 1.24 or newer.
+func TestFiles(t *testing.T) {
+ f, err := os.Open(".")
+ if err != nil {
+ t.Fatal(err)
+ }
+ names, err := f.Readdirnames(-1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, name := range names {
+ if !strings.HasSuffix(name, ".go") {
+ continue
+ }
+ b, err := os.ReadFile(name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Check for copyright header while we're in here.
+ if !bytes.Contains(b, []byte("The Go Authors.")) {
+ t.Errorf("%v: missing copyright", name)
+ }
+ // doc.go doesn't need a build constraint.
+ if name == "doc.go" {
+ continue
+ }
+ if !bytes.Contains(b, []byte("//go:build go1.24")) {
+ t.Errorf("%v: missing constraint on go1.24", name)
+ }
+ if bytes.Contains(b, []byte(`"testing/synctest"`)) &&
+ !bytes.Contains(b, []byte("//go:build go1.24 && goexperiment.synctest")) {
+ t.Errorf("%v: missing constraint on go1.24 && goexperiment.synctest", name)
+ }
+ }
+}
diff --git a/internal/http3/http3.go b/internal/http3/http3.go
new file mode 100644
index 0000000000..1f60670564
--- /dev/null
+++ b/internal/http3/http3.go
@@ -0,0 +1,86 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "fmt"
+
+// Stream types.
+//
+// For unidirectional streams, the value is the stream type sent over the wire.
+//
+// For bidirectional streams (which are always request streams),
+// the value is arbitrary and never sent on the wire.
+type streamType int64
+
+const (
+ // Bidirectional request stream.
+ // All bidirectional streams are request streams.
+ // This stream type is never sent over the wire.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1
+ streamTypeRequest = streamType(-1)
+
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2
+ streamTypeControl = streamType(0x00)
+ streamTypePush = streamType(0x01)
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.2
+ streamTypeEncoder = streamType(0x02)
+ streamTypeDecoder = streamType(0x03)
+)
+
+func (stype streamType) String() string {
+ switch stype {
+ case streamTypeRequest:
+ return "request"
+ case streamTypeControl:
+ return "control"
+ case streamTypePush:
+ return "push"
+ case streamTypeEncoder:
+ return "encoder"
+ case streamTypeDecoder:
+ return "decoder"
+ default:
+ return "unknown"
+ }
+}
+
+// Frame types.
+type frameType int64
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2
+ frameTypeData = frameType(0x00)
+ frameTypeHeaders = frameType(0x01)
+ frameTypeCancelPush = frameType(0x03)
+ frameTypeSettings = frameType(0x04)
+ frameTypePushPromise = frameType(0x05)
+ frameTypeGoaway = frameType(0x07)
+ frameTypeMaxPushID = frameType(0x0d)
+)
+
+func (ftype frameType) String() string {
+ switch ftype {
+ case frameTypeData:
+ return "DATA"
+ case frameTypeHeaders:
+ return "HEADERS"
+ case frameTypeCancelPush:
+ return "CANCEL_PUSH"
+ case frameTypeSettings:
+ return "SETTINGS"
+ case frameTypePushPromise:
+ return "PUSH_PROMISE"
+ case frameTypeGoaway:
+ return "GOAWAY"
+ case frameTypeMaxPushID:
+ return "MAX_PUSH_ID"
+ default:
+ return fmt.Sprintf("UNKNOWN_%d", int64(ftype))
+ }
+}
diff --git a/internal/http3/http3_test.go b/internal/http3/http3_test.go
new file mode 100644
index 0000000000..f490ad3f03
--- /dev/null
+++ b/internal/http3/http3_test.go
@@ -0,0 +1,82 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "encoding/hex"
+ "os"
+ "slices"
+ "strings"
+ "testing"
+ "testing/synctest"
+)
+
+func init() {
+ // testing/synctest requires asynctimerchan=0 (the default as of Go 1.23),
+ // but the x/net go.mod is currently selecting go1.18.
+ //
+ // Set asynctimerchan=0 explicitly.
+ //
+ // TODO: Remove this when the x/net go.mod Go version is >= go1.23.
+ os.Setenv("GODEBUG", os.Getenv("GODEBUG")+",asynctimerchan=0")
+}
+
+// runSynctest runs f in a synctest.Run bubble.
+// It arranges for t.Cleanup functions to run within the bubble.
+func runSynctest(t *testing.T, f func(t testing.TB)) {
+ synctest.Run(func() {
+ ct := &cleanupT{T: t}
+ defer ct.done()
+ f(ct)
+ })
+}
+
+// runSynctestSubtest runs f in a subtest in a synctest.Run bubble.
+func runSynctestSubtest(t *testing.T, name string, f func(t testing.TB)) {
+ t.Run(name, func(t *testing.T) {
+ runSynctest(t, f)
+ })
+}
+
+// cleanupT wraps a testing.T and adds its own Cleanup method.
+// Used to execute cleanup functions within a synctest bubble.
+type cleanupT struct {
+ *testing.T
+ cleanups []func()
+}
+
+// Cleanup replaces T.Cleanup.
+func (t *cleanupT) Cleanup(f func()) {
+ t.cleanups = append(t.cleanups, f)
+}
+
+func (t *cleanupT) done() {
+ for _, f := range slices.Backward(t.cleanups) {
+ f()
+ }
+}
+
+func unhex(s string) []byte {
+ b, err := hex.DecodeString(strings.Map(func(c rune) rune {
+ switch c {
+ case ' ', '\t', '\n':
+ return -1 // ignore
+ }
+ return c
+ }, s))
+ if err != nil {
+ panic(err)
+ }
+ return b
+}
+
+// testReader implements io.Reader.
+type testReader struct {
+ readFunc func([]byte) (int, error)
+}
+
+func (r testReader) Read(p []byte) (n int, err error) { return r.readFunc(p) }
diff --git a/internal/http3/qpack.go b/internal/http3/qpack.go
new file mode 100644
index 0000000000..66f4e29762
--- /dev/null
+++ b/internal/http3/qpack.go
@@ -0,0 +1,334 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "io"
+
+ "golang.org/x/net/http2/hpack"
+)
+
+// QPACK (RFC 9204) header compression wire encoding.
+// https://www.rfc-editor.org/rfc/rfc9204.html
+
+// tableType is the static or dynamic table.
+//
+// The T bit in QPACK instructions indicates whether a table index refers to
+// the dynamic (T=0) or static (T=1) table. tableTypeForTBit and tableType.tbit
+// convert a T bit from the wire encoding to/from a tableType.
+type tableType byte
+
+const (
+ dynamicTable = 0x00 // T=0, dynamic table
+ staticTable = 0xff // T=1, static table
+)
+
+// tableTypeForTbit returns the table type corresponding to a T bit value.
+// The input parameter contains a byte masked to contain only the T bit.
+func tableTypeForTbit(bit byte) tableType {
+ if bit == 0 {
+ return dynamicTable
+ }
+ return staticTable
+}
+
+// tbit produces the T bit corresponding to the table type.
+// The input parameter contains a byte with the T bit set to 1,
+// and the return is either the input or 0 depending on the table type.
+func (t tableType) tbit(bit byte) byte {
+ return bit & byte(t)
+}
+
+// indexType indicates a literal's indexing status.
+//
+// The N bit in QPACK instructions indicates whether a literal is "never-indexed".
+// A never-indexed literal (N=1) must not be encoded as an indexed literal if it
+// forwarded on another connection.
+//
+// (See https://www.rfc-editor.org/rfc/rfc9204.html#section-7.1 for details on the
+// security reasons for never-indexed literals.)
+type indexType byte
+
+const (
+ mayIndex = 0x00 // N=0, not a never-indexed literal
+ neverIndex = 0xff // N=1, never-indexed literal
+)
+
+// indexTypeForNBit returns the index type corresponding to a N bit value.
+// The input parameter contains a byte masked to contain only the N bit.
+func indexTypeForNBit(bit byte) indexType {
+ if bit == 0 {
+ return mayIndex
+ }
+ return neverIndex
+}
+
+// nbit produces the N bit corresponding to the table type.
+// The input parameter contains a byte with the N bit set to 1,
+// and the return is either the input or 0 depending on the table type.
+func (t indexType) nbit(bit byte) byte {
+ return bit & byte(t)
+}
+
+// Indexed Field Line:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 1 | T | Index (6+) |
+// +---+---+-----------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.2
+
+func appendIndexedFieldLine(b []byte, ttype tableType, index int) []byte {
+ const tbit = 0b_01000000
+ return appendPrefixedInt(b, 0b_1000_0000|ttype.tbit(tbit), 6, int64(index))
+}
+
+func (st *stream) decodeIndexedFieldLine(b byte) (itype indexType, name, value string, err error) {
+ index, err := st.readPrefixedIntWithByte(b, 6)
+ if err != nil {
+ return 0, "", "", err
+ }
+ const tbit = 0b_0100_0000
+ if tableTypeForTbit(b&tbit) == staticTable {
+ ent, err := staticTableEntry(index)
+ if err != nil {
+ return 0, "", "", err
+ }
+ return mayIndex, ent.name, ent.value, nil
+ } else {
+ return 0, "", "", errors.New("dynamic table is not supported yet")
+ }
+}
+
+// Literal Field Line With Name Reference:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 1 | N | T |Name Index (4+)|
+// +---+---+---+---+---------------+
+// | H | Value Length (7+) |
+// +---+---------------------------+
+// | Value String (Length bytes) |
+// +-------------------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.4
+
+func appendLiteralFieldLineWithNameReference(b []byte, ttype tableType, itype indexType, nameIndex int, value string) []byte {
+ const tbit = 0b_0001_0000
+ const nbit = 0b_0010_0000
+ b = appendPrefixedInt(b, 0b_0100_0000|itype.nbit(nbit)|ttype.tbit(tbit), 4, int64(nameIndex))
+ b = appendPrefixedString(b, 0, 7, value)
+ return b
+}
+
+func (st *stream) decodeLiteralFieldLineWithNameReference(b byte) (itype indexType, name, value string, err error) {
+ nameIndex, err := st.readPrefixedIntWithByte(b, 4)
+ if err != nil {
+ return 0, "", "", err
+ }
+
+ const tbit = 0b_0001_0000
+ if tableTypeForTbit(b&tbit) == staticTable {
+ ent, err := staticTableEntry(nameIndex)
+ if err != nil {
+ return 0, "", "", err
+ }
+ name = ent.name
+ } else {
+ return 0, "", "", errors.New("dynamic table is not supported yet")
+ }
+
+ _, value, err = st.readPrefixedString(7)
+ if err != nil {
+ return 0, "", "", err
+ }
+
+ const nbit = 0b_0010_0000
+ itype = indexTypeForNBit(b & nbit)
+
+ return itype, name, value, nil
+}
+
+// Literal Field Line with Literal Name:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 0 | 1 | N | H |NameLen(3+)|
+// +---+---+---+---+---+-----------+
+// | Name String (Length bytes) |
+// +---+---------------------------+
+// | H | Value Length (7+) |
+// +---+---------------------------+
+// | Value String (Length bytes) |
+// +-------------------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.6
+
+func appendLiteralFieldLineWithLiteralName(b []byte, itype indexType, name, value string) []byte {
+ const nbit = 0b_0001_0000
+ b = appendPrefixedString(b, 0b_0010_0000|itype.nbit(nbit), 3, name)
+ b = appendPrefixedString(b, 0, 7, value)
+ return b
+}
+
+func (st *stream) decodeLiteralFieldLineWithLiteralName(b byte) (itype indexType, name, value string, err error) {
+ name, err = st.readPrefixedStringWithByte(b, 3)
+ if err != nil {
+ return 0, "", "", err
+ }
+ _, value, err = st.readPrefixedString(7)
+ if err != nil {
+ return 0, "", "", err
+ }
+ const nbit = 0b_0001_0000
+ itype = indexTypeForNBit(b & nbit)
+ return itype, name, value, nil
+}
+
+// Prefixed-integer encoding from RFC 7541, section 5.1
+//
+// Prefixed integers consist of some number of bits of data,
+// N bits of encoded integer, and 0 or more additional bytes of
+// encoded integer.
+//
+// The RFCs represent this as, for example:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 0 | 1 | Capacity (5+) |
+// +---+---+---+-------------------+
+//
+// "Capacity" is an integer with a 5-bit prefix.
+//
+// In the following functions, a "prefixLen" parameter is the number
+// of integer bits in the first byte (5 in the above example), and
+// a "firstByte" parameter is a byte containing the first byte of
+// the encoded value (0x001x_xxxx in the above example).
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.1
+// https://www.rfc-editor.org/rfc/rfc7541#section-5.1
+
+// readPrefixedInt reads an RFC 7541 prefixed integer from st.
+func (st *stream) readPrefixedInt(prefixLen uint8) (firstByte byte, v int64, err error) {
+ firstByte, err = st.ReadByte()
+ if err != nil {
+ return 0, 0, errQPACKDecompressionFailed
+ }
+ v, err = st.readPrefixedIntWithByte(firstByte, prefixLen)
+ return firstByte, v, err
+}
+
+// readPrefixedInt reads an RFC 7541 prefixed integer from st.
+// The first byte has already been read from the stream.
+func (st *stream) readPrefixedIntWithByte(firstByte byte, prefixLen uint8) (v int64, err error) {
+ prefixMask := (byte(1) << prefixLen) - 1
+ v = int64(firstByte & prefixMask)
+ if v != int64(prefixMask) {
+ return v, nil
+ }
+ m := 0
+ for {
+ b, err := st.ReadByte()
+ if err != nil {
+ return 0, errQPACKDecompressionFailed
+ }
+ v += int64(b&127) << m
+ m += 7
+ if b&128 == 0 {
+ break
+ }
+ }
+ return v, err
+}
+
+// appendPrefixedInt appends an RFC 7541 prefixed integer to b.
+//
+// The firstByte parameter includes the non-integer bits of the first byte.
+// The other bits must be zero.
+func appendPrefixedInt(b []byte, firstByte byte, prefixLen uint8, i int64) []byte {
+ u := uint64(i)
+ prefixMask := (uint64(1) << prefixLen) - 1
+ if u < prefixMask {
+ return append(b, firstByte|byte(u))
+ }
+ b = append(b, firstByte|byte(prefixMask))
+ u -= prefixMask
+ for u >= 128 {
+ b = append(b, 0x80|byte(u&0x7f))
+ u >>= 7
+ }
+ return append(b, byte(u))
+}
+
+// String literal encoding from RFC 7541, section 5.2
+//
+// String literals consist of a single bit flag indicating
+// whether the string is Huffman-encoded, a prefixed integer (see above),
+// and the string.
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2
+// https://www.rfc-editor.org/rfc/rfc7541#section-5.2
+
+// readPrefixedString reads an RFC 7541 string from st.
+func (st *stream) readPrefixedString(prefixLen uint8) (firstByte byte, s string, err error) {
+ firstByte, err = st.ReadByte()
+ if err != nil {
+ return 0, "", errQPACKDecompressionFailed
+ }
+ s, err = st.readPrefixedStringWithByte(firstByte, prefixLen)
+ return firstByte, s, err
+}
+
+// readPrefixedString reads an RFC 7541 string from st.
+// The first byte has already been read from the stream.
+func (st *stream) readPrefixedStringWithByte(firstByte byte, prefixLen uint8) (s string, err error) {
+ size, err := st.readPrefixedIntWithByte(firstByte, prefixLen)
+ if err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+
+ hbit := byte(1) << prefixLen
+ isHuffman := firstByte&hbit != 0
+
+ // TODO: Avoid allocating here.
+ data := make([]byte, size)
+ if _, err := io.ReadFull(st, data); err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+ if isHuffman {
+ // TODO: Move Huffman functions into a new package that hpack (HTTP/2)
+ // and this package can both import. Most of the hpack package isn't
+ // relevant to HTTP/3.
+ s, err := hpack.HuffmanDecodeToString(data)
+ if err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+ return s, nil
+ }
+ return string(data), nil
+}
+
+// appendPrefixedString appends an RFC 7541 string to st,
+// applying Huffman encoding and setting the H bit (indicating Huffman encoding)
+// when appropriate.
+//
+// The firstByte parameter includes the non-integer bits of the first byte.
+// The other bits must be zero.
+func appendPrefixedString(b []byte, firstByte byte, prefixLen uint8, s string) []byte {
+ huffmanLen := hpack.HuffmanEncodeLength(s)
+ if huffmanLen < uint64(len(s)) {
+ hbit := byte(1) << prefixLen
+ b = appendPrefixedInt(b, firstByte|hbit, prefixLen, int64(huffmanLen))
+ b = hpack.AppendHuffmanString(b, s)
+ } else {
+ b = appendPrefixedInt(b, firstByte, prefixLen, int64(len(s)))
+ b = append(b, s...)
+ }
+ return b
+}
diff --git a/internal/http3/qpack_decode.go b/internal/http3/qpack_decode.go
new file mode 100644
index 0000000000..018867afb1
--- /dev/null
+++ b/internal/http3/qpack_decode.go
@@ -0,0 +1,83 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "math/bits"
+)
+
+type qpackDecoder struct {
+ // The decoder has no state for now,
+ // but that'll change once we add dynamic table support.
+ //
+ // TODO: dynamic table support.
+}
+
+func (qd *qpackDecoder) decode(st *stream, f func(itype indexType, name, value string) error) error {
+ // Encoded Field Section prefix.
+
+ // We set SETTINGS_QPACK_MAX_TABLE_CAPACITY to 0,
+ // so the Required Insert Count must be 0.
+ _, requiredInsertCount, err := st.readPrefixedInt(8)
+ if err != nil {
+ return err
+ }
+ if requiredInsertCount != 0 {
+ return errQPACKDecompressionFailed
+ }
+
+ // Delta Base. We don't use the dynamic table yet, so this may be ignored.
+ _, _, err = st.readPrefixedInt(7)
+ if err != nil {
+ return err
+ }
+
+ sawNonPseudo := false
+ for st.lim > 0 {
+ firstByte, err := st.ReadByte()
+ if err != nil {
+ return err
+ }
+ var name, value string
+ var itype indexType
+ switch bits.LeadingZeros8(firstByte) {
+ case 0:
+ // Indexed Field Line
+ itype, name, value, err = st.decodeIndexedFieldLine(firstByte)
+ case 1:
+ // Literal Field Line With Name Reference
+ itype, name, value, err = st.decodeLiteralFieldLineWithNameReference(firstByte)
+ case 2:
+ // Literal Field Line with Literal Name
+ itype, name, value, err = st.decodeLiteralFieldLineWithLiteralName(firstByte)
+ case 3:
+ // Indexed Field Line With Post-Base Index
+ err = errors.New("dynamic table is not supported yet")
+ case 4:
+ // Indexed Field Line With Post-Base Name Reference
+ err = errors.New("dynamic table is not supported yet")
+ }
+ if err != nil {
+ return err
+ }
+ if len(name) == 0 {
+ return errH3MessageError
+ }
+ if name[0] == ':' {
+ if sawNonPseudo {
+ return errH3MessageError
+ }
+ } else {
+ sawNonPseudo = true
+ }
+ if err := f(itype, name, value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/internal/http3/qpack_decode_test.go b/internal/http3/qpack_decode_test.go
new file mode 100644
index 0000000000..1b779aa782
--- /dev/null
+++ b/internal/http3/qpack_decode_test.go
@@ -0,0 +1,196 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestQPACKDecode(t *testing.T) {
+ type header struct {
+ itype indexType
+ name, value string
+ }
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ enc []byte
+ want []header
+ }{{
+ name: "empty",
+ enc: unhex("0000"),
+ want: []header{},
+ }, {
+ name: "literal entry empty value",
+ enc: unhex("000023666f6f00"),
+ want: []header{
+ {mayIndex, "foo", ""},
+ },
+ }, {
+ name: "simple literal entry",
+ enc: unhex("000023666f6f03626172"),
+ want: []header{
+ {mayIndex, "foo", "bar"},
+ },
+ }, {
+ name: "multiple literal entries",
+ enc: unhex("0000" + // prefix
+ // foo: bar
+ "23666f6f03626172" +
+ // 7 octet long header name, the smallest number
+ // that does not fit on a 3-bit prefix.
+ "2700666f6f62616172" +
+ // 127 octet long header value, the smallest number
+ // that does not fit on a 7-bit prefix.
+ "7f00616161616161616161616161616161616161616161616161616161616161616161" +
+ "6161616161616161616161616161616161616161616161616161616161616161616161" +
+ "6161616161616161616161616161616161616161616161616161616161616161616161" +
+ "616161616161616161616161616161616161616161616161",
+ ),
+ want: []header{
+ {mayIndex, "foo", "bar"},
+ {mayIndex, "foobaar", strings.Repeat("a", 127)},
+ },
+ }, {
+ name: "line feed in value",
+ enc: unhex("000023666f6f0462610a72"),
+ want: []header{
+ {mayIndex, "foo", "ba\nr"},
+ },
+ }, {
+ name: "huffman simple",
+ enc: unhex("00002f0125a849e95ba97d7f8925a849e95bb8e8b4bf"),
+ want: []header{
+ {mayIndex, "custom-key", "custom-value"},
+ },
+ }, {
+ name: "alternating huffman nonhuffman",
+ enc: unhex("0000" + // Prefix.
+ "2f0125a849e95ba97d7f" + // Huffman-encoded name.
+ "8925a849e95bb8e8b4bf" + // Huffman-encoded value.
+ "2703637573746f6d2d6b6579" + // Non-Huffman encoded name.
+ "0c637573746f6d2d76616c7565" + // Non-Huffman encoded value.
+ "2f0125a849e95ba97d7f" + // Huffman-encoded name.
+ "0c637573746f6d2d76616c7565" + // Non-Huffman encoded value.
+ "2703637573746f6d2d6b6579" + // Non-Huffman encoded name.
+ "8925a849e95bb8e8b4bf", // Huffman-encoded value.
+ ),
+ want: []header{
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ },
+ }, {
+ name: "static table",
+ enc: unhex("0000d1d45f00055452414345dfcc5f108621e9aec2a11f5c8294e75f1000"),
+ want: []header{
+ {mayIndex, ":method", "GET"},
+ {mayIndex, ":method", "POST"},
+ {mayIndex, ":method", "TRACE"},
+ {mayIndex, "accept-encoding", "gzip, deflate, br"},
+ {mayIndex, "location", ""},
+ {mayIndex, "accept-encoding", "compress"},
+ {mayIndex, "location", "foo"},
+ {mayIndex, "accept-encoding", ""},
+ },
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ st1, st2 := newStreamPair(t)
+ st1.Write(test.enc)
+ st1.Flush()
+
+ st2.lim = int64(len(test.enc))
+
+ var dec qpackDecoder
+ got := []header{}
+ err := dec.decode(st2, func(itype indexType, name, value string) error {
+ got = append(got, header{itype, name, value})
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("encoded: %x", test.enc)
+ t.Errorf("got headers:")
+ for _, h := range got {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ t.Errorf("want headers:")
+ for _, h := range test.want {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ }
+ })
+ }
+}
+
+func TestQPACKDecodeErrors(t *testing.T) {
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ enc []byte
+ }{{
+ name: "literal entry empty name",
+ enc: unhex("00002003666f6f"),
+ }, {
+ name: "literal entry empty name and value",
+ enc: unhex("00002000"),
+ }, {
+ name: "name length too large for varint",
+ enc: unhex("000027ffffffffffffffffffff"),
+ }, {
+ name: "string literal too long",
+ enc: unhex("000027ffff7f"),
+ }, {
+ name: "value length too large for varint",
+ enc: unhex("000023666f6f7fffffffffffffffffffff"),
+ }, {
+ name: "value length too long",
+ enc: unhex("000023666f6f7fffff7f"),
+ }, {
+ name: "incomplete header block",
+ enc: unhex("00002366"),
+ }, {
+ name: "huffman name does not have eos prefix",
+ enc: unhex("00002f0125a849e95ba97d7e8925a849e95bb8e8b4bf"),
+ }, {
+ name: "huffman value does not have eos prefix",
+ enc: unhex("00002f0125a849e95ba97d7f8925a849e95bb8e8b4be"),
+ }, {
+ name: "huffman name eos prefix too long",
+ enc: unhex("00002f0225a849e95ba97d7fff8925a849e95bb8e8b4bf"),
+ }, {
+ name: "huffman value eos prefix too long",
+ enc: unhex("00002f0125a849e95ba97d7f8a25a849e95bb8e8b4bfff"),
+ }, {
+ name: "too high static table index",
+ enc: unhex("0000ff23ff24"),
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ st1, st2 := newStreamPair(t)
+ st1.Write(test.enc)
+ st1.Flush()
+
+ st2.lim = int64(len(test.enc))
+
+ var dec qpackDecoder
+ err := dec.decode(st2, func(itype indexType, name, value string) error {
+ return nil
+ })
+ if err == nil {
+ t.Errorf("encoded: %x", test.enc)
+ t.Fatalf("decode succeeded; want error")
+ }
+ })
+ }
+}
diff --git a/internal/http3/qpack_encode.go b/internal/http3/qpack_encode.go
new file mode 100644
index 0000000000..0f35e0c54f
--- /dev/null
+++ b/internal/http3/qpack_encode.go
@@ -0,0 +1,47 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+type qpackEncoder struct {
+ // The encoder has no state for now,
+ // but that'll change once we add dynamic table support.
+ //
+ // TODO: dynamic table support.
+}
+
+func (qe *qpackEncoder) init() {
+ staticTableOnce.Do(initStaticTableMaps)
+}
+
+// encode encodes a list of headers into a QPACK encoded field section.
+//
+// The headers func must produce the same headers on repeated calls,
+// although the order may vary.
+func (qe *qpackEncoder) encode(headers func(func(itype indexType, name, value string))) []byte {
+ // Encoded Field Section prefix.
+ //
+ // We don't yet use the dynamic table, so both values here are zero.
+ var b []byte
+ b = appendPrefixedInt(b, 0, 8, 0) // Required Insert Count
+ b = appendPrefixedInt(b, 0, 7, 0) // Delta Base
+
+ headers(func(itype indexType, name, value string) {
+ if itype == mayIndex {
+ if i, ok := staticTableByNameValue[tableEntry{name, value}]; ok {
+ b = appendIndexedFieldLine(b, staticTable, i)
+ return
+ }
+ }
+ if i, ok := staticTableByName[name]; ok {
+ b = appendLiteralFieldLineWithNameReference(b, staticTable, itype, i, value)
+ } else {
+ b = appendLiteralFieldLineWithLiteralName(b, itype, name, value)
+ }
+ })
+
+ return b
+}
diff --git a/internal/http3/qpack_encode_test.go b/internal/http3/qpack_encode_test.go
new file mode 100644
index 0000000000..f426d773a6
--- /dev/null
+++ b/internal/http3/qpack_encode_test.go
@@ -0,0 +1,126 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+)
+
+func TestQPACKEncode(t *testing.T) {
+ type header struct {
+ itype indexType
+ name, value string
+ }
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ headers []header
+ want []byte
+ }{{
+ name: "empty",
+ headers: []header{},
+ want: unhex("0000"),
+ }, {
+ name: "empty name",
+ headers: []header{
+ {mayIndex, "", "foo"},
+ },
+ want: unhex("0000208294e7"),
+ }, {
+ name: "empty value",
+ headers: []header{
+ {mayIndex, "foo", ""},
+ },
+ want: unhex("00002a94e700"),
+ }, {
+ name: "empty name and value",
+ headers: []header{
+ {mayIndex, "", ""},
+ },
+ want: unhex("00002000"),
+ }, {
+ name: "simple",
+ headers: []header{
+ {mayIndex, "foo", "bar"},
+ },
+ want: unhex("00002a94e703626172"),
+ }, {
+ name: "multiple",
+ headers: []header{
+ {mayIndex, "foo", "bar"},
+ {mayIndex, "ZZZZZZZ", strings.Repeat("Z", 127)},
+ },
+ want: unhex("0000" + // prefix
+ // foo: bar
+ "2a94e703626172" +
+ // 7 octet long header name, the smallest number
+ // that does not fit on a 3-bit prefix.
+ "27005a5a5a5a5a5a5a" +
+ // 127 octet long header value, the smallest
+ // number that does not fit on a 7-bit prefix.
+ "7f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"),
+ }, {
+ name: "static table 1",
+ headers: []header{
+ {mayIndex, ":method", "GET"},
+ {mayIndex, "accept-encoding", "gzip, deflate, br"},
+ {mayIndex, "location", ""},
+ },
+ want: unhex("0000d1dfcc"),
+ }, {
+ name: "static table 2",
+ headers: []header{
+ {mayIndex, ":method", "POST"},
+ {mayIndex, "accept-encoding", "compress"},
+ {mayIndex, "location", "foo"},
+ },
+ want: unhex("0000d45f108621e9aec2a11f5c8294e7"),
+ }, {
+ name: "static table 3",
+ headers: []header{
+ {mayIndex, ":method", "TRACE"},
+ {mayIndex, "accept-encoding", ""},
+ },
+ want: unhex("00005f000554524143455f1000"),
+ }, {
+ name: "never indexed literal field line with name reference",
+ headers: []header{
+ {neverIndex, ":method", ""},
+ },
+ want: unhex("00007f0000"),
+ }, {
+ name: "never indexed literal field line with literal name",
+ headers: []header{
+ {neverIndex, "a", "b"},
+ },
+ want: unhex("000031610162"),
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ var enc qpackEncoder
+ enc.init()
+
+ got := enc.encode(func(f func(itype indexType, name, value string)) {
+ for _, h := range test.headers {
+ f(h.itype, h.name, h.value)
+ }
+ })
+ if !bytes.Equal(got, test.want) {
+ for _, h := range test.headers {
+ t.Logf("header %v: %q", h.name, h.value)
+ }
+ t.Errorf("got: %x", got)
+ t.Errorf("want: %x", test.want)
+ }
+ })
+ }
+}
diff --git a/internal/http3/qpack_static.go b/internal/http3/qpack_static.go
new file mode 100644
index 0000000000..cb0884eb7b
--- /dev/null
+++ b/internal/http3/qpack_static.go
@@ -0,0 +1,144 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "sync"
+
+type tableEntry struct {
+ name string
+ value string
+}
+
+// staticTableEntry returns the static table entry with the given index.
+func staticTableEntry(index int64) (tableEntry, error) {
+ if index >= int64(len(staticTableEntries)) {
+ return tableEntry{}, errQPACKDecompressionFailed
+ }
+ return staticTableEntries[index], nil
+}
+
+func initStaticTableMaps() {
+ staticTableByName = make(map[string]int)
+ staticTableByNameValue = make(map[tableEntry]int)
+ for i, ent := range staticTableEntries {
+ if _, ok := staticTableByName[ent.name]; !ok {
+ staticTableByName[ent.name] = i
+ }
+ staticTableByNameValue[ent] = i
+ }
+}
+
+var (
+ staticTableOnce sync.Once
+ staticTableByName map[string]int
+ staticTableByNameValue map[tableEntry]int
+)
+
+// https://www.rfc-editor.org/rfc/rfc9204.html#appendix-A
+//
+// Note that this is different from the HTTP/2 static table.
+var staticTableEntries = [...]tableEntry{
+ 0: {":authority", ""},
+ 1: {":path", "/"},
+ 2: {"age", "0"},
+ 3: {"content-disposition", ""},
+ 4: {"content-length", "0"},
+ 5: {"cookie", ""},
+ 6: {"date", ""},
+ 7: {"etag", ""},
+ 8: {"if-modified-since", ""},
+ 9: {"if-none-match", ""},
+ 10: {"last-modified", ""},
+ 11: {"link", ""},
+ 12: {"location", ""},
+ 13: {"referer", ""},
+ 14: {"set-cookie", ""},
+ 15: {":method", "CONNECT"},
+ 16: {":method", "DELETE"},
+ 17: {":method", "GET"},
+ 18: {":method", "HEAD"},
+ 19: {":method", "OPTIONS"},
+ 20: {":method", "POST"},
+ 21: {":method", "PUT"},
+ 22: {":scheme", "http"},
+ 23: {":scheme", "https"},
+ 24: {":status", "103"},
+ 25: {":status", "200"},
+ 26: {":status", "304"},
+ 27: {":status", "404"},
+ 28: {":status", "503"},
+ 29: {"accept", "*/*"},
+ 30: {"accept", "application/dns-message"},
+ 31: {"accept-encoding", "gzip, deflate, br"},
+ 32: {"accept-ranges", "bytes"},
+ 33: {"access-control-allow-headers", "cache-control"},
+ 34: {"access-control-allow-headers", "content-type"},
+ 35: {"access-control-allow-origin", "*"},
+ 36: {"cache-control", "max-age=0"},
+ 37: {"cache-control", "max-age=2592000"},
+ 38: {"cache-control", "max-age=604800"},
+ 39: {"cache-control", "no-cache"},
+ 40: {"cache-control", "no-store"},
+ 41: {"cache-control", "public, max-age=31536000"},
+ 42: {"content-encoding", "br"},
+ 43: {"content-encoding", "gzip"},
+ 44: {"content-type", "application/dns-message"},
+ 45: {"content-type", "application/javascript"},
+ 46: {"content-type", "application/json"},
+ 47: {"content-type", "application/x-www-form-urlencoded"},
+ 48: {"content-type", "image/gif"},
+ 49: {"content-type", "image/jpeg"},
+ 50: {"content-type", "image/png"},
+ 51: {"content-type", "text/css"},
+ 52: {"content-type", "text/html; charset=utf-8"},
+ 53: {"content-type", "text/plain"},
+ 54: {"content-type", "text/plain;charset=utf-8"},
+ 55: {"range", "bytes=0-"},
+ 56: {"strict-transport-security", "max-age=31536000"},
+ 57: {"strict-transport-security", "max-age=31536000; includesubdomains"},
+ 58: {"strict-transport-security", "max-age=31536000; includesubdomains; preload"},
+ 59: {"vary", "accept-encoding"},
+ 60: {"vary", "origin"},
+ 61: {"x-content-type-options", "nosniff"},
+ 62: {"x-xss-protection", "1; mode=block"},
+ 63: {":status", "100"},
+ 64: {":status", "204"},
+ 65: {":status", "206"},
+ 66: {":status", "302"},
+ 67: {":status", "400"},
+ 68: {":status", "403"},
+ 69: {":status", "421"},
+ 70: {":status", "425"},
+ 71: {":status", "500"},
+ 72: {"accept-language", ""},
+ 73: {"access-control-allow-credentials", "FALSE"},
+ 74: {"access-control-allow-credentials", "TRUE"},
+ 75: {"access-control-allow-headers", "*"},
+ 76: {"access-control-allow-methods", "get"},
+ 77: {"access-control-allow-methods", "get, post, options"},
+ 78: {"access-control-allow-methods", "options"},
+ 79: {"access-control-expose-headers", "content-length"},
+ 80: {"access-control-request-headers", "content-type"},
+ 81: {"access-control-request-method", "get"},
+ 82: {"access-control-request-method", "post"},
+ 83: {"alt-svc", "clear"},
+ 84: {"authorization", ""},
+ 85: {"content-security-policy", "script-src 'none'; object-src 'none'; base-uri 'none'"},
+ 86: {"early-data", "1"},
+ 87: {"expect-ct", ""},
+ 88: {"forwarded", ""},
+ 89: {"if-range", ""},
+ 90: {"origin", ""},
+ 91: {"purpose", "prefetch"},
+ 92: {"server", ""},
+ 93: {"timing-allow-origin", "*"},
+ 94: {"upgrade-insecure-requests", "1"},
+ 95: {"user-agent", ""},
+ 96: {"x-forwarded-for", ""},
+ 97: {"x-frame-options", "deny"},
+ 98: {"x-frame-options", "sameorigin"},
+}
diff --git a/internal/http3/qpack_test.go b/internal/http3/qpack_test.go
new file mode 100644
index 0000000000..6e16511fc6
--- /dev/null
+++ b/internal/http3/qpack_test.go
@@ -0,0 +1,173 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestPrefixedInt(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, test := range []struct {
+ value int64
+ prefixLen uint8
+ encoded []byte
+ }{
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.1
+ {
+ value: 10,
+ prefixLen: 5,
+ encoded: []byte{
+ 0b_0000_1010,
+ },
+ },
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.2
+ {
+ value: 1337,
+ prefixLen: 5,
+ encoded: []byte{
+ 0b0001_1111,
+ 0b1001_1010,
+ 0b0000_1010,
+ },
+ },
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.3
+ {
+ value: 42,
+ prefixLen: 8,
+ encoded: []byte{
+ 0b0010_1010,
+ },
+ },
+ } {
+ highBitMask := ^((byte(1) << test.prefixLen) - 1)
+ for _, highBits := range []byte{
+ 0, highBitMask, 0b1010_1010 & highBitMask,
+ } {
+ gotEnc := appendPrefixedInt(nil, highBits, test.prefixLen, test.value)
+ wantEnc := append([]byte{}, test.encoded...)
+ wantEnc[0] |= highBits
+ if !bytes.Equal(gotEnc, wantEnc) {
+ t.Errorf("appendPrefixedInt(nil, 0b%08b, %v, %v) = {%x}, want {%x}",
+ highBits, test.prefixLen, test.value, gotEnc, wantEnc)
+ }
+
+ st1.Write(gotEnc)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ gotFirstByte, v, err := st2.readPrefixedInt(test.prefixLen)
+ if err != nil || gotFirstByte&highBitMask != highBits || v != test.value {
+ t.Errorf("st.readPrefixedInt(%v) = 0b%08b, %v, %v; want 0b%08b, %v, nil", test.prefixLen, gotFirstByte, v, err, highBits, test.value)
+ }
+ }
+ }
+}
+
+func TestPrefixedString(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, test := range []struct {
+ value string
+ prefixLen uint8
+ encoded []byte
+ }{
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.6.1
+ {
+ value: "302",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x82, // H bit + length 2
+ 0x64, 0x02,
+ },
+ },
+ {
+ value: "private",
+ prefixLen: 5,
+ encoded: []byte{
+ 0x25, // H bit + length 5
+ 0xae, 0xc3, 0x77, 0x1a, 0x4b,
+ },
+ },
+ {
+ value: "Mon, 21 Oct 2013 20:13:21 GMT",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x96, // H bit + length 22
+ 0xd0, 0x7a, 0xbe, 0x94, 0x10, 0x54, 0xd4, 0x44,
+ 0xa8, 0x20, 0x05, 0x95, 0x04, 0x0b, 0x81, 0x66,
+ 0xe0, 0x82, 0xa6, 0x2d, 0x1b, 0xff,
+ },
+ },
+ {
+ value: "https://www.example.com",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x91, // H bit + length 17
+ 0x9d, 0x29, 0xad, 0x17, 0x18, 0x63, 0xc7, 0x8f,
+ 0x0b, 0x97, 0xc8, 0xe9, 0xae, 0x82, 0xae, 0x43,
+ 0xd3,
+ },
+ },
+ // Not Huffman encoded (encoded size == unencoded size).
+ {
+ value: "a",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x01, // length 1
+ 0x61,
+ },
+ },
+ // Empty string.
+ {
+ value: "",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x00, // length 0
+ },
+ },
+ } {
+ highBitMask := ^((byte(1) << (test.prefixLen + 1)) - 1)
+ for _, highBits := range []byte{
+ 0, highBitMask, 0b1010_1010 & highBitMask,
+ } {
+ gotEnc := appendPrefixedString(nil, highBits, test.prefixLen, test.value)
+ wantEnc := append([]byte{}, test.encoded...)
+ wantEnc[0] |= highBits
+ if !bytes.Equal(gotEnc, wantEnc) {
+ t.Errorf("appendPrefixedString(nil, 0b%08b, %v, %v) = {%x}, want {%x}",
+ highBits, test.prefixLen, test.value, gotEnc, wantEnc)
+ }
+
+ st1.Write(gotEnc)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ gotFirstByte, v, err := st2.readPrefixedString(test.prefixLen)
+ if err != nil || gotFirstByte&highBitMask != highBits || v != test.value {
+ t.Errorf("st.readPrefixedInt(%v) = 0b%08b, %q, %v; want 0b%08b, %q, nil", test.prefixLen, gotFirstByte, v, err, highBits, test.value)
+ }
+ }
+ }
+}
+
+func TestHuffmanDecodingFailure(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.Write([]byte{
+ 0x82, // H bit + length 4
+ 0b_1111_1111,
+ 0b_1111_1111,
+ 0b_1111_1111,
+ 0b_1111_1111,
+ })
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ if b, v, err := st2.readPrefixedString(7); err == nil {
+ t.Fatalf("readPrefixedString(7) = %x, %v, nil; want error", b, v)
+ }
+}
diff --git a/internal/http3/quic.go b/internal/http3/quic.go
new file mode 100644
index 0000000000..6d2b120094
--- /dev/null
+++ b/internal/http3/quic.go
@@ -0,0 +1,42 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "crypto/tls"
+
+ "golang.org/x/net/quic"
+)
+
+func initConfig(config *quic.Config) *quic.Config {
+ if config == nil {
+ config = &quic.Config{}
+ }
+
+ // maybeCloneTLSConfig clones the user-provided tls.Config (but only once)
+ // prior to us modifying it.
+ needCloneTLSConfig := true
+ maybeCloneTLSConfig := func() *tls.Config {
+ if needCloneTLSConfig {
+ config.TLSConfig = config.TLSConfig.Clone()
+ needCloneTLSConfig = false
+ }
+ return config.TLSConfig
+ }
+
+ if config.TLSConfig == nil {
+ config.TLSConfig = &tls.Config{}
+ needCloneTLSConfig = false
+ }
+ if config.TLSConfig.MinVersion == 0 {
+ maybeCloneTLSConfig().MinVersion = tls.VersionTLS13
+ }
+ if config.TLSConfig.NextProtos == nil {
+ maybeCloneTLSConfig().NextProtos = []string{"h3"}
+ }
+ return config
+}
diff --git a/internal/http3/quic_test.go b/internal/http3/quic_test.go
new file mode 100644
index 0000000000..bc3b110fe9
--- /dev/null
+++ b/internal/http3/quic_test.go
@@ -0,0 +1,234 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "net"
+ "net/netip"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+
+ "golang.org/x/net/internal/gate"
+ "golang.org/x/net/internal/testcert"
+ "golang.org/x/net/quic"
+)
+
+// newLocalQUICEndpoint returns a QUIC Endpoint listening on localhost.
+func newLocalQUICEndpoint(t *testing.T) *quic.Endpoint {
+ t.Helper()
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS)
+ }
+ conf := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ e, err := quic.Listen("udp", "127.0.0.1:0", conf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(context.Background())
+ })
+ return e
+}
+
+// newQUICEndpointPair returns two QUIC endpoints on the same test network.
+func newQUICEndpointPair(t testing.TB) (e1, e2 *quic.Endpoint) {
+ config := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ tn := &testNet{}
+ e1 = tn.newQUICEndpoint(t, config)
+ e2 = tn.newQUICEndpoint(t, config)
+ return e1, e2
+}
+
+// newQUICStreamPair returns the two sides of a bidirectional QUIC stream.
+func newQUICStreamPair(t testing.TB) (s1, s2 *quic.Stream) {
+ t.Helper()
+ config := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ e1, e2 := newQUICEndpointPair(t)
+ c1, err := e1.Dial(context.Background(), "udp", e2.LocalAddr().String(), config)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2, err := e2.Accept(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ s1, err = c1.NewStream(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ s1.Flush()
+ s2, err = c2.AcceptStream(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ return s1, s2
+}
+
+// A testNet is a fake network of net.PacketConns.
+type testNet struct {
+ mu sync.Mutex
+ conns map[netip.AddrPort]*testPacketConn
+}
+
+// newPacketConn returns a new PacketConn with a unique source address.
+func (tn *testNet) newPacketConn() *testPacketConn {
+ tn.mu.Lock()
+ defer tn.mu.Unlock()
+ if tn.conns == nil {
+ tn.conns = make(map[netip.AddrPort]*testPacketConn)
+ }
+ localAddr := netip.AddrPortFrom(
+ netip.AddrFrom4([4]byte{
+ 127, 0, 0, byte(len(tn.conns)),
+ }),
+ 443)
+ tc := &testPacketConn{
+ tn: tn,
+ localAddr: localAddr,
+ gate: gate.New(false),
+ }
+ tn.conns[localAddr] = tc
+ return tc
+}
+
+func (tn *testNet) newQUICEndpoint(t testing.TB, config *quic.Config) *quic.Endpoint {
+ t.Helper()
+ pc := tn.newPacketConn()
+ e, err := quic.NewEndpoint(pc, config)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(t.Context())
+ })
+ return e
+}
+
+// connForAddr returns the conn with the given source address.
+func (tn *testNet) connForAddr(srcAddr netip.AddrPort) *testPacketConn {
+ tn.mu.Lock()
+ defer tn.mu.Unlock()
+ return tn.conns[srcAddr]
+}
+
+// A testPacketConn is a net.PacketConn on a testNet fake network.
+type testPacketConn struct {
+ tn *testNet
+ localAddr netip.AddrPort
+
+ gate gate.Gate
+ queue []testPacket
+ closed bool
+}
+
+type testPacket struct {
+ b []byte
+ src netip.AddrPort
+}
+
+func (tc *testPacketConn) unlock() {
+ tc.gate.Unlock(tc.closed || len(tc.queue) > 0)
+}
+
+func (tc *testPacketConn) ReadFrom(p []byte) (n int, srcAddr net.Addr, err error) {
+ if err := tc.gate.WaitAndLock(context.Background()); err != nil {
+ return 0, nil, err
+ }
+ defer tc.unlock()
+ if tc.closed {
+ return 0, nil, net.ErrClosed
+ }
+ n = copy(p, tc.queue[0].b)
+ srcAddr = net.UDPAddrFromAddrPort(tc.queue[0].src)
+ tc.queue = tc.queue[1:]
+ return n, srcAddr, nil
+}
+
+func (tc *testPacketConn) WriteTo(p []byte, dstAddr net.Addr) (n int, err error) {
+ tc.gate.Lock()
+ closed := tc.closed
+ tc.unlock()
+ if closed {
+ return 0, net.ErrClosed
+ }
+
+ ap, err := addrPortFromAddr(dstAddr)
+ if err != nil {
+ return 0, err
+ }
+ dst := tc.tn.connForAddr(ap)
+ if dst == nil {
+ return len(p), nil // sent into the void
+ }
+ dst.gate.Lock()
+ defer dst.unlock()
+ dst.queue = append(dst.queue, testPacket{
+ b: bytes.Clone(p),
+ src: tc.localAddr,
+ })
+ return len(p), nil
+}
+
+func (tc *testPacketConn) Close() error {
+ tc.tn.mu.Lock()
+ tc.tn.conns[tc.localAddr] = nil
+ tc.tn.mu.Unlock()
+
+ tc.gate.Lock()
+ defer tc.unlock()
+ tc.closed = true
+ tc.queue = nil
+ return nil
+}
+
+func (tc *testPacketConn) LocalAddr() net.Addr {
+ return net.UDPAddrFromAddrPort(tc.localAddr)
+}
+
+func (tc *testPacketConn) SetDeadline(time.Time) error { panic("unimplemented") }
+func (tc *testPacketConn) SetReadDeadline(time.Time) error { panic("unimplemented") }
+func (tc *testPacketConn) SetWriteDeadline(time.Time) error { panic("unimplemented") }
+
+func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.AddrPort(), nil
+ }
+ return netip.ParseAddrPort(addr.String())
+}
+
+var testTLSConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ CipherSuites: []uint16{
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+ },
+ MinVersion: tls.VersionTLS13,
+ Certificates: []tls.Certificate{testCert},
+ NextProtos: []string{"h3"},
+}
+
+var testCert = func() tls.Certificate {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ panic(err)
+ }
+ return cert
+}()
diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go
new file mode 100644
index 0000000000..bf55a13159
--- /dev/null
+++ b/internal/http3/roundtrip.go
@@ -0,0 +1,347 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "strconv"
+ "sync"
+
+ "golang.org/x/net/internal/httpcommon"
+)
+
+type roundTripState struct {
+ cc *ClientConn
+ st *stream
+
+ // Request body, provided by the caller.
+ onceCloseReqBody sync.Once
+ reqBody io.ReadCloser
+
+ reqBodyWriter bodyWriter
+
+ // Response.Body, provided to the caller.
+ respBody bodyReader
+
+ errOnce sync.Once
+ err error
+}
+
+// abort terminates the RoundTrip.
+// It returns the first fatal error encountered by the RoundTrip call.
+func (rt *roundTripState) abort(err error) error {
+ rt.errOnce.Do(func() {
+ rt.err = err
+ switch e := err.(type) {
+ case *connectionError:
+ rt.cc.abort(e)
+ case *streamError:
+ rt.st.stream.CloseRead()
+ rt.st.stream.Reset(uint64(e.code))
+ default:
+ rt.st.stream.CloseRead()
+ rt.st.stream.Reset(uint64(errH3NoError))
+ }
+ })
+ return rt.err
+}
+
+// closeReqBody closes the Request.Body, at most once.
+func (rt *roundTripState) closeReqBody() {
+ if rt.reqBody != nil {
+ rt.onceCloseReqBody.Do(func() {
+ rt.reqBody.Close()
+ })
+ }
+}
+
+// RoundTrip sends a request on the connection.
+func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) {
+ // Each request gets its own QUIC stream.
+ st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest)
+ if err != nil {
+ return nil, err
+ }
+ rt := &roundTripState{
+ cc: cc,
+ st: st,
+ }
+ defer func() {
+ if err != nil {
+ err = rt.abort(err)
+ }
+ }()
+
+ // Cancel reads/writes on the stream when the request expires.
+ st.stream.SetReadContext(req.Context())
+ st.stream.SetWriteContext(req.Context())
+
+ contentLength := actualContentLength(req)
+
+ var encr httpcommon.EncodeHeadersResult
+ headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) {
+ encr, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
+ Request: httpcommon.Request{
+ URL: req.URL,
+ Method: req.Method,
+ Host: req.Host,
+ Header: req.Header,
+ Trailer: req.Trailer,
+ ActualContentLength: contentLength,
+ },
+ AddGzipHeader: false, // TODO: add when appropriate
+ PeerMaxHeaderListSize: 0,
+ DefaultUserAgent: "Go-http-client/3",
+ }, func(name, value string) {
+ // Issue #71374: Consider supporting never-indexed fields.
+ yield(mayIndex, name, value)
+ })
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Write the HEADERS frame.
+ st.writeVarint(int64(frameTypeHeaders))
+ st.writeVarint(int64(len(headers)))
+ st.Write(headers)
+ if err := st.Flush(); err != nil {
+ return nil, err
+ }
+
+ if encr.HasBody {
+ // TODO: Defer sending the request body when "Expect: 100-continue" is set.
+ rt.reqBody = req.Body
+ rt.reqBodyWriter.st = st
+ rt.reqBodyWriter.remain = contentLength
+ rt.reqBodyWriter.flush = true
+ rt.reqBodyWriter.name = "request"
+ go copyRequestBody(rt)
+ }
+
+ // Read the response headers.
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return nil, err
+ }
+ switch ftype {
+ case frameTypeHeaders:
+ statusCode, h, err := cc.handleHeaders(st)
+ if err != nil {
+ return nil, err
+ }
+
+ if statusCode >= 100 && statusCode < 199 {
+ // TODO: Handle 1xx responses.
+ continue
+ }
+
+ // We have the response headers.
+ // Set up the response and return it to the caller.
+ contentLength, err := parseResponseContentLength(req.Method, statusCode, h)
+ if err != nil {
+ return nil, err
+ }
+ rt.respBody.st = st
+ rt.respBody.remain = contentLength
+ resp := &http.Response{
+ Proto: "HTTP/3.0",
+ ProtoMajor: 3,
+ Header: h,
+ StatusCode: statusCode,
+ Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode),
+ ContentLength: contentLength,
+ Body: (*transportResponseBody)(rt),
+ }
+ // TODO: Automatic Content-Type: gzip decoding.
+ return resp, nil
+ case frameTypePushPromise:
+ if err := cc.handlePushPromise(st); err != nil {
+ return nil, err
+ }
+ default:
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return nil, err
+ }
+ }
+ }
+}
+
+// actualContentLength returns a sanitized version of req.ContentLength,
+// where 0 actually means zero (not unknown) and -1 means unknown.
+func actualContentLength(req *http.Request) int64 {
+ if req.Body == nil || req.Body == http.NoBody {
+ return 0
+ }
+ if req.ContentLength != 0 {
+ return req.ContentLength
+ }
+ return -1
+}
+
+func copyRequestBody(rt *roundTripState) {
+ defer rt.closeReqBody()
+ _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody)
+ if closeErr := rt.reqBodyWriter.Close(); err == nil {
+ err = closeErr
+ }
+ if err != nil {
+ // Something went wrong writing the body.
+ rt.abort(err)
+ } else {
+ // We wrote the whole body.
+ rt.st.stream.CloseWrite()
+ }
+}
+
+// transportResponseBody is the Response.Body returned by RoundTrip.
+type transportResponseBody roundTripState
+
+// Read is Response.Body.Read.
+func (b *transportResponseBody) Read(p []byte) (n int, err error) {
+ return b.respBody.Read(p)
+}
+
+var errRespBodyClosed = errors.New("response body closed")
+
+// Close is Response.Body.Close.
+// Closing the response body is how the caller signals that they're done with a request.
+func (b *transportResponseBody) Close() error {
+ rt := (*roundTripState)(b)
+ // Close the request body, which should wake up copyRequestBody if it's
+ // currently blocked reading the body.
+ rt.closeReqBody()
+ // Close the request stream, since we're done with the request.
+ // Reset closes the sending half of the stream.
+ rt.st.stream.Reset(uint64(errH3NoError))
+ // respBody.Close is responsible for closing the receiving half.
+ err := rt.respBody.Close()
+ if err == nil {
+ err = errRespBodyClosed
+ }
+ err = rt.abort(err)
+ if err == errRespBodyClosed {
+ // No other errors occurred before closing Response.Body,
+ // so consider this a successful request.
+ return nil
+ }
+ return err
+}
+
+func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) {
+ clens := h["Content-Length"]
+ if len(clens) == 0 {
+ return -1, nil
+ }
+
+ // We allow duplicate Content-Length headers,
+ // but only if they all have the same value.
+ for _, v := range clens[1:] {
+ if clens[0] != v {
+ return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"}
+ }
+ }
+
+ // "A server MUST NOT send a Content-Length header field in any response
+ // with a status code of 1xx (Informational) or 204 (No Content).
+ // A server MUST NOT send a Content-Length header field in any 2xx (Successful)
+ // response to a CONNECT request [...]"
+ // https://www.rfc-editor.org/rfc/rfc9110#section-8.6-8
+ if (statusCode >= 100 && statusCode < 200) ||
+ statusCode == 204 ||
+ (method == "CONNECT" && statusCode >= 200 && statusCode < 300) {
+ // This is a protocol violation, but a fairly harmless one.
+ // Just ignore the header.
+ return -1, nil
+ }
+
+ contentLen, err := strconv.ParseUint(clens[0], 10, 63)
+ if err != nil {
+ return -1, &streamError{errH3MessageError, "invalid Content-Length header"}
+ }
+ return int64(contentLen), nil
+}
+
+func (cc *ClientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) {
+ haveStatus := false
+ cookie := ""
+ // Issue #71374: Consider tracking the never-indexed status of headers
+ // with the N bit set in their QPACK encoding.
+ err = cc.dec.decode(st, func(_ indexType, name, value string) error {
+ switch {
+ case name == ":status":
+ if haveStatus {
+ return &streamError{errH3MessageError, "duplicate :status"}
+ }
+ haveStatus = true
+ statusCode, err = strconv.Atoi(value)
+ if err != nil {
+ return &streamError{errH3MessageError, "invalid :status"}
+ }
+ case name[0] == ':':
+ // "Endpoints MUST treat a request or response
+ // that contains undefined or invalid
+ // pseudo-header fields as malformed."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3-3
+ return &streamError{errH3MessageError, "undefined pseudo-header"}
+ case name == "cookie":
+ // "If a decompressed field section contains multiple cookie field lines,
+ // these MUST be concatenated into a single byte string [...]"
+ // using the two-byte delimiter of "; "''
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2
+ if cookie == "" {
+ cookie = value
+ } else {
+ cookie += "; " + value
+ }
+ default:
+ if h == nil {
+ h = make(http.Header)
+ }
+ // TODO: Use a per-connection canonicalization cache as we do in HTTP/2.
+ // Maybe we could put this in the QPACK decoder and have it deliver
+ // pre-canonicalized headers to us here?
+ cname := httpcommon.CanonicalHeader(name)
+ // TODO: Consider using a single []string slice for all headers,
+ // as we do in the HTTP/1 and HTTP/2 cases.
+ // This is a bit tricky, since we don't know the number of headers
+ // at the start of decoding. Perhaps it's worth doing a two-pass decode,
+ // or perhaps we should just allocate header value slices in
+ // reasonably-sized chunks.
+ h[cname] = append(h[cname], value)
+ }
+ return nil
+ })
+ if !haveStatus {
+ // "[The :status] pseudo-header field MUST be included in all responses [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3.2-1
+ err = errH3MessageError
+ }
+ if cookie != "" {
+ if h == nil {
+ h = make(http.Header)
+ }
+ h["Cookie"] = []string{cookie}
+ }
+ if err := st.endFrame(); err != nil {
+ return 0, nil, err
+ }
+ return statusCode, h, err
+}
+
+func (cc *ClientConn) handlePushPromise(st *stream) error {
+ // "A client MUST treat receipt of a PUSH_PROMISE frame that contains a
+ // larger push ID than the client has advertised as a connection error of H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5
+ return &connectionError{
+ code: errH3IDError,
+ message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent",
+ }
+}
diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go
new file mode 100644
index 0000000000..acd8613d0e
--- /dev/null
+++ b/internal/http3/roundtrip_test.go
@@ -0,0 +1,354 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/quic"
+)
+
+func TestRoundTripSimple(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ req.Header["User-Agent"] = nil
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(http.Header{
+ ":authority": []string{"example.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ ":scheme": []string{"https"},
+ })
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ "x-some-header": []string{"value"},
+ })
+ rt.wantStatus(200)
+ rt.wantHeaders(http.Header{
+ "X-Some-Header": []string{"value"},
+ })
+ })
+}
+
+func TestRoundTripWithBadHeaders(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ req.Header["Invalid\nHeader"] = []string{"x"}
+ rt := tc.roundTrip(req)
+ rt.wantError("RoundTrip fails when request contains invalid headers")
+ })
+}
+
+func TestRoundTripWithUnknownFrame(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ // Write an unknown frame type before the response HEADERS.
+ data := "frame content"
+ st.writeVarint(0x1f + 0x21) // reserved frame type
+ st.writeVarint(int64(len(data))) // size
+ st.Write([]byte(data))
+
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+ })
+}
+
+func TestRoundTripWithInvalidPushPromise(t *testing.T) {
+ // "A client MUST treat receipt of a PUSH_PROMISE frame that contains
+ // a larger push ID than the client has advertised as a connection error of H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ // Write a PUSH_PROMISE frame.
+ // Since the client hasn't indicated willingness to accept pushes,
+ // this is a connection error.
+ st.writePushPromise(0, http.Header{
+ ":path": []string{"/foo"},
+ })
+ rt.wantError("RoundTrip fails after receiving invalid PUSH_PROMISE")
+ tc.wantClosed(
+ "push ID exceeds client's MAX_PUSH_ID",
+ errH3IDError,
+ )
+ })
+}
+
+func TestRoundTripResponseContentLength(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ wantContentLength int64
+ wantError bool
+ }{{
+ name: "valid",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"100"},
+ },
+ wantContentLength: 100,
+ }, {
+ name: "absent",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ },
+ wantContentLength: -1,
+ }, {
+ name: "unparseable",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"1 1"},
+ },
+ wantError: true,
+ }, {
+ name: "duplicated",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"500", "500", "500"},
+ },
+ wantContentLength: 500,
+ }, {
+ name: "inconsistent",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"1", "2"},
+ },
+ wantError: true,
+ }, {
+ // 204 responses aren't allowed to contain a Content-Length header.
+ // We just ignore it.
+ name: "204",
+ respHeader: http.Header{
+ ":status": []string{"204"},
+ "content-length": []string{"100"},
+ },
+ wantContentLength: -1,
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(test.respHeader)
+ if test.wantError {
+ rt.wantError("invalid content-length in response")
+ return
+ }
+ if got, want := rt.response().ContentLength, test.wantContentLength; got != want {
+ t.Errorf("Response.ContentLength = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+func TestRoundTripMalformedResponses(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ }{{
+ name: "duplicate :status",
+ respHeader: http.Header{
+ ":status": []string{"200", "204"},
+ },
+ }, {
+ name: "unparseable :status",
+ respHeader: http.Header{
+ ":status": []string{"frogpants"},
+ },
+ }, {
+ name: "undefined pseudo-header",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ ":unknown": []string{"x"},
+ },
+ }, {
+ name: "no :status",
+ respHeader: http.Header{},
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(test.respHeader)
+ rt.wantError("malformed response")
+ })
+ }
+}
+
+func TestRoundTripCrumbledCookiesInResponse(t *testing.T) {
+ // "If a decompressed field section contains multiple cookie field lines,
+ // these MUST be concatenated into a single byte string [...]"
+ // using the two-byte delimiter of "; "''
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ "cookie": []string{"a=1", "b=2; c=3", "d=4"},
+ })
+ rt.wantStatus(200)
+ rt.wantHeaders(http.Header{
+ "Cookie": []string{"a=1; b=2; c=3; d=4"},
+ })
+ })
+}
+
+func TestRoundTripRequestBodySent(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyr, bodyw := io.Pipe()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", bodyr)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ bodyw.Write([]byte{0, 1, 2, 3, 4})
+ st.wantData([]byte{0, 1, 2, 3, 4})
+
+ bodyw.Write([]byte{5, 6, 7})
+ st.wantData([]byte{5, 6, 7})
+
+ bodyw.Close()
+ st.wantClosed("request body sent")
+
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+ rt.response().Body.Close()
+ })
+}
+
+func TestRoundTripRequestBodyErrors(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ body io.Reader
+ contentLength int64
+ }{{
+ name: "too short",
+ contentLength: 10,
+ body: bytes.NewReader([]byte{0, 1, 2, 3, 4}),
+ }, {
+ name: "too long",
+ contentLength: 5,
+ body: bytes.NewReader([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
+ }, {
+ name: "read error",
+ body: io.MultiReader(
+ bytes.NewReader([]byte{0, 1, 2, 3, 4}),
+ &testReader{
+ readFunc: func([]byte) (int, error) {
+ return 0, errors.New("read error")
+ },
+ },
+ ),
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", test.body)
+ req.ContentLength = test.contentLength
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+
+ // The Transport should send some number of frames before detecting an
+ // error in the request body and aborting the request.
+ synctest.Wait()
+ for {
+ _, err := st.readFrameHeader()
+ if err != nil {
+ var code quic.StreamErrorCode
+ if !errors.As(err, &code) {
+ t.Fatalf("request stream closed with error %v: want QUIC stream error", err)
+ }
+ break
+ }
+ if err := st.discardFrame(); err != nil {
+ t.Fatalf("discardFrame: %v", err)
+ }
+ }
+
+ // RoundTrip returns with an error.
+ rt.wantError("request fails due to body error")
+ })
+ }
+}
+
+func TestRoundTripRequestBodyErrorAfterHeaders(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyr, bodyw := io.Pipe()
+ req, _ := http.NewRequest("GET", "https://example.tld/", bodyr)
+ req.ContentLength = 10
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+
+ // Server sends response headers, and RoundTrip returns.
+ // The request body hasn't been sent yet.
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+
+ // Write too many bytes to the request body, triggering a request error.
+ bodyw.Write(make([]byte, req.ContentLength+1))
+
+ //io.Copy(io.Discard, st)
+ st.wantError(quic.StreamErrorCode(errH3InternalError))
+
+ if err := rt.response().Body.Close(); err == nil {
+ t.Fatalf("Response.Body.Close() = %v, want error", err)
+ }
+ })
+}
diff --git a/internal/http3/server.go b/internal/http3/server.go
new file mode 100644
index 0000000000..ca93c5298a
--- /dev/null
+++ b/internal/http3/server.go
@@ -0,0 +1,172 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "net/http"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+// A Server is an HTTP/3 server.
+// The zero value for Server is a valid server.
+type Server struct {
+ // Handler to invoke for requests, http.DefaultServeMux if nil.
+ Handler http.Handler
+
+ // Config is the QUIC configuration used by the server.
+ // The Config may be nil.
+ //
+ // ListenAndServe may clone and modify the Config.
+ // The Config must not be modified after calling ListenAndServe.
+ Config *quic.Config
+
+ initOnce sync.Once
+}
+
+func (s *Server) init() {
+ s.initOnce.Do(func() {
+ s.Config = initConfig(s.Config)
+ if s.Handler == nil {
+ s.Handler = http.DefaultServeMux
+ }
+ })
+}
+
+// ListenAndServe listens on the UDP network address addr
+// and then calls Serve to handle requests on incoming connections.
+func (s *Server) ListenAndServe(addr string) error {
+ s.init()
+ e, err := quic.Listen("udp", addr, s.Config)
+ if err != nil {
+ return err
+ }
+ return s.Serve(e)
+}
+
+// Serve accepts incoming connections on the QUIC endpoint e,
+// and handles requests from those connections.
+func (s *Server) Serve(e *quic.Endpoint) error {
+ s.init()
+ for {
+ qconn, err := e.Accept(context.Background())
+ if err != nil {
+ return err
+ }
+ go newServerConn(qconn)
+ }
+}
+
+type serverConn struct {
+ qconn *quic.Conn
+
+ genericConn // for handleUnidirectionalStream
+ enc qpackEncoder
+ dec qpackDecoder
+}
+
+func newServerConn(qconn *quic.Conn) {
+ sc := &serverConn{
+ qconn: qconn,
+ }
+ sc.enc.init()
+
+ // Create control stream and send SETTINGS frame.
+ // TODO: Time out on creating stream.
+ controlStream, err := newConnStream(context.Background(), sc.qconn, streamTypeControl)
+ if err != nil {
+ return
+ }
+ controlStream.writeSettings()
+ controlStream.Flush()
+
+ sc.acceptStreams(sc.qconn, sc)
+}
+
+func (sc *serverConn) handleControlStream(st *stream) error {
+ // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2
+ if err := st.readSettings(func(settingsType, settingsValue int64) error {
+ switch settingsType {
+ case settingsMaxFieldSectionSize:
+ _ = settingsValue // TODO
+ case settingsQPACKMaxTableCapacity:
+ _ = settingsValue // TODO
+ case settingsQPACKBlockedStreams:
+ _ = settingsValue // TODO
+ default:
+ // Unknown settings types are ignored.
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return err
+ }
+ switch ftype {
+ case frameTypeCancelPush:
+ // "If a server receives a CANCEL_PUSH frame for a push ID
+ // that has not yet been mentioned by a PUSH_PROMISE frame,
+ // this MUST be treated as a connection error of type H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8
+ return &connectionError{
+ code: errH3IDError,
+ message: "CANCEL_PUSH for unsent push ID",
+ }
+ case frameTypeGoaway:
+ return errH3NoError
+ default:
+ // Unknown frames are ignored.
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (sc *serverConn) handleEncoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (sc *serverConn) handleDecoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (sc *serverConn) handlePushStream(*stream) error {
+ // "[...] if a server receives a client-initiated push stream,
+ // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "client created push stream",
+ }
+}
+
+func (sc *serverConn) handleRequestStream(st *stream) error {
+ // TODO
+ return nil
+}
+
+// abort closes the connection with an error.
+func (sc *serverConn) abort(err error) {
+ if e, ok := err.(*connectionError); ok {
+ sc.qconn.Abort(&quic.ApplicationError{
+ Code: uint64(e.code),
+ Reason: e.message,
+ })
+ } else {
+ sc.qconn.Abort(err)
+ }
+}
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
new file mode 100644
index 0000000000..8e727d2512
--- /dev/null
+++ b/internal/http3/server_test.go
@@ -0,0 +1,110 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "net/netip"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/internal/quic/quicwire"
+ "golang.org/x/net/quic"
+)
+
+func TestServerReceivePushStream(t *testing.T) {
+ // "[...] if a server receives a client-initiated push stream,
+ // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
+ runSynctest(t, func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ tc.newStream(streamTypePush)
+ tc.wantClosed("invalid client-created push stream", errH3StreamCreationError)
+ })
+}
+
+func TestServerCancelPushForUnsentPromise(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ tc.greet()
+
+ const pushID = 100
+ tc.control.writeVarint(int64(frameTypeCancelPush))
+ tc.control.writeVarint(int64(quicwire.SizeVarint(pushID)))
+ tc.control.writeVarint(pushID)
+ tc.control.Flush()
+
+ tc.wantClosed("client canceled never-sent push ID", errH3IDError)
+ })
+}
+
+type testServer struct {
+ t testing.TB
+ s *Server
+ tn testNet
+ *testQUICEndpoint
+
+ addr netip.AddrPort
+}
+
+type testQUICEndpoint struct {
+ t testing.TB
+ e *quic.Endpoint
+}
+
+func (te *testQUICEndpoint) dial() {
+}
+
+type testServerConn struct {
+ ts *testServer
+
+ *testQUICConn
+ control *testQUICStream
+}
+
+func newTestServer(t testing.TB) *testServer {
+ t.Helper()
+ ts := &testServer{
+ t: t,
+ s: &Server{
+ Config: &quic.Config{
+ TLSConfig: testTLSConfig,
+ },
+ },
+ }
+ e := ts.tn.newQUICEndpoint(t, ts.s.Config)
+ ts.addr = e.LocalAddr()
+ go ts.s.Serve(e)
+ return ts
+}
+
+func (ts *testServer) connect() *testServerConn {
+ ts.t.Helper()
+ config := &quic.Config{TLSConfig: testTLSConfig}
+ e := ts.tn.newQUICEndpoint(ts.t, nil)
+ qconn, err := e.Dial(ts.t.Context(), "udp", ts.addr.String(), config)
+ if err != nil {
+ ts.t.Fatal(err)
+ }
+ tc := &testServerConn{
+ ts: ts,
+ testQUICConn: newTestQUICConn(ts.t, qconn),
+ }
+ synctest.Wait()
+ return tc
+}
+
+// greet performs initial connection handshaking with the server.
+func (tc *testServerConn) greet() {
+ // Client creates a control stream.
+ tc.control = tc.newStream(streamTypeControl)
+ tc.control.writeVarint(int64(frameTypeSettings))
+ tc.control.writeVarint(0) // size
+ tc.control.Flush()
+ synctest.Wait()
+}
diff --git a/internal/http3/settings.go b/internal/http3/settings.go
new file mode 100644
index 0000000000..b5e562ecad
--- /dev/null
+++ b/internal/http3/settings.go
@@ -0,0 +1,72 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1
+ settingsMaxFieldSectionSize = 0x06
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-5
+ settingsQPACKMaxTableCapacity = 0x01
+ settingsQPACKBlockedStreams = 0x07
+)
+
+// writeSettings writes a complete SETTINGS frame.
+// Its parameter is a list of alternating setting types and values.
+func (st *stream) writeSettings(settings ...int64) {
+ var size int64
+ for _, s := range settings {
+ // Settings values that don't fit in a QUIC varint ([0,2^62)) will panic here.
+ size += int64(quicwire.SizeVarint(uint64(s)))
+ }
+ st.writeVarint(int64(frameTypeSettings))
+ st.writeVarint(size)
+ for _, s := range settings {
+ st.writeVarint(s)
+ }
+}
+
+// readSettings reads a complete SETTINGS frame, including the frame header.
+func (st *stream) readSettings(f func(settingType, value int64) error) error {
+ frameType, err := st.readFrameHeader()
+ if err != nil || frameType != frameTypeSettings {
+ return &connectionError{
+ code: errH3MissingSettings,
+ message: "settings not sent on control stream",
+ }
+ }
+ for st.lim > 0 {
+ settingsType, err := st.readVarint()
+ if err != nil {
+ return err
+ }
+ settingsValue, err := st.readVarint()
+ if err != nil {
+ return err
+ }
+
+ // Use of HTTP/2 settings where there is no corresponding HTTP/3 setting
+ // is an error.
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5
+ switch settingsType {
+ case 0x02, 0x03, 0x04, 0x05:
+ return &connectionError{
+ code: errH3SettingsError,
+ message: "use of reserved setting",
+ }
+ }
+
+ if err := f(settingsType, settingsValue); err != nil {
+ return err
+ }
+ }
+ return st.endFrame()
+}
diff --git a/internal/http3/stream.go b/internal/http3/stream.go
new file mode 100644
index 0000000000..0f975407be
--- /dev/null
+++ b/internal/http3/stream.go
@@ -0,0 +1,262 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "io"
+
+ "golang.org/x/net/quic"
+)
+
+// A stream wraps a QUIC stream, providing methods to read/write various values.
+type stream struct {
+ stream *quic.Stream
+
+ // lim is the current read limit.
+ // Reading a frame header sets the limit to the end of the frame.
+ // Reading past the limit or reading less than the limit and ending the frame
+ // results in an error.
+ // -1 indicates no limit.
+ lim int64
+}
+
+// newConnStream creates a new stream on a connection.
+// It writes the stream header for unidirectional streams.
+//
+// The stream returned by newStream is not flushed,
+// and will not be sent to the peer until the caller calls
+// Flush or writes enough data to the stream.
+func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) {
+ var qs *quic.Stream
+ var err error
+ if stype == streamTypeRequest {
+ // Request streams are bidirectional.
+ qs, err = qconn.NewStream(ctx)
+ } else {
+ // All other streams are unidirectional.
+ qs, err = qconn.NewSendOnlyStream(ctx)
+ }
+ if err != nil {
+ return nil, err
+ }
+ st := &stream{
+ stream: qs,
+ lim: -1, // no limit
+ }
+ if stype != streamTypeRequest {
+ // Unidirectional stream header.
+ st.writeVarint(int64(stype))
+ }
+ return st, err
+}
+
+func newStream(qs *quic.Stream) *stream {
+ return &stream{
+ stream: qs,
+ lim: -1, // no limit
+ }
+}
+
+// readFrameHeader reads the type and length fields of an HTTP/3 frame.
+// It sets the read limit to the end of the frame.
+//
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1
+func (st *stream) readFrameHeader() (ftype frameType, err error) {
+ if st.lim >= 0 {
+ // We shoudn't call readFrameHeader before ending the previous frame.
+ return 0, errH3FrameError
+ }
+ ftype, err = readVarint[frameType](st)
+ if err != nil {
+ return 0, err
+ }
+ size, err := st.readVarint()
+ if err != nil {
+ return 0, err
+ }
+ st.lim = size
+ return ftype, nil
+}
+
+// endFrame is called after reading a frame to reset the read limit.
+// It returns an error if the entire contents of a frame have not been read.
+func (st *stream) endFrame() error {
+ if st.lim != 0 {
+ return &connectionError{
+ code: errH3FrameError,
+ message: "invalid HTTP/3 frame",
+ }
+ }
+ st.lim = -1
+ return nil
+}
+
+// readFrameData returns the remaining data in the current frame.
+func (st *stream) readFrameData() ([]byte, error) {
+ if st.lim < 0 {
+ return nil, errH3FrameError
+ }
+ // TODO: Pool buffers to avoid allocation here.
+ b := make([]byte, st.lim)
+ _, err := io.ReadFull(st, b)
+ if err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+// ReadByte reads one byte from the stream.
+func (st *stream) ReadByte() (b byte, err error) {
+ if err := st.recordBytesRead(1); err != nil {
+ return 0, err
+ }
+ b, err = st.stream.ReadByte()
+ if err != nil {
+ if err == io.EOF && st.lim < 0 {
+ return 0, io.EOF
+ }
+ return 0, errH3FrameError
+ }
+ return b, nil
+}
+
+// Read reads from the stream.
+func (st *stream) Read(b []byte) (int, error) {
+ n, err := st.stream.Read(b)
+ if e2 := st.recordBytesRead(n); e2 != nil {
+ return 0, e2
+ }
+ if err == io.EOF {
+ if st.lim == 0 {
+ // EOF at end of frame, ignore.
+ return n, nil
+ } else if st.lim > 0 {
+ // EOF inside frame, error.
+ return 0, errH3FrameError
+ } else {
+ // EOF outside of frame, surface to caller.
+ return n, io.EOF
+ }
+ }
+ if err != nil {
+ return 0, errH3FrameError
+ }
+ return n, nil
+}
+
+// discardUnknownFrame discards an unknown frame.
+//
+// HTTP/3 requires that unknown frames be ignored on all streams.
+// However, a known frame appearing in an unexpected place is a fatal error,
+// so this returns an error if the frame is one we know.
+func (st *stream) discardUnknownFrame(ftype frameType) error {
+ switch ftype {
+ case frameTypeData,
+ frameTypeHeaders,
+ frameTypeCancelPush,
+ frameTypeSettings,
+ frameTypePushPromise,
+ frameTypeGoaway,
+ frameTypeMaxPushID:
+ return &connectionError{
+ code: errH3FrameUnexpected,
+ message: "unexpected " + ftype.String() + " frame",
+ }
+ }
+ return st.discardFrame()
+}
+
+// discardFrame discards any remaining data in the current frame and resets the read limit.
+func (st *stream) discardFrame() error {
+ // TODO: Consider adding a *quic.Stream method to discard some amount of data.
+ for range st.lim {
+ _, err := st.stream.ReadByte()
+ if err != nil {
+ return &streamError{errH3FrameError, err.Error()}
+ }
+ }
+ st.lim = -1
+ return nil
+}
+
+// Write writes to the stream.
+func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) }
+
+// Flush commits data written to the stream.
+func (st *stream) Flush() error { return st.stream.Flush() }
+
+// readVarint reads a QUIC variable-length integer from the stream.
+func (st *stream) readVarint() (v int64, err error) {
+ b, err := st.stream.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ v = int64(b & 0x3f)
+ n := 1 << (b >> 6)
+ for i := 1; i < n; i++ {
+ b, err := st.stream.ReadByte()
+ if err != nil {
+ return 0, errH3FrameError
+ }
+ v = (v << 8) | int64(b)
+ }
+ if err := st.recordBytesRead(n); err != nil {
+ return 0, err
+ }
+ return v, nil
+}
+
+// readVarint reads a varint of a particular type.
+func readVarint[T ~int64 | ~uint64](st *stream) (T, error) {
+ v, err := st.readVarint()
+ return T(v), err
+}
+
+// writeVarint writes a QUIC variable-length integer to the stream.
+func (st *stream) writeVarint(v int64) {
+ switch {
+ case v <= (1<<6)-1:
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<14)-1:
+ st.stream.WriteByte((1 << 6) | byte(v>>8))
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<30)-1:
+ st.stream.WriteByte((2 << 6) | byte(v>>24))
+ st.stream.WriteByte(byte(v >> 16))
+ st.stream.WriteByte(byte(v >> 8))
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<62)-1:
+ st.stream.WriteByte((3 << 6) | byte(v>>56))
+ st.stream.WriteByte(byte(v >> 48))
+ st.stream.WriteByte(byte(v >> 40))
+ st.stream.WriteByte(byte(v >> 32))
+ st.stream.WriteByte(byte(v >> 24))
+ st.stream.WriteByte(byte(v >> 16))
+ st.stream.WriteByte(byte(v >> 8))
+ st.stream.WriteByte(byte(v))
+ default:
+ panic("varint too large")
+ }
+}
+
+// recordBytesRead records that n bytes have been read.
+// It returns an error if the read passes the current limit.
+func (st *stream) recordBytesRead(n int) error {
+ if st.lim < 0 {
+ return nil
+ }
+ st.lim -= int64(n)
+ if st.lim < 0 {
+ st.stream = nil // panic if we try to read again
+ return &connectionError{
+ code: errH3FrameError,
+ message: "invalid HTTP/3 frame",
+ }
+ }
+ return nil
+}
diff --git a/internal/http3/stream_test.go b/internal/http3/stream_test.go
new file mode 100644
index 0000000000..12b281c558
--- /dev/null
+++ b/internal/http3/stream_test.go
@@ -0,0 +1,319 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+func TestStreamReadVarint(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, b := range [][]byte{
+ {0x00},
+ {0x3f},
+ {0x40, 0x00},
+ {0x7f, 0xff},
+ {0x80, 0x00, 0x00, 0x00},
+ {0xbf, 0xff, 0xff, 0xff},
+ {0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1
+ {0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c},
+ {0x9d, 0x7f, 0x3e, 0x7d},
+ {0x7b, 0xbd},
+ {0x25},
+ {0x40, 0x25},
+ } {
+ trailer := []byte{0xde, 0xad, 0xbe, 0xef}
+ st1.Write(b)
+ st1.Write(trailer)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ got, err := st2.readVarint()
+ if err != nil {
+ t.Fatalf("st.readVarint() = %v", err)
+ }
+ want, _ := quicwire.ConsumeVarintInt64(b)
+ if got != want {
+ t.Fatalf("st.readVarint() = %v, want %v", got, want)
+ }
+ gotTrailer := make([]byte, len(trailer))
+ if _, err := io.ReadFull(st2, gotTrailer); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(gotTrailer, trailer) {
+ t.Fatalf("after st.readVarint, read %x, want %x", gotTrailer, trailer)
+ }
+ }
+}
+
+func TestStreamWriteVarint(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, v := range []int64{
+ 0,
+ 63,
+ 16383,
+ 1073741823,
+ 4611686018427387903,
+ // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1
+ 151288809941952652,
+ 494878333,
+ 15293,
+ 37,
+ } {
+ trailer := []byte{0xde, 0xad, 0xbe, 0xef}
+ st1.writeVarint(v)
+ st1.Write(trailer)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ want := quicwire.AppendVarint(nil, uint64(v))
+ want = append(want, trailer...)
+
+ got := make([]byte, len(want))
+ if _, err := io.ReadFull(st2, got); err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(got, want) {
+ t.Errorf("AppendVarint(nil, %v) = %x, want %x", v, got, want)
+ }
+ }
+}
+
+func TestStreamReadFrames(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, frame := range []struct {
+ ftype frameType
+ data []byte
+ }{{
+ ftype: 1,
+ data: []byte("hello"),
+ }, {
+ ftype: 2,
+ data: []byte{},
+ }, {
+ ftype: 3,
+ data: []byte("goodbye"),
+ }} {
+ st1.writeVarint(int64(frame.ftype))
+ st1.writeVarint(int64(len(frame.data)))
+ st1.Write(frame.data)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if gotFrameType, err := st2.readFrameHeader(); err != nil || gotFrameType != frame.ftype {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", gotFrameType, err, frame.ftype)
+ }
+ if gotData, err := st2.readFrameData(); err != nil || !bytes.Equal(gotData, frame.data) {
+ t.Fatalf("st.readFrameData() = %x, %v; want %x, nil", gotData, err, frame.data)
+ }
+ if err := st2.endFrame(); err != nil {
+ t.Fatalf("st.endFrame() = %v; want nil", err)
+ }
+ }
+}
+
+func TestStreamReadFrameUnderflow(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := io.ReadFull(st2, make([]byte, size-1)); err != nil {
+ t.Fatalf("st.Read() = %v", err)
+ }
+ // We have not consumed the full frame: Error.
+ if err := st2.endFrame(); !errors.Is(err, errH3FrameError) {
+ t.Fatalf("st.endFrame before end: %v, want errH3FrameError", err)
+ }
+}
+
+func TestStreamReadFrameWithoutEnd(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := st2.readFrameHeader(); err == nil {
+ t.Fatalf("st.readFrameHeader before st.endFrame for prior frame: success, want error")
+ }
+}
+
+func TestStreamReadFrameOverflow(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size+1)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := io.ReadFull(st2, make([]byte, size+1)); !errors.Is(err, errH3FrameError) {
+ t.Fatalf("st.Read past end of frame: %v, want errH3FrameError", err)
+ }
+}
+
+func TestStreamReadFrameHeaderPartial(t *testing.T) {
+ var frame []byte
+ frame = quicwire.AppendVarint(frame, 1000) // type
+ frame = quicwire.AppendVarint(frame, 2000) // size
+
+ for i := 1; i < len(frame)-1; i++ {
+ st1, st2 := newStreamPair(t)
+ st1.Write(frame[:i])
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ st1.stream.CloseWrite()
+
+ if _, err := st2.readFrameHeader(); err == nil {
+ t.Fatalf("%v/%v bytes of frame available: st.readFrameHeader() succeded; want error", i, len(frame))
+ }
+ }
+}
+
+func TestStreamReadFrameDataPartial(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(1) // type
+ st1.writeVarint(100) // size
+ st1.Write(make([]byte, 50)) // data
+ st1.stream.CloseWrite()
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if n, err := io.ReadAll(st2); err == nil {
+ t.Fatalf("io.ReadAll with partial frame = %v, nil; want error", n)
+ }
+}
+
+func TestStreamReadByteFrameDataPartial(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(1) // type
+ st1.writeVarint(100) // size
+ st1.stream.CloseWrite()
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if b, err := st2.ReadByte(); err == nil {
+ t.Fatalf("io.ReadAll with partial frame = %v, nil; want error", b)
+ }
+}
+
+func TestStreamReadFrameDataAtEOF(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+
+ st1.Write(data) // data
+ st1.stream.CloseWrite() // end stream
+ got := make([]byte, len(data)+1)
+ if n, err := st2.Read(got); err != nil || n != len(data) || !bytes.Equal(got[:n], data) {
+ t.Fatalf("st.Read() = %v, %v (data=%x); want %v, nil (data=%x)", n, err, got[:n], len(data), data)
+ }
+}
+
+func TestStreamReadFrameData(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ st1.Write(data) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+ if got, err := st2.readFrameData(); err != nil || !bytes.Equal(got, data) {
+ t.Fatalf("st.readFrameData() = %x, %v; want %x, nil", got, err, data)
+ }
+}
+
+func TestStreamReadByte(t *testing.T) {
+ const stype = 1
+ const want = 42
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(stype) // stream type
+ st1.writeVarint(1) // size
+ st1.Write([]byte{want}) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, err := st2.readFrameHeader(); err != nil || got != stype {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, stype)
+ }
+ if got, err := st2.ReadByte(); err != nil || got != want {
+ t.Fatalf("st.ReadByte() = %v, %v; want %v, nil", got, err, want)
+ }
+ if got, err := st2.ReadByte(); err == nil {
+ t.Fatalf("reading past end of frame: st.ReadByte() = %v, %v; want error", got, err)
+ }
+}
+
+func TestStreamDiscardFrame(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ st1.Write(data) // data
+ st1.stream.CloseWrite()
+
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+ if err := st2.discardFrame(); err != nil {
+ t.Fatalf("st.discardFrame() = %v", err)
+ }
+ if b, err := io.ReadAll(st2); err != nil || len(b) > 0 {
+ t.Fatalf("after discarding frame, read %x, %v; want EOF", b, err)
+ }
+}
+
+func newStreamPair(t testing.TB) (s1, s2 *stream) {
+ t.Helper()
+ q1, q2 := newQUICStreamPair(t)
+ return newStream(q1), newStream(q2)
+}
diff --git a/internal/http3/transport.go b/internal/http3/transport.go
new file mode 100644
index 0000000000..b26524cbda
--- /dev/null
+++ b/internal/http3/transport.go
@@ -0,0 +1,190 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+// A Transport is an HTTP/3 transport.
+//
+// It does not manage a pool of connections,
+// and therefore does not implement net/http.RoundTripper.
+//
+// TODO: Provide a way to register an HTTP/3 transport with a net/http.Transport's
+// connection pool.
+type Transport struct {
+ // Endpoint is the QUIC endpoint used by connections created by the transport.
+ // If unset, it is initialized by the first call to Dial.
+ Endpoint *quic.Endpoint
+
+ // Config is the QUIC configuration used for client connections.
+ // The Config may be nil.
+ //
+ // Dial may clone and modify the Config.
+ // The Config must not be modified after calling Dial.
+ Config *quic.Config
+
+ initOnce sync.Once
+ initErr error
+}
+
+func (tr *Transport) init() error {
+ tr.initOnce.Do(func() {
+ tr.Config = initConfig(tr.Config)
+ if tr.Endpoint == nil {
+ tr.Endpoint, tr.initErr = quic.Listen("udp", ":0", nil)
+ }
+ })
+ return tr.initErr
+}
+
+// Dial creates a new HTTP/3 client connection.
+func (tr *Transport) Dial(ctx context.Context, target string) (*ClientConn, error) {
+ if err := tr.init(); err != nil {
+ return nil, err
+ }
+ qconn, err := tr.Endpoint.Dial(ctx, "udp", target, tr.Config)
+ if err != nil {
+ return nil, err
+ }
+ return newClientConn(ctx, qconn)
+}
+
+// A ClientConn is a client HTTP/3 connection.
+//
+// Multiple goroutines may invoke methods on a ClientConn simultaneously.
+type ClientConn struct {
+ qconn *quic.Conn
+ genericConn
+
+ enc qpackEncoder
+ dec qpackDecoder
+}
+
+func newClientConn(ctx context.Context, qconn *quic.Conn) (*ClientConn, error) {
+ cc := &ClientConn{
+ qconn: qconn,
+ }
+ cc.enc.init()
+
+ // Create control stream and send SETTINGS frame.
+ controlStream, err := newConnStream(ctx, cc.qconn, streamTypeControl)
+ if err != nil {
+ return nil, fmt.Errorf("http3: cannot create control stream: %v", err)
+ }
+ controlStream.writeSettings()
+ controlStream.Flush()
+
+ go cc.acceptStreams(qconn, cc)
+ return cc, nil
+}
+
+// Close closes the connection.
+// Any in-flight requests are canceled.
+// Close does not wait for the peer to acknowledge the connection closing.
+func (cc *ClientConn) Close() error {
+ // Close the QUIC connection immediately with a status of NO_ERROR.
+ cc.qconn.Abort(nil)
+
+ // Return any existing error from the peer, but don't wait for it.
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ return cc.qconn.Wait(ctx)
+}
+
+func (cc *ClientConn) handleControlStream(st *stream) error {
+ // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2
+ if err := st.readSettings(func(settingsType, settingsValue int64) error {
+ switch settingsType {
+ case settingsMaxFieldSectionSize:
+ _ = settingsValue // TODO
+ case settingsQPACKMaxTableCapacity:
+ _ = settingsValue // TODO
+ case settingsQPACKBlockedStreams:
+ _ = settingsValue // TODO
+ default:
+ // Unknown settings types are ignored.
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return err
+ }
+ switch ftype {
+ case frameTypeCancelPush:
+ // "If a CANCEL_PUSH frame is received that references a push ID
+ // greater than currently allowed on the connection,
+ // this MUST be treated as a connection error of type H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-7
+ return &connectionError{
+ code: errH3IDError,
+ message: "CANCEL_PUSH received when no MAX_PUSH_ID has been sent",
+ }
+ case frameTypeGoaway:
+ // TODO: Wait for requests to complete before closing connection.
+ return errH3NoError
+ default:
+ // Unknown frames are ignored.
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (cc *ClientConn) handleEncoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (cc *ClientConn) handleDecoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (cc *ClientConn) handlePushStream(*stream) error {
+ // "A client MUST treat receipt of a push stream as a connection error
+ // of type H3_ID_ERROR when no MAX_PUSH_ID frame has been sent [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.6-3
+ return &connectionError{
+ code: errH3IDError,
+ message: "push stream created when no MAX_PUSH_ID has been sent",
+ }
+}
+
+func (cc *ClientConn) handleRequestStream(st *stream) error {
+ // "Clients MUST treat receipt of a server-initiated bidirectional
+ // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "server created bidirectional stream",
+ }
+}
+
+// abort closes the connection with an error.
+func (cc *ClientConn) abort(err error) {
+ if e, ok := err.(*connectionError); ok {
+ cc.qconn.Abort(&quic.ApplicationError{
+ Code: uint64(e.code),
+ Reason: e.message,
+ })
+ } else {
+ cc.qconn.Abort(err)
+ }
+}
diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go
new file mode 100644
index 0000000000..b300866390
--- /dev/null
+++ b/internal/http3/transport_test.go
@@ -0,0 +1,448 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "reflect"
+ "slices"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/internal/quic/quicwire"
+ "golang.org/x/net/quic"
+)
+
+func TestTransportServerCreatesBidirectionalStream(t *testing.T) {
+ // "Clients MUST treat receipt of a server-initiated bidirectional
+ // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+ st := tc.newStream(streamTypeRequest)
+ st.Flush()
+ tc.wantClosed("after server creates bidi stream", errH3StreamCreationError)
+ })
+}
+
+// A testQUICConn wraps a *quic.Conn and provides methods for inspecting it.
+type testQUICConn struct {
+ t testing.TB
+ qconn *quic.Conn
+ streams map[streamType][]*testQUICStream
+}
+
+func newTestQUICConn(t testing.TB, qconn *quic.Conn) *testQUICConn {
+ tq := &testQUICConn{
+ t: t,
+ qconn: qconn,
+ streams: make(map[streamType][]*testQUICStream),
+ }
+
+ go tq.acceptStreams(t.Context())
+
+ t.Cleanup(func() {
+ tq.qconn.Close()
+ })
+ return tq
+}
+
+func (tq *testQUICConn) acceptStreams(ctx context.Context) {
+ for {
+ qst, err := tq.qconn.AcceptStream(ctx)
+ if err != nil {
+ return
+ }
+ st := newStream(qst)
+ stype := streamTypeRequest
+ if qst.IsReadOnly() {
+ v, err := st.readVarint()
+ if err != nil {
+ tq.t.Errorf("error reading stream type from unidirectional stream: %v", err)
+ continue
+ }
+ stype = streamType(v)
+ }
+ tq.streams[stype] = append(tq.streams[stype], newTestQUICStream(tq.t, st))
+ }
+}
+
+func (tq *testQUICConn) newStream(stype streamType) *testQUICStream {
+ tq.t.Helper()
+ var qs *quic.Stream
+ var err error
+ if stype == streamTypeRequest {
+ qs, err = tq.qconn.NewStream(canceledCtx)
+ } else {
+ qs, err = tq.qconn.NewSendOnlyStream(canceledCtx)
+ }
+ if err != nil {
+ tq.t.Fatal(err)
+ }
+ st := newStream(qs)
+ if stype != streamTypeRequest {
+ st.writeVarint(int64(stype))
+ if err := st.Flush(); err != nil {
+ tq.t.Fatal(err)
+ }
+ }
+ return newTestQUICStream(tq.t, st)
+}
+
+// wantNotClosed asserts that the peer has not closed the connectioln.
+func (tq *testQUICConn) wantNotClosed(reason string) {
+ t := tq.t
+ t.Helper()
+ synctest.Wait()
+ err := tq.qconn.Wait(canceledCtx)
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("%v: want QUIC connection to be alive; closed with error: %v", reason, err)
+ }
+}
+
+// wantClosed asserts that the peer has closed the connection
+// with the provided error code.
+func (tq *testQUICConn) wantClosed(reason string, want error) {
+ t := tq.t
+ t.Helper()
+ synctest.Wait()
+
+ if e, ok := want.(http3Error); ok {
+ want = &quic.ApplicationError{Code: uint64(e)}
+ }
+ got := tq.qconn.Wait(canceledCtx)
+ if errors.Is(got, context.Canceled) {
+ t.Fatalf("%v: want QUIC connection closed, but it is not", reason)
+ }
+ if !errors.Is(got, want) {
+ t.Fatalf("%v: connection closed with error: %v; want %v", reason, got, want)
+ }
+}
+
+// wantStream asserts that a stream of a given type has been created,
+// and returns that stream.
+func (tq *testQUICConn) wantStream(stype streamType) *testQUICStream {
+ tq.t.Helper()
+ synctest.Wait()
+ if len(tq.streams[stype]) == 0 {
+ tq.t.Fatalf("expected a %v stream to be created, but none were", stype)
+ }
+ ts := tq.streams[stype][0]
+ tq.streams[stype] = tq.streams[stype][1:]
+ return ts
+}
+
+// testQUICStream wraps a QUIC stream and provides methods for inspecting it.
+type testQUICStream struct {
+ t testing.TB
+ *stream
+}
+
+func newTestQUICStream(t testing.TB, st *stream) *testQUICStream {
+ st.stream.SetReadContext(canceledCtx)
+ st.stream.SetWriteContext(canceledCtx)
+ return &testQUICStream{
+ t: t,
+ stream: st,
+ }
+}
+
+// wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type.
+func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) {
+ ts.t.Helper()
+ synctest.Wait()
+ gotType, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("%v: failed to read frame header: %v", reason, err)
+ }
+ if gotType != wantType {
+ ts.t.Fatalf("%v: got frame type %v, want %v", reason, gotType, wantType)
+ }
+}
+
+// wantHeaders reads a HEADERS frame.
+// If want is nil, the contents of the frame are ignored.
+func (ts *testQUICStream) wantHeaders(want http.Header) {
+ ts.t.Helper()
+ ftype, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("want HEADERS frame, got error: %v", err)
+ }
+ if ftype != frameTypeHeaders {
+ ts.t.Fatalf("want HEADERS frame, got: %v", ftype)
+ }
+
+ if want == nil {
+ if err := ts.discardFrame(); err != nil {
+ ts.t.Fatalf("discardFrame: %v", err)
+ }
+ return
+ }
+
+ got := make(http.Header)
+ var dec qpackDecoder
+ err = dec.decode(ts.stream, func(_ indexType, name, value string) error {
+ got.Add(name, value)
+ return nil
+ })
+ if diff := diffHeaders(got, want); diff != "" {
+ ts.t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+ if err := ts.endFrame(); err != nil {
+ ts.t.Fatalf("endFrame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) encodeHeaders(h http.Header) []byte {
+ ts.t.Helper()
+ var enc qpackEncoder
+ return enc.encode(func(yield func(itype indexType, name, value string)) {
+ names := slices.Collect(maps.Keys(h))
+ slices.Sort(names)
+ for _, k := range names {
+ for _, v := range h[k] {
+ yield(mayIndex, k, v)
+ }
+ }
+ })
+}
+
+func (ts *testQUICStream) writeHeaders(h http.Header) {
+ ts.t.Helper()
+ headers := ts.encodeHeaders(h)
+ ts.writeVarint(int64(frameTypeHeaders))
+ ts.writeVarint(int64(len(headers)))
+ ts.Write(headers)
+ if err := ts.Flush(); err != nil {
+ ts.t.Fatalf("flushing HEADERS frame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) wantData(want []byte) {
+ ts.t.Helper()
+ synctest.Wait()
+ ftype, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("want DATA frame, got error: %v", err)
+ }
+ if ftype != frameTypeData {
+ ts.t.Fatalf("want DATA frame, got: %v", ftype)
+ }
+ got, err := ts.readFrameData()
+ if err != nil {
+ ts.t.Fatalf("error reading DATA frame: %v", err)
+ }
+ if !bytes.Equal(got, want) {
+ ts.t.Fatalf("got data: {%x}, want {%x}", got, want)
+ }
+ if err := ts.endFrame(); err != nil {
+ ts.t.Fatalf("endFrame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) wantClosed(reason string) {
+ ts.t.Helper()
+ synctest.Wait()
+ ftype, err := ts.readFrameHeader()
+ if err != io.EOF {
+ ts.t.Fatalf("%v: want io.EOF, got %v %v", reason, ftype, err)
+ }
+}
+
+func (ts *testQUICStream) wantError(want quic.StreamErrorCode) {
+ ts.t.Helper()
+ synctest.Wait()
+ _, err := ts.stream.stream.ReadByte()
+ if err == nil {
+ ts.t.Fatalf("successfully read from stream; want stream error code %v", want)
+ }
+ var got quic.StreamErrorCode
+ if !errors.As(err, &got) {
+ ts.t.Fatalf("stream error = %v; want %v", err, want)
+ }
+ if got != want {
+ ts.t.Fatalf("stream error code = %v; want %v", got, want)
+ }
+}
+
+func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) {
+ ts.t.Helper()
+ headers := ts.encodeHeaders(h)
+ ts.writeVarint(int64(frameTypePushPromise))
+ ts.writeVarint(int64(quicwire.SizeVarint(uint64(pushID)) + len(headers)))
+ ts.writeVarint(pushID)
+ ts.Write(headers)
+ if err := ts.Flush(); err != nil {
+ ts.t.Fatalf("flushing PUSH_PROMISE frame: %v", err)
+ }
+}
+
+func diffHeaders(got, want http.Header) string {
+ // nil and 0-length non-nil are equal.
+ if len(got) == 0 && len(want) == 0 {
+ return ""
+ }
+ // We could do a more sophisticated diff here.
+ // DeepEqual is good enough for now.
+ if reflect.DeepEqual(got, want) {
+ return ""
+ }
+ return fmt.Sprintf("got: %v\nwant: %v", got, want)
+}
+
+func (ts *testQUICStream) Flush() error {
+ err := ts.stream.Flush()
+ ts.t.Helper()
+ if err != nil {
+ ts.t.Errorf("unexpected error flushing stream: %v", err)
+ }
+ return err
+}
+
+// A testClientConn is a ClientConn on a test network.
+type testClientConn struct {
+ tr *Transport
+ cc *ClientConn
+
+ // *testQUICConn is the server half of the connection.
+ *testQUICConn
+ control *testQUICStream
+}
+
+func newTestClientConn(t testing.TB) *testClientConn {
+ e1, e2 := newQUICEndpointPair(t)
+ tr := &Transport{
+ Endpoint: e1,
+ Config: &quic.Config{
+ TLSConfig: testTLSConfig,
+ },
+ }
+
+ cc, err := tr.Dial(t.Context(), e2.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ cc.Close()
+ })
+ srvConn, err := e2.Accept(t.Context())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tc := &testClientConn{
+ tr: tr,
+ cc: cc,
+ testQUICConn: newTestQUICConn(t, srvConn),
+ }
+ synctest.Wait()
+ return tc
+}
+
+// greet performs initial connection handshaking with the client.
+func (tc *testClientConn) greet() {
+ // Client creates a control stream.
+ clientControlStream := tc.wantStream(streamTypeControl)
+ clientControlStream.wantFrameHeader(
+ "client sends SETTINGS frame on control stream",
+ frameTypeSettings)
+ clientControlStream.discardFrame()
+
+ // Server creates a control stream.
+ tc.control = tc.newStream(streamTypeControl)
+ tc.control.writeVarint(int64(frameTypeSettings))
+ tc.control.writeVarint(0) // size
+ tc.control.Flush()
+
+ synctest.Wait()
+}
+
+type testRoundTrip struct {
+ t testing.TB
+ resp *http.Response
+ respErr error
+}
+
+func (rt *testRoundTrip) done() bool {
+ synctest.Wait()
+ return rt.resp != nil || rt.respErr != nil
+}
+
+func (rt *testRoundTrip) result() (*http.Response, error) {
+ rt.t.Helper()
+ if !rt.done() {
+ rt.t.Fatal("RoundTrip is not done; want it to be")
+ }
+ return rt.resp, rt.respErr
+}
+
+func (rt *testRoundTrip) response() *http.Response {
+ rt.t.Helper()
+ if !rt.done() {
+ rt.t.Fatal("RoundTrip is not done; want it to be")
+ }
+ if rt.respErr != nil {
+ rt.t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
+ }
+ return rt.resp
+}
+
+// err returns the (possibly nil) error result of RoundTrip.
+func (rt *testRoundTrip) err() error {
+ rt.t.Helper()
+ _, err := rt.result()
+ return err
+}
+
+func (rt *testRoundTrip) wantError(reason string) {
+ rt.t.Helper()
+ synctest.Wait()
+ if !rt.done() {
+ rt.t.Fatalf("%v: RoundTrip is not done; want it to have returned an error", reason)
+ }
+ if rt.respErr == nil {
+ rt.t.Fatalf("%v: RoundTrip succeeded; want it to have returned an error", reason)
+ }
+}
+
+// wantStatus indicates the expected response StatusCode.
+func (rt *testRoundTrip) wantStatus(want int) {
+ rt.t.Helper()
+ if got := rt.response().StatusCode; got != want {
+ rt.t.Fatalf("got response status %v, want %v", got, want)
+ }
+}
+
+func (rt *testRoundTrip) wantHeaders(want http.Header) {
+ rt.t.Helper()
+ if diff := diffHeaders(rt.response().Header, want); diff != "" {
+ rt.t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+}
+
+func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{t: tc.t}
+ go func() {
+ rt.resp, rt.respErr = tc.cc.RoundTrip(req)
+ }()
+ return rt
+}
+
+// canceledCtx is a canceled Context.
+// Used for performing non-blocking QUIC operations.
+var canceledCtx = func() context.Context {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ return ctx
+}()
diff --git a/internal/httpcommon/ascii.go b/internal/httpcommon/ascii.go
new file mode 100644
index 0000000000..ed14da5afc
--- /dev/null
+++ b/internal/httpcommon/ascii.go
@@ -0,0 +1,53 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import "strings"
+
+// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
+// contains helper functions which may use Unicode-aware functions which would
+// otherwise be unsafe and could introduce vulnerabilities if used improperly.
+
+// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func asciiEqualFold(s, t string) bool {
+ if len(s) != len(t) {
+ return false
+ }
+ for i := 0; i < len(s); i++ {
+ if lower(s[i]) != lower(t[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// lower returns the ASCII lowercase version of b.
+func lower(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+}
+
+// isASCIIPrint returns whether s is ASCII and printable according to
+// https://tools.ietf.org/html/rfc20#section-4.2.
+func isASCIIPrint(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] < ' ' || s[i] > '~' {
+ return false
+ }
+ }
+ return true
+}
+
+// asciiToLower returns the lowercase version of s if s is ASCII and printable,
+// and whether or not it was.
+func asciiToLower(s string) (lower string, ok bool) {
+ if !isASCIIPrint(s) {
+ return "", false
+ }
+ return strings.ToLower(s), true
+}
diff --git a/http2/headermap.go b/internal/httpcommon/headermap.go
similarity index 74%
rename from http2/headermap.go
rename to internal/httpcommon/headermap.go
index 149b3dd20e..92483d8e41 100644
--- a/http2/headermap.go
+++ b/internal/httpcommon/headermap.go
@@ -1,11 +1,11 @@
-// Copyright 2014 The Go Authors. All rights reserved.
+// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http2
+package httpcommon
import (
- "net/http"
+ "net/textproto"
"sync"
)
@@ -82,13 +82,15 @@ func buildCommonHeaderMaps() {
commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common))
for _, v := range common {
- chk := http.CanonicalHeaderKey(v)
+ chk := textproto.CanonicalMIMEHeaderKey(v)
commonLowerHeader[chk] = v
commonCanonHeader[v] = chk
}
}
-func lowerHeader(v string) (lower string, ascii bool) {
+// LowerHeader returns the lowercase form of a header name,
+// used on the wire for HTTP/2 and HTTP/3 requests.
+func LowerHeader(v string) (lower string, ascii bool) {
buildCommonHeaderMapsOnce()
if s, ok := commonLowerHeader[v]; ok {
return s, true
@@ -96,10 +98,18 @@ func lowerHeader(v string) (lower string, ascii bool) {
return asciiToLower(v)
}
-func canonicalHeader(v string) string {
+// CanonicalHeader canonicalizes a header name. (For example, "host" becomes "Host".)
+func CanonicalHeader(v string) string {
buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok {
return s
}
- return http.CanonicalHeaderKey(v)
+ return textproto.CanonicalMIMEHeaderKey(v)
+}
+
+// CachedCanonicalHeader returns the canonical form of a well-known header name.
+func CachedCanonicalHeader(v string) (string, bool) {
+ buildCommonHeaderMapsOnce()
+ s, ok := commonCanonHeader[v]
+ return s, ok
}
diff --git a/internal/httpcommon/httpcommon_test.go b/internal/httpcommon/httpcommon_test.go
new file mode 100644
index 0000000000..e725ec76cb
--- /dev/null
+++ b/internal/httpcommon/httpcommon_test.go
@@ -0,0 +1,37 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon_test
+
+import (
+ "bytes"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+// This package is imported by the net/http package,
+// and therefore must not itself import net/http.
+func TestNoNetHttp(t *testing.T) {
+ files, err := filepath.Glob("*.go")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, file := range files {
+ if strings.HasSuffix(file, "_test.go") {
+ continue
+ }
+ // Could use something complex like go/build or x/tools/go/packages,
+ // but there's no reason for "net/http" to appear (in quotes) in the source
+ // otherwise, so just use a simple substring search.
+ data, err := os.ReadFile(file)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Contains(data, []byte(`"net/http"`)) {
+ t.Errorf(`%s: cannot import "net/http"`, file)
+ }
+ }
+}
diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go
new file mode 100644
index 0000000000..4b70553179
--- /dev/null
+++ b/internal/httpcommon/request.go
@@ -0,0 +1,467 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http/httptrace"
+ "net/textproto"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http2/hpack"
+)
+
+var (
+ ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
+)
+
+// Request is a subset of http.Request.
+// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http
+// without creating a dependency cycle.
+type Request struct {
+ URL *url.URL
+ Method string
+ Host string
+ Header map[string][]string
+ Trailer map[string][]string
+ ActualContentLength int64 // 0 means 0, -1 means unknown
+}
+
+// EncodeHeadersParam is parameters to EncodeHeaders.
+type EncodeHeadersParam struct {
+ Request Request
+
+ // AddGzipHeader indicates that an "accept-encoding: gzip" header should be
+ // added to the request.
+ AddGzipHeader bool
+
+ // PeerMaxHeaderListSize, when non-zero, is the peer's MAX_HEADER_LIST_SIZE setting.
+ PeerMaxHeaderListSize uint64
+
+ // DefaultUserAgent is the User-Agent header to send when the request
+ // neither contains a User-Agent nor disables it.
+ DefaultUserAgent string
+}
+
+// EncodeHeadersParam is the result of EncodeHeaders.
+type EncodeHeadersResult struct {
+ HasBody bool
+ HasTrailers bool
+}
+
+// EncodeHeaders constructs request headers common to HTTP/2 and HTTP/3.
+// It validates a request and calls headerf with each pseudo-header and header
+// for the request.
+// The headerf function is called with the validated, canonicalized header name.
+func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
+ req := param.Request
+
+ // Check for invalid connection-level headers.
+ if err := checkConnHeaders(req.Header); err != nil {
+ return res, err
+ }
+
+ if req.URL == nil {
+ return res, errors.New("Request.URL is nil")
+ }
+
+ host := req.Host
+ if host == "" {
+ host = req.URL.Host
+ }
+ host, err := httpguts.PunycodeHostPort(host)
+ if err != nil {
+ return res, err
+ }
+ if !httpguts.ValidHostHeader(host) {
+ return res, errors.New("invalid Host header")
+ }
+
+ // isNormalConnect is true if this is a non-extended CONNECT request.
+ isNormalConnect := false
+ var protocol string
+ if vv := req.Header[":protocol"]; len(vv) > 0 {
+ protocol = vv[0]
+ }
+ if req.Method == "CONNECT" && protocol == "" {
+ isNormalConnect = true
+ } else if protocol != "" && req.Method != "CONNECT" {
+ return res, errors.New("invalid :protocol header in non-CONNECT request")
+ }
+
+ // Validate the path, except for non-extended CONNECT requests which have no path.
+ var path string
+ if !isNormalConnect {
+ path = req.URL.RequestURI()
+ if !validPseudoPath(path) {
+ orig := path
+ path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
+ if !validPseudoPath(path) {
+ if req.URL.Opaque != "" {
+ return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
+ } else {
+ return res, fmt.Errorf("invalid request :path %q", orig)
+ }
+ }
+ }
+ }
+
+ // Check for any invalid headers+trailers and return an error before we
+ // potentially pollute our hpack state. (We want to be able to
+ // continue to reuse the hpack encoder for future requests)
+ if err := validateHeaders(req.Header); err != "" {
+ return res, fmt.Errorf("invalid HTTP header %s", err)
+ }
+ if err := validateHeaders(req.Trailer); err != "" {
+ return res, fmt.Errorf("invalid HTTP trailer %s", err)
+ }
+
+ trailers, err := commaSeparatedTrailers(req.Trailer)
+ if err != nil {
+ return res, err
+ }
+
+ enumerateHeaders := func(f func(name, value string)) {
+ // 8.1.2.3 Request Pseudo-Header Fields
+ // The :path pseudo-header field includes the path and query parts of the
+ // target URI (the path-absolute production and optionally a '?' character
+ // followed by the query production, see Sections 3.3 and 3.4 of
+ // [RFC3986]).
+ f(":authority", host)
+ m := req.Method
+ if m == "" {
+ m = "GET"
+ }
+ f(":method", m)
+ if !isNormalConnect {
+ f(":path", path)
+ f(":scheme", req.URL.Scheme)
+ }
+ if protocol != "" {
+ f(":protocol", protocol)
+ }
+ if trailers != "" {
+ f("trailer", trailers)
+ }
+
+ var didUA bool
+ for k, vv := range req.Header {
+ if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
+ // Host is :authority, already sent.
+ // Content-Length is automatic, set below.
+ continue
+ } else if asciiEqualFold(k, "connection") ||
+ asciiEqualFold(k, "proxy-connection") ||
+ asciiEqualFold(k, "transfer-encoding") ||
+ asciiEqualFold(k, "upgrade") ||
+ asciiEqualFold(k, "keep-alive") {
+ // Per 8.1.2.2 Connection-Specific Header
+ // Fields, don't send connection-specific
+ // fields. We have already checked if any
+ // are error-worthy so just ignore the rest.
+ continue
+ } else if asciiEqualFold(k, "user-agent") {
+ // Match Go's http1 behavior: at most one
+ // User-Agent. If set to nil or empty string,
+ // then omit it. Otherwise if not mentioned,
+ // include the default (below).
+ didUA = true
+ if len(vv) < 1 {
+ continue
+ }
+ vv = vv[:1]
+ if vv[0] == "" {
+ continue
+ }
+ } else if asciiEqualFold(k, "cookie") {
+ // Per 8.1.2.5 To allow for better compression efficiency, the
+ // Cookie header field MAY be split into separate header fields,
+ // each with one or more cookie-pairs.
+ for _, v := range vv {
+ for {
+ p := strings.IndexByte(v, ';')
+ if p < 0 {
+ break
+ }
+ f("cookie", v[:p])
+ p++
+ // strip space after semicolon if any.
+ for p+1 <= len(v) && v[p] == ' ' {
+ p++
+ }
+ v = v[p:]
+ }
+ if len(v) > 0 {
+ f("cookie", v)
+ }
+ }
+ continue
+ } else if k == ":protocol" {
+ // :protocol pseudo-header was already sent above.
+ continue
+ }
+
+ for _, v := range vv {
+ f(k, v)
+ }
+ }
+ if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
+ f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
+ }
+ if param.AddGzipHeader {
+ f("accept-encoding", "gzip")
+ }
+ if !didUA {
+ f("user-agent", param.DefaultUserAgent)
+ }
+ }
+
+ // Do a first pass over the headers counting bytes to ensure
+ // we don't exceed cc.peerMaxHeaderListSize. This is done as a
+ // separate pass before encoding the headers to prevent
+ // modifying the hpack state.
+ if param.PeerMaxHeaderListSize > 0 {
+ hlSize := uint64(0)
+ enumerateHeaders(func(name, value string) {
+ hf := hpack.HeaderField{Name: name, Value: value}
+ hlSize += uint64(hf.Size())
+ })
+
+ if hlSize > param.PeerMaxHeaderListSize {
+ return res, ErrRequestHeaderListSize
+ }
+ }
+
+ trace := httptrace.ContextClientTrace(ctx)
+
+ // Header list size is ok. Write the headers.
+ enumerateHeaders(func(name, value string) {
+ name, ascii := LowerHeader(name)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ return
+ }
+
+ headerf(name, value)
+
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(name, []string{value})
+ }
+ })
+
+ res.HasBody = req.ActualContentLength != 0
+ res.HasTrailers = trailers != ""
+ return res, nil
+}
+
+// IsRequestGzip reports whether we should add an Accept-Encoding: gzip header
+// for a request.
+func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
+ // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
+ if !disableCompression &&
+ len(header["Accept-Encoding"]) == 0 &&
+ len(header["Range"]) == 0 &&
+ method != "HEAD" {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: https://zlib.net/zlib_faq.html#faq39
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // http://trac.nginx.org/nginx/ticket/358
+ // https://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See https://golang.org/issue/8923
+ return true
+ }
+ return false
+}
+
+// checkConnHeaders checks whether req has any invalid connection-level headers.
+//
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2-3
+// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.2.2-1
+//
+// Certain headers are special-cased as okay but not transmitted later.
+// For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding.
+func checkConnHeaders(h map[string][]string) error {
+ if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("invalid Upgrade request header: %q", vv)
+ }
+ if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
+ }
+ if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
+ return fmt.Errorf("invalid Connection request header: %q", vv)
+ }
+ return nil
+}
+
+func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
+ keys := make([]string, 0, len(trailer))
+ for k := range trailer {
+ k = CanonicalHeader(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return "", fmt.Errorf("invalid Trailer key %q", k)
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ return strings.Join(keys, ","), nil
+ }
+ return "", nil
+}
+
+// validPseudoPath reports whether v is a valid :path pseudo-header
+// value. It must be either:
+//
+// - a non-empty string starting with '/'
+// - the string '*', for OPTIONS requests.
+//
+// For now this is only used a quick check for deciding when to clean
+// up Opaque URLs before sending requests from the Transport.
+// See golang.org/issue/16847
+//
+// We used to enforce that the path also didn't start with "//", but
+// Google's GFE accepts such paths and Chrome sends them, so ignore
+// that part of the spec. See golang.org/issue/19103.
+func validPseudoPath(v string) bool {
+ return (len(v) > 0 && v[0] == '/') || v == "*"
+}
+
+func validateHeaders(hdrs map[string][]string) string {
+ for k, vv := range hdrs {
+ if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
+ return fmt.Sprintf("name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error,
+ // because it may be sensitive.
+ return fmt.Sprintf("value for header %q", k)
+ }
+ }
+ }
+ return ""
+}
+
+// shouldSendReqContentLength reports whether we should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func shouldSendReqContentLength(method string, contentLength int64) bool {
+ if contentLength > 0 {
+ return true
+ }
+ if contentLength < 0 {
+ return false
+ }
+ // For zero bodies, whether we send a content-length depends on the method.
+ // It also kinda doesn't matter for http2 either way, with END_STREAM.
+ switch method {
+ case "POST", "PUT", "PATCH":
+ return true
+ default:
+ return false
+ }
+}
+
+// ServerRequestParam is parameters to NewServerRequest.
+type ServerRequestParam struct {
+ Method string
+ Scheme, Authority, Path string
+ Protocol string
+ Header map[string][]string
+}
+
+// ServerRequestResult is the result of NewServerRequest.
+type ServerRequestResult struct {
+ // Various http.Request fields.
+ URL *url.URL
+ RequestURI string
+ Trailer map[string][]string
+
+ NeedsContinue bool // client provided an "Expect: 100-continue" header
+
+ // If the request should be rejected, this is a short string suitable for passing
+ // to the http2 package's CountError function.
+ // It might be a bit odd to return errors this way rather than returing an error,
+ // but this ensures we don't forget to include a CountError reason.
+ InvalidReason string
+}
+
+func NewServerRequest(rp ServerRequestParam) ServerRequestResult {
+ needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue")
+ if needsContinue {
+ delete(rp.Header, "Expect")
+ }
+ // Merge Cookie headers into one "; "-delimited value.
+ if cookies := rp.Header["Cookie"]; len(cookies) > 1 {
+ rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")}
+ }
+
+ // Setup Trailers
+ var trailer map[string][]string
+ for _, v := range rp.Header["Trailer"] {
+ for _, key := range strings.Split(v, ",") {
+ key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ // Bogus. (copy of http1 rules)
+ // Ignore.
+ default:
+ if trailer == nil {
+ trailer = make(map[string][]string)
+ }
+ trailer[key] = nil
+ }
+ }
+ }
+ delete(rp.Header, "Trailer")
+
+ // "':authority' MUST NOT include the deprecated userinfo subcomponent
+ // for "http" or "https" schemed URIs."
+ // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
+ if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") {
+ return ServerRequestResult{
+ InvalidReason: "userinfo_in_authority",
+ }
+ }
+
+ var url_ *url.URL
+ var requestURI string
+ if rp.Method == "CONNECT" && rp.Protocol == "" {
+ url_ = &url.URL{Host: rp.Authority}
+ requestURI = rp.Authority // mimic HTTP/1 server behavior
+ } else {
+ var err error
+ url_, err = url.ParseRequestURI(rp.Path)
+ if err != nil {
+ return ServerRequestResult{
+ InvalidReason: "bad_path",
+ }
+ }
+ requestURI = rp.Path
+ }
+
+ return ServerRequestResult{
+ URL: url_,
+ NeedsContinue: needsContinue,
+ RequestURI: requestURI,
+ Trailer: trailer,
+ }
+}
diff --git a/internal/httpcommon/request_test.go b/internal/httpcommon/request_test.go
new file mode 100644
index 0000000000..b8792977c1
--- /dev/null
+++ b/internal/httpcommon/request_test.go
@@ -0,0 +1,672 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import (
+ "cmp"
+ "context"
+ "io"
+ "net/http"
+ "slices"
+ "strings"
+ "testing"
+)
+
+func TestEncodeHeaders(t *testing.T) {
+ type header struct {
+ name string
+ value string
+ }
+ for _, test := range []struct {
+ name string
+ in EncodeHeadersParam
+ want EncodeHeadersResult
+ wantHeaders []header
+ disableCompression bool
+ }{{
+ name: "simple request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("GET", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "host set from URL",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Host = ""
+ req.URL.Host = "example.tld"
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "chunked transfer-encoding",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Transfer-Encoding", "chunked") // ignored
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "connection close",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "close")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "connection keep-alive",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "keep-alive")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "normal connect",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("CONNECT", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "CONNECT"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "extended connect",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("CONNECT", "https://example.tld/", nil))
+ req.Header.Set(":protocol", "foo")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "CONNECT"},
+ {":path", "/"},
+ {":protocol", "foo"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "trailers",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("a", "1")
+ req.Trailer.Set("b", "2")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: true,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"trailer", "A,B"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "override user-agent",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("User-Agent", "GopherTron 9000")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "GopherTron 9000"},
+ },
+ }, {
+ name: "disable user-agent",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header["User-Agent"] = nil
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ },
+ }, {
+ name: "ignore host header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Host", "gophers.tld/") // ignored
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "crumble cookie header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Cookie", "a=b; b=c; c=d")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ // Cookie header is split into separate header fields.
+ {"cookie", "a=b"},
+ {"cookie", "b=c"},
+ {"cookie", "c=d"},
+ },
+ }, {
+ name: "post with nil body",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("POST", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "0"},
+ },
+ }, {
+ name: "post with NoBody",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("POST", "https://example.tld/", http.NoBody))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "0"},
+ },
+ }, {
+ name: "post with Content-Length",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ type reader struct{ io.ReadCloser }
+ req := must(http.NewRequest("POST", "https://example.tld/", reader{}))
+ req.ContentLength = 10
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: true,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "10"},
+ },
+ }, {
+ name: "post with unknown Content-Length",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ type reader struct{ io.ReadCloser }
+ req := must(http.NewRequest("POST", "https://example.tld/", reader{}))
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: true,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "explicit accept-encoding",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Accept-Encoding", "deflate")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "deflate"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "head request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("HEAD", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "HEAD"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "range request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("HEAD", "https://example.tld/", nil))
+ req.Header.Set("Range", "bytes=0-10")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "HEAD"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"user-agent", "default-user-agent"},
+ {"range", "bytes=0-10"},
+ },
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ var gotHeaders []header
+ if IsRequestGzip(test.in.Request.Method, test.in.Request.Header, test.disableCompression) {
+ test.in.AddGzipHeader = true
+ }
+
+ got, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {
+ gotHeaders = append(gotHeaders, header{name, value})
+ })
+ if err != nil {
+ t.Fatalf("EncodeHeaders = %v", err)
+ }
+ if got.HasBody != test.want.HasBody {
+ t.Errorf("HasBody = %v, want %v", got.HasBody, test.want.HasBody)
+ }
+ if got.HasTrailers != test.want.HasTrailers {
+ t.Errorf("HasTrailers = %v, want %v", got.HasTrailers, test.want.HasTrailers)
+ }
+ cmpHeader := func(a, b header) int {
+ return cmp.Or(
+ cmp.Compare(a.name, b.name),
+ cmp.Compare(a.value, b.value),
+ )
+ }
+ slices.SortFunc(gotHeaders, cmpHeader)
+ slices.SortFunc(test.wantHeaders, cmpHeader)
+ if !slices.Equal(gotHeaders, test.wantHeaders) {
+ t.Errorf("got headers:")
+ for _, h := range gotHeaders {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ t.Errorf("want headers:")
+ for _, h := range test.wantHeaders {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ }
+ })
+ }
+}
+
+func TestEncodeHeaderErrors(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ in EncodeHeadersParam
+ want string
+ }{{
+ name: "URL is nil",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.URL = nil
+ return req
+ }),
+ },
+ want: "URL is nil",
+ }, {
+ name: "upgrade header is set",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Upgrade", "foo")
+ return req
+ }),
+ },
+ want: "Upgrade",
+ }, {
+ name: "unsupported transfer-encoding header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Transfer-Encoding", "identity")
+ return req
+ }),
+ },
+ want: "Transfer-Encoding",
+ }, {
+ name: "unsupported connection header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "x")
+ return req
+ }),
+ },
+ want: "Connection",
+ }, {
+ name: "invalid host",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Host = "\x00.tld"
+ return req
+ }),
+ },
+ want: "Host",
+ }, {
+ name: "protocol header is set",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set(":protocol", "foo")
+ return req
+ }),
+ },
+ want: ":protocol",
+ }, {
+ name: "invalid path",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.URL.Path = "no_leading_slash"
+ return req
+ }),
+ },
+ want: "path",
+ }, {
+ name: "invalid header name",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("x\ny", "foo")
+ return req
+ }),
+ },
+ want: "header",
+ }, {
+ name: "invalid header value",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("x", "foo\nbar")
+ return req
+ }),
+ },
+ want: "header",
+ }, {
+ name: "invalid trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("x\ny", "foo")
+ return req
+ }),
+ },
+ want: "trailer",
+ }, {
+ name: "transfer-encoding trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Transfer-Encoding", "chunked")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "trailer trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Trailer", "chunked")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "content-length trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Content-Length", "0")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "too many headers",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("X-Foo", strings.Repeat("x", 1000))
+ return req
+ }),
+ PeerMaxHeaderListSize: 1000,
+ },
+ want: "limit",
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ _, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {})
+ if err == nil {
+ t.Fatalf("EncodeHeaders = nil, want %q", test.want)
+ }
+ if !strings.Contains(err.Error(), test.want) {
+ t.Fatalf("EncodeHeaders = %q, want error containing %q", err, test.want)
+ }
+ })
+ }
+}
+
+func newReq(f func() *http.Request) Request {
+ req := f()
+ contentLength := req.ContentLength
+ if req.Body == nil || req.Body == http.NoBody {
+ contentLength = 0
+ } else if contentLength == 0 {
+ contentLength = -1
+ }
+ return Request{
+ Header: req.Header,
+ Trailer: req.Trailer,
+ URL: req.URL,
+ Host: req.Host,
+ Method: req.Method,
+ ActualContentLength: contentLength,
+ }
+}
+
+func must[T any](v T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/internal/iana/gen.go b/internal/iana/gen.go
index 34f0f7eeea..b4470baa75 100644
--- a/internal/iana/gen.go
+++ b/internal/iana/gen.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
//go:generate go run gen.go
@@ -17,7 +16,6 @@ import (
"fmt"
"go/format"
"io"
- "io/ioutil"
"net/http"
"os"
"strconv"
@@ -70,7 +68,7 @@ func main() {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
- if err := ioutil.WriteFile("const.go", b, 0644); err != nil {
+ if err := os.WriteFile("const.go", b, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
diff --git a/internal/quic/cmd/interop/Dockerfile b/internal/quic/cmd/interop/Dockerfile
new file mode 100644
index 0000000000..b60999a862
--- /dev/null
+++ b/internal/quic/cmd/interop/Dockerfile
@@ -0,0 +1,32 @@
+FROM martenseemann/quic-network-simulator-endpoint:latest AS builder
+
+ARG TARGETPLATFORM
+RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}"
+
+RUN apt-get update && apt-get install -y wget tar git
+
+ENV GOVERSION=1.21.1
+
+RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \
+ filename="go${GOVERSION}.${platform}.tar.gz" && \
+ wget --no-verbose https://dl.google.com/go/${filename} && \
+ tar xfz ${filename} && \
+ rm ${filename}
+
+ENV PATH="/go/bin:${PATH}"
+
+RUN git clone https://go.googlesource.com/net
+
+WORKDIR /net
+RUN go build -o /interop ./internal/quic/cmd/interop
+
+FROM martenseemann/quic-network-simulator-endpoint:latest
+
+WORKDIR /go-x-net
+
+COPY --from=builder /interop ./
+
+# copy run script and run it
+COPY run_endpoint.sh .
+RUN chmod +x run_endpoint.sh
+ENTRYPOINT [ "./run_endpoint.sh" ]
diff --git a/internal/quic/cmd/interop/README.md b/internal/quic/cmd/interop/README.md
new file mode 100644
index 0000000000..aca0571b91
--- /dev/null
+++ b/internal/quic/cmd/interop/README.md
@@ -0,0 +1,7 @@
+This directory contains configuration and programs used to
+integrate with the QUIC Interop Test Runner.
+
+The QUIC Interop Test Runner executes a variety of test cases
+against a matrix of clients and servers.
+
+https://github.com/marten-seemann/quic-interop-runner
diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go
new file mode 100644
index 0000000000..5b652a2b15
--- /dev/null
+++ b/internal/quic/cmd/interop/main.go
@@ -0,0 +1,269 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+// The interop command is the client and server used by QUIC interoperability tests.
+//
+// https://github.com/marten-seemann/quic-interop-runner
+package main
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "log/slog"
+ "net"
+ "net/url"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "golang.org/x/net/quic"
+ "golang.org/x/net/quic/qlog"
+)
+
+var (
+ listen = flag.String("listen", "", "listen address")
+ cert = flag.String("cert", "", "certificate")
+ pkey = flag.String("key", "", "private key")
+ root = flag.String("root", "", "serve files from this root")
+ output = flag.String("output", "", "directory to write files to")
+ qlogdir = flag.String("qlog", "", "directory to write qlog output to")
+)
+
+func main() {
+ ctx := context.Background()
+ flag.Parse()
+ urls := flag.Args()
+
+ config := &quic.Config{
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ MinVersion: tls.VersionTLS13,
+ NextProtos: []string{"hq-interop"},
+ },
+ MaxBidiRemoteStreams: -1,
+ MaxUniRemoteStreams: -1,
+ QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: quic.QLogLevelFrame,
+ Dir: *qlogdir,
+ })),
+ }
+ if *cert != "" {
+ c, err := tls.LoadX509KeyPair(*cert, *pkey)
+ if err != nil {
+ log.Fatal(err)
+ }
+ config.TLSConfig.Certificates = []tls.Certificate{c}
+ }
+ if *root != "" {
+ config.MaxBidiRemoteStreams = 100
+ }
+ if keylog := os.Getenv("SSLKEYLOGFILE"); keylog != "" {
+ f, err := os.Create(keylog)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer f.Close()
+ config.TLSConfig.KeyLogWriter = f
+ }
+
+ testcase := os.Getenv("TESTCASE")
+ switch testcase {
+ case "handshake", "keyupdate":
+ basicTest(ctx, config, urls)
+ return
+ case "chacha20":
+ // "[...] offer only ChaCha20 as a ciphersuite."
+ //
+ // crypto/tls does not support configuring TLS 1.3 ciphersuites,
+ // so we can't support this test.
+ case "transfer":
+ // "The client should use small initial flow control windows
+ // for both stream- and connection-level flow control
+ // such that the during the transfer of files on the order of 1 MB
+ // the flow control window needs to be increased."
+ config.MaxStreamReadBufferSize = 64 << 10
+ config.MaxConnReadBufferSize = 64 << 10
+ basicTest(ctx, config, urls)
+ return
+ case "http3":
+ // TODO
+ case "multiconnect":
+ // TODO
+ case "resumption":
+ // TODO
+ case "retry":
+ // TODO
+ case "versionnegotiation":
+ // "The client should start a connection using
+ // an unsupported version number [...]"
+ //
+ // We don't support setting the client's version,
+ // so only run this test as a server.
+ if *listen != "" && len(urls) == 0 {
+ basicTest(ctx, config, urls)
+ return
+ }
+ case "v2":
+ // We do not support QUIC v2.
+ case "zerortt":
+ // TODO
+ }
+ fmt.Printf("unsupported test case %q\n", testcase)
+ os.Exit(127)
+}
+
+// basicTest runs the standard test setup.
+//
+// As a server, it serves the contents of the -root directory.
+// As a client, it downloads all the provided URLs in parallel,
+// making one connection to each destination server.
+func basicTest(ctx context.Context, config *quic.Config, urls []string) {
+ l, err := quic.Listen("udp", *listen, config)
+ if err != nil {
+ log.Fatal(err)
+ }
+ log.Printf("listening on %v", l.LocalAddr())
+
+ byAuthority := map[string][]*url.URL{}
+ for _, s := range urls {
+ u, addr, err := parseURL(s)
+ if err != nil {
+ log.Fatal(err)
+ }
+ byAuthority[addr] = append(byAuthority[addr], u)
+ }
+ var g sync.WaitGroup
+ defer g.Wait()
+ for addr, u := range byAuthority {
+ addr, u := addr, u
+ g.Add(1)
+ go func() {
+ defer g.Done()
+ fetchFrom(ctx, config, l, addr, u)
+ }()
+ }
+
+ if config.MaxBidiRemoteStreams >= 0 {
+ serve(ctx, l)
+ }
+}
+
+func serve(ctx context.Context, l *quic.Endpoint) error {
+ for {
+ c, err := l.Accept(ctx)
+ if err != nil {
+ return err
+ }
+ go serveConn(ctx, c)
+ }
+}
+
+func serveConn(ctx context.Context, c *quic.Conn) {
+ for {
+ s, err := c.AcceptStream(ctx)
+ if err != nil {
+ return
+ }
+ go func() {
+ if err := serveReq(ctx, s); err != nil {
+ log.Print("serveReq:", err)
+ }
+ }()
+ }
+}
+
+func serveReq(ctx context.Context, s *quic.Stream) error {
+ defer s.Close()
+ req, err := io.ReadAll(s)
+ if err != nil {
+ return err
+ }
+ if !bytes.HasSuffix(req, []byte("\r\n")) {
+ return errors.New("invalid request")
+ }
+ req = bytes.TrimSuffix(req, []byte("\r\n"))
+ if !bytes.HasPrefix(req, []byte("GET /")) {
+ return errors.New("invalid request")
+ }
+ req = bytes.TrimPrefix(req, []byte("GET /"))
+ if !filepath.IsLocal(string(req)) {
+ return errors.New("invalid request")
+ }
+ f, err := os.Open(filepath.Join(*root, string(req)))
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ _, err = io.Copy(s, f)
+ return err
+}
+
+func parseURL(s string) (u *url.URL, authority string, err error) {
+ u, err = url.Parse(s)
+ if err != nil {
+ return nil, "", err
+ }
+ host := u.Hostname()
+ port := u.Port()
+ if port == "" {
+ port = "443"
+ }
+ authority = net.JoinHostPort(host, port)
+ return u, authority, nil
+}
+
+func fetchFrom(ctx context.Context, config *quic.Config, l *quic.Endpoint, addr string, urls []*url.URL) {
+ conn, err := l.Dial(ctx, "udp", addr, config)
+ if err != nil {
+ log.Printf("%v: %v", addr, err)
+ return
+ }
+ log.Printf("connected to %v", addr)
+ defer conn.Close()
+ var g sync.WaitGroup
+ for _, u := range urls {
+ u := u
+ g.Add(1)
+ go func() {
+ defer g.Done()
+ if err := fetchOne(ctx, conn, u); err != nil {
+ log.Printf("fetch %v: %v", u, err)
+ } else {
+ log.Printf("fetched %v", u)
+ }
+ }()
+ }
+ g.Wait()
+}
+
+func fetchOne(ctx context.Context, conn *quic.Conn, u *url.URL) error {
+ if len(u.Path) == 0 || u.Path[0] != '/' || !filepath.IsLocal(u.Path[1:]) {
+ return errors.New("invalid path")
+ }
+ file, err := os.Create(filepath.Join(*output, u.Path[1:]))
+ if err != nil {
+ return err
+ }
+ s, err := conn.NewStream(ctx)
+ if err != nil {
+ return err
+ }
+ defer s.Close()
+ if _, err := s.Write([]byte("GET " + u.Path + "\r\n")); err != nil {
+ return err
+ }
+ s.CloseWrite()
+ if _, err := io.Copy(file, s); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/internal/quic/cmd/interop/main_test.go b/internal/quic/cmd/interop/main_test.go
new file mode 100644
index 0000000000..4119740e6c
--- /dev/null
+++ b/internal/quic/cmd/interop/main_test.go
@@ -0,0 +1,174 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "net"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+)
+
+func init() {
+ // We reexec the test binary with CMD_INTEROP_MAIN=1 to run main.
+ if os.Getenv("CMD_INTEROP_MAIN") == "1" {
+ main()
+ os.Exit(0)
+ }
+}
+
+var (
+ tryExecOnce sync.Once
+ tryExecErr error
+)
+
+// needsExec skips the test if we can't use exec.Command.
+func needsExec(t *testing.T) {
+ tryExecOnce.Do(func() {
+ cmd := exec.Command(os.Args[0], "-test.list=^$")
+ cmd.Env = []string{}
+ tryExecErr = cmd.Run()
+ })
+ if tryExecErr != nil {
+ t.Skipf("skipping test: cannot exec subprocess: %v", tryExecErr)
+ }
+}
+
+type interopTest struct {
+ donec chan struct{}
+ addr string
+ cmd *exec.Cmd
+}
+
+func run(ctx context.Context, t *testing.T, name, testcase string, args []string) *interopTest {
+ needsExec(t)
+ ctx, cancel := context.WithCancel(ctx)
+ cmd := exec.CommandContext(ctx, os.Args[0], args...)
+ out, err := cmd.StderrPipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ cmd.Stdout = cmd.Stderr
+ cmd.Env = []string{
+ "CMD_INTEROP_MAIN=1",
+ "TESTCASE=" + testcase,
+ }
+ t.Logf("run %v: %v", name, args)
+ err = cmd.Start()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ addrc := make(chan string, 1)
+ donec := make(chan struct{})
+ go func() {
+ defer close(addrc)
+ defer close(donec)
+ defer t.Logf("%v done", name)
+ s := bufio.NewScanner(out)
+ for s.Scan() {
+ line := s.Text()
+ t.Logf("%v: %v", name, line)
+ _, addr, ok := strings.Cut(line, "listening on ")
+ if ok {
+ select {
+ case addrc <- addr:
+ default:
+ }
+ }
+ }
+ }()
+
+ t.Cleanup(func() {
+ cancel()
+ <-donec
+ })
+
+ addr, ok := <-addrc
+ if !ok {
+ t.Fatal(cmd.Wait())
+ }
+ _, port, _ := net.SplitHostPort(addr)
+ addr = net.JoinHostPort("localhost", port)
+
+ iop := &interopTest{
+ cmd: cmd,
+ donec: donec,
+ addr: addr,
+ }
+ return iop
+}
+
+func (iop *interopTest) wait() {
+ <-iop.donec
+}
+
+func TestTransfer(t *testing.T) {
+ ctx := context.Background()
+ src := t.TempDir()
+ dst := t.TempDir()
+ certs := t.TempDir()
+ certFile := filepath.Join(certs, "cert.pem")
+ keyFile := filepath.Join(certs, "key.pem")
+ sourceName := "source"
+ content := []byte("hello, world\n")
+
+ os.WriteFile(certFile, localhostCert, 0600)
+ os.WriteFile(keyFile, localhostKey, 0600)
+ os.WriteFile(filepath.Join(src, sourceName), content, 0600)
+
+ srv := run(ctx, t, "server", "transfer", []string{
+ "-listen", "localhost:0",
+ "-cert", filepath.Join(certs, "cert.pem"),
+ "-key", filepath.Join(certs, "key.pem"),
+ "-root", src,
+ })
+ cli := run(ctx, t, "client", "transfer", []string{
+ "-output", dst, "https://" + srv.addr + "/" + sourceName,
+ })
+ cli.wait()
+
+ got, err := os.ReadFile(filepath.Join(dst, "source"))
+ if err != nil {
+ t.Fatalf("reading downloaded file: %v", err)
+ }
+ if !bytes.Equal(got, content) {
+ t.Fatalf("got downloaded file: %q, want %q", string(got), string(content))
+ }
+}
+
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
+// generated from src/crypto/tls:
+// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO
+BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa
+MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh
+WRF8p8X9scgW7JjqAwI9nYV8jtkdhqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGms
+PyfMPe5Jrha/LmjgR1G9o4GIMIGFMA4GA1UdDwEB/wQEAwIChDATBgNVHSUEDDAK
+BggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSOJri/wLQxq6oC
+Y6ZImms/STbTljAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAA
+AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40
+HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE
+-----END CERTIFICATE-----`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs
+rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd
+hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9
+-----END TESTING KEY-----`))
+
+// testingKey helps keep security scanners from getting excited about a private key in this file.
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/internal/quic/cmd/interop/run_endpoint.sh b/internal/quic/cmd/interop/run_endpoint.sh
new file mode 100644
index 0000000000..442039bc07
--- /dev/null
+++ b/internal/quic/cmd/interop/run_endpoint.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Set up the routing needed for the simulation
+/setup.sh
+
+# The following variables are available for use:
+# - ROLE contains the role of this execution context, client or server
+# - SERVER_PARAMS contains user-supplied command line parameters
+# - CLIENT_PARAMS contains user-supplied command line parameters
+
+if [ "$ROLE" == "client" ]; then
+ # Wait for the simulator to start up.
+ /wait-for-it.sh sim:57832 -s -t 30
+ ./interop -output=/downloads -qlog=$QLOGDIR $CLIENT_PARAMS $REQUESTS
+elif [ "$ROLE" == "server" ]; then
+ ./interop -cert=/certs/cert.pem -key=/certs/priv.key -qlog=$QLOGDIR -listen=:443 -root=/www "$@" $SERVER_PARAMS
+fi
diff --git a/internal/quic/config.go b/internal/quic/config.go
deleted file mode 100644
index b390d6911e..0000000000
--- a/internal/quic/config.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "crypto/tls"
-)
-
-// A Config structure configures a QUIC endpoint.
-// A Config must not be modified after it has been passed to a QUIC function.
-// A Config may be reused; the quic package will also not modify it.
-type Config struct {
- // TLSConfig is the endpoint's TLS configuration.
- // It must be non-nil and include at least one certificate or else set GetCertificate.
- TLSConfig *tls.Config
-
- // MaxBidiRemoteStreams limits the number of simultaneous bidirectional streams
- // a peer may open.
- // If zero, the default value of 100 is used.
- // If negative, the limit is zero.
- MaxBidiRemoteStreams int64
-
- // MaxUniRemoteStreams limits the number of simultaneous unidirectional streams
- // a peer may open.
- // If zero, the default value of 100 is used.
- // If negative, the limit is zero.
- MaxUniRemoteStreams int64
-
- // MaxStreamReadBufferSize is the maximum amount of data sent by the peer that a
- // stream will buffer for reading.
- // If zero, the default value of 1MiB is used.
- // If negative, the limit is zero.
- MaxStreamReadBufferSize int64
-
- // MaxStreamWriteBufferSize is the maximum amount of data a stream will buffer for
- // sending to the peer.
- // If zero, the default value of 1MiB is used.
- // If negative, the limit is zero.
- MaxStreamWriteBufferSize int64
-
- // MaxConnReadBufferSize is the maximum amount of data sent by the peer that a
- // connection will buffer for reading, across all streams.
- // If zero, the default value of 1MiB is used.
- // If negative, the limit is zero.
- MaxConnReadBufferSize int64
-}
-
-func configDefault(v, def, limit int64) int64 {
- switch {
- case v == 0:
- return def
- case v < 0:
- return 0
- default:
- return min(v, limit)
- }
-}
-
-func (c *Config) maxBidiRemoteStreams() int64 {
- return configDefault(c.MaxBidiRemoteStreams, 100, maxStreamsLimit)
-}
-
-func (c *Config) maxUniRemoteStreams() int64 {
- return configDefault(c.MaxUniRemoteStreams, 100, maxStreamsLimit)
-}
-
-func (c *Config) maxStreamReadBufferSize() int64 {
- return configDefault(c.MaxStreamReadBufferSize, 1<<20, maxVarint)
-}
-
-func (c *Config) maxStreamWriteBufferSize() int64 {
- return configDefault(c.MaxStreamWriteBufferSize, 1<<20, maxVarint)
-}
-
-func (c *Config) maxConnReadBufferSize() int64 {
- return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint)
-}
diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go
deleted file mode 100644
index b8b86fd6fb..0000000000
--- a/internal/quic/conn_close.go
+++ /dev/null
@@ -1,252 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "context"
- "errors"
- "time"
-)
-
-// lifetimeState tracks the state of a connection.
-//
-// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps
-// reason about operations that cause state transitions.
-type lifetimeState struct {
- readyc chan struct{} // closed when TLS handshake completes
- drainingc chan struct{} // closed when entering the draining state
-
- // Possible states for the connection:
- //
- // Alive: localErr and finalErr are both nil.
- //
- // Closing: localErr is non-nil and finalErr is nil.
- // We have sent a CONNECTION_CLOSE to the peer or are about to
- // (if connCloseSentTime is zero) and are waiting for the peer to respond.
- // drainEndTime is set to the time the closing state ends.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.1
- //
- // Draining: finalErr is non-nil.
- // If localErr is nil, we're waiting for the user to provide us with a final status
- // to send to the peer.
- // Otherwise, we've either sent a CONNECTION_CLOSE to the peer or are about to
- // (if connCloseSentTime is zero).
- // drainEndTime is set to the time the draining state ends.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
- localErr error // error sent to the peer
- finalErr error // error sent by the peer, or transport error; always set before draining
-
- connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame
- connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent
- drainEndTime time.Time // time the connection exits the draining state
-}
-
-func (c *Conn) lifetimeInit() {
- c.lifetime.readyc = make(chan struct{})
- c.lifetime.drainingc = make(chan struct{})
-}
-
-var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE")
-
-// advance is called when time passes.
-func (c *Conn) lifetimeAdvance(now time.Time) (done bool) {
- if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) {
- return false
- }
- // The connection drain period has ended, and we can shut down.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7
- c.lifetime.drainEndTime = time.Time{}
- if c.lifetime.finalErr == nil {
- // The peer never responded to our CONNECTION_CLOSE.
- c.enterDraining(errNoPeerResponse)
- }
- return true
-}
-
-// confirmHandshake is called when the TLS handshake completes.
-func (c *Conn) handshakeDone() {
- close(c.lifetime.readyc)
-}
-
-// isDraining reports whether the conn is in the draining state.
-//
-// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame.
-// The endpoint will no longer send any packets, but we retain knowledge of the connection
-// until the end of the drain period to ensure we discard packets for the connection
-// rather than treating them as starting a new connection.
-//
-// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
-func (c *Conn) isDraining() bool {
- return c.lifetime.finalErr != nil
-}
-
-// isClosingOrDraining reports whether the conn is in the closing or draining states.
-func (c *Conn) isClosingOrDraining() bool {
- return c.lifetime.localErr != nil || c.lifetime.finalErr != nil
-}
-
-// sendOK reports whether the conn can send frames at this time.
-func (c *Conn) sendOK(now time.Time) bool {
- if !c.isClosingOrDraining() {
- return true
- }
- // We are closing or draining.
- if c.lifetime.localErr == nil {
- // We're waiting for the user to close the connection, providing us with
- // a final status to send to the peer.
- return false
- }
- // Past this point, returning true will result in the conn sending a CONNECTION_CLOSE
- // due to localErr being set.
- if c.lifetime.drainEndTime.IsZero() {
- // The closing and draining states should last for at least three times
- // the current PTO interval. We currently use exactly that minimum.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-5
- //
- // The drain period begins when we send or receive a CONNECTION_CLOSE,
- // whichever comes first.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2-3
- c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
- }
- if c.lifetime.connCloseSentTime.IsZero() {
- // We haven't sent a CONNECTION_CLOSE yet. Do so.
- // Either we're initiating an immediate close
- // (and will enter the closing state as soon as we send CONNECTION_CLOSE),
- // or we've read a CONNECTION_CLOSE from our peer
- // (and may send one CONNECTION_CLOSE before entering the draining state).
- //
- // Set the initial delay before we will send another CONNECTION_CLOSE.
- //
- // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames,
- // but leaves the implementation of the limit up to us. Here, we start
- // with the same delay as the PTO timer (RFC 9002, Section 6.2.1),
- // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent.
- c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity)
- c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
- return true
- }
- if c.isDraining() {
- // We are in the draining state, and will send no more packets.
- return false
- }
- maxRecvTime := c.acks[initialSpace].maxRecvTime
- if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) {
- maxRecvTime = t
- }
- if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) {
- maxRecvTime = t
- }
- if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) {
- // After sending CONNECTION_CLOSE, ignore packets from the peer for
- // a delay. On the next packet received after the delay, send another
- // CONNECTION_CLOSE.
- return false
- }
- c.lifetime.connCloseSentTime = now
- c.lifetime.connCloseDelay *= 2
- return true
-}
-
-// enterDraining enters the draining state.
-func (c *Conn) enterDraining(err error) {
- if c.isDraining() {
- return
- }
- if e, ok := c.lifetime.localErr.(localTransportError); ok && transportError(e) != errNo {
- // If we've terminated the connection due to a peer protocol violation,
- // record the final error on the connection as our reason for termination.
- c.lifetime.finalErr = c.lifetime.localErr
- } else {
- c.lifetime.finalErr = err
- }
- close(c.lifetime.drainingc)
- c.streams.queue.close(c.lifetime.finalErr)
-}
-
-func (c *Conn) waitReady(ctx context.Context) error {
- select {
- case <-c.lifetime.readyc:
- return nil
- case <-c.lifetime.drainingc:
- return c.lifetime.finalErr
- default:
- }
- select {
- case <-c.lifetime.readyc:
- return nil
- case <-c.lifetime.drainingc:
- return c.lifetime.finalErr
- case <-ctx.Done():
- return ctx.Err()
- }
-}
-
-// Close closes the connection.
-//
-// Close is equivalent to:
-//
-// conn.Abort(nil)
-// err := conn.Wait(context.Background())
-func (c *Conn) Close() error {
- c.Abort(nil)
- <-c.lifetime.drainingc
- return c.lifetime.finalErr
-}
-
-// Wait waits for the peer to close the connection.
-//
-// If the connection is closed locally and the peer does not close its end of the connection,
-// Wait will return with a non-nil error after the drain period expires.
-//
-// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil.
-// If the peer closes the connection with an application error, Wait returns an ApplicationError
-// containing the peer's error code and reason.
-// If the peer closes the connection with any other status, Wait returns a non-nil error.
-func (c *Conn) Wait(ctx context.Context) error {
- if err := c.waitOnDone(ctx, c.lifetime.drainingc); err != nil {
- return err
- }
- return c.lifetime.finalErr
-}
-
-// Abort closes the connection and returns immediately.
-//
-// If err is nil, Abort sends a transport error of NO_ERROR to the peer.
-// If err is an ApplicationError, Abort sends its error code and text.
-// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text.
-func (c *Conn) Abort(err error) {
- if err == nil {
- err = localTransportError(errNo)
- }
- c.sendMsg(func(now time.Time, c *Conn) {
- c.abort(now, err)
- })
-}
-
-// abort terminates a connection with an error.
-func (c *Conn) abort(now time.Time, err error) {
- if c.lifetime.localErr != nil {
- return // already closing
- }
- c.lifetime.localErr = err
-}
-
-// abortImmediately terminates a connection.
-// The connection does not send a CONNECTION_CLOSE, and skips the draining period.
-func (c *Conn) abortImmediately(now time.Time, err error) {
- c.abort(now, err)
- c.enterDraining(err)
- c.exited = true
-}
-
-// exit fully terminates a connection immediately.
-func (c *Conn) exit() {
- c.sendMsg(func(now time.Time, c *Conn) {
- c.enterDraining(errors.New("connection closed"))
- c.exited = true
- })
-}
diff --git a/internal/quic/doc.go b/internal/quic/doc.go
deleted file mode 100644
index 2fe17fe226..0000000000
--- a/internal/quic/doc.go
+++ /dev/null
@@ -1,9 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package quic is an experimental, incomplete implementation of the QUIC protocol.
-// This package is a work in progress, and is not ready for use at this time.
-//
-// This package implements (or will implement) RFC 9000, RFC 9001, and RFC 9002.
-package quic
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
deleted file mode 100644
index 96b1e45934..0000000000
--- a/internal/quic/listener.go
+++ /dev/null
@@ -1,322 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "context"
- "errors"
- "net"
- "net/netip"
- "sync"
- "sync/atomic"
- "time"
-)
-
-// A Listener listens for QUIC traffic on a network address.
-// It can accept inbound connections or create outbound ones.
-//
-// Multiple goroutines may invoke methods on a Listener simultaneously.
-type Listener struct {
- config *Config
- udpConn udpConn
- testHooks connTestHooks
-
- acceptQueue queue[*Conn] // new inbound connections
-
- connsMu sync.Mutex
- conns map[*Conn]struct{}
- closing bool // set when Close is called
- closec chan struct{} // closed when the listen loop exits
-
- // The datagram receive loop keeps a mapping of connection IDs to conns.
- // When a conn's connection IDs change, we add it to connIDUpdates and set
- // connIDUpdateNeeded, and the receive loop updates its map.
- connIDUpdateMu sync.Mutex
- connIDUpdateNeeded atomic.Bool
- connIDUpdates []connIDUpdate
-}
-
-// A udpConn is a UDP connection.
-// It is implemented by net.UDPConn.
-type udpConn interface {
- Close() error
- LocalAddr() net.Addr
- ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
- WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
-}
-
-type connIDUpdate struct {
- conn *Conn
- retired bool
- cid []byte
-}
-
-// Listen listens on a local network address.
-// The configuration config must be non-nil.
-func Listen(network, address string, config *Config) (*Listener, error) {
- if config.TLSConfig == nil {
- return nil, errors.New("TLSConfig is not set")
- }
- a, err := net.ResolveUDPAddr(network, address)
- if err != nil {
- return nil, err
- }
- udpConn, err := net.ListenUDP(network, a)
- if err != nil {
- return nil, err
- }
- return newListener(udpConn, config, nil), nil
-}
-
-func newListener(udpConn udpConn, config *Config, hooks connTestHooks) *Listener {
- l := &Listener{
- config: config,
- udpConn: udpConn,
- testHooks: hooks,
- conns: make(map[*Conn]struct{}),
- acceptQueue: newQueue[*Conn](),
- closec: make(chan struct{}),
- }
- go l.listen()
- return l
-}
-
-// LocalAddr returns the local network address.
-func (l *Listener) LocalAddr() netip.AddrPort {
- a, _ := l.udpConn.LocalAddr().(*net.UDPAddr)
- return a.AddrPort()
-}
-
-// Close closes the listener.
-// Any blocked operations on the Listener or associated Conns and Stream will be unblocked
-// and return errors.
-//
-// Close aborts every open connection.
-// Data in stream read and write buffers is discarded.
-// It waits for the peers of any open connection to acknowledge the connection has been closed.
-func (l *Listener) Close(ctx context.Context) error {
- l.acceptQueue.close(errors.New("listener closed"))
- l.connsMu.Lock()
- if !l.closing {
- l.closing = true
- for c := range l.conns {
- c.Abort(errors.New("listener closed"))
- }
- if len(l.conns) == 0 {
- l.udpConn.Close()
- }
- }
- l.connsMu.Unlock()
- select {
- case <-l.closec:
- case <-ctx.Done():
- l.connsMu.Lock()
- for c := range l.conns {
- c.exit()
- }
- l.connsMu.Unlock()
- return ctx.Err()
- }
- return nil
-}
-
-// Accept waits for and returns the next connection to the listener.
-func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
- return l.acceptQueue.get(ctx, nil)
-}
-
-// Dial creates and returns a connection to a network address.
-func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) {
- u, err := net.ResolveUDPAddr(network, address)
- if err != nil {
- return nil, err
- }
- addr := u.AddrPort()
- addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
- c, err := l.newConn(time.Now(), clientSide, nil, addr)
- if err != nil {
- return nil, err
- }
- if err := c.waitReady(ctx); err != nil {
- c.Abort(nil)
- return nil, err
- }
- return c, nil
-}
-
-func (l *Listener) newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort) (*Conn, error) {
- l.connsMu.Lock()
- defer l.connsMu.Unlock()
- if l.closing {
- return nil, errors.New("listener closed")
- }
- c, err := newConn(now, side, initialConnID, peerAddr, l.config, l, l.testHooks)
- if err != nil {
- return nil, err
- }
- l.conns[c] = struct{}{}
- return c, nil
-}
-
-// serverConnEstablished is called by a conn when the handshake completes
-// for an inbound (serverSide) connection.
-func (l *Listener) serverConnEstablished(c *Conn) {
- l.acceptQueue.put(c)
-}
-
-// connDrained is called by a conn when it leaves the draining state,
-// either when the peer acknowledges connection closure or the drain timeout expires.
-func (l *Listener) connDrained(c *Conn) {
- l.connsMu.Lock()
- defer l.connsMu.Unlock()
- delete(l.conns, c)
- if l.closing && len(l.conns) == 0 {
- l.udpConn.Close()
- }
-}
-
-// connIDsChanged is called by a conn when its connection IDs change.
-func (l *Listener) connIDsChanged(c *Conn, retired bool, cids []connID) {
- l.connIDUpdateMu.Lock()
- defer l.connIDUpdateMu.Unlock()
- for _, cid := range cids {
- l.connIDUpdates = append(l.connIDUpdates, connIDUpdate{
- conn: c,
- retired: retired,
- cid: cid.cid,
- })
- }
- l.connIDUpdateNeeded.Store(true)
-}
-
-// updateConnIDs is called by the datagram receive loop to update its connection ID map.
-func (l *Listener) updateConnIDs(conns map[string]*Conn) {
- l.connIDUpdateMu.Lock()
- defer l.connIDUpdateMu.Unlock()
- for i, u := range l.connIDUpdates {
- if u.retired {
- delete(conns, string(u.cid))
- } else {
- conns[string(u.cid)] = u.conn
- }
- l.connIDUpdates[i] = connIDUpdate{} // drop refs
- }
- l.connIDUpdates = l.connIDUpdates[:0]
- l.connIDUpdateNeeded.Store(false)
-}
-
-func (l *Listener) listen() {
- defer close(l.closec)
- conns := map[string]*Conn{}
- for {
- m := newDatagram()
- // TODO: Read and process the ECN (explicit congestion notification) field.
- // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4
- n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil)
- if err != nil {
- // The user has probably closed the listener.
- // We currently don't surface errors from other causes;
- // we could check to see if the listener has been closed and
- // record the unexpected error if it has not.
- return
- }
- if n == 0 {
- continue
- }
- if l.connIDUpdateNeeded.Load() {
- l.updateConnIDs(conns)
- }
- m.addr = addr
- m.b = m.b[:n]
- l.handleDatagram(m, conns)
- }
-}
-
-func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) {
- dstConnID, ok := dstConnIDForDatagram(m.b)
- if !ok {
- m.recycle()
- return
- }
- c := conns[string(dstConnID)]
- if c == nil {
- // TODO: Move this branch into a separate goroutine to avoid blocking
- // the listener while processing packets.
- l.handleUnknownDestinationDatagram(m)
- return
- }
-
- // TODO: This can block the listener while waiting for the conn to accept the dgram.
- // Think about buffering between the receive loop and the conn.
- c.sendMsg(m)
-}
-
-func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
- defer func() {
- if m != nil {
- m.recycle()
- }
- }()
- if len(m.b) < minimumClientInitialDatagramSize {
- return
- }
- p, ok := parseGenericLongHeaderPacket(m.b)
- if !ok {
- // Not a long header packet, or not parseable.
- // Short header (1-RTT) packets don't contain enough information
- // to do anything useful with if we don't recognize the
- // connection ID.
- return
- }
-
- switch p.version {
- case quicVersion1:
- case 0:
- // Version Negotiation for an unknown connection.
- return
- default:
- // Unknown version.
- l.sendVersionNegotiation(p, m.addr)
- return
- }
- if getPacketType(m.b) != packetTypeInitial {
- // This packet isn't trying to create a new connection.
- // It might be associated with some connection we've lost state for.
- // TODO: Send a stateless reset when appropriate.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
- return
- }
- var now time.Time
- if l.testHooks != nil {
- now = l.testHooks.timeNow()
- } else {
- now = time.Now()
- }
- var err error
- c, err := l.newConn(now, serverSide, p.dstConnID, m.addr)
- if err != nil {
- // The accept queue is probably full.
- // We could send a CONNECTION_CLOSE to the peer to reject the connection.
- // Currently, we just drop the datagram.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
- return
- }
- c.sendMsg(m)
- m = nil // don't recycle, sendMsg takes ownership
-}
-
-func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
- m := newDatagram()
- m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
- l.sendDatagram(m.b, addr)
- m.recycle()
-}
-
-func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error {
- _, err := l.udpConn.WriteToUDPAddrPort(p, addr)
- return err
-}
diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go
deleted file mode 100644
index 9d0f314ecc..0000000000
--- a/internal/quic/listener_test.go
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "bytes"
- "context"
- "io"
- "net"
- "net/netip"
- "testing"
-)
-
-func TestConnect(t *testing.T) {
- newLocalConnPair(t, &Config{}, &Config{})
-}
-
-func TestStreamTransfer(t *testing.T) {
- ctx := context.Background()
- cli, srv := newLocalConnPair(t, &Config{}, &Config{})
- data := makeTestData(1 << 20)
-
- srvdone := make(chan struct{})
- go func() {
- defer close(srvdone)
- s, err := srv.AcceptStream(ctx)
- if err != nil {
- t.Errorf("AcceptStream: %v", err)
- return
- }
- b, err := io.ReadAll(s)
- if err != nil {
- t.Errorf("io.ReadAll(s): %v", err)
- return
- }
- if !bytes.Equal(b, data) {
- t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
- }
- if err := s.Close(); err != nil {
- t.Errorf("s.Close() = %v", err)
- }
- }()
-
- s, err := cli.NewStream(ctx)
- if err != nil {
- t.Fatalf("NewStream: %v", err)
- }
- n, err := io.Copy(s, bytes.NewBuffer(data))
- if n != int64(len(data)) || err != nil {
- t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
- }
- if err := s.Close(); err != nil {
- t.Fatalf("s.Close() = %v", err)
- }
-}
-
-func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
- t.Helper()
- ctx := context.Background()
- l1 := newLocalListener(t, serverSide, conf1)
- l2 := newLocalListener(t, clientSide, conf2)
- c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
- if err != nil {
- t.Fatal(err)
- }
- c1, err := l1.Accept(ctx)
- if err != nil {
- t.Fatal(err)
- }
- return c2, c1
-}
-
-func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
- t.Helper()
- if conf.TLSConfig == nil {
- conf.TLSConfig = newTestTLSConfig(side)
- }
- l, err := Listen("udp", "127.0.0.1:0", conf)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() {
- l.Close(context.Background())
- })
- return l
-}
-
-type testListener struct {
- t *testing.T
- l *Listener
- recvc chan *datagram
- idlec chan struct{}
- sentDatagrams [][]byte
-}
-
-func newTestListener(t *testing.T, config *Config, testHooks connTestHooks) *testListener {
- tl := &testListener{
- t: t,
- recvc: make(chan *datagram),
- idlec: make(chan struct{}),
- }
- tl.l = newListener((*testListenerUDPConn)(tl), config, testHooks)
- t.Cleanup(tl.cleanup)
- return tl
-}
-
-func (tl *testListener) cleanup() {
- tl.l.Close(canceledContext())
-}
-
-func (tl *testListener) wait() {
- tl.idlec <- struct{}{}
-}
-
-func (tl *testListener) write(d *datagram) {
- tl.recvc <- d
- tl.wait()
-}
-
-func (tl *testListener) read() []byte {
- tl.wait()
- if len(tl.sentDatagrams) == 0 {
- return nil
- }
- d := tl.sentDatagrams[0]
- tl.sentDatagrams = tl.sentDatagrams[1:]
- return d
-}
-
-// testListenerUDPConn implements UDPConn.
-type testListenerUDPConn testListener
-
-func (tl *testListenerUDPConn) Close() error {
- close(tl.recvc)
- return nil
-}
-
-func (tl *testListenerUDPConn) LocalAddr() net.Addr {
- return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
-}
-
-func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
- for {
- select {
- case d, ok := <-tl.recvc:
- if !ok {
- return 0, 0, 0, netip.AddrPort{}, io.EOF
- }
- n = copy(b, d.b)
- return n, 0, 0, d.addr, nil
- case <-tl.idlec:
- }
- }
-}
-
-func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
- tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...))
- return len(b), nil
-}
diff --git a/internal/quic/wire.go b/internal/quic/quicwire/wire.go
similarity index 65%
rename from internal/quic/wire.go
rename to internal/quic/quicwire/wire.go
index 8486029151..0edf42227d 100644
--- a/internal/quic/wire.go
+++ b/internal/quic/quicwire/wire.go
@@ -4,20 +4,22 @@
//go:build go1.21
-package quic
+// Package quicwire encodes and decode QUIC/HTTP3 wire encoding types,
+// particularly variable-length integers.
+package quicwire
import "encoding/binary"
const (
- maxVarintSize = 8 // encoded size in bytes
- maxVarint = (1 << 62) - 1
+ MaxVarintSize = 8 // encoded size in bytes
+ MaxVarint = (1 << 62) - 1
)
-// consumeVarint parses a variable-length integer, reporting its length.
+// ConsumeVarint parses a variable-length integer, reporting its length.
// It returns a negative length upon an error.
//
// https://www.rfc-editor.org/rfc/rfc9000.html#section-16
-func consumeVarint(b []byte) (v uint64, n int) {
+func ConsumeVarint(b []byte) (v uint64, n int) {
if len(b) < 1 {
return 0, -1
}
@@ -44,17 +46,17 @@ func consumeVarint(b []byte) (v uint64, n int) {
return 0, -1
}
-// consumeVarint64 parses a variable-length integer as an int64.
-func consumeVarintInt64(b []byte) (v int64, n int) {
- u, n := consumeVarint(b)
+// consumeVarintInt64 parses a variable-length integer as an int64.
+func ConsumeVarintInt64(b []byte) (v int64, n int) {
+ u, n := ConsumeVarint(b)
// QUIC varints are 62-bits large, so this conversion can never overflow.
return int64(u), n
}
-// appendVarint appends a variable-length integer to b.
+// AppendVarint appends a variable-length integer to b.
//
// https://www.rfc-editor.org/rfc/rfc9000.html#section-16
-func appendVarint(b []byte, v uint64) []byte {
+func AppendVarint(b []byte, v uint64) []byte {
switch {
case v <= 63:
return append(b, byte(v))
@@ -69,8 +71,8 @@ func appendVarint(b []byte, v uint64) []byte {
}
}
-// sizeVarint returns the size of the variable-length integer encoding of f.
-func sizeVarint(v uint64) int {
+// SizeVarint returns the size of the variable-length integer encoding of f.
+func SizeVarint(v uint64) int {
switch {
case v <= 63:
return 1
@@ -85,28 +87,28 @@ func sizeVarint(v uint64) int {
}
}
-// consumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length.
+// ConsumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length.
// It returns a negative length upon an error.
-func consumeUint32(b []byte) (uint32, int) {
+func ConsumeUint32(b []byte) (uint32, int) {
if len(b) < 4 {
return 0, -1
}
return binary.BigEndian.Uint32(b), 4
}
-// consumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length.
+// ConsumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length.
// It returns a negative length upon an error.
-func consumeUint64(b []byte) (uint64, int) {
+func ConsumeUint64(b []byte) (uint64, int) {
if len(b) < 8 {
return 0, -1
}
return binary.BigEndian.Uint64(b), 8
}
-// consumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length,
+// ConsumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length,
// reporting the total number of bytes consumed.
// It returns a negative length upon an error.
-func consumeUint8Bytes(b []byte) ([]byte, int) {
+func ConsumeUint8Bytes(b []byte) ([]byte, int) {
if len(b) < 1 {
return nil, -1
}
@@ -118,8 +120,8 @@ func consumeUint8Bytes(b []byte) ([]byte, int) {
return b[n:][:size], size + n
}
-// appendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length.
-func appendUint8Bytes(b, v []byte) []byte {
+// AppendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length.
+func AppendUint8Bytes(b, v []byte) []byte {
if len(v) > 0xff {
panic("uint8-prefixed bytes too large")
}
@@ -128,11 +130,11 @@ func appendUint8Bytes(b, v []byte) []byte {
return b
}
-// consumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length,
+// ConsumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length,
// reporting the total number of bytes consumed.
// It returns a negative length upon an error.
-func consumeVarintBytes(b []byte) ([]byte, int) {
- size, n := consumeVarint(b)
+func ConsumeVarintBytes(b []byte) ([]byte, int) {
+ size, n := ConsumeVarint(b)
if n < 0 {
return nil, -1
}
@@ -142,9 +144,9 @@ func consumeVarintBytes(b []byte) ([]byte, int) {
return b[n:][:size], int(size) + n
}
-// appendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length.
-func appendVarintBytes(b, v []byte) []byte {
- b = appendVarint(b, uint64(len(v)))
+// AppendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length.
+func AppendVarintBytes(b, v []byte) []byte {
+ b = AppendVarint(b, uint64(len(v)))
b = append(b, v...)
return b
}
diff --git a/internal/quic/wire_test.go b/internal/quic/quicwire/wire_test.go
similarity index 73%
rename from internal/quic/wire_test.go
rename to internal/quic/quicwire/wire_test.go
index 379da0d349..9167a5b72f 100644
--- a/internal/quic/wire_test.go
+++ b/internal/quic/quicwire/wire_test.go
@@ -4,7 +4,7 @@
//go:build go1.21
-package quic
+package quicwire
import (
"bytes"
@@ -32,22 +32,22 @@ func TestConsumeVarint(t *testing.T) {
{[]byte{0x25}, 37, 1},
{[]byte{0x40, 0x25}, 37, 2},
} {
- got, gotLen := consumeVarint(test.b)
+ got, gotLen := ConsumeVarint(test.b)
if got != test.want || gotLen != test.wantLen {
- t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want %v, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeVarint(b)
+ got, gotLen = ConsumeVarint(b)
if got != test.want || gotLen != test.wantLen {
- t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want %v, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen = consumeVarint(b)
+ got, gotLen = ConsumeVarint(b)
if got != 0 || gotLen >= 0 {
- t.Errorf("consumeVarint(%x) = %v, %v; want 0, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want 0, -1", b, got, gotLen)
}
}
}
@@ -69,11 +69,11 @@ func TestAppendVarint(t *testing.T) {
{15293, []byte{0x7b, 0xbd}},
{37, []byte{0x25}},
} {
- got := appendVarint([]byte{}, test.v)
+ got := AppendVarint([]byte{}, test.v)
if !bytes.Equal(got, test.want) {
t.Errorf("AppendVarint(nil, %v) = %x, want %x", test.v, got, test.want)
}
- if gotLen, wantLen := sizeVarint(test.v), len(got); gotLen != wantLen {
+ if gotLen, wantLen := SizeVarint(test.v), len(got); gotLen != wantLen {
t.Errorf("SizeVarint(%v) = %v, want %v", test.v, gotLen, wantLen)
}
}
@@ -88,8 +88,8 @@ func TestConsumeUint32(t *testing.T) {
{[]byte{0x01, 0x02, 0x03, 0x04}, 0x01020304, 4},
{[]byte{0x01, 0x02, 0x03}, 0, -1},
} {
- if got, n := consumeUint32(test.b); got != test.want || n != test.wantLen {
- t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
+ if got, n := ConsumeUint32(test.b); got != test.want || n != test.wantLen {
+ t.Errorf("ConsumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
}
}
}
@@ -103,8 +103,8 @@ func TestConsumeUint64(t *testing.T) {
{[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, 0x0102030405060708, 8},
{[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 0, -1},
} {
- if got, n := consumeUint64(test.b); got != test.want || n != test.wantLen {
- t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
+ if got, n := ConsumeUint64(test.b); got != test.want || n != test.wantLen {
+ t.Errorf("ConsumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
}
}
}
@@ -120,22 +120,22 @@ func TestConsumeVarintBytes(t *testing.T) {
{[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5},
{[]byte{0x40, 0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 6},
} {
- got, gotLen := consumeVarintBytes(test.b)
+ got, gotLen := ConsumeVarintBytes(test.b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeVarintBytes(b)
+ got, gotLen = ConsumeVarintBytes(b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen := consumeVarintBytes(b)
+ got, gotLen := ConsumeVarintBytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
@@ -147,9 +147,9 @@ func TestConsumeVarintBytesErrors(t *testing.T) {
{0x01},
{0x40, 0x01},
} {
- got, gotLen := consumeVarintBytes(b)
+ got, gotLen := ConsumeVarintBytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
}
@@ -164,22 +164,22 @@ func TestConsumeUint8Bytes(t *testing.T) {
{[]byte{0x01, 0x00}, []byte{0x00}, 2},
{[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5},
} {
- got, gotLen := consumeUint8Bytes(test.b)
+ got, gotLen := ConsumeUint8Bytes(test.b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeUint8Bytes(b)
+ got, gotLen = ConsumeUint8Bytes(b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen := consumeUint8Bytes(b)
+ got, gotLen := ConsumeUint8Bytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
@@ -191,35 +191,35 @@ func TestConsumeUint8BytesErrors(t *testing.T) {
{0x01},
{0x04, 0x01, 0x02, 0x03},
} {
- got, gotLen := consumeUint8Bytes(b)
+ got, gotLen := ConsumeUint8Bytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
}
func TestAppendUint8Bytes(t *testing.T) {
var got []byte
- got = appendUint8Bytes(got, []byte{})
- got = appendUint8Bytes(got, []byte{0xaa, 0xbb})
+ got = AppendUint8Bytes(got, []byte{})
+ got = AppendUint8Bytes(got, []byte{0xaa, 0xbb})
want := []byte{
0x00,
0x02, 0xaa, 0xbb,
}
if !bytes.Equal(got, want) {
- t.Errorf("appendUint8Bytes {}, {aabb} = {%x}; want {%x}", got, want)
+ t.Errorf("AppendUint8Bytes {}, {aabb} = {%x}; want {%x}", got, want)
}
}
func TestAppendVarintBytes(t *testing.T) {
var got []byte
- got = appendVarintBytes(got, []byte{})
- got = appendVarintBytes(got, []byte{0xaa, 0xbb})
+ got = AppendVarintBytes(got, []byte{})
+ got = AppendVarintBytes(got, []byte{0xaa, 0xbb})
want := []byte{
0x00,
0x02, 0xaa, 0xbb,
}
if !bytes.Equal(got, want) {
- t.Errorf("appendVarintBytes {}, {aabb} = {%x}; want {%x}", got, want)
+ t.Errorf("AppendVarintBytes {}, {aabb} = {%x}; want {%x}", got, want)
}
}
diff --git a/internal/socket/cmsghdr.go b/internal/socket/cmsghdr.go
index 4bdaaaf1ad..33a5bf59c3 100644
--- a/internal/socket/cmsghdr.go
+++ b/internal/socket/cmsghdr.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package socket
diff --git a/internal/socket/cmsghdr_bsd.go b/internal/socket/cmsghdr_bsd.go
index 0d30e0a0f2..68f438c845 100644
--- a/internal/socket/cmsghdr_bsd.go
+++ b/internal/socket/cmsghdr_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd
-// +build aix darwin dragonfly freebsd netbsd openbsd
package socket
diff --git a/internal/socket/cmsghdr_linux_32bit.go b/internal/socket/cmsghdr_linux_32bit.go
index 4936e8a6f3..058ea8de89 100644
--- a/internal/socket/cmsghdr_linux_32bit.go
+++ b/internal/socket/cmsghdr_linux_32bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm || mips || mipsle || 386 || ppc) && linux
-// +build arm mips mipsle 386 ppc
-// +build linux
package socket
diff --git a/internal/socket/cmsghdr_linux_64bit.go b/internal/socket/cmsghdr_linux_64bit.go
index f6877f98fd..3ca0d3a0ab 100644
--- a/internal/socket/cmsghdr_linux_64bit.go
+++ b/internal/socket/cmsghdr_linux_64bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm64 || amd64 || loong64 || ppc64 || ppc64le || mips64 || mips64le || riscv64 || s390x) && linux
-// +build arm64 amd64 loong64 ppc64 ppc64le mips64 mips64le riscv64 s390x
-// +build linux
package socket
diff --git a/internal/socket/cmsghdr_solaris_64bit.go b/internal/socket/cmsghdr_solaris_64bit.go
index d3dbe1b8e0..6d0e426cdd 100644
--- a/internal/socket/cmsghdr_solaris_64bit.go
+++ b/internal/socket/cmsghdr_solaris_64bit.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build amd64 && solaris
-// +build amd64,solaris
package socket
diff --git a/internal/socket/cmsghdr_stub.go b/internal/socket/cmsghdr_stub.go
index 1d9f2ed625..7ca9cb7e78 100644
--- a/internal/socket/cmsghdr_stub.go
+++ b/internal/socket/cmsghdr_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!zos
package socket
diff --git a/internal/socket/cmsghdr_unix.go b/internal/socket/cmsghdr_unix.go
index 19d46789de..0211f225bf 100644
--- a/internal/socket/cmsghdr_unix.go
+++ b/internal/socket/cmsghdr_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package socket
diff --git a/internal/socket/complete_dontwait.go b/internal/socket/complete_dontwait.go
index 5b1d50ae72..2038f29043 100644
--- a/internal/socket/complete_dontwait.go
+++ b/internal/socket/complete_dontwait.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package socket
diff --git a/internal/socket/complete_nodontwait.go b/internal/socket/complete_nodontwait.go
index be63409583..70e6f448b0 100644
--- a/internal/socket/complete_nodontwait.go
+++ b/internal/socket/complete_nodontwait.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || windows || zos
-// +build aix windows zos
package socket
diff --git a/internal/socket/defs_aix.go b/internal/socket/defs_aix.go
index 0bc1703ca6..2c847bbeb3 100644
--- a/internal/socket/defs_aix.go
+++ b/internal/socket/defs_aix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_darwin.go b/internal/socket/defs_darwin.go
index 0f07b57253..d94fff7558 100644
--- a/internal/socket/defs_darwin.go
+++ b/internal/socket/defs_darwin.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_dragonfly.go b/internal/socket/defs_dragonfly.go
index 0f07b57253..d94fff7558 100644
--- a/internal/socket/defs_dragonfly.go
+++ b/internal/socket/defs_dragonfly.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_freebsd.go b/internal/socket/defs_freebsd.go
index 0f07b57253..d94fff7558 100644
--- a/internal/socket/defs_freebsd.go
+++ b/internal/socket/defs_freebsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_linux.go b/internal/socket/defs_linux.go
index bbaafdf30a..d0d52bdfb7 100644
--- a/internal/socket/defs_linux.go
+++ b/internal/socket/defs_linux.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_netbsd.go b/internal/socket/defs_netbsd.go
index 5b57b0c426..8db525bf49 100644
--- a/internal/socket/defs_netbsd.go
+++ b/internal/socket/defs_netbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_openbsd.go b/internal/socket/defs_openbsd.go
index 0f07b57253..d94fff7558 100644
--- a/internal/socket/defs_openbsd.go
+++ b/internal/socket/defs_openbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/defs_solaris.go b/internal/socket/defs_solaris.go
index 0f07b57253..d94fff7558 100644
--- a/internal/socket/defs_solaris.go
+++ b/internal/socket/defs_solaris.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package socket
diff --git a/internal/socket/empty.s b/internal/socket/empty.s
index 90ab4ca3d8..49d79791e0 100644
--- a/internal/socket/empty.s
+++ b/internal/socket/empty.s
@@ -3,6 +3,5 @@
// license that can be found in the LICENSE file.
//go:build darwin && go1.12
-// +build darwin,go1.12
// This exists solely so we can linkname in symbols from syscall.
diff --git a/internal/socket/error_unix.go b/internal/socket/error_unix.go
index 78f4129047..7a5cc5c43e 100644
--- a/internal/socket/error_unix.go
+++ b/internal/socket/error_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package socket
diff --git a/internal/socket/iovec_32bit.go b/internal/socket/iovec_32bit.go
index 2b8fbb3f3d..340e53fbda 100644
--- a/internal/socket/iovec_32bit.go
+++ b/internal/socket/iovec_32bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm || mips || mipsle || 386 || ppc) && (darwin || dragonfly || freebsd || linux || netbsd || openbsd)
-// +build arm mips mipsle 386 ppc
-// +build darwin dragonfly freebsd linux netbsd openbsd
package socket
diff --git a/internal/socket/iovec_64bit.go b/internal/socket/iovec_64bit.go
index 2e94e96f8b..26470c191a 100644
--- a/internal/socket/iovec_64bit.go
+++ b/internal/socket/iovec_64bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm64 || amd64 || loong64 || ppc64 || ppc64le || mips64 || mips64le || riscv64 || s390x) && (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || zos)
-// +build arm64 amd64 loong64 ppc64 ppc64le mips64 mips64le riscv64 s390x
-// +build aix darwin dragonfly freebsd linux netbsd openbsd zos
package socket
diff --git a/internal/socket/iovec_solaris_64bit.go b/internal/socket/iovec_solaris_64bit.go
index f7da2bc4d4..8859ce1035 100644
--- a/internal/socket/iovec_solaris_64bit.go
+++ b/internal/socket/iovec_solaris_64bit.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build amd64 && solaris
-// +build amd64,solaris
package socket
diff --git a/internal/socket/iovec_stub.go b/internal/socket/iovec_stub.go
index 14caf52483..da886b0326 100644
--- a/internal/socket/iovec_stub.go
+++ b/internal/socket/iovec_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!zos
package socket
diff --git a/internal/socket/mmsghdr_stub.go b/internal/socket/mmsghdr_stub.go
index 113e773cd5..4825b21e3e 100644
--- a/internal/socket/mmsghdr_stub.go
+++ b/internal/socket/mmsghdr_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !linux && !netbsd
-// +build !aix,!linux,!netbsd
package socket
diff --git a/internal/socket/mmsghdr_unix.go b/internal/socket/mmsghdr_unix.go
index 41883c530c..311fd2c789 100644
--- a/internal/socket/mmsghdr_unix.go
+++ b/internal/socket/mmsghdr_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || linux || netbsd
-// +build aix linux netbsd
package socket
diff --git a/internal/socket/msghdr_bsd.go b/internal/socket/msghdr_bsd.go
index 25f6847f99..ebff4f6e05 100644
--- a/internal/socket/msghdr_bsd.go
+++ b/internal/socket/msghdr_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd
-// +build aix darwin dragonfly freebsd netbsd openbsd
package socket
diff --git a/internal/socket/msghdr_bsdvar.go b/internal/socket/msghdr_bsdvar.go
index 5b8e00f1cd..62e6fe8616 100644
--- a/internal/socket/msghdr_bsdvar.go
+++ b/internal/socket/msghdr_bsdvar.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd
-// +build aix darwin dragonfly freebsd netbsd
package socket
diff --git a/internal/socket/msghdr_linux_32bit.go b/internal/socket/msghdr_linux_32bit.go
index b4658fbaeb..3dd07250a6 100644
--- a/internal/socket/msghdr_linux_32bit.go
+++ b/internal/socket/msghdr_linux_32bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm || mips || mipsle || 386 || ppc) && linux
-// +build arm mips mipsle 386 ppc
-// +build linux
package socket
diff --git a/internal/socket/msghdr_linux_64bit.go b/internal/socket/msghdr_linux_64bit.go
index 42411affad..5af9ddd6ab 100644
--- a/internal/socket/msghdr_linux_64bit.go
+++ b/internal/socket/msghdr_linux_64bit.go
@@ -3,8 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (arm64 || amd64 || loong64 || ppc64 || ppc64le || mips64 || mips64le || riscv64 || s390x) && linux
-// +build arm64 amd64 loong64 ppc64 ppc64le mips64 mips64le riscv64 s390x
-// +build linux
package socket
diff --git a/internal/socket/msghdr_solaris_64bit.go b/internal/socket/msghdr_solaris_64bit.go
index 3098f5d783..e212b50f8d 100644
--- a/internal/socket/msghdr_solaris_64bit.go
+++ b/internal/socket/msghdr_solaris_64bit.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build amd64 && solaris
-// +build amd64,solaris
package socket
diff --git a/internal/socket/msghdr_stub.go b/internal/socket/msghdr_stub.go
index eb79151f6a..e876776459 100644
--- a/internal/socket/msghdr_stub.go
+++ b/internal/socket/msghdr_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!zos
package socket
diff --git a/internal/socket/msghdr_zos_s390x.go b/internal/socket/msghdr_zos_s390x.go
index 324e9ee7d1..529db68ee3 100644
--- a/internal/socket/msghdr_zos_s390x.go
+++ b/internal/socket/msghdr_zos_s390x.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build s390x && zos
-// +build s390x,zos
package socket
diff --git a/internal/socket/norace.go b/internal/socket/norace.go
index de0ad420fc..8af30ecfbb 100644
--- a/internal/socket/norace.go
+++ b/internal/socket/norace.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !race
-// +build !race
package socket
diff --git a/internal/socket/race.go b/internal/socket/race.go
index f0a28a625d..9afa958083 100644
--- a/internal/socket/race.go
+++ b/internal/socket/race.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build race
-// +build race
package socket
diff --git a/internal/socket/rawconn_mmsg.go b/internal/socket/rawconn_mmsg.go
index 8f79b38f74..0431390789 100644
--- a/internal/socket/rawconn_mmsg.go
+++ b/internal/socket/rawconn_mmsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux
-// +build linux
package socket
diff --git a/internal/socket/rawconn_msg.go b/internal/socket/rawconn_msg.go
index f7d0b0d2b8..7c0d7410bc 100644
--- a/internal/socket/rawconn_msg.go
+++ b/internal/socket/rawconn_msg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package socket
diff --git a/internal/socket/rawconn_nommsg.go b/internal/socket/rawconn_nommsg.go
index 02f3285566..e363fb5a89 100644
--- a/internal/socket/rawconn_nommsg.go
+++ b/internal/socket/rawconn_nommsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !linux
-// +build !linux
package socket
diff --git a/internal/socket/rawconn_nomsg.go b/internal/socket/rawconn_nomsg.go
index dd785877b6..ff7a8baf0b 100644
--- a/internal/socket/rawconn_nomsg.go
+++ b/internal/socket/rawconn_nomsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package socket
diff --git a/internal/socket/socket_dontwait_test.go b/internal/socket/socket_dontwait_test.go
index 8eab9900b1..1eb3580f63 100644
--- a/internal/socket/socket_dontwait_test.go
+++ b/internal/socket/socket_dontwait_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package socket_test
diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go
index 84907d8bc1..26077a7a5b 100644
--- a/internal/socket/socket_test.go
+++ b/internal/socket/socket_test.go
@@ -3,14 +3,12 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package socket_test
import (
"bytes"
"fmt"
- "io/ioutil"
"net"
"os"
"os/exec"
@@ -447,11 +445,7 @@ func main() {
if runtime.Compiler == "gccgo" {
t.Skip("skipping race test when built with gccgo")
}
- dir, err := ioutil.TempDir("", "testrace")
- if err != nil {
- t.Fatalf("failed to create temp directory: %v", err)
- }
- defer os.RemoveAll(dir)
+ dir := t.TempDir()
goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
t.Logf("%s version", goBinary)
got, err := exec.Command(goBinary, "version").CombinedOutput()
@@ -464,7 +458,7 @@ func main() {
for i, test := range tests {
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
- if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
+ if err := os.WriteFile(src, []byte(test), 0644); err != nil {
t.Fatalf("failed to write file: %v", err)
}
t.Logf("%s run -race %s", goBinary, src)
diff --git a/internal/socket/sys_bsd.go b/internal/socket/sys_bsd.go
index b258879d44..e7664d48be 100644
--- a/internal/socket/sys_bsd.go
+++ b/internal/socket/sys_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd openbsd solaris
package socket
diff --git a/internal/socket/sys_const_unix.go b/internal/socket/sys_const_unix.go
index 5d99f2373f..d7627f87eb 100644
--- a/internal/socket/sys_const_unix.go
+++ b/internal/socket/sys_const_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package socket
diff --git a/internal/socket/sys_linux.go b/internal/socket/sys_linux.go
index 76f5b8ae5d..08d4910778 100644
--- a/internal/socket/sys_linux.go
+++ b/internal/socket/sys_linux.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux && !s390x && !386
-// +build linux,!s390x,!386
package socket
diff --git a/internal/socket/sys_linux_loong64.go b/internal/socket/sys_linux_loong64.go
index af964e6171..1d182470d0 100644
--- a/internal/socket/sys_linux_loong64.go
+++ b/internal/socket/sys_linux_loong64.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build loong64
-// +build loong64
package socket
diff --git a/internal/socket/sys_linux_riscv64.go b/internal/socket/sys_linux_riscv64.go
index 5b128fbb2a..0e407d1257 100644
--- a/internal/socket/sys_linux_riscv64.go
+++ b/internal/socket/sys_linux_riscv64.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build riscv64
-// +build riscv64
package socket
diff --git a/internal/socket/sys_posix.go b/internal/socket/sys_posix.go
index 42b8f2340e..58d8654824 100644
--- a/internal/socket/sys_posix.go
+++ b/internal/socket/sys_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package socket
diff --git a/internal/socket/sys_stub.go b/internal/socket/sys_stub.go
index 7cfb349c0c..2e5b473c66 100644
--- a/internal/socket/sys_stub.go
+++ b/internal/socket/sys_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package socket
diff --git a/internal/socket/sys_unix.go b/internal/socket/sys_unix.go
index de823932b9..93058db5b9 100644
--- a/internal/socket/sys_unix.go
+++ b/internal/socket/sys_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package socket
diff --git a/internal/socket/zsys_aix_ppc64.go b/internal/socket/zsys_aix_ppc64.go
index 00691bd524..45bab004c1 100644
--- a/internal/socket/zsys_aix_ppc64.go
+++ b/internal/socket/zsys_aix_ppc64.go
@@ -3,7 +3,6 @@
// Added for go1.11 compatibility
//go:build aix
-// +build aix
package socket
diff --git a/internal/socket/zsys_linux_loong64.go b/internal/socket/zsys_linux_loong64.go
index 6a94fec2c5..b6fc15a1a2 100644
--- a/internal/socket/zsys_linux_loong64.go
+++ b/internal/socket/zsys_linux_loong64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build loong64
-// +build loong64
package socket
diff --git a/internal/socket/zsys_linux_riscv64.go b/internal/socket/zsys_linux_riscv64.go
index c066272ddd..e67fc3cbaa 100644
--- a/internal/socket/zsys_linux_riscv64.go
+++ b/internal/socket/zsys_linux_riscv64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build riscv64
-// +build riscv64
package socket
diff --git a/internal/socket/zsys_openbsd_ppc64.go b/internal/socket/zsys_openbsd_ppc64.go
index cebde7634f..3c9576e2d8 100644
--- a/internal/socket/zsys_openbsd_ppc64.go
+++ b/internal/socket/zsys_openbsd_ppc64.go
@@ -4,27 +4,27 @@
package socket
type iovec struct {
- Base *byte
- Len uint64
+ Base *byte
+ Len uint64
}
type msghdr struct {
- Name *byte
- Namelen uint32
- Iov *iovec
- Iovlen uint32
- Control *byte
- Controllen uint32
- Flags int32
+ Name *byte
+ Namelen uint32
+ Iov *iovec
+ Iovlen uint32
+ Control *byte
+ Controllen uint32
+ Flags int32
}
type cmsghdr struct {
- Len uint32
- Level int32
- Type int32
+ Len uint32
+ Level int32
+ Type int32
}
const (
- sizeofIovec = 0x10
- sizeofMsghdr = 0x30
+ sizeofIovec = 0x10
+ sizeofMsghdr = 0x30
)
diff --git a/internal/socket/zsys_openbsd_riscv64.go b/internal/socket/zsys_openbsd_riscv64.go
index cebde7634f..3c9576e2d8 100644
--- a/internal/socket/zsys_openbsd_riscv64.go
+++ b/internal/socket/zsys_openbsd_riscv64.go
@@ -4,27 +4,27 @@
package socket
type iovec struct {
- Base *byte
- Len uint64
+ Base *byte
+ Len uint64
}
type msghdr struct {
- Name *byte
- Namelen uint32
- Iov *iovec
- Iovlen uint32
- Control *byte
- Controllen uint32
- Flags int32
+ Name *byte
+ Namelen uint32
+ Iov *iovec
+ Iovlen uint32
+ Control *byte
+ Controllen uint32
+ Flags int32
}
type cmsghdr struct {
- Len uint32
- Level int32
- Type int32
+ Len uint32
+ Level int32
+ Type int32
}
const (
- sizeofIovec = 0x10
- sizeofMsghdr = 0x30
+ sizeofIovec = 0x10
+ sizeofMsghdr = 0x30
)
diff --git a/internal/quic/tlsconfig_test.go b/internal/testcert/testcert.go
similarity index 62%
rename from internal/quic/tlsconfig_test.go
rename to internal/testcert/testcert.go
index 47bfb05983..4d8ae33bba 100644
--- a/internal/quic/tlsconfig_test.go
+++ b/internal/testcert/testcert.go
@@ -1,45 +1,19 @@
-// Copyright 2023 The Go Authors. All rights reserved.
+// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build go1.21
-
-package quic
+// Package testcert contains a test-only localhost certificate.
+package testcert
import (
- "crypto/tls"
"strings"
)
-func newTestTLSConfig(side connSide) *tls.Config {
- config := &tls.Config{
- InsecureSkipVerify: true,
- CipherSuites: []uint16{
- tls.TLS_AES_128_GCM_SHA256,
- tls.TLS_AES_256_GCM_SHA384,
- tls.TLS_CHACHA20_POLY1305_SHA256,
- },
- MinVersion: tls.VersionTLS13,
- }
- if side == serverSide {
- config.Certificates = []tls.Certificate{testCert}
- }
- return config
-}
-
-var testCert = func() tls.Certificate {
- cert, err := tls.X509KeyPair(localhostCert, localhostKey)
- if err != nil {
- panic(err)
- }
- return cert
-}()
-
-// localhostCert is a PEM-encoded TLS cert with SAN IPs
+// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
-var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+var LocalhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO
BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa
MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh
@@ -51,8 +25,8 @@ AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40
HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE
-----END CERTIFICATE-----`)
-// localhostKey is the private key for localhostCert.
-var localhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
+// LocalhostKey is the private key for localhostCert.
+var LocalhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs
rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd
hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9
diff --git a/ipv4/control_bsd.go b/ipv4/control_bsd.go
index b7385dfd95..c88da8cbe7 100644
--- a/ipv4/control_bsd.go
+++ b/ipv4/control_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd
-// +build aix darwin dragonfly freebsd netbsd openbsd
package ipv4
diff --git a/ipv4/control_pktinfo.go b/ipv4/control_pktinfo.go
index 0e748dbdc4..14ae2dae49 100644
--- a/ipv4/control_pktinfo.go
+++ b/ipv4/control_pktinfo.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || linux || solaris
-// +build darwin linux solaris
package ipv4
diff --git a/ipv4/control_stub.go b/ipv4/control_stub.go
index f27322c3ed..3ba6611609 100644
--- a/ipv4/control_stub.go
+++ b/ipv4/control_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv4
diff --git a/ipv4/control_unix.go b/ipv4/control_unix.go
index 2413e02f8f..2e765548f3 100644
--- a/ipv4/control_unix.go
+++ b/ipv4/control_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package ipv4
diff --git a/ipv4/defs_aix.go b/ipv4/defs_aix.go
index b70b618240..5e590a7df2 100644
--- a/ipv4/defs_aix.go
+++ b/ipv4/defs_aix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_darwin.go b/ipv4/defs_darwin.go
index 0ceadfce2e..2494ff86a9 100644
--- a/ipv4/defs_darwin.go
+++ b/ipv4/defs_darwin.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_dragonfly.go b/ipv4/defs_dragonfly.go
index a84630c5cd..43e9f67bb7 100644
--- a/ipv4/defs_dragonfly.go
+++ b/ipv4/defs_dragonfly.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_freebsd.go b/ipv4/defs_freebsd.go
index b068087a47..05899b3b4f 100644
--- a/ipv4/defs_freebsd.go
+++ b/ipv4/defs_freebsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_linux.go b/ipv4/defs_linux.go
index 7c8554d4b3..fc869b0194 100644
--- a/ipv4/defs_linux.go
+++ b/ipv4/defs_linux.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_netbsd.go b/ipv4/defs_netbsd.go
index a84630c5cd..43e9f67bb7 100644
--- a/ipv4/defs_netbsd.go
+++ b/ipv4/defs_netbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_openbsd.go b/ipv4/defs_openbsd.go
index a84630c5cd..43e9f67bb7 100644
--- a/ipv4/defs_openbsd.go
+++ b/ipv4/defs_openbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/defs_solaris.go b/ipv4/defs_solaris.go
index 0ceadfce2e..2494ff86a9 100644
--- a/ipv4/defs_solaris.go
+++ b/ipv4/defs_solaris.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
diff --git a/ipv4/errors_other_test.go b/ipv4/errors_other_test.go
index 6154353918..93a7f9d74c 100644
--- a/ipv4/errors_other_test.go
+++ b/ipv4/errors_other_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris)
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris
package ipv4_test
diff --git a/ipv4/errors_unix_test.go b/ipv4/errors_unix_test.go
index 566e070a50..7cff0097c9 100644
--- a/ipv4/errors_unix_test.go
+++ b/ipv4/errors_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package ipv4_test
diff --git a/ipv4/gen.go b/ipv4/gen.go
index e7b053a17b..f0182be2da 100644
--- a/ipv4/gen.go
+++ b/ipv4/gen.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
//go:generate go run gen.go
@@ -18,7 +17,6 @@ import (
"fmt"
"go/format"
"io"
- "io/ioutil"
"net/http"
"os"
"os/exec"
@@ -62,7 +60,7 @@ func genzsys() error {
case "freebsd", "linux":
zsys = "zsys_" + runtime.GOOS + "_" + runtime.GOARCH + ".go"
}
- if err := ioutil.WriteFile(zsys, b, 0644); err != nil {
+ if err := os.WriteFile(zsys, b, 0644); err != nil {
return err
}
return nil
@@ -101,7 +99,7 @@ func geniana() error {
if err != nil {
return err
}
- if err := ioutil.WriteFile("iana.go", b, 0644); err != nil {
+ if err := os.WriteFile("iana.go", b, 0644); err != nil {
return err
}
return nil
diff --git a/ipv4/helper_posix_test.go b/ipv4/helper_posix_test.go
index 4f6ecc0fd9..ab8ffd90dc 100644
--- a/ipv4/helper_posix_test.go
+++ b/ipv4/helper_posix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package ipv4_test
diff --git a/ipv4/helper_stub_test.go b/ipv4/helper_stub_test.go
index e47ddf7f36..791e6d4c0a 100644
--- a/ipv4/helper_stub_test.go
+++ b/ipv4/helper_stub_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv4_test
diff --git a/ipv4/icmp_stub.go b/ipv4/icmp_stub.go
index cd4ee6e1c9..c2c4ce7ff5 100644
--- a/ipv4/icmp_stub.go
+++ b/ipv4/icmp_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !linux
-// +build !linux
package ipv4
diff --git a/ipv4/payload_cmsg.go b/ipv4/payload_cmsg.go
index 1bb370e25f..91c685e8fc 100644
--- a/ipv4/payload_cmsg.go
+++ b/ipv4/payload_cmsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package ipv4
diff --git a/ipv4/payload_nocmsg.go b/ipv4/payload_nocmsg.go
index 53f0794eb7..2afd4b50ef 100644
--- a/ipv4/payload_nocmsg.go
+++ b/ipv4/payload_nocmsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!zos
package ipv4
diff --git a/ipv4/sockopt_posix.go b/ipv4/sockopt_posix.go
index eb07c1c02a..82e2c37838 100644
--- a/ipv4/sockopt_posix.go
+++ b/ipv4/sockopt_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package ipv4
diff --git a/ipv4/sockopt_stub.go b/ipv4/sockopt_stub.go
index cf036893b7..840108bf76 100644
--- a/ipv4/sockopt_stub.go
+++ b/ipv4/sockopt_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv4
diff --git a/ipv4/sys_aix.go b/ipv4/sys_aix.go
index 02730cdfd2..9244a68a38 100644
--- a/ipv4/sys_aix.go
+++ b/ipv4/sys_aix.go
@@ -4,7 +4,6 @@
// Added for go1.11 compatibility
//go:build aix
-// +build aix
package ipv4
diff --git a/ipv4/sys_asmreq.go b/ipv4/sys_asmreq.go
index 22322b387e..645f254c6d 100644
--- a/ipv4/sys_asmreq.go
+++ b/ipv4/sys_asmreq.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd netbsd openbsd solaris windows
package ipv4
diff --git a/ipv4/sys_asmreq_stub.go b/ipv4/sys_asmreq_stub.go
index fde640142d..48cfb6db2f 100644
--- a/ipv4/sys_asmreq_stub.go
+++ b/ipv4/sys_asmreq_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !windows
-// +build !aix,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!windows
package ipv4
diff --git a/ipv4/sys_asmreqn.go b/ipv4/sys_asmreqn.go
index 54eb9901b5..0b27b632f1 100644
--- a/ipv4/sys_asmreqn.go
+++ b/ipv4/sys_asmreqn.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || freebsd || linux
-// +build darwin freebsd linux
package ipv4
diff --git a/ipv4/sys_asmreqn_stub.go b/ipv4/sys_asmreqn_stub.go
index dcb15f25a5..303a5e2e68 100644
--- a/ipv4/sys_asmreqn_stub.go
+++ b/ipv4/sys_asmreqn_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !darwin && !freebsd && !linux
-// +build !darwin,!freebsd,!linux
package ipv4
diff --git a/ipv4/sys_bpf.go b/ipv4/sys_bpf.go
index fb11e324e2..1b4780df41 100644
--- a/ipv4/sys_bpf.go
+++ b/ipv4/sys_bpf.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux
-// +build linux
package ipv4
diff --git a/ipv4/sys_bpf_stub.go b/ipv4/sys_bpf_stub.go
index fc53a0d33a..b1f779b493 100644
--- a/ipv4/sys_bpf_stub.go
+++ b/ipv4/sys_bpf_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !linux
-// +build !linux
package ipv4
diff --git a/ipv4/sys_bsd.go b/ipv4/sys_bsd.go
index e191b2f14f..b7b032d260 100644
--- a/ipv4/sys_bsd.go
+++ b/ipv4/sys_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build netbsd || openbsd
-// +build netbsd openbsd
package ipv4
diff --git a/ipv4/sys_ssmreq.go b/ipv4/sys_ssmreq.go
index 6a4e7abf9b..a295e15ea0 100644
--- a/ipv4/sys_ssmreq.go
+++ b/ipv4/sys_ssmreq.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || freebsd || linux || solaris
-// +build darwin freebsd linux solaris
package ipv4
diff --git a/ipv4/sys_ssmreq_stub.go b/ipv4/sys_ssmreq_stub.go
index 157159fd50..74bd454e25 100644
--- a/ipv4/sys_ssmreq_stub.go
+++ b/ipv4/sys_ssmreq_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !darwin && !freebsd && !linux && !solaris
-// +build !darwin,!freebsd,!linux,!solaris
package ipv4
diff --git a/ipv4/sys_stub.go b/ipv4/sys_stub.go
index d550851658..20af4074c2 100644
--- a/ipv4/sys_stub.go
+++ b/ipv4/sys_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv4
diff --git a/ipv4/zsys_aix_ppc64.go b/ipv4/zsys_aix_ppc64.go
index b7f2d6e5c1..dd454025c7 100644
--- a/ipv4/zsys_aix_ppc64.go
+++ b/ipv4/zsys_aix_ppc64.go
@@ -3,7 +3,6 @@
// Added for go1.11 compatibility
//go:build aix
-// +build aix
package ipv4
diff --git a/ipv4/zsys_linux_loong64.go b/ipv4/zsys_linux_loong64.go
index e15c22c748..54f9e13948 100644
--- a/ipv4/zsys_linux_loong64.go
+++ b/ipv4/zsys_linux_loong64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build loong64
-// +build loong64
package ipv4
diff --git a/ipv4/zsys_linux_riscv64.go b/ipv4/zsys_linux_riscv64.go
index e2edebdb81..78374a5250 100644
--- a/ipv4/zsys_linux_riscv64.go
+++ b/ipv4/zsys_linux_riscv64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build riscv64
-// +build riscv64
package ipv4
diff --git a/ipv6/control_rfc2292_unix.go b/ipv6/control_rfc2292_unix.go
index 2733ddbe27..a8f04e7b3b 100644
--- a/ipv6/control_rfc2292_unix.go
+++ b/ipv6/control_rfc2292_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin
-// +build darwin
package ipv6
diff --git a/ipv6/control_rfc3542_unix.go b/ipv6/control_rfc3542_unix.go
index 9c90844aac..51fbbb1f17 100644
--- a/ipv6/control_rfc3542_unix.go
+++ b/ipv6/control_rfc3542_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package ipv6
diff --git a/ipv6/control_stub.go b/ipv6/control_stub.go
index b7e8643fc9..eb28ce7534 100644
--- a/ipv6/control_stub.go
+++ b/ipv6/control_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv6
diff --git a/ipv6/control_unix.go b/ipv6/control_unix.go
index 63e475db83..9c73b8647e 100644
--- a/ipv6/control_unix.go
+++ b/ipv6/control_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package ipv6
diff --git a/ipv6/defs_aix.go b/ipv6/defs_aix.go
index 97db07e8d6..de171ce2c8 100644
--- a/ipv6/defs_aix.go
+++ b/ipv6/defs_aix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_darwin.go b/ipv6/defs_darwin.go
index 1d31e22c18..3b9e6ba649 100644
--- a/ipv6/defs_darwin.go
+++ b/ipv6/defs_darwin.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_dragonfly.go b/ipv6/defs_dragonfly.go
index ddaed6597c..b40d34b136 100644
--- a/ipv6/defs_dragonfly.go
+++ b/ipv6/defs_dragonfly.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_freebsd.go b/ipv6/defs_freebsd.go
index 6f6bc6dbc3..fe9a0f70fb 100644
--- a/ipv6/defs_freebsd.go
+++ b/ipv6/defs_freebsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_linux.go b/ipv6/defs_linux.go
index 0adcbd92dc..b947c225ae 100644
--- a/ipv6/defs_linux.go
+++ b/ipv6/defs_linux.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_netbsd.go b/ipv6/defs_netbsd.go
index ddaed6597c..b40d34b136 100644
--- a/ipv6/defs_netbsd.go
+++ b/ipv6/defs_netbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_openbsd.go b/ipv6/defs_openbsd.go
index ddaed6597c..b40d34b136 100644
--- a/ipv6/defs_openbsd.go
+++ b/ipv6/defs_openbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/defs_solaris.go b/ipv6/defs_solaris.go
index 03193da9be..7981a04524 100644
--- a/ipv6/defs_solaris.go
+++ b/ipv6/defs_solaris.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/ipv6/errors_other_test.go b/ipv6/errors_other_test.go
index 5a87d73618..5f6c0cb270 100644
--- a/ipv6/errors_other_test.go
+++ b/ipv6/errors_other_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris)
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris
package ipv6_test
diff --git a/ipv6/errors_unix_test.go b/ipv6/errors_unix_test.go
index 978ae61f84..9e8efd3137 100644
--- a/ipv6/errors_unix_test.go
+++ b/ipv6/errors_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package ipv6_test
diff --git a/ipv6/gen.go b/ipv6/gen.go
index bd53468eb0..590568a113 100644
--- a/ipv6/gen.go
+++ b/ipv6/gen.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
//go:generate go run gen.go
@@ -18,7 +17,6 @@ import (
"fmt"
"go/format"
"io"
- "io/ioutil"
"net/http"
"os"
"os/exec"
@@ -62,7 +60,7 @@ func genzsys() error {
case "freebsd", "linux":
zsys = "zsys_" + runtime.GOOS + "_" + runtime.GOARCH + ".go"
}
- if err := ioutil.WriteFile(zsys, b, 0644); err != nil {
+ if err := os.WriteFile(zsys, b, 0644); err != nil {
return err
}
return nil
@@ -101,7 +99,7 @@ func geniana() error {
if err != nil {
return err
}
- if err := ioutil.WriteFile("iana.go", b, 0644); err != nil {
+ if err := os.WriteFile("iana.go", b, 0644); err != nil {
return err
}
return nil
diff --git a/ipv6/helper_posix_test.go b/ipv6/helper_posix_test.go
index 8ca6a3c3cb..f412a78cbc 100644
--- a/ipv6/helper_posix_test.go
+++ b/ipv6/helper_posix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package ipv6_test
diff --git a/ipv6/helper_stub_test.go b/ipv6/helper_stub_test.go
index 15e99fa94a..9412a4cf5d 100644
--- a/ipv6/helper_stub_test.go
+++ b/ipv6/helper_stub_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv6_test
diff --git a/ipv6/helper_unix_test.go b/ipv6/helper_unix_test.go
index 5ccff9d9b2..c2459e320e 100644
--- a/ipv6/helper_unix_test.go
+++ b/ipv6/helper_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package ipv6_test
diff --git a/ipv6/icmp_bsd.go b/ipv6/icmp_bsd.go
index 120bf87758..2814534a0b 100644
--- a/ipv6/icmp_bsd.go
+++ b/ipv6/icmp_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd
-// +build aix darwin dragonfly freebsd netbsd openbsd
package ipv6
diff --git a/ipv6/icmp_stub.go b/ipv6/icmp_stub.go
index d60136a901..c92c9b51e1 100644
--- a/ipv6/icmp_stub.go
+++ b/ipv6/icmp_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv6
diff --git a/ipv6/payload_cmsg.go b/ipv6/payload_cmsg.go
index b0692e4304..be04e4d6ae 100644
--- a/ipv6/payload_cmsg.go
+++ b/ipv6/payload_cmsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package ipv6
diff --git a/ipv6/payload_nocmsg.go b/ipv6/payload_nocmsg.go
index cd0ff50838..29b9ccf691 100644
--- a/ipv6/payload_nocmsg.go
+++ b/ipv6/payload_nocmsg.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!zos
package ipv6
diff --git a/ipv6/sockopt_posix.go b/ipv6/sockopt_posix.go
index 37c6287130..34dfed588e 100644
--- a/ipv6/sockopt_posix.go
+++ b/ipv6/sockopt_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
package ipv6
diff --git a/ipv6/sockopt_stub.go b/ipv6/sockopt_stub.go
index 32fd8664ce..a09c3aaf26 100644
--- a/ipv6/sockopt_stub.go
+++ b/ipv6/sockopt_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv6
diff --git a/ipv6/sys_aix.go b/ipv6/sys_aix.go
index a47182afb9..93c8efc468 100644
--- a/ipv6/sys_aix.go
+++ b/ipv6/sys_aix.go
@@ -4,7 +4,6 @@
// Added for go1.11 compatibility
//go:build aix
-// +build aix
package ipv6
diff --git a/ipv6/sys_asmreq.go b/ipv6/sys_asmreq.go
index 6ff9950d13..5c9cb44471 100644
--- a/ipv6/sys_asmreq.go
+++ b/ipv6/sys_asmreq.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
package ipv6
diff --git a/ipv6/sys_asmreq_stub.go b/ipv6/sys_asmreq_stub.go
index 485290cb82..dc70494680 100644
--- a/ipv6/sys_asmreq_stub.go
+++ b/ipv6/sys_asmreq_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows
package ipv6
diff --git a/ipv6/sys_bpf.go b/ipv6/sys_bpf.go
index b5661fb8f0..e39f75f49f 100644
--- a/ipv6/sys_bpf.go
+++ b/ipv6/sys_bpf.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux
-// +build linux
package ipv6
diff --git a/ipv6/sys_bpf_stub.go b/ipv6/sys_bpf_stub.go
index cb00661872..8532a8f5de 100644
--- a/ipv6/sys_bpf_stub.go
+++ b/ipv6/sys_bpf_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !linux
-// +build !linux
package ipv6
diff --git a/ipv6/sys_bsd.go b/ipv6/sys_bsd.go
index bde41a6cef..9f3bc2afde 100644
--- a/ipv6/sys_bsd.go
+++ b/ipv6/sys_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || netbsd || openbsd
-// +build dragonfly netbsd openbsd
package ipv6
diff --git a/ipv6/sys_ssmreq.go b/ipv6/sys_ssmreq.go
index 023488a49c..b40f5c685b 100644
--- a/ipv6/sys_ssmreq.go
+++ b/ipv6/sys_ssmreq.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || freebsd || linux || solaris || zos
-// +build aix darwin freebsd linux solaris zos
package ipv6
diff --git a/ipv6/sys_ssmreq_stub.go b/ipv6/sys_ssmreq_stub.go
index acdf2e5cf7..6526aad581 100644
--- a/ipv6/sys_ssmreq_stub.go
+++ b/ipv6/sys_ssmreq_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !freebsd && !linux && !solaris && !zos
-// +build !aix,!darwin,!freebsd,!linux,!solaris,!zos
package ipv6
diff --git a/ipv6/sys_stub.go b/ipv6/sys_stub.go
index 5807bba392..76602c34e6 100644
--- a/ipv6/sys_stub.go
+++ b/ipv6/sys_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package ipv6
diff --git a/ipv6/zsys_aix_ppc64.go b/ipv6/zsys_aix_ppc64.go
index f604b0f3b4..668716df4d 100644
--- a/ipv6/zsys_aix_ppc64.go
+++ b/ipv6/zsys_aix_ppc64.go
@@ -3,7 +3,6 @@
// Added for go1.11 compatibility
//go:build aix
-// +build aix
package ipv6
diff --git a/ipv6/zsys_linux_loong64.go b/ipv6/zsys_linux_loong64.go
index 598fbfa06f..6a53284dbe 100644
--- a/ipv6/zsys_linux_loong64.go
+++ b/ipv6/zsys_linux_loong64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build loong64
-// +build loong64
package ipv6
diff --git a/ipv6/zsys_linux_riscv64.go b/ipv6/zsys_linux_riscv64.go
index d4f78e405a..13b3472057 100644
--- a/ipv6/zsys_linux_riscv64.go
+++ b/ipv6/zsys_linux_riscv64.go
@@ -2,7 +2,6 @@
// cgo -godefs defs_linux.go
//go:build riscv64
-// +build riscv64
package ipv6
diff --git a/lif/address.go b/lif/address.go
index 8eaddb508d..0ed62a2c4c 100644
--- a/lif/address.go
+++ b/lif/address.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/address_test.go b/lif/address_test.go
index fdaa7f3aa4..0e99b8d34e 100644
--- a/lif/address_test.go
+++ b/lif/address_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/binary.go b/lif/binary.go
index f31ca3ad07..8a6c456061 100644
--- a/lif/binary.go
+++ b/lif/binary.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/defs_solaris.go b/lif/defs_solaris.go
index dbed7c86ed..6bc8fa8e6b 100644
--- a/lif/defs_solaris.go
+++ b/lif/defs_solaris.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
diff --git a/lif/lif.go b/lif/lif.go
index f1fce48b34..e9f2a9e0ed 100644
--- a/lif/lif.go
+++ b/lif/lif.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
// Package lif provides basic functions for the manipulation of
// logical network interfaces and interface addresses on Solaris.
diff --git a/lif/link.go b/lif/link.go
index 00b78545b5..d0c615a0b3 100644
--- a/lif/link.go
+++ b/lif/link.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/link_test.go b/lif/link_test.go
index 40b3f3ff2b..fe56697f82 100644
--- a/lif/link_test.go
+++ b/lif/link_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/sys.go b/lif/sys.go
index d0b532d9dc..caba2fe90d 100644
--- a/lif/sys.go
+++ b/lif/sys.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/lif/syscall.go b/lif/syscall.go
index 8d03b4aa92..329a65fe63 100644
--- a/lif/syscall.go
+++ b/lif/syscall.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build solaris
-// +build solaris
package lif
diff --git a/nettest/conntest.go b/nettest/conntest.go
index 615f4980c5..4297d408c0 100644
--- a/nettest/conntest.go
+++ b/nettest/conntest.go
@@ -8,7 +8,6 @@ import (
"bytes"
"encoding/binary"
"io"
- "io/ioutil"
"math/rand"
"net"
"runtime"
@@ -173,7 +172,7 @@ func testRacyRead(t *testing.T, c1, c2 net.Conn) {
// testRacyWrite tests that it is safe to mutate the input Write buffer
// immediately after cancelation has occurred.
func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
- go chunkedCopy(ioutil.Discard, c2)
+ go chunkedCopy(io.Discard, c2)
var wg sync.WaitGroup
defer wg.Wait()
@@ -200,7 +199,7 @@ func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
// testReadTimeout tests that Read timeouts do not affect Write.
func testReadTimeout(t *testing.T, c1, c2 net.Conn) {
- go chunkedCopy(ioutil.Discard, c2)
+ go chunkedCopy(io.Discard, c2)
c1.SetReadDeadline(aLongTimeAgo)
_, err := c1.Read(make([]byte, 1024))
diff --git a/nettest/conntest_test.go b/nettest/conntest_test.go
index 7c5aeb9b32..c57e640048 100644
--- a/nettest/conntest_test.go
+++ b/nettest/conntest_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build go1.8
-// +build go1.8
package nettest
diff --git a/nettest/nettest.go b/nettest/nettest.go
index 3656c3c54b..37e6dcb1b4 100644
--- a/nettest/nettest.go
+++ b/nettest/nettest.go
@@ -8,7 +8,6 @@ package nettest
import (
"errors"
"fmt"
- "io/ioutil"
"net"
"os"
"os/exec"
@@ -226,7 +225,7 @@ func LocalPath() (string, error) {
if runtime.GOOS == "darwin" {
dir = "/tmp"
}
- f, err := ioutil.TempFile(dir, "go-nettest")
+ f, err := os.CreateTemp(dir, "go-nettest")
if err != nil {
return "", err
}
diff --git a/nettest/nettest_stub.go b/nettest/nettest_stub.go
index 6e3a9312b9..1725b6aa18 100644
--- a/nettest/nettest_stub.go
+++ b/nettest/nettest_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos
-// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris,!windows,!zos
package nettest
diff --git a/nettest/nettest_unix.go b/nettest/nettest_unix.go
index b1cb8b2f3b..9ba269d020 100644
--- a/nettest/nettest_unix.go
+++ b/nettest/nettest_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package nettest
diff --git a/proxy/per_host.go b/proxy/per_host.go
index 573fe79e86..32bdf435ec 100644
--- a/proxy/per_host.go
+++ b/proxy/per_host.go
@@ -7,6 +7,7 @@ package proxy
import (
"context"
"net"
+ "net/netip"
"strings"
)
@@ -57,7 +58,8 @@ func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.
}
func (p *PerHost) dialerForRequest(host string) Dialer {
- if ip := net.ParseIP(host); ip != nil {
+ if nip, err := netip.ParseAddr(host); err == nil {
+ ip := net.IP(nip.AsSlice())
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
@@ -108,8 +110,8 @@ func (p *PerHost) AddFromString(s string) {
}
continue
}
- if ip := net.ParseIP(host); ip != nil {
- p.AddIP(ip)
+ if nip, err := netip.ParseAddr(host); err == nil {
+ p.AddIP(net.IP(nip.AsSlice()))
continue
}
if strings.HasPrefix(host, "*.") {
@@ -137,9 +139,7 @@ func (p *PerHost) AddNetwork(net *net.IPNet) {
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
- if strings.HasSuffix(zone, ".") {
- zone = zone[:len(zone)-1]
- }
+ zone = strings.TrimSuffix(zone, ".")
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
@@ -148,8 +148,6 @@ func (p *PerHost) AddZone(zone string) {
// AddHost specifies a host name that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
- if strings.HasSuffix(host, ".") {
- host = host[:len(host)-1]
- }
+ host = strings.TrimSuffix(host, ".")
p.bypassHosts = append(p.bypassHosts, host)
}
diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go
index 0447eb427a..b7bcec8ae3 100644
--- a/proxy/per_host_test.go
+++ b/proxy/per_host_test.go
@@ -7,8 +7,9 @@ package proxy
import (
"context"
"errors"
+ "fmt"
"net"
- "reflect"
+ "slices"
"testing"
)
@@ -22,55 +23,118 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
}
func TestPerHost(t *testing.T) {
- expectedDef := []string{
- "example.com:123",
- "1.2.3.4:123",
- "[1001::]:123",
- }
- expectedBypass := []string{
- "localhost:123",
- "zone:123",
- "foo.zone:123",
- "127.0.0.1:123",
- "10.1.2.3:123",
- "[1000::]:123",
- }
-
- t.Run("Dial", func(t *testing.T) {
- var def, bypass recordingProxy
- perHost := NewPerHost(&def, &bypass)
- perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
- for _, addr := range expectedDef {
- perHost.Dial("tcp", addr)
+ for _, test := range []struct {
+ config string // passed to PerHost.AddFromString
+ nomatch []string // addrs using the default dialer
+ match []string // addrs using the bypass dialer
+ }{{
+ config: "localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16",
+ nomatch: []string{
+ "example.com:123",
+ "1.2.3.4:123",
+ "[1001::]:123",
+ },
+ match: []string{
+ "localhost:123",
+ "zone:123",
+ "foo.zone:123",
+ "127.0.0.1:123",
+ "10.1.2.3:123",
+ "[1000::]:123",
+ "[1000::%25.example.com]:123",
+ },
+ }, {
+ config: "localhost",
+ nomatch: []string{
+ "127.0.0.1:80",
+ },
+ match: []string{
+ "localhost:80",
+ },
+ }, {
+ config: "*.zone",
+ nomatch: []string{
+ "foo.com:80",
+ },
+ match: []string{
+ "foo.zone:80",
+ "foo.bar.zone:80",
+ },
+ }, {
+ config: "1.2.3.4",
+ nomatch: []string{
+ "127.0.0.1:80",
+ "11.2.3.4:80",
+ },
+ match: []string{
+ "1.2.3.4:80",
+ },
+ }, {
+ config: "10.0.0.0/24",
+ nomatch: []string{
+ "10.0.1.1:80",
+ },
+ match: []string{
+ "10.0.0.1:80",
+ "10.0.0.255:80",
+ },
+ }, {
+ config: "fe80::/10",
+ nomatch: []string{
+ "[fec0::1]:80",
+ "[fec0::1%en0]:80",
+ },
+ match: []string{
+ "[fe80::1]:80",
+ "[fe80::1%en0]:80",
+ },
+ }, {
+ // We don't allow zone IDs in network prefixes,
+ // so this config matches nothing.
+ config: "fe80::%en0/10",
+ nomatch: []string{
+ "[fec0::1]:80",
+ "[fec0::1%en0]:80",
+ "[fe80::1]:80",
+ "[fe80::1%en0]:80",
+ "[fe80::1%en1]:80",
+ },
+ }} {
+ for _, addr := range test.match {
+ testPerHost(t, test.config, addr, true)
}
- for _, addr := range expectedBypass {
- perHost.Dial("tcp", addr)
+ for _, addr := range test.nomatch {
+ testPerHost(t, test.config, addr, false)
}
+ }
+}
- if !reflect.DeepEqual(expectedDef, def.addrs) {
- t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
- }
- if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
- t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
- }
- })
+func testPerHost(t *testing.T, config, addr string, wantMatch bool) {
+ name := fmt.Sprintf("config %q, dial %q", config, addr)
- t.Run("DialContext", func(t *testing.T) {
- var def, bypass recordingProxy
- perHost := NewPerHost(&def, &bypass)
- perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
- for _, addr := range expectedDef {
- perHost.DialContext(context.Background(), "tcp", addr)
- }
- for _, addr := range expectedBypass {
- perHost.DialContext(context.Background(), "tcp", addr)
- }
+ var def, bypass recordingProxy
+ perHost := NewPerHost(&def, &bypass)
+ perHost.AddFromString(config)
+ perHost.Dial("tcp", addr)
- if !reflect.DeepEqual(expectedDef, def.addrs) {
- t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
- }
- if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
- t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
- }
- })
+ // Dial and DialContext should have the same results.
+ var defc, bypassc recordingProxy
+ perHostc := NewPerHost(&defc, &bypassc)
+ perHostc.AddFromString(config)
+ perHostc.DialContext(context.Background(), "tcp", addr)
+ if !slices.Equal(def.addrs, defc.addrs) {
+ t.Errorf("%v: Dial default=%v, bypass=%v; DialContext default=%v, bypass=%v", name, def.addrs, bypass.addrs, defc.addrs, bypass.addrs)
+ return
+ }
+
+ if got, want := slices.Concat(def.addrs, bypass.addrs), []string{addr}; !slices.Equal(got, want) {
+ t.Errorf("%v: dialed %q, want %q", name, got, want)
+ return
+ }
+
+ gotMatch := len(bypass.addrs) > 0
+ if gotMatch != wantMatch {
+ t.Errorf("%v: matched=%v, want %v", name, gotMatch, wantMatch)
+ return
+ }
}
diff --git a/publicsuffix/gen.go b/publicsuffix/gen.go
index 2ad0abdc1a..5f454e57e9 100644
--- a/publicsuffix/gen.go
+++ b/publicsuffix/gen.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package main
@@ -22,16 +21,16 @@ package main
import (
"bufio"
"bytes"
+ "cmp"
"encoding/binary"
"flag"
"fmt"
"go/format"
"io"
- "io/ioutil"
"net/http"
"os"
"regexp"
- "sort"
+ "slices"
"strings"
"golang.org/x/net/idna"
@@ -64,20 +63,6 @@ var (
maxLo uint32
)
-func max(a, b int) int {
- if a < b {
- return b
- }
- return a
-}
-
-func u32max(a, b uint32) uint32 {
- if a < b {
- return b
- }
- return a
-}
-
const (
nodeTypeNormal = 0
nodeTypeException = 1
@@ -85,18 +70,6 @@ const (
numNodeType = 3
)
-func nodeTypeStr(n int) string {
- switch n {
- case nodeTypeNormal:
- return "+"
- case nodeTypeException:
- return "!"
- case nodeTypeParentOnly:
- return "o"
- }
- panic("unreachable")
-}
-
const (
defaultURL = "https://publicsuffix.org/list/effective_tld_names.dat"
gitCommitURL = "https://api.github.com/repos/publicsuffix/list/commits?path=public_suffix_list.dat"
@@ -253,7 +226,7 @@ func main1() error {
for label := range labelsMap {
labelsList = append(labelsList, label)
}
- sort.Strings(labelsList)
+ slices.Sort(labelsList)
combinedText = combineText(labelsList)
if combinedText == "" {
@@ -299,7 +272,7 @@ func generate(p func(io.Writer, *node) error, root *node, filename string) error
if err != nil {
return err
}
- return ioutil.WriteFile(filename, b, 0644)
+ return os.WriteFile(filename, b, 0644)
}
func gitCommit() (sha, date string, retErr error) {
@@ -311,7 +284,7 @@ func gitCommit() (sha, date string, retErr error) {
return "", "", fmt.Errorf("bad GET status for %s: %s", gitCommitURL, res.Status)
}
defer res.Body.Close()
- b, err := ioutil.ReadAll(res.Body)
+ b, err := io.ReadAll(res.Body)
if err != nil {
return "", "", err
}
@@ -511,15 +484,13 @@ func (n *node) child(label string) *node {
icann: true,
}
n.children = append(n.children, c)
- sort.Sort(byLabel(n.children))
+ slices.SortFunc(n.children, byLabel)
return c
}
-type byLabel []*node
-
-func (b byLabel) Len() int { return len(b) }
-func (b byLabel) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
-func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label }
+func byLabel(a, b *node) int {
+ return strings.Compare(a.label, b.label)
+}
var nextNodesIndex int
@@ -559,7 +530,7 @@ func assignIndexes(n *node) error {
n.childrenIndex = len(childrenEncoding)
lo := uint32(n.firstChild)
hi := lo + uint32(len(n.children))
- maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi)
+ maxLo, maxHi = max(maxLo, lo), max(maxHi, hi)
if lo >= 1< 0 && ss[0] == "" {
ss = ss[1:]
}
diff --git a/publicsuffix/list.go b/publicsuffix/list.go
index d56e9e7624..56069d0429 100644
--- a/publicsuffix/list.go
+++ b/publicsuffix/list.go
@@ -88,7 +88,7 @@ func PublicSuffix(domain string) (publicSuffix string, icann bool) {
s, suffix, icannNode, wildcard := domain, len(domain), false, false
loop:
for {
- dot := strings.LastIndex(s, ".")
+ dot := strings.LastIndexByte(s, '.')
if wildcard {
icann = icannNode
suffix = 1 + dot
@@ -129,7 +129,7 @@ loop:
}
if suffix == len(domain) {
// If no rules match, the prevailing rule is "*".
- return domain[1+strings.LastIndex(domain, "."):], icann
+ return domain[1+strings.LastIndexByte(domain, '.'):], icann
}
return domain[suffix:], icann
}
@@ -178,26 +178,28 @@ func EffectiveTLDPlusOne(domain string) (string, error) {
if domain[i] != '.' {
return "", fmt.Errorf("publicsuffix: invalid public suffix %q for domain %q", suffix, domain)
}
- return domain[1+strings.LastIndex(domain[:i], "."):], nil
+ return domain[1+strings.LastIndexByte(domain[:i], '.'):], nil
}
type uint32String string
func (u uint32String) get(i uint32) uint32 {
off := i * 4
- return (uint32(u[off])<<24 |
- uint32(u[off+1])<<16 |
- uint32(u[off+2])<<8 |
- uint32(u[off+3]))
+ u = u[off:] // help the compiler reduce bounds checks
+ return uint32(u[3]) |
+ uint32(u[2])<<8 |
+ uint32(u[1])<<16 |
+ uint32(u[0])<<24
}
type uint40String string
func (u uint40String) get(i uint32) uint64 {
off := uint64(i * (nodesBits / 8))
- return uint64(u[off])<<32 |
- uint64(u[off+1])<<24 |
- uint64(u[off+2])<<16 |
- uint64(u[off+3])<<8 |
- uint64(u[off+4])
+ u = u[off:] // help the compiler reduce bounds checks
+ return uint64(u[4]) |
+ uint64(u[3])<<8 |
+ uint64(u[2])<<16 |
+ uint64(u[1])<<24 |
+ uint64(u[0])<<32
}
diff --git a/internal/quic/ack_delay.go b/quic/ack_delay.go
similarity index 100%
rename from internal/quic/ack_delay.go
rename to quic/ack_delay.go
diff --git a/internal/quic/ack_delay_test.go b/quic/ack_delay_test.go
similarity index 100%
rename from internal/quic/ack_delay_test.go
rename to quic/ack_delay_test.go
diff --git a/internal/quic/acks.go b/quic/acks.go
similarity index 91%
rename from internal/quic/acks.go
rename to quic/acks.go
index ba860efb2b..039b7b46e6 100644
--- a/internal/quic/acks.go
+++ b/quic/acks.go
@@ -130,12 +130,19 @@ func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bo
// there are no gaps. If it does not, there must be a gap.
return true
}
- if acks.unackedAckEliciting >= 2 {
- // "[...] after receiving at least two ack-eliciting packets."
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2
- return true
+ // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2
+ //
+ // This ack frequency takes a substantial toll on performance, however.
+ // Follow the behavior of Google QUICHE:
+ // Ack every other packet for the first 100 packets, and then ack every 10th packet.
+ // This keeps ack frequency high during the beginning of slow start when CWND is
+ // increasing rapidly.
+ packetsBeforeAck := 2
+ if acks.seen.max() > 100 {
+ packetsBeforeAck = 10
}
- return false
+ return acks.unackedAckEliciting >= packetsBeforeAck
}
// shouldSendAck reports whether the connection should send an ACK frame at this time,
diff --git a/internal/quic/acks_test.go b/quic/acks_test.go
similarity index 94%
rename from internal/quic/acks_test.go
rename to quic/acks_test.go
index 4f1032910f..d10f917ad9 100644
--- a/internal/quic/acks_test.go
+++ b/quic/acks_test.go
@@ -7,6 +7,7 @@
package quic
import (
+ "slices"
"testing"
"time"
)
@@ -198,7 +199,7 @@ func TestAcksSent(t *testing.T) {
if len(gotNums) == 0 {
wantDelay = 0
}
- if !slicesEqual(gotNums, test.wantAcks) || gotDelay != wantDelay {
+ if !slices.Equal(gotNums, test.wantAcks) || gotDelay != wantDelay {
t.Errorf("acks.acksToSend(T+%v) = %v, %v; want %v, %v", delay, gotNums, gotDelay, test.wantAcks, wantDelay)
}
}
@@ -206,20 +207,6 @@ func TestAcksSent(t *testing.T) {
}
}
-// slicesEqual reports whether two slices are equal.
-// Replace this with slices.Equal once the module go.mod is go1.17 or newer.
-func slicesEqual[E comparable](s1, s2 []E) bool {
- if len(s1) != len(s2) {
- return false
- }
- for i := range s1 {
- if s1[i] != s2[i] {
- return false
- }
- }
- return true
-}
-
func TestAcksDiscardAfterAck(t *testing.T) {
acks := ackState{}
now := time.Now()
diff --git a/internal/quic/atomic_bits.go b/quic/atomic_bits.go
similarity index 100%
rename from internal/quic/atomic_bits.go
rename to quic/atomic_bits.go
diff --git a/quic/bench_test.go b/quic/bench_test.go
new file mode 100644
index 0000000000..636b71327e
--- /dev/null
+++ b/quic/bench_test.go
@@ -0,0 +1,170 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "math"
+ "sync"
+ "testing"
+)
+
+// BenchmarkThroughput is based on the crypto/tls benchmark of the same name.
+func BenchmarkThroughput(b *testing.B) {
+ for size := 1; size <= 64; size <<= 1 {
+ name := fmt.Sprintf("%dMiB", size)
+ b.Run(name, func(b *testing.B) {
+ throughput(b, int64(size<<20))
+ })
+ }
+}
+
+func throughput(b *testing.B, totalBytes int64) {
+ // Same buffer size as crypto/tls's BenchmarkThroughput, for consistency.
+ const bufsize = 32 << 10
+
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ go func() {
+ buf := make([]byte, bufsize)
+ for i := 0; i < b.N; i++ {
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ if _, err := io.CopyBuffer(sconn, sconn, buf); err != nil {
+ panic(fmt.Errorf("CopyBuffer: %v", err))
+ }
+ sconn.Close()
+ }
+ }()
+
+ b.SetBytes(totalBytes)
+ buf := make([]byte, bufsize)
+ chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
+ for i := 0; i < b.N; i++ {
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ closec := make(chan struct{})
+ go func() {
+ defer close(closec)
+ buf := make([]byte, bufsize)
+ if _, err := io.CopyBuffer(io.Discard, cconn, buf); err != nil {
+ panic(fmt.Errorf("Discard: %v", err))
+ }
+ }()
+ for j := 0; j < chunks; j++ {
+ _, err := cconn.Write(buf)
+ if err != nil {
+ b.Fatalf("Write: %v", err)
+ }
+ }
+ cconn.CloseWrite()
+ <-closec
+ cconn.Close()
+ }
+}
+
+func BenchmarkReadByte(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ buf := make([]byte, 1<<20)
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ for {
+ if _, err := sconn.Write(buf); err != nil {
+ break
+ }
+ sconn.Flush()
+ }
+ }()
+
+ b.SetBytes(1)
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Flush()
+ for i := 0; i < b.N; i++ {
+ _, err := cconn.ReadByte()
+ if err != nil {
+ b.Fatalf("ReadByte: %v", err)
+ }
+ }
+ cconn.Close()
+}
+
+func BenchmarkWriteByte(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ n, err := io.Copy(io.Discard, sconn)
+ if n != int64(b.N) || err != nil {
+ b.Errorf("server io.Copy() = %v, %v; want %v, nil", n, err, b.N)
+ }
+ }()
+
+ b.SetBytes(1)
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Flush()
+ for i := 0; i < b.N; i++ {
+ if err := cconn.WriteByte(0); err != nil {
+ b.Fatalf("WriteByte: %v", err)
+ }
+ }
+ cconn.Close()
+}
+
+func BenchmarkStreamCreation(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ go func() {
+ for i := 0; i < b.N; i++ {
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ sconn.Close()
+ }
+ }()
+
+ buf := make([]byte, 1)
+ for i := 0; i < b.N; i++ {
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Write(buf)
+ cconn.Flush()
+ cconn.Read(buf)
+ cconn.Close()
+ }
+}
diff --git a/quic/config.go b/quic/config.go
new file mode 100644
index 0000000000..d6aa87730f
--- /dev/null
+++ b/quic/config.go
@@ -0,0 +1,160 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "log/slog"
+ "math"
+ "time"
+
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+// A Config structure configures a QUIC endpoint.
+// A Config must not be modified after it has been passed to a QUIC function.
+// A Config may be reused; the quic package will also not modify it.
+type Config struct {
+ // TLSConfig is the endpoint's TLS configuration.
+ // It must be non-nil and include at least one certificate or else set GetCertificate.
+ TLSConfig *tls.Config
+
+ // MaxBidiRemoteStreams limits the number of simultaneous bidirectional streams
+ // a peer may open.
+ // If zero, the default value of 100 is used.
+ // If negative, the limit is zero.
+ MaxBidiRemoteStreams int64
+
+ // MaxUniRemoteStreams limits the number of simultaneous unidirectional streams
+ // a peer may open.
+ // If zero, the default value of 100 is used.
+ // If negative, the limit is zero.
+ MaxUniRemoteStreams int64
+
+ // MaxStreamReadBufferSize is the maximum amount of data sent by the peer that a
+ // stream will buffer for reading.
+ // If zero, the default value of 1MiB is used.
+ // If negative, the limit is zero.
+ MaxStreamReadBufferSize int64
+
+ // MaxStreamWriteBufferSize is the maximum amount of data a stream will buffer for
+ // sending to the peer.
+ // If zero, the default value of 1MiB is used.
+ // If negative, the limit is zero.
+ MaxStreamWriteBufferSize int64
+
+ // MaxConnReadBufferSize is the maximum amount of data sent by the peer that a
+ // connection will buffer for reading, across all streams.
+ // If zero, the default value of 1MiB is used.
+ // If negative, the limit is zero.
+ MaxConnReadBufferSize int64
+
+ // RequireAddressValidation may be set to true to enable address validation
+ // of client connections prior to starting the handshake.
+ //
+ // Enabling this setting reduces the amount of work packets with spoofed
+ // source address information can cause a server to perform,
+ // at the cost of increased handshake latency.
+ RequireAddressValidation bool
+
+ // StatelessResetKey is used to provide stateless reset of connections.
+ // A restart may leave an endpoint without access to the state of
+ // existing connections. Stateless reset permits an endpoint to respond
+ // to a packet for a connection it does not recognize.
+ //
+ // This field should be filled with random bytes.
+ // The contents should remain stable across restarts,
+ // to permit an endpoint to send a reset for
+ // connections created before a restart.
+ //
+ // The contents of the StatelessResetKey should not be exposed.
+ // An attacker can use knowledge of this field's value to
+ // reset existing connections.
+ //
+ // If this field is left as zero, stateless reset is disabled.
+ StatelessResetKey [32]byte
+
+ // HandshakeTimeout is the maximum time in which a connection handshake must complete.
+ // If zero, the default of 10 seconds is used.
+ // If negative, there is no handshake timeout.
+ HandshakeTimeout time.Duration
+
+ // MaxIdleTimeout is the maximum time after which an idle connection will be closed.
+ // If zero, the default of 30 seconds is used.
+ // If negative, idle connections are never closed.
+ //
+ // The idle timeout for a connection is the minimum of the maximum idle timeouts
+ // of the endpoints.
+ MaxIdleTimeout time.Duration
+
+ // KeepAlivePeriod is the time after which a packet will be sent to keep
+ // an idle connection alive.
+ // If zero, keep alive packets are not sent.
+ // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and
+ // half the connection idle timeout.
+ KeepAlivePeriod time.Duration
+
+ // QLogLogger receives qlog events.
+ //
+ // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03.
+ // This is not the latest version of the draft, but is the latest version supported
+ // by common event log viewers as of the time this paragraph was written.
+ //
+ // The qlog package contains a slog.Handler which serializes qlog events
+ // to a standard JSON representation.
+ QLogLogger *slog.Logger
+}
+
+// Clone returns a shallow clone of c, or nil if c is nil.
+// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint.
+func (c *Config) Clone() *Config {
+ n := *c
+ return &n
+}
+
+func configDefault[T ~int64](v, def, limit T) T {
+ switch {
+ case v == 0:
+ return def
+ case v < 0:
+ return 0
+ default:
+ return min(v, limit)
+ }
+}
+
+func (c *Config) maxBidiRemoteStreams() int64 {
+ return configDefault(c.MaxBidiRemoteStreams, 100, maxStreamsLimit)
+}
+
+func (c *Config) maxUniRemoteStreams() int64 {
+ return configDefault(c.MaxUniRemoteStreams, 100, maxStreamsLimit)
+}
+
+func (c *Config) maxStreamReadBufferSize() int64 {
+ return configDefault(c.MaxStreamReadBufferSize, 1<<20, quicwire.MaxVarint)
+}
+
+func (c *Config) maxStreamWriteBufferSize() int64 {
+ return configDefault(c.MaxStreamWriteBufferSize, 1<<20, quicwire.MaxVarint)
+}
+
+func (c *Config) maxConnReadBufferSize() int64 {
+ return configDefault(c.MaxConnReadBufferSize, 1<<20, quicwire.MaxVarint)
+}
+
+func (c *Config) handshakeTimeout() time.Duration {
+ return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64)
+}
+
+func (c *Config) maxIdleTimeout() time.Duration {
+ return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64)
+}
+
+func (c *Config) keepAlivePeriod() time.Duration {
+ return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64)
+}
diff --git a/internal/quic/config_test.go b/quic/config_test.go
similarity index 100%
rename from internal/quic/config_test.go
rename to quic/config_test.go
diff --git a/internal/quic/congestion_reno.go b/quic/congestion_reno.go
similarity index 83%
rename from internal/quic/congestion_reno.go
rename to quic/congestion_reno.go
index 982cbf4bb4..a539835247 100644
--- a/internal/quic/congestion_reno.go
+++ b/quic/congestion_reno.go
@@ -7,6 +7,8 @@
package quic
import (
+ "context"
+ "log/slog"
"math"
"time"
)
@@ -40,6 +42,9 @@ type ccReno struct {
// true if we haven't sent that packet yet.
sendOnePacketInRecovery bool
+ // inRecovery is set when we are in the recovery state.
+ inRecovery bool
+
// underutilized is set if the congestion window is underutilized
// due to insufficient application data, flow control limits, or
// anti-amplification limits.
@@ -100,12 +105,19 @@ func (c *ccReno) canSend() bool {
// congestion controller permits sending data, but no data is sent.
//
// https://www.rfc-editor.org/rfc/rfc9002#section-7.8
-func (c *ccReno) setUnderutilized(v bool) {
+func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) {
+ if c.underutilized == v {
+ return
+ }
+ oldState := c.state()
c.underutilized = v
+ if logEnabled(log, QLogLevelPacket) {
+ logCongestionStateUpdated(log, oldState, c.state())
+ }
}
// packetSent indicates that a packet has been sent.
-func (c *ccReno) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
if !sent.inFlight {
return
}
@@ -185,7 +197,11 @@ func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket,
}
// packetBatchEnd is called at the end of processing a batch of acked or lost packets.
-func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, maxAckDelay time.Duration) {
+func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) {
+ if logEnabled(log, QLogLevelPacket) {
+ oldState := c.state()
+ defer func() { logCongestionStateUpdated(log, oldState, c.state()) }()
+ }
if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) {
// Enter the recovery state.
// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2
@@ -196,8 +212,10 @@ func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState,
// Clear congestionPendingAcks to avoid increasing the congestion
// window based on acks in a frame that sends us into recovery.
c.congestionPendingAcks = 0
+ c.inRecovery = true
} else if c.congestionPendingAcks > 0 {
// We are in slow start or congestion avoidance.
+ c.inRecovery = false
if c.congestionWindow < c.slowStartThreshold {
// When the congestion window is less than the slow start threshold,
// we are in slow start and increase the window by the number of
@@ -253,3 +271,38 @@ func (c *ccReno) minimumCongestionWindow() int {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4
return 2 * c.maxDatagramSize
}
+
+func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) {
+ if oldState == newState {
+ return
+ }
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:congestion_state_updated",
+ slog.String("old", oldState.String()),
+ slog.String("new", newState.String()),
+ )
+}
+
+type congestionState string
+
+func (s congestionState) String() string { return string(s) }
+
+const (
+ congestionSlowStart = congestionState("slow_start")
+ congestionCongestionAvoidance = congestionState("congestion_avoidance")
+ congestionApplicationLimited = congestionState("application_limited")
+ congestionRecovery = congestionState("recovery")
+)
+
+func (c *ccReno) state() congestionState {
+ switch {
+ case c.inRecovery:
+ return congestionRecovery
+ case c.underutilized:
+ return congestionApplicationLimited
+ case c.congestionWindow < c.slowStartThreshold:
+ return congestionSlowStart
+ default:
+ return congestionCongestionAvoidance
+ }
+}
diff --git a/internal/quic/congestion_reno_test.go b/quic/congestion_reno_test.go
similarity index 99%
rename from internal/quic/congestion_reno_test.go
rename to quic/congestion_reno_test.go
index e9af6452ca..cda7a90a80 100644
--- a/internal/quic/congestion_reno_test.go
+++ b/quic/congestion_reno_test.go
@@ -470,7 +470,7 @@ func (c *ccTest) setRTT(smoothedRTT, rttvar time.Duration) {
func (c *ccTest) setUnderutilized(v bool) {
c.t.Helper()
c.t.Logf("set underutilized = %v", v)
- c.cc.setUnderutilized(v)
+ c.cc.setUnderutilized(nil, v)
}
func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket)) *sentPacket {
@@ -488,7 +488,7 @@ func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket
f(sent)
}
c.t.Logf("packet sent: num=%v.%v, size=%v", space, sent.num, sent.size)
- c.cc.packetSent(c.now, space, sent)
+ c.cc.packetSent(c.now, nil, space, sent)
return sent
}
@@ -519,7 +519,7 @@ func (c *ccTest) packetDiscarded(space numberSpace, sent *sentPacket) {
func (c *ccTest) packetBatchEnd(space numberSpace) {
c.t.Helper()
c.t.Logf("(end of batch)")
- c.cc.packetBatchEnd(c.now, space, &c.rtt, c.maxAckDelay)
+ c.cc.packetBatchEnd(c.now, nil, space, &c.rtt, c.maxAckDelay)
}
func (c *ccTest) wantCanSend(want bool) {
diff --git a/internal/quic/conn.go b/quic/conn.go
similarity index 66%
rename from internal/quic/conn.go
rename to quic/conn.go
index 9db00fe092..1f1cfa6d0a 100644
--- a/internal/quic/conn.go
+++ b/quic/conn.go
@@ -11,6 +11,7 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "log/slog"
"net/netip"
"time"
)
@@ -20,26 +21,23 @@ import (
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
side connSide
- listener *Listener
+ endpoint *Endpoint
config *Config
testHooks connTestHooks
peerAddr netip.AddrPort
+ localAddr netip.AddrPort
- msgc chan any
- donec chan struct{} // closed when conn loop exits
- exited bool // set to make the conn loop exit immediately
+ msgc chan any
+ donec chan struct{} // closed when conn loop exits
w packetWriter
acks [numberSpaceCount]ackState // indexed by number space
lifetime lifetimeState
+ idle idleState
connIDState connIDState
loss lossState
streams streamsState
-
- // idleTimeout is the time at which the connection will be closed due to inactivity.
- // https://www.rfc-editor.org/rfc/rfc9000#section-10.1
- maxIdleTimeout time.Duration
- idleTimeout time.Time
+ path pathState
// Packet protection keys, CRYPTO streams, and TLS state.
keysInitial fixedKeyPair
@@ -48,6 +46,9 @@ type Conn struct {
crypto [numberSpaceCount]cryptoStream
tls *tls.QUICConn
+ // retryToken is the token provided by the peer in a Retry packet.
+ retryToken []byte
+
// handshakeConfirmed is set when the handshake is confirmed.
// For server connections, it tracks sending HANDSHAKE_DONE.
handshakeConfirmed sentVal
@@ -57,60 +58,98 @@ type Conn struct {
// Tests only: Send a PING in a specific number space.
testSendPingSpace numberSpace
testSendPing sentVal
+
+ log *slog.Logger
}
// connTestHooks override conn behavior in tests.
type connTestHooks interface {
+ // init is called after a conn is created.
+ init()
+
+ // nextMessage is called to request the next event from msgc.
+ // Used to give tests control of the connection event loop.
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
+
+ // handleTLSEvent is called with each TLS event.
handleTLSEvent(tls.QUICEvent)
+
+ // newConnID is called to generate a new connection ID.
+ // Permits tests to generate consistent connection IDs rather than random ones.
newConnID(seq int64) ([]byte, error)
+
+ // waitUntil blocks until the until func returns true or the context is done.
+ // Used to synchronize asynchronous blocking operations in tests.
waitUntil(ctx context.Context, until func() bool) error
+
+ // timeNow returns the current time.
timeNow() time.Time
}
-func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener, hooks connTestHooks) (*Conn, error) {
+// newServerConnIDs is connection IDs associated with a new server connection.
+type newServerConnIDs struct {
+ srcConnID []byte // source from client's current Initial
+ dstConnID []byte // destination from client's current Initial
+ originalDstConnID []byte // destination from client's first Initial
+ retrySrcConnID []byte // source from server's Retry
+}
+
+func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) {
c := &Conn{
side: side,
- listener: l,
+ endpoint: e,
config: config,
- peerAddr: peerAddr,
+ peerAddr: unmapAddrPort(peerAddr),
msgc: make(chan any, 1),
donec: make(chan struct{}),
- testHooks: hooks,
- maxIdleTimeout: defaultMaxIdleTimeout,
- idleTimeout: now.Add(defaultMaxIdleTimeout),
peerAckDelayExponent: -1,
}
+ defer func() {
+ // If we hit an error in newConn, close donec so tests don't get stuck waiting for it.
+ // This is only relevant if we've got a bug, but it makes tracking that bug down
+ // much easier.
+ if conn == nil {
+ close(c.donec)
+ }
+ }()
// A one-element buffer allows us to wake a Conn's event loop as a
// non-blocking operation.
c.msgc = make(chan any, 1)
- var originalDstConnID []byte
+ if e.testHooks != nil {
+ e.testHooks.newConn(c)
+ }
+
+ // initialConnID is the connection ID used to generate Initial packet protection keys.
+ var initialConnID []byte
if c.side == clientSide {
if err := c.connIDState.initClient(c); err != nil {
return nil, err
}
initialConnID, _ = c.connIDState.dstConnID()
} else {
- if err := c.connIDState.initServer(c, initialConnID); err != nil {
+ initialConnID = cids.originalDstConnID
+ if cids.retrySrcConnID != nil {
+ initialConnID = cids.retrySrcConnID
+ }
+ if err := c.connIDState.initServer(c, cids); err != nil {
return nil, err
}
- originalDstConnID = initialConnID
}
- // The smallest allowed maximum QUIC datagram size is 1200 bytes.
// TODO: PMTU discovery.
- const maxDatagramSize = 1200
+ c.logConnectionStarted(cids.originalDstConnID, peerAddr)
c.keysAppData.init()
- c.loss.init(c.side, maxDatagramSize, now)
+ c.loss.init(c.side, smallestMaxDatagramSize, now)
c.streamsInit()
c.lifetimeInit()
+ c.restartIdleTimer(now)
- // TODO: retry_source_connection_id
- if err := c.startTLS(now, initialConnID, transportParameters{
+ if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
- originalDstConnID: originalDstConnID,
+ originalDstConnID: cids.originalDstConnID,
+ retrySrcConnID: cids.retrySrcConnID,
ackDelayExponent: ackDelayExponent,
maxUDPPayloadSize: maxUDPPayloadSize,
maxAckDelay: maxAckDelay,
@@ -126,6 +165,9 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
return nil, err
}
+ if c.testHooks != nil {
+ c.testHooks.init()
+ }
go c.loop(now)
return c, nil
}
@@ -134,6 +176,21 @@ func (c *Conn) String() string {
return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
}
+// LocalAddr returns the local network address, if known.
+func (c *Conn) LocalAddr() netip.AddrPort {
+ return c.localAddr
+}
+
+// RemoteAddr returns the remote network address, if known.
+func (c *Conn) RemoteAddr() netip.AddrPort {
+ return c.peerAddr
+}
+
+// ConnectionState returns basic TLS details about the connection.
+func (c *Conn) ConnectionState() tls.ConnectionState {
+ return c.tls.ConnectionState()
+}
+
// confirmHandshake is called when the handshake is confirmed.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
func (c *Conn) confirmHandshake(now time.Time) {
@@ -147,13 +204,14 @@ func (c *Conn) confirmHandshake(now time.Time) {
if c.side == serverSide {
// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
c.handshakeConfirmed.setUnsent()
- c.listener.serverConnEstablished(c)
+ c.endpoint.serverConnEstablished(c)
} else {
// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
// to the received state, indicating that the handshake is confirmed and we
// don't need to send anything.
c.handshakeConfirmed.setReceived()
}
+ c.restartIdleTimer(now)
c.loss.confirmHandshake()
// "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
@@ -163,18 +221,22 @@ func (c *Conn) confirmHandshake(now time.Time) {
// discardKeys discards unused packet protection keys.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
func (c *Conn) discardKeys(now time.Time, space numberSpace) {
+ if err := c.crypto[space].discardKeys(); err != nil {
+ c.abort(now, err)
+ }
switch space {
case initialSpace:
c.keysInitial.discard()
case handshakeSpace:
c.keysHandshake.discard()
}
- c.loss.discardKeys(now, space)
+ c.loss.discardKeys(now, c.log, space)
}
// receiveTransportParameters applies transport parameters sent by the peer.
func (c *Conn) receiveTransportParameters(p transportParameters) error {
- if err := c.connIDState.validateTransportParameters(c.side, p); err != nil {
+ isRetry := c.retryToken != nil
+ if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil {
return err
}
c.streams.outflow.setMaxData(p.initialMaxData)
@@ -183,6 +245,7 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error {
c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal
c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote
c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
+ c.receivePeerMaxIdleTimeout(p.maxIdleTimeout)
c.peerAckDelayExponent = p.ackDelayExponent
c.loss.setMaxAckDelay(p.maxAckDelay)
if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
@@ -195,13 +258,14 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error {
resetToken [16]byte
)
copy(resetToken[:], p.preferredAddrResetToken)
- if err := c.connIDState.handleNewConnID(seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
+ if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
return err
}
}
-
- // TODO: Many more transport parameters to come.
-
+ // TODO: stateless_reset_token
+ // TODO: max_udp_payload_size
+ // TODO: disable_active_migration
+ // TODO: preferred_address
return nil
}
@@ -210,6 +274,8 @@ type (
wakeEvent struct{}
)
+var errIdleTimeout = errors.New("idle timeout")
+
// loop is the connection main loop.
//
// Except where otherwise noted, all connection state is owned by the loop goroutine.
@@ -217,9 +283,7 @@ type (
// The loop processes messages from c.msgc and timer events.
// Other goroutines may examine or modify conn state by sending the loop funcs to execute.
func (c *Conn) loop(now time.Time) {
- defer close(c.donec)
- defer c.tls.Close()
- defer c.listener.connDrained(c)
+ defer c.cleanup()
// The connection timer sends a message to the connection loop on expiry.
// We need to give it an expiry when creating it, so set the initial timeout to
@@ -236,14 +300,14 @@ func (c *Conn) loop(now time.Time) {
defer timer.Stop()
}
- for !c.exited {
+ for c.lifetime.state != connStateDone {
sendTimeout := c.maybeSend(now) // try sending
// Note that we only need to consider the ack timer for the App Data space,
// since the Initial and Handshake spaces always ack immediately.
nextTimeout := sendTimeout
- nextTimeout = firstTime(nextTimeout, c.idleTimeout)
- if !c.isClosingOrDraining() {
+ nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout)
+ if c.isAlive() {
nextTimeout = firstTime(nextTimeout, c.loss.timer)
nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck)
} else {
@@ -273,15 +337,17 @@ func (c *Conn) loop(now time.Time) {
}
switch m := m.(type) {
case *datagram:
- c.handleDatagram(now, m)
+ if !c.handleDatagram(now, m) {
+ if c.logEnabled(QLogLevelPacket) {
+ c.logPacketDropped(m)
+ }
+ }
m.recycle()
case timerEvent:
// A connection timer has expired.
- if !now.Before(c.idleTimeout) {
- // "[...] the connection is silently closed and
- // its state is discarded [...]"
- // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1
- c.exited = true
+ if c.idleAdvance(now) {
+ // The connection idle timer has expired.
+ c.abortImmediately(now, errIdleTimeout)
return
}
c.loss.advance(now, c.handleAckOrLoss)
@@ -301,6 +367,13 @@ func (c *Conn) loop(now time.Time) {
}
}
+func (c *Conn) cleanup() {
+ c.logConnectionClosed()
+ c.endpoint.connDrained(c)
+ c.tls.Close()
+ close(c.donec)
+}
+
// sendMsg sends a message to the conn's loop.
// It does not wait for the message to be processed.
// The conn may close before processing the message, in which case it is lost.
@@ -320,12 +393,37 @@ func (c *Conn) wake() {
}
// runOnLoop executes a function within the conn's loop goroutine.
-func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error {
+func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
donec := make(chan struct{})
- c.sendMsg(func(now time.Time, c *Conn) {
+ msg := func(now time.Time, c *Conn) {
defer close(donec)
f(now, c)
- })
+ }
+ if c.testHooks != nil {
+ // In tests, we can't rely on being able to send a message immediately:
+ // c.msgc might be full, and testConnHooks.nextMessage might be waiting
+ // for us to block before it processes the next message.
+ // To avoid a deadlock, we send the message in waitUntil.
+ // If msgc is empty, the message is buffered.
+ // If msgc is full, we block and let nextMessage process the queue.
+ msgc := c.msgc
+ c.testHooks.waitUntil(ctx, func() bool {
+ for {
+ select {
+ case msgc <- msg:
+ msgc = nil // send msg only once
+ case <-donec:
+ return true
+ case <-c.donec:
+ return true
+ default:
+ return false
+ }
+ }
+ })
+ } else {
+ c.sendMsg(msg)
+ }
select {
case <-donec:
case <-c.donec:
diff --git a/internal/quic/conn_async_test.go b/quic/conn_async_test.go
similarity index 94%
rename from internal/quic/conn_async_test.go
rename to quic/conn_async_test.go
index dc2a57f9dd..4671f8340e 100644
--- a/internal/quic/conn_async_test.go
+++ b/quic/conn_async_test.go
@@ -41,7 +41,7 @@ type asyncOp[T any] struct {
err error
caller string
- state *asyncTestState
+ tc *testConn
donec chan struct{}
cancelFunc context.CancelFunc
}
@@ -55,7 +55,7 @@ func (a *asyncOp[T]) cancel() {
default:
}
a.cancelFunc()
- <-a.state.notify
+ <-a.tc.asyncTestState.notify
select {
case <-a.donec:
default:
@@ -73,6 +73,7 @@ var errNotDone = errors.New("async op is not done")
// control over the progress of operations, an asyncOp can only
// become done in reaction to the test taking some action.
func (a *asyncOp[T]) result() (v T, err error) {
+ a.tc.wait()
select {
case <-a.donec:
return a.v, a.err
@@ -94,8 +95,8 @@ type asyncContextKey struct{}
// The function f should call a blocking function such as
// Stream.Write or Conn.AcceptStream and return its result.
// It must use the provided context.
-func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[T] {
- as := &ts.asyncTestState
+func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[T] {
+ as := &tc.asyncTestState
if as.notify == nil {
as.notify = make(chan struct{})
as.mu.Lock()
@@ -106,7 +107,7 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[
ctx := context.WithValue(context.Background(), asyncContextKey{}, true)
ctx, cancel := context.WithCancel(ctx)
a := &asyncOp[T]{
- state: as,
+ tc: tc,
caller: fmt.Sprintf("%v:%v", filepath.Base(file), line),
donec: make(chan struct{}),
cancelFunc: cancel,
@@ -116,14 +117,15 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[
close(a.donec)
as.notify <- struct{}{}
}()
- ts.t.Cleanup(func() {
+ tc.t.Cleanup(func() {
if _, err := a.result(); err == errNotDone {
- ts.t.Errorf("%v: async operation is still executing at end of test", a.caller)
+ tc.t.Errorf("%v: async operation is still executing at end of test", a.caller)
a.cancel()
}
})
// Wait for the operation to either finish or block.
<-as.notify
+ tc.wait()
return a
}
diff --git a/quic/conn_close.go b/quic/conn_close.go
new file mode 100644
index 0000000000..cd8d7e3c5a
--- /dev/null
+++ b/quic/conn_close.go
@@ -0,0 +1,342 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "errors"
+ "time"
+)
+
+// connState is the state of a connection.
+type connState int
+
+const (
+ // A connection is alive when it is first created.
+ connStateAlive = connState(iota)
+
+ // The connection has received a CONNECTION_CLOSE frame from the peer,
+ // and has not yet sent a CONNECTION_CLOSE in response.
+ //
+ // We will send a CONNECTION_CLOSE, and then enter the draining state.
+ connStatePeerClosed
+
+ // The connection is in the closing state.
+ //
+ // We will send CONNECTION_CLOSE frames to the peer
+ // (once upon entering the closing state, and possibly again in response to peer packets).
+ //
+ // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state.
+ // Otherwise, we will eventually time out and move to the done state.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1
+ connStateClosing
+
+ // The connection is in the draining state.
+ //
+ // We will neither send packets nor process received packets.
+ // When the drain timer expires, we move to the done state.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2
+ connStateDraining
+
+ // The connection is done, and the conn loop will exit.
+ connStateDone
+)
+
+// lifetimeState tracks the state of a connection.
+//
+// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps
+// reason about operations that cause state transitions.
+type lifetimeState struct {
+ state connState
+
+ readyc chan struct{} // closed when TLS handshake completes
+ donec chan struct{} // closed when finalErr is set
+
+ localErr error // error sent to the peer
+ finalErr error // error sent by the peer, or transport error; set before closing donec
+
+ connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame
+ connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent
+ drainEndTime time.Time // time the connection exits the draining state
+}
+
+func (c *Conn) lifetimeInit() {
+ c.lifetime.readyc = make(chan struct{})
+ c.lifetime.donec = make(chan struct{})
+}
+
+var (
+ errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE")
+ errConnClosed = errors.New("connection closed")
+)
+
+// advance is called when time passes.
+func (c *Conn) lifetimeAdvance(now time.Time) (done bool) {
+ if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) {
+ return false
+ }
+ // The connection drain period has ended, and we can shut down.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7
+ c.lifetime.drainEndTime = time.Time{}
+ if c.lifetime.state != connStateDraining {
+ // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer.
+ c.setFinalError(errNoPeerResponse)
+ }
+ c.setState(now, connStateDone)
+ return true
+}
+
+// setState sets the conn state.
+func (c *Conn) setState(now time.Time, state connState) {
+ if c.lifetime.state == state {
+ return
+ }
+ c.lifetime.state = state
+ switch state {
+ case connStateClosing, connStateDraining:
+ if c.lifetime.drainEndTime.IsZero() {
+ c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
+ }
+ case connStateDone:
+ c.setFinalError(nil)
+ }
+ if state != connStateAlive {
+ c.streamsCleanup()
+ }
+}
+
+// confirmHandshake is called when the TLS handshake completes.
+func (c *Conn) handshakeDone() {
+ close(c.lifetime.readyc)
+}
+
+// isDraining reports whether the conn is in the draining state.
+//
+// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame.
+// The endpoint will no longer send any packets, but we retain knowledge of the connection
+// until the end of the drain period to ensure we discard packets for the connection
+// rather than treating them as starting a new connection.
+//
+// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
+func (c *Conn) isDraining() bool {
+ switch c.lifetime.state {
+ case connStateDraining, connStateDone:
+ return true
+ }
+ return false
+}
+
+// isAlive reports whether the conn is handling packets.
+func (c *Conn) isAlive() bool {
+ return c.lifetime.state == connStateAlive
+}
+
+// sendOK reports whether the conn can send frames at this time.
+func (c *Conn) sendOK(now time.Time) bool {
+ switch c.lifetime.state {
+ case connStateAlive:
+ return true
+ case connStatePeerClosed:
+ if c.lifetime.localErr == nil {
+ // We're waiting for the user to close the connection, providing us with
+ // a final status to send to the peer.
+ return false
+ }
+ // We should send a CONNECTION_CLOSE.
+ return true
+ case connStateClosing:
+ if c.lifetime.connCloseSentTime.IsZero() {
+ return true
+ }
+ maxRecvTime := c.acks[initialSpace].maxRecvTime
+ if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) {
+ maxRecvTime = t
+ }
+ if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) {
+ maxRecvTime = t
+ }
+ if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) {
+ // After sending CONNECTION_CLOSE, ignore packets from the peer for
+ // a delay. On the next packet received after the delay, send another
+ // CONNECTION_CLOSE.
+ return false
+ }
+ return true
+ case connStateDraining:
+ // We are in the draining state, and will send no more packets.
+ return false
+ case connStateDone:
+ return false
+ default:
+ panic("BUG: unhandled connection state")
+ }
+}
+
+// sentConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer.
+func (c *Conn) sentConnectionClose(now time.Time) {
+ switch c.lifetime.state {
+ case connStatePeerClosed:
+ c.enterDraining(now)
+ }
+ if c.lifetime.connCloseSentTime.IsZero() {
+ // Set the initial delay before we will send another CONNECTION_CLOSE.
+ //
+ // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames,
+ // but leaves the implementation of the limit up to us. Here, we start
+ // with the same delay as the PTO timer (RFC 9002, Section 6.2.1),
+ // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent.
+ c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity)
+ } else if !c.lifetime.connCloseSentTime.Equal(now) {
+ // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames
+ // coalesced into the same datagram. We only want to increase the delay once.
+ c.lifetime.connCloseDelay *= 2
+ }
+ c.lifetime.connCloseSentTime = now
+}
+
+// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer.
+func (c *Conn) handlePeerConnectionClose(now time.Time, err error) {
+ c.setFinalError(err)
+ switch c.lifetime.state {
+ case connStateAlive:
+ c.setState(now, connStatePeerClosed)
+ case connStatePeerClosed:
+ // Duplicate CONNECTION_CLOSE, ignore.
+ case connStateClosing:
+ if c.lifetime.connCloseSentTime.IsZero() {
+ c.setState(now, connStatePeerClosed)
+ } else {
+ c.setState(now, connStateDraining)
+ }
+ case connStateDraining:
+ case connStateDone:
+ }
+}
+
+// setFinalError records the final connection status we report to the user.
+func (c *Conn) setFinalError(err error) {
+ select {
+ case <-c.lifetime.donec:
+ return // already set
+ default:
+ }
+ c.lifetime.finalErr = err
+ close(c.lifetime.donec)
+}
+
+// finalError returns the final connection status reported to the user,
+// or nil if a final status has not yet been set.
+func (c *Conn) finalError() error {
+ select {
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ default:
+ }
+ return nil
+}
+
+func (c *Conn) waitReady(ctx context.Context) error {
+ select {
+ case <-c.lifetime.readyc:
+ return nil
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ default:
+ }
+ select {
+ case <-c.lifetime.readyc:
+ return nil
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// Close closes the connection.
+//
+// Close is equivalent to:
+//
+// conn.Abort(nil)
+// err := conn.Wait(context.Background())
+func (c *Conn) Close() error {
+ c.Abort(nil)
+ <-c.lifetime.donec
+ return c.lifetime.finalErr
+}
+
+// Wait waits for the peer to close the connection.
+//
+// If the connection is closed locally and the peer does not close its end of the connection,
+// Wait will return with a non-nil error after the drain period expires.
+//
+// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil.
+// If the peer closes the connection with an application error, Wait returns an ApplicationError
+// containing the peer's error code and reason.
+// If the peer closes the connection with any other status, Wait returns a non-nil error.
+func (c *Conn) Wait(ctx context.Context) error {
+ if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil {
+ return err
+ }
+ return c.lifetime.finalErr
+}
+
+// Abort closes the connection and returns immediately.
+//
+// If err is nil, Abort sends a transport error of NO_ERROR to the peer.
+// If err is an ApplicationError, Abort sends its error code and text.
+// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text.
+func (c *Conn) Abort(err error) {
+ if err == nil {
+ err = localTransportError{code: errNo}
+ }
+ c.sendMsg(func(now time.Time, c *Conn) {
+ c.enterClosing(now, err)
+ })
+}
+
+// abort terminates a connection with an error.
+func (c *Conn) abort(now time.Time, err error) {
+ c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE
+ c.enterClosing(now, err)
+}
+
+// abortImmediately terminates a connection.
+// The connection does not send a CONNECTION_CLOSE, and skips the draining period.
+func (c *Conn) abortImmediately(now time.Time, err error) {
+ c.setFinalError(err)
+ c.setState(now, connStateDone)
+}
+
+// enterClosing starts an immediate close.
+// We will send a CONNECTION_CLOSE to the peer and wait for their response.
+func (c *Conn) enterClosing(now time.Time, err error) {
+ switch c.lifetime.state {
+ case connStateAlive:
+ c.lifetime.localErr = err
+ c.setState(now, connStateClosing)
+ case connStatePeerClosed:
+ c.lifetime.localErr = err
+ }
+}
+
+// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE.
+func (c *Conn) enterDraining(now time.Time) {
+ switch c.lifetime.state {
+ case connStateAlive, connStatePeerClosed, connStateClosing:
+ c.setState(now, connStateDraining)
+ }
+}
+
+// exit fully terminates a connection immediately.
+func (c *Conn) exit() {
+ c.sendMsg(func(now time.Time, c *Conn) {
+ c.abortImmediately(now, errors.New("connection closed"))
+ })
+}
diff --git a/internal/quic/conn_close_test.go b/quic/conn_close_test.go
similarity index 67%
rename from internal/quic/conn_close_test.go
rename to quic/conn_close_test.go
index 20c00e754c..2139750119 100644
--- a/internal/quic/conn_close_test.go
+++ b/quic/conn_close_test.go
@@ -15,7 +15,9 @@ import (
)
func TestConnCloseResponseBackoff(t *testing.T) {
- tc := newTestConn(t, clientSide)
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ clear(c.StatelessResetKey[:])
+ })
tc.handshake()
tc.conn.Abort(nil)
@@ -68,7 +70,8 @@ func TestConnCloseResponseBackoff(t *testing.T) {
}
func TestConnCloseWithPeerResponse(t *testing.T) {
- tc := newTestConn(t, clientSide)
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
tc.handshake()
tc.conn.Abort(nil)
@@ -97,10 +100,19 @@ func TestConnCloseWithPeerResponse(t *testing.T) {
if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) {
t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr)
}
+
+ tc.advance(1 * time.Second) // long enough to exit the draining state
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "application",
+ },
+ })
}
func TestConnClosePeerCloses(t *testing.T) {
- tc := newTestConn(t, clientSide)
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
tc.handshake()
wantErr := &ApplicationError{
@@ -126,6 +138,14 @@ func TestConnClosePeerCloses(t *testing.T) {
code: 9,
reason: "because",
})
+
+ tc.advance(1 * time.Second) // long enough to exit the draining state
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "application",
+ },
+ })
}
func TestConnCloseReceiveInInitial(t *testing.T) {
@@ -184,3 +204,79 @@ func TestConnCloseReceiveInHandshake(t *testing.T) {
})
tc.wantIdle("no more frames to send")
}
+
+func TestConnCloseClosedByEndpoint(t *testing.T) {
+ ctx := canceledContext()
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+
+ tc.endpoint.e.Close(ctx)
+ tc.wantFrame("endpoint closes connection before exiting",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errNo,
+ })
+}
+
+func testConnCloseUnblocks(t *testing.T, f func(context.Context, *testConn) error, opts ...any) {
+ tc := newTestConn(t, clientSide, opts...)
+ tc.handshake()
+ op := runAsync(tc, func(ctx context.Context) (struct{}, error) {
+ return struct{}{}, f(ctx, tc)
+ })
+ if _, err := op.result(); err != errNotDone {
+ t.Fatalf("before abort, op = %v, want errNotDone", err)
+ }
+ tc.conn.Abort(nil)
+ if _, err := op.result(); err == nil || err == errNotDone {
+ t.Fatalf("after abort, op = %v, want error", err)
+ }
+}
+
+func TestConnCloseUnblocksAcceptStream(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ _, err := tc.conn.AcceptStream(ctx)
+ return err
+ }, permissiveTransportParameters)
+}
+
+func TestConnCloseUnblocksNewStream(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ _, err := tc.conn.NewStream(ctx)
+ return err
+ })
+}
+
+func TestConnCloseUnblocksStreamRead(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetReadContext(ctx)
+ buf := make([]byte, 16)
+ _, err := s.Read(buf)
+ return err
+ }, permissiveTransportParameters)
+}
+
+func TestConnCloseUnblocksStreamWrite(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetWriteContext(ctx)
+ buf := make([]byte, 32)
+ _, err := s.Write(buf)
+ return err
+ }, permissiveTransportParameters, func(c *Config) {
+ c.MaxStreamWriteBufferSize = 16
+ })
+}
+
+func TestConnCloseUnblocksStreamClose(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetWriteContext(ctx)
+ buf := make([]byte, 16)
+ _, err := s.Write(buf)
+ if err != nil {
+ return err
+ }
+ return s.Close()
+ }, permissiveTransportParameters)
+}
diff --git a/internal/quic/conn_flow.go b/quic/conn_flow.go
similarity index 97%
rename from internal/quic/conn_flow.go
rename to quic/conn_flow.go
index 4f1ab6eafc..8b69ef7dba 100644
--- a/internal/quic/conn_flow.go
+++ b/quic/conn_flow.go
@@ -90,7 +90,10 @@ func (c *Conn) shouldUpdateFlowControl(credit int64) bool {
func (c *Conn) handleStreamBytesReceived(n int64) error {
c.streams.inflow.usedLimit += n
if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit {
- return localTransportError(errFlowControl)
+ return localTransportError{
+ code: errFlowControl,
+ reason: "stream exceeded flow control limit",
+ }
}
return nil
}
diff --git a/internal/quic/conn_flow_test.go b/quic/conn_flow_test.go
similarity index 90%
rename from internal/quic/conn_flow_test.go
rename to quic/conn_flow_test.go
index 03e0757a6d..260684bdbc 100644
--- a/internal/quic/conn_flow_test.go
+++ b/quic/conn_flow_test.go
@@ -12,39 +12,34 @@ import (
)
func TestConnInflowReturnOnRead(t *testing.T) {
- ctx := canceledContext()
tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) {
c.MaxConnReadBufferSize = 64
})
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
- data: make([]byte, 64),
+ data: make([]byte, 8),
})
- const readSize = 8
- if n, err := s.ReadContext(ctx, make([]byte, readSize)); n != readSize || err != nil {
- t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize)
- }
- tc.wantFrame("available window increases, send a MAX_DATA",
- packetType1RTT, debugFrameMaxData{
- max: 64 + readSize,
- })
- if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64-readSize || err != nil {
- t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize)
+ if n, err := s.Read(make([]byte, 8)); n != 8 || err != nil {
+ t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 8)
}
tc.wantFrame("available window increases, send a MAX_DATA",
packetType1RTT, debugFrameMaxData{
- max: 128,
+ max: 64 + 8,
})
// Peer can write up to the new limit.
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
- off: 64,
+ off: 8,
data: make([]byte, 64),
})
- tc.wantIdle("connection is idle")
- if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64 || err != nil {
- t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64)
+ if n, err := s.Read(make([]byte, 64+1)); n != 64 {
+ t.Fatalf("s.Read() = %v, %v; want %v, anything", n, err, 64)
}
+ tc.wantFrame("available window increases, send a MAX_DATA",
+ packetType1RTT, debugFrameMaxData{
+ max: 64 + 8 + 64,
+ })
+ tc.wantIdle("connection is idle")
}
func TestConnInflowReturnOnRacingReads(t *testing.T) {
@@ -64,11 +59,11 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
tc.ignoreFrame(frameTypeAck)
tc.writeFrames(packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, uniStream, 0),
- data: make([]byte, 32),
+ data: make([]byte, 16),
})
tc.writeFrames(packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, uniStream, 1),
- data: make([]byte, 32),
+ data: make([]byte, 1),
})
s1, err := tc.conn.AcceptStream(ctx)
if err != nil {
@@ -79,10 +74,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
t.Fatalf("conn.AcceptStream() = %v", err)
}
read1 := runAsync(tc, func(ctx context.Context) (int, error) {
- return s1.ReadContext(ctx, make([]byte, 16))
+ return s1.Read(make([]byte, 16))
})
read2 := runAsync(tc, func(ctx context.Context) (int, error) {
- return s2.ReadContext(ctx, make([]byte, 1))
+ return s2.Read(make([]byte, 1))
})
// This MAX_DATA might extend the window by 16 or 17, depending on
// whether the second write occurs before the update happens.
@@ -90,10 +85,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
packetType1RTT, debugFrameMaxData{})
tc.wantIdle("redundant MAX_DATA is not sent")
if _, err := read1.result(); err != nil {
- t.Errorf("ReadContext #1 = %v", err)
+ t.Errorf("Read #1 = %v", err)
}
if _, err := read2.result(); err != nil {
- t.Errorf("ReadContext #2 = %v", err)
+ t.Errorf("Read #2 = %v", err)
}
}
@@ -204,7 +199,6 @@ func TestConnInflowResetViolation(t *testing.T) {
}
func TestConnInflowMultipleStreams(t *testing.T) {
- ctx := canceledContext()
tc := newTestConn(t, serverSide, func(c *Config) {
c.MaxConnReadBufferSize = 128
})
@@ -220,21 +214,26 @@ func TestConnInflowMultipleStreams(t *testing.T) {
} {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: id,
- data: make([]byte, 32),
+ data: make([]byte, 1),
})
- s, err := tc.conn.AcceptStream(ctx)
- if err != nil {
- t.Fatalf("AcceptStream() = %v", err)
- }
+ s := tc.acceptStream()
streams = append(streams, s)
- if n, err := s.ReadContext(ctx, make([]byte, 1)); err != nil || n != 1 {
+ if n, err := s.Read(make([]byte, 1)); err != nil || n != 1 {
t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err)
}
}
tc.wantIdle("streams have read data, but not enough to update MAX_DATA")
- if n, err := streams[0].ReadContext(ctx, make([]byte, 32)); err != nil || n != 31 {
- t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err)
+ for _, s := range streams {
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: 1,
+ data: make([]byte, 31),
+ })
+ }
+
+ if n, err := streams[0].Read(make([]byte, 32)); n != 31 {
+ t.Fatalf("s.Read() = %v, %v; want 31, anything", n, err)
}
tc.wantFrame("read enough data to trigger a MAX_DATA update",
packetType1RTT, debugFrameMaxData{
@@ -262,6 +261,7 @@ func TestConnOutflowBlocked(t *testing.T) {
if n != len(data) || err != nil {
t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
+ s.Flush()
tc.wantFrame("stream writes data up to MAX_DATA limit",
packetType1RTT, debugFrameStream{
@@ -310,6 +310,7 @@ func TestConnOutflowMaxDataDecreases(t *testing.T) {
if n != len(data) || err != nil {
t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
+ s.Flush()
tc.wantFrame("stream writes data up to MAX_DATA limit",
packetType1RTT, debugFrameStream{
@@ -337,7 +338,9 @@ func TestConnOutflowMaxDataRoundRobin(t *testing.T) {
}
s1.Write(make([]byte, 10))
+ s1.Flush()
s2.Write(make([]byte, 10))
+ s2.Flush()
tc.writeFrames(packetType1RTT, debugFrameMaxData{
max: 1,
@@ -378,6 +381,7 @@ func TestConnOutflowMetaAndData(t *testing.T) {
data := makeTestData(32)
s.Write(data)
+ s.Flush()
s.CloseRead()
tc.wantFrame("CloseRead sends a STOP_SENDING, not flow controlled",
@@ -405,6 +409,7 @@ func TestConnOutflowResentData(t *testing.T) {
data := makeTestData(15)
s.Write(data[:8])
+ s.Flush()
tc.wantFrame("data is under MAX_DATA limit, all sent",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -421,6 +426,7 @@ func TestConnOutflowResentData(t *testing.T) {
})
s.Write(data[8:])
+ s.Flush()
tc.wantFrame("new data is sent up to the MAX_DATA limit",
packetType1RTT, debugFrameStream{
id: s.id,
diff --git a/internal/quic/conn_id.go b/quic/conn_id.go
similarity index 55%
rename from internal/quic/conn_id.go
rename to quic/conn_id.go
index 045e646ac1..2d50f14fa6 100644
--- a/internal/quic/conn_id.go
+++ b/quic/conn_id.go
@@ -9,6 +9,7 @@ package quic
import (
"bytes"
"crypto/rand"
+ "slices"
)
// connIDState is a conn's connection IDs.
@@ -22,11 +23,22 @@ type connIDState struct {
//
// These are []connID rather than []*connID to minimize allocations.
local []connID
- remote []connID
+ remote []remoteConnID
nextLocalSeq int64
- retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
- peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter
+ peerActiveConnIDLimit int64 // peer's active_connection_id_limit
+
+ // Handling of retirement of remote connection IDs.
+ // The rangesets track ID sequence numbers.
+ // IDs in need of retirement are added to remoteRetiring,
+ // moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame,
+ // and removed from the set once retirement completes.
+ retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
+ remoteRetiring rangeset[int64] // remote IDs in need of retirement
+ remoteRetiringSent rangeset[int64] // remote IDs waiting for ack of retirement
+
+ originalDstConnID []byte // expected original_destination_connection_id param
+ retrySrcConnID []byte // expected retry_source_connection_id param
needSend bool
}
@@ -42,9 +54,6 @@ type connID struct {
// For the transient destination ID in a client's Initial packet, this is -1.
seq int64
- // retired is set when the connection ID is retired.
- retired bool
-
// send is set when the connection ID's state needs to be sent to the peer.
//
// For local IDs, this indicates a new ID that should be sent
@@ -55,6 +64,12 @@ type connID struct {
send sentVal
}
+// A remoteConnID is a connection ID and stateless reset token.
+type remoteConnID struct {
+ connID
+ resetToken statelessResetToken
+}
+
func (s *connIDState) initClient(c *Conn) error {
// Client chooses its initial connection ID, and sends it
// in the Source Connection ID field of the first Initial packet.
@@ -67,6 +82,9 @@ func (s *connIDState) initClient(c *Conn) error {
cid: locid,
})
s.nextLocalSeq = 1
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.addConnID(c, locid)
+ })
// Client chooses an initial, transient connection ID for the server,
// and sends it in the Destination Connection ID field of the first Initial packet.
@@ -74,22 +92,24 @@ func (s *connIDState) initClient(c *Conn) error {
if err != nil {
return err
}
- s.remote = append(s.remote, connID{
- seq: -1,
- cid: remid,
+ s.remote = append(s.remote, remoteConnID{
+ connID: connID{
+ seq: -1,
+ cid: remid,
+ },
})
- const retired = false
- c.listener.connIDsChanged(c, retired, s.local[:])
+ s.originalDstConnID = remid
return nil
}
-func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
+func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
+ dstConnID := cloneBytes(cids.dstConnID)
// Client-chosen, transient connection ID received in the first Initial packet.
// The server will not use this as the Source Connection ID of packets it sends,
// but remembers it because it may receive packets sent to this destination.
s.local = append(s.local, connID{
seq: -1,
- cid: cloneBytes(dstConnID),
+ cid: dstConnID,
})
// Server chooses a connection ID, and sends it in the Source Connection ID of
@@ -103,8 +123,18 @@ func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
cid: locid,
})
s.nextLocalSeq = 1
- const retired = false
- c.listener.connIDsChanged(c, retired, s.local[:])
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.addConnID(c, dstConnID)
+ conns.addConnID(c, locid)
+ })
+
+ // Client chose its own connection ID.
+ s.remote = append(s.remote, remoteConnID{
+ connID: connID{
+ seq: 0,
+ cid: cloneBytes(cids.srcConnID),
+ },
+ })
return nil
}
@@ -120,13 +150,22 @@ func (s *connIDState) srcConnID() []byte {
// dstConnID is the Destination Connection ID to use in a sent packet.
func (s *connIDState) dstConnID() (cid []byte, ok bool) {
for i := range s.remote {
- if !s.remote[i].retired {
- return s.remote[i].cid, true
- }
+ return s.remote[i].cid, true
}
return nil, false
}
+// isValidStatelessResetToken reports whether the given reset token is
+// associated with a non-retired connection ID which we have used.
+func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
+ if len(s.remote) == 0 {
+ return false
+ }
+ // We currently only use the first available remote connection ID,
+ // so any other reset token is not valid.
+ return s.remote[0].resetToken == resetToken
+}
+
// setPeerActiveConnIDLimit sets the active_connection_id_limit
// transport parameter received from the peer.
func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
@@ -137,16 +176,17 @@ func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
func (s *connIDState) issueLocalIDs(c *Conn) error {
toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
for i := range s.local {
- if s.local[i].seq != -1 && !s.local[i].retired {
+ if s.local[i].seq != -1 {
toIssue--
}
}
- prev := len(s.local)
+ var newIDs [][]byte
for toIssue > 0 {
cid, err := c.newConnID(s.nextLocalSeq)
if err != nil {
return err
}
+ newIDs = append(newIDs, cid)
s.local = append(s.local, connID{
seq: s.nextLocalSeq,
cid: cid,
@@ -156,40 +196,62 @@ func (s *connIDState) issueLocalIDs(c *Conn) error {
s.needSend = true
toIssue--
}
- const retired = false
- c.listener.connIDsChanged(c, retired, s.local[prev:])
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ for _, cid := range newIDs {
+ conns.addConnID(c, cid)
+ }
+ })
return nil
}
// validateTransportParameters verifies the original_destination_connection_id and
// initial_source_connection_id transport parameters match the expected values.
-func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error {
+func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
// TODO: Consider returning more detailed errors, for debugging.
- switch side {
- case clientSide:
- // Verify original_destination_connection_id matches
- // the transient remote connection ID we chose.
- if len(s.remote) == 0 || s.remote[0].seq != -1 {
- return localTransportError(errInternal)
- }
- if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) {
- return localTransportError(errTransportParameter)
+ // Verify original_destination_connection_id matches
+ // the transient remote connection ID we chose (client)
+ // or is empty (server).
+ if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
+ return localTransportError{
+ code: errTransportParameter,
+ reason: "original_destination_connection_id mismatch",
}
- // Remove the transient remote connection ID.
- // We have no further need for it.
- s.remote = append(s.remote[:0], s.remote[1:]...)
- case serverSide:
- if p.originalDstConnID != nil {
- // Clients do not send original_destination_connection_id.
- return localTransportError(errTransportParameter)
+ }
+ s.originalDstConnID = nil // we have no further need for this
+ // Verify retry_source_connection_id matches the value from
+ // the server's Retry packet (when one was sent), or is empty.
+ if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
+ return localTransportError{
+ code: errTransportParameter,
+ reason: "retry_source_connection_id mismatch",
}
}
+ s.retrySrcConnID = nil // we have no further need for this
// Verify initial_source_connection_id matches the first remote connection ID.
if len(s.remote) == 0 || s.remote[0].seq != 0 {
- return localTransportError(errInternal)
+ return localTransportError{
+ code: errInternal,
+ reason: "remote connection id missing",
+ }
}
if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
- return localTransportError(errTransportParameter)
+ return localTransportError{
+ code: errTransportParameter,
+ reason: "initial_source_connection_id mismatch",
+ }
+ }
+ if len(p.statelessResetToken) > 0 {
+ if c.side == serverSide {
+ return localTransportError{
+ code: errTransportParameter,
+ reason: "client sent stateless_reset_token",
+ }
+ }
+ token := statelessResetToken(p.statelessResetToken)
+ s.remote[0].resetToken = token
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.addResetToken(c, token)
+ })
}
return nil
}
@@ -203,63 +265,79 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
// We're a client connection processing the first Initial packet
// from the server. Replace the transient remote connection ID
// with the Source Connection ID from the packet.
- // Leave the transient ID the list for now, since we'll need it when
- // processing the transport parameters.
- s.remote[0].retired = true
- s.remote = append(s.remote, connID{
- seq: 0,
- cid: cloneBytes(srcConnID),
- })
- }
- case ptype == packetTypeInitial && c.side == serverSide:
- if len(s.remote) == 0 {
- // We're a server connection processing the first Initial packet
- // from the client. Set the client's connection ID.
- s.remote = append(s.remote, connID{
- seq: 0,
- cid: cloneBytes(srcConnID),
- })
+ s.remote[0] = remoteConnID{
+ connID: connID{
+ seq: 0,
+ cid: cloneBytes(srcConnID),
+ },
+ }
}
case ptype == packetTypeHandshake && c.side == serverSide:
- if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
+ if len(s.local) > 0 && s.local[0].seq == -1 {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
- const retired = true
- c.listener.connIDsChanged(c, retired, s.local[0:1])
+ cid := s.local[0].cid
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.retireConnID(c, cid)
+ })
s.local = append(s.local[:0], s.local[1:]...)
}
}
}
-func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken [16]byte) error {
+func (s *connIDState) handleRetryPacket(srcConnID []byte) {
+ if len(s.remote) != 1 || s.remote[0].seq != -1 {
+ panic("BUG: handling retry with non-transient remote conn id")
+ }
+ s.retrySrcConnID = cloneBytes(srcConnID)
+ s.remote[0].cid = s.retrySrcConnID
+}
+
+func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
if len(s.remote[0].cid) == 0 {
// "An endpoint that is sending packets with a zero-length
// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
// frame as a connection error of type PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
- return localTransportError(errProtocolViolation)
+ return localTransportError{
+ code: errProtocolViolation,
+ reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
+ }
+ }
+
+ if seq < s.retireRemotePriorTo {
+ // This ID was already retired by a previous NEW_CONNECTION_ID frame.
+ // Nothing to do.
+ return nil
}
if retire > s.retireRemotePriorTo {
+ // Add newly-retired connection IDs to the set we need to send
+ // RETIRE_CONNECTION_ID frames for, and remove them from s.remote.
+ //
+ // (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've
+ // never seen. That's fine.)
+ s.remoteRetiring.add(s.retireRemotePriorTo, retire)
s.retireRemotePriorTo = retire
+ s.needSend = true
+ s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool {
+ return rcid.seq < s.retireRemotePriorTo
+ })
}
have := false // do we already have this connection ID?
- active := 0
for i := range s.remote {
rcid := &s.remote[i]
- if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
- s.retireRemote(rcid)
- }
- if !rcid.retired {
- active++
- }
if rcid.seq == seq {
if !bytes.Equal(rcid.cid, cid) {
- return localTransportError(errProtocolViolation)
+ return localTransportError{
+ code: errProtocolViolation,
+ reason: "NEW_CONNECTION_ID does not match prior id",
+ }
}
have = true // yes, we've seen this sequence number
+ break
}
}
@@ -269,53 +347,57 @@ func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken
// We could take steps to keep the list of remote connection IDs
// sorted by sequence number, but there's no particular need
// so we don't bother.
- s.remote = append(s.remote, connID{
- seq: seq,
- cid: cloneBytes(cid),
+ s.remote = append(s.remote, remoteConnID{
+ connID: connID{
+ seq: seq,
+ cid: cloneBytes(cid),
+ },
+ resetToken: resetToken,
+ })
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.addResetToken(c, resetToken)
})
- if seq < s.retireRemotePriorTo {
- // This ID was already retired by a previous NEW_CONNECTION_ID frame.
- s.retireRemote(&s.remote[len(s.remote)-1])
- } else {
- active++
- }
}
- if active > activeConnIDLimit {
+ if len(s.remote) > activeConnIDLimit {
// Retired connection IDs (including newly-retired ones) do not count
// against the limit.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
- return localTransportError(errConnectionIDLimit)
+ return localTransportError{
+ code: errConnectionIDLimit,
+ reason: "active_connection_id_limit exceeded",
+ }
}
// "An endpoint SHOULD limit the number of connection IDs it has retired locally
// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
//
- // Set a limit of four times the active_connection_id_limit for
- // the total number of remote connection IDs we keep state for locally.
- if len(s.remote) > 4*activeConnIDLimit {
- return localTransportError(errConnectionIDLimit)
+ // Set a limit of three times the active_connection_id_limit for
+ // the total number of remote connection IDs we keep retirement state for.
+ if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit {
+ return localTransportError{
+ code: errConnectionIDLimit,
+ reason: "too many unacknowledged retired connection ids",
+ }
}
return nil
}
-// retireRemote marks a remote connection ID as retired.
-func (s *connIDState) retireRemote(rcid *connID) {
- rcid.retired = true
- rcid.send.setUnsent()
- s.needSend = true
-}
-
func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
if seq >= s.nextLocalSeq {
- return localTransportError(errProtocolViolation)
+ return localTransportError{
+ code: errProtocolViolation,
+ reason: "RETIRE_CONNECTION_ID for unissued sequence number",
+ }
}
for i := range s.local {
if s.local[i].seq == seq {
- const retired = true
- c.listener.connIDsChanged(c, retired, s.local[i:i+1])
+ cid := s.local[i].cid
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.retireConnID(c, cid)
+ })
s.local = append(s.local[:i], s.local[i+1:]...)
break
}
@@ -338,20 +420,11 @@ func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fat
}
func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
- for i := 0; i < len(s.remote); i++ {
- if s.remote[i].seq != seq {
- continue
- }
- if fate == packetAcked {
- // We have retired this connection ID, and the peer has acked.
- // Discard its state completely.
- s.remote = append(s.remote[:i], s.remote[i+1:]...)
- } else {
- // RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
- s.needSend = true
- s.remote[i].send.ackOrLoss(pnum, fate)
- }
- return
+ s.remoteRetiringSent.sub(seq, seq+1)
+ if fate == packetLost {
+ // RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
+ s.remoteRetiring.add(seq, seq+1)
+ s.needSend = true
}
}
@@ -360,7 +433,7 @@ func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64,
//
// It returns true if no more frames need appending,
// false if not everything fit in the current packet.
-func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
+func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
if !s.needSend && !pto {
// Fast path: We don't need to send anything.
return true
@@ -373,24 +446,32 @@ func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool)
if !s.local[i].send.shouldSendPTO(pto) {
continue
}
- if !w.appendNewConnectionIDFrame(
+ if !c.w.appendNewConnectionIDFrame(
s.local[i].seq,
retireBefore,
s.local[i].cid,
- [16]byte{}, // TODO: stateless reset token
+ c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
) {
return false
}
s.local[i].send.setSent(pnum)
}
- for i := range s.remote {
- if !s.remote[i].send.shouldSendPTO(pto) {
- continue
+ if pto {
+ for _, r := range s.remoteRetiringSent {
+ for cid := r.start; cid < r.end; cid++ {
+ if !c.w.appendRetireConnectionIDFrame(cid) {
+ return false
+ }
+ }
}
- if !w.appendRetireConnectionIDFrame(s.remote[i].seq) {
+ }
+ for s.remoteRetiring.numRanges() > 0 {
+ cid := s.remoteRetiring.min()
+ if !c.w.appendRetireConnectionIDFrame(cid) {
return false
}
- s.remote[i].send.setSent(pnum)
+ s.remoteRetiring.sub(cid, cid+1)
+ s.remoteRetiringSent.add(cid, cid+1)
}
s.needSend = false
return true
diff --git a/internal/quic/conn_id_test.go b/quic/conn_id_test.go
similarity index 79%
rename from internal/quic/conn_id_test.go
rename to quic/conn_id_test.go
index 44755ecf45..2c3f170160 100644
--- a/internal/quic/conn_id_test.go
+++ b/quic/conn_id_test.go
@@ -47,15 +47,14 @@ func TestConnIDClientHandshake(t *testing.T) {
if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
- wantRemote := []connID{{
- cid: testLocalConnID(-1),
- seq: -1,
- }, {
- cid: testPeerConnID(0),
- seq: 0,
+ wantRemote := []remoteConnID{{
+ connID: connID{
+ cid: testPeerConnID(0),
+ seq: 0,
+ },
}}
- if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
- t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+ if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
}
}
@@ -96,12 +95,14 @@ func TestConnIDServerHandshake(t *testing.T) {
if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
- wantRemote := []connID{{
- cid: testPeerConnID(0),
- seq: 0,
+ wantRemote := []remoteConnID{{
+ connID: connID{
+ cid: testPeerConnID(0),
+ seq: 0,
+ },
}}
- if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
- t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+ if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
}
// The client's first Handshake packet permits the server to discard the
@@ -137,6 +138,24 @@ func connIDListEqual(a, b []connID) bool {
return true
}
+func remoteConnIDListEqual(a, b []remoteConnID) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].seq != b[i].seq {
+ return false
+ }
+ if !bytes.Equal(a[i].cid, b[i].cid) {
+ return false
+ }
+ if a[i].resetToken != b[i].resetToken {
+ return false
+ }
+ }
+ return true
+}
+
func fmtConnIDList(s []connID) string {
var strs []string
for _, cid := range s {
@@ -145,6 +164,14 @@ func fmtConnIDList(s []connID) string {
return "{" + strings.Join(strs, " ") + "}"
}
+func fmtRemoteConnIDList(s []remoteConnID) string {
+ var strs []string
+ for _, cid := range s {
+ strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x} token:{%x}]", cid.seq, cid.cid, cid.resetToken))
+ }
+ return "{" + strings.Join(strs, " ") + "}"
+}
+
func TestNewRandomConnID(t *testing.T) {
cid, err := newRandomConnID(0)
if len(cid) != connIDLen || err != nil {
@@ -177,16 +204,19 @@ func TestConnIDPeerRequestsManyIDs(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
tc.wantFrame("provide additional connection ID 2",
packetType1RTT, debugFrameNewConnectionID{
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
tc.wantFrame("provide additional connection ID 3",
packetType1RTT, debugFrameNewConnectionID{
seq: 3,
connID: testLocalConnID(3),
+ token: testLocalStatelessResetToken(3),
})
tc.wantIdle("connection ID limit reached, no more to provide")
}
@@ -258,6 +288,7 @@ func TestConnIDPeerRetiresConnID(t *testing.T) {
seq: 2,
retirePriorTo: 1,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
})
}
@@ -458,6 +489,7 @@ func TestConnIDRepeatedRetireConnectionIDFrame(t *testing.T) {
retirePriorTo: 1,
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
tc.wantIdle("repeated RETIRE_CONNECTION_ID frames are not an error")
}
@@ -546,8 +578,11 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) {
p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0")
p.preferredAddrConnID = testPeerConnID(1)
p.preferredAddrResetToken = make([]byte, 16)
+ }, func(cids *newServerConnIDs) {
+ cids.srcConnID = []byte{}
+ }, func(tc *testConn) {
+ tc.peerConnID = []byte{}
})
- tc.peerConnID = []byte{}
tc.writeFrames(packetTypeInitial,
debugFrameCrypto{
@@ -586,3 +621,95 @@ func TestConnIDInitialSrcConnIDMismatch(t *testing.T) {
})
})
}
+
+func TestConnIDsCleanedUpAfterClose(t *testing.T) {
+ testSides(t, "", func(t *testing.T, side connSide) {
+ tc := newTestConn(t, side, func(p *transportParameters) {
+ if side == clientSide {
+ token := testPeerStatelessResetToken(0)
+ p.statelessResetToken = token[:]
+ }
+ })
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ retirePriorTo: 1,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(0),
+ })
+ tc.wantFrame("peer asked for conn id 0 to be retired",
+ packetType1RTT, debugFrameRetireConnectionID{
+ seq: 0,
+ })
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+ tc.conn.Abort(nil)
+ tc.wantFrame("CONN_CLOSE sent after user closes connection",
+ packetType1RTT, debugFrameConnectionCloseTransport{})
+
+ // Wait for the conn to drain.
+ // Then wait for the conn loop to exit,
+ // and force an immediate sync of the connsMap updates
+ // (normally only done by the endpoint read loop).
+ tc.advanceToTimer()
+ <-tc.conn.donec
+ tc.endpoint.e.connsMap.applyUpdates()
+
+ if got := len(tc.endpoint.e.connsMap.byConnID); got != 0 {
+ t.Errorf("%v conn ids in endpoint map after closing, want 0", got)
+ }
+ if got := len(tc.endpoint.e.connsMap.byResetToken); got != 0 {
+ t.Errorf("%v reset tokens in endpoint map after closing, want 0", got)
+ }
+ })
+}
+
+func TestConnIDRetiredConnIDResent(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ //tc.ignoreFrame(frameTypeRetireConnectionID)
+
+ // Send CID 2, retire 0-1 (negotiated during the handshake).
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ retirePriorTo: 2,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(2),
+ })
+ tc.wantFrame("retire CID 0", packetType1RTT, debugFrameRetireConnectionID{seq: 0})
+ tc.wantFrame("retire CID 1", packetType1RTT, debugFrameRetireConnectionID{seq: 1})
+
+ // Send CID 3, retire 2.
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 3,
+ retirePriorTo: 3,
+ connID: testPeerConnID(3),
+ token: testPeerStatelessResetToken(3),
+ })
+ tc.wantFrame("retire CID 2", packetType1RTT, debugFrameRetireConnectionID{seq: 2})
+
+ // Acknowledge retirement of CIDs 0-2.
+ // The server should have state for only one CID: 3.
+ tc.writeAckForAll()
+ if got, want := len(tc.conn.connIDState.remote), 1; got != want {
+ t.Fatalf("connection has state for %v connection IDs, want %v", got, want)
+ }
+
+ // Send CID 2 again.
+ // The server should ignore this, since it's already retired the CID.
+ tc.ignoreFrames[frameTypeRetireConnectionID] = false
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(2),
+ })
+ if got, want := len(tc.conn.connIDState.remote), 1; got != want {
+ t.Fatalf("connection has state for %v connection IDs, want %v", got, want)
+ }
+ tc.wantIdle("server does not re-retire already retired CID 2")
+}
diff --git a/internal/quic/conn_loss.go b/quic/conn_loss.go
similarity index 96%
rename from internal/quic/conn_loss.go
rename to quic/conn_loss.go
index 85bda314ec..623ebdd7c6 100644
--- a/internal/quic/conn_loss.go
+++ b/quic/conn_loss.go
@@ -20,6 +20,10 @@ import "fmt"
// See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss.
// https://www.rfc-editor.org/rfc/rfc9000#section-13.3
func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) {
+ if fate == packetLost && c.logEnabled(QLogLevelPacket) {
+ c.logPacketLost(space, sent)
+ }
+
// The list of frames in a sent packet is marshaled into a buffer in the sentPacket
// by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with
// packetWriter.append*.
diff --git a/internal/quic/conn_loss_test.go b/quic/conn_loss_test.go
similarity index 92%
rename from internal/quic/conn_loss_test.go
rename to quic/conn_loss_test.go
index 9b88462518..81d537803d 100644
--- a/internal/quic/conn_loss_test.go
+++ b/quic/conn_loss_test.go
@@ -160,6 +160,7 @@ func TestLostCryptoFrame(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
tc.triggerLossOrPTO(packetTypeHandshake, pto)
tc.wantFrame("client resends Handshake CRYPTO frame",
@@ -182,7 +183,7 @@ func TestLostStreamFrameEmpty(t *testing.T) {
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ c.Flush() // open the stream
tc.wantFrame("created bidirectional stream 0",
packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, bidiStream, 0),
@@ -212,6 +213,7 @@ func TestLostStreamWithData(t *testing.T) {
p.initialMaxStreamDataUni = 1 << 20
})
s.Write(data[:4])
+ s.Flush()
tc.wantFrame("send [0,4)",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -219,6 +221,7 @@ func TestLostStreamWithData(t *testing.T) {
data: data[:4],
})
s.Write(data[4:8])
+ s.Flush()
tc.wantFrame("send [4,8)",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -262,6 +265,7 @@ func TestLostStreamPartialLoss(t *testing.T) {
})
for i := range data {
s.Write(data[i : i+1])
+ s.Flush()
tc.wantFrame(fmt.Sprintf("send STREAM frame with byte %v", i),
packetType1RTT, debugFrameStream{
id: s.id,
@@ -304,9 +308,9 @@ func TestLostMaxDataFrame(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
off: 0,
- data: make([]byte, maxWindowSize),
+ data: make([]byte, maxWindowSize-1),
})
- if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 {
+ if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1)
}
tc.wantFrame("conn window is extended after reading data",
@@ -315,7 +319,12 @@ func TestLostMaxDataFrame(t *testing.T) {
})
// MAX_DATA = 64, which is only one more byte, so we don't send the frame.
- if n, err := s.Read(buf); err != nil || n != 1 {
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: maxWindowSize - 1,
+ data: make([]byte, 1),
+ })
+ if n, err := s.Read(buf[:1]); err != nil || n != 1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1)
}
tc.wantIdle("read doesn't extend window enough to send another MAX_DATA")
@@ -344,9 +353,9 @@ func TestLostMaxStreamDataFrame(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
off: 0,
- data: make([]byte, maxWindowSize),
+ data: make([]byte, maxWindowSize-1),
})
- if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 {
+ if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1)
}
tc.wantFrame("stream window is extended after reading data",
@@ -356,6 +365,11 @@ func TestLostMaxStreamDataFrame(t *testing.T) {
})
// MAX_STREAM_DATA = 64, which is only one more byte, so we don't send the frame.
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: maxWindowSize - 1,
+ data: make([]byte, 1),
+ })
if n, err := s.Read(buf); err != nil || n != 1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1)
}
@@ -429,7 +443,8 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) {
if err != nil {
t.Fatalf("AcceptStream() = %v", err)
}
- s.CloseContext(ctx)
+ s.SetWriteContext(ctx)
+ s.Close()
if styp == bidiStream {
tc.wantFrame("stream is closed",
packetType1RTT, debugFrameStream{
@@ -476,7 +491,7 @@ func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) {
if err != nil {
t.Fatalf("AcceptStream() = %v", err)
}
- if err := s.CloseContext(ctx); err != nil {
+ if err := s.Close(); err != nil {
t.Fatalf("stream.Close() = %v", err)
}
tc.wantFrame("closing stream updates peer's MAX_STREAMS",
@@ -508,7 +523,7 @@ func TestLostStreamDataBlockedFrame(t *testing.T) {
})
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, []byte{0, 1, 2, 3})
+ return s.Write([]byte{0, 1, 2, 3})
})
defer w.cancel()
tc.wantFrame("write is blocked by flow control",
@@ -560,7 +575,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) {
data := []byte{0, 1, 2, 3}
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, data)
+ return s.Write(data)
})
defer w.cancel()
tc.wantFrame("write is blocked by flow control",
@@ -607,6 +622,7 @@ func TestLostNewConnectionIDFrame(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
tc.triggerLossOrPTO(packetType1RTT, pto)
@@ -614,6 +630,7 @@ func TestLostNewConnectionIDFrame(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
})
}
@@ -646,6 +663,29 @@ func TestLostRetireConnectionIDFrame(t *testing.T) {
})
}
+func TestLostPathResponseFrame(t *testing.T) {
+ // "Responses to path validation using PATH_RESPONSE frames are sent just once."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.12
+ lostFrameTest(t, func(t *testing.T, pto bool) {
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypePing)
+
+ data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ tc.writeFrames(packetType1RTT, debugFramePathChallenge{
+ data: data,
+ })
+ tc.wantFrame("response to PATH_CHALLENGE",
+ packetType1RTT, debugFramePathResponse{
+ data: data,
+ })
+
+ tc.triggerLossOrPTO(packetType1RTT, pto)
+ tc.wantIdle("lost PATH_RESPONSE frame is not retransmitted")
+ })
+}
+
func TestLostHandshakeDoneFrame(t *testing.T) {
// "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged."
// https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16
@@ -669,6 +709,7 @@ func TestLostHandshakeDoneFrame(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
tc.writeFrames(packetTypeHandshake,
debugFrameCrypto{
diff --git a/internal/quic/conn_recv.go b/quic/conn_recv.go
similarity index 66%
rename from internal/quic/conn_recv.go
rename to quic/conn_recv.go
index 9b1ba1ae10..dbfe34a343 100644
--- a/internal/quic/conn_recv.go
+++ b/quic/conn_recv.go
@@ -13,43 +13,80 @@ import (
"time"
)
-func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
+func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) {
+ if !c.localAddr.IsValid() {
+ // We don't have any way to tell in the general case what address we're
+ // sending packets from. Set our address from the destination address of
+ // the first packet received from the peer.
+ c.localAddr = dgram.localAddr
+ }
+ if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr {
+ if c.side == clientSide {
+ // "If a client receives packets from an unknown server address,
+ // the client MUST discard these packets."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-9-6
+ return false
+ }
+ // We currently don't support connection migration,
+ // so for now the server also drops packets from an unknown address.
+ return false
+ }
buf := dgram.b
c.loss.datagramReceived(now, len(buf))
if c.isDraining() {
- return
+ return false
}
for len(buf) > 0 {
var n int
ptype := getPacketType(buf)
switch ptype {
case packetTypeInitial:
- if c.side == serverSide && len(dgram.b) < minimumClientInitialDatagramSize {
+ if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize {
// Discard client-sent Initial packets in too-short datagrams.
// https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4
- return
+ return false
}
- n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf)
+ n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf)
case packetTypeHandshake:
- n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
+ n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf)
case packetType1RTT:
- n = c.handle1RTT(now, buf)
+ n = c.handle1RTT(now, dgram, buf)
+ case packetTypeRetry:
+ c.handleRetry(now, buf)
+ return true
case packetTypeVersionNegotiation:
c.handleVersionNegotiation(now, buf)
- return
+ return true
default:
- return
+ n = -1
}
if n <= 0 {
+ // We don't expect to get a stateless reset with a valid
+ // destination connection ID, since the sender of a stateless
+ // reset doesn't know what the connection ID is.
+ //
+ // We're required to perform this check anyway.
+ //
+ // "[...] the comparison MUST be performed when the first packet
+ // in an incoming datagram [...] cannot be decrypted."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2
+ if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen {
+ var token statelessResetToken
+ copy(token[:], buf[len(buf)-len(token):])
+ if c.handleStatelessReset(now, token) {
+ return true
+ }
+ }
// Invalid data at the end of a datagram is ignored.
- break
+ return false
}
- c.idleTimeout = now.Add(c.maxIdleTimeout)
+ c.idleHandlePacketReceived(now)
buf = buf[n:]
}
+ return true
}
-func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
+func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
if !k.isSet() {
return skipLongHeaderPacket(buf)
}
@@ -62,12 +99,18 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
if buf[0]&reservedLongBits != 0 {
// Reserved header bits must be 0.
// https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "reserved header bits are not zero",
+ })
return -1
}
if p.version != quicVersion1 {
// The peer has changed versions on us mid-handshake?
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "protocol version changed during handshake",
+ })
return -1
}
@@ -78,8 +121,11 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
if logPackets {
logInboundLongPacket(c, p)
}
+ if c.logEnabled(QLogLevelPacket) {
+ c.logLongPacketReceived(p, buf[:n])
+ }
c.connIDState.handlePacket(c, p.ptype, p.srcConnID)
- ackEliciting := c.handleFrames(now, ptype, space, p.payload)
+ ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload)
c.acks[space].receive(now, space, p.num, ackEliciting)
if p.ptype == packetTypeHandshake && c.side == serverSide {
c.loss.validateClientAddress()
@@ -92,7 +138,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
return n
}
-func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
+func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int {
if !c.keysAppData.canRead() {
// 1-RTT packets extend to the end of the datagram,
// so skip the remainder of the datagram if we can't parse this.
@@ -112,7 +158,10 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
if buf[0]&reserved1RTTBits != 0 {
// Reserved header bits must be 0.
// https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "reserved header bits are not zero",
+ })
return -1
}
@@ -123,11 +172,50 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
if logPackets {
logInboundShortPacket(c, p)
}
- ackEliciting := c.handleFrames(now, packetType1RTT, appDataSpace, p.payload)
+ if c.logEnabled(QLogLevelPacket) {
+ c.log1RTTPacketReceived(p, buf)
+ }
+ ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload)
c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting)
return len(buf)
}
+func (c *Conn) handleRetry(now time.Time, pkt []byte) {
+ if c.side != clientSide {
+ return // clients don't send Retry packets
+ }
+ // "After the client has received and processed an Initial or Retry packet
+ // from the server, it MUST discard any subsequent Retry packets that it receives."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1
+ if !c.keysInitial.canRead() {
+ return // discarded Initial keys, connection is already established
+ }
+ if c.acks[initialSpace].seen.numRanges() != 0 {
+ return // processed at least one packet
+ }
+ if c.retryToken != nil {
+ return // received a Retry already
+ }
+ // "Clients MUST discard Retry packets that have a Retry Integrity Tag
+ // that cannot be validated."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2
+ p, ok := parseRetryPacket(pkt, c.connIDState.originalDstConnID)
+ if !ok {
+ return
+ }
+ // "A client MUST discard a Retry packet with a zero-length Retry Token field."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2
+ if len(p.token) == 0 {
+ return
+ }
+ c.retryToken = cloneBytes(p.token)
+ c.connIDState.handleRetryPacket(p.srcConnID)
+ // We need to resend any data we've already sent in Initial packets.
+ // We must not reuse already sent packet numbers.
+ c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss)
+ // TODO: Discard 0-RTT packets as well, once we support 0-RTT.
+}
+
var errVersionNegotiation = errors.New("server does not support QUIC version 1")
func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) {
@@ -164,12 +252,15 @@ func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) {
c.abortImmediately(now, errVersionNegotiation)
}
-func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) {
+func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) {
if len(payload) == 0 {
// "An endpoint MUST treat receipt of a packet containing no frames
// as a connection error of type PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "packet contains no frames",
+ })
return false
}
// frameOK verifies that ptype is one of the packets in mask.
@@ -179,7 +270,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
// that is not permitted as a connection error of type
// PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "frame not allowed in packet",
+ })
return false
}
return true
@@ -191,6 +285,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
__01 = packetType0RTT | packetType1RTT
___1 = packetType1RTT
)
+ hasCrypto := false
for len(payload) > 0 {
switch payload[0] {
case frameTypePadding, frameTypeAck, frameTypeAckECN,
@@ -228,6 +323,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
if !frameOK(c, ptype, IH_1) {
return
}
+ hasCrypto = true
n = c.handleCryptoFrame(now, space, payload)
case frameTypeNewToken:
if !frameOK(c, ptype, ___1) {
@@ -279,6 +375,16 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
return
}
n = c.handleRetireConnectionIDFrame(now, space, payload)
+ case frameTypePathChallenge:
+ if !frameOK(c, ptype, __01) {
+ return
+ }
+ n = c.handlePathChallengeFrame(now, dgram, space, payload)
+ case frameTypePathResponse:
+ if !frameOK(c, ptype, ___1) {
+ return
+ }
+ n = c.handlePathResponseFrame(now, space, payload)
case frameTypeConnectionCloseTransport:
// Transport CONNECTION_CLOSE is OK in all spaces.
n = c.handleConnectionCloseTransportFrame(now, payload)
@@ -294,11 +400,23 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
n = c.handleHandshakeDoneFrame(now, space, payload)
}
if n < 0 {
- c.abort(now, localTransportError(errFrameEncoding))
+ c.abort(now, localTransportError{
+ code: errFrameEncoding,
+ reason: "frame encoding error",
+ })
return false
}
payload = payload[n:]
}
+ if hasCrypto {
+ // Process TLS events after handling all frames in a packet.
+ // TLS events can cause us to drop state for a number space,
+ // so do that last, to avoid handling frames differently
+ // depending on whether they come before or after a CRYPTO frame.
+ if err := c.handleTLSEvents(now); err != nil {
+ c.abort(now, err)
+ }
+ }
return ackEliciting
}
@@ -307,7 +425,10 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte)
largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
if end > c.loss.nextNumber(space) {
// Acknowledgement of a packet we never sent.
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "acknowledgement for unsent packet",
+ })
return
}
c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss)
@@ -336,7 +457,7 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte)
if c.peerAckDelayExponent >= 0 {
delay = ackDelay.Duration(uint8(c.peerAckDelayExponent))
}
- c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
+ c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss)
if space == appDataSpace {
c.keysAppData.handleAckFor(largest)
}
@@ -429,7 +550,7 @@ func (c *Conn) handleNewConnectionIDFrame(now time.Time, space numberSpace, payl
if n < 0 {
return -1
}
- if err := c.connIDState.handleNewConnID(seq, retire, connID, resetToken); err != nil {
+ if err := c.connIDState.handleNewConnID(c, seq, retire, connID, resetToken); err != nil {
c.abort(now, err)
}
return n
@@ -446,12 +567,30 @@ func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, p
return n
}
+func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int {
+ data, n := consumePathChallengeFrame(payload)
+ if n < 0 {
+ return -1
+ }
+ c.handlePathChallenge(now, dgram, data)
+ return n
+}
+
+func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int {
+ data, n := consumePathResponseFrame(payload)
+ if n < 0 {
+ return -1
+ }
+ c.handlePathResponse(now, data)
+ return n
+}
+
func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int {
code, _, reason, n := consumeConnectionCloseTransportFrame(payload)
if n < 0 {
return -1
}
- c.enterDraining(peerTransportError{code: code, reason: reason})
+ c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason})
return n
}
@@ -460,7 +599,7 @@ func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []by
if n < 0 {
return -1
}
- c.enterDraining(&ApplicationError{Code: code, Reason: reason})
+ c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason})
return n
}
@@ -468,11 +607,25 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa
if c.side == serverSide {
// Clients should never send HANDSHAKE_DONE.
// https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4
- c.abort(now, localTransportError(errProtocolViolation))
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "client sent HANDSHAKE_DONE",
+ })
return -1
}
- if !c.isClosingOrDraining() {
+ if c.isAlive() {
c.confirmHandshake(now)
}
return 1
}
+
+var errStatelessReset = errors.New("received stateless reset")
+
+func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) {
+ if !c.connIDState.isValidStatelessResetToken(resetToken) {
+ return false
+ }
+ c.setFinalError(errStatelessReset)
+ c.enterDraining(now)
+ return true
+}
diff --git a/quic/conn_recv_test.go b/quic/conn_recv_test.go
new file mode 100644
index 0000000000..0e94731bf7
--- /dev/null
+++ b/quic/conn_recv_test.go
@@ -0,0 +1,60 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "testing"
+)
+
+func TestConnReceiveAckForUnsentPacket(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.handshake()
+ tc.writeFrames(packetType1RTT,
+ debugFrameAck{
+ ackDelay: 0,
+ ranges: []i64range[packetNumber]{{0, 10}},
+ })
+ tc.wantFrame("ACK for unsent packet causes CONNECTION_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ })
+}
+
+// Issue #70703: If a packet contains both a CRYPTO frame which causes us to
+// drop state for a number space, and also contains a valid ACK frame for that space,
+// we shouldn't complain about the ACK.
+func TestConnReceiveAckForDroppedSpace(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("send Initial crypto",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("send Handshake crypto",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ },
+ debugFrameAck{
+ ackDelay: 0,
+ ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
+ })
+ tc.wantFrame("handshake finishes",
+ packetType1RTT, debugFrameHandshakeDone{})
+ tc.wantIdle("connection is idle")
+}
diff --git a/internal/quic/conn_send.go b/quic/conn_send.go
similarity index 78%
rename from internal/quic/conn_send.go
rename to quic/conn_send.go
index 00b02c2a31..a87cac232e 100644
--- a/internal/quic/conn_send.go
+++ b/quic/conn_send.go
@@ -22,7 +22,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// Assumption: The congestion window is not underutilized.
// If congestion control, pacing, and anti-amplification all permit sending,
// but we have no packet to send, then we will declare the window underutilized.
- c.loss.cc.setUnderutilized(false)
+ underutilized := false
+ defer func() {
+ c.loss.cc.setUnderutilized(c.log, underutilized)
+ }()
// Send one datagram on each iteration of this loop,
// until we hit a limit or run out of data to send.
@@ -60,7 +63,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
pad := false
var sentInitial *sentPacket
if c.keysInitial.canWrite() {
- pnumMaxAcked := c.acks[initialSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked
pnum := c.loss.nextNumber(initialSpace)
p := longPacket{
ptype: packetTypeInitial,
@@ -68,18 +71,23 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
num: pnum,
dstConnID: dstConnID,
srcConnID: c.connIDState.srcConnID(),
+ extra: c.retryToken,
}
c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
c.appendFrames(now, initialSpace, pnum, limit)
if logPackets {
logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload())
+ }
sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p)
if sentInitial != nil {
- // Client initial packets need to be sent in a datagram padded to
- // at least 1200 bytes. We can't add the padding yet, however,
- // since we may want to coalesce additional packets with this one.
- if c.side == clientSide {
+ // Client initial packets and ack-eliciting server initial packaets
+ // need to be sent in a datagram padded to at least 1200 bytes.
+ // We can't add the padding yet, however, since we may want to
+ // coalesce additional packets with this one.
+ if c.side == clientSide || sentInitial.ackEliciting {
pad = true
}
}
@@ -87,7 +95,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// Handshake packet.
if c.keysHandshake.canWrite() {
- pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked
pnum := c.loss.nextNumber(handshakeSpace)
p := longPacket{
ptype: packetTypeHandshake,
@@ -101,8 +109,11 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if logPackets {
logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload())
+ }
if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil {
- c.loss.packetSent(now, handshakeSpace, sent)
+ c.packetSent(now, handshakeSpace, sent)
if c.side == clientSide {
// "[...] a client MUST discard Initial keys when it first
// sends a Handshake packet [...]"
@@ -114,7 +125,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// 1-RTT packet.
if c.keysAppData.canWrite() {
- pnumMaxAcked := c.acks[appDataSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked
pnum := c.loss.nextNumber(appDataSpace)
c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
c.appendFrames(now, appDataSpace, pnum, limit)
@@ -122,14 +133,17 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// 1-RTT packets have no length field and extend to the end
// of the datagram, so if we're sending a datagram that needs
// padding we need to add it inside the 1-RTT packet.
- c.w.appendPaddingTo(minimumClientInitialDatagramSize)
+ c.w.appendPaddingTo(paddedInitialDatagramSize)
pad = false
}
if logPackets {
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload())
+ }
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
- c.loss.packetSent(now, appDataSpace, sent)
+ c.packetSent(now, appDataSpace, sent)
}
}
@@ -138,7 +152,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if limit == ccOK {
// We have nothing to send, and congestion control does not
// block sending. The congestion window is underutilized.
- c.loss.cc.setUnderutilized(true)
+ underutilized = true
}
return next
}
@@ -148,7 +162,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// Pad out the datagram with zeros, coalescing the Initial
// packet with invalid packets that will be ignored by the peer.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-14.1-1
- for len(buf) < minimumClientInitialDatagramSize {
+ for len(buf) < paddedInitialDatagramSize {
buf = append(buf, 0)
// Technically this padding isn't in any packet, but
// account it to the Initial packet in this datagram
@@ -161,14 +175,22 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// with a Handshake packet, then we've discarded Initial keys
// since constructing the packet and shouldn't record it as in-flight.
if c.keysInitial.canWrite() {
- c.loss.packetSent(now, initialSpace, sentInitial)
+ c.packetSent(now, initialSpace, sentInitial)
}
}
- c.listener.sendDatagram(buf, c.peerAddr)
+ c.endpoint.sendDatagram(datagram{
+ b: buf,
+ peerAddr: c.peerAddr,
+ })
}
}
+func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+ c.idleHandlePacketSent(now, sent)
+ c.loss.packetSent(now, c.log, space, sent)
+}
+
func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) {
if c.lifetime.localErr != nil {
c.appendConnectionCloseFrame(now, space, c.lifetime.localErr)
@@ -208,11 +230,7 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
// Either we are willing to send an ACK-only packet,
// or we've added additional frames.
c.acks[space].sentAck()
- if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() {
- // The peer has initiated a key update.
- // We haven't sent them any packets yet in the new phase.
- // Make this an ack-eliciting packet.
- // Their ack of this packet will complete the key update.
+ if !c.w.sent.ackEliciting && c.shouldMakePacketAckEliciting() {
c.w.appendPingFrame()
}
}()
@@ -249,8 +267,15 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
}
// NEW_CONNECTION_ID, RETIRE_CONNECTION_ID
- if !c.connIDState.appendFrames(&c.w, pnum, pto) {
+ if !c.connIDState.appendFrames(c, pnum, pto) {
+ return
+ }
+
+ // PATH_RESPONSE
+ if pad, ok := c.appendPathFrames(); !ok {
return
+ } else if pad {
+ defer c.w.appendPaddingTo(smallestMaxDatagramSize)
}
// All stream-related frames. This should come last in the packet,
@@ -259,6 +284,10 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
if !c.appendStreamFrames(&c.w, pnum, pto) {
return
}
+
+ if !c.appendKeepAlive(now) {
+ return
+ }
}
// If this is a PTO probe and we haven't added an ack-eliciting frame yet,
@@ -313,6 +342,30 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
}
}
+// shouldMakePacketAckEliciting is called when sending a packet containing nothing but an ACK frame.
+// It reports whether we should add a PING frame to the packet to make it ack-eliciting.
+func (c *Conn) shouldMakePacketAckEliciting() bool {
+ if c.keysAppData.needAckEliciting() {
+ // The peer has initiated a key update.
+ // We haven't sent them any packets yet in the new phase.
+ // Make this an ack-eliciting packet.
+ // Their ack of this packet will complete the key update.
+ return true
+ }
+ if c.loss.consecutiveNonAckElicitingPackets >= 19 {
+ // We've sent a run of non-ack-eliciting packets.
+ // Add in an ack-eliciting one every once in a while so the peer
+ // lets us know which ones have arrived.
+ //
+ // Google QUICHE injects a PING after sending 19 packets. We do the same.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2
+ return true
+ }
+ // TODO: Consider making every packet sent when in PTO ack-eliciting to speed up recovery.
+ return false
+}
+
func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool {
seen, delay := c.acks[space].acksToSend(now)
if len(seen) == 0 {
@@ -323,10 +376,10 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool {
}
func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) {
- c.lifetime.connCloseSentTime = now
+ c.sentConnectionClose(now)
switch e := err.(type) {
case localTransportError:
- c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "")
+ c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason)
case *ApplicationError:
if space != appDataSpace {
// "CONNECTION_CLOSE frames signaling application errors (type 0x1d)
@@ -340,11 +393,12 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err
// TLS alerts are sent using error codes [0x0100,0x01ff).
// https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1
var alert tls.AlertError
- if errors.As(err, &alert) {
+ switch {
+ case errors.As(err, &alert):
// tls.AlertError is a uint8, so this can't exceed 0x01ff.
code := errTLSBase + transportError(alert)
c.w.appendConnectionCloseTransportFrame(code, 0, "")
- } else {
+ default:
c.w.appendConnectionCloseTransportFrame(errInternal, 0, "")
}
}
diff --git a/quic/conn_send_test.go b/quic/conn_send_test.go
new file mode 100644
index 0000000000..2205ff2f79
--- /dev/null
+++ b/quic/conn_send_test.go
@@ -0,0 +1,83 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "testing"
+ "time"
+)
+
+func TestAckElicitingAck(t *testing.T) {
+ // "A receiver that sends only non-ack-eliciting packets [...] might not receive
+ // an acknowledgment for a long period of time.
+ // [...] a receiver could send a [...] ack-eliciting frame occasionally [...]
+ // to elicit an ACK from the peer."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2
+ //
+ // Send a bunch of ack-eliciting packets, verify that the conn doesn't just
+ // send ACKs in response.
+ tc := newTestConn(t, clientSide, permissiveTransportParameters)
+ tc.handshake()
+ const count = 100
+ for i := 0; i < count; i++ {
+ tc.advance(1 * time.Millisecond)
+ tc.writeFrames(packetType1RTT,
+ debugFramePing{},
+ )
+ got, _ := tc.readFrame()
+ switch got.(type) {
+ case debugFrameAck:
+ continue
+ case debugFramePing:
+ return
+ }
+ }
+ t.Errorf("after sending %v PINGs, got no ack-eliciting response", count)
+}
+
+func TestSendPacketNumberSize(t *testing.T) {
+ tc := newTestConn(t, clientSide, permissiveTransportParameters)
+ tc.handshake()
+
+ recvPing := func() *testPacket {
+ t.Helper()
+ tc.conn.ping(appDataSpace)
+ p := tc.readPacket()
+ if p == nil {
+ t.Fatalf("want packet containing PING, got none")
+ }
+ return p
+ }
+
+ // Desynchronize the packet numbers the conn is sending and the ones it is receiving,
+ // by having the conn send a number of unacked packets.
+ for i := 0; i < 16; i++ {
+ recvPing()
+ }
+
+ // Establish the maximum packet number the conn has received an ACK for.
+ maxAcked := recvPing().num
+ tc.writeAckForAll()
+
+ // Make the conn send a sequence of packets.
+ // Check that the packet number is encoded with two bytes once the difference between the
+ // current packet and the max acked one is sufficiently large.
+ for want := maxAcked + 1; want < maxAcked+0x100; want++ {
+ p := recvPing()
+ if p.num != want {
+ t.Fatalf("received packet number %v, want %v", p.num, want)
+ }
+ gotPnumLen := int(p.header&0x03) + 1
+ wantPnumLen := 1
+ if p.num-maxAcked >= 0x80 {
+ wantPnumLen = 2
+ }
+ if gotPnumLen != wantPnumLen {
+ t.Fatalf("packet number 0x%x encoded with %v bytes, want %v (max acked = %v)", p.num, gotPnumLen, wantPnumLen, maxAcked)
+ }
+ }
+}
diff --git a/internal/quic/conn_streams.go b/quic/conn_streams.go
similarity index 90%
rename from internal/quic/conn_streams.go
rename to quic/conn_streams.go
index a0793297e1..87cfd297ed 100644
--- a/internal/quic/conn_streams.go
+++ b/quic/conn_streams.go
@@ -16,8 +16,14 @@ import (
type streamsState struct {
queue queue[*Stream] // new, peer-created streams
- streamsMu sync.Mutex
- streams map[streamID]*Stream
+ // All peer-created streams.
+ //
+ // Implicitly created streams are included as an empty entry in the map.
+ // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and
+ // insert an empty entry for it to the map.)
+ //
+ // The map value is maybeStream rather than *Stream as a reminder that values can be nil.
+ streams map[streamID]maybeStream
// Limits on the number of streams, indexed by streamType.
localLimit [streamTypeCount]localStreamLimits
@@ -39,8 +45,13 @@ type streamsState struct {
queueData streamRing // streams with only flow-controlled frames
}
+// maybeStream is a possibly nil *Stream. See streamsState.streams.
+type maybeStream struct {
+ s *Stream
+}
+
func (c *Conn) streamsInit() {
- c.streams.streams = make(map[streamID]*Stream)
+ c.streams.streams = make(map[streamID]maybeStream)
c.streams.queue = newQueue[*Stream]()
c.streams.localLimit[bidiStream].init()
c.streams.localLimit[uniStream].init()
@@ -49,6 +60,17 @@ func (c *Conn) streamsInit() {
c.inflowInit()
}
+func (c *Conn) streamsCleanup() {
+ c.streams.queue.close(errConnClosed)
+ c.streams.localLimit[bidiStream].connHasClosed()
+ c.streams.localLimit[uniStream].connHasClosed()
+ for _, s := range c.streams.streams {
+ if s.s != nil {
+ s.s.connHasClosed()
+ }
+ }
+}
+
// AcceptStream waits for and returns the next stream created by the peer.
func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) {
return c.streams.queue.get(ctx, c.testHooks)
@@ -71,9 +93,6 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) {
}
func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) {
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
-
num, err := c.streams.localLimit[styp].open(ctx, c)
if err != nil {
return nil, err
@@ -89,7 +108,12 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er
s.inUnlock()
s.outUnlock()
- c.streams.streams[s.id] = s
+ // Modify c.streams on the conn's loop.
+ if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ c.streams.streams[s.id] = maybeStream{s}
+ }); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -108,9 +132,7 @@ const (
// streamForID returns the stream with the given id.
// If the stream does not exist, it returns nil.
func (c *Conn) streamForID(id streamID) *Stream {
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
- return c.streams.streams[id]
+ return c.streams.streams[id].s
}
// streamForFrame returns the stream with the given id.
@@ -127,16 +149,17 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
if (id.initiator() == c.side) != (ftype == sendStream) {
// Received an invalid frame for unidirectional stream.
// For example, a RESET_STREAM frame for a send-only stream.
- c.abort(now, localTransportError(errStreamState))
+ c.abort(now, localTransportError{
+ code: errStreamState,
+ reason: "invalid frame for unidirectional stream",
+ })
return nil
}
}
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
- s, isOpen := c.streams.streams[id]
- if s != nil {
- return s
+ ms, isOpen := c.streams.streams[id]
+ if ms.s != nil {
+ return ms.s
}
num := id.num()
@@ -148,7 +171,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
}
// Received a frame for a stream that should be originated by us,
// but which we never created.
- c.abort(now, localTransportError(errStreamState))
+ c.abort(now, localTransportError{
+ code: errStreamState,
+ reason: "received frame for unknown stream",
+ })
return nil
} else {
// if isOpen, this is a stream that was implicitly opened by a
@@ -170,10 +196,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
// with the same initiator and type and a lower number.
// Add a nil entry to the streams map for each implicitly created stream.
for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 {
- c.streams.streams[n] = nil
+ c.streams.streams[n] = maybeStream{}
}
- s = newStream(c, id)
+ s := newStream(c, id)
s.inmaxbuf = c.config.maxStreamReadBufferSize()
s.inwin = c.config.maxStreamReadBufferSize()
if id.streamType() == bidiStream {
@@ -183,7 +209,7 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
s.inUnlock()
s.outUnlock()
- c.streams.streams[id] = s
+ c.streams.streams[id] = maybeStream{s}
c.streams.queue.put(s)
return s
}
@@ -387,7 +413,11 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool {
c.streams.sendMu.Lock()
defer c.streams.sendMu.Unlock()
const pto = true
- for _, s := range c.streams.streams {
+ for _, ms := range c.streams.streams {
+ s := ms.s
+ if s == nil {
+ continue
+ }
const pto = true
s.ingate.lock()
inOK := s.appendInFramesLocked(w, pnum, pto)
diff --git a/internal/quic/conn_streams_test.go b/quic/conn_streams_test.go
similarity index 84%
rename from internal/quic/conn_streams_test.go
rename to quic/conn_streams_test.go
index 69f982c3a6..dc81ad9913 100644
--- a/internal/quic/conn_streams_test.go
+++ b/quic/conn_streams_test.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"math"
+ "sync"
"testing"
)
@@ -19,33 +20,33 @@ func TestStreamsCreate(t *testing.T) {
tc := newTestConn(t, clientSide, permissiveTransportParameters)
tc.handshake()
- c, err := tc.conn.NewStream(ctx)
+ s, err := tc.conn.NewStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created bidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 0, // client-initiated, bidi, number 0
data: []byte{},
})
- c, err = tc.conn.NewSendOnlyStream(ctx)
+ s, err = tc.conn.NewSendOnlyStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created unidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 2, // client-initiated, uni, number 0
data: []byte{},
})
- c, err = tc.conn.NewStream(ctx)
+ s, err = tc.conn.NewStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created bidirectional stream 1",
packetType1RTT, debugFrameStream{
id: 4, // client-initiated, uni, number 4
@@ -177,11 +178,11 @@ func TestStreamsStreamSendOnly(t *testing.T) {
tc := newTestConn(t, serverSide, permissiveTransportParameters)
tc.handshake()
- c, err := tc.conn.NewSendOnlyStream(ctx)
+ s, err := tc.conn.NewSendOnlyStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created unidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 3, // server-initiated, uni, number 0
@@ -229,8 +230,8 @@ func TestStreamsWriteQueueFairness(t *testing.T) {
t.Fatal(err)
}
streams = append(streams, s)
- if n, err := s.WriteContext(ctx, data); n != len(data) || err != nil {
- t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(data))
+ if n, err := s.Write(data); n != len(data) || err != nil {
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
// Wait for the stream to finish writing whatever frames it can before
// congestion control blocks it.
@@ -297,7 +298,7 @@ func TestStreamsShutdown(t *testing.T) {
side: localStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- s.CloseContext(canceledContext())
+ s.Close()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
@@ -310,7 +311,7 @@ func TestStreamsShutdown(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameResetStream{
id: s.id,
})
- s.CloseContext(canceledContext())
+ s.Close()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
@@ -320,8 +321,8 @@ func TestStreamsShutdown(t *testing.T) {
side: localStream,
styp: bidiStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- s.CloseContext(canceledContext())
- tc.wantIdle("all frames after CloseContext are ignored")
+ s.Close()
+ tc.wantIdle("all frames after Close are ignored")
tc.writeAckForAll()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
@@ -334,13 +335,12 @@ func TestStreamsShutdown(t *testing.T) {
side: remoteStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- ctx := canceledContext()
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
fin: true,
})
- if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF {
- t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err)
+ if n, err := s.Read(make([]byte, 16)); n != 0 || err != io.EOF {
+ t.Errorf("Read() = %v, %v; want 0, io.EOF", n, err)
}
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
@@ -450,17 +450,14 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
id: op.id,
})
case acceptOp:
- s, err := tc.conn.AcceptStream(ctx)
- if err != nil {
- t.Fatalf("AcceptStream() = %q; want stream %v", err, stringID(op.id))
- }
+ s := tc.acceptStream()
if s.id != op.id {
- t.Fatalf("accepted stram %v; want stream %v", err, stringID(op.id))
+ t.Fatalf("accepted stream %v; want stream %v", stringID(s.id), stringID(op.id))
}
t.Logf("accepted stream %v", stringID(op.id))
// Immediately close the stream, so the stream becomes done when the
// peer closes its end.
- s.CloseContext(ctx)
+ s.Close()
}
p := tc.readPacket()
if p != nil {
@@ -478,3 +475,85 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
t.Fatalf("after test, stream send queue is not empty; should be")
}
}
+
+func TestStreamsCreateConcurrency(t *testing.T) {
+ cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+
+ srvdone := make(chan int)
+ go func() {
+ defer close(srvdone)
+ for streams := 0; ; streams++ {
+ s, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ srvdone <- streams
+ return
+ }
+ s.Close()
+ }
+ }()
+
+ var wg sync.WaitGroup
+ const concurrency = 10
+ const streams = 10
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < streams; j++ {
+ s, err := cli.NewStream(context.Background())
+ if err != nil {
+ t.Errorf("NewStream: %v", err)
+ return
+ }
+ s.Flush()
+ _, err = io.ReadAll(s)
+ if err != nil {
+ t.Errorf("ReadFull: %v", err)
+ }
+ s.Close()
+ }
+ }()
+ }
+ wg.Wait()
+
+ cli.Abort(nil)
+ srv.Abort(nil)
+ if got, want := <-srvdone, concurrency*streams; got != want {
+ t.Errorf("accepted %v streams, want %v", got, want)
+ }
+}
+
+func TestStreamsPTOWithImplicitStream(t *testing.T) {
+ ctx := canceledContext()
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+
+ // Peer creates stream 1, and implicitly creates stream 0.
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ })
+
+ // We accept stream 1 and write data to it.
+ data := []byte("data")
+ s, err := tc.conn.AcceptStream(ctx)
+ if err != nil {
+ t.Fatalf("conn.AcceptStream() = %v, want stream", err)
+ }
+ s.Write(data)
+ s.Flush()
+ tc.wantFrame("data written to stream",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ data: data,
+ })
+
+ // PTO expires, and the data is resent.
+ const pto = true
+ tc.triggerLossOrPTO(packetType1RTT, true)
+ tc.wantFrame("data resent after PTO expires",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ data: data,
+ })
+}
diff --git a/internal/quic/conn_test.go b/quic/conn_test.go
similarity index 73%
rename from internal/quic/conn_test.go
rename to quic/conn_test.go
index 6a359e89a1..51402630fc 100644
--- a/internal/quic/conn_test.go
+++ b/quic/conn_test.go
@@ -13,50 +13,63 @@ import (
"errors"
"flag"
"fmt"
+ "log/slog"
"math"
"net/netip"
"reflect"
"strings"
"testing"
"time"
+
+ "golang.org/x/net/quic/qlog"
)
-var testVV = flag.Bool("vv", false, "even more verbose test output")
+var (
+ testVV = flag.Bool("vv", false, "even more verbose test output")
+ qlogdir = flag.String("qlog", "", "write qlog logs to directory")
+)
func TestConnTestConn(t *testing.T) {
tc := newTestConn(t, serverSide)
+ tc.handshake()
if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
}
- var ranAt time.Time
- tc.conn.runOnLoop(func(now time.Time, c *Conn) {
- ranAt = now
- })
- if !ranAt.Equal(tc.now) {
- t.Errorf("func ran on loop at %v, want %v", ranAt, tc.now)
+ ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
+ tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ when = now
+ })
+ return
+ }).result()
+ if !ranAt.Equal(tc.endpoint.now) {
+ t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
}
tc.wait()
- nextTime := tc.now.Add(defaultMaxIdleTimeout / 2)
+ nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
tc.advanceTo(nextTime)
- tc.conn.runOnLoop(func(now time.Time, c *Conn) {
- ranAt = now
- })
+ ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
+ tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ when = now
+ })
+ return
+ }).result()
if !ranAt.Equal(nextTime) {
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
}
tc.wait()
tc.advanceToTimer()
- if !tc.conn.exited {
- t.Errorf("after advancing to idle timeout, exited = false, want true")
+ if got := tc.conn.lifetime.state; got != connStateDone {
+ t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
}
}
type testDatagram struct {
packets []*testPacket
paddedSize int
+ addr netip.AddrPort
}
func (d testDatagram) String() string {
@@ -74,14 +87,17 @@ func (d testDatagram) String() string {
}
type testPacket struct {
- ptype packetType
- version uint32
- num packetNumber
- keyPhaseBit bool
- keyNumber int
- dstConnID []byte
- srcConnID []byte
- frames []debugFrame
+ ptype packetType
+ header byte
+ version uint32
+ num packetNumber
+ keyPhaseBit bool
+ keyNumber int
+ dstConnID []byte
+ srcConnID []byte
+ token []byte
+ originalDstConnID []byte // used for encoding Retry packets
+ frames []debugFrame
}
func (p testPacket) String() string {
@@ -96,6 +112,9 @@ func (p testPacket) String() string {
if p.dstConnID != nil {
fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
}
+ if p.token != nil {
+ fmt.Fprintf(&b, " token={%x}", p.token)
+ }
for _, f := range p.frames {
fmt.Fprintf(&b, "\n %v", f)
}
@@ -110,8 +129,7 @@ const maxTestKeyPhases = 3
type testConn struct {
t *testing.T
conn *Conn
- listener *testListener
- now time.Time
+ endpoint *testEndpoint
timer time.Time
timerLastFired time.Time
idlec chan struct{} // only accessed on the conn's loop
@@ -150,6 +168,7 @@ type testConn struct {
sentDatagrams [][]byte
sentPackets []*testPacket
sentFrames []debugFrame
+ lastDatagram *testDatagram
lastPacket *testPacket
recvDatagram chan *datagram
@@ -183,10 +202,65 @@ type keySecret struct {
// allowing test code to access Conn state directly
// by first ensuring the loop goroutine is idle.
func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
+ t.Helper()
+ config := &Config{
+ TLSConfig: newTestTLSConfig(side),
+ StatelessResetKey: testStatelessResetKey,
+ QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ Dir: *qlogdir,
+ })),
+ }
+ var cids newServerConnIDs
+ if side == serverSide {
+ // The initial connection ID for the server is chosen by the client.
+ cids.srcConnID = testPeerConnID(0)
+ cids.dstConnID = testPeerConnID(-1)
+ cids.originalDstConnID = cids.dstConnID
+ }
+ var configTransportParams []func(*transportParameters)
+ var configTestConn []func(*testConn)
+ for _, o := range opts {
+ switch o := o.(type) {
+ case func(*Config):
+ o(config)
+ case func(*tls.Config):
+ o(config.TLSConfig)
+ case func(cids *newServerConnIDs):
+ o(&cids)
+ case func(p *transportParameters):
+ configTransportParams = append(configTransportParams, o)
+ case func(p *testConn):
+ configTestConn = append(configTestConn, o)
+ default:
+ t.Fatalf("unknown newTestConn option %T", o)
+ }
+ }
+
+ endpoint := newTestEndpoint(t, config)
+ endpoint.configTransportParams = configTransportParams
+ endpoint.configTestConn = configTestConn
+ conn, err := endpoint.e.newConn(
+ endpoint.now,
+ config,
+ side,
+ cids,
+ "",
+ netip.MustParseAddrPort("127.0.0.1:443"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ tc := endpoint.conns[conn]
+ tc.wait()
+ return tc
+}
+
+func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
t.Helper()
tc := &testConn{
t: t,
- now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
+ endpoint: endpoint,
+ conn: conn,
peerConnID: testPeerConnID(0),
ignoreFrames: map[byte]bool{
frameTypePadding: true, // ignore PADDING by default
@@ -196,80 +270,51 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
+ for _, f := range endpoint.configTestConn {
+ f(tc)
+ }
+ conn.testHooks = (*testConnHooks)(tc)
- config := &Config{
- TLSConfig: newTestTLSConfig(side),
+ if endpoint.peerTLSConn != nil {
+ tc.peerTLSConn = endpoint.peerTLSConn
+ endpoint.peerTLSConn = nil
+ return tc
}
+
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = testPeerConnID(0)
- if side == clientSide {
+ if conn.side == clientSide {
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
}
- for _, o := range opts {
- switch o := o.(type) {
- case func(*Config):
- o(config)
- case func(*tls.Config):
- o(config.TLSConfig)
- case func(p *transportParameters):
- o(&peerProvidedParams)
- default:
- t.Fatalf("unknown newTestConn option %T", o)
- }
- }
-
- var initialConnID []byte
- if side == serverSide {
- // The initial connection ID for the server is chosen by the client.
- initialConnID = testPeerConnID(-1)
+ for _, f := range endpoint.configTransportParams {
+ f(&peerProvidedParams)
}
- peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())}
- if side == clientSide {
+ peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
+ if conn.side == clientSide {
tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
} else {
tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
}
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tc.peerTLSConn.Start(context.Background())
+ t.Cleanup(func() {
+ tc.peerTLSConn.Close()
+ })
- tc.listener = newTestListener(t, config, (*testConnHooks)(tc))
- conn, err := tc.listener.l.newConn(
- tc.now,
- side,
- initialConnID,
- netip.MustParseAddrPort("127.0.0.1:443"))
- if err != nil {
- tc.t.Fatal(err)
- }
- tc.conn = conn
-
- conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
- tc.keysInitial.r = conn.keysInitial.w
- tc.keysInitial.w = conn.keysInitial.r
-
- tc.wait()
return tc
}
// advance causes time to pass.
func (tc *testConn) advance(d time.Duration) {
tc.t.Helper()
- tc.advanceTo(tc.now.Add(d))
+ tc.endpoint.advance(d)
}
// advanceTo sets the current time.
func (tc *testConn) advanceTo(now time.Time) {
tc.t.Helper()
- if tc.now.After(now) {
- tc.t.Fatalf("time moved backwards: %v -> %v", tc.now, now)
- }
- tc.now = now
- if tc.timer.After(tc.now) {
- return
- }
- tc.conn.sendMsg(timerEvent{})
- tc.wait()
+ tc.endpoint.advanceTo(now)
}
// advanceToTimer sets the current time to the time of the Conn's next timer event.
@@ -284,10 +329,10 @@ func (tc *testConn) timerDelay() time.Duration {
if tc.timer.IsZero() {
return math.MaxInt64 // infinite
}
- if tc.timer.Before(tc.now) {
+ if tc.timer.Before(tc.endpoint.now) {
return 0
}
- return tc.timer.Sub(tc.now)
+ return tc.timer.Sub(tc.endpoint.now)
}
const infiniteDuration = time.Duration(math.MaxInt64)
@@ -297,10 +342,10 @@ func (tc *testConn) timeUntilEvent() time.Duration {
if tc.timer.IsZero() {
return infiniteDuration
}
- if tc.timer.Before(tc.now) {
+ if tc.timer.Before(tc.endpoint.now) {
return 0
}
- return tc.timer.Sub(tc.now)
+ return tc.timer.Sub(tc.endpoint.now)
}
// wait blocks until the conn becomes idle.
@@ -340,8 +385,19 @@ func (tc *testConn) cleanup() {
<-tc.conn.donec
}
-func (tc *testConn) logDatagram(text string, d *testDatagram) {
+func (tc *testConn) acceptStream() *Stream {
tc.t.Helper()
+ s, err := tc.conn.AcceptStream(canceledContext())
+ if err != nil {
+ tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
+ }
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return s
+}
+
+func logDatagram(t *testing.T, text string, d *testDatagram) {
+ t.Helper()
if !*testVV {
return
}
@@ -349,7 +405,7 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) {
if d.paddedSize > 0 {
pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
}
- tc.t.Logf("%v datagram%v", text, pad)
+ t.Logf("%v datagram%v", text, pad)
for _, p := range d.packets {
var s string
switch p.ptype {
@@ -358,15 +414,18 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) {
default:
s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
}
+ if p.token != nil {
+ s += fmt.Sprintf(" token={%x}", p.token)
+ }
if p.keyPhaseBit {
s += fmt.Sprintf(" KeyPhase")
}
if p.keyNumber != 0 {
s += fmt.Sprintf(" keynum=%v", p.keyNumber)
}
- tc.t.Log(s)
+ t.Log(s)
for _, f := range p.frames {
- tc.t.Logf(" %v", f)
+ t.Logf(" %v", f)
}
}
}
@@ -374,30 +433,10 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) {
// write sends the Conn a datagram.
func (tc *testConn) write(d *testDatagram) {
tc.t.Helper()
- var buf []byte
- tc.logDatagram("<- conn under test receives", d)
- for _, p := range d.packets {
- space := spaceForPacketType(p.ptype)
- if p.num >= tc.peerNextPacketNum[space] {
- tc.peerNextPacketNum[space] = p.num + 1
- }
- pad := 0
- if p.ptype == packetType1RTT {
- pad = d.paddedSize
- }
- buf = append(buf, tc.encodeTestPacket(p, pad)...)
- }
- for len(buf) < d.paddedSize {
- buf = append(buf, 0)
- }
- // TODO: This should use tc.listener.write.
- tc.conn.sendMsg(&datagram{
- b: buf,
- })
- tc.wait()
+ tc.endpoint.writeDatagram(d)
}
-// writeFrame sends the Conn a datagram containing the given frames.
+// writeFrames sends the Conn a datagram containing the given frames.
func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
tc.t.Helper()
space := spaceForPacketType(ptype)
@@ -417,6 +456,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
dstConnID: dstConnID,
srcConnID: tc.peerConnID,
}},
+ addr: tc.conn.peerAddr,
}
if ptype == packetTypeInitial && tc.conn.side == serverSide {
d.paddedSize = 1200
@@ -460,14 +500,14 @@ func (tc *testConn) readDatagram() *testDatagram {
tc.wait()
tc.sentPackets = nil
tc.sentFrames = nil
- buf := tc.listener.read()
+ buf := tc.endpoint.read()
if buf == nil {
return nil
}
- d := tc.parseTestDatagram(buf)
+ d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
// Log the datagram before removing ignored frames.
// When things go wrong, it's useful to see all the frames.
- tc.logDatagram("-> conn under test sends", d)
+ logDatagram(tc.t, "-> conn under test sends", d)
typeForFrame := func(f debugFrame) byte {
// This is very clunky, and points at a problem
// in how we specify what frames to ignore in tests.
@@ -539,6 +579,7 @@ func (tc *testConn) readDatagram() *testDatagram {
}
p.frames = frames
}
+ tc.lastDatagram = d
return d
}
@@ -551,7 +592,13 @@ func (tc *testConn) readPacket() *testPacket {
if d == nil {
return nil
}
- tc.sentPackets = d.packets
+ for _, p := range d.packets {
+ if len(p.frames) == 0 {
+ tc.lastPacket = p
+ continue
+ }
+ tc.sentPackets = append(tc.sentPackets, p)
+ }
}
p := tc.sentPackets[0]
tc.sentPackets = tc.sentPackets[1:]
@@ -579,20 +626,67 @@ func (tc *testConn) readFrame() (debugFrame, packetType) {
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
tc.t.Helper()
got := tc.readDatagram()
- if !reflect.DeepEqual(got, want) {
+ if !datagramEqual(got, want) {
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
+func datagramEqual(a, b *testDatagram) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
+ if a.paddedSize != b.paddedSize ||
+ a.addr != b.addr ||
+ len(a.packets) != len(b.packets) {
+ return false
+ }
+ for i := range a.packets {
+ if !packetEqual(a.packets[i], b.packets[i]) {
+ return false
+ }
+ }
+ return true
+}
+
// wantPacket indicates that we expect the Conn to send a packet.
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
got := tc.readPacket()
- if !reflect.DeepEqual(got, want) {
+ if !packetEqual(got, want) {
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
}
}
+func packetEqual(a, b *testPacket) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
+ ac := *a
+ ac.frames = nil
+ ac.header = 0
+ bc := *b
+ bc.frames = nil
+ bc.header = 0
+ if !reflect.DeepEqual(ac, bc) {
+ return false
+ }
+ if len(a.frames) != len(b.frames) {
+ return false
+ }
+ for i := range a.frames {
+ if !frameEqual(a.frames[i], b.frames[i]) {
+ return false
+ }
+ }
+ return true
+}
+
// wantFrame indicates that we expect the Conn to send a frame.
func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
@@ -603,11 +697,20 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
- if !reflect.DeepEqual(got, want) {
+ if !frameEqual(got, want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
}
}
+func frameEqual(a, b debugFrame) bool {
+ switch af := a.(type) {
+ case debugFrameConnectionCloseTransport:
+ bf, ok := b.(debugFrameConnectionCloseTransport)
+ return ok && af.code == bf.code
+ }
+ return reflect.DeepEqual(a, b)
+}
+
// wantFrameType indicates that we expect the Conn to send a frame,
// although we don't care about the contents.
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
@@ -638,21 +741,29 @@ func (tc *testConn) wantIdle(expectation string) {
}
}
-func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
- tc.t.Helper()
+func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
+ t.Helper()
var w packetWriter
w.reset(1200)
var pnumMaxAcked packetNumber
- if p.ptype != packetType1RTT {
+ switch p.ptype {
+ case packetTypeRetry:
+ return encodeRetryPacket(p.originalDstConnID, retryPacket{
+ srcConnID: p.srcConnID,
+ dstConnID: p.dstConnID,
+ token: p.token,
+ })
+ case packetType1RTT:
+ w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
+ default:
w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
+ extra: p.token,
})
- } else {
- w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
}
for _, f := range p.frames {
f.write(&w)
@@ -660,14 +771,22 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
w.appendPaddingTo(pad)
if p.ptype != packetType1RTT {
var k fixedKeys
- switch p.ptype {
- case packetTypeInitial:
- k = tc.keysInitial.w
- case packetTypeHandshake:
- k = tc.keysHandshake.w
+ if tc == nil {
+ if p.ptype == packetTypeInitial {
+ k = initialKeys(p.dstConnID, serverSide).r
+ } else {
+ t.Fatalf("sending %v packet with no conn", p.ptype)
+ }
+ } else {
+ switch p.ptype {
+ case packetTypeInitial:
+ k = tc.keysInitial.w
+ case packetTypeHandshake:
+ k = tc.keysHandshake.w
+ }
}
if !k.isSet() {
- tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+ t.Fatalf("sending %v packet with no write key", p.ptype)
}
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
ptype: p.ptype,
@@ -675,10 +794,11 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
+ extra: p.token,
})
} else {
- if !tc.wkeyAppData.hdr.isSet() {
- tc.t.Fatalf("sending 1-RTT packet with no write key")
+ if tc == nil || !tc.wkeyAppData.hdr.isSet() {
+ t.Fatalf("sending 1-RTT packet with no write key")
}
// Somewhat hackish: Generate a temporary updatingKeyPair that will
// always use our desired key phase.
@@ -700,8 +820,8 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
return w.datagram()
}
-func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
- tc.t.Helper()
+func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
+ t.Helper()
bufSize := len(buf)
d := &testDatagram{}
size := len(buf)
@@ -711,38 +831,67 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
break
}
ptype := getPacketType(buf)
- if isLongHeader(buf[0]) {
- var k fixedKeyPair
- switch ptype {
- case packetTypeInitial:
- k = tc.keysInitial
- case packetTypeHandshake:
- k = tc.keysHandshake
+ switch ptype {
+ case packetTypeRetry:
+ retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
+ if !ok {
+ t.Fatalf("could not parse %v packet", ptype)
}
- if !k.canRead() {
- tc.t.Fatalf("reading %v packet with no read key", ptype)
+ return &testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ dstConnID: retry.dstConnID,
+ srcConnID: retry.srcConnID,
+ token: retry.token,
+ }},
+ }
+ case packetTypeInitial, packetTypeHandshake:
+ var k fixedKeys
+ if tc == nil {
+ if ptype == packetTypeInitial {
+ p, _ := parseGenericLongHeaderPacket(buf)
+ k = initialKeys(p.srcConnID, serverSide).w
+ } else {
+ t.Fatalf("reading %v packet with no conn", ptype)
+ }
+ } else {
+ switch ptype {
+ case packetTypeInitial:
+ k = tc.keysInitial.r
+ case packetTypeHandshake:
+ k = tc.keysHandshake.r
+ }
+ }
+ if !k.isSet() {
+ t.Fatalf("reading %v packet with no read key", ptype)
}
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parseLongHeaderPacket(buf, k.r, pnumMax)
+ p, n := parseLongHeaderPacket(buf, k, pnumMax)
if n < 0 {
- tc.t.Fatalf("packet parse error")
+ t.Fatalf("packet parse error")
}
- frames, err := tc.parseTestFrames(p.payload)
+ frames, err := parseTestFrames(t, p.payload)
if err != nil {
- tc.t.Fatal(err)
+ t.Fatal(err)
+ }
+ var token []byte
+ if ptype == packetTypeInitial && len(p.extra) > 0 {
+ token = p.extra
}
d.packets = append(d.packets, &testPacket{
ptype: p.ptype,
+ header: buf[0],
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
+ token: token,
frames: frames,
})
buf = buf[n:]
- } else {
- if !tc.rkeyAppData.hdr.isSet() {
- tc.t.Fatalf("reading 1-RTT packet with no read key")
+ case packetType1RTT:
+ if tc == nil || !tc.rkeyAppData.hdr.isSet() {
+ t.Fatalf("reading 1-RTT packet with no read key")
}
var pnumMax packetNumber // TODO: Track packet numbers.
pnumOff := 1 + len(tc.peerConnID)
@@ -756,7 +905,7 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
b := append([]byte{}, buf...)
hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
if err != nil {
- tc.t.Fatalf("1-RTT packet header parse error")
+ t.Fatalf("1-RTT packet header parse error")
}
k := tc.rkeyAppData.pkt[phase]
pay, err = k.unprotect(hdr, pay, pnum)
@@ -765,14 +914,15 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
}
}
if err != nil {
- tc.t.Fatalf("1-RTT packet payload parse error")
+ t.Fatalf("1-RTT packet payload parse error")
}
- frames, err := tc.parseTestFrames(pay)
+ frames, err := parseTestFrames(t, pay)
if err != nil {
- tc.t.Fatal(err)
+ t.Fatal(err)
}
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
+ header: hdr[0],
num: pnum,
dstConnID: hdr[1:][:len(tc.peerConnID)],
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
@@ -780,6 +930,8 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
frames: frames,
})
buf = buf[len(buf):]
+ default:
+ t.Fatalf("unhandled packet type %v", ptype)
}
}
// This is rather hackish: If the last frame in the last packet
@@ -799,8 +951,8 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
return d
}
-func (tc *testConn) parseTestFrames(payload []byte) ([]debugFrame, error) {
- tc.t.Helper()
+func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
+ t.Helper()
var frames []debugFrame
for len(payload) > 0 {
f, n := parseDebugFrame(payload)
@@ -822,7 +974,7 @@ func spaceForPacketType(ptype packetType) numberSpace {
case packetTypeHandshake:
return handshakeSpace
case packetTypeRetry:
- panic("TODO: packetTypeRetry")
+ panic("retry packets have no number space")
case packetType1RTT:
return appDataSpace
}
@@ -832,6 +984,15 @@ func spaceForPacketType(ptype packetType) numberSpace {
// testConnHooks implements connTestHooks.
type testConnHooks testConn
+func (tc *testConnHooks) init() {
+ tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
+ tc.keysInitial.r = tc.conn.keysInitial.w
+ tc.keysInitial.w = tc.conn.keysInitial.r
+ if tc.conn.side == serverSide {
+ tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
+ }
+}
+
// handleTLSEvent processes TLS events generated by
// the connection under test's tls.QUICConn.
//
@@ -929,20 +1090,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
tc.timer = timer
for {
- if !timer.IsZero() && !timer.After(tc.now) {
+ if !timer.IsZero() && !timer.After(tc.endpoint.now) {
if timer.Equal(tc.timerLastFired) {
// If the connection timer fires at time T, the Conn should take some
// action to advance the timer into the future. If the Conn reschedules
// the timer for the same time, it isn't making progress and we have a bug.
- tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.now, timer)
+ tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
} else {
tc.timerLastFired = timer
- return tc.now, timerEvent{}
+ return tc.endpoint.now, timerEvent{}
}
}
select {
case m := <-msgc:
- return tc.now, m
+ return tc.endpoint.now, m
default:
}
if !tc.wakeAsync() {
@@ -956,7 +1117,7 @@ func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.T
close(idlec)
}
m = <-msgc
- return tc.now, m
+ return tc.endpoint.now, m
}
func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
@@ -964,7 +1125,7 @@ func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
}
func (tc *testConnHooks) timeNow() time.Time {
- return tc.now
+ return tc.endpoint.now
}
// testLocalConnID returns the connection ID with a given sequence number
@@ -984,6 +1145,13 @@ func testPeerConnID(seq int64) []byte {
return []byte{0xbe, 0xee, 0xff, byte(seq)}
}
+func testPeerStatelessResetToken(seq int64) statelessResetToken {
+ return statelessResetToken{
+ 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
+ 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
+ }
+}
+
// canceledContext returns a canceled Context.
//
// Functions which take a context preference progress over cancelation.
diff --git a/internal/quic/crypto_stream.go b/quic/crypto_stream.go
similarity index 86%
rename from internal/quic/crypto_stream.go
rename to quic/crypto_stream.go
index 8aa8f7b828..806c963943 100644
--- a/internal/quic/crypto_stream.go
+++ b/quic/crypto_stream.go
@@ -30,7 +30,10 @@ type cryptoStream struct {
func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error {
end := off + int64(len(b))
if end-s.inset.min() > cryptoBufferSize {
- return localTransportError(errCryptoBufferExceeded)
+ return localTransportError{
+ code: errCryptoBufferExceeded,
+ reason: "crypto buffer exceeded",
+ }
}
s.inset.add(off, end)
if off == s.in.start {
@@ -136,3 +139,21 @@ func (s *cryptoStream) sendData(off int64, b []byte) {
s.out.copy(off, b)
s.outunsent.sub(off, off+int64(len(b)))
}
+
+// discardKeys is called when the packet protection keys for the stream are dropped.
+func (s *cryptoStream) discardKeys() error {
+ if s.in.end-s.in.start != 0 {
+ // The peer sent some unprocessed CRYPTO data that we're about to discard.
+ // Close the connetion with a TLS unexpected_message alert.
+ // https://www.rfc-editor.org/rfc/rfc5246#section-7.2.2
+ const unexpectedMessage = 10
+ return localTransportError{
+ code: errTLSBase + unexpectedMessage,
+ reason: "excess crypto data",
+ }
+ }
+ // Discard any unacked (but presumably received) data in our output buffer.
+ s.out.discardBefore(s.out.end)
+ *s = cryptoStream{}
+ return nil
+}
diff --git a/internal/quic/crypto_stream_test.go b/quic/crypto_stream_test.go
similarity index 96%
rename from internal/quic/crypto_stream_test.go
rename to quic/crypto_stream_test.go
index a6c1e1b521..6bee8bb9f6 100644
--- a/internal/quic/crypto_stream_test.go
+++ b/quic/crypto_stream_test.go
@@ -94,6 +94,21 @@ func TestCryptoStreamReceive(t *testing.T) {
end: 3000,
want: 4000,
}},
+ }, {
+ name: "resent consumed data",
+ frames: []frame{{
+ start: 0,
+ end: 1000,
+ want: 1000,
+ }, {
+ start: 1000,
+ end: 2000,
+ want: 2000,
+ }, {
+ start: 0,
+ end: 1000,
+ want: 2000,
+ }},
}} {
t.Run(test.name, func(t *testing.T) {
var s cryptoStream
diff --git a/internal/quic/dgram.go b/quic/dgram.go
similarity index 58%
rename from internal/quic/dgram.go
rename to quic/dgram.go
index 79e6650fa4..6155893732 100644
--- a/internal/quic/dgram.go
+++ b/quic/dgram.go
@@ -12,10 +12,25 @@ import (
)
type datagram struct {
- b []byte
- addr netip.AddrPort
+ b []byte
+ localAddr netip.AddrPort
+ peerAddr netip.AddrPort
+ ecn ecnBits
}
+// Explicit Congestion Notification bits.
+//
+// https://www.rfc-editor.org/rfc/rfc3168.html#section-5
+type ecnBits byte
+
+const (
+ ecnMask = 0b000000_11
+ ecnNotECT = 0b000000_00
+ ecnECT1 = 0b000000_01
+ ecnECT0 = 0b000000_10
+ ecnCE = 0b000000_11
+)
+
var datagramPool = sync.Pool{
New: func() any {
return &datagram{
@@ -26,7 +41,9 @@ var datagramPool = sync.Pool{
func newDatagram() *datagram {
m := datagramPool.Get().(*datagram)
- m.b = m.b[:cap(m.b)]
+ *m = datagram{
+ b: m.b[:cap(m.b)],
+ }
return m
}
diff --git a/quic/doc.go b/quic/doc.go
new file mode 100644
index 0000000000..2fd10f0878
--- /dev/null
+++ b/quic/doc.go
@@ -0,0 +1,45 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package quic implements the QUIC protocol.
+//
+// This package is a work in progress.
+// It is not ready for production usage.
+// Its API is subject to change without notice.
+//
+// This package is low-level.
+// Most users will use it indirectly through an HTTP/3 implementation.
+//
+// # Usage
+//
+// An [Endpoint] sends and receives traffic on a network address.
+// Create an Endpoint to either accept inbound QUIC connections
+// or create outbound ones.
+//
+// A [Conn] is a QUIC connection.
+//
+// A [Stream] is a QUIC stream, an ordered, reliable byte stream.
+//
+// # Cancelation
+//
+// All blocking operations may be canceled using a context.Context.
+// When performing an operation with a canceled context, the operation
+// will succeed if doing so does not require blocking. For example,
+// reading from a stream will return data when buffered data is available,
+// even if the stream context is canceled.
+//
+// # Limitations
+//
+// This package is a work in progress.
+// Known limitations include:
+//
+// - Performance is untuned.
+// - 0-RTT is not supported.
+// - Address migration is not supported.
+// - Server preferred addresses are not supported.
+// - The latency spin bit is not supported.
+// - Stream send/receive windows are configurable,
+// but are fixed and do not adapt to available throughput.
+// - Path MTU discovery is not implemented.
+package quic
diff --git a/quic/endpoint.go b/quic/endpoint.go
new file mode 100644
index 0000000000..b9ababe6b1
--- /dev/null
+++ b/quic/endpoint.go
@@ -0,0 +1,480 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "crypto/rand"
+ "errors"
+ "net"
+ "net/netip"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// An Endpoint handles QUIC traffic on a network address.
+// It can accept inbound connections or create outbound ones.
+//
+// Multiple goroutines may invoke methods on an Endpoint simultaneously.
+type Endpoint struct {
+ listenConfig *Config
+ packetConn packetConn
+ testHooks endpointTestHooks
+ resetGen statelessResetTokenGenerator
+ retry retryState
+
+ acceptQueue queue[*Conn] // new inbound connections
+ connsMap connsMap // only accessed by the listen loop
+
+ connsMu sync.Mutex
+ conns map[*Conn]struct{}
+ closing bool // set when Close is called
+ closec chan struct{} // closed when the listen loop exits
+}
+
+type endpointTestHooks interface {
+ timeNow() time.Time
+ newConn(c *Conn)
+}
+
+// A packetConn is the interface to sending and receiving UDP packets.
+type packetConn interface {
+ Close() error
+ LocalAddr() netip.AddrPort
+ Read(f func(*datagram))
+ Write(datagram) error
+}
+
+// Listen listens on a local network address.
+//
+// The config is used to for connections accepted by the endpoint.
+// If the config is nil, the endpoint will not accept connections.
+func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
+ if listenConfig != nil && listenConfig.TLSConfig == nil {
+ return nil, errors.New("TLSConfig is not set")
+ }
+ a, err := net.ResolveUDPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ udpConn, err := net.ListenUDP(network, a)
+ if err != nil {
+ return nil, err
+ }
+ pc, err := newNetUDPConn(udpConn)
+ if err != nil {
+ return nil, err
+ }
+ return newEndpoint(pc, listenConfig, nil)
+}
+
+// NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport.
+//
+// If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack
+// access to some features of the network.
+func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) {
+ var pc packetConn
+ var err error
+ switch conn := conn.(type) {
+ case *net.UDPConn:
+ pc, err = newNetUDPConn(conn)
+ default:
+ pc, err = newNetPacketConn(conn)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return newEndpoint(pc, config, nil)
+}
+
+func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
+ e := &Endpoint{
+ listenConfig: config,
+ packetConn: pc,
+ testHooks: hooks,
+ conns: make(map[*Conn]struct{}),
+ acceptQueue: newQueue[*Conn](),
+ closec: make(chan struct{}),
+ }
+ var statelessResetKey [32]byte
+ if config != nil {
+ statelessResetKey = config.StatelessResetKey
+ }
+ e.resetGen.init(statelessResetKey)
+ e.connsMap.init()
+ if config != nil && config.RequireAddressValidation {
+ if err := e.retry.init(); err != nil {
+ return nil, err
+ }
+ }
+ go e.listen()
+ return e, nil
+}
+
+// LocalAddr returns the local network address.
+func (e *Endpoint) LocalAddr() netip.AddrPort {
+ return e.packetConn.LocalAddr()
+}
+
+// Close closes the Endpoint.
+// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked
+// and return errors.
+//
+// Close aborts every open connection.
+// Data in stream read and write buffers is discarded.
+// It waits for the peers of any open connection to acknowledge the connection has been closed.
+func (e *Endpoint) Close(ctx context.Context) error {
+ e.acceptQueue.close(errors.New("endpoint closed"))
+
+ // It isn't safe to call Conn.Abort or conn.exit with connsMu held,
+ // so copy the list of conns.
+ var conns []*Conn
+ e.connsMu.Lock()
+ if !e.closing {
+ e.closing = true // setting e.closing prevents new conns from being created
+ for c := range e.conns {
+ conns = append(conns, c)
+ }
+ if len(e.conns) == 0 {
+ e.packetConn.Close()
+ }
+ }
+ e.connsMu.Unlock()
+
+ for _, c := range conns {
+ c.Abort(localTransportError{code: errNo})
+ }
+ select {
+ case <-e.closec:
+ case <-ctx.Done():
+ for _, c := range conns {
+ c.exit()
+ }
+ return ctx.Err()
+ }
+ return nil
+}
+
+// Accept waits for and returns the next connection.
+func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
+ return e.acceptQueue.get(ctx, nil)
+}
+
+// Dial creates and returns a connection to a network address.
+// The config cannot be nil.
+func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
+ u, err := net.ResolveUDPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ addr := u.AddrPort()
+ addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
+ c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
+ if err != nil {
+ return nil, err
+ }
+ if err := c.waitReady(ctx); err != nil {
+ c.Abort(nil)
+ return nil, err
+ }
+ return c, nil
+}
+
+func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
+ e.connsMu.Lock()
+ defer e.connsMu.Unlock()
+ if e.closing {
+ return nil, errors.New("endpoint closed")
+ }
+ c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
+ if err != nil {
+ return nil, err
+ }
+ e.conns[c] = struct{}{}
+ return c, nil
+}
+
+// serverConnEstablished is called by a conn when the handshake completes
+// for an inbound (serverSide) connection.
+func (e *Endpoint) serverConnEstablished(c *Conn) {
+ e.acceptQueue.put(c)
+}
+
+// connDrained is called by a conn when it leaves the draining state,
+// either when the peer acknowledges connection closure or the drain timeout expires.
+func (e *Endpoint) connDrained(c *Conn) {
+ var cids [][]byte
+ for i := range c.connIDState.local {
+ cids = append(cids, c.connIDState.local[i].cid)
+ }
+ var tokens []statelessResetToken
+ for i := range c.connIDState.remote {
+ tokens = append(tokens, c.connIDState.remote[i].resetToken)
+ }
+ e.connsMap.updateConnIDs(func(conns *connsMap) {
+ for _, cid := range cids {
+ conns.retireConnID(c, cid)
+ }
+ for _, token := range tokens {
+ conns.retireResetToken(c, token)
+ }
+ })
+ e.connsMu.Lock()
+ defer e.connsMu.Unlock()
+ delete(e.conns, c)
+ if e.closing && len(e.conns) == 0 {
+ e.packetConn.Close()
+ }
+}
+
+func (e *Endpoint) listen() {
+ defer close(e.closec)
+ e.packetConn.Read(func(m *datagram) {
+ if e.connsMap.updateNeeded.Load() {
+ e.connsMap.applyUpdates()
+ }
+ e.handleDatagram(m)
+ })
+}
+
+func (e *Endpoint) handleDatagram(m *datagram) {
+ dstConnID, ok := dstConnIDForDatagram(m.b)
+ if !ok {
+ m.recycle()
+ return
+ }
+ c := e.connsMap.byConnID[string(dstConnID)]
+ if c == nil {
+ // TODO: Move this branch into a separate goroutine to avoid blocking
+ // the endpoint while processing packets.
+ e.handleUnknownDestinationDatagram(m)
+ return
+ }
+
+ // TODO: This can block the endpoint while waiting for the conn to accept the dgram.
+ // Think about buffering between the receive loop and the conn.
+ c.sendMsg(m)
+}
+
+func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
+ defer func() {
+ if m != nil {
+ m.recycle()
+ }
+ }()
+ const minimumValidPacketSize = 21
+ if len(m.b) < minimumValidPacketSize {
+ return
+ }
+ var now time.Time
+ if e.testHooks != nil {
+ now = e.testHooks.timeNow()
+ } else {
+ now = time.Now()
+ }
+ // Check to see if this is a stateless reset.
+ var token statelessResetToken
+ copy(token[:], m.b[len(m.b)-len(token):])
+ if c := e.connsMap.byResetToken[token]; c != nil {
+ c.sendMsg(func(now time.Time, c *Conn) {
+ c.handleStatelessReset(now, token)
+ })
+ return
+ }
+ // If this is a 1-RTT packet, there's nothing productive we can do with it.
+ // Send a stateless reset if possible.
+ if !isLongHeader(m.b[0]) {
+ e.maybeSendStatelessReset(m.b, m.peerAddr)
+ return
+ }
+ p, ok := parseGenericLongHeaderPacket(m.b)
+ if !ok || len(m.b) < paddedInitialDatagramSize {
+ return
+ }
+ switch p.version {
+ case quicVersion1:
+ case 0:
+ // Version Negotiation for an unknown connection.
+ return
+ default:
+ // Unknown version.
+ e.sendVersionNegotiation(p, m.peerAddr)
+ return
+ }
+ if getPacketType(m.b) != packetTypeInitial {
+ // This packet isn't trying to create a new connection.
+ // It might be associated with some connection we've lost state for.
+ // We are technically permitted to send a stateless reset for
+ // a long-header packet, but this isn't generally useful. See:
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
+ return
+ }
+ if e.listenConfig == nil {
+ // We are not configured to accept connections.
+ return
+ }
+ cids := newServerConnIDs{
+ srcConnID: p.srcConnID,
+ dstConnID: p.dstConnID,
+ }
+ if e.listenConfig.RequireAddressValidation {
+ var ok bool
+ cids.retrySrcConnID = p.dstConnID
+ cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
+ if !ok {
+ return
+ }
+ } else {
+ cids.originalDstConnID = p.dstConnID
+ }
+ var err error
+ c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
+ if err != nil {
+ // The accept queue is probably full.
+ // We could send a CONNECTION_CLOSE to the peer to reject the connection.
+ // Currently, we just drop the datagram.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
+ return
+ }
+ c.sendMsg(m)
+ m = nil // don't recycle, sendMsg takes ownership
+}
+
+func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
+ if !e.resetGen.canReset {
+ // Config.StatelessResetKey isn't set, so we don't send stateless resets.
+ return
+ }
+ // The smallest possible valid packet a peer can send us is:
+ // 1 byte of header
+ // connIDLen bytes of destination connection ID
+ // 1 byte of packet number
+ // 1 byte of payload
+ // 16 bytes AEAD expansion
+ if len(b) < 1+connIDLen+1+1+16 {
+ return
+ }
+ // TODO: Rate limit stateless resets.
+ cid := b[1:][:connIDLen]
+ token := e.resetGen.tokenForConnID(cid)
+ // We want to generate a stateless reset that is as short as possible,
+ // but long enough to be difficult to distinguish from a 1-RTT packet.
+ //
+ // The minimal 1-RTT packet is:
+ // 1 byte of header
+ // 0-20 bytes of destination connection ID
+ // 1-4 bytes of packet number
+ // 1 byte of payload
+ // 16 bytes AEAD expansion
+ //
+ // Assuming the maximum possible connection ID and packet number size,
+ // this gives 1 + 20 + 4 + 1 + 16 = 42 bytes.
+ //
+ // We also must generate a stateless reset that is shorter than the datagram
+ // we are responding to, in order to ensure that reset loops terminate.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3
+ size := min(len(b)-1, 42)
+ // Reuse the input buffer for generating the stateless reset.
+ b = b[:size]
+ rand.Read(b[:len(b)-statelessResetTokenLen])
+ b[0] &^= headerFormLong // clear long header bit
+ b[0] |= fixedBit // set fixed bit
+ copy(b[len(b)-statelessResetTokenLen:], token[:])
+ e.sendDatagram(datagram{
+ b: b,
+ peerAddr: peerAddr,
+ })
+}
+
+func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
+ m := newDatagram()
+ m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
+ m.peerAddr = peerAddr
+ e.sendDatagram(*m)
+ m.recycle()
+}
+
+func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
+ keys := initialKeys(in.dstConnID, serverSide)
+ var w packetWriter
+ p := longPacket{
+ ptype: packetTypeInitial,
+ version: quicVersion1,
+ num: 0,
+ dstConnID: in.srcConnID,
+ srcConnID: in.dstConnID,
+ }
+ const pnumMaxAcked = 0
+ w.reset(paddedInitialDatagramSize)
+ w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
+ w.appendConnectionCloseTransportFrame(code, 0, "")
+ w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
+ buf := w.datagram()
+ if len(buf) == 0 {
+ return
+ }
+ e.sendDatagram(datagram{
+ b: buf,
+ peerAddr: peerAddr,
+ })
+}
+
+func (e *Endpoint) sendDatagram(dgram datagram) error {
+ return e.packetConn.Write(dgram)
+}
+
+// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns.
+type connsMap struct {
+ byConnID map[string]*Conn
+ byResetToken map[statelessResetToken]*Conn
+
+ updateMu sync.Mutex
+ updateNeeded atomic.Bool
+ updates []func(*connsMap)
+}
+
+func (m *connsMap) init() {
+ m.byConnID = map[string]*Conn{}
+ m.byResetToken = map[statelessResetToken]*Conn{}
+}
+
+func (m *connsMap) addConnID(c *Conn, cid []byte) {
+ m.byConnID[string(cid)] = c
+}
+
+func (m *connsMap) retireConnID(c *Conn, cid []byte) {
+ delete(m.byConnID, string(cid))
+}
+
+func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
+ m.byResetToken[token] = c
+}
+
+func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
+ delete(m.byResetToken, token)
+}
+
+func (m *connsMap) updateConnIDs(f func(*connsMap)) {
+ m.updateMu.Lock()
+ defer m.updateMu.Unlock()
+ m.updates = append(m.updates, f)
+ m.updateNeeded.Store(true)
+}
+
+// applyUpdates is called by the datagram receive loop to update its connection ID map.
+func (m *connsMap) applyUpdates() {
+ m.updateMu.Lock()
+ defer m.updateMu.Unlock()
+ for _, f := range m.updates {
+ f(m)
+ }
+ clear(m.updates)
+ m.updates = m.updates[:0]
+ m.updateNeeded.Store(false)
+}
diff --git a/quic/endpoint_test.go b/quic/endpoint_test.go
new file mode 100644
index 0000000000..dc1c510971
--- /dev/null
+++ b/quic/endpoint_test.go
@@ -0,0 +1,341 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "io"
+ "log/slog"
+ "net/netip"
+ "runtime"
+ "testing"
+ "time"
+
+ "golang.org/x/net/quic/qlog"
+)
+
+func TestConnect(t *testing.T) {
+ newLocalConnPair(t, &Config{}, &Config{})
+}
+
+func TestConnectDefaultTLSConfig(t *testing.T) {
+ serverConfig := newTestTLSConfigWithMoreDefaults(serverSide)
+ clientConfig := newTestTLSConfigWithMoreDefaults(clientSide)
+ newLocalConnPair(t, &Config{TLSConfig: serverConfig}, &Config{TLSConfig: clientConfig})
+}
+
+func TestStreamTransfer(t *testing.T) {
+ ctx := context.Background()
+ cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+ data := makeTestData(1 << 20)
+
+ srvdone := make(chan struct{})
+ go func() {
+ defer close(srvdone)
+ s, err := srv.AcceptStream(ctx)
+ if err != nil {
+ t.Errorf("AcceptStream: %v", err)
+ return
+ }
+ b, err := io.ReadAll(s)
+ if err != nil {
+ t.Errorf("io.ReadAll(s): %v", err)
+ return
+ }
+ if !bytes.Equal(b, data) {
+ t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Errorf("s.Close() = %v", err)
+ }
+ }()
+
+ s, err := cli.NewSendOnlyStream(ctx)
+ if err != nil {
+ t.Fatalf("NewStream: %v", err)
+ }
+ n, err := io.Copy(s, bytes.NewBuffer(data))
+ if n != int64(len(data)) || err != nil {
+ t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Fatalf("s.Close() = %v", err)
+ }
+}
+
+func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS)
+ }
+ t.Helper()
+ ctx := context.Background()
+ e1 := newLocalEndpoint(t, serverSide, conf1)
+ e2 := newLocalEndpoint(t, clientSide, conf2)
+ conf2 = makeTestConfig(conf2, clientSide)
+ c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String(), conf2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1, err := e1.Accept(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c2, c1
+}
+
+func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint {
+ t.Helper()
+ conf = makeTestConfig(conf, side)
+ e, err := Listen("udp", "127.0.0.1:0", conf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(canceledContext())
+ })
+ return e
+}
+
+func makeTestConfig(conf *Config, side connSide) *Config {
+ if conf == nil {
+ return nil
+ }
+ newConf := *conf
+ conf = &newConf
+ if conf.TLSConfig == nil {
+ conf.TLSConfig = newTestTLSConfig(side)
+ }
+ if conf.QLogLogger == nil {
+ conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ Dir: *qlogdir,
+ }))
+ }
+ return conf
+}
+
+type testEndpoint struct {
+ t *testing.T
+ e *Endpoint
+ now time.Time
+ recvc chan *datagram
+ idlec chan struct{}
+ conns map[*Conn]*testConn
+ acceptQueue []*testConn
+ configTransportParams []func(*transportParameters)
+ configTestConn []func(*testConn)
+ sentDatagrams [][]byte
+ peerTLSConn *tls.QUICConn
+ lastInitialDstConnID []byte // for parsing Retry packets
+}
+
+func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
+ te := &testEndpoint{
+ t: t,
+ now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
+ recvc: make(chan *datagram),
+ idlec: make(chan struct{}),
+ conns: make(map[*Conn]*testConn),
+ }
+ var err error
+ te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(te.cleanup)
+ return te
+}
+
+func (te *testEndpoint) cleanup() {
+ te.e.Close(canceledContext())
+}
+
+func (te *testEndpoint) wait() {
+ select {
+ case te.idlec <- struct{}{}:
+ case <-te.e.closec:
+ }
+ for _, tc := range te.conns {
+ tc.wait()
+ }
+}
+
+// accept returns a server connection from the endpoint.
+// Unlike Endpoint.Accept, connections are available as soon as they are created.
+func (te *testEndpoint) accept() *testConn {
+ if len(te.acceptQueue) == 0 {
+ te.t.Fatalf("accept: expected available conn, but found none")
+ }
+ tc := te.acceptQueue[0]
+ te.acceptQueue = te.acceptQueue[1:]
+ return tc
+}
+
+func (te *testEndpoint) write(d *datagram) {
+ te.recvc <- d
+ te.wait()
+}
+
+var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
+
+func (te *testEndpoint) writeDatagram(d *testDatagram) {
+ te.t.Helper()
+ logDatagram(te.t, "<- endpoint under test receives", d)
+ var buf []byte
+ for _, p := range d.packets {
+ tc := te.connForDestination(p.dstConnID)
+ if p.ptype != packetTypeRetry && tc != nil {
+ space := spaceForPacketType(p.ptype)
+ if p.num >= tc.peerNextPacketNum[space] {
+ tc.peerNextPacketNum[space] = p.num + 1
+ }
+ }
+ if p.ptype == packetTypeInitial {
+ te.lastInitialDstConnID = p.dstConnID
+ }
+ pad := 0
+ if p.ptype == packetType1RTT {
+ pad = d.paddedSize - len(buf)
+ }
+ buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
+ }
+ for len(buf) < d.paddedSize {
+ buf = append(buf, 0)
+ }
+ te.write(&datagram{
+ b: buf,
+ peerAddr: d.addr,
+ })
+}
+
+func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
+ for _, tc := range te.conns {
+ for _, loc := range tc.conn.connIDState.local {
+ if bytes.Equal(loc.cid, dstConnID) {
+ return tc
+ }
+ }
+ }
+ return nil
+}
+
+func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
+ for _, tc := range te.conns {
+ for _, loc := range tc.conn.connIDState.remote {
+ if bytes.Equal(loc.cid, srcConnID) {
+ return tc
+ }
+ }
+ }
+ return nil
+}
+
+func (te *testEndpoint) read() []byte {
+ te.t.Helper()
+ te.wait()
+ if len(te.sentDatagrams) == 0 {
+ return nil
+ }
+ d := te.sentDatagrams[0]
+ te.sentDatagrams = te.sentDatagrams[1:]
+ return d
+}
+
+func (te *testEndpoint) readDatagram() *testDatagram {
+ te.t.Helper()
+ buf := te.read()
+ if buf == nil {
+ return nil
+ }
+ p, _ := parseGenericLongHeaderPacket(buf)
+ tc := te.connForSource(p.dstConnID)
+ d := parseTestDatagram(te.t, te, tc, buf)
+ logDatagram(te.t, "-> endpoint under test sends", d)
+ return d
+}
+
+// wantDatagram indicates that we expect the Endpoint to send a datagram.
+func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
+ te.t.Helper()
+ got := te.readDatagram()
+ if !datagramEqual(got, want) {
+ te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
+ }
+}
+
+// wantIdle indicates that we expect the Endpoint to not send any more datagrams.
+func (te *testEndpoint) wantIdle(expectation string) {
+ if got := te.readDatagram(); got != nil {
+ te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
+ }
+}
+
+// advance causes time to pass.
+func (te *testEndpoint) advance(d time.Duration) {
+ te.t.Helper()
+ te.advanceTo(te.now.Add(d))
+}
+
+// advanceTo sets the current time.
+func (te *testEndpoint) advanceTo(now time.Time) {
+ te.t.Helper()
+ if te.now.After(now) {
+ te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
+ }
+ te.now = now
+ for _, tc := range te.conns {
+ if !tc.timer.After(te.now) {
+ tc.conn.sendMsg(timerEvent{})
+ tc.wait()
+ }
+ }
+}
+
+// testEndpointHooks implements endpointTestHooks.
+type testEndpointHooks testEndpoint
+
+func (te *testEndpointHooks) timeNow() time.Time {
+ return te.now
+}
+
+func (te *testEndpointHooks) newConn(c *Conn) {
+ tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
+ te.conns[c] = tc
+}
+
+// testEndpointUDPConn implements UDPConn.
+type testEndpointUDPConn testEndpoint
+
+func (te *testEndpointUDPConn) Close() error {
+ close(te.recvc)
+ return nil
+}
+
+func (te *testEndpointUDPConn) LocalAddr() netip.AddrPort {
+ return netip.MustParseAddrPort("127.0.0.1:443")
+}
+
+func (te *testEndpointUDPConn) Read(f func(*datagram)) {
+ for {
+ select {
+ case d, ok := <-te.recvc:
+ if !ok {
+ return
+ }
+ f(d)
+ case <-te.idlec:
+ }
+ }
+}
+
+func (te *testEndpointUDPConn) Write(dgram datagram) error {
+ te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...))
+ return nil
+}
diff --git a/internal/quic/errors.go b/quic/errors.go
similarity index 92%
rename from internal/quic/errors.go
rename to quic/errors.go
index 8e01bb7cb7..b805b93c1b 100644
--- a/internal/quic/errors.go
+++ b/quic/errors.go
@@ -83,10 +83,16 @@ func (e transportError) String() string {
}
// A localTransportError is an error sent to the peer.
-type localTransportError transportError
+type localTransportError struct {
+ code transportError
+ reason string
+}
func (e localTransportError) Error() string {
- return "closed connection: " + transportError(e).String()
+ if e.reason == "" {
+ return fmt.Sprintf("closed connection: %v", e.code)
+ }
+ return fmt.Sprintf("closed connection: %v: %q", e.code, e.reason)
}
// A peerTransportError is an error received from the peer.
@@ -115,8 +121,7 @@ type ApplicationError struct {
}
func (e *ApplicationError) Error() string {
- // TODO: Include the Reason string here, but sanitize it first.
- return fmt.Sprintf("AppError %v", e.Code)
+ return fmt.Sprintf("peer closed connection: %v: %q", e.Code, e.Reason)
}
// Is reports a match if err is an *ApplicationError with a matching Code.
diff --git a/internal/quic/files_test.go b/quic/files_test.go
similarity index 100%
rename from internal/quic/files_test.go
rename to quic/files_test.go
diff --git a/internal/quic/frame_debug.go b/quic/frame_debug.go
similarity index 68%
rename from internal/quic/frame_debug.go
rename to quic/frame_debug.go
index 7a5aee57b1..17234dd7cd 100644
--- a/internal/quic/frame_debug.go
+++ b/quic/frame_debug.go
@@ -8,6 +8,9 @@ package quic
import (
"fmt"
+ "log/slog"
+ "strconv"
+ "time"
)
// A debugFrame is a representation of the contents of a QUIC frame,
@@ -15,6 +18,7 @@ import (
type debugFrame interface {
String() string
write(w *packetWriter) bool
+ LogValue() slog.Value
}
func parseDebugFrame(b []byte) (f debugFrame, n int) {
@@ -73,6 +77,7 @@ func parseDebugFrame(b []byte) (f debugFrame, n int) {
// debugFramePadding is a sequence of PADDING frames.
type debugFramePadding struct {
size int
+ to int // alternate for writing packets: pad to
}
func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) {
@@ -91,12 +96,23 @@ func (f debugFramePadding) write(w *packetWriter) bool {
if w.avail() == 0 {
return false
}
+ if f.to > 0 {
+ w.appendPaddingTo(f.to)
+ return true
+ }
for i := 0; i < f.size && w.avail() > 0; i++ {
w.b = append(w.b, frameTypePadding)
}
return true
}
+func (f debugFramePadding) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "padding"),
+ slog.Int("length", f.size),
+ )
+}
+
// debugFramePing is a PING frame.
type debugFramePing struct{}
@@ -112,6 +128,12 @@ func (f debugFramePing) write(w *packetWriter) bool {
return w.appendPingFrame()
}
+func (f debugFramePing) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "ping"),
+ )
+}
+
// debugFrameAck is an ACK frame.
type debugFrameAck struct {
ackDelay unscaledAckDelay
@@ -126,7 +148,7 @@ func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) {
end: end,
})
})
- // Ranges are parsed smallest to highest; reverse ranges slice to order them high to low.
+ // Ranges are parsed high to low; reverse ranges slice to order them low to high.
for i := 0; i < len(f.ranges)/2; i++ {
j := len(f.ranges) - 1
f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i]
@@ -146,6 +168,61 @@ func (f debugFrameAck) write(w *packetWriter) bool {
return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay)
}
+func (f debugFrameAck) LogValue() slog.Value {
+ return slog.StringValue("error: debugFrameAck should not appear as a slog Value")
+}
+
+// debugFrameScaledAck is an ACK frame with scaled ACK Delay.
+//
+// This type is used in qlog events, which need access to the delay as a duration.
+type debugFrameScaledAck struct {
+ ackDelay time.Duration
+ ranges []i64range[packetNumber]
+}
+
+func (f debugFrameScaledAck) LogValue() slog.Value {
+ var ackDelay slog.Attr
+ if f.ackDelay >= 0 {
+ ackDelay = slog.Duration("ack_delay", f.ackDelay)
+ }
+ return slog.GroupValue(
+ slog.String("frame_type", "ack"),
+ // Rather than trying to convert the ack ranges into the slog data model,
+ // pass a value that can JSON-encode itself.
+ slog.Any("acked_ranges", debugAckRanges(f.ranges)),
+ ackDelay,
+ )
+}
+
+type debugAckRanges []i64range[packetNumber]
+
+// AppendJSON appends a JSON encoding of the ack ranges to b, and returns it.
+// This is different than the standard json.Marshaler, but more efficient.
+// Since we only use this in cooperation with the qlog package,
+// encoding/json compatibility is irrelevant.
+func (r debugAckRanges) AppendJSON(b []byte) []byte {
+ b = append(b, '[')
+ for i, ar := range r {
+ start, end := ar.start, ar.end-1 // qlog ranges are closed-closed
+ if i != 0 {
+ b = append(b, ',')
+ }
+ b = append(b, '[')
+ b = strconv.AppendInt(b, int64(start), 10)
+ if start != end {
+ b = append(b, ',')
+ b = strconv.AppendInt(b, int64(end), 10)
+ }
+ b = append(b, ']')
+ }
+ b = append(b, ']')
+ return b
+}
+
+func (r debugAckRanges) String() string {
+ return string(r.AppendJSON(nil))
+}
+
// debugFrameResetStream is a RESET_STREAM frame.
type debugFrameResetStream struct {
id streamID
@@ -166,6 +243,14 @@ func (f debugFrameResetStream) write(w *packetWriter) bool {
return w.appendResetStreamFrame(f.id, f.code, f.finalSize)
}
+func (f debugFrameResetStream) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "reset_stream"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Uint64("final_size", uint64(f.finalSize)),
+ )
+}
+
// debugFrameStopSending is a STOP_SENDING frame.
type debugFrameStopSending struct {
id streamID
@@ -185,6 +270,14 @@ func (f debugFrameStopSending) write(w *packetWriter) bool {
return w.appendStopSendingFrame(f.id, f.code)
}
+func (f debugFrameStopSending) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "stop_sending"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Uint64("error_code", uint64(f.code)),
+ )
+}
+
// debugFrameCrypto is a CRYPTO frame.
type debugFrameCrypto struct {
off int64
@@ -206,6 +299,14 @@ func (f debugFrameCrypto) write(w *packetWriter) bool {
return added
}
+func (f debugFrameCrypto) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "crypto"),
+ slog.Int64("offset", f.off),
+ slog.Int("length", len(f.data)),
+ )
+}
+
// debugFrameNewToken is a NEW_TOKEN frame.
type debugFrameNewToken struct {
token []byte
@@ -224,6 +325,13 @@ func (f debugFrameNewToken) write(w *packetWriter) bool {
return w.appendNewTokenFrame(f.token)
}
+func (f debugFrameNewToken) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "new_token"),
+ slogHexstring("token", f.token),
+ )
+}
+
// debugFrameStream is a STREAM frame.
type debugFrameStream struct {
id streamID
@@ -251,6 +359,20 @@ func (f debugFrameStream) write(w *packetWriter) bool {
return added
}
+func (f debugFrameStream) LogValue() slog.Value {
+ var fin slog.Attr
+ if f.fin {
+ fin = slog.Bool("fin", true)
+ }
+ return slog.GroupValue(
+ slog.String("frame_type", "stream"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("offset", f.off),
+ slog.Int("length", len(f.data)),
+ fin,
+ )
+}
+
// debugFrameMaxData is a MAX_DATA frame.
type debugFrameMaxData struct {
max int64
@@ -269,6 +391,13 @@ func (f debugFrameMaxData) write(w *packetWriter) bool {
return w.appendMaxDataFrame(f.max)
}
+func (f debugFrameMaxData) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_data"),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameMaxStreamData is a MAX_STREAM_DATA frame.
type debugFrameMaxStreamData struct {
id streamID
@@ -288,6 +417,14 @@ func (f debugFrameMaxStreamData) write(w *packetWriter) bool {
return w.appendMaxStreamDataFrame(f.id, f.max)
}
+func (f debugFrameMaxStreamData) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_stream_data"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameMaxStreams is a MAX_STREAMS frame.
type debugFrameMaxStreams struct {
streamType streamType
@@ -307,6 +444,14 @@ func (f debugFrameMaxStreams) write(w *packetWriter) bool {
return w.appendMaxStreamsFrame(f.streamType, f.max)
}
+func (f debugFrameMaxStreams) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_streams"),
+ slog.String("stream_type", f.streamType.qlogString()),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameDataBlocked is a DATA_BLOCKED frame.
type debugFrameDataBlocked struct {
max int64
@@ -325,6 +470,13 @@ func (f debugFrameDataBlocked) write(w *packetWriter) bool {
return w.appendDataBlockedFrame(f.max)
}
+func (f debugFrameDataBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "data_blocked"),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame.
type debugFrameStreamDataBlocked struct {
id streamID
@@ -344,6 +496,14 @@ func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool {
return w.appendStreamDataBlockedFrame(f.id, f.max)
}
+func (f debugFrameStreamDataBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "stream_data_blocked"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameStreamsBlocked is a STREAMS_BLOCKED frame.
type debugFrameStreamsBlocked struct {
streamType streamType
@@ -363,12 +523,20 @@ func (f debugFrameStreamsBlocked) write(w *packetWriter) bool {
return w.appendStreamsBlockedFrame(f.streamType, f.max)
}
+func (f debugFrameStreamsBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "streams_blocked"),
+ slog.String("stream_type", f.streamType.qlogString()),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameNewConnectionID is a NEW_CONNECTION_ID frame.
type debugFrameNewConnectionID struct {
seq int64
retirePriorTo int64
connID []byte
- token [16]byte
+ token statelessResetToken
}
func parseDebugFrameNewConnectionID(b []byte) (f debugFrameNewConnectionID, n int) {
@@ -384,6 +552,16 @@ func (f debugFrameNewConnectionID) write(w *packetWriter) bool {
return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token)
}
+func (f debugFrameNewConnectionID) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "new_connection_id"),
+ slog.Int64("sequence_number", f.seq),
+ slog.Int64("retire_prior_to", f.retirePriorTo),
+ slogHexstring("connection_id", f.connID),
+ slogHexstring("stateless_reset_token", f.token[:]),
+ )
+}
+
// debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame.
type debugFrameRetireConnectionID struct {
seq int64
@@ -402,9 +580,16 @@ func (f debugFrameRetireConnectionID) write(w *packetWriter) bool {
return w.appendRetireConnectionIDFrame(f.seq)
}
+func (f debugFrameRetireConnectionID) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "retire_connection_id"),
+ slog.Int64("sequence_number", f.seq),
+ )
+}
+
// debugFramePathChallenge is a PATH_CHALLENGE frame.
type debugFramePathChallenge struct {
- data uint64
+ data pathChallengeData
}
func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) {
@@ -413,16 +598,23 @@ func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) {
}
func (f debugFramePathChallenge) String() string {
- return fmt.Sprintf("PATH_CHALLENGE Data=%016x", f.data)
+ return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data)
}
func (f debugFramePathChallenge) write(w *packetWriter) bool {
return w.appendPathChallengeFrame(f.data)
}
+func (f debugFramePathChallenge) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "path_challenge"),
+ slog.String("data", fmt.Sprintf("%x", f.data)),
+ )
+}
+
// debugFramePathResponse is a PATH_RESPONSE frame.
type debugFramePathResponse struct {
- data uint64
+ data pathChallengeData
}
func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) {
@@ -431,13 +623,20 @@ func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) {
}
func (f debugFramePathResponse) String() string {
- return fmt.Sprintf("PATH_RESPONSE Data=%016x", f.data)
+ return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data)
}
func (f debugFramePathResponse) write(w *packetWriter) bool {
return w.appendPathResponseFrame(f.data)
}
+func (f debugFramePathResponse) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "path_response"),
+ slog.String("data", fmt.Sprintf("%x", f.data)),
+ )
+}
+
// debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error.
type debugFrameConnectionCloseTransport struct {
code transportError
@@ -465,6 +664,15 @@ func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool {
return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason)
}
+func (f debugFrameConnectionCloseTransport) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "connection_close"),
+ slog.String("error_space", "transport"),
+ slog.Uint64("error_code_value", uint64(f.code)),
+ slog.String("reason", f.reason),
+ )
+}
+
// debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error.
type debugFrameConnectionCloseApplication struct {
code uint64
@@ -488,6 +696,15 @@ func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool {
return w.appendConnectionCloseApplicationFrame(f.code, f.reason)
}
+func (f debugFrameConnectionCloseApplication) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "connection_close"),
+ slog.String("error_space", "application"),
+ slog.Uint64("error_code_value", uint64(f.code)),
+ slog.String("reason", f.reason),
+ )
+}
+
// debugFrameHandshakeDone is a HANDSHAKE_DONE frame.
type debugFrameHandshakeDone struct{}
@@ -502,3 +719,9 @@ func (f debugFrameHandshakeDone) String() string {
func (f debugFrameHandshakeDone) write(w *packetWriter) bool {
return w.appendHandshakeDoneFrame()
}
+
+func (f debugFrameHandshakeDone) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "handshake_done"),
+ )
+}
diff --git a/internal/quic/gate.go b/quic/gate.go
similarity index 94%
rename from internal/quic/gate.go
rename to quic/gate.go
index a2fb537115..8f1db2be66 100644
--- a/internal/quic/gate.go
+++ b/quic/gate.go
@@ -27,7 +27,7 @@ func newGate() gate {
return g
}
-// newLocked gate returns a new, locked gate.
+// newLockedGate returns a new, locked gate.
func newLockedGate() gate {
return gate{
set: make(chan struct{}, 1),
@@ -84,7 +84,7 @@ func (g *gate) unlock(set bool) {
}
}
-// unlock sets the condition to the result of f and releases the gate.
+// unlockFunc sets the condition to the result of f and releases the gate.
// Useful in defers.
func (g *gate) unlockFunc(f func() bool) {
g.unlock(f())
diff --git a/internal/quic/gate_test.go b/quic/gate_test.go
similarity index 100%
rename from internal/quic/gate_test.go
rename to quic/gate_test.go
diff --git a/internal/quic/gotraceback_test.go b/quic/gotraceback_test.go
similarity index 100%
rename from internal/quic/gotraceback_test.go
rename to quic/gotraceback_test.go
diff --git a/quic/idle.go b/quic/idle.go
new file mode 100644
index 0000000000..f5b2422adb
--- /dev/null
+++ b/quic/idle.go
@@ -0,0 +1,170 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "time"
+)
+
+// idleState tracks connection idle events.
+//
+// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout.
+//
+// After the handshake is confirmed, the idle timeout is
+// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter.
+//
+// If KeepAlivePeriod is set, keep-alive pings are sent.
+// Keep-alives are only sent after the handshake is confirmed.
+//
+// https://www.rfc-editor.org/rfc/rfc9000#section-10.1
+type idleState struct {
+ // idleDuration is the negotiated idle timeout for the connection.
+ idleDuration time.Duration
+
+ // idleTimeout is the time at which the connection will be closed due to inactivity.
+ idleTimeout time.Time
+
+ // nextTimeout is the time of the next idle event.
+ // If nextTimeout == idleTimeout, this is the idle timeout.
+ // Otherwise, this is the keep-alive timeout.
+ nextTimeout time.Time
+
+ // sentSinceLastReceive is set if we have sent an ack-eliciting packet
+ // since the last time we received and processed a packet from the peer.
+ sentSinceLastReceive bool
+}
+
+// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter.
+func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) {
+ localMaxIdleTimeout := c.config.maxIdleTimeout()
+ switch {
+ case localMaxIdleTimeout == 0:
+ c.idle.idleDuration = peerMaxIdleTimeout
+ case peerMaxIdleTimeout == 0:
+ c.idle.idleDuration = localMaxIdleTimeout
+ default:
+ c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout)
+ }
+}
+
+func (c *Conn) idleHandlePacketReceived(now time.Time) {
+ if !c.handshakeConfirmed.isSet() {
+ return
+ }
+ // "An endpoint restarts its idle timer when a packet from its peer is
+ // received and processed successfully."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3
+ c.idle.sentSinceLastReceive = false
+ c.restartIdleTimer(now)
+}
+
+func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) {
+ // "An endpoint also restarts its idle timer when sending an ack-eliciting packet
+ // if no other ack-eliciting packets have been sent since
+ // last receiving and processing a packet."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3
+ if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() {
+ return
+ }
+ c.idle.sentSinceLastReceive = true
+ c.restartIdleTimer(now)
+}
+
+func (c *Conn) restartIdleTimer(now time.Time) {
+ if !c.isAlive() {
+ // Connection is closing, disable timeouts.
+ c.idle.idleTimeout = time.Time{}
+ c.idle.nextTimeout = time.Time{}
+ return
+ }
+ var idleDuration time.Duration
+ if c.handshakeConfirmed.isSet() {
+ idleDuration = c.idle.idleDuration
+ } else {
+ idleDuration = c.config.handshakeTimeout()
+ }
+ if idleDuration == 0 {
+ c.idle.idleTimeout = time.Time{}
+ } else {
+ // "[...] endpoints MUST increase the idle timeout period to be
+ // at least three times the current Probe Timeout (PTO)."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4
+ idleDuration = max(idleDuration, 3*c.loss.ptoPeriod())
+ c.idle.idleTimeout = now.Add(idleDuration)
+ }
+ // Set the time of our next event:
+ // The idle timer if no keep-alive is set, or the keep-alive timer if one is.
+ c.idle.nextTimeout = c.idle.idleTimeout
+ keepAlive := c.config.keepAlivePeriod()
+ switch {
+ case !c.handshakeConfirmed.isSet():
+ // We do not send keep-alives before the handshake is complete.
+ case keepAlive <= 0:
+ // Keep-alives are not enabled.
+ case c.idle.sentSinceLastReceive:
+ // We have sent an ack-eliciting packet to the peer.
+ // If they don't acknowledge it, loss detection will follow up with PTO probes,
+ // which will function as keep-alives.
+ // We don't need to send further pings.
+ case idleDuration == 0:
+ // The connection does not have a negotiated idle timeout.
+ // Send keep-alives anyway, since they may be required to keep middleboxes
+ // from losing state.
+ c.idle.nextTimeout = now.Add(keepAlive)
+ default:
+ // Schedule our next keep-alive.
+ // If our configured keep-alive period is greater than half the negotiated
+ // connection idle timeout, we reduce the keep-alive period to half
+ // the idle timeout to ensure we have time for the ping to arrive.
+ c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2))
+ }
+}
+
+func (c *Conn) appendKeepAlive(now time.Time) bool {
+ if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) {
+ return true // timer has not expired
+ }
+ if c.idle.nextTimeout.Equal(c.idle.idleTimeout) {
+ return true // no keepalive timer set, only idle
+ }
+ if c.idle.sentSinceLastReceive {
+ return true // already sent an ack-eliciting packet
+ }
+ if c.w.sent.ackEliciting {
+ return true // this packet is already ack-eliciting
+ }
+ // Send an ack-eliciting PING frame to the peer to keep the connection alive.
+ return c.w.appendPingFrame()
+}
+
+var errHandshakeTimeout error = localTransportError{
+ code: errConnectionRefused,
+ reason: "handshake timeout",
+}
+
+func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) {
+ if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) {
+ return false
+ }
+ c.idle.idleTimeout = time.Time{}
+ c.idle.nextTimeout = time.Time{}
+ if !c.handshakeConfirmed.isSet() {
+ // Handshake timeout has expired.
+ // If we're a server, we're refusing the too-slow client.
+ // If we're a client, we're giving up.
+ // In either case, we're going to send a CONNECTION_CLOSE frame and
+ // enter the closing state rather than unceremoniously dropping the connection,
+ // since the peer might still be trying to complete the handshake.
+ c.abort(now, errHandshakeTimeout)
+ return false
+ }
+ // Idle timeout has expired.
+ //
+ // "[...] the connection is silently closed and its state is discarded [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1
+ return true
+}
diff --git a/quic/idle_test.go b/quic/idle_test.go
new file mode 100644
index 0000000000..18f6a690a4
--- /dev/null
+++ b/quic/idle_test.go
@@ -0,0 +1,225 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "testing"
+ "time"
+)
+
+func TestHandshakeTimeoutExpiresServer(t *testing.T) {
+ const timeout = 5 * time.Second
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.HandshakeTimeout = timeout
+ })
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ // Server starts its end of the handshake.
+ // Client acks these packets to avoid starting the PTO timer.
+ tc.wantFrameType("server sends Initial CRYPTO flight",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.wantFrameType("server sends Handshake CRYPTO flight",
+ packetTypeHandshake, debugFrameCrypto{})
+ tc.writeAckForAll()
+
+ if got, want := tc.timerDelay(), timeout; got != want {
+ t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want)
+ }
+
+ // Client sends a packet, but this does not extend the handshake timer.
+ tc.advance(1 * time.Second)
+ tc.writeFrames(packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], // partial data
+ })
+ tc.wantIdle("handshake is not complete")
+
+ tc.advance(timeout - 1*time.Second)
+ tc.wantFrame("server closes connection after handshake timeout",
+ packetTypeHandshake, debugFrameConnectionCloseTransport{
+ code: errConnectionRefused,
+ })
+}
+
+func TestHandshakeTimeoutExpiresClient(t *testing.T) {
+ const timeout = 5 * time.Second
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ c.HandshakeTimeout = timeout
+ })
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ // Start the handshake.
+ // The client always sets a PTO timer until it gets an ack for a handshake packet
+ // or confirms the handshake, so proceed far enough through the handshake to
+ // let us not worry about PTO.
+ tc.wantFrameType("client sends Initial CRYPTO flight",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrameType("client sends Handshake CRYPTO flight",
+ packetTypeHandshake, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.wantIdle("client is waiting for end of handshake")
+
+ if got, want := tc.timerDelay(), timeout; got != want {
+ t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want)
+ }
+ tc.advance(timeout)
+ tc.wantFrame("client closes connection after handshake timeout",
+ packetTypeHandshake, debugFrameConnectionCloseTransport{
+ code: errConnectionRefused,
+ })
+}
+
+func TestIdleTimeoutExpires(t *testing.T) {
+ for _, test := range []struct {
+ localMaxIdleTimeout time.Duration
+ peerMaxIdleTimeout time.Duration
+ wantTimeout time.Duration
+ }{{
+ localMaxIdleTimeout: 10 * time.Second,
+ peerMaxIdleTimeout: 20 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 20 * time.Second,
+ peerMaxIdleTimeout: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 0,
+ peerMaxIdleTimeout: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 10 * time.Second,
+ peerMaxIdleTimeout: 0,
+ wantTimeout: 10 * time.Second,
+ }} {
+ name := fmt.Sprintf("local=%v/peer=%v", test.localMaxIdleTimeout, test.peerMaxIdleTimeout)
+ t.Run(name, func(t *testing.T) {
+ tc := newTestConn(t, serverSide, func(p *transportParameters) {
+ p.maxIdleTimeout = test.peerMaxIdleTimeout
+ }, func(c *Config) {
+ c.MaxIdleTimeout = test.localMaxIdleTimeout
+ })
+ tc.handshake()
+ if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want {
+ t.Errorf("new conn timeout=%v, want %v (idle timeout)", got, want)
+ }
+ tc.advance(test.wantTimeout - 1)
+ tc.wantIdle("connection is idle and alive prior to timeout")
+ ctx := canceledContext()
+ if err := tc.conn.Wait(ctx); err != context.Canceled {
+ t.Fatalf("conn.Wait() = %v, want Canceled", err)
+ }
+ tc.advance(1)
+ tc.wantIdle("connection exits after timeout")
+ if err := tc.conn.Wait(ctx); err != errIdleTimeout {
+ t.Fatalf("conn.Wait() = %v, want errIdleTimeout", err)
+ }
+ })
+ }
+}
+
+func TestIdleTimeoutKeepAlive(t *testing.T) {
+ for _, test := range []struct {
+ idleTimeout time.Duration
+ keepAlive time.Duration
+ wantTimeout time.Duration
+ }{{
+ idleTimeout: 30 * time.Second,
+ keepAlive: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ idleTimeout: 10 * time.Second,
+ keepAlive: 30 * time.Second,
+ wantTimeout: 5 * time.Second,
+ }, {
+ idleTimeout: -1, // disabled
+ keepAlive: 30 * time.Second,
+ wantTimeout: 30 * time.Second,
+ }} {
+ name := fmt.Sprintf("idle_timeout=%v/keepalive=%v", test.idleTimeout, test.keepAlive)
+ t.Run(name, func(t *testing.T) {
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.MaxIdleTimeout = test.idleTimeout
+ c.KeepAlivePeriod = test.keepAlive
+ })
+ tc.handshake()
+ if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want {
+ t.Errorf("new conn timeout=%v, want %v (keepalive timeout)", got, want)
+ }
+ tc.advance(test.wantTimeout - 1)
+ tc.wantIdle("connection is idle prior to timeout")
+ tc.advance(1)
+ tc.wantFrameType("keep-alive ping is sent", packetType1RTT,
+ debugFramePing{})
+ })
+ }
+}
+
+func TestIdleLongTermKeepAliveSent(t *testing.T) {
+ // This test examines a connection sitting idle and sending periodic keep-alive pings.
+ const keepAlivePeriod = 30 * time.Second
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ c.KeepAlivePeriod = keepAlivePeriod
+ c.MaxIdleTimeout = -1
+ })
+ tc.handshake()
+ // The handshake will have completed a little bit after the point at which the
+ // keepalive timer was set. Send two PING frames to the conn, triggering an immediate ack
+ // and resetting the timer.
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.wantFrameType("conn acks received pings", packetType1RTT, debugFrameAck{})
+ for i := 0; i < 10; i++ {
+ tc.wantIdle("conn has nothing more to send")
+ if got, want := tc.timeUntilEvent(), keepAlivePeriod; got != want {
+ t.Errorf("i=%v conn timeout=%v, want %v (keepalive timeout)", i, got, want)
+ }
+ tc.advance(keepAlivePeriod)
+ tc.wantFrameType("keep-alive ping is sent", packetType1RTT,
+ debugFramePing{})
+ tc.writeAckForAll()
+ }
+}
+
+func TestIdleLongTermKeepAliveReceived(t *testing.T) {
+ // This test examines a connection sitting idle, but receiving periodic peer
+ // traffic to keep the connection alive.
+ const idleTimeout = 30 * time.Second
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.MaxIdleTimeout = idleTimeout
+ })
+ tc.handshake()
+ for i := 0; i < 10; i++ {
+ tc.advance(idleTimeout - 1*time.Second)
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ if got, want := tc.timeUntilEvent(), maxAckDelay-timerGranularity; got != want {
+ t.Errorf("i=%v conn timeout=%v, want %v (max_ack_delay)", i, got, want)
+ }
+ tc.advanceToTimer()
+ tc.wantFrameType("conn acks received ping", packetType1RTT, debugFrameAck{})
+ }
+ // Connection is still alive.
+ ctx := canceledContext()
+ if err := tc.conn.Wait(ctx); err != context.Canceled {
+ t.Fatalf("conn.Wait() = %v, want Canceled", err)
+ }
+}
diff --git a/internal/quic/key_update_test.go b/quic/key_update_test.go
similarity index 100%
rename from internal/quic/key_update_test.go
rename to quic/key_update_test.go
diff --git a/internal/quic/log.go b/quic/log.go
similarity index 100%
rename from internal/quic/log.go
rename to quic/log.go
diff --git a/internal/quic/loss.go b/quic/loss.go
similarity index 85%
rename from internal/quic/loss.go
rename to quic/loss.go
index 152815a291..796b5f7a34 100644
--- a/internal/quic/loss.go
+++ b/quic/loss.go
@@ -7,6 +7,8 @@
package quic
import (
+ "context"
+ "log/slog"
"math"
"time"
)
@@ -50,6 +52,9 @@ type lossState struct {
// https://www.rfc-editor.org/rfc/rfc9000#section-8-2
antiAmplificationLimit int
+ // Count of non-ack-eliciting packets (ACKs) sent since the last ack-eliciting one.
+ consecutiveNonAckElicitingPackets int
+
rtt rttState
pacer pacerState
cc *ccReno
@@ -176,7 +181,7 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber {
}
// packetSent records a sent packet.
-func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
sent.time = now
c.spaces[space].add(sent)
size := sent.size
@@ -184,13 +189,21 @@ func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacke
c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size)
}
if sent.inFlight {
- c.cc.packetSent(now, space, sent)
+ c.cc.packetSent(now, log, space, sent)
c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT)
if sent.ackEliciting {
c.spaces[space].lastAckEliciting = sent.num
c.ptoExpired = false // reset expired PTO timer after sending probe
}
c.scheduleTimer(now)
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
+ }
+ if sent.ackEliciting {
+ c.consecutiveNonAckElicitingPackets = 0
+ } else {
+ c.consecutiveNonAckElicitingPackets++
}
}
@@ -259,7 +272,7 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex
// receiveAckEnd finishes processing an ack frame.
// The lossf function is called for each packet newly detected as lost.
-func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) {
+func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) {
c.spaces[space].sentPacketList.clean()
// Update the RTT sample when the largest acknowledged packet in the ACK frame
// is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting.
@@ -278,11 +291,44 @@ func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay tim
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3
c.timer = time.Time{}
c.detectLoss(now, lossf)
- c.cc.packetBatchEnd(now, space, &c.rtt, c.maxAckDelay)
+ c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay)
+
+ if logEnabled(log, QLogLevelPacket) {
+ var ssthresh slog.Attr
+ if c.cc.slowStartThreshold != math.MaxInt {
+ ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold)
+ }
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:metrics_updated",
+ slog.Duration("min_rtt", c.rtt.minRTT),
+ slog.Duration("smoothed_rtt", c.rtt.smoothedRTT),
+ slog.Duration("latest_rtt", c.rtt.latestRTT),
+ slog.Duration("rtt_variance", c.rtt.rttvar),
+ slog.Int("congestion_window", c.cc.congestionWindow),
+ slog.Int("bytes_in_flight", c.cc.bytesInFlight),
+ ssthresh,
+ )
+ }
+}
+
+// discardPackets declares that packets within a number space will not be delivered
+// and that data contained in them should be resent.
+// For example, after receiving a Retry packet we discard already-sent Initial packets.
+func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) {
+ for i := 0; i < c.spaces[space].size; i++ {
+ sent := c.spaces[space].nth(i)
+ sent.lost = true
+ c.cc.packetDiscarded(sent)
+ lossf(numberSpace(space), sent, packetLost)
+ }
+ c.spaces[space].clean()
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
}
// discardKeys is called when dropping packet protection keys for a number space.
-func (c *lossState) discardKeys(now time.Time, space numberSpace) {
+func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4
for i := 0; i < c.spaces[space].size; i++ {
sent := c.spaces[space].nth(i)
@@ -292,6 +338,9 @@ func (c *lossState) discardKeys(now time.Time, space numberSpace) {
c.spaces[space].maxAcked = -1
c.spaces[space].lastAckEliciting = -1
c.scheduleTimer(now)
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
}
func (c *lossState) lossDuration() time.Duration {
@@ -418,12 +467,15 @@ func (c *lossState) scheduleTimer(now time.Time) {
c.timer = time.Time{}
return
}
- // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
- pto := c.ptoBasePeriod() << c.ptoBackoffCount
- c.timer = last.Add(pto)
+ c.timer = last.Add(c.ptoPeriod())
c.ptoTimerArmed = true
}
+func (c *lossState) ptoPeriod() time.Duration {
+ // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
+ return c.ptoBasePeriod() << c.ptoBackoffCount
+}
+
func (c *lossState) ptoBasePeriod() time.Duration {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity)
@@ -435,3 +487,10 @@ func (c *lossState) ptoBasePeriod() time.Duration {
}
return pto
}
+
+func logBytesInFlight(log *slog.Logger, bytesInFlight int) {
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:metrics_updated",
+ slog.Int("bytes_in_flight", bytesInFlight),
+ )
+}
diff --git a/internal/quic/loss_test.go b/quic/loss_test.go
similarity index 99%
rename from internal/quic/loss_test.go
rename to quic/loss_test.go
index efbf1649ec..1fb9662e4c 100644
--- a/internal/quic/loss_test.go
+++ b/quic/loss_test.go
@@ -1060,7 +1060,7 @@ func TestLossPersistentCongestion(t *testing.T) {
maxDatagramSize: 1200,
})
test.send(initialSpace, 0, testSentPacketSize(1200))
- test.c.cc.setUnderutilized(true)
+ test.c.cc.setUnderutilized(nil, true)
test.advance(10 * time.Millisecond)
test.ack(initialSpace, 0*time.Millisecond, i64range[packetNumber]{0, 1})
@@ -1377,7 +1377,7 @@ func (c *lossTest) setRTTVar(d time.Duration) {
func (c *lossTest) setUnderutilized(v bool) {
c.t.Logf("set congestion window underutilized: %v", v)
- c.c.cc.setUnderutilized(v)
+ c.c.cc.setUnderutilized(nil, v)
}
func (c *lossTest) advance(d time.Duration) {
@@ -1438,7 +1438,7 @@ func (c *lossTest) send(spaceID numberSpace, opts ...any) {
sent := &sentPacket{}
*sent = prototype
sent.num = num
- c.c.packetSent(c.now, spaceID, sent)
+ c.c.packetSent(c.now, nil, spaceID, sent)
}
}
@@ -1462,7 +1462,7 @@ func (c *lossTest) ack(spaceID numberSpace, ackDelay time.Duration, rs ...i64ran
c.t.Logf("ack %v delay=%v [%v,%v)", spaceID, ackDelay, r.start, r.end)
c.c.receiveAckRange(c.now, spaceID, i, r.start, r.end, c.onAckOrLoss)
}
- c.c.receiveAckEnd(c.now, spaceID, ackDelay, c.onAckOrLoss)
+ c.c.receiveAckEnd(c.now, nil, spaceID, ackDelay, c.onAckOrLoss)
}
func (c *lossTest) onAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) {
@@ -1491,7 +1491,7 @@ func (c *lossTest) discardKeys(spaceID numberSpace) {
c.t.Helper()
c.checkUnexpectedEvents()
c.t.Logf("discard %s keys", spaceID)
- c.c.discardKeys(c.now, spaceID)
+ c.c.discardKeys(c.now, nil, spaceID)
}
func (c *lossTest) setMaxAckDelay(d time.Duration) {
diff --git a/quic/main_test.go b/quic/main_test.go
new file mode 100644
index 0000000000..25e0096e43
--- /dev/null
+++ b/quic/main_test.go
@@ -0,0 +1,72 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func TestMain(m *testing.M) {
+ // Add all goroutines running at the start of the test to the set
+ // of not-leaked goroutines. This includes TestMain, and anything else
+ // that might have been started by test infrastructure.
+ skip := [][]byte{
+ []byte("created by os/signal.Notify"),
+ []byte("gotraceback_test.go"),
+ }
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ for _, g := range bytes.Split(buf, []byte("\n\n")) {
+ id, _, _ := bytes.Cut(g, []byte("["))
+ skip = append(skip, id)
+ }
+
+ defer os.Exit(m.Run())
+
+ // Look for leaked goroutines.
+ //
+ // Checking after every test makes it easier to tell which test is the culprit,
+ // but checking once at the end is faster and less likely to miss something.
+ if runtime.GOOS == "js" {
+ // The js-wasm runtime creates an additional background goroutine.
+ // Just skip the leak check there.
+ return
+ }
+ start := time.Now()
+ warned := false
+ for {
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ leaked := false
+ for _, g := range bytes.Split(buf, []byte("\n\n")) {
+ leaked = true
+ for _, s := range skip {
+ if bytes.Contains(g, s) {
+ leaked = false
+ break
+ }
+ }
+ }
+ if !leaked {
+ break
+ }
+ if !warned && time.Since(start) > 1*time.Second {
+ // Print a warning quickly, in case this is an interactive session.
+ // Keep waiting until the test times out, in case this is a slow trybot.
+ fmt.Printf("Tests seem to have leaked some goroutines, still waiting.\n\n")
+ fmt.Print(string(buf))
+ warned = true
+ }
+ // Goroutines might still be shutting down.
+ time.Sleep(1 * time.Millisecond)
+ }
+}
diff --git a/internal/quic/math.go b/quic/math.go
similarity index 100%
rename from internal/quic/math.go
rename to quic/math.go
diff --git a/internal/quic/pacer.go b/quic/pacer.go
similarity index 100%
rename from internal/quic/pacer.go
rename to quic/pacer.go
diff --git a/internal/quic/pacer_test.go b/quic/pacer_test.go
similarity index 100%
rename from internal/quic/pacer_test.go
rename to quic/pacer_test.go
diff --git a/internal/quic/packet.go b/quic/packet.go
similarity index 88%
rename from internal/quic/packet.go
rename to quic/packet.go
index 7d69f96d27..883754f021 100644
--- a/internal/quic/packet.go
+++ b/quic/packet.go
@@ -9,6 +9,8 @@ package quic
import (
"encoding/binary"
"fmt"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// packetType is a QUIC packet type.
@@ -41,6 +43,22 @@ func (p packetType) String() string {
return fmt.Sprintf("unknown packet type %v", byte(p))
}
+func (p packetType) qlogString() string {
+ switch p {
+ case packetTypeInitial:
+ return "initial"
+ case packetType0RTT:
+ return "0RTT"
+ case packetTypeHandshake:
+ return "handshake"
+ case packetTypeRetry:
+ return "retry"
+ case packetType1RTT:
+ return "1RTT"
+ }
+ return "unknown"
+}
+
// Bits set in the first byte of a packet.
const (
headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1
@@ -97,6 +115,9 @@ const (
streamFinBit = 0x01
)
+// Maximum length of a connection ID.
+const maxConnIDLen = 20
+
// isLongHeader returns true if b is the first byte of a long header.
func isLongHeader(b byte) bool {
return b&headerFormLong == headerFormLong
@@ -177,10 +198,10 @@ func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte)
// appendVersionNegotiation appends a Version Negotiation packet to pkt,
// returning the result.
func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte {
- pkt = append(pkt, headerFormLong|fixedBit) // header byte
- pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation)
- pkt = appendUint8Bytes(pkt, dstConnID) // Destination Connection ID
- pkt = appendUint8Bytes(pkt, srcConnID) // Source Connection ID
+ pkt = append(pkt, headerFormLong|fixedBit) // header byte
+ pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation)
+ pkt = quicwire.AppendUint8Bytes(pkt, dstConnID) // Destination Connection ID
+ pkt = quicwire.AppendUint8Bytes(pkt, srcConnID) // Source Connection ID
for _, v := range versions {
pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version
}
@@ -224,21 +245,21 @@ func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) {
b = b[1:]
// Version (32),
var n int
- p.version, n = consumeUint32(b)
+ p.version, n = quicwire.ConsumeUint32(b)
if n < 0 {
return genericLongPacket{}, false
}
b = b[n:]
// Destination Connection ID Length (8),
// Destination Connection ID (0..2048),
- p.dstConnID, n = consumeUint8Bytes(b)
+ p.dstConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > 2048/8 {
return genericLongPacket{}, false
}
b = b[n:]
// Source Connection ID Length (8),
// Source Connection ID (0..2048),
- p.srcConnID, n = consumeUint8Bytes(b)
+ p.srcConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > 2048/8 {
return genericLongPacket{}, false
}
diff --git a/internal/quic/packet_codec_test.go b/quic/packet_codec_test.go
similarity index 85%
rename from internal/quic/packet_codec_test.go
rename to quic/packet_codec_test.go
index 7b01bb00d6..2a2b08f4e3 100644
--- a/internal/quic/packet_codec_test.go
+++ b/quic/packet_codec_test.go
@@ -9,8 +9,14 @@ package quic
import (
"bytes"
"crypto/tls"
+ "io"
+ "log/slog"
"reflect"
"testing"
+ "time"
+
+ "golang.org/x/net/internal/quic/quicwire"
+ "golang.org/x/net/quic/qlog"
)
func TestParseLongHeaderPacket(t *testing.T) {
@@ -207,11 +213,13 @@ func TestRoundtripEncodeShortPacket(t *testing.T) {
func TestFrameEncodeDecode(t *testing.T) {
for _, test := range []struct {
s string
+ j string
f debugFrame
b []byte
truncated []byte
}{{
s: "PADDING*1",
+ j: `{"frame_type":"padding","length":1}`,
f: debugFramePadding{
size: 1,
},
@@ -221,12 +229,14 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PING",
+ j: `{"frame_type":"ping"}`,
f: debugFramePing{},
b: []byte{
0x01, // TYPE(i) = 0x01
},
}, {
s: "ACK Delay=10 [0,16) [17,32) [48,64)",
+ j: `"error: debugFrameAck should not appear as a slog Value"`,
f: debugFrameAck{
ackDelay: 10,
ranges: []i64range[packetNumber]{
@@ -257,6 +267,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "RESET_STREAM ID=1 Code=2 FinalSize=3",
+ j: `{"frame_type":"reset_stream","stream_id":1,"final_size":3}`,
f: debugFrameResetStream{
id: 1,
code: 2,
@@ -270,6 +281,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STOP_SENDING ID=1 Code=2",
+ j: `{"frame_type":"stop_sending","stream_id":1,"error_code":2}`,
f: debugFrameStopSending{
id: 1,
code: 2,
@@ -281,6 +293,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "CRYPTO Offset=1 Length=2",
+ j: `{"frame_type":"crypto","offset":1,"length":2}`,
f: debugFrameCrypto{
off: 1,
data: []byte{3, 4},
@@ -299,6 +312,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "NEW_TOKEN Token=0304",
+ j: `{"frame_type":"new_token","token":"0304"}`,
f: debugFrameNewToken{
token: []byte{3, 4},
},
@@ -309,6 +323,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=1 Offset=0 Length=0",
+ j: `{"frame_type":"stream","stream_id":1,"offset":0,"length":0}`,
f: debugFrameStream{
id: 1,
fin: false,
@@ -324,6 +339,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=100 Offset=4 Length=3",
+ j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3}`,
f: debugFrameStream{
id: 100,
fin: false,
@@ -346,6 +362,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=100 FIN Offset=4 Length=3",
+ j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3,"fin":true}`,
f: debugFrameStream{
id: 100,
fin: true,
@@ -368,6 +385,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=1 FIN Offset=100 Length=0",
+ j: `{"frame_type":"stream","stream_id":1,"offset":100,"length":0,"fin":true}`,
f: debugFrameStream{
id: 1,
fin: true,
@@ -383,6 +401,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_DATA Max=10",
+ j: `{"frame_type":"max_data","maximum":10}`,
f: debugFrameMaxData{
max: 10,
},
@@ -392,6 +411,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAM_DATA ID=1 Max=10",
+ j: `{"frame_type":"max_stream_data","stream_id":1,"maximum":10}`,
f: debugFrameMaxStreamData{
id: 1,
max: 10,
@@ -403,6 +423,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAMS Type=bidi Max=1",
+ j: `{"frame_type":"max_streams","stream_type":"bidirectional","maximum":1}`,
f: debugFrameMaxStreams{
streamType: bidiStream,
max: 1,
@@ -413,6 +434,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAMS Type=uni Max=1",
+ j: `{"frame_type":"max_streams","stream_type":"unidirectional","maximum":1}`,
f: debugFrameMaxStreams{
streamType: uniStream,
max: 1,
@@ -423,6 +445,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "DATA_BLOCKED Max=1",
+ j: `{"frame_type":"data_blocked","limit":1}`,
f: debugFrameDataBlocked{
max: 1,
},
@@ -432,6 +455,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM_DATA_BLOCKED ID=1 Max=2",
+ j: `{"frame_type":"stream_data_blocked","stream_id":1,"limit":2}`,
f: debugFrameStreamDataBlocked{
id: 1,
max: 2,
@@ -443,6 +467,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAMS_BLOCKED Type=bidi Max=1",
+ j: `{"frame_type":"streams_blocked","stream_type":"bidirectional","limit":1}`,
f: debugFrameStreamsBlocked{
streamType: bidiStream,
max: 1,
@@ -453,6 +478,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAMS_BLOCKED Type=uni Max=1",
+ j: `{"frame_type":"streams_blocked","stream_type":"unidirectional","limit":1}`,
f: debugFrameStreamsBlocked{
streamType: uniStream,
max: 1,
@@ -463,6 +489,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "NEW_CONNECTION_ID Seq=3 Retire=2 ID=a0a1a2a3 Token=0102030405060708090a0b0c0d0e0f10",
+ j: `{"frame_type":"new_connection_id","sequence_number":3,"retire_prior_to":2,"connection_id":"a0a1a2a3","stateless_reset_token":"0102030405060708090a0b0c0d0e0f10"}`,
f: debugFrameNewConnectionID{
seq: 3,
retirePriorTo: 2,
@@ -479,6 +506,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "RETIRE_CONNECTION_ID Seq=1",
+ j: `{"frame_type":"retire_connection_id","sequence_number":1}`,
f: debugFrameRetireConnectionID{
seq: 1,
},
@@ -488,8 +516,9 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PATH_CHALLENGE Data=0123456789abcdef",
+ j: `{"frame_type":"path_challenge","data":"0123456789abcdef"}`,
f: debugFramePathChallenge{
- data: 0x0123456789abcdef,
+ data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
},
b: []byte{
0x1a, // Type (i) = 0x1a,
@@ -497,8 +526,9 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PATH_RESPONSE Data=0123456789abcdef",
+ j: `{"frame_type":"path_response","data":"0123456789abcdef"}`,
f: debugFramePathResponse{
- data: 0x0123456789abcdef,
+ data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
},
b: []byte{
0x1b, // Type (i) = 0x1b,
@@ -506,6 +536,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: `CONNECTION_CLOSE Code=INTERNAL_ERROR FrameType=2 Reason="oops"`,
+ j: `{"frame_type":"connection_close","error_space":"transport","error_code_value":1,"reason":"oops"}`,
f: debugFrameConnectionCloseTransport{
code: 1,
frameType: 2,
@@ -520,6 +551,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: `CONNECTION_CLOSE AppCode=1 Reason="oops"`,
+ j: `{"frame_type":"connection_close","error_space":"application","error_code_value":1,"reason":"oops"}`,
f: debugFrameConnectionCloseApplication{
code: 1,
reason: "oops",
@@ -532,6 +564,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "HANDSHAKE_DONE",
+ j: `{"frame_type":"handshake_done"}`,
f: debugFrameHandshakeDone{},
b: []byte{
0x1e, // Type (i) = 0x1e,
@@ -554,6 +587,9 @@ func TestFrameEncodeDecode(t *testing.T) {
if got, want := test.f.String(), test.s; got != want {
t.Errorf("frame.String():\ngot %q\nwant %q", got, want)
}
+ if got, want := frameJSON(test.f), test.j; got != want {
+ t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want)
+ }
// Try encoding the frame into too little space.
// Most frames will result in an error; some (like STREAM frames) will truncate
@@ -579,6 +615,42 @@ func TestFrameEncodeDecode(t *testing.T) {
}
}
+func TestFrameScaledAck(t *testing.T) {
+ for _, test := range []struct {
+ j string
+ f debugFrameScaledAck
+ }{{
+ j: `{"frame_type":"ack","acked_ranges":[[0,15],[17],[48,63]],"ack_delay":10.000000}`,
+ f: debugFrameScaledAck{
+ ackDelay: 10 * time.Millisecond,
+ ranges: []i64range[packetNumber]{
+ {0x00, 0x10},
+ {0x11, 0x12},
+ {0x30, 0x40},
+ },
+ },
+ }} {
+ if got, want := frameJSON(test.f), test.j; got != want {
+ t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want)
+ }
+ }
+}
+
+func frameJSON(f slog.LogValuer) string {
+ var buf bytes.Buffer
+ h := qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) {
+ return nopCloseWriter{&buf}, nil
+ },
+ })
+ // Log the frame, and then trim out everything but the frame from the log.
+ slog.New(h).Info("message", slog.Any("frame", f))
+ _, b, _ := bytes.Cut(buf.Bytes(), []byte(`"frame":`))
+ b = bytes.TrimSuffix(b, []byte("}}\n"))
+ return string(b)
+}
+
func TestFrameDecode(t *testing.T) {
for _, test := range []struct {
desc string
@@ -665,7 +737,7 @@ func TestFrameDecodeErrors(t *testing.T) {
name: "MAX_STREAMS with too many streams",
b: func() []byte {
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.11-5.2.1
- return appendVarint([]byte{frameTypeMaxStreamsBidi}, (1<<60)+1)
+ return quicwire.AppendVarint([]byte{frameTypeMaxStreamsBidi}, (1<<60)+1)
}(),
}, {
name: "NEW_CONNECTION_ID too small",
diff --git a/internal/quic/packet_number.go b/quic/packet_number.go
similarity index 100%
rename from internal/quic/packet_number.go
rename to quic/packet_number.go
diff --git a/internal/quic/packet_number_test.go b/quic/packet_number_test.go
similarity index 100%
rename from internal/quic/packet_number_test.go
rename to quic/packet_number_test.go
diff --git a/internal/quic/packet_parser.go b/quic/packet_parser.go
similarity index 77%
rename from internal/quic/packet_parser.go
rename to quic/packet_parser.go
index ce04339025..dca3018086 100644
--- a/internal/quic/packet_parser.go
+++ b/quic/packet_parser.go
@@ -6,6 +6,8 @@
package quic
+import "golang.org/x/net/internal/quic/quicwire"
+
// parseLongHeaderPacket parses a QUIC long header packet.
//
// It does not parse Version Negotiation packets.
@@ -34,7 +36,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
}
b = b[1:]
// Version (32),
- p.version, n = consumeUint32(b)
+ p.version, n = quicwire.ConsumeUint32(b)
if n < 0 {
return longPacket{}, -1
}
@@ -46,16 +48,16 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
// Destination Connection ID Length (8),
// Destination Connection ID (0..160),
- p.dstConnID, n = consumeUint8Bytes(b)
- if n < 0 || len(p.dstConnID) > 20 {
+ p.dstConnID, n = quicwire.ConsumeUint8Bytes(b)
+ if n < 0 || len(p.dstConnID) > maxConnIDLen {
return longPacket{}, -1
}
b = b[n:]
// Source Connection ID Length (8),
// Source Connection ID (0..160),
- p.srcConnID, n = consumeUint8Bytes(b)
- if n < 0 || len(p.dstConnID) > 20 {
+ p.srcConnID, n = quicwire.ConsumeUint8Bytes(b)
+ if n < 0 || len(p.dstConnID) > maxConnIDLen {
return longPacket{}, -1
}
b = b[n:]
@@ -64,7 +66,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
case packetTypeInitial:
// Token Length (i),
// Token (..),
- p.extra, n = consumeVarintBytes(b)
+ p.extra, n = quicwire.ConsumeVarintBytes(b)
if n < 0 {
return longPacket{}, -1
}
@@ -77,7 +79,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
}
// Length (i),
- payLen, n := consumeVarint(b)
+ payLen, n := quicwire.ConsumeVarint(b)
if n < 0 {
return longPacket{}, -1
}
@@ -121,14 +123,14 @@ func skipLongHeaderPacket(pkt []byte) int {
}
if getPacketType(pkt) == packetTypeInitial {
// Token length, token.
- _, nn := consumeVarintBytes(pkt[n:])
+ _, nn := quicwire.ConsumeVarintBytes(pkt[n:])
if nn < 0 {
return -1
}
n += nn
}
// Length, packet number, payload.
- _, nn := consumeVarintBytes(pkt[n:])
+ _, nn := quicwire.ConsumeVarintBytes(pkt[n:])
if nn < 0 {
return -1
}
@@ -160,20 +162,20 @@ func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax p
func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, n int) {
b := frame[1:] // type
- largestAck, n := consumeVarint(b)
+ largestAck, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- v, n := consumeVarintInt64(b)
+ v, n := quicwire.ConsumeVarintInt64(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
ackDelay = unscaledAckDelay(v)
- ackRangeCount, n := consumeVarint(b)
+ ackRangeCount, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -181,7 +183,7 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
rangeMax := packetNumber(largestAck)
for i := uint64(0); ; i++ {
- rangeLen, n := consumeVarint(b)
+ rangeLen, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -196,7 +198,7 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
break
}
- gap, n := consumeVarint(b)
+ gap, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -209,17 +211,17 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
return packetNumber(largestAck), ackDelay, len(frame) - len(b)
}
- ect0Count, n := consumeVarint(b)
+ ect0Count, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- ect1Count, n := consumeVarint(b)
+ ect1Count, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- ecnCECount, n := consumeVarint(b)
+ ecnCECount, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -236,17 +238,17 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) {
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
n += nn
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
n += nn
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
@@ -257,12 +259,12 @@ func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int6
func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) {
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -272,13 +274,13 @@ func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) {
func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, nil, -1
}
off = int64(v)
n += nn
- data, nn = consumeVarintBytes(b[n:])
+ data, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, nil, -1
}
@@ -288,7 +290,7 @@ func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) {
func consumeNewTokenFrame(b []byte) (token []byte, n int) {
n = 1
- data, nn := consumeVarintBytes(b[n:])
+ data, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return nil, -1
}
@@ -302,13 +304,13 @@ func consumeNewTokenFrame(b []byte) (token []byte, n int) {
func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte, n int) {
fin = (b[0] & 0x01) != 0
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
n += nn
if b[0]&0x04 != 0 {
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
@@ -316,7 +318,7 @@ func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte
off = int64(v)
}
if b[0]&0x02 != 0 {
- data, nn = consumeVarintBytes(b[n:])
+ data, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
@@ -333,7 +335,7 @@ func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte
func consumeMaxDataFrame(b []byte) (max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, -1
}
@@ -343,13 +345,13 @@ func consumeMaxDataFrame(b []byte) (max int64, n int) {
func consumeMaxStreamDataFrame(b []byte) (id streamID, max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
id = streamID(v)
- v, nn = consumeVarint(b[n:])
+ v, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -368,7 +370,7 @@ func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) {
return 0, 0, -1
}
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -381,13 +383,13 @@ func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) {
func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
id = streamID(v)
- max, nn = consumeVarintInt64(b[n:])
+ max, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -397,7 +399,7 @@ func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) {
func consumeDataBlockedFrame(b []byte) (max int64, n int) {
n = 1
- max, nn := consumeVarintInt64(b[n:])
+ max, nn := quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, -1
}
@@ -412,7 +414,7 @@ func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) {
typ = uniStream
}
n = 1
- max, nn := consumeVarintInt64(b[n:])
+ max, nn := quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -420,32 +422,32 @@ func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) {
return typ, max, n
}
-func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken [16]byte, n int) {
+func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken statelessResetToken, n int) {
n = 1
var nn int
- seq, nn = consumeVarintInt64(b[n:])
+ seq, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
n += nn
- retire, nn = consumeVarintInt64(b[n:])
+ retire, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
n += nn
if seq < retire {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
- connID, nn = consumeVarintBytes(b[n:])
+ connID, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
if len(connID) < 1 || len(connID) > 20 {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
n += nn
if len(b[n:]) < len(resetToken) {
- return 0, 0, nil, [16]byte{}, -1
+ return 0, 0, nil, statelessResetToken{}, -1
}
copy(resetToken[:], b[n:])
n += len(resetToken)
@@ -455,7 +457,7 @@ func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, re
func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) {
n = 1
var nn int
- seq, nn = consumeVarintInt64(b[n:])
+ seq, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, -1
}
@@ -463,18 +465,17 @@ func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) {
return seq, n
}
-func consumePathChallengeFrame(b []byte) (data uint64, n int) {
+func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) {
n = 1
- var nn int
- data, nn = consumeUint64(b[n:])
- if nn < 0 {
- return 0, -1
+ nn := copy(data[:], b[n:])
+ if nn != len(data) {
+ return data, -1
}
n += nn
return data, n
}
-func consumePathResponseFrame(b []byte) (data uint64, n int) {
+func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) {
return consumePathChallengeFrame(b) // identical frame format
}
@@ -482,18 +483,18 @@ func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameT
n = 1
var nn int
var codeInt uint64
- codeInt, nn = consumeVarint(b[n:])
+ codeInt, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
code = transportError(codeInt)
n += nn
- frameType, nn = consumeVarint(b[n:])
+ frameType, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
n += nn
- reasonb, nn := consumeVarintBytes(b[n:])
+ reasonb, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
@@ -505,12 +506,12 @@ func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameT
func consumeConnectionCloseApplicationFrame(b []byte) (code uint64, reason string, n int) {
n = 1
var nn int
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, "", -1
}
n += nn
- reasonb, nn := consumeVarintBytes(b[n:])
+ reasonb, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, "", -1
}
diff --git a/internal/quic/packet_protection.go b/quic/packet_protection.go
similarity index 97%
rename from internal/quic/packet_protection.go
rename to quic/packet_protection.go
index 7b141ac49e..9f1bbc6a4a 100644
--- a/internal/quic/packet_protection.go
+++ b/quic/packet_protection.go
@@ -351,7 +351,13 @@ func (k *updatingKeyPair) init() {
// We perform the first key update early in the connection so a peer
// which does not support key updates will fail rapidly,
// rather than after the connection has been long established.
- k.updateAfter = 1000
+ //
+ // The QUIC interop runner "keyupdate" test requires that the client
+ // initiate a key rotation early in the connection. Increasing this
+ // value may cause interop test failures; if we do want to increase it,
+ // we should either skip the keyupdate test or provide a way to override
+ // the setting in interop tests.
+ k.updateAfter = 100
}
func (k *updatingKeyPair) canRead() bool {
@@ -441,7 +447,7 @@ func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumbe
if err != nil {
k.authFailures++
if k.authFailures >= aeadIntegrityLimit(k.r.suite) {
- return nil, 0, localTransportError(errAEADLimitReached)
+ return nil, 0, localTransportError{code: errAEADLimitReached}
}
return nil, 0, err
}
@@ -513,7 +519,7 @@ func hashForSuite(suite uint16) (h crypto.Hash, keySize int) {
}
}
-// hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
+// hkdfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
//
// Copied from crypto/tls/key_schedule.go.
func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte {
diff --git a/internal/quic/packet_protection_test.go b/quic/packet_protection_test.go
similarity index 100%
rename from internal/quic/packet_protection_test.go
rename to quic/packet_protection_test.go
diff --git a/internal/quic/packet_test.go b/quic/packet_test.go
similarity index 100%
rename from internal/quic/packet_test.go
rename to quic/packet_test.go
diff --git a/internal/quic/packet_writer.go b/quic/packet_writer.go
similarity index 77%
rename from internal/quic/packet_writer.go
rename to quic/packet_writer.go
index 0c2b2ee41e..e75edcda5b 100644
--- a/internal/quic/packet_writer.go
+++ b/quic/packet_writer.go
@@ -8,6 +8,8 @@ package quic
import (
"encoding/binary"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A packetWriter constructs QUIC datagrams.
@@ -47,6 +49,11 @@ func (w *packetWriter) datagram() []byte {
return w.b
}
+// packetLen returns the size of the current packet.
+func (w *packetWriter) packetLen() int {
+ return len(w.b[w.pktOff:]) + aeadOverhead
+}
+
// payload returns the payload of the current packet.
func (w *packetWriter) payload() []byte {
return w.b[w.payOff:]
@@ -69,7 +76,7 @@ func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber,
hdrSize += 1 + len(p.srcConnID)
switch p.ptype {
case packetTypeInitial:
- hdrSize += sizeVarint(uint64(len(p.extra))) + len(p.extra)
+ hdrSize += quicwire.SizeVarint(uint64(len(p.extra))) + len(p.extra)
}
hdrSize += 2 // length, hardcoded to a 2-byte varint
pnumOff := len(w.b) + hdrSize
@@ -122,11 +129,11 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber
}
hdr = append(hdr, headerFormLong|fixedBit|typeBits|byte(pnumLen-1))
hdr = binary.BigEndian.AppendUint32(hdr, p.version)
- hdr = appendUint8Bytes(hdr, p.dstConnID)
- hdr = appendUint8Bytes(hdr, p.srcConnID)
+ hdr = quicwire.AppendUint8Bytes(hdr, p.dstConnID)
+ hdr = quicwire.AppendUint8Bytes(hdr, p.srcConnID)
switch p.ptype {
case packetTypeInitial:
- hdr = appendVarintBytes(hdr, p.extra) // token
+ hdr = quicwire.AppendVarintBytes(hdr, p.extra) // token
}
// Packet length, always encoded as a 2-byte varint.
@@ -136,7 +143,7 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber
hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked)
k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num)
- return w.finish(p.num)
+ return w.finish(p.ptype, p.num)
}
// start1RTTPacket starts writing a 1-RTT (short header) packet.
@@ -178,7 +185,7 @@ func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConn
hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked)
w.padPacketLength(pnumLen)
k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum)
- return w.finish(pnum)
+ return w.finish(packetType1RTT, pnum)
}
// padPacketLength pads out the payload of the current packet to the minimum size,
@@ -199,9 +206,10 @@ func (w *packetWriter) padPacketLength(pnumLen int) int {
}
// finish finishes the current packet after protection is applied.
-func (w *packetWriter) finish(pnum packetNumber) *sentPacket {
+func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket {
w.b = w.b[:len(w.b)+aeadOverhead]
w.sent.size = len(w.b) - w.pktOff
+ w.sent.ptype = ptype
w.sent.num = pnum
sent := w.sent
w.sent = nil
@@ -237,10 +245,7 @@ func (w *packetWriter) appendPingFrame() (added bool) {
return false
}
w.b = append(w.b, frameTypePing)
- // Mark this packet as ack-eliciting and in-flight,
- // but there's no need to record the presence of a PING frame in it.
- w.sent.ackEliciting = true
- w.sent.inFlight = true
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
@@ -267,26 +272,26 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale
largest = uint64(seen.max())
firstRange = uint64(seen[len(seen)-1].size() - 1)
)
- if w.avail() < 1+sizeVarint(largest)+sizeVarint(uint64(delay))+1+sizeVarint(firstRange) {
+ if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange) {
return false
}
w.b = append(w.b, frameTypeAck)
- w.b = appendVarint(w.b, largest)
- w.b = appendVarint(w.b, uint64(delay))
+ w.b = quicwire.AppendVarint(w.b, largest)
+ w.b = quicwire.AppendVarint(w.b, uint64(delay))
// The range count is technically a varint, but we'll reserve a single byte for it
// and never add more than 62 ranges (the maximum varint that fits in a byte).
rangeCountOff := len(w.b)
w.b = append(w.b, 0)
- w.b = appendVarint(w.b, firstRange)
+ w.b = quicwire.AppendVarint(w.b, firstRange)
rangeCount := byte(0)
for i := len(seen) - 2; i >= 0; i-- {
gap := uint64(seen[i+1].start - seen[i].end - 1)
size := uint64(seen[i].size() - 1)
- if w.avail() < sizeVarint(gap)+sizeVarint(size) || rangeCount > 62 {
+ if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size) || rangeCount > 62 {
break
}
- w.b = appendVarint(w.b, gap)
- w.b = appendVarint(w.b, size)
+ w.b = quicwire.AppendVarint(w.b, gap)
+ w.b = quicwire.AppendVarint(w.b, size)
rangeCount++
}
w.b[rangeCountOff] = rangeCount
@@ -296,34 +301,34 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale
}
func (w *packetWriter) appendNewTokenFrame(token []byte) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(len(token)))+len(token) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(len(token)))+len(token) {
return false
}
w.b = append(w.b, frameTypeNewToken)
- w.b = appendVarintBytes(w.b, token)
+ w.b = quicwire.AppendVarintBytes(w.b, token)
return true
}
func (w *packetWriter) appendResetStreamFrame(id streamID, code uint64, finalSize int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code)+sizeVarint(uint64(finalSize)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(finalSize)) {
return false
}
w.b = append(w.b, frameTypeResetStream)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, code)
- w.b = appendVarint(w.b, uint64(finalSize))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, code)
+ w.b = quicwire.AppendVarint(w.b, uint64(finalSize))
w.sent.appendAckElicitingFrame(frameTypeResetStream)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code) {
return false
}
w.b = append(w.b, frameTypeStopSending)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, code)
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, code)
w.sent.appendAckElicitingFrame(frameTypeStopSending)
w.sent.appendInt(uint64(id))
return true
@@ -334,9 +339,9 @@ func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added b
// The returned []byte may be smaller than size if the packet cannot hold all the data.
func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added bool) {
max := w.avail()
- max -= 1 // frame type
- max -= sizeVarint(uint64(off)) // offset
- max -= sizeVarint(uint64(size)) // maximum length
+ max -= 1 // frame type
+ max -= quicwire.SizeVarint(uint64(off)) // offset
+ max -= quicwire.SizeVarint(uint64(size)) // maximum length
if max <= 0 {
return nil, false
}
@@ -344,8 +349,8 @@ func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added b
size = max
}
w.b = append(w.b, frameTypeCrypto)
- w.b = appendVarint(w.b, uint64(off))
- w.b = appendVarint(w.b, uint64(size))
+ w.b = quicwire.AppendVarint(w.b, uint64(off))
+ w.b = quicwire.AppendVarint(w.b, uint64(size))
start := len(w.b)
w.b = w.b[:start+size]
w.sent.appendAckElicitingFrame(frameTypeCrypto)
@@ -360,12 +365,12 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
typ := uint8(frameTypeStreamBase | streamLenBit)
max := w.avail()
max -= 1 // frame type
- max -= sizeVarint(uint64(id))
+ max -= quicwire.SizeVarint(uint64(id))
if off != 0 {
- max -= sizeVarint(uint64(off))
+ max -= quicwire.SizeVarint(uint64(off))
typ |= streamOffBit
}
- max -= sizeVarint(uint64(size)) // maximum length
+ max -= quicwire.SizeVarint(uint64(size)) // maximum length
if max < 0 || (max == 0 && size > 0) {
return nil, false
}
@@ -375,47 +380,43 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
typ |= streamFinBit
}
w.b = append(w.b, typ)
- w.b = appendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
if off != 0 {
- w.b = appendVarint(w.b, uint64(off))
+ w.b = quicwire.AppendVarint(w.b, uint64(off))
}
- w.b = appendVarint(w.b, uint64(size))
+ w.b = quicwire.AppendVarint(w.b, uint64(size))
start := len(w.b)
w.b = w.b[:start+size]
- if fin {
- w.sent.appendAckElicitingFrame(frameTypeStreamBase | streamFinBit)
- } else {
- w.sent.appendAckElicitingFrame(frameTypeStreamBase)
- }
+ w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit))
w.sent.appendInt(uint64(id))
w.sent.appendOffAndSize(off, size)
return w.b[start:][:size], true
}
func (w *packetWriter) appendMaxDataFrame(max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeMaxData)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeMaxData)
return true
}
func (w *packetWriter) appendMaxStreamDataFrame(id streamID, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeMaxStreamData)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeMaxStreamData)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
var typ byte
@@ -425,35 +426,35 @@ func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (
typ = frameTypeMaxStreamsUni
}
w.b = append(w.b, typ)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(typ)
return true
}
func (w *packetWriter) appendDataBlockedFrame(max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeDataBlocked)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeDataBlocked)
return true
}
func (w *packetWriter) appendStreamDataBlockedFrame(id streamID, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeStreamDataBlocked)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeStreamDataBlocked)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
var ftype byte
@@ -463,19 +464,19 @@ func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (add
ftype = frameTypeStreamsBlockedUni
}
w.b = append(w.b, ftype)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(ftype)
return true
}
func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, connID []byte, token [16]byte) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(seq))+sizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(seq))+quicwire.SizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) {
return false
}
w.b = append(w.b, frameTypeNewConnectionID)
- w.b = appendVarint(w.b, uint64(seq))
- w.b = appendVarint(w.b, uint64(retirePriorTo))
- w.b = appendUint8Bytes(w.b, connID)
+ w.b = quicwire.AppendVarint(w.b, uint64(seq))
+ w.b = quicwire.AppendVarint(w.b, uint64(retirePriorTo))
+ w.b = quicwire.AppendUint8Bytes(w.b, connID)
w.b = append(w.b, token[:]...)
w.sent.appendAckElicitingFrame(frameTypeNewConnectionID)
w.sent.appendInt(uint64(seq))
@@ -483,60 +484,60 @@ func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, conn
}
func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(seq)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(seq)) {
return false
}
w.b = append(w.b, frameTypeRetireConnectionID)
- w.b = appendVarint(w.b, uint64(seq))
+ w.b = quicwire.AppendVarint(w.b, uint64(seq))
w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID)
w.sent.appendInt(uint64(seq))
return true
}
-func (w *packetWriter) appendPathChallengeFrame(data uint64) (added bool) {
+func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) {
if w.avail() < 1+8 {
return false
}
w.b = append(w.b, frameTypePathChallenge)
- w.b = binary.BigEndian.AppendUint64(w.b, data)
- w.sent.appendAckElicitingFrame(frameTypePathChallenge)
+ w.b = append(w.b, data[:]...)
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
-func (w *packetWriter) appendPathResponseFrame(data uint64) (added bool) {
+func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) {
if w.avail() < 1+8 {
return false
}
w.b = append(w.b, frameTypePathResponse)
- w.b = binary.BigEndian.AppendUint64(w.b, data)
- w.sent.appendAckElicitingFrame(frameTypePathResponse)
+ w.b = append(w.b, data[:]...)
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame
// carrying a transport error code.
func (w *packetWriter) appendConnectionCloseTransportFrame(code transportError, frameType uint64, reason string) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(code))+sizeVarint(frameType)+sizeVarint(uint64(len(reason)))+len(reason) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(code))+quicwire.SizeVarint(frameType)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) {
return false
}
w.b = append(w.b, frameTypeConnectionCloseTransport)
- w.b = appendVarint(w.b, uint64(code))
- w.b = appendVarint(w.b, frameType)
- w.b = appendVarintBytes(w.b, []byte(reason))
+ w.b = quicwire.AppendVarint(w.b, uint64(code))
+ w.b = quicwire.AppendVarint(w.b, frameType)
+ w.b = quicwire.AppendVarintBytes(w.b, []byte(reason))
// We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or
// detected as lost.
return true
}
-// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame
+// appendConnectionCloseApplicationFrame appends a CONNECTION_CLOSE frame
// carrying an application protocol error code.
func (w *packetWriter) appendConnectionCloseApplicationFrame(code uint64, reason string) (added bool) {
- if w.avail() < 1+sizeVarint(code)+sizeVarint(uint64(len(reason)))+len(reason) {
+ if w.avail() < 1+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) {
return false
}
w.b = append(w.b, frameTypeConnectionCloseApplication)
- w.b = appendVarint(w.b, code)
- w.b = appendVarintBytes(w.b, []byte(reason))
+ w.b = quicwire.AppendVarint(w.b, code)
+ w.b = quicwire.AppendVarintBytes(w.b, []byte(reason))
// We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or
// detected as lost.
return true
diff --git a/quic/path.go b/quic/path.go
new file mode 100644
index 0000000000..8c237dd45f
--- /dev/null
+++ b/quic/path.go
@@ -0,0 +1,89 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import "time"
+
+type pathState struct {
+ // Response to a peer's PATH_CHALLENGE.
+ // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames.
+ // We only track the most recent PATH_CHALLENGE.
+ // If the peer sends a second PATH_CHALLENGE before we respond to the first,
+ // we'll drop the first response.
+ sendPathResponse pathResponseType
+ data pathChallengeData
+}
+
+// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame.
+type pathChallengeData [64 / 8]byte
+
+type pathResponseType uint8
+
+const (
+ pathResponseNotNeeded = pathResponseType(iota)
+ pathResponseSmall // send PATH_RESPONSE, do not expand datagram
+ pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes
+)
+
+func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) {
+ // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes,
+ // except when this would exceed the anti-amplification limit.
+ //
+ // Rather than maintaining anti-amplification state for each path
+ // we may be sending a PATH_RESPONSE on, follow the following heuristic:
+ //
+ // If we receive a PATH_CHALLENGE in an expanded datagram,
+ // respond with an expanded datagram.
+ //
+ // If we receive a PATH_CHALLENGE in a non-expanded datagram,
+ // then the peer is presumably blocked by its own anti-amplification limit.
+ // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE
+ // will validate the path to the peer, remove its anti-amplification limit,
+ // and permit it to send a followup PATH_CHALLENGE in an expanded datagram.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1
+ if len(dgram.b) >= smallestMaxDatagramSize {
+ c.path.sendPathResponse = pathResponseExpanded
+ } else {
+ c.path.sendPathResponse = pathResponseSmall
+ }
+ c.path.data = data
+}
+
+func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) {
+ // "If the content of a PATH_RESPONSE frame does not match the content of
+ // a PATH_CHALLENGE frame previously sent by the endpoint,
+ // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4
+ //
+ // We never send PATH_CHALLENGE frames.
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent",
+ })
+}
+
+// appendPathFrames appends path validation related frames to the current packet.
+// If the return value pad is true, then the packet should be padded to 1200 bytes.
+func (c *Conn) appendPathFrames() (pad, ok bool) {
+ if c.path.sendPathResponse == pathResponseNotNeeded {
+ return pad, true
+ }
+ // We're required to send the PATH_RESPONSE on the path where the
+ // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2).
+ //
+ // At the moment, we don't support path migration and reject packets if
+ // the peer changes its source address, so just sending the PATH_RESPONSE
+ // in a regular datagram is fine.
+ if !c.w.appendPathResponseFrame(c.path.data) {
+ return pad, false
+ }
+ if c.path.sendPathResponse == pathResponseExpanded {
+ pad = true
+ }
+ c.path.sendPathResponse = pathResponseNotNeeded
+ return pad, true
+}
diff --git a/quic/path_test.go b/quic/path_test.go
new file mode 100644
index 0000000000..a309ed14ba
--- /dev/null
+++ b/quic/path_test.go
@@ -0,0 +1,66 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "testing"
+)
+
+func TestPathChallengeReceived(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ padTo int
+ wantPadding int
+ }{{
+ name: "unexpanded",
+ padTo: 0,
+ wantPadding: 0,
+ }, {
+ name: "expanded",
+ padTo: 1200,
+ wantPadding: 1200,
+ }} {
+ // "The recipient of [a PATH_CHALLENGE] frame MUST generate
+ // a PATH_RESPONSE frame [...] containing the same Data value."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ tc.writeFrames(packetType1RTT, debugFramePathChallenge{
+ data: data,
+ }, debugFramePadding{
+ to: test.padTo,
+ })
+ tc.wantFrame("response to PATH_CHALLENGE",
+ packetType1RTT, debugFramePathResponse{
+ data: data,
+ })
+ if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want {
+ t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want)
+ }
+ tc.wantIdle("connection is idle")
+ }
+}
+
+func TestPathResponseMismatchReceived(t *testing.T) {
+ // "If the content of a PATH_RESPONSE frame does not match the content of
+ // a PATH_CHALLENGE frame previously sent by the endpoint,
+ // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.writeFrames(packetType1RTT, debugFramePathResponse{
+ data: pathChallengeData{},
+ })
+ tc.wantFrame("invalid PATH_RESPONSE causes the connection to close",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ },
+ )
+}
diff --git a/internal/quic/ping.go b/quic/ping.go
similarity index 100%
rename from internal/quic/ping.go
rename to quic/ping.go
diff --git a/internal/quic/ping_test.go b/quic/ping_test.go
similarity index 100%
rename from internal/quic/ping_test.go
rename to quic/ping_test.go
diff --git a/internal/quic/pipe.go b/quic/pipe.go
similarity index 71%
rename from internal/quic/pipe.go
rename to quic/pipe.go
index 978a4f3d8b..75cf76db21 100644
--- a/internal/quic/pipe.go
+++ b/quic/pipe.go
@@ -17,14 +17,14 @@ import (
// Writing past the end of the window extends it.
// Data may be discarded from the start of the pipe, advancing the window.
type pipe struct {
- start int64
- end int64
- head *pipebuf
- tail *pipebuf
+ start int64 // stream position of first stored byte
+ end int64 // stream position just past the last stored byte
+ head *pipebuf // if non-nil, then head.off + len(head.b) > start
+ tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end
}
type pipebuf struct {
- off int64
+ off int64 // stream position of b[0]
b []byte
next *pipebuf
}
@@ -111,6 +111,7 @@ func (p *pipe) copy(off int64, b []byte) {
// read calls f with the data in [off, off+n)
// The data may be provided sequentially across multiple calls to f.
+// Note that read (unlike an io.Reader) does not consume the read data.
func (p *pipe) read(off int64, n int, f func([]byte) error) error {
if off < p.start {
panic("invalid read range")
@@ -135,6 +136,30 @@ func (p *pipe) read(off int64, n int, f func([]byte) error) error {
return nil
}
+// peek returns a reference to up to n bytes of internal data buffer, starting at p.start.
+// The returned slice is valid until the next call to discardBefore.
+// The length of the returned slice will be in the range [0,n].
+func (p *pipe) peek(n int64) []byte {
+ pb := p.head
+ if pb == nil {
+ return nil
+ }
+ b := pb.b[p.start-pb.off:]
+ return b[:min(int64(len(b)), n)]
+}
+
+// availableBuffer returns the available contiguous, allocated buffer space
+// following the pipe window.
+//
+// This is used by the stream write fast path, which makes multiple writes into the pipe buffer
+// without a lock, and then adjusts p.end at a later time with a lock held.
+func (p *pipe) availableBuffer() []byte {
+ if p.tail == nil {
+ return nil
+ }
+ return p.tail.b[p.end-p.tail.off:]
+}
+
// discardBefore discards all data prior to off.
func (p *pipe) discardBefore(off int64) {
for p.head != nil && p.head.end() < off {
@@ -146,4 +171,5 @@ func (p *pipe) discardBefore(off int64) {
p.tail = nil
}
p.start = off
+ p.end = max(p.end, off)
}
diff --git a/internal/quic/pipe_test.go b/quic/pipe_test.go
similarity index 92%
rename from internal/quic/pipe_test.go
rename to quic/pipe_test.go
index 7a05ff4d47..bcb3a8bc05 100644
--- a/internal/quic/pipe_test.go
+++ b/quic/pipe_test.go
@@ -61,6 +61,12 @@ func TestPipeWrites(t *testing.T) {
discardBeforeOp{10000},
writeOp{10000, 20000},
},
+ }, {
+ desc: "discard before writing",
+ ops: []op{
+ discardBeforeOp{1000},
+ writeOp{0, 1},
+ },
}} {
var p pipe
var wantset rangeset[int64]
@@ -78,6 +84,9 @@ func TestPipeWrites(t *testing.T) {
p.discardBefore(o.off)
wantset.sub(0, o.off)
wantStart = o.off
+ if o.off > wantEnd {
+ wantEnd = o.off
+ }
}
if p.start != wantStart || p.end != wantEnd {
t.Errorf("%v: after %#v p contains [%v,%v), want [%v,%v)", test.desc, test.ops[:i+1], p.start, p.end, wantStart, wantEnd)
diff --git a/quic/qlog.go b/quic/qlog.go
new file mode 100644
index 0000000000..36831252c6
--- /dev/null
+++ b/quic/qlog.go
@@ -0,0 +1,274 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "encoding/hex"
+ "log/slog"
+ "net/netip"
+ "time"
+)
+
+// Log levels for qlog events.
+const (
+ // QLogLevelFrame includes per-frame information.
+ // When this level is enabled, packet_sent and packet_received events will
+ // contain information on individual frames sent/received.
+ QLogLevelFrame = slog.Level(-6)
+
+ // QLogLevelPacket events occur at most once per packet sent or received.
+ //
+ // For example: packet_sent, packet_received.
+ QLogLevelPacket = slog.Level(-4)
+
+ // QLogLevelConn events occur multiple times over a connection's lifetime,
+ // but less often than the frequency of individual packets.
+ //
+ // For example: connection_state_updated.
+ QLogLevelConn = slog.Level(-2)
+
+ // QLogLevelEndpoint events occur at most once per connection.
+ //
+ // For example: connection_started, connection_closed.
+ QLogLevelEndpoint = slog.Level(0)
+)
+
+func (c *Conn) logEnabled(level slog.Level) bool {
+ return logEnabled(c.log, level)
+}
+
+func logEnabled(log *slog.Logger, level slog.Level) bool {
+ return log != nil && log.Enabled(context.Background(), level)
+}
+
+// slogHexstring returns a slog.Attr for a value of the hexstring type.
+//
+// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1
+func slogHexstring(key string, value []byte) slog.Attr {
+ return slog.String(key, hex.EncodeToString(value))
+}
+
+func slogAddr(key string, value netip.Addr) slog.Attr {
+ return slog.String(key, value.String())
+}
+
+func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) {
+ if c.config.QLogLogger == nil ||
+ !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) {
+ return
+ }
+ var vantage string
+ if c.side == clientSide {
+ vantage = "client"
+ originalDstConnID = c.connIDState.originalDstConnID
+ } else {
+ vantage = "server"
+ }
+ // A qlog Trace container includes some metadata (title, description, vantage_point)
+ // and a list of Events. The Trace also includes a common_fields field setting field
+ // values common to all events in the trace.
+ //
+ // Trace = {
+ // ? title: text
+ // ? description: text
+ // ? configuration: Configuration
+ // ? common_fields: CommonFields
+ // ? vantage_point: VantagePoint
+ // events: [* Event]
+ // }
+ //
+ // To map this into slog's data model, we start each per-connection trace with a With
+ // call that includes both the trace metadata and the common fields.
+ //
+ // This means that in slog's model, each trace event will also include
+ // the Trace metadata fields (vantage_point), which is a divergence from the qlog model.
+ c.log = c.config.QLogLogger.With(
+ // The group_id permits associating traces taken from different vantage points
+ // for the same connection.
+ //
+ // We use the original destination connection ID as the group ID.
+ //
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6
+ slogHexstring("group_id", originalDstConnID),
+ slog.Group("vantage_point",
+ slog.String("name", "go quic"),
+ slog.String("type", vantage),
+ ),
+ )
+ localAddr := c.endpoint.LocalAddr()
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2
+ c.log.LogAttrs(context.Background(), QLogLevelEndpoint,
+ "connectivity:connection_started",
+ slogAddr("src_ip", localAddr.Addr()),
+ slog.Int("src_port", int(localAddr.Port())),
+ slogHexstring("src_cid", c.connIDState.local[0].cid),
+ slogAddr("dst_ip", peerAddr.Addr()),
+ slog.Int("dst_port", int(peerAddr.Port())),
+ slogHexstring("dst_cid", c.connIDState.remote[0].cid),
+ )
+}
+
+func (c *Conn) logConnectionClosed() {
+ if !c.logEnabled(QLogLevelEndpoint) {
+ return
+ }
+ err := c.lifetime.finalErr
+ trigger := "error"
+ switch e := err.(type) {
+ case *ApplicationError:
+ // TODO: Distinguish between peer and locally-initiated close.
+ trigger = "application"
+ case localTransportError:
+ switch err {
+ case errHandshakeTimeout:
+ trigger = "handshake_timeout"
+ default:
+ if e.code == errNo {
+ trigger = "clean"
+ }
+ }
+ case peerTransportError:
+ if e.code == errNo {
+ trigger = "clean"
+ }
+ default:
+ switch err {
+ case errIdleTimeout:
+ trigger = "idle_timeout"
+ case errStatelessReset:
+ trigger = "stateless_reset"
+ }
+ }
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3
+ c.log.LogAttrs(context.Background(), QLogLevelEndpoint,
+ "connectivity:connection_closed",
+ slog.String("trigger", trigger),
+ )
+}
+
+func (c *Conn) logPacketDropped(dgram *datagram) {
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "connectivity:packet_dropped",
+ )
+}
+
+func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(p.payload)
+ }
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_received",
+ slog.Group("header",
+ slog.String("packet_type", p.ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(p.num)),
+ slog.Uint64("flags", uint64(pkt[0])),
+ slogHexstring("scid", p.srcConnID),
+ slogHexstring("dcid", p.dstConnID),
+ ),
+ slog.Group("raw",
+ slog.Int("length", len(pkt)),
+ ),
+ frames,
+ )
+}
+
+func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(p.payload)
+ }
+ dstConnID, _ := dstConnIDForDatagram(pkt)
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_received",
+ slog.Group("header",
+ slog.String("packet_type", packetType1RTT.qlogString()),
+ slog.Uint64("packet_number", uint64(p.num)),
+ slog.Uint64("flags", uint64(pkt[0])),
+ slogHexstring("dcid", dstConnID),
+ ),
+ slog.Group("raw",
+ slog.Int("length", len(pkt)),
+ ),
+ frames,
+ )
+}
+
+func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst []byte, pktLen int, payload []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(payload)
+ }
+ var scid slog.Attr
+ if len(src) > 0 {
+ scid = slogHexstring("scid", src)
+ }
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_sent",
+ slog.Group("header",
+ slog.String("packet_type", ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(pnum)),
+ scid,
+ slogHexstring("dcid", dst),
+ ),
+ slog.Group("raw",
+ slog.Int("length", pktLen),
+ ),
+ frames,
+ )
+}
+
+// packetFramesAttr returns the "frames" attribute containing the frames in a packet.
+// We currently pass this as a slog Any containing a []slog.Value,
+// where each Value is a debugFrame that implements slog.LogValuer.
+//
+// This isn't tremendously efficient, but avoids the need to put a JSON encoder
+// in the quic package or a frame parser in the qlog package.
+func (c *Conn) packetFramesAttr(payload []byte) slog.Attr {
+ var frames []slog.Value
+ for len(payload) > 0 {
+ f, n := parseDebugFrame(payload)
+ if n < 0 {
+ break
+ }
+ payload = payload[n:]
+ switch f := f.(type) {
+ case debugFrameAck:
+ // The qlog ACK frame contains the ACK Delay field as a duration.
+ // Interpreting the contents of this field as a duration requires
+ // knowing the peer's ack_delay_exponent transport parameter,
+ // and it's possible for us to parse an ACK frame before we've
+ // received that parameter.
+ //
+ // We could plumb connection state down into the frame parser,
+ // but for now let's minimize the amount of code that needs to
+ // deal with this and convert the unscaled value into a scaled one here.
+ ackDelay := time.Duration(-1)
+ if c.peerAckDelayExponent >= 0 {
+ ackDelay = f.ackDelay.Duration(uint8(c.peerAckDelayExponent))
+ }
+ frames = append(frames, slog.AnyValue(debugFrameScaledAck{
+ ranges: f.ranges,
+ ackDelay: ackDelay,
+ }))
+ default:
+ frames = append(frames, slog.AnyValue(f))
+ }
+ }
+ return slog.Any("frames", frames)
+}
+
+func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) {
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:packet_lost",
+ slog.Group("header",
+ slog.String("packet_type", sent.ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(sent.num)),
+ ),
+ )
+}
diff --git a/quic/qlog/handler.go b/quic/qlog/handler.go
new file mode 100644
index 0000000000..35a66cf8bf
--- /dev/null
+++ b/quic/qlog/handler.go
@@ -0,0 +1,76 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "context"
+ "log/slog"
+)
+
+type withAttrsHandler struct {
+ attrs []slog.Attr
+ h slog.Handler
+}
+
+func withAttrs(h slog.Handler, attrs []slog.Attr) slog.Handler {
+ if len(attrs) == 0 {
+ return h
+ }
+ return &withAttrsHandler{attrs: attrs, h: h}
+}
+
+func (h *withAttrsHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return h.h.Enabled(ctx, level)
+}
+
+func (h *withAttrsHandler) Handle(ctx context.Context, r slog.Record) error {
+ r.AddAttrs(h.attrs...)
+ return h.h.Handle(ctx, r)
+}
+
+func (h *withAttrsHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return withAttrs(h, attrs)
+}
+
+func (h *withAttrsHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
+
+type withGroupHandler struct {
+ name string
+ h slog.Handler
+}
+
+func withGroup(h slog.Handler, name string) slog.Handler {
+ if name == "" {
+ return h
+ }
+ return &withGroupHandler{name: name, h: h}
+}
+
+func (h *withGroupHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return h.h.Enabled(ctx, level)
+}
+
+func (h *withGroupHandler) Handle(ctx context.Context, r slog.Record) error {
+ var attrs []slog.Attr
+ r.Attrs(func(a slog.Attr) bool {
+ attrs = append(attrs, a)
+ return true
+ })
+ nr := slog.NewRecord(r.Time, r.Level, r.Message, r.PC)
+ nr.Add(slog.Any(h.name, slog.GroupValue(attrs...)))
+ return h.h.Handle(ctx, nr)
+}
+
+func (h *withGroupHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return withAttrs(h, attrs)
+}
+
+func (h *withGroupHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
diff --git a/quic/qlog/json_writer.go b/quic/qlog/json_writer.go
new file mode 100644
index 0000000000..7867c590df
--- /dev/null
+++ b/quic/qlog/json_writer.go
@@ -0,0 +1,261 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log/slog"
+ "strconv"
+ "sync"
+ "time"
+)
+
+// A jsonWriter writes JSON-SEQ (RFC 7464).
+//
+// A JSON-SEQ file consists of a series of JSON text records,
+// each beginning with an RS (0x1e) character and ending with LF (0x0a).
+type jsonWriter struct {
+ mu sync.Mutex
+ w io.WriteCloser
+ buf bytes.Buffer
+}
+
+// writeRecordStart writes the start of a JSON-SEQ record.
+func (w *jsonWriter) writeRecordStart() {
+ w.mu.Lock()
+ w.buf.WriteByte(0x1e)
+ w.buf.WriteByte('{')
+}
+
+// writeRecordEnd finishes writing a JSON-SEQ record.
+func (w *jsonWriter) writeRecordEnd() {
+ w.buf.WriteByte('}')
+ w.buf.WriteByte('\n')
+ w.w.Write(w.buf.Bytes())
+ w.buf.Reset()
+ w.mu.Unlock()
+}
+
+func (w *jsonWriter) writeAttrs(attrs []slog.Attr) {
+ w.buf.WriteByte('{')
+ for _, a := range attrs {
+ w.writeAttr(a)
+ }
+ w.buf.WriteByte('}')
+}
+
+func (w *jsonWriter) writeAttr(a slog.Attr) {
+ if a.Key == "" {
+ return
+ }
+ w.writeName(a.Key)
+ w.writeValue(a.Value)
+}
+
+// writeAttrsField writes a []slog.Attr as an object field.
+func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) {
+ w.writeName(name)
+ w.writeAttrs(attrs)
+}
+
+func (w *jsonWriter) writeValue(v slog.Value) {
+ v = v.Resolve()
+ switch v.Kind() {
+ case slog.KindAny:
+ switch v := v.Any().(type) {
+ case []slog.Value:
+ w.writeArray(v)
+ case interface{ AppendJSON([]byte) []byte }:
+ w.buf.Write(v.AppendJSON(w.buf.AvailableBuffer()))
+ default:
+ w.writeString(fmt.Sprint(v))
+ }
+ case slog.KindBool:
+ w.writeBool(v.Bool())
+ case slog.KindDuration:
+ w.writeDuration(v.Duration())
+ case slog.KindFloat64:
+ w.writeFloat64(v.Float64())
+ case slog.KindInt64:
+ w.writeInt64(v.Int64())
+ case slog.KindString:
+ w.writeString(v.String())
+ case slog.KindTime:
+ w.writeTime(v.Time())
+ case slog.KindUint64:
+ w.writeUint64(v.Uint64())
+ case slog.KindGroup:
+ w.writeAttrs(v.Group())
+ default:
+ w.writeString("unhandled kind")
+ }
+}
+
+// writeName writes an object field name followed by a colon.
+func (w *jsonWriter) writeName(name string) {
+ if b := w.buf.Bytes(); len(b) > 0 && b[len(b)-1] != '{' {
+ // Add the comma separating this from the previous field.
+ w.buf.WriteByte(',')
+ }
+ w.writeString(name)
+ w.buf.WriteByte(':')
+}
+
+func (w *jsonWriter) writeObject(f func()) {
+ w.buf.WriteByte('{')
+ f()
+ w.buf.WriteByte('}')
+}
+
+// writeObjectField writes an object-valued object field.
+// The function f is called to write the contents.
+func (w *jsonWriter) writeObjectField(name string, f func()) {
+ w.writeName(name)
+ w.writeObject(f)
+}
+
+func (w *jsonWriter) writeArray(vals []slog.Value) {
+ w.buf.WriteByte('[')
+ for i, v := range vals {
+ if i != 0 {
+ w.buf.WriteByte(',')
+ }
+ w.writeValue(v)
+ }
+ w.buf.WriteByte(']')
+}
+
+func (w *jsonWriter) writeRaw(v string) {
+ w.buf.WriteString(v)
+}
+
+// writeRawField writes a field with a raw JSON value.
+func (w *jsonWriter) writeRawField(name, v string) {
+ w.writeName(name)
+ w.writeRaw(v)
+}
+
+func (w *jsonWriter) writeBool(v bool) {
+ if v {
+ w.buf.WriteString("true")
+ } else {
+ w.buf.WriteString("false")
+ }
+}
+
+// writeBoolField writes a bool-valued object field.
+func (w *jsonWriter) writeBoolField(name string, v bool) {
+ w.writeName(name)
+ w.writeBool(v)
+}
+
+// writeDuration writes a duration as milliseconds.
+func (w *jsonWriter) writeDuration(v time.Duration) {
+ if v < 0 {
+ w.buf.WriteByte('-')
+ v = -v
+ }
+ fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond)
+}
+
+// writeDurationField writes a millisecond duration-valued object field.
+func (w *jsonWriter) writeDurationField(name string, v time.Duration) {
+ w.writeName(name)
+ w.writeDuration(v)
+}
+
+func (w *jsonWriter) writeFloat64(v float64) {
+ w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64))
+}
+
+// writeFloat64Field writes an float64-valued object field.
+func (w *jsonWriter) writeFloat64Field(name string, v float64) {
+ w.writeName(name)
+ w.writeFloat64(v)
+}
+
+func (w *jsonWriter) writeInt64(v int64) {
+ w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10))
+}
+
+// writeInt64Field writes an int64-valued object field.
+func (w *jsonWriter) writeInt64Field(name string, v int64) {
+ w.writeName(name)
+ w.writeInt64(v)
+}
+
+func (w *jsonWriter) writeUint64(v uint64) {
+ w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10))
+}
+
+// writeUint64Field writes a uint64-valued object field.
+func (w *jsonWriter) writeUint64Field(name string, v uint64) {
+ w.writeName(name)
+ w.writeUint64(v)
+}
+
+// writeTime writes a time as seconds since the Unix epoch.
+func (w *jsonWriter) writeTime(v time.Time) {
+ fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond))
+}
+
+// writeTimeField writes a time-valued object field.
+func (w *jsonWriter) writeTimeField(name string, v time.Time) {
+ w.writeName(name)
+ w.writeTime(v)
+}
+
+func jsonSafeSet(c byte) bool {
+ // mask is a 128-bit bitmap with 1s for allowed bytes,
+ // so that the byte c can be tested with a shift and an and.
+ // If c > 128, then 1<>64)) != 0
+}
+
+func jsonNeedsEscape(s string) bool {
+ for i := range s {
+ if !jsonSafeSet(s[i]) {
+ return true
+ }
+ }
+ return false
+}
+
+// writeString writes an ASCII string.
+//
+// qlog fields should never contain anything that isn't ASCII,
+// so we do the bare minimum to avoid producing invalid output if we
+// do write something unexpected.
+func (w *jsonWriter) writeString(v string) {
+ w.buf.WriteByte('"')
+ if !jsonNeedsEscape(v) {
+ w.buf.WriteString(v)
+ } else {
+ for i := range v {
+ if jsonSafeSet(v[i]) {
+ w.buf.WriteByte(v[i])
+ } else {
+ fmt.Fprintf(&w.buf, `\u%04x`, v[i])
+ }
+ }
+ }
+ w.buf.WriteByte('"')
+}
+
+// writeStringField writes a string-valued object field.
+func (w *jsonWriter) writeStringField(name, v string) {
+ w.writeName(name)
+ w.writeString(v)
+}
diff --git a/quic/qlog/json_writer_test.go b/quic/qlog/json_writer_test.go
new file mode 100644
index 0000000000..03cf6947ce
--- /dev/null
+++ b/quic/qlog/json_writer_test.go
@@ -0,0 +1,196 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+type testJSONOut struct {
+ bytes.Buffer
+}
+
+func (o *testJSONOut) Close() error { return nil }
+
+func newTestJSONWriter() *jsonWriter {
+ return &jsonWriter{w: &testJSONOut{}}
+}
+
+func wantJSONRecord(t *testing.T, w *jsonWriter, want string) {
+ t.Helper()
+ want = "\x1e" + want + "\n"
+ got := w.w.(*testJSONOut).String()
+ if got != want {
+ t.Errorf("jsonWriter contains unexpected output\ngot: %q\nwant: %q", got, want)
+ }
+}
+
+func TestJSONWriterWriteConcurrentRecords(t *testing.T) {
+ w := newTestJSONWriter()
+ var wg sync.WaitGroup
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ w.writeRecordStart()
+ w.writeInt64Field("field", 0)
+ w.writeRecordEnd()
+ }()
+ }
+ wg.Wait()
+ wantJSONRecord(t, w, strings.Join([]string{
+ `{"field":0}`,
+ `{"field":0}`,
+ `{"field":0}`,
+ }, "\n\x1e"))
+}
+
+func TestJSONWriterAttrs(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeAttrsField("field", []slog.Attr{
+ slog.Any("any", errors.New("value")),
+ slog.Bool("bool", true),
+ slog.Duration("duration", 1*time.Second),
+ slog.Float64("float64", 1),
+ slog.Int64("int64", 1),
+ slog.String("string", "value"),
+ slog.Time("time", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)),
+ slog.Uint64("uint64", 1),
+ slog.Group("group", "a", 1),
+ })
+ w.writeRecordEnd()
+ wantJSONRecord(t, w,
+ `{"field":{`+
+ `"any":"value",`+
+ `"bool":true,`+
+ `"duration":1000.000000,`+
+ `"float64":1,`+
+ `"int64":1,`+
+ `"string":"value",`+
+ `"time":946684800000.000000,`+
+ `"uint64":1,`+
+ `"group":{"a":1}`+
+ `}}`)
+}
+
+func TestJSONWriterAttrEmpty(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ var a slog.Attr
+ w.writeAttr(a)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{}`)
+}
+
+func TestJSONWriterObjectEmpty(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeObjectField("field", func() {})
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":{}}`)
+}
+
+func TestJSONWriterObjectFields(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeObjectField("field", func() {
+ w.writeStringField("a", "value")
+ w.writeInt64Field("b", 10)
+ })
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":{"a":"value","b":10}}`)
+}
+
+func TestJSONWriterRawField(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeRawField("field", `[1]`)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":[1]}`)
+}
+
+func TestJSONWriterBoolField(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeBoolField("true", true)
+ w.writeBoolField("false", false)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"true":true,"false":false}`)
+}
+
+func TestJSONWriterDurationField(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeDurationField("field1", (10*time.Millisecond)+(2*time.Nanosecond))
+ w.writeDurationField("field2", -((10 * time.Millisecond) + (2 * time.Nanosecond)))
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field1":10.000002,"field2":-10.000002}`)
+}
+
+func TestJSONWriterFloat64Field(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeFloat64Field("field", 1.1)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":1.1}`)
+}
+
+func TestJSONWriterInt64Field(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeInt64Field("field", 1234)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":1234}`)
+}
+
+func TestJSONWriterUint64Field(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeUint64Field("field", 1234)
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":1234}`)
+}
+
+func TestJSONWriterStringField(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeStringField("field", "value")
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":"value"}`)
+}
+
+func TestJSONWriterStringFieldEscaped(t *testing.T) {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeStringField("field", "va\x00ue")
+ w.writeRecordEnd()
+ wantJSONRecord(t, w, `{"field":"va\u0000ue"}`)
+}
+
+func TestJSONWriterStringEscaping(t *testing.T) {
+ for c := 0; c <= 0xff; c++ {
+ w := newTestJSONWriter()
+ w.writeRecordStart()
+ w.writeStringField("field", string([]byte{byte(c)}))
+ w.writeRecordEnd()
+ var want string
+ if (c >= 0x20 && c <= 0x21) || (c >= 0x23 && c <= 0x5b) || (c >= 0x5d && c <= 0x7e) {
+ want = fmt.Sprintf(`%c`, c)
+ } else {
+ want = fmt.Sprintf(`\u%04x`, c)
+ }
+ wantJSONRecord(t, w, `{"field":"`+want+`"}`)
+ }
+}
diff --git a/quic/qlog/qlog.go b/quic/qlog/qlog.go
new file mode 100644
index 0000000000..f33c6b0fd9
--- /dev/null
+++ b/quic/qlog/qlog.go
@@ -0,0 +1,267 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+// Package qlog serializes qlog events.
+package qlog
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+)
+
+// Vantage is the vantage point of a trace.
+type Vantage string
+
+const (
+ // VantageEndpoint traces contain events not specific to a single connection.
+ VantageEndpoint = Vantage("endpoint")
+
+ // VantageClient traces follow a connection from the client's perspective.
+ VantageClient = Vantage("client")
+
+ // VantageServer traces follow a connection from the server's perspective.
+ VantageServer = Vantage("server")
+)
+
+// TraceInfo contains information about a trace.
+type TraceInfo struct {
+ // Vantage is the vantage point of the trace.
+ Vantage Vantage
+
+ // GroupID identifies the logical group the trace belongs to.
+ // For a connection trace, the group will be the same for
+ // both the client and server vantage points.
+ GroupID string
+}
+
+// HandlerOptions are options for a JSONHandler.
+type HandlerOptions struct {
+ // Level reports the minimum record level that will be logged.
+ // If Level is nil, the handler assumes QLogLevelEndpoint.
+ Level slog.Leveler
+
+ // Dir is the directory in which to create trace files.
+ // The handler will create one file per connection.
+ // If NewTrace is non-nil or Dir is "", the handler will not create files.
+ Dir string
+
+ // NewTrace is called to create a new trace.
+ // If NewTrace is nil and Dir is set,
+ // the handler will create a new file in Dir for each trace.
+ NewTrace func(TraceInfo) (io.WriteCloser, error)
+}
+
+type endpointHandler struct {
+ opts HandlerOptions
+
+ traceOnce sync.Once
+ trace *jsonTraceHandler
+}
+
+// NewJSONHandler returns a handler which serializes qlog events to JSON.
+//
+// The handler will write an endpoint-wide trace,
+// and a separate trace for each connection.
+// The HandlerOptions control the location traces are written.
+//
+// It uses the streamable JSON Text Sequences mapping (JSON-SEQ)
+// defined in draft-ietf-quic-qlog-main-schema-04, Section 6.2.
+//
+// A JSONHandler may be used as the handler for a quic.Config.QLogLogger.
+// It is not a general-purpose slog handler,
+// and may not properly handle events from other sources.
+func NewJSONHandler(opts HandlerOptions) slog.Handler {
+ if opts.Dir == "" && opts.NewTrace == nil {
+ return slogDiscard{}
+ }
+ return &endpointHandler{
+ opts: opts,
+ }
+}
+
+func (h *endpointHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return enabled(h.opts.Level, level)
+}
+
+func (h *endpointHandler) Handle(ctx context.Context, r slog.Record) error {
+ h.traceOnce.Do(func() {
+ h.trace, _ = newJSONTraceHandler(h.opts, nil)
+ })
+ if h.trace != nil {
+ h.trace.Handle(ctx, r)
+ }
+ return nil
+}
+
+func (h *endpointHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ // Create a new trace output file for each top-level WithAttrs.
+ tr, err := newJSONTraceHandler(h.opts, attrs)
+ if err != nil {
+ return withAttrs(h, attrs)
+ }
+ return tr
+}
+
+func (h *endpointHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
+
+type jsonTraceHandler struct {
+ level slog.Leveler
+ w jsonWriter
+ start time.Time
+ buf bytes.Buffer
+}
+
+func newJSONTraceHandler(opts HandlerOptions, attrs []slog.Attr) (*jsonTraceHandler, error) {
+ w, err := newTraceWriter(opts, traceInfoFromAttrs(attrs))
+ if err != nil {
+ return nil, err
+ }
+
+ // For testing, it might be nice to set the start time used for relative timestamps
+ // to the time of the first event.
+ //
+ // At the expense of some additional complexity here, we could defer writing
+ // the reference_time header field until the first event is processed.
+ //
+ // Just use the current time for now.
+ start := time.Now()
+
+ h := &jsonTraceHandler{
+ w: jsonWriter{w: w},
+ level: opts.Level,
+ start: start,
+ }
+ h.writeHeader(attrs)
+ return h, nil
+}
+
+func traceInfoFromAttrs(attrs []slog.Attr) TraceInfo {
+ info := TraceInfo{
+ Vantage: VantageEndpoint, // default if not specified
+ }
+ for _, a := range attrs {
+ if a.Key == "group_id" && a.Value.Kind() == slog.KindString {
+ info.GroupID = a.Value.String()
+ }
+ if a.Key == "vantage_point" && a.Value.Kind() == slog.KindGroup {
+ for _, aa := range a.Value.Group() {
+ if aa.Key == "type" && aa.Value.Kind() == slog.KindString {
+ info.Vantage = Vantage(aa.Value.String())
+ }
+ }
+ }
+ }
+ return info
+}
+
+func newTraceWriter(opts HandlerOptions, info TraceInfo) (io.WriteCloser, error) {
+ var w io.WriteCloser
+ var err error
+ if opts.NewTrace != nil {
+ w, err = opts.NewTrace(info)
+ } else if opts.Dir != "" {
+ var filename string
+ if info.GroupID != "" {
+ filename = info.GroupID + "_"
+ }
+ filename += string(info.Vantage) + ".sqlog"
+ if !filepath.IsLocal(filename) {
+ return nil, errors.New("invalid trace filename")
+ }
+ w, err = os.OpenFile(filepath.Join(opts.Dir, filename), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0666)
+ } else {
+ err = errors.New("no log destination")
+ }
+ return w, err
+}
+
+func (h *jsonTraceHandler) writeHeader(attrs []slog.Attr) {
+ h.w.writeRecordStart()
+ defer h.w.writeRecordEnd()
+
+ // At the time of writing this comment the most recent version is 0.4,
+ // but qvis only supports up to 0.3.
+ h.w.writeStringField("qlog_version", "0.3")
+ h.w.writeStringField("qlog_format", "JSON-SEQ")
+
+ // The attrs flatten both common trace event fields and Trace fields.
+ // This identifies the fields that belong to the Trace.
+ isTraceSeqField := func(s string) bool {
+ switch s {
+ case "title", "description", "configuration", "vantage_point":
+ return true
+ }
+ return false
+ }
+
+ h.w.writeObjectField("trace", func() {
+ h.w.writeObjectField("common_fields", func() {
+ h.w.writeRawField("protocol_type", `["QUIC"]`)
+ h.w.writeStringField("time_format", "relative")
+ h.w.writeTimeField("reference_time", h.start)
+ for _, a := range attrs {
+ if !isTraceSeqField(a.Key) {
+ h.w.writeAttr(a)
+ }
+ }
+ })
+ for _, a := range attrs {
+ if isTraceSeqField(a.Key) {
+ h.w.writeAttr(a)
+ }
+ }
+ })
+}
+
+func (h *jsonTraceHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return enabled(h.level, level)
+}
+
+func (h *jsonTraceHandler) Handle(ctx context.Context, r slog.Record) error {
+ h.w.writeRecordStart()
+ defer h.w.writeRecordEnd()
+ h.w.writeDurationField("time", r.Time.Sub(h.start))
+ h.w.writeStringField("name", r.Message)
+ h.w.writeObjectField("data", func() {
+ r.Attrs(func(a slog.Attr) bool {
+ h.w.writeAttr(a)
+ return true
+ })
+ })
+ return nil
+}
+
+func (h *jsonTraceHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return withAttrs(h, attrs)
+}
+
+func (h *jsonTraceHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
+
+func enabled(leveler slog.Leveler, level slog.Level) bool {
+ var minLevel slog.Level
+ if leveler != nil {
+ minLevel = leveler.Level()
+ }
+ return level >= minLevel
+}
+
+type slogDiscard struct{}
+
+func (slogDiscard) Enabled(context.Context, slog.Level) bool { return false }
+func (slogDiscard) Handle(ctx context.Context, r slog.Record) error { return nil }
+func (slogDiscard) WithAttrs(attrs []slog.Attr) slog.Handler { return slogDiscard{} }
+func (slogDiscard) WithGroup(name string) slog.Handler { return slogDiscard{} }
diff --git a/quic/qlog/qlog_test.go b/quic/qlog/qlog_test.go
new file mode 100644
index 0000000000..7575cd890e
--- /dev/null
+++ b/quic/qlog/qlog_test.go
@@ -0,0 +1,151 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "reflect"
+ "testing"
+ "time"
+)
+
+// QLog tests are mostly in the quic package, where we can test event generation
+// and serialization together.
+
+func TestQLogHandlerEvents(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ f func(*slog.Logger)
+ want []map[string]any // events, not counting the trace header
+ }{{
+ name: "various types",
+ f: func(log *slog.Logger) {
+ log.Info("message",
+ "bool", true,
+ "duration", time.Duration(1*time.Second),
+ "float", 0.0,
+ "int", 0,
+ "string", "value",
+ "uint", uint64(0),
+ slog.Group("group",
+ "a", 0,
+ ),
+ )
+ },
+ want: []map[string]any{{
+ "name": "message",
+ "data": map[string]any{
+ "bool": true,
+ "duration": float64(1000),
+ "float": float64(0.0),
+ "int": float64(0),
+ "string": "value",
+ "uint": float64(0),
+ "group": map[string]any{
+ "a": float64(0),
+ },
+ },
+ }},
+ }, {
+ name: "WithAttrs",
+ f: func(log *slog.Logger) {
+ log = log.With(
+ "with_a", "a",
+ "with_b", "b",
+ )
+ log.Info("m1", "field", "1")
+ log.Info("m2", "field", "2")
+ },
+ want: []map[string]any{{
+ "name": "m1",
+ "data": map[string]any{
+ "with_a": "a",
+ "with_b": "b",
+ "field": "1",
+ },
+ }, {
+ "name": "m2",
+ "data": map[string]any{
+ "with_a": "a",
+ "with_b": "b",
+ "field": "2",
+ },
+ }},
+ }, {
+ name: "WithGroup",
+ f: func(log *slog.Logger) {
+ log = log.With(
+ "with_a", "a",
+ "with_b", "b",
+ )
+ log.Info("m1", "field", "1")
+ log.Info("m2", "field", "2")
+ },
+ want: []map[string]any{{
+ "name": "m1",
+ "data": map[string]any{
+ "with_a": "a",
+ "with_b": "b",
+ "field": "1",
+ },
+ }, {
+ "name": "m2",
+ "data": map[string]any{
+ "with_a": "a",
+ "with_b": "b",
+ "field": "2",
+ },
+ }},
+ }} {
+ var out bytes.Buffer
+ opts := HandlerOptions{
+ Level: slog.LevelDebug,
+ NewTrace: func(TraceInfo) (io.WriteCloser, error) {
+ return nopCloseWriter{&out}, nil
+ },
+ }
+ h, err := newJSONTraceHandler(opts, []slog.Attr{
+ slog.String("group_id", "group"),
+ slog.Group("vantage_point",
+ slog.String("type", "client"),
+ ),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ log := slog.New(h)
+ test.f(log)
+ got := []map[string]any{}
+ for i, e := range bytes.Split(out.Bytes(), []byte{0x1e}) {
+ // i==0: empty string before the initial record separator
+ // i==1: trace header; not part of this test
+ if i < 2 {
+ continue
+ }
+ var val map[string]any
+ if err := json.Unmarshal(e, &val); err != nil {
+ panic(fmt.Errorf("log unmarshal failure: %v\n%q", err, string(e)))
+ }
+ delete(val, "time")
+ got = append(got, val)
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("event mismatch\ngot: %v\nwant: %v", got, test.want)
+ }
+ }
+
+}
+
+type nopCloseWriter struct {
+ io.Writer
+}
+
+func (nopCloseWriter) Close() error { return nil }
diff --git a/quic/qlog_test.go b/quic/qlog_test.go
new file mode 100644
index 0000000000..c0b5cd170f
--- /dev/null
+++ b/quic/qlog_test.go
@@ -0,0 +1,364 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "reflect"
+ "testing"
+ "time"
+
+ "golang.org/x/net/quic/qlog"
+)
+
+func TestQLogHandshake(t *testing.T) {
+ testSides(t, "", func(t *testing.T, side connSide) {
+ qr := &qlogRecord{}
+ tc := newTestConn(t, side, qr.config)
+ tc.handshake()
+ tc.conn.Abort(nil)
+ tc.wantFrame("aborting connection generates CONN_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errNo,
+ })
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+ tc.advanceToTimer() // let the conn finish draining
+
+ var src, dst []byte
+ if side == clientSide {
+ src = testLocalConnID(0)
+ dst = testLocalConnID(-1)
+ } else {
+ src = testPeerConnID(-1)
+ dst = testPeerConnID(0)
+ }
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_started",
+ "data": map[string]any{
+ "src_cid": hex.EncodeToString(src),
+ "dst_cid": hex.EncodeToString(dst),
+ },
+ }, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "clean",
+ },
+ })
+ })
+}
+
+func TestQLogPacketFrames(t *testing.T) {
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
+ tc.handshake()
+ tc.conn.Abort(nil)
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+ tc.advanceToTimer() // let the conn finish draining
+
+ qr.wantEvents(t, jsonEvent{
+ "name": "transport:packet_sent",
+ "data": map[string]any{
+ "header": map[string]any{
+ "packet_type": "initial",
+ "packet_number": 0,
+ "dcid": hex.EncodeToString(testLocalConnID(-1)),
+ "scid": hex.EncodeToString(testLocalConnID(0)),
+ },
+ "frames": []any{
+ map[string]any{"frame_type": "crypto"},
+ },
+ },
+ }, jsonEvent{
+ "name": "transport:packet_received",
+ "data": map[string]any{
+ "header": map[string]any{
+ "packet_type": "initial",
+ "packet_number": 0,
+ "dcid": hex.EncodeToString(testLocalConnID(0)),
+ "scid": hex.EncodeToString(testPeerConnID(0)),
+ },
+ "frames": []any{map[string]any{"frame_type": "crypto"}},
+ },
+ })
+}
+
+func TestQLogConnectionClosedTrigger(t *testing.T) {
+ for _, test := range []struct {
+ trigger string
+ connOpts []any
+ f func(*testConn)
+ }{{
+ trigger: "clean",
+ f: func(tc *testConn) {
+ tc.handshake()
+ tc.conn.Abort(nil)
+ },
+ }, {
+ trigger: "handshake_timeout",
+ connOpts: []any{
+ func(c *Config) {
+ c.HandshakeTimeout = 5 * time.Second
+ },
+ },
+ f: func(tc *testConn) {
+ tc.ignoreFrame(frameTypeCrypto)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypePing)
+ tc.advance(5 * time.Second)
+ },
+ }, {
+ trigger: "idle_timeout",
+ connOpts: []any{
+ func(c *Config) {
+ c.MaxIdleTimeout = 5 * time.Second
+ },
+ },
+ f: func(tc *testConn) {
+ tc.handshake()
+ tc.advance(5 * time.Second)
+ },
+ }, {
+ trigger: "error",
+ f: func(tc *testConn) {
+ tc.handshake()
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ })
+ tc.conn.Abort(nil)
+ },
+ }} {
+ t.Run(test.trigger, func(t *testing.T) {
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, append(test.connOpts, qr.config)...)
+ test.f(tc)
+ fr, ptype := tc.readFrame()
+ switch fr := fr.(type) {
+ case debugFrameConnectionCloseTransport:
+ tc.writeFrames(ptype, fr)
+ case nil:
+ default:
+ t.Fatalf("unexpected frame: %v", fr)
+ }
+ tc.wantIdle("connection should be idle while closing")
+ tc.advance(5 * time.Second) // long enough for the drain timer to expire
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": test.trigger,
+ },
+ })
+ })
+ }
+}
+
+func TestQLogRecovery(t *testing.T) {
+ qr := &qlogRecord{}
+ tc, s := newTestConnAndLocalStream(t, clientSide, uniStream,
+ permissiveTransportParameters, qr.config)
+
+ // Ignore events from the handshake.
+ qr.ev = nil
+
+ data := make([]byte, 16)
+ s.Write(data)
+ s.CloseWrite()
+ tc.wantFrame("created stream 0",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, uniStream, 0),
+ fin: true,
+ data: data,
+ })
+ tc.writeAckForAll()
+ tc.wantIdle("connection should be idle now")
+
+ // Don't check the contents of fields, but verify that recovery metrics are logged.
+ qr.wantEvents(t, jsonEvent{
+ "name": "recovery:metrics_updated",
+ "data": map[string]any{
+ "bytes_in_flight": nil,
+ },
+ }, jsonEvent{
+ "name": "recovery:metrics_updated",
+ "data": map[string]any{
+ "bytes_in_flight": 0,
+ "congestion_window": nil,
+ "latest_rtt": nil,
+ "min_rtt": nil,
+ "rtt_variance": nil,
+ "smoothed_rtt": nil,
+ },
+ })
+}
+
+func TestQLogLoss(t *testing.T) {
+ qr := &qlogRecord{}
+ tc, s := newTestConnAndLocalStream(t, clientSide, uniStream,
+ permissiveTransportParameters, qr.config)
+
+ // Ignore events from the handshake.
+ qr.ev = nil
+
+ data := make([]byte, 16)
+ s.Write(data)
+ s.CloseWrite()
+ tc.wantFrame("created stream 0",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, uniStream, 0),
+ fin: true,
+ data: data,
+ })
+
+ const pto = false
+ tc.triggerLossOrPTO(packetType1RTT, pto)
+
+ qr.wantEvents(t, jsonEvent{
+ "name": "recovery:packet_lost",
+ "data": map[string]any{
+ "header": map[string]any{
+ "packet_number": nil,
+ "packet_type": "1RTT",
+ },
+ },
+ })
+}
+
+func TestQLogPacketDropped(t *testing.T) {
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, permissiveTransportParameters, qr.config)
+ tc.handshake()
+
+ // A garbage-filled datagram with a DCID matching this connection.
+ dgram := bytes.Join([][]byte{
+ {headerFormShort | fixedBit},
+ testLocalConnID(0),
+ make([]byte, 100),
+ []byte{1, 2, 3, 4}, // random data, to avoid this looking like a stateless reset
+ }, nil)
+ tc.endpoint.write(&datagram{
+ b: dgram,
+ })
+
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:packet_dropped",
+ })
+}
+
+type nopCloseWriter struct {
+ io.Writer
+}
+
+func (nopCloseWriter) Close() error { return nil }
+
+type jsonEvent map[string]any
+
+func (j jsonEvent) String() string {
+ b, _ := json.MarshalIndent(j, "", " ")
+ return string(b)
+}
+
+// jsonPartialEqual compares two JSON structures.
+// It ignores fields not set in want (see below for specifics).
+func jsonPartialEqual(got, want any) (equal bool) {
+ cmpval := func(v any) any {
+ // Map certain types to a common representation.
+ switch v := v.(type) {
+ case int:
+ // JSON uses float64s rather than ints for numbers.
+ // Map int->float64 so we can use integers in expectations.
+ return float64(v)
+ case jsonEvent:
+ return (map[string]any)(v)
+ case []jsonEvent:
+ s := []any{}
+ for _, e := range v {
+ s = append(s, e)
+ }
+ return s
+ }
+ return v
+ }
+ if want == nil {
+ return true // match anything
+ }
+ got = cmpval(got)
+ want = cmpval(want)
+ if reflect.TypeOf(got) != reflect.TypeOf(want) {
+ return false
+ }
+ switch w := want.(type) {
+ case map[string]any:
+ // JSON object: Every field in want must match a field in got.
+ g := got.(map[string]any)
+ for k := range w {
+ if !jsonPartialEqual(g[k], w[k]) {
+ return false
+ }
+ }
+ case []any:
+ // JSON slice: Every field in want must match a field in got, in order.
+ // So want=[2,4] matches got=[1,2,3,4] but not [4,2].
+ g := got.([]any)
+ for _, ge := range g {
+ if jsonPartialEqual(ge, w[0]) {
+ w = w[1:]
+ if len(w) == 0 {
+ return true
+ }
+ }
+ }
+ return false
+ default:
+ if !reflect.DeepEqual(got, want) {
+ return false
+ }
+ }
+ return true
+}
+
+// A qlogRecord records events.
+type qlogRecord struct {
+ ev []jsonEvent
+}
+
+func (q *qlogRecord) Write(b []byte) (int, error) {
+ // This relies on the property that the Handler always makes one Write call per event.
+ if len(b) < 1 || b[0] != 0x1e {
+ panic(fmt.Errorf("trace Write should start with record separator, got %q", string(b)))
+ }
+ var val map[string]any
+ if err := json.Unmarshal(b[1:], &val); err != nil {
+ panic(fmt.Errorf("log unmarshal failure: %v\n%v", err, string(b)))
+ }
+ q.ev = append(q.ev, val)
+ return len(b), nil
+}
+
+func (q *qlogRecord) Close() error { return nil }
+
+// config may be passed to newTestConn to configure the conn to use this logger.
+func (q *qlogRecord) config(c *Config) {
+ c.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) {
+ return q, nil
+ },
+ }))
+}
+
+// wantEvents checks that every event in want occurs in the order specified.
+func (q *qlogRecord) wantEvents(t *testing.T, want ...jsonEvent) {
+ t.Helper()
+ got := q.ev
+ if !jsonPartialEqual(got, want) {
+ t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want)
+ }
+}
diff --git a/internal/quic/queue.go b/quic/queue.go
similarity index 100%
rename from internal/quic/queue.go
rename to quic/queue.go
diff --git a/internal/quic/queue_test.go b/quic/queue_test.go
similarity index 100%
rename from internal/quic/queue_test.go
rename to quic/queue_test.go
diff --git a/internal/quic/quic.go b/quic/quic.go
similarity index 89%
rename from internal/quic/quic.go
rename to quic/quic.go
index 9de97b6d88..3e62d7cd94 100644
--- a/internal/quic/quic.go
+++ b/quic/quic.go
@@ -54,13 +54,24 @@ const (
maxPeerActiveConnIDLimit = 4
)
+// Time limit for completing the handshake.
+const defaultHandshakeTimeout = 10 * time.Second
+
+// Keep-alive ping frequency.
+const defaultKeepAlivePeriod = 0
+
// Local timer granularity.
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6
const timerGranularity = 1 * time.Millisecond
-// Minimum size of a UDP datagram sent by a client carrying an Initial packet.
+// The smallest allowed maximum datagram size.
+// https://www.rfc-editor.org/rfc/rfc9000#section-14
+const smallestMaxDatagramSize = 1200
+
+// Minimum size of a UDP datagram sent by a client carrying an Initial packet,
+// or a server containing an ack-eliciting Initial packet.
// https://www.rfc-editor.org/rfc/rfc9000#section-14.1
-const minimumClientInitialDatagramSize = 1200
+const paddedInitialDatagramSize = smallestMaxDatagramSize
// Maximum number of streams of a given type which may be created.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2
@@ -133,6 +144,17 @@ const (
streamTypeCount
)
+func (s streamType) qlogString() string {
+ switch s {
+ case bidiStream:
+ return "bidirectional"
+ case uniStream:
+ return "unidirectional"
+ default:
+ return "BUG"
+ }
+}
+
func (s streamType) String() string {
switch s {
case bidiStream:
diff --git a/internal/quic/quic_test.go b/quic/quic_test.go
similarity index 100%
rename from internal/quic/quic_test.go
rename to quic/quic_test.go
diff --git a/internal/quic/rangeset.go b/quic/rangeset.go
similarity index 94%
rename from internal/quic/rangeset.go
rename to quic/rangeset.go
index 4966a99d2c..528d53df39 100644
--- a/internal/quic/rangeset.go
+++ b/quic/rangeset.go
@@ -50,7 +50,7 @@ func (s *rangeset[T]) add(start, end T) {
if end <= r.end {
return
}
- // Possibly coalesce subsquent ranges into range i.
+ // Possibly coalesce subsequent ranges into range i.
r.end = end
j := i + 1
for ; j < len(*s) && r.end >= (*s)[j].start; j++ {
@@ -159,6 +159,14 @@ func (s rangeset[T]) numRanges() int {
return len(s)
}
+// size returns the size of all ranges in the rangeset.
+func (s rangeset[T]) size() (total T) {
+ for _, r := range s {
+ total += r.size()
+ }
+ return total
+}
+
// isrange reports if the rangeset covers exactly the range [start, end).
func (s rangeset[T]) isrange(start, end T) bool {
switch len(s) {
diff --git a/internal/quic/rangeset_test.go b/quic/rangeset_test.go
similarity index 100%
rename from internal/quic/rangeset_test.go
rename to quic/rangeset_test.go
diff --git a/quic/retry.go b/quic/retry.go
new file mode 100644
index 0000000000..8c56ee1b10
--- /dev/null
+++ b/quic/retry.go
@@ -0,0 +1,239 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/binary"
+ "net/netip"
+ "time"
+
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+// AEAD and nonce used to compute the Retry Integrity Tag.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.8
+var (
+ retrySecret = []byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}
+ retryNonce = []byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
+ retryAEAD = func() cipher.AEAD {
+ c, err := aes.NewCipher(retrySecret)
+ if err != nil {
+ panic(err)
+ }
+ aead, err := cipher.NewGCM(c)
+ if err != nil {
+ panic(err)
+ }
+ return aead
+ }()
+)
+
+// retryTokenValidityPeriod is how long we accept a Retry packet token after sending it.
+const retryTokenValidityPeriod = 5 * time.Second
+
+// retryState generates and validates an endpoint's retry tokens.
+type retryState struct {
+ aead cipher.AEAD
+}
+
+func (rs *retryState) init() error {
+ // Retry tokens are authenticated using a per-server key chosen at start time.
+ // TODO: Provide a way for the user to set this key.
+ secret := make([]byte, chacha20poly1305.KeySize)
+ if _, err := rand.Read(secret); err != nil {
+ return err
+ }
+ aead, err := chacha20poly1305.NewX(secret)
+ if err != nil {
+ panic(err)
+ }
+ rs.aead = aead
+ return nil
+}
+
+// Retry tokens are encrypted with an AEAD.
+// The plaintext contains the time the token was created and
+// the original destination connection ID.
+// The additional data contains the sender's source address and original source connection ID.
+// The token nonce is randomly generated.
+// We use the nonce as the Source Connection ID of the Retry packet.
+// Since the 24-byte XChaCha20-Poly1305 nonce is too large to fit in a 20-byte connection ID,
+// we include the remaining 4 bytes of nonce in the token.
+//
+// Token {
+// Last 4 Bytes of Nonce (32),
+// Ciphertext (..),
+// }
+//
+// Plaintext {
+// Timestamp (64),
+// Original Destination Connection ID,
+// }
+//
+//
+// Additional Data {
+// Original Source Connection ID Length (8),
+// Original Source Connection ID (..),
+// IP Address (32..128),
+// Port (16),
+// }
+//
+// TODO: Consider using AES-256-GCM-SIV once crypto/tls supports it.
+
+func (rs *retryState) makeToken(now time.Time, srcConnID, origDstConnID []byte, addr netip.AddrPort) (token, newDstConnID []byte, err error) {
+ nonce := make([]byte, rs.aead.NonceSize())
+ if _, err := rand.Read(nonce); err != nil {
+ return nil, nil, err
+ }
+
+ var plaintext []byte
+ plaintext = binary.BigEndian.AppendUint64(plaintext, uint64(now.Unix()))
+ plaintext = append(plaintext, origDstConnID...)
+
+ token = append(token, nonce[maxConnIDLen:]...)
+ token = rs.aead.Seal(token, nonce, plaintext, rs.additionalData(srcConnID, addr))
+ return token, nonce[:maxConnIDLen], nil
+}
+
+func (rs *retryState) validateToken(now time.Time, token, srcConnID, dstConnID []byte, addr netip.AddrPort) (origDstConnID []byte, ok bool) {
+ tokenNonceLen := rs.aead.NonceSize() - maxConnIDLen
+ if len(token) < tokenNonceLen {
+ return nil, false
+ }
+ nonce := append([]byte{}, dstConnID...)
+ nonce = append(nonce, token[:tokenNonceLen]...)
+ ciphertext := token[tokenNonceLen:]
+
+ plaintext, err := rs.aead.Open(nil, nonce, ciphertext, rs.additionalData(srcConnID, addr))
+ if err != nil {
+ return nil, false
+ }
+ if len(plaintext) < 8 {
+ return nil, false
+ }
+ when := time.Unix(int64(binary.BigEndian.Uint64(plaintext)), 0)
+ origDstConnID = plaintext[8:]
+
+ // We allow for tokens created in the future (up to the validity period),
+ // which likely indicates that the system clock was adjusted backwards.
+ if d := abs(now.Sub(when)); d > retryTokenValidityPeriod {
+ return nil, false
+ }
+
+ return origDstConnID, true
+}
+
+func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []byte {
+ var additional []byte
+ additional = quicwire.AppendUint8Bytes(additional, srcConnID)
+ additional = append(additional, addr.Addr().AsSlice()...)
+ additional = binary.BigEndian.AppendUint16(additional, addr.Port())
+ return additional
+}
+
+func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) {
+ // The retry token is at the start of an Initial packet's data.
+ token, n := quicwire.ConsumeUint8Bytes(p.data)
+ if n < 0 {
+ // We've already validated that the packet is at least 1200 bytes long,
+ // so there's no way for even a maximum size token to not fit.
+ // Check anyway.
+ return nil, false
+ }
+ if len(token) == 0 {
+ // The sender has not provided a token.
+ // Send a Retry packet to them with one.
+ e.sendRetry(now, p, peerAddr)
+ return nil, false
+ }
+ origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr)
+ if !ok {
+ // This does not seem to be a valid token.
+ // Close the connection with an INVALID_TOKEN error.
+ // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5
+ e.sendConnectionClose(p, peerAddr, errInvalidToken)
+ return nil, false
+ }
+ return origDstConnID, true
+}
+
+func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) {
+ token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr)
+ if err != nil {
+ return
+ }
+ b := encodeRetryPacket(p.dstConnID, retryPacket{
+ dstConnID: p.srcConnID,
+ srcConnID: srcConnID,
+ token: token,
+ })
+ e.sendDatagram(datagram{
+ b: b,
+ peerAddr: peerAddr,
+ })
+}
+
+type retryPacket struct {
+ dstConnID []byte
+ srcConnID []byte
+ token []byte
+}
+
+func encodeRetryPacket(originalDstConnID []byte, p retryPacket) []byte {
+ // Retry packets include an integrity tag, computed by AEAD_AES_128_GCM over
+ // the original destination connection ID followed by the Retry packet
+ // (less the integrity tag itself).
+ // https://www.rfc-editor.org/rfc/rfc9001#section-5.8
+ //
+ // Create the pseudo-packet (including the original DCID), append the tag,
+ // and return the Retry packet.
+ var b []byte
+ b = quicwire.AppendUint8Bytes(b, originalDstConnID) // Original Destination Connection ID
+ start := len(b) // start of the Retry packet
+ b = append(b, headerFormLong|fixedBit|longPacketTypeRetry)
+ b = binary.BigEndian.AppendUint32(b, quicVersion1) // Version
+ b = quicwire.AppendUint8Bytes(b, p.dstConnID) // Destination Connection ID
+ b = quicwire.AppendUint8Bytes(b, p.srcConnID) // Source Connection ID
+ b = append(b, p.token...) // Token
+ b = retryAEAD.Seal(b, retryNonce, nil, b) // Retry Integrity Tag
+ return b[start:]
+}
+
+func parseRetryPacket(b, origDstConnID []byte) (p retryPacket, ok bool) {
+ const retryIntegrityTagLength = 128 / 8
+
+ lp, ok := parseGenericLongHeaderPacket(b)
+ if !ok {
+ return retryPacket{}, false
+ }
+ if len(lp.data) < retryIntegrityTagLength {
+ return retryPacket{}, false
+ }
+ gotTag := lp.data[len(lp.data)-retryIntegrityTagLength:]
+
+ // Create the pseudo-packet consisting of the original destination connection ID
+ // followed by the Retry packet (less the integrity tag).
+ // Use this to validate the packet integrity tag.
+ pseudo := quicwire.AppendUint8Bytes(nil, origDstConnID)
+ pseudo = append(pseudo, b[:len(b)-retryIntegrityTagLength]...)
+ wantTag := retryAEAD.Seal(nil, retryNonce, nil, pseudo)
+ if !bytes.Equal(gotTag, wantTag) {
+ return retryPacket{}, false
+ }
+
+ token := lp.data[:len(lp.data)-retryIntegrityTagLength]
+ return retryPacket{
+ dstConnID: lp.dstConnID,
+ srcConnID: lp.srcConnID,
+ token: token,
+ }, true
+}
diff --git a/quic/retry_test.go b/quic/retry_test.go
new file mode 100644
index 0000000000..c898ad331d
--- /dev/null
+++ b/quic/retry_test.go
@@ -0,0 +1,570 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "net/netip"
+ "testing"
+ "time"
+)
+
+type retryServerTest struct {
+ te *testEndpoint
+ originalSrcConnID []byte
+ originalDstConnID []byte
+ retry retryPacket
+ initialCrypto []byte
+}
+
+// newRetryServerTest creates a test server connection,
+// sends the connection an Initial packet,
+// and expects a Retry in response.
+func newRetryServerTest(t *testing.T) *retryServerTest {
+ t.Helper()
+ config := &Config{
+ TLSConfig: newTestTLSConfig(serverSide),
+ RequireAddressValidation: true,
+ }
+ te := newTestEndpoint(t, config)
+ srcID := testPeerConnID(0)
+ dstID := testLocalConnID(-1)
+ params := defaultTransportParameters()
+ params.initialSrcConnID = srcID
+ initialCrypto := initialClientCrypto(t, te, params)
+
+ // Initial packet with no Token.
+ // Server responds with a Retry containing a token.
+ te.writeDatagram(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 0,
+ version: quicVersion1,
+ srcConnID: srcID,
+ dstConnID: dstID,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: initialCrypto,
+ },
+ },
+ }},
+ paddedSize: 1200,
+ })
+ got := te.readDatagram()
+ if len(got.packets) != 1 || got.packets[0].ptype != packetTypeRetry {
+ t.Fatalf("got datagram: %v\nwant Retry", got)
+ }
+ p := got.packets[0]
+ if got, want := p.dstConnID, srcID; !bytes.Equal(got, want) {
+ t.Fatalf("Retry destination = {%x}, want {%x}", got, want)
+ }
+
+ return &retryServerTest{
+ te: te,
+ originalSrcConnID: srcID,
+ originalDstConnID: dstID,
+ retry: retryPacket{
+ dstConnID: p.dstConnID,
+ srcConnID: p.srcConnID,
+ token: p.token,
+ },
+ initialCrypto: initialCrypto,
+ }
+}
+
+func TestRetryServerSucceeds(t *testing.T) {
+ rt := newRetryServerTest(t)
+ te := rt.te
+ te.advance(retryTokenValidityPeriod)
+ te.writeDatagram(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: quicVersion1,
+ srcConnID: rt.originalSrcConnID,
+ dstConnID: rt.retry.srcConnID,
+ token: rt.retry.token,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: rt.initialCrypto,
+ },
+ },
+ }},
+ paddedSize: 1200,
+ })
+ tc := te.accept()
+ initial := tc.readPacket()
+ if initial == nil || initial.ptype != packetTypeInitial {
+ t.Fatalf("got packet:\n%v\nwant: Initial", initial)
+ }
+ handshake := tc.readPacket()
+ if handshake == nil || handshake.ptype != packetTypeHandshake {
+ t.Fatalf("got packet:\n%v\nwant: Handshake", initial)
+ }
+ if got, want := tc.sentTransportParameters.retrySrcConnID, rt.retry.srcConnID; !bytes.Equal(got, want) {
+ t.Errorf("retry_source_connection_id = {%x}, want {%x}", got, want)
+ }
+ if got, want := tc.sentTransportParameters.initialSrcConnID, initial.srcConnID; !bytes.Equal(got, want) {
+ t.Errorf("initial_source_connection_id = {%x}, want {%x}", got, want)
+ }
+ if got, want := tc.sentTransportParameters.originalDstConnID, rt.originalDstConnID; !bytes.Equal(got, want) {
+ t.Errorf("original_destination_connection_id = {%x}, want {%x}", got, want)
+ }
+}
+
+func TestRetryServerTokenInvalid(t *testing.T) {
+ // "If a server receives a client Initial that contains an invalid Retry token [...]
+ // the server SHOULD immediately close [...] the connection with an
+ // INVALID_TOKEN error."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5
+ rt := newRetryServerTest(t)
+ te := rt.te
+ te.writeDatagram(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: quicVersion1,
+ srcConnID: rt.originalSrcConnID,
+ dstConnID: rt.retry.srcConnID,
+ token: append(rt.retry.token, 0),
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: rt.initialCrypto,
+ },
+ },
+ }},
+ paddedSize: 1200,
+ })
+ te.wantDatagram("server closes connection after Initial with invalid Retry token",
+ initialConnectionCloseDatagram(
+ rt.retry.srcConnID,
+ rt.originalSrcConnID,
+ errInvalidToken))
+}
+
+func TestRetryServerTokenTooOld(t *testing.T) {
+ // "[...] a token SHOULD have an expiration time [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.3-3
+ rt := newRetryServerTest(t)
+ te := rt.te
+ te.advance(retryTokenValidityPeriod + time.Second)
+ te.writeDatagram(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: quicVersion1,
+ srcConnID: rt.originalSrcConnID,
+ dstConnID: rt.retry.srcConnID,
+ token: rt.retry.token,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: rt.initialCrypto,
+ },
+ },
+ }},
+ paddedSize: 1200,
+ })
+ te.wantDatagram("server closes connection after Initial with expired token",
+ initialConnectionCloseDatagram(
+ rt.retry.srcConnID,
+ rt.originalSrcConnID,
+ errInvalidToken))
+}
+
+func TestRetryServerTokenWrongIP(t *testing.T) {
+ // "Tokens sent in Retry packets SHOULD include information that allows the server
+ // to verify that the source IP address and port in client packets remain constant."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.4-3
+ rt := newRetryServerTest(t)
+ te := rt.te
+ te.writeDatagram(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: quicVersion1,
+ srcConnID: rt.originalSrcConnID,
+ dstConnID: rt.retry.srcConnID,
+ token: rt.retry.token,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: rt.initialCrypto,
+ },
+ },
+ }},
+ paddedSize: 1200,
+ addr: netip.MustParseAddrPort("10.0.0.2:8000"),
+ })
+ te.wantDatagram("server closes connection after Initial from wrong address",
+ initialConnectionCloseDatagram(
+ rt.retry.srcConnID,
+ rt.originalSrcConnID,
+ errInvalidToken))
+}
+
+func TestRetryServerIgnoresRetry(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.write(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: testLocalConnID(-1),
+ srcConnID: testPeerConnID(0),
+ dstConnID: testLocalConnID(0),
+ token: []byte{1, 2, 3, 4},
+ }},
+ })
+ // Send two packets, to trigger an immediate ACK.
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.wantFrameType("server connection ignores spurious Retry packet",
+ packetType1RTT, debugFrameAck{})
+}
+
+func TestRetryClientSuccess(t *testing.T) {
+ // "This token MUST be repeated by the client in all Initial packets it sends
+ // for that connection after it receives the Retry packet."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-1
+ tc := newTestConn(t, clientSide)
+ tc.wantFrame("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ newServerConnID := []byte("new_conn_id")
+ token := []byte("token")
+ tc.write(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: testLocalConnID(-1),
+ srcConnID: newServerConnID,
+ dstConnID: testLocalConnID(0),
+ token: token,
+ }},
+ })
+ tc.wantPacket("client sends a new Initial packet with a token",
+ &testPacket{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: quicVersion1,
+ srcConnID: testLocalConnID(0),
+ dstConnID: newServerConnID,
+ token: token,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ },
+ },
+ },
+ )
+ tc.advanceToTimer()
+ tc.wantPacket("after PTO client sends another Initial packet with a token",
+ &testPacket{
+ ptype: packetTypeInitial,
+ num: 2,
+ version: quicVersion1,
+ srcConnID: testLocalConnID(0),
+ dstConnID: newServerConnID,
+ token: token,
+ frames: []debugFrame{
+ debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ },
+ },
+ },
+ )
+}
+
+func TestRetryClientInvalidServerTransportParameters(t *testing.T) {
+ // Various permutations of missing or invalid values for transport parameters
+ // after a Retry.
+ // https://www.rfc-editor.org/rfc/rfc9000#section-7.3
+ initialSrcConnID := testPeerConnID(0)
+ originalDstConnID := testLocalConnID(-1)
+ retrySrcConnID := testPeerConnID(100)
+ for _, test := range []struct {
+ name string
+ f func(*transportParameters)
+ ok bool
+ }{{
+ name: "valid",
+ f: func(p *transportParameters) {},
+ ok: true,
+ }, {
+ name: "missing initial_source_connection_id",
+ f: func(p *transportParameters) {
+ p.initialSrcConnID = nil
+ },
+ }, {
+ name: "invalid initial_source_connection_id",
+ f: func(p *transportParameters) {
+ p.initialSrcConnID = []byte("invalid")
+ },
+ }, {
+ name: "missing original_destination_connection_id",
+ f: func(p *transportParameters) {
+ p.originalDstConnID = nil
+ },
+ }, {
+ name: "invalid original_destination_connection_id",
+ f: func(p *transportParameters) {
+ p.originalDstConnID = []byte("invalid")
+ },
+ }, {
+ name: "missing retry_source_connection_id",
+ f: func(p *transportParameters) {
+ p.retrySrcConnID = nil
+ },
+ }, {
+ name: "invalid retry_source_connection_id",
+ f: func(p *transportParameters) {
+ p.retrySrcConnID = []byte("invalid")
+ },
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ tc := newTestConn(t, clientSide,
+ func(p *transportParameters) {
+ p.initialSrcConnID = initialSrcConnID
+ p.originalDstConnID = originalDstConnID
+ p.retrySrcConnID = retrySrcConnID
+ },
+ test.f)
+ tc.ignoreFrame(frameTypeAck)
+ tc.wantFrameType("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.write(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: originalDstConnID,
+ srcConnID: retrySrcConnID,
+ dstConnID: testLocalConnID(0),
+ token: []byte{1, 2, 3, 4},
+ }},
+ })
+ tc.wantFrameType("client resends Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ if test.ok {
+ tc.wantFrameType("valid params, client sends Handshake",
+ packetTypeHandshake, debugFrameCrypto{})
+ } else {
+ tc.wantFrame("invalid transport parameters",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errTransportParameter,
+ })
+ }
+ })
+ }
+}
+
+func TestRetryClientIgnoresRetryAfterReceivingPacket(t *testing.T) {
+ // "After the client has received and processed an Initial or Retry packet
+ // from the server, it MUST discard any subsequent Retry packets that it receives."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ tc.wantFrameType("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ retry := &testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: testLocalConnID(-1),
+ srcConnID: testPeerConnID(100),
+ dstConnID: testLocalConnID(0),
+ token: []byte{1, 2, 3, 4},
+ }},
+ }
+ tc.write(retry)
+ tc.wantIdle("client ignores Retry after receiving Initial packet")
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrameType("client Handshake CRYPTO data",
+ packetTypeHandshake, debugFrameCrypto{})
+ tc.write(retry)
+ tc.wantIdle("client ignores Retry after discarding Initial keys")
+}
+
+func TestRetryClientIgnoresRetryAfterReceivingRetry(t *testing.T) {
+ // "After the client has received and processed an Initial or Retry packet
+ // from the server, it MUST discard any subsequent Retry packets that it receives."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1
+ tc := newTestConn(t, clientSide)
+ tc.wantFrameType("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ retry := &testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: testLocalConnID(-1),
+ srcConnID: testPeerConnID(100),
+ dstConnID: testLocalConnID(0),
+ token: []byte{1, 2, 3, 4},
+ }},
+ }
+ tc.write(retry)
+ tc.wantFrameType("client resends Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.write(retry)
+ tc.wantIdle("client ignores second Retry")
+}
+
+func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.wantFrameType("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ pkt := encodeRetryPacket(testLocalConnID(-1), retryPacket{
+ srcConnID: testPeerConnID(100),
+ dstConnID: testLocalConnID(0),
+ token: []byte{1, 2, 3, 4},
+ })
+ pkt[len(pkt)-1] ^= 1 // invalidate the integrity tag
+ tc.endpoint.write(&datagram{
+ b: pkt,
+ peerAddr: testClientAddr,
+ })
+ tc.wantIdle("client ignores Retry with invalid integrity tag")
+}
+
+func TestRetryClientIgnoresRetryWithZeroLengthToken(t *testing.T) {
+ // "A client MUST discard a Retry packet with a zero-length Retry Token field."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2
+ tc := newTestConn(t, clientSide)
+ tc.wantFrameType("client Initial CRYPTO data",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.write(&testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeRetry,
+ originalDstConnID: testLocalConnID(-1),
+ srcConnID: testPeerConnID(100),
+ dstConnID: testLocalConnID(0),
+ token: []byte{},
+ }},
+ })
+ tc.wantIdle("client ignores Retry with zero-length token")
+}
+
+func TestRetryStateValidateInvalidToken(t *testing.T) {
+ // Test handling of tokens that may have a valid signature,
+ // but unexpected contents.
+ var rs retryState
+ if err := rs.init(); err != nil {
+ t.Fatal(err)
+ }
+ nonce := make([]byte, rs.aead.NonceSize())
+ now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
+ srcConnID := []byte{1, 2, 3, 4}
+ dstConnID := nonce[:20]
+ addr := testClientAddr
+
+ for _, test := range []struct {
+ name string
+ token []byte
+ }{{
+ name: "token too short",
+ token: []byte{1, 2, 3},
+ }, {
+ name: "token plaintext too short",
+ token: func() []byte {
+ plaintext := make([]byte, 7) // not enough bytes of content
+ token := append([]byte{}, nonce[20:]...)
+ return rs.aead.Seal(token, nonce, plaintext, rs.additionalData(srcConnID, addr))
+ }(),
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ if _, ok := rs.validateToken(now, test.token, srcConnID, dstConnID, addr); ok {
+ t.Errorf("validateToken succeeded, want failure")
+ }
+ })
+ }
+}
+
+func TestParseInvalidRetryPackets(t *testing.T) {
+ originalDstConnID := []byte{1, 2, 3, 4}
+ goodPkt := encodeRetryPacket(originalDstConnID, retryPacket{
+ dstConnID: []byte{1},
+ srcConnID: []byte{2},
+ token: []byte{3},
+ })
+ for _, test := range []struct {
+ name string
+ pkt []byte
+ }{{
+ name: "packet too short",
+ pkt: goodPkt[:len(goodPkt)-4],
+ }, {
+ name: "packet header invalid",
+ pkt: goodPkt[:5],
+ }, {
+ name: "integrity tag invalid",
+ pkt: func() []byte {
+ pkt := cloneBytes(goodPkt)
+ pkt[len(pkt)-1] ^= 1
+ return pkt
+ }(),
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ if _, ok := parseRetryPacket(test.pkt, originalDstConnID); ok {
+ t.Errorf("parseRetryPacket succeeded, want failure")
+ }
+ })
+ }
+}
+
+func initialClientCrypto(t *testing.T, e *testEndpoint, p transportParameters) []byte {
+ t.Helper()
+ config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)}
+ tlsClient := tls.QUICClient(config)
+ tlsClient.SetTransportParameters(marshalTransportParameters(p))
+ tlsClient.Start(context.Background())
+ t.Cleanup(func() {
+ tlsClient.Close()
+ })
+ e.peerTLSConn = tlsClient
+ var data []byte
+ for {
+ e := tlsClient.NextEvent()
+ switch e.Kind {
+ case tls.QUICNoEvent:
+ return data
+ case tls.QUICWriteData:
+ if e.Level != tls.QUICEncryptionLevelInitial {
+ t.Fatal("initial data at unexpected level")
+ }
+ data = append(data, e.Data...)
+ }
+ }
+}
+
+func initialConnectionCloseDatagram(srcConnID, dstConnID []byte, code transportError) *testDatagram {
+ return &testDatagram{
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 0,
+ version: quicVersion1,
+ srcConnID: srcConnID,
+ dstConnID: dstConnID,
+ frames: []debugFrame{
+ debugFrameConnectionCloseTransport{
+ code: code,
+ },
+ },
+ }},
+ }
+}
diff --git a/internal/quic/rtt.go b/quic/rtt.go
similarity index 97%
rename from internal/quic/rtt.go
rename to quic/rtt.go
index 4942f8cca1..494060c67d 100644
--- a/internal/quic/rtt.go
+++ b/quic/rtt.go
@@ -37,7 +37,7 @@ func (r *rttState) establishPersistentCongestion() {
r.minRTT = r.latestRTT
}
-// updateRTTSample is called when we generate a new RTT sample.
+// updateSample is called when we generate a new RTT sample.
// https://www.rfc-editor.org/rfc/rfc9002.html#section-5
func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) {
r.latestRTT = latestRTT
diff --git a/internal/quic/rtt_test.go b/quic/rtt_test.go
similarity index 100%
rename from internal/quic/rtt_test.go
rename to quic/rtt_test.go
diff --git a/internal/quic/sent_packet.go b/quic/sent_packet.go
similarity index 81%
rename from internal/quic/sent_packet.go
rename to quic/sent_packet.go
index 4f11aa1368..eedd2f61b3 100644
--- a/internal/quic/sent_packet.go
+++ b/quic/sent_packet.go
@@ -9,14 +9,17 @@ package quic
import (
"sync"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A sentPacket tracks state related to an in-flight packet we sent,
// to be committed when the peer acks it or resent if the packet is lost.
type sentPacket struct {
- num packetNumber
- size int // size in bytes
- time time.Time // time sent
+ num packetNumber
+ size int // size in bytes
+ time time.Time // time sent
+ ptype packetType
ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1
inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1
@@ -58,6 +61,12 @@ func (sent *sentPacket) reset() {
}
}
+// markAckEliciting marks the packet as containing an ack-eliciting frame.
+func (sent *sentPacket) markAckEliciting() {
+ sent.ackEliciting = true
+ sent.inFlight = true
+}
+
// The append* methods record information about frames in the packet.
func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) {
@@ -71,12 +80,12 @@ func (sent *sentPacket) appendAckElicitingFrame(frameType byte) {
}
func (sent *sentPacket) appendInt(v uint64) {
- sent.b = appendVarint(sent.b, v)
+ sent.b = quicwire.AppendVarint(sent.b, v)
}
func (sent *sentPacket) appendOffAndSize(start int64, size int) {
- sent.b = appendVarint(sent.b, uint64(start))
- sent.b = appendVarint(sent.b, uint64(size))
+ sent.b = quicwire.AppendVarint(sent.b, uint64(start))
+ sent.b = quicwire.AppendVarint(sent.b, uint64(size))
}
// The next* methods read back information about frames in the packet.
@@ -88,7 +97,7 @@ func (sent *sentPacket) next() (frameType byte) {
}
func (sent *sentPacket) nextInt() uint64 {
- v, n := consumeVarint(sent.b[sent.n:])
+ v, n := quicwire.ConsumeVarint(sent.b[sent.n:])
sent.n += n
return v
}
diff --git a/internal/quic/sent_packet_list.go b/quic/sent_packet_list.go
similarity index 100%
rename from internal/quic/sent_packet_list.go
rename to quic/sent_packet_list.go
diff --git a/internal/quic/sent_packet_list_test.go b/quic/sent_packet_list_test.go
similarity index 100%
rename from internal/quic/sent_packet_list_test.go
rename to quic/sent_packet_list_test.go
diff --git a/internal/quic/sent_packet_test.go b/quic/sent_packet_test.go
similarity index 100%
rename from internal/quic/sent_packet_test.go
rename to quic/sent_packet_test.go
diff --git a/internal/quic/sent_val.go b/quic/sent_val.go
similarity index 98%
rename from internal/quic/sent_val.go
rename to quic/sent_val.go
index 31f69e47d0..920658919b 100644
--- a/internal/quic/sent_val.go
+++ b/quic/sent_val.go
@@ -37,7 +37,7 @@ func (s sentVal) isSet() bool { return s != 0 }
// shouldSend reports whether the value is set and has not been sent to the peer.
func (s sentVal) shouldSend() bool { return s.state() == sentValUnsent }
-// shouldSend reports whether the value needs to be sent to the peer.
+// shouldSendPTO reports whether the value needs to be sent to the peer.
// The value needs to be sent if it is set and has not been sent.
// If pto is true, indicating that we are sending a PTO probe, the value
// should also be sent if it is set and has not been acknowledged.
diff --git a/internal/quic/sent_val_test.go b/quic/sent_val_test.go
similarity index 100%
rename from internal/quic/sent_val_test.go
rename to quic/sent_val_test.go
diff --git a/quic/stateless_reset.go b/quic/stateless_reset.go
new file mode 100644
index 0000000000..53c3ba5399
--- /dev/null
+++ b/quic/stateless_reset.go
@@ -0,0 +1,61 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha256"
+ "hash"
+ "sync"
+)
+
+const statelessResetTokenLen = 128 / 8
+
+// A statelessResetToken is a stateless reset token.
+// https://www.rfc-editor.org/rfc/rfc9000#section-10.3
+type statelessResetToken [statelessResetTokenLen]byte
+
+type statelessResetTokenGenerator struct {
+ canReset bool
+
+ // The hash.Hash interface is not concurrency safe,
+ // so we need a mutex here.
+ //
+ // There shouldn't be much contention on stateless reset token generation.
+ // If this proves to be a problem, we could avoid the mutex by using a separate
+ // generator per Conn, or by using a concurrency-safe generator.
+ mu sync.Mutex
+ mac hash.Hash
+}
+
+func (g *statelessResetTokenGenerator) init(secret [32]byte) {
+ zero := true
+ for _, b := range secret {
+ if b != 0 {
+ zero = false
+ break
+ }
+ }
+ if zero {
+ // Generate tokens using a random secret, but don't send stateless resets.
+ rand.Read(secret[:])
+ g.canReset = false
+ } else {
+ g.canReset = true
+ }
+ g.mac = hmac.New(sha256.New, secret[:])
+}
+
+func (g *statelessResetTokenGenerator) tokenForConnID(cid []byte) (token statelessResetToken) {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ defer g.mac.Reset()
+ g.mac.Write(cid)
+ copy(token[:], g.mac.Sum(nil))
+ return token
+}
diff --git a/quic/stateless_reset_test.go b/quic/stateless_reset_test.go
new file mode 100644
index 0000000000..9458d2ea9d
--- /dev/null
+++ b/quic/stateless_reset_test.go
@@ -0,0 +1,288 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/tls"
+ "errors"
+ "net/netip"
+ "testing"
+ "time"
+)
+
+func TestStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) {
+ // "[The stateless_reset_token] transport parameter MUST NOT be sent by a client [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-18.2-4.6.1
+ resetToken := testPeerStatelessResetToken(0)
+ tc := newTestConn(t, serverSide, func(p *transportParameters) {
+ p.statelessResetToken = resetToken[:]
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("client provided stateless_reset_token transport parameter",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errTransportParameter,
+ })
+}
+
+var testStatelessResetKey = func() (key [32]byte) {
+ if _, err := rand.Read(key[:]); err != nil {
+ panic(err)
+ }
+ return key
+}()
+
+func testStatelessResetToken(cid []byte) statelessResetToken {
+ var gen statelessResetTokenGenerator
+ gen.init(testStatelessResetKey)
+ return gen.tokenForConnID(cid)
+}
+
+func testLocalStatelessResetToken(seq int64) statelessResetToken {
+ return testStatelessResetToken(testLocalConnID(seq))
+}
+
+func newDatagramForReset(cid []byte, size int, addr netip.AddrPort) *datagram {
+ dgram := append([]byte{headerFormShort | fixedBit}, cid...)
+ for len(dgram) < size {
+ dgram = append(dgram, byte(len(dgram))) // semi-random junk
+ }
+ return &datagram{
+ b: dgram,
+ peerAddr: addr,
+ }
+}
+
+func TestStatelessResetSentSizes(t *testing.T) {
+ config := &Config{
+ TLSConfig: newTestTLSConfig(serverSide),
+ StatelessResetKey: testStatelessResetKey,
+ }
+ addr := netip.MustParseAddr("127.0.0.1")
+ te := newTestEndpoint(t, config)
+ for i, test := range []struct {
+ reqSize int
+ wantSize int
+ }{{
+ // Datagrams larger than 42 bytes result in a 42-byte stateless reset.
+ // This isn't specifically mandated by RFC 9000, but is implied.
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-11
+ reqSize: 1200,
+ wantSize: 42,
+ }, {
+ // "An endpoint that sends a Stateless Reset in response to a packet
+ // that is 43 bytes or shorter SHOULD send a Stateless Reset that is
+ // one byte shorter than the packet it responds to."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-11
+ reqSize: 43,
+ wantSize: 42,
+ }, {
+ reqSize: 42,
+ wantSize: 41,
+ }, {
+ // We should send a stateless reset in response to the smallest possible
+ // valid datagram the peer can send us.
+ // The smallest packet is 1-RTT:
+ // header byte, conn id, packet num, payload, AEAD.
+ reqSize: 1 + connIDLen + 1 + 1 + 16,
+ wantSize: 1 + connIDLen + 1 + 1 + 16 - 1,
+ }, {
+ // The smallest possible stateless reset datagram is 21 bytes.
+ // Since our response must be smaller than the incoming datagram,
+ // we must not respond to a 21 byte or smaller packet.
+ reqSize: 21,
+ wantSize: 0,
+ }} {
+ cid := testLocalConnID(int64(i))
+ token := testStatelessResetToken(cid)
+ addrport := netip.AddrPortFrom(addr, uint16(8000+i))
+ te.write(newDatagramForReset(cid, test.reqSize, addrport))
+
+ got := te.read()
+ if len(got) != test.wantSize {
+ t.Errorf("got %v-byte response to %v-byte req, want %v",
+ len(got), test.reqSize, test.wantSize)
+ }
+ if len(got) == 0 {
+ continue
+ }
+ // "Endpoints MUST send Stateless Resets formatted as
+ // a packet with a short header."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-15
+ if isLongHeader(got[0]) {
+ t.Errorf("response to %v-byte request is not a short-header packet\ngot: %x", test.reqSize, got)
+ }
+ if !bytes.HasSuffix(got, token[:]) {
+ t.Errorf("response to %v-byte request does not end in stateless reset token\ngot: %x\nwant suffix: %x", test.reqSize, got, token)
+ }
+ }
+}
+
+func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) {
+ // "[...] Stateless Reset Token field values from [...] NEW_CONNECTION_ID frames [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+
+ // Retire connection ID 0.
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ retirePriorTo: 1,
+ seq: 2,
+ connID: testPeerConnID(2),
+ })
+ tc.wantFrame("peer requested we retire conn id 0",
+ packetType1RTT, debugFrameRetireConnectionID{
+ seq: 0,
+ })
+
+ resetToken := testPeerStatelessResetToken(1) // provided during handshake
+ dgram := append(make([]byte, 100), resetToken[:]...)
+ tc.endpoint.write(&datagram{
+ b: dgram,
+ })
+
+ if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+ t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+ }
+ tc.wantIdle("closed connection is idle in draining")
+ tc.advance(1 * time.Second) // long enough to exit the draining state
+ tc.wantIdle("closed connection is idle after draining")
+
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "stateless_reset",
+ },
+ })
+}
+
+func TestStatelessResetSuccessfulTransportParameter(t *testing.T) {
+ // "[...] Stateless Reset Token field values from [...]
+ // the server's transport parameters [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1
+ resetToken := testPeerStatelessResetToken(0)
+ tc := newTestConn(t, clientSide, func(p *transportParameters) {
+ p.statelessResetToken = resetToken[:]
+ })
+ tc.handshake()
+
+ dgram := append(make([]byte, 100), resetToken[:]...)
+ tc.endpoint.write(&datagram{
+ b: dgram,
+ })
+
+ if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+ t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+ }
+ tc.wantIdle("closed connection is idle")
+}
+
+func TestStatelessResetSuccessfulPrefix(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ prefix []byte
+ size int
+ }{{
+ name: "short header and fixed bit",
+ prefix: []byte{
+ headerFormShort | fixedBit,
+ },
+ size: 100,
+ }, {
+ // "[...] endpoints MUST treat [long header packets] ending in a
+ // valid stateless reset token as a Stateless Reset [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-15
+ name: "long header no fixed bit",
+ prefix: []byte{
+ headerFormLong,
+ },
+ size: 100,
+ }, {
+ // "[...] the comparison MUST be performed when the first packet
+ // in an incoming datagram [...] cannot be decrypted."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2
+ name: "short header valid DCID",
+ prefix: append([]byte{
+ headerFormShort | fixedBit,
+ }, testLocalConnID(0)...),
+ size: 100,
+ }, {
+ name: "handshake valid DCID",
+ prefix: append([]byte{
+ headerFormLong | fixedBit | longPacketTypeHandshake,
+ }, testLocalConnID(0)...),
+ size: 100,
+ }, {
+ name: "no fixed bit valid DCID",
+ prefix: append([]byte{
+ 0,
+ }, testLocalConnID(0)...),
+ size: 100,
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ resetToken := testPeerStatelessResetToken(0)
+ tc := newTestConn(t, clientSide, func(p *transportParameters) {
+ p.statelessResetToken = resetToken[:]
+ })
+ tc.handshake()
+
+ dgram := test.prefix
+ for len(dgram) < test.size-len(resetToken) {
+ dgram = append(dgram, byte(len(dgram))) // semi-random junk
+ }
+ dgram = append(dgram, resetToken[:]...)
+ tc.endpoint.write(&datagram{
+ b: dgram,
+ })
+ if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+ t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+ }
+ })
+ }
+}
+
+func TestStatelessResetRetiredConnID(t *testing.T) {
+ // "An endpoint MUST NOT check for any stateless reset tokens [...]
+ // for connection IDs that have been retired."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-3
+ resetToken := testPeerStatelessResetToken(0)
+ tc := newTestConn(t, clientSide, func(p *transportParameters) {
+ p.statelessResetToken = resetToken[:]
+ })
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+
+ // We retire connection ID 0.
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ retirePriorTo: 1,
+ connID: testPeerConnID(2),
+ })
+ tc.wantFrame("peer asked for conn id 0 to be retired",
+ packetType1RTT, debugFrameRetireConnectionID{
+ seq: 0,
+ })
+
+ // Receive a stateless reset for connection ID 0.
+ dgram := append(make([]byte, 100), resetToken[:]...)
+ tc.endpoint.write(&datagram{
+ b: dgram,
+ })
+
+ if err := tc.conn.Wait(canceledContext()); !errors.Is(err, context.Canceled) {
+ t.Errorf("conn.Wait() = %v, want connection to be alive", err)
+ }
+}
diff --git a/internal/quic/stream.go b/quic/stream.go
similarity index 69%
rename from internal/quic/stream.go
rename to quic/stream.go
index 89036b19b6..8068b10acd 100644
--- a/internal/quic/stream.go
+++ b/quic/stream.go
@@ -11,13 +11,36 @@ import (
"errors"
"fmt"
"io"
+ "math"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
+// A Stream is an ordered byte stream.
+//
+// Streams may be bidirectional, read-only, or write-only.
+// Methods inappropriate for a stream's direction
+// (for example, [Write] to a read-only stream)
+// return errors.
+//
+// It is not safe to perform concurrent reads from or writes to a stream.
+// It is safe, however, to read and write at the same time.
+//
+// Reads and writes are buffered.
+// It is generally not necessary to wrap a stream in a [bufio.ReadWriter]
+// or otherwise apply additional buffering.
+//
+// To cancel reads or writes, use the [SetReadContext] and [SetWriteContext] methods.
type Stream struct {
id streamID
conn *Conn
- // ingate's lock guards all receive-related state.
+ // Contexts used for read/write operations.
+ // Intentionally not mutex-guarded, to allow the race detector to catch concurrent access.
+ inctx context.Context
+ outctx context.Context
+
+ // ingate's lock guards receive-related state.
//
// The gate condition is set if a read from the stream will not block,
// either because the stream has available data or because the read will fail.
@@ -31,17 +54,18 @@ type Stream struct {
inclosed sentVal // set by CloseRead
inresetcode int64 // RESET_STREAM code received from the peer; -1 if not reset
- // outgate's lock guards all send-related state.
+ // outgate's lock guards send-related state.
//
// The gate condition is set if a write to the stream will not block,
// either because the stream has available flow control or because
// the write will fail.
outgate gate
out pipe // buffered data to send
+ outflushed int64 // offset of last flush call
outwin int64 // maximum MAX_STREAM_DATA received from the peer
outmaxsent int64 // maximum data offset we've sent to the peer
outmaxbuf int64 // maximum amount of data we will buffer
- outunsent rangeset[int64] // ranges buffered but not yet sent
+ outunsent rangeset[int64] // ranges buffered but not yet sent (only flushed data)
outacked rangeset[int64] // ranges sent and acknowledged
outopened sentVal // set if we should open the stream
outclosed sentVal // set by CloseWrite
@@ -50,6 +74,12 @@ type Stream struct {
outresetcode uint64 // reset code to send in RESET_STREAM
outdone chan struct{} // closed when all data sent
+ // Unsynchronized buffers, used for lock-free fast path.
+ inbuf []byte // received data
+ inbufoff int // bytes of inbuf which have been consumed
+ outbuf []byte // written data
+ outbufoff int // bytes of outbuf which contain data to write
+
// Atomic stream state bits.
//
// These bits provide a fast way to coordinate between the
@@ -104,6 +134,11 @@ const (
dataQueue // streamsState.queueData
)
+// streamResetByConnClose is assigned to Stream.inresetcode to indicate that a stream
+// was implicitly reset when the connection closed. It's out of the range of
+// possible reset codes the peer can send.
+const streamResetByConnClose = math.MaxInt64
+
// wantQueue returns the send queue the stream should be on.
func (s streamState) wantQueue() streamQueue {
switch {
@@ -145,6 +180,8 @@ func newStream(c *Conn, id streamID) *Stream {
inresetcode: -1, // -1 indicates no RESET_STREAM received
ingate: newLockedGate(),
outgate: newLockedGate(),
+ inctx: context.Background(),
+ outctx: context.Background(),
}
if !s.IsReadOnly() {
s.outdone = make(chan struct{})
@@ -152,6 +189,22 @@ func newStream(c *Conn, id streamID) *Stream {
return s
}
+// SetReadContext sets the context used for reads from the stream.
+//
+// It is not safe to call SetReadContext concurrently.
+func (s *Stream) SetReadContext(ctx context.Context) {
+ s.inctx = ctx
+}
+
+// SetWriteContext sets the context used for writes to the stream.
+// The write context is also used by Close when waiting for writes to be
+// received by the peer.
+//
+// It is not safe to call SetWriteContext concurrently.
+func (s *Stream) SetWriteContext(ctx context.Context) {
+ s.outctx = ctx
+}
+
// IsReadOnly reports whether the stream is read-only
// (a unidirectional stream created by the peer).
func (s *Stream) IsReadOnly() bool {
@@ -165,31 +218,49 @@ func (s *Stream) IsWriteOnly() bool {
}
// Read reads data from the stream.
-// See ReadContext for more details.
-func (s *Stream) Read(b []byte) (n int, err error) {
- return s.ReadContext(context.Background(), b)
-}
-
-// ReadContext reads data from the stream.
//
-// ReadContext returns as soon as at least one byte of data is available.
+// Read returns as soon as at least one byte of data is available.
//
-// If the peer closes the stream cleanly, ReadContext returns io.EOF after
+// If the peer closes the stream cleanly, Read returns io.EOF after
// returning all data sent by the peer.
-// If the peer aborts reads on the stream, ReadContext returns
+// If the peer aborts reads on the stream, Read returns
// an error wrapping StreamResetCode.
-func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) {
+//
+// It is not safe to call Read concurrently.
+func (s *Stream) Read(b []byte) (n int, err error) {
if s.IsWriteOnly() {
return 0, errors.New("read from write-only stream")
}
- if err := s.ingate.waitAndLock(ctx, s.conn.testHooks); err != nil {
+ if len(s.inbuf) > s.inbufoff {
+ // Fast path: If s.inbuf contains unread bytes, return them immediately
+ // without taking a lock.
+ n = copy(b, s.inbuf[s.inbufoff:])
+ s.inbufoff += n
+ return n, nil
+ }
+ if err := s.ingate.waitAndLock(s.inctx, s.conn.testHooks); err != nil {
return 0, err
}
+ if s.inbufoff > 0 {
+ // Discard bytes consumed by the fast path above.
+ s.in.discardBefore(s.in.start + int64(s.inbufoff))
+ s.inbufoff = 0
+ s.inbuf = nil
+ }
+ // bytesRead contains the number of bytes of connection-level flow control to return.
+ // We return flow control for bytes read by this Read call, as well as bytes moved
+ // to the fast-path read buffer (s.inbuf).
+ var bytesRead int64
defer func() {
s.inUnlock()
- s.conn.handleStreamBytesReadOffLoop(int64(n)) // must be done with ingate unlocked
+ s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked
}()
if s.inresetcode != -1 {
+ if s.inresetcode == streamResetByConnClose {
+ if err := s.conn.finalError(); err != nil {
+ return 0, err
+ }
+ }
return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode))
}
if s.inclosed.isSet() {
@@ -205,22 +276,50 @@ func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) {
if size := int(s.inset[0].end - s.in.start); size < len(b) {
b = b[:size]
}
+ bytesRead = int64(len(b))
start := s.in.start
end := start + int64(len(b))
s.in.copy(start, b)
s.in.discardBefore(end)
+ if end == s.insize {
+ // We have read up to the end of the stream.
+ // No need to update stream flow control.
+ return len(b), io.EOF
+ }
+ if len(s.inset) > 0 && s.inset[0].start <= s.in.start && s.inset[0].end > s.in.start {
+ // If we have more readable bytes available, put the next chunk of data
+ // in s.inbuf for lock-free reads.
+ s.inbuf = s.in.peek(s.inset[0].end - s.in.start)
+ bytesRead += int64(len(s.inbuf))
+ }
if s.insize == -1 || s.insize > s.inwin {
- if shouldUpdateFlowControl(s.inmaxbuf, s.in.start+s.inmaxbuf-s.inwin) {
+ newWindow := s.in.start + int64(len(s.inbuf)) + s.inmaxbuf
+ addedWindow := newWindow - s.inwin
+ if shouldUpdateFlowControl(s.inmaxbuf, addedWindow) {
// Update stream flow control with a STREAM_MAX_DATA frame.
s.insendmax.setUnsent()
}
}
- if end == s.insize {
- return len(b), io.EOF
- }
return len(b), nil
}
+// ReadByte reads and returns a single byte from the stream.
+//
+// It is not safe to call ReadByte concurrently.
+func (s *Stream) ReadByte() (byte, error) {
+ if len(s.inbuf) > s.inbufoff {
+ b := s.inbuf[s.inbufoff]
+ s.inbufoff++
+ return b, nil
+ }
+ var b [1]byte
+ n, err := s.Read(b[:])
+ if n > 0 {
+ return b[0], nil
+ }
+ return 0, err
+}
+
// shouldUpdateFlowControl determines whether to send a flow control window update.
//
// We want to balance keeping the peer well-supplied with flow control with not sending
@@ -230,23 +329,22 @@ func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool {
}
// Write writes data to the stream.
-// See WriteContext for more details.
-func (s *Stream) Write(b []byte) (n int, err error) {
- return s.WriteContext(context.Background(), b)
-}
-
-// WriteContext writes data to the stream.
//
-// WriteContext writes data to the stream write buffer.
+// Write writes data to the stream write buffer.
// Buffered data is only sent when the buffer is sufficiently full.
// Call the Flush method to ensure buffered data is sent.
-//
-// TODO: Implement Flush.
-func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) {
+func (s *Stream) Write(b []byte) (n int, err error) {
if s.IsReadOnly() {
return 0, errors.New("write to read-only stream")
}
+ if len(b) > 0 && len(s.outbuf)-s.outbufoff >= len(b) {
+ // Fast path: The data to write fits in s.outbuf.
+ copy(s.outbuf[s.outbufoff:], b)
+ s.outbufoff += len(b)
+ return len(b), nil
+ }
canWrite := s.outgate.lock()
+ s.flushFastOutputBuffer()
for {
// The first time through this loop, we may or may not be write blocked.
// We exit the loop after writing all data, so on subsequent passes through
@@ -254,25 +352,17 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error)
if len(b) > 0 && !canWrite {
// Our send buffer is full. Wait for the peer to ack some data.
s.outUnlock()
- if err := s.outgate.waitAndLock(ctx, s.conn.testHooks); err != nil {
+ if err := s.outgate.waitAndLock(s.outctx, s.conn.testHooks); err != nil {
return n, err
}
// Successfully returning from waitAndLockGate means we are no longer
// write blocked. (Unlike traditional condition variables, gates do not
// have spurious wakeups.)
}
- if s.outreset.isSet() {
- s.outUnlock()
- return n, errors.New("write to reset stream")
- }
- if s.outclosed.isSet() {
+ if err := s.writeErrorLocked(); err != nil {
s.outUnlock()
- return n, errors.New("write to closed stream")
+ return n, err
}
- // We set outopened here rather than below,
- // so if this is a zero-length write we still
- // open the stream despite not writing any data to it.
- s.outopened.set()
if len(b) == 0 {
break
}
@@ -282,13 +372,26 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error)
// Amount to write is min(the full buffer, data up to the write limit).
// This is a number of bytes.
nn := min(int64(len(b)), lim-s.out.end)
- // Copy the data into the output buffer and mark it as unsent.
- if s.out.end <= s.outwin {
- s.outunsent.add(s.out.end, min(s.out.end+nn, s.outwin))
- }
+ // Copy the data into the output buffer.
s.out.writeAt(b[:nn], s.out.end)
b = b[nn:]
n += int(nn)
+ // Possibly flush the output buffer.
+ // We automatically flush if:
+ // - We have enough data to consume the send window.
+ // Sending this data may cause the peer to extend the window.
+ // - We have buffered as much data as we're willing do.
+ // We need to send data to clear out buffer space.
+ // - We have enough data to fill a 1-RTT packet using the smallest
+ // possible maximum datagram size (1200 bytes, less header byte,
+ // connection ID, packet number, and AEAD overhead).
+ const autoFlushSize = smallestMaxDatagramSize - 1 - connIDLen - 1 - aeadOverhead
+ shouldFlush := s.out.end >= s.outwin || // peer send window is full
+ s.out.end >= lim || // local send buffer is full
+ (s.out.end-s.outflushed) >= autoFlushSize // enough data buffered
+ if shouldFlush {
+ s.flushLocked()
+ }
if s.out.end > s.outwin {
// We're blocked by flow control.
// Send a STREAM_DATA_BLOCKED frame to let the peer know.
@@ -297,32 +400,117 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error)
// If we have bytes left to send, we're blocked.
canWrite = false
}
+ if lim := s.out.start + s.outmaxbuf - s.out.end - 1; lim > 0 {
+ // If s.out has space allocated and available to be written into,
+ // then reference it in s.outbuf for fast-path writes.
+ //
+ // It's perhaps a bit pointless to limit s.outbuf to the send buffer limit.
+ // We've already allocated this buffer so we aren't saving any memory
+ // by not using it.
+ // For now, we limit it anyway to make it easier to reason about limits.
+ //
+ // We set the limit to one less than the send buffer limit (the -1 above)
+ // so that a write which completely fills the buffer will overflow
+ // s.outbuf and trigger a flush.
+ s.outbuf = s.out.availableBuffer()
+ if int64(len(s.outbuf)) > lim {
+ s.outbuf = s.outbuf[:lim]
+ }
+ }
s.outUnlock()
return n, nil
}
-// Close closes the stream.
-// See CloseContext for more details.
-func (s *Stream) Close() error {
- return s.CloseContext(context.Background())
+// WriteByte writes a single byte to the stream.
+func (s *Stream) WriteByte(c byte) error {
+ if s.outbufoff < len(s.outbuf) {
+ s.outbuf[s.outbufoff] = c
+ s.outbufoff++
+ return nil
+ }
+ b := [1]byte{c}
+ _, err := s.Write(b[:])
+ return err
}
-// CloseContext closes the stream.
+func (s *Stream) flushFastOutputBuffer() {
+ if s.outbuf == nil {
+ return
+ }
+ // Commit data previously written to s.outbuf.
+ // s.outbuf is a reference to a buffer in s.out, so we just need to record
+ // that the output buffer has been extended.
+ s.out.end += int64(s.outbufoff)
+ s.outbuf = nil
+ s.outbufoff = 0
+}
+
+// Flush flushes data written to the stream.
+// It does not wait for the peer to acknowledge receipt of the data.
+// Use Close to wait for the peer's acknowledgement.
+func (s *Stream) Flush() error {
+ if s.IsReadOnly() {
+ return errors.New("flush of read-only stream")
+ }
+ s.outgate.lock()
+ defer s.outUnlock()
+ if err := s.writeErrorLocked(); err != nil {
+ return err
+ }
+ s.flushLocked()
+ return nil
+}
+
+// writeErrorLocked returns the error (if any) which should be returned by write operations
+// due to the stream being reset or closed.
+func (s *Stream) writeErrorLocked() error {
+ if s.outreset.isSet() {
+ if s.outresetcode == streamResetByConnClose {
+ if err := s.conn.finalError(); err != nil {
+ return err
+ }
+ }
+ return errors.New("write to reset stream")
+ }
+ if s.outclosed.isSet() {
+ return errors.New("write to closed stream")
+ }
+ return nil
+}
+
+func (s *Stream) flushLocked() {
+ s.flushFastOutputBuffer()
+ s.outopened.set()
+ if s.outflushed < s.outwin {
+ s.outunsent.add(s.outflushed, min(s.outwin, s.out.end))
+ }
+ s.outflushed = s.out.end
+}
+
+// Close closes the stream.
// Any blocked stream operations will be unblocked and return errors.
//
-// CloseContext flushes any data in the stream write buffer and waits for the peer to
+// Close flushes any data in the stream write buffer and waits for the peer to
// acknowledge receipt of the data.
// If the stream has been reset, it waits for the peer to acknowledge the reset.
// If the context expires before the peer receives the stream's data,
-// CloseContext discards the buffer and returns the context error.
-func (s *Stream) CloseContext(ctx context.Context) error {
+// Close discards the buffer and returns the context error.
+func (s *Stream) Close() error {
s.CloseRead()
if s.IsReadOnly() {
return nil
}
s.CloseWrite()
// TODO: Return code from peer's RESET_STREAM frame?
- return s.conn.waitOnDone(ctx, s.outdone)
+ if err := s.conn.waitOnDone(s.outctx, s.outdone); err != nil {
+ return err
+ }
+ s.outgate.lock()
+ defer s.outUnlock()
+ if s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) {
+ return nil
+ }
+ return errors.New("stream reset")
}
// CloseRead aborts reads on the stream.
@@ -330,7 +518,7 @@ func (s *Stream) CloseContext(ctx context.Context) error {
//
// CloseRead notifies the peer that the stream has been closed for reading.
// It does not wait for the peer to acknowledge the closure.
-// Use CloseContext to wait for the peer's acknowledgement.
+// Use Close to wait for the peer's acknowledgement.
func (s *Stream) CloseRead() {
if s.IsWriteOnly() {
return
@@ -355,7 +543,7 @@ func (s *Stream) CloseRead() {
//
// CloseWrite sends any data in the stream write buffer to the peer.
// It does not wait for the peer to acknowledge receipt of the data.
-// Use CloseContext to wait for the peer's acknowledgement.
+// Use Close to wait for the peer's acknowledgement.
func (s *Stream) CloseWrite() {
if s.IsReadOnly() {
return
@@ -363,6 +551,7 @@ func (s *Stream) CloseWrite() {
s.outgate.lock()
defer s.outUnlock()
s.outclosed.set()
+ s.flushLocked()
}
// Reset aborts writes on the stream and notifies the peer
@@ -372,7 +561,7 @@ func (s *Stream) CloseWrite() {
// Reset sends the application protocol error code, which must be
// less than 2^62, to the peer.
// It does not wait for the peer to acknowledge receipt of the error.
-// Use CloseContext to wait for the peer's acknowledgement.
+// Use Close to wait for the peer's acknowledgement.
//
// Reset does not affect reads.
// Use CloseRead to abort reads on the stream.
@@ -398,19 +587,49 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) {
if s.outreset.isSet() {
return
}
- if code > maxVarint {
- code = maxVarint
+ if code > quicwire.MaxVarint {
+ code = quicwire.MaxVarint
}
// We could check here to see if the stream is closed and the
// peer has acked all the data and the FIN, but sending an
// extra RESET_STREAM in this case is harmless.
s.outreset.set()
s.outresetcode = code
+ s.outbuf = nil
+ s.outbufoff = 0
s.out.discardBefore(s.out.end)
s.outunsent = rangeset[int64]{}
s.outblocked.clear()
}
+// connHasClosed indicates the stream's conn has closed.
+func (s *Stream) connHasClosed() {
+ // If we're in the closing state, the user closed the conn.
+ // Otherwise, we the peer initiated the close.
+ // This only matters for the error we're going to return from stream operations.
+ localClose := s.conn.lifetime.state == connStateClosing
+
+ s.ingate.lock()
+ if !s.inset.isrange(0, s.insize) && s.inresetcode == -1 {
+ if localClose {
+ s.inclosed.set()
+ } else {
+ s.inresetcode = streamResetByConnClose
+ }
+ }
+ s.inUnlock()
+
+ s.outgate.lock()
+ if localClose {
+ s.outclosed.set()
+ s.outreset.set()
+ } else {
+ s.outresetcode = streamResetByConnClose
+ s.outreset.setReceived()
+ }
+ s.outUnlock()
+}
+
// inUnlock unlocks s.ingate.
// It sets the gate condition if reads from s will not block.
// If s has receive-related frames to write or if both directions
@@ -423,8 +642,9 @@ func (s *Stream) inUnlock() {
// inUnlockNoQueue is inUnlock,
// but reports whether s has frames to write rather than notifying the Conn.
func (s *Stream) inUnlockNoQueue() streamState {
- canRead := s.inset.contains(s.in.start) || // data available to read
- s.insize == s.in.start || // at EOF
+ nextByte := s.in.start + int64(len(s.inbuf))
+ canRead := s.inset.contains(nextByte) || // data available to read
+ s.insize == s.in.start+int64(len(s.inbuf)) || // at EOF
s.inresetcode != -1 || // reset by peer
s.inclosed.isSet() // closed locally
defer s.ingate.unlock(canRead)
@@ -567,19 +787,31 @@ func (s *Stream) handleReset(code uint64, finalSize int64) error {
func (s *Stream) checkStreamBounds(end int64, fin bool) error {
if end > s.inwin {
// The peer sent us data past the maximum flow control window we gave them.
- return localTransportError(errFlowControl)
+ return localTransportError{
+ code: errFlowControl,
+ reason: "stream flow control window exceeded",
+ }
}
if s.insize != -1 && end > s.insize {
// The peer sent us data past the final size of the stream they previously gave us.
- return localTransportError(errFinalSize)
+ return localTransportError{
+ code: errFinalSize,
+ reason: "data received past end of stream",
+ }
}
if fin && s.insize != -1 && end != s.insize {
// The peer changed the final size of the stream.
- return localTransportError(errFinalSize)
+ return localTransportError{
+ code: errFinalSize,
+ reason: "final size of stream changed",
+ }
}
if fin && end < s.in.end {
// The peer has previously sent us data past the final size.
- return localTransportError(errFinalSize)
+ return localTransportError{
+ code: errFinalSize,
+ reason: "end of stream occurs before prior data",
+ }
}
return nil
}
@@ -600,8 +832,8 @@ func (s *Stream) handleMaxStreamData(maxStreamData int64) error {
if maxStreamData <= s.outwin {
return nil
}
- if s.out.end > s.outwin {
- s.outunsent.add(s.outwin, min(maxStreamData, s.out.end))
+ if s.outflushed > s.outwin {
+ s.outunsent.add(s.outwin, min(maxStreamData, s.outflushed))
}
s.outwin = maxStreamData
if s.out.end > s.outwin {
@@ -729,10 +961,11 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b
}
for {
// STREAM
- off, size := dataToSend(min(s.out.start, s.outwin), min(s.out.end, s.outwin), s.outunsent, s.outacked, pto)
+ off, size := dataToSend(min(s.out.start, s.outwin), min(s.outflushed, s.outwin), s.outunsent, s.outacked, pto)
if end := off + size; end > s.outmaxsent {
// This will require connection-level flow control to send.
end = min(end, s.outmaxsent+s.conn.streams.outflow.avail())
+ end = max(end, off)
size = end - off
}
fin := s.outclosed.isSet() && off+size == s.out.end
diff --git a/internal/quic/stream_limits.go b/quic/stream_limits.go
similarity index 87%
rename from internal/quic/stream_limits.go
rename to quic/stream_limits.go
index 6eda7883b9..71cc291351 100644
--- a/internal/quic/stream_limits.go
+++ b/quic/stream_limits.go
@@ -21,7 +21,7 @@ import (
type localStreamLimits struct {
gate gate
max int64 // peer-provided MAX_STREAMS
- opened int64 // number of streams opened by us
+ opened int64 // number of streams opened by us, -1 when conn is closed
}
func (lim *localStreamLimits) init() {
@@ -34,10 +34,21 @@ func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err
if err := lim.gate.waitAndLock(ctx, c.testHooks); err != nil {
return 0, err
}
- n := lim.opened
+ if lim.opened < 0 {
+ lim.gate.unlock(true)
+ return 0, errConnClosed
+ }
+ num = lim.opened
lim.opened++
lim.gate.unlock(lim.opened < lim.max)
- return n, nil
+ return num, nil
+}
+
+// connHasClosed indicates the connection has been closed, locally or by the peer.
+func (lim *localStreamLimits) connHasClosed() {
+ lim.gate.lock()
+ lim.opened = -1
+ lim.gate.unlock(true)
}
// setMax sets the MAX_STREAMS provided by the peer.
@@ -66,7 +77,10 @@ func (lim *remoteStreamLimits) init(maxOpen int64) {
func (lim *remoteStreamLimits) open(id streamID) error {
num := id.num()
if num >= lim.max {
- return localTransportError(errStreamLimit)
+ return localTransportError{
+ code: errStreamLimit,
+ reason: "stream limit exceeded",
+ }
}
if num >= lim.opened {
lim.opened = num + 1
diff --git a/internal/quic/stream_limits_test.go b/quic/stream_limits_test.go
similarity index 96%
rename from internal/quic/stream_limits_test.go
rename to quic/stream_limits_test.go
index 3f291e9f4c..8fed825d74 100644
--- a/internal/quic/stream_limits_test.go
+++ b/quic/stream_limits_test.go
@@ -200,7 +200,6 @@ func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) {
func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
- ctx := canceledContext()
tc := newTestConn(t, serverSide, func(c *Config) {
if styp == uniStream {
c.MaxUniRemoteStreams = 4
@@ -218,13 +217,9 @@ func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) {
id: newStreamID(clientSide, styp, int64(i)),
fin: true,
})
- s, err := tc.conn.AcceptStream(ctx)
- if err != nil {
- t.Fatalf("AcceptStream = %v", err)
- }
- streams = append(streams, s)
+ streams = append(streams, tc.acceptStream())
}
- streams[3].CloseContext(ctx)
+ streams[3].Close()
if styp == bidiStream {
tc.wantFrame("stream is closed",
packetType1RTT, debugFrameStream{
@@ -254,7 +249,7 @@ func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameStopSending{
id: s.id,
})
- tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM",
+ tc.wantFrame("received STOP_SENDING, send RESET_STREAM",
packetType1RTT, debugFrameResetStream{
id: s.id,
})
diff --git a/internal/quic/stream_test.go b/quic/stream_test.go
similarity index 78%
rename from internal/quic/stream_test.go
rename to quic/stream_test.go
index 7c1377faee..2643ae3dba 100644
--- a/internal/quic/stream_test.go
+++ b/quic/stream_test.go
@@ -13,14 +13,14 @@ import (
"errors"
"fmt"
"io"
- "reflect"
"strings"
"testing"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
func TestStreamWriteBlockedByOutputBuffer(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
- ctx := canceledContext()
want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
const writeBufferSize = 4
tc := newTestConn(t, clientSide, permissiveTransportParameters, func(c *Config) {
@@ -29,16 +29,14 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) {
tc.handshake()
tc.ignoreFrame(frameTypeAck)
- s, err := tc.conn.newLocalStream(ctx, styp)
- if err != nil {
- t.Fatal(err)
- }
+ s := newLocalStream(t, tc, styp)
// Non-blocking write.
- n, err := s.WriteContext(ctx, want)
+ n, err := s.Write(want)
if n != writeBufferSize || err != context.Canceled {
- t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize)
+ t.Fatalf("s.Write() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize)
}
+ s.Flush()
tc.wantFrame("first write buffer of data sent",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -48,7 +46,10 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) {
// Blocking write, which must wait for buffer space.
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, want[writeBufferSize:])
+ s.SetWriteContext(ctx)
+ n, err := s.Write(want[writeBufferSize:])
+ s.Flush()
+ return n, err
})
tc.wantIdle("write buffer is full, no more data can be sent")
@@ -73,7 +74,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) {
})
if n, err := w.result(); n != len(want)-writeBufferSize || err != nil {
- t.Fatalf("s.WriteContext() = %v, %v; want %v, nil",
+ t.Fatalf("s.Write() = %v, %v; want %v, nil",
len(want)-writeBufferSize, err, writeBufferSize)
}
})
@@ -97,10 +98,11 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) {
}
// Data is written to the stream output buffer, but we have no flow control.
- _, err = s.WriteContext(ctx, want[:1])
+ _, err = s.Write(want[:1])
if err != nil {
t.Fatalf("write with available output buffer: unexpected error: %v", err)
}
+ s.Flush()
tc.wantFrame("write blocked by flow control triggers a STREAM_DATA_BLOCKED frame",
packetType1RTT, debugFrameStreamDataBlocked{
id: s.id,
@@ -108,10 +110,11 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) {
})
// Write more data.
- _, err = s.WriteContext(ctx, want[1:])
+ _, err = s.Write(want[1:])
if err != nil {
t.Fatalf("write with available output buffer: unexpected error: %v", err)
}
+ s.Flush()
tc.wantIdle("adding more blocked data does not trigger another STREAM_DATA_BLOCKED")
// Provide some flow control window.
@@ -170,7 +173,8 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- s.WriteContext(ctx, want[:1])
+ s.Write(want[:1])
+ s.Flush()
tc.wantFrame("sent data (1 byte) fits within flow control limit",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -185,7 +189,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) {
})
// Write [1,4).
- s.WriteContext(ctx, want[1:])
+ s.Write(want[1:])
tc.wantFrame("stream limit is 4 bytes, ignoring decrease in MAX_STREAM_DATA",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -205,7 +209,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) {
})
// Write [1,4).
- s.WriteContext(ctx, want[4:])
+ s.Write(want[4:])
tc.wantFrame("stream limit is 8 bytes, ignoring decrease in MAX_STREAM_DATA",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -217,7 +221,6 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) {
func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
- ctx := canceledContext()
want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
const maxWriteBuffer = 4
tc := newTestConn(t, clientSide, func(p *transportParameters) {
@@ -235,12 +238,10 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) {
// Write more data than StreamWriteBufferSize.
// The peer has given us plenty of flow control,
// so we're just blocked by our local limit.
- s, err := tc.conn.newLocalStream(ctx, styp)
- if err != nil {
- t.Fatal(err)
- }
+ s := newLocalStream(t, tc, styp)
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, want)
+ s.SetWriteContext(ctx)
+ return s.Write(want)
})
tc.wantFrame("stream write should send as much data as write buffer allows",
packetType1RTT, debugFrameStream{
@@ -263,7 +264,7 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) {
w.cancel()
n, err := w.result()
if n != 2*maxWriteBuffer || err == nil {
- t.Fatalf("WriteContext() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer)
+ t.Fatalf("Write() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer)
}
})
}
@@ -394,7 +395,6 @@ func TestStreamReceive(t *testing.T) {
}},
}} {
testStreamTypes(t, test.name, func(t *testing.T, styp streamType) {
- ctx := canceledContext()
tc := newTestConn(t, serverSide)
tc.handshake()
sid := newStreamID(clientSide, styp, 0)
@@ -410,21 +410,17 @@ func TestStreamReceive(t *testing.T) {
fin: f.fin,
})
if s == nil {
- var err error
- s, err = tc.conn.AcceptStream(ctx)
- if err != nil {
- tc.t.Fatalf("conn.AcceptStream() = %v", err)
- }
+ s = tc.acceptStream()
}
for {
- n, err := s.ReadContext(ctx, got[total:])
- t.Logf("s.ReadContext() = %v, %v", n, err)
+ n, err := s.Read(got[total:])
+ t.Logf("s.Read() = %v, %v", n, err)
total += n
if f.wantEOF && err != io.EOF {
- t.Fatalf("ReadContext() error = %v; want io.EOF", err)
+ t.Fatalf("Read() error = %v; want io.EOF", err)
}
if !f.wantEOF && err == io.EOF {
- t.Fatalf("ReadContext() error = io.EOF, want something else")
+ t.Fatalf("Read() error = io.EOF, want something else")
}
if err != nil {
break
@@ -465,8 +461,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) {
}
tc.wantIdle("stream window is not extended before data is read")
buf := make([]byte, maxWindowSize+1)
- if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != nil {
- t.Fatalf("s.ReadContext() = %v, %v; want %v, nil", n, err, maxWindowSize)
+ if n, err := s.Read(buf); n != maxWindowSize || err != nil {
+ t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, maxWindowSize)
}
tc.wantFrame("stream window is extended after reading data",
packetType1RTT, debugFrameMaxStreamData{
@@ -479,8 +475,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) {
data: make([]byte, maxWindowSize),
fin: true,
})
- if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != io.EOF {
- t.Fatalf("s.ReadContext() = %v, %v; want %v, io.EOF", n, err, maxWindowSize)
+ if n, err := s.Read(buf); n != maxWindowSize || err != io.EOF {
+ t.Fatalf("s.Read() = %v, %v; want %v, io.EOF", n, err, maxWindowSize)
}
tc.wantIdle("stream window is not extended after FIN")
})
@@ -546,6 +542,51 @@ func TestStreamReceiveDuplicateDataDoesNotViolateLimits(t *testing.T) {
})
}
+func TestStreamReceiveEmptyEOF(t *testing.T) {
+ // A stream receives some data, we read a byte of that data
+ // (causing the rest to be pulled into the s.inbuf buffer),
+ // and then we receive a FIN with no additional data.
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters)
+ want := []byte{1, 2, 3}
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: want,
+ })
+ if got, err := s.ReadByte(); got != want[0] || err != nil {
+ t.Fatalf("s.ReadByte() = %v, %v; want %v, nil", got, err, want[0])
+ }
+
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: 3,
+ fin: true,
+ })
+ if got, err := io.ReadAll(s); !bytes.Equal(got, want[1:]) || err != nil {
+ t.Fatalf("io.ReadAll(s) = {%x}, %v; want {%x}, nil", got, err, want[1:])
+ }
+ })
+}
+
+func TestStreamReadByteFromOneByteStream(t *testing.T) {
+ // ReadByte on the only byte of a stream should not return an error.
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters)
+ want := byte(1)
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: []byte{want},
+ fin: true,
+ })
+ if got, err := s.ReadByte(); got != want || err != nil {
+ t.Fatalf("s.ReadByte() = %v, %v; want %v, nil", got, err, want)
+ }
+ if got, err := s.ReadByte(); err != io.EOF {
+ t.Fatalf("s.ReadByte() = %v, %v; want _, EOF", got, err)
+ }
+ })
+}
+
func finalSizeTest(t *testing.T, wantErr transportError, f func(tc *testConn, sid streamID) (finalSize int64), opts ...any) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
for _, test := range []struct {
@@ -670,18 +711,19 @@ func TestStreamReceiveUnblocksReader(t *testing.T) {
t.Fatalf("AcceptStream() = %v", err)
}
- // ReadContext succeeds immediately, since we already have data.
+ // Read succeeds immediately, since we already have data.
got := make([]byte, len(want))
read := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.ReadContext(ctx, got)
+ return s.Read(got)
})
if n, err := read.result(); n != write1size || err != nil {
- t.Fatalf("ReadContext = %v, %v; want %v, nil", n, err, write1size)
+ t.Fatalf("Read = %v, %v; want %v, nil", n, err, write1size)
}
- // ReadContext blocks waiting for more data.
+ // Read blocks waiting for more data.
read = runAsync(tc, func(ctx context.Context) (int, error) {
- return s.ReadContext(ctx, got[write1size:])
+ s.SetReadContext(ctx)
+ return s.Read(got[write1size:])
})
tc.writeFrames(packetType1RTT, debugFrameStream{
id: sid,
@@ -690,7 +732,7 @@ func TestStreamReceiveUnblocksReader(t *testing.T) {
fin: true,
})
if n, err := read.result(); n != len(want)-write1size || err != io.EOF {
- t.Fatalf("ReadContext = %v, %v; want %v, io.EOF", n, err, len(want)-write1size)
+ t.Fatalf("Read = %v, %v; want %v, io.EOF", n, err, len(want)-write1size)
}
if !bytes.Equal(got, want) {
t.Fatalf("read bytes %x, want %x", got, want)
@@ -724,7 +766,7 @@ func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFra
if err != nil {
t.Fatal(err)
}
- s.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("new stream is opened",
packetType1RTT, debugFrameStream{
id: sid,
@@ -848,7 +890,7 @@ func TestStreamOffsetTooLarge(t *testing.T) {
got, _ := tc.readFrame()
want1 := debugFrameConnectionCloseTransport{code: errFrameEncoding}
want2 := debugFrameConnectionCloseTransport{code: errFlowControl}
- if !reflect.DeepEqual(got, want1) && !reflect.DeepEqual(got, want2) {
+ if !frameEqual(got, want1) && !frameEqual(got, want2) {
t.Fatalf("STREAM offset exceeds 2^62-1\ngot: %v\nwant: %v\n or: %v", got, want1, want2)
}
}
@@ -932,7 +974,8 @@ func TestStreamResetBlockedStream(t *testing.T) {
})
tc.ignoreFrame(frameTypeStreamDataBlocked)
writing := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, []byte{0, 1, 2, 3, 4, 5, 6, 7})
+ s.SetWriteContext(ctx)
+ return s.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7})
})
tc.wantFrame("stream writes data until write buffer fills",
packetType1RTT, debugFrameStream{
@@ -969,7 +1012,9 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) {
want := make([]byte, 4096)
rand.Read(want) // doesn't need to be crypto/rand, but non-deprecated and harmless
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, want)
+ n, err := s.Write(want)
+ s.Flush()
+ return n, err
})
got := make([]byte, 0, len(want))
for {
@@ -987,7 +1032,7 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) {
got = append(got, sf.data...)
}
if n, err := w.result(); n != len(want) || err != nil {
- t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(want))
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want))
}
if !bytes.Equal(got, want) {
t.Fatalf("mismatch in received stream data")
@@ -995,16 +1040,16 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) {
}
func TestStreamCloseWaitsForAcks(t *testing.T) {
- ctx := canceledContext()
tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
data := make([]byte, 100)
- s.WriteContext(ctx, data)
+ s.Write(data)
+ s.Flush()
tc.wantFrame("conn sends data for the stream",
packetType1RTT, debugFrameStream{
id: s.id,
data: data,
})
- if err := s.CloseContext(ctx); err != context.Canceled {
+ if err := s.Close(); err != context.Canceled {
t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err)
}
tc.wantFrame("conn sends FIN for closed stream",
@@ -1015,21 +1060,22 @@ func TestStreamCloseWaitsForAcks(t *testing.T) {
data: []byte{},
})
closing := runAsync(tc, func(ctx context.Context) (struct{}, error) {
- return struct{}{}, s.CloseContext(ctx)
+ s.SetWriteContext(ctx)
+ return struct{}{}, s.Close()
})
if _, err := closing.result(); err != errNotDone {
- t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err)
+ t.Fatalf("s.Close() = %v, want it to block waiting for acks", err)
}
tc.writeAckForAll()
if _, err := closing.result(); err != nil {
- t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err)
+ t.Fatalf("s.Close() = %v, want nil (all data acked)", err)
}
}
func TestStreamCloseReadOnly(t *testing.T) {
tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, permissiveTransportParameters)
- if err := s.CloseContext(canceledContext()); err != nil {
- t.Errorf("s.CloseContext() = %v, want nil", err)
+ if err := s.Close(); err != nil {
+ t.Errorf("s.Close() = %v, want nil", err)
}
tc.wantFrame("closed stream sends STOP_SENDING",
packetType1RTT, debugFrameStopSending{
@@ -1041,11 +1087,13 @@ func TestStreamCloseUnblocked(t *testing.T) {
for _, test := range []struct {
name string
unblock func(tc *testConn, s *Stream)
+ success bool
}{{
name: "data received",
unblock: func(tc *testConn, s *Stream) {
tc.writeAckForAll()
},
+ success: true,
}, {
name: "stop sending received",
unblock: func(tc *testConn, s *Stream) {
@@ -1061,16 +1109,16 @@ func TestStreamCloseUnblocked(t *testing.T) {
},
}} {
t.Run(test.name, func(t *testing.T) {
- ctx := canceledContext()
tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
data := make([]byte, 100)
- s.WriteContext(ctx, data)
+ s.Write(data)
+ s.Flush()
tc.wantFrame("conn sends data for the stream",
packetType1RTT, debugFrameStream{
id: s.id,
data: data,
})
- if err := s.CloseContext(ctx); err != context.Canceled {
+ if err := s.Close(); err != context.Canceled {
t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err)
}
tc.wantFrame("conn sends FIN for closed stream",
@@ -1081,28 +1129,34 @@ func TestStreamCloseUnblocked(t *testing.T) {
data: []byte{},
})
closing := runAsync(tc, func(ctx context.Context) (struct{}, error) {
- return struct{}{}, s.CloseContext(ctx)
+ s.SetWriteContext(ctx)
+ return struct{}{}, s.Close()
})
if _, err := closing.result(); err != errNotDone {
- t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err)
+ t.Fatalf("s.Close() = %v, want it to block waiting for acks", err)
}
test.unblock(tc, s)
- if _, err := closing.result(); err != nil {
- t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err)
+ _, err := closing.result()
+ switch {
+ case err == errNotDone:
+ t.Fatalf("s.Close() still blocking; want it to have returned")
+ case err == nil && !test.success:
+ t.Fatalf("s.Close() = nil, want error")
+ case err != nil && test.success:
+ t.Fatalf("s.Close() = %v, want nil (all data acked)", err)
}
})
}
}
func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) {
- ctx := canceledContext()
tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters,
func(p *transportParameters) {
//p.initialMaxData = 0
p.initialMaxStreamDataUni = 0
})
tc.ignoreFrame(frameTypeStreamDataBlocked)
- if _, err := s.WriteContext(ctx, []byte{0, 1}); err != nil {
+ if _, err := s.Write([]byte{0, 1}); err != nil {
t.Fatalf("s.Write = %v", err)
}
s.CloseWrite()
@@ -1134,7 +1188,6 @@ func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) {
func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
- ctx := canceledContext()
tc, s := newTestConnAndRemoteStream(t, serverSide, styp)
data := []byte{0, 1, 2, 3, 4, 5, 6, 7}
tc.writeFrames(packetType1RTT, debugFrameStream{
@@ -1142,7 +1195,7 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) {
data: data,
})
got := make([]byte, 4)
- if n, err := s.ReadContext(ctx, got); n != len(got) || err != nil {
+ if n, err := s.Read(got); n != len(got) || err != nil {
t.Fatalf("Read start of stream: got %v, %v; want %v, nil", n, err, len(got))
}
const sentCode = 42
@@ -1152,8 +1205,8 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) {
code: sentCode,
})
wantErr := StreamErrorCode(sentCode)
- if n, err := s.ReadContext(ctx, got); n != 0 || !errors.Is(err, wantErr) {
- t.Fatalf("Read reset stream: got %v, %v; want 0, %v", n, err, wantErr)
+ if _, err := io.ReadAll(s); !errors.Is(err, wantErr) {
+ t.Fatalf("Read reset stream: ReadAll got error %v; want %v", err, wantErr)
}
})
}
@@ -1162,8 +1215,9 @@ func TestStreamPeerResetWakesBlockedRead(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
tc, s := newTestConnAndRemoteStream(t, serverSide, styp)
reader := runAsync(tc, func(ctx context.Context) (int, error) {
+ s.SetReadContext(ctx)
got := make([]byte, 4)
- return s.ReadContext(ctx, got)
+ return s.Read(got)
})
const sentCode = 42
tc.writeFrames(packetType1RTT, debugFrameResetStream{
@@ -1229,6 +1283,7 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) {
tc, s := newTestConnAndLocalStream(t, serverSide, styp, permissiveTransportParameters)
for i := 0; i < 4; i++ {
s.Write([]byte{byte(i)})
+ s.Flush()
tc.wantFrame("write sends a STREAM frame to peer",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -1272,6 +1327,154 @@ func TestStreamReceiveDataBlocked(t *testing.T) {
tc.wantIdle("no response to STREAM_DATA_BLOCKED and DATA_BLOCKED")
}
+func TestStreamFlushExplicit(t *testing.T) {
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters)
+ want := []byte{0, 1, 2, 3}
+ n, err := s.Write(want)
+ if n != len(want) || err != nil {
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want))
+ }
+ tc.wantIdle("unflushed data is not sent")
+ s.Flush()
+ tc.wantFrame("data is sent after flush",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: want,
+ })
+ })
+}
+
+func TestStreamFlushClosedStream(t *testing.T) {
+ _, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Close()
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of closed stream = nil, want error")
+ }
+}
+
+func TestStreamFlushResetStream(t *testing.T) {
+ _, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Reset(0)
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of reset stream = nil, want error")
+ }
+}
+
+func TestStreamFlushStreamAfterPeerStopSending(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Flush() // create the stream
+ tc.wantFrame("stream created after flush",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: []byte{},
+ })
+
+ // Peer sends a STOP_SENDING.
+ tc.writeFrames(packetType1RTT, debugFrameStopSending{
+ id: s.id,
+ })
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of stream reset by peer = nil, want error")
+ }
+}
+
+func TestStreamErrorsAfterConnectionClosed(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ wantErr := &ApplicationError{Code: 42}
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseApplication{
+ code: wantErr.Code,
+ })
+ if _, err := s.Read(make([]byte, 1)); !errors.Is(err, wantErr) {
+ t.Errorf("s.Read on closed connection = %v, want %v", err, wantErr)
+ }
+ if _, err := s.Write(make([]byte, 1)); !errors.Is(err, wantErr) {
+ t.Errorf("s.Write on closed connection = %v, want %v", err, wantErr)
+ }
+ if err := s.Flush(); !errors.Is(err, wantErr) {
+ t.Errorf("s.Flush on closed connection = %v, want %v", err, wantErr)
+ }
+}
+
+func TestStreamFlushImplicitExact(t *testing.T) {
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ const writeBufferSize = 4
+ tc, s := newTestConnAndLocalStream(t, clientSide, styp,
+ permissiveTransportParameters,
+ func(c *Config) {
+ c.MaxStreamWriteBufferSize = writeBufferSize
+ })
+ want := []byte{0, 1, 2, 3, 4, 5, 6}
+
+ // This write doesn't quite fill the output buffer.
+ n, err := s.Write(want[:3])
+ if n != 3 || err != nil {
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want))
+ }
+ tc.wantIdle("unflushed data is not sent")
+
+ // This write fills the output buffer exactly.
+ n, err = s.Write(want[3:4])
+ if n != 1 || err != nil {
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want))
+ }
+ tc.wantFrame("data is sent after write buffer fills",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: want[0:4],
+ })
+ })
+}
+
+func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) {
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ const writeBufferSize = 4
+ tc, s := newTestConnAndLocalStream(t, clientSide, styp,
+ permissiveTransportParameters,
+ func(c *Config) {
+ c.MaxStreamWriteBufferSize = writeBufferSize
+ })
+ want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+
+ w := runAsync(tc, func(ctx context.Context) (int, error) {
+ s.SetWriteContext(ctx)
+ n, err := s.Write(want)
+ return n, err
+ })
+
+ tc.wantFrame("data is sent after write buffer fills",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: want[0:4],
+ })
+ tc.writeAckForAll()
+ tc.wantFrame("ack permits sending more data",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: 4,
+ data: want[4:8],
+ })
+ tc.writeAckForAll()
+
+ tc.wantIdle("write buffer is not full")
+ if n, err := w.result(); n != len(want) || err != nil {
+ t.Fatalf("Write() = %v, %v; want %v, nil", n, err, len(want))
+ }
+
+ s.Flush()
+ tc.wantFrame("flush sends last buffer of data",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: 8,
+ data: want[8:],
+ })
+ })
+}
+
type streamSide string
const (
@@ -1289,41 +1492,61 @@ func newTestConnAndStream(t *testing.T, side connSide, sside streamSide, styp st
func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) {
t.Helper()
- ctx := canceledContext()
tc := newTestConn(t, side, opts...)
tc.handshake()
tc.ignoreFrame(frameTypeAck)
+ s := newLocalStream(t, tc, styp)
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return tc, s
+}
+
+func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream {
+ t.Helper()
+ ctx := canceledContext()
s, err := tc.conn.newLocalStream(ctx, styp)
if err != nil {
t.Fatalf("conn.newLocalStream(%v) = %v", styp, err)
}
- return tc, s
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return s
}
func newTestConnAndRemoteStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) {
t.Helper()
- ctx := canceledContext()
tc := newTestConn(t, side, opts...)
tc.handshake()
tc.ignoreFrame(frameTypeAck)
+ s := newRemoteStream(t, tc, styp)
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return tc, s
+}
+
+func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream {
+ t.Helper()
+ ctx := canceledContext()
tc.writeFrames(packetType1RTT, debugFrameStream{
- id: newStreamID(side.peer(), styp, 0),
+ id: newStreamID(tc.conn.side.peer(), styp, 0),
})
s, err := tc.conn.AcceptStream(ctx)
if err != nil {
t.Fatalf("conn.AcceptStream() = %v", err)
}
- return tc, s
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return s
}
// permissiveTransportParameters may be passed as an option to newTestConn.
func permissiveTransportParameters(p *transportParameters) {
p.initialMaxStreamsBidi = maxStreamsLimit
p.initialMaxStreamsUni = maxStreamsLimit
- p.initialMaxData = maxVarint
- p.initialMaxStreamDataBidiRemote = maxVarint
- p.initialMaxStreamDataBidiLocal = maxVarint
- p.initialMaxStreamDataUni = maxVarint
+ p.initialMaxData = quicwire.MaxVarint
+ p.initialMaxStreamDataBidiRemote = quicwire.MaxVarint
+ p.initialMaxStreamDataBidiLocal = quicwire.MaxVarint
+ p.initialMaxStreamDataUni = quicwire.MaxVarint
}
func makeTestData(n int) []byte {
diff --git a/internal/quic/tls.go b/quic/tls.go
similarity index 86%
rename from internal/quic/tls.go
rename to quic/tls.go
index a37e26fb8e..89b31842cd 100644
--- a/internal/quic/tls.go
+++ b/quic/tls.go
@@ -11,14 +11,24 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "net"
"time"
)
// startTLS starts the TLS handshake.
-func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error {
+func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string, params transportParameters) error {
+ tlsConfig := c.config.TLSConfig
+ if a, _, err := net.SplitHostPort(peerHostname); err == nil {
+ peerHostname = a
+ }
+ if tlsConfig.ServerName == "" && peerHostname != "" {
+ tlsConfig = tlsConfig.Clone()
+ tlsConfig.ServerName = peerHostname
+ }
+
c.keysInitial = initialKeys(initialConnID, c.side)
- qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig}
+ qconfig := &tls.QUICConfig{TLSConfig: tlsConfig}
if c.side == clientSide {
c.tls = tls.QUICClient(qconfig)
} else {
@@ -109,11 +119,7 @@ func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []
default:
return errors.New("quic: internal error: received CRYPTO frame in unexpected number space")
}
- err := c.crypto[space].handleCrypto(off, data, func(b []byte) error {
+ return c.crypto[space].handleCrypto(off, data, func(b []byte) error {
return c.tls.HandleData(level, b)
})
- if err != nil {
- return err
- }
- return c.handleTLSEvents(now)
}
diff --git a/internal/quic/tls_test.go b/quic/tls_test.go
similarity index 90%
rename from internal/quic/tls_test.go
rename to quic/tls_test.go
index 81d17b8587..f4abdda582 100644
--- a/internal/quic/tls_test.go
+++ b/quic/tls_test.go
@@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
- "reflect"
"testing"
"time"
)
@@ -36,7 +35,7 @@ func (tc *testConn) handshake() {
for {
if i == len(dgrams)-1 {
if tc.conn.side == clientSide {
- want := tc.now.Add(maxAckDelay - timerGranularity)
+ want := tc.endpoint.now.Add(maxAckDelay - timerGranularity)
if !tc.timer.Equal(want) {
t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer)
}
@@ -56,7 +55,7 @@ func (tc *testConn) handshake() {
fillCryptoFrames(want, tc.cryptoDataOut)
i++
}
- if !reflect.DeepEqual(got, want) {
+ if !datagramEqual(got, want) {
t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want)
}
if i >= len(dgrams) {
@@ -71,9 +70,11 @@ func (tc *testConn) handshake() {
func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
var (
- clientConnIDs [][]byte
- serverConnIDs [][]byte
- transientConnID []byte
+ clientConnIDs [][]byte
+ serverConnIDs [][]byte
+ clientResetToken statelessResetToken
+ serverResetToken statelessResetToken
+ transientConnID []byte
)
localConnIDs := [][]byte{
testLocalConnID(0),
@@ -83,14 +84,20 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
testPeerConnID(0),
testPeerConnID(1),
}
+ localResetToken := tc.endpoint.e.resetGen.tokenForConnID(localConnIDs[1])
+ peerResetToken := testPeerStatelessResetToken(1)
if tc.conn.side == clientSide {
clientConnIDs = localConnIDs
serverConnIDs = peerConnIDs
+ clientResetToken = localResetToken
+ serverResetToken = peerResetToken
transientConnID = testLocalConnID(-1)
} else {
clientConnIDs = peerConnIDs
serverConnIDs = localConnIDs
- transientConnID = []byte{0xde, 0xad, 0xbe, 0xef}
+ clientResetToken = peerResetToken
+ serverResetToken = localResetToken
+ transientConnID = testPeerConnID(-1)
}
return []*testDatagram{{
// Client Initial
@@ -136,9 +143,11 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
debugFrameNewConnectionID{
seq: 1,
connID: serverConnIDs[1],
+ token: serverResetToken,
},
},
}},
+ paddedSize: 1200,
}, {
// Client Initial + Handshake + 1-RTT
packets: []*testPacket{{
@@ -175,6 +184,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
debugFrameNewConnectionID{
seq: 1,
connID: clientConnIDs[1],
+ token: clientResetToken,
},
},
}},
@@ -337,6 +347,7 @@ func TestConnKeysDiscardedClient(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
// The client discards Initial keys after sending a Handshake packet.
@@ -390,6 +401,7 @@ func TestConnKeysDiscardedServer(t *testing.T) {
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
tc.wantIdle("server has discarded Initial keys, cannot read CONNECTION_CLOSE")
@@ -546,7 +558,9 @@ func TestConnAEADLimitReached(t *testing.T) {
// exceeds the integrity limit for the selected AEAD,
// the endpoint MUST immediately close the connection [...]"
// https://www.rfc-editor.org/rfc/rfc9001#section-6.6-6
- tc := newTestConn(t, clientSide)
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ clear(c.StatelessResetKey[:])
+ })
tc.handshake()
var limit int64
@@ -564,7 +578,7 @@ func TestConnAEADLimitReached(t *testing.T) {
// Only use the transient connection ID in Initial packets.
dstConnID = tc.conn.connIDState.local[1].cid
}
- invalid := tc.encodeTestPacket(&testPacket{
+ invalid := encodeTestPacket(t, tc, &testPacket{
ptype: packetType1RTT,
num: 1000,
frames: []debugFrame{debugFramePing{}},
@@ -601,3 +615,32 @@ func TestConnAEADLimitReached(t *testing.T) {
tc.advance(1 * time.Second)
tc.wantIdle("auth failures at limit: conn does not process additional packets")
}
+
+func TestConnKeysDiscardedWithExcessCryptoData(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ tc.ignoreFrame(frameTypeCrypto)
+
+ // One byte of excess CRYPTO data, separated from the valid data by a one-byte gap.
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ off: int64(len(tc.cryptoDataIn[tls.QUICEncryptionLevelInitial]) + 1),
+ data: []byte{0},
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+
+ // We don't drop the Initial keys and discover the excess data until the client
+ // sends a Handshake packet.
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrame("connection closed due to excess Initial CRYPTO data",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errTLSBase + 10,
+ })
+}
diff --git a/quic/tlsconfig_test.go b/quic/tlsconfig_test.go
new file mode 100644
index 0000000000..e24cef08ae
--- /dev/null
+++ b/quic/tlsconfig_test.go
@@ -0,0 +1,56 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+
+ "golang.org/x/net/internal/testcert"
+)
+
+func newTestTLSConfig(side connSide) *tls.Config {
+ config := &tls.Config{
+ InsecureSkipVerify: true,
+ CipherSuites: []uint16{
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+ },
+ MinVersion: tls.VersionTLS13,
+ // Default key exchange mechanisms as of Go 1.23 minus X25519Kyber768Draft00,
+ // which bloats the client hello enough to spill into a second datagram.
+ // Tests were written with the assuption each flight in the handshake
+ // fits in one datagram, and it's simpler to keep that property.
+ CurvePreferences: []tls.CurveID{
+ tls.X25519, tls.CurveP256, tls.CurveP384, tls.CurveP521,
+ },
+ }
+ if side == serverSide {
+ config.Certificates = []tls.Certificate{testCert}
+ }
+ return config
+}
+
+// newTestTLSConfigWithMoreDefaults returns a *tls.Config for testing
+// which behaves more like a default, empty config.
+//
+// In particular, it uses the default curve preferences, which can increase
+// the size of the handshake.
+func newTestTLSConfigWithMoreDefaults(side connSide) *tls.Config {
+ config := newTestTLSConfig(side)
+ config.CipherSuites = nil
+ config.CurvePreferences = nil
+ return config
+}
+
+var testCert = func() tls.Certificate {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ panic(err)
+ }
+ return cert
+}()
diff --git a/internal/quic/transport_params.go b/quic/transport_params.go
similarity index 60%
rename from internal/quic/transport_params.go
rename to quic/transport_params.go
index dc76d16509..13d1c7c7d5 100644
--- a/internal/quic/transport_params.go
+++ b/quic/transport_params.go
@@ -10,6 +10,8 @@ import (
"encoding/binary"
"net/netip"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// transportParameters transferred in the quic_transport_parameters TLS extension.
@@ -77,89 +79,89 @@ const (
func marshalTransportParameters(p transportParameters) []byte {
var b []byte
if v := p.originalDstConnID; v != nil {
- b = appendVarint(b, paramOriginalDestinationConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramOriginalDestinationConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := uint64(p.maxIdleTimeout / time.Millisecond); v != 0 {
- b = appendVarint(b, paramMaxIdleTimeout)
- b = appendVarint(b, uint64(sizeVarint(v)))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramMaxIdleTimeout)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v)))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.statelessResetToken; v != nil {
- b = appendVarint(b, paramStatelessResetToken)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramStatelessResetToken)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := p.maxUDPPayloadSize; v != defaultParamMaxUDPPayloadSize {
- b = appendVarint(b, paramMaxUDPPayloadSize)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramMaxUDPPayloadSize)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxData; v != 0 {
- b = appendVarint(b, paramInitialMaxData)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxData)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataBidiLocal; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataBidiLocal)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiLocal)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataBidiRemote; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataBidiRemote)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiRemote)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataUni; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataUni)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataUni)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamsBidi; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamsBidi)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamsBidi)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamsUni; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamsUni)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamsUni)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.ackDelayExponent; v != defaultParamAckDelayExponent {
- b = appendVarint(b, paramAckDelayExponent)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramAckDelayExponent)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := uint64(p.maxAckDelay / time.Millisecond); v != defaultParamMaxAckDelayMilliseconds {
- b = appendVarint(b, paramMaxAckDelay)
- b = appendVarint(b, uint64(sizeVarint(v)))
- b = appendVarint(b, v)
+ b = quicwire.AppendVarint(b, paramMaxAckDelay)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v)))
+ b = quicwire.AppendVarint(b, v)
}
if p.disableActiveMigration {
- b = appendVarint(b, paramDisableActiveMigration)
+ b = quicwire.AppendVarint(b, paramDisableActiveMigration)
b = append(b, 0) // 0-length value
}
if p.preferredAddrConnID != nil {
b = append(b, paramPreferredAddress)
- b = appendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16))
+ b = quicwire.AppendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16))
b = append(b, p.preferredAddrV4.Addr().AsSlice()...) // 4 bytes
b = binary.BigEndian.AppendUint16(b, p.preferredAddrV4.Port()) // 2 bytes
b = append(b, p.preferredAddrV6.Addr().AsSlice()...) // 16 bytes
b = binary.BigEndian.AppendUint16(b, p.preferredAddrV6.Port()) // 2 bytes
- b = appendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id)
+ b = quicwire.AppendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id)
b = append(b, p.preferredAddrResetToken...) // 16 bytes
}
if v := p.activeConnIDLimit; v != defaultParamActiveConnIDLimit {
- b = appendVarint(b, paramActiveConnectionIDLimit)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramActiveConnectionIDLimit)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialSrcConnID; v != nil {
- b = appendVarint(b, paramInitialSourceConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramInitialSourceConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := p.retrySrcConnID; v != nil {
- b = appendVarint(b, paramRetrySourceConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramRetrySourceConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
return b
}
@@ -167,14 +169,14 @@ func marshalTransportParameters(p transportParameters) []byte {
func unmarshalTransportParams(params []byte) (transportParameters, error) {
p := defaultTransportParameters()
for len(params) > 0 {
- id, n := consumeVarint(params)
+ id, n := quicwire.ConsumeVarint(params)
if n < 0 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
params = params[n:]
- val, n := consumeVarintBytes(params)
+ val, n := quicwire.ConsumeVarintBytes(params)
if n < 0 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
params = params[n:]
n = 0
@@ -184,7 +186,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
n = len(val)
case paramMaxIdleTimeout:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
// If this is unreasonably large, consider it as no timeout to avoid
// time.Duration overflows.
if v > 1<<32 {
@@ -193,52 +195,52 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
p.maxIdleTimeout = time.Duration(v) * time.Millisecond
case paramStatelessResetToken:
if len(val) != 16 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
p.statelessResetToken = val
n = 16
case paramMaxUDPPayloadSize:
- p.maxUDPPayloadSize, n = consumeVarintInt64(val)
+ p.maxUDPPayloadSize, n = quicwire.ConsumeVarintInt64(val)
if p.maxUDPPayloadSize < 1200 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
case paramInitialMaxData:
- p.initialMaxData, n = consumeVarintInt64(val)
+ p.initialMaxData, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataBidiLocal:
- p.initialMaxStreamDataBidiLocal, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataBidiLocal, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataBidiRemote:
- p.initialMaxStreamDataBidiRemote, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataBidiRemote, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataUni:
- p.initialMaxStreamDataUni, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataUni, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamsBidi:
- p.initialMaxStreamsBidi, n = consumeVarintInt64(val)
+ p.initialMaxStreamsBidi, n = quicwire.ConsumeVarintInt64(val)
if p.initialMaxStreamsBidi > maxStreamsLimit {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
case paramInitialMaxStreamsUni:
- p.initialMaxStreamsUni, n = consumeVarintInt64(val)
+ p.initialMaxStreamsUni, n = quicwire.ConsumeVarintInt64(val)
if p.initialMaxStreamsUni > maxStreamsLimit {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
case paramAckDelayExponent:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
if v > 20 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
p.ackDelayExponent = int8(v)
case paramMaxAckDelay:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
if v >= 1<<14 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
p.maxAckDelay = time.Duration(v) * time.Millisecond
case paramDisableActiveMigration:
p.disableActiveMigration = true
case paramPreferredAddress:
if len(val) < 4+2+16+2+1 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
p.preferredAddrV4 = netip.AddrPortFrom(
netip.AddrFrom4(*(*[4]byte)(val[:4])),
@@ -251,20 +253,20 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
)
val = val[16+2:]
var nn int
- p.preferredAddrConnID, nn = consumeUint8Bytes(val)
+ p.preferredAddrConnID, nn = quicwire.ConsumeUint8Bytes(val)
if nn < 0 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
val = val[nn:]
if len(val) != 16 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
p.preferredAddrResetToken = val
val = nil
case paramActiveConnectionIDLimit:
- p.activeConnIDLimit, n = consumeVarintInt64(val)
+ p.activeConnIDLimit, n = quicwire.ConsumeVarintInt64(val)
if p.activeConnIDLimit < 2 {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
case paramInitialSourceConnectionID:
p.initialSrcConnID = val
@@ -276,7 +278,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
n = len(val)
}
if n != len(val) {
- return p, localTransportError(errTransportParameter)
+ return p, localTransportError{code: errTransportParameter}
}
}
return p, nil
diff --git a/internal/quic/transport_params_test.go b/quic/transport_params_test.go
similarity index 97%
rename from internal/quic/transport_params_test.go
rename to quic/transport_params_test.go
index cc88e83fd6..f1961178e8 100644
--- a/internal/quic/transport_params_test.go
+++ b/quic/transport_params_test.go
@@ -13,6 +13,8 @@ import (
"reflect"
"testing"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
func TestTransportParametersMarshalUnmarshal(t *testing.T) {
@@ -334,9 +336,9 @@ func TestTransportParameterMaxIdleTimeoutOverflowsDuration(t *testing.T) {
tooManyMS := 1 + (math.MaxInt64 / uint64(time.Millisecond))
var enc []byte
- enc = appendVarint(enc, paramMaxIdleTimeout)
- enc = appendVarint(enc, uint64(sizeVarint(tooManyMS)))
- enc = appendVarint(enc, uint64(tooManyMS))
+ enc = quicwire.AppendVarint(enc, paramMaxIdleTimeout)
+ enc = quicwire.AppendVarint(enc, uint64(quicwire.SizeVarint(tooManyMS)))
+ enc = quicwire.AppendVarint(enc, uint64(tooManyMS))
dec, err := unmarshalTransportParams(enc)
if err != nil {
diff --git a/quic/udp.go b/quic/udp.go
new file mode 100644
index 0000000000..0a578286b2
--- /dev/null
+++ b/quic/udp.go
@@ -0,0 +1,30 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import "net/netip"
+
+// Per-plaform consts describing support for various features.
+//
+// const udpECNSupport indicates whether the platform supports setting
+// the ECN (Explicit Congestion Notification) IP header bits.
+//
+// const udpInvalidLocalAddrIsError indicates whether sending a packet
+// from an local address not associated with the system is an error.
+// For example, assuming 127.0.0.2 is not a local address, does sending
+// from it (using IP_PKTINFO or some other such feature) result in an error?
+
+// unmapAddrPort returns a with any IPv4-mapped IPv6 address prefix removed.
+func unmapAddrPort(a netip.AddrPort) netip.AddrPort {
+ if a.Addr().Is4In6() {
+ return netip.AddrPortFrom(
+ a.Addr().Unmap(),
+ a.Port(),
+ )
+ }
+ return a
+}
diff --git a/quic/udp_darwin.go b/quic/udp_darwin.go
new file mode 100644
index 0000000000..2eb2e9f9f0
--- /dev/null
+++ b/quic/udp_darwin.go
@@ -0,0 +1,38 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21 && darwin
+
+package quic
+
+import (
+ "encoding/binary"
+
+ "golang.org/x/sys/unix"
+)
+
+// See udp.go.
+const (
+ udpECNSupport = true
+ udpInvalidLocalAddrIsError = true
+)
+
+// Confusingly, on Darwin the contents of the IP_TOS option differ depending on whether
+// it is used as an inbound or outbound cmsg.
+
+func parseIPTOS(b []byte) (ecnBits, bool) {
+ // Single byte. The low two bits are the ECN field.
+ if len(b) != 1 {
+ return 0, false
+ }
+ return ecnBits(b[0] & ecnMask), true
+}
+
+func appendCmsgECNv4(b []byte, ecn ecnBits) []byte {
+ // 32-bit integer.
+ // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073
+ b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4)
+ binary.NativeEndian.PutUint32(data, uint32(ecn))
+ return b
+}
diff --git a/quic/udp_linux.go b/quic/udp_linux.go
new file mode 100644
index 0000000000..6f191ed398
--- /dev/null
+++ b/quic/udp_linux.go
@@ -0,0 +1,33 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21 && linux
+
+package quic
+
+import (
+ "golang.org/x/sys/unix"
+)
+
+// See udp.go.
+const (
+ udpECNSupport = true
+ udpInvalidLocalAddrIsError = false
+)
+
+// The IP_TOS socket option is a single byte containing the IP TOS field.
+// The low two bits are the ECN field.
+
+func parseIPTOS(b []byte) (ecnBits, bool) {
+ if len(b) != 1 {
+ return 0, false
+ }
+ return ecnBits(b[0] & ecnMask), true
+}
+
+func appendCmsgECNv4(b []byte, ecn ecnBits) []byte {
+ b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 1)
+ data[0] = byte(ecn)
+ return b
+}
diff --git a/quic/udp_msg.go b/quic/udp_msg.go
new file mode 100644
index 0000000000..0b600a2b46
--- /dev/null
+++ b/quic/udp_msg.go
@@ -0,0 +1,247 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21 && !quicbasicnet && (darwin || linux)
+
+package quic
+
+import (
+ "encoding/binary"
+ "net"
+ "net/netip"
+ "sync"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// Network interface for platforms using sendmsg/recvmsg with cmsgs.
+
+type netUDPConn struct {
+ c *net.UDPConn
+ localAddr netip.AddrPort
+}
+
+func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
+ a, _ := uc.LocalAddr().(*net.UDPAddr)
+ localAddr := a.AddrPort()
+ if localAddr.Addr().IsUnspecified() {
+ // If the conn is not bound to a specified (non-wildcard) address,
+ // then set localAddr.Addr to an invalid netip.Addr.
+ // This better conveys that this is not an address we should be using,
+ // and is a bit more efficient to test against.
+ localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
+ }
+
+ sc, err := uc.SyscallConn()
+ if err != nil {
+ return nil, err
+ }
+ sc.Control(func(fd uintptr) {
+ // Ask for ECN info and (when we aren't bound to a fixed local address)
+ // destination info.
+ //
+ // If any of these calls fail, we won't get the requested information.
+ // That's fine, we'll gracefully handle the lack.
+ unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
+ unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
+ if !localAddr.IsValid() {
+ unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ }
+ })
+
+ return &netUDPConn{
+ c: uc,
+ localAddr: localAddr,
+ }, nil
+}
+
+func (c *netUDPConn) Close() error { return c.c.Close() }
+
+func (c *netUDPConn) LocalAddr() netip.AddrPort {
+ a, _ := c.c.LocalAddr().(*net.UDPAddr)
+ return a.AddrPort()
+}
+
+func (c *netUDPConn) Read(f func(*datagram)) {
+ // We shouldn't ever see all of these messages at the same time,
+ // but the total is small so just allocate enough space for everything we use.
+ const (
+ inPktinfoSize = 12 // int + in_addr + in_addr
+ in6PktinfoSize = 20 // in6_addr + int
+ ipTOSSize = 4
+ ipv6TclassSize = 4
+ )
+ control := make([]byte, 0+
+ unix.CmsgSpace(inPktinfoSize)+
+ unix.CmsgSpace(in6PktinfoSize)+
+ unix.CmsgSpace(ipTOSSize)+
+ unix.CmsgSpace(ipv6TclassSize))
+
+ for {
+ d := newDatagram()
+ n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
+ if err != nil {
+ return
+ }
+ if n == 0 {
+ continue
+ }
+ d.localAddr = c.localAddr
+ d.peerAddr = unmapAddrPort(peerAddr)
+ d.b = d.b[:n]
+ parseControl(d, control[:controlLen])
+ f(d)
+ }
+}
+
+var cmsgPool = sync.Pool{
+ New: func() any {
+ return new([]byte)
+ },
+}
+
+func (c *netUDPConn) Write(dgram datagram) error {
+ controlp := cmsgPool.Get().(*[]byte)
+ control := *controlp
+ defer func() {
+ *controlp = control[:0]
+ cmsgPool.Put(controlp)
+ }()
+
+ localIP := dgram.localAddr.Addr()
+ if localIP.IsValid() {
+ if localIP.Is4() {
+ control = appendCmsgIPSourceAddrV4(control, localIP)
+ } else {
+ control = appendCmsgIPSourceAddrV6(control, localIP)
+ }
+ }
+ if dgram.ecn != ecnNotECT {
+ if dgram.peerAddr.Addr().Is4() {
+ control = appendCmsgECNv4(control, dgram.ecn)
+ } else {
+ control = appendCmsgECNv6(control, dgram.ecn)
+ }
+ }
+
+ _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
+ return err
+}
+
+func parseControl(d *datagram, control []byte) {
+ for len(control) > 0 {
+ hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control)
+ if err != nil {
+ return
+ }
+ control = remainder
+ switch hdr.Level {
+ case unix.IPPROTO_IP:
+ switch hdr.Type {
+ case unix.IP_TOS, unix.IP_RECVTOS:
+ // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS,
+ // just check for both.)
+ if ecn, ok := parseIPTOS(data); ok {
+ d.ecn = ecn
+ }
+ case unix.IP_PKTINFO:
+ if a, ok := parseInPktinfo(data); ok {
+ d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
+ }
+ }
+ case unix.IPPROTO_IPV6:
+ switch hdr.Type {
+ case unix.IPV6_TCLASS:
+ // 32-bit integer containing the traffic class field.
+ // The low two bits are the ECN field.
+ if ecn, ok := parseIPv6TCLASS(data); ok {
+ d.ecn = ecn
+ }
+ case unix.IPV6_PKTINFO:
+ if a, ok := parseIn6Pktinfo(data); ok {
+ d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
+ }
+ }
+ }
+ }
+}
+
+// IPV6_TCLASS is specified by RFC 3542 as an int.
+
+func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
+ if len(b) != 4 {
+ return 0, false
+ }
+ return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
+}
+
+func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
+ b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4)
+ binary.NativeEndian.PutUint32(data, uint32(ecn))
+ return b
+}
+
+// struct in_pktinfo {
+// unsigned int ipi_ifindex; /* send/recv interface index */
+// struct in_addr ipi_spec_dst; /* Local address */
+// struct in_addr ipi_addr; /* IP Header dst address */
+// };
+
+// parseInPktinfo returns the destination address from an IP_PKTINFO.
+func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
+ if len(b) != 12 {
+ return netip.Addr{}, false
+ }
+ return netip.AddrFrom4([4]byte(b[8:][:4])), true
+}
+
+// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address
+// for an outbound datagram.
+func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
+ // struct in_pktinfo {
+ // unsigned int ipi_ifindex; /* send/recv interface index */
+ // struct in_addr ipi_spec_dst; /* Local address */
+ // struct in_addr ipi_addr; /* IP Header dst address */
+ // };
+ b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12)
+ ip := src.As4()
+ copy(data[4:], ip[:])
+ return b
+}
+
+// struct in6_pktinfo {
+// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
+// unsigned int ipi6_ifindex; /* send/recv interface index */
+// };
+
+// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO.
+func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
+ if len(b) != 20 {
+ return netip.Addr{}, false
+ }
+ return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
+}
+
+// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address
+// for an outbound datagram.
+func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
+ b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20)
+ ip := src.As16()
+ copy(data[0:], ip[:])
+ return b
+}
+
+// appendCmsg appends a cmsg with the given level, type, and size to b.
+// It returns the new buffer, and the data section of the cmsg.
+func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
+ off := len(b)
+ b = append(b, make([]byte, unix.CmsgSpace(size))...)
+ h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off]))
+ h.Level = level
+ h.Type = typ
+ h.SetLen(unix.CmsgLen(size))
+ return b, b[off+unix.CmsgSpace(0):][:size]
+}
diff --git a/quic/udp_other.go b/quic/udp_other.go
new file mode 100644
index 0000000000..28be6d2006
--- /dev/null
+++ b/quic/udp_other.go
@@ -0,0 +1,62 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21 && (quicbasicnet || !(darwin || linux))
+
+package quic
+
+import (
+ "net"
+ "net/netip"
+)
+
+// Lowest common denominator network interface: Basic net.UDPConn, no cmsgs.
+// We will not be able to send or receive ECN bits,
+// and we will not know what our local address is.
+//
+// The quicbasicnet build tag allows selecting this interface on any platform.
+
+// See udp.go.
+const (
+ udpECNSupport = false
+ udpInvalidLocalAddrIsError = false
+)
+
+type netUDPConn struct {
+ c *net.UDPConn
+}
+
+func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
+ return &netUDPConn{
+ c: uc,
+ }, nil
+}
+
+func (c *netUDPConn) Close() error { return c.c.Close() }
+
+func (c *netUDPConn) LocalAddr() netip.AddrPort {
+ a, _ := c.c.LocalAddr().(*net.UDPAddr)
+ return a.AddrPort()
+}
+
+func (c *netUDPConn) Read(f func(*datagram)) {
+ for {
+ dgram := newDatagram()
+ n, _, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(dgram.b, nil)
+ if err != nil {
+ return
+ }
+ if n == 0 {
+ continue
+ }
+ dgram.peerAddr = unmapAddrPort(peerAddr)
+ dgram.b = dgram.b[:n]
+ f(dgram)
+ }
+}
+
+func (c *netUDPConn) Write(dgram datagram) error {
+ _, err := c.c.WriteToUDPAddrPort(dgram.b, dgram.peerAddr)
+ return err
+}
diff --git a/quic/udp_packetconn.go b/quic/udp_packetconn.go
new file mode 100644
index 0000000000..85ce349ff1
--- /dev/null
+++ b/quic/udp_packetconn.go
@@ -0,0 +1,69 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "net"
+ "net/netip"
+)
+
+// netPacketConn is a packetConn implementation wrapping a net.PacketConn.
+//
+// This is mostly useful for tests, since PacketConn doesn't provide access to
+// important features such as identifying the local address packets were received on.
+type netPacketConn struct {
+ c net.PacketConn
+ localAddr netip.AddrPort
+}
+
+func newNetPacketConn(pc net.PacketConn) (*netPacketConn, error) {
+ addr, err := addrPortFromAddr(pc.LocalAddr())
+ if err != nil {
+ return nil, err
+ }
+ return &netPacketConn{
+ c: pc,
+ localAddr: addr,
+ }, nil
+}
+
+func (c *netPacketConn) Close() error {
+ return c.c.Close()
+}
+
+func (c *netPacketConn) LocalAddr() netip.AddrPort {
+ return c.localAddr
+}
+
+func (c *netPacketConn) Read(f func(*datagram)) {
+ for {
+ dgram := newDatagram()
+ n, peerAddr, err := c.c.ReadFrom(dgram.b)
+ if err != nil {
+ return
+ }
+ dgram.peerAddr, err = addrPortFromAddr(peerAddr)
+ if err != nil {
+ continue
+ }
+ dgram.b = dgram.b[:n]
+ f(dgram)
+ }
+}
+
+func (c *netPacketConn) Write(dgram datagram) error {
+ _, err := c.c.WriteTo(dgram.b, net.UDPAddrFromAddrPort(dgram.peerAddr))
+ return err
+}
+
+func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.AddrPort(), nil
+ }
+ return netip.ParseAddrPort(addr.String())
+}
diff --git a/quic/udp_test.go b/quic/udp_test.go
new file mode 100644
index 0000000000..5c4ba10fcc
--- /dev/null
+++ b/quic/udp_test.go
@@ -0,0 +1,191 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+)
+
+func TestUDPSourceUnspecified(t *testing.T) {
+ // Send datagram with no source address set.
+ runUDPTest(t, func(t *testing.T, test udpTest) {
+ t.Logf("%v", test.dstAddr)
+ data := []byte("source unspecified")
+ if err := test.src.Write(datagram{
+ b: data,
+ peerAddr: test.dstAddr,
+ }); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ got := <-test.dgramc
+ if !bytes.Equal(got.b, data) {
+ t.Errorf("got datagram {%x}, want {%x}", got.b, data)
+ }
+ })
+}
+
+func TestUDPSourceSpecified(t *testing.T) {
+ // Send datagram with source address set.
+ runUDPTest(t, func(t *testing.T, test udpTest) {
+ data := []byte("source specified")
+ if err := test.src.Write(datagram{
+ b: data,
+ peerAddr: test.dstAddr,
+ localAddr: test.src.LocalAddr(),
+ }); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ got := <-test.dgramc
+ if !bytes.Equal(got.b, data) {
+ t.Errorf("got datagram {%x}, want {%x}", got.b, data)
+ }
+ })
+}
+
+func TestUDPSourceInvalid(t *testing.T) {
+ // Send datagram with source address set to an address not associated with the connection.
+ if !udpInvalidLocalAddrIsError {
+ t.Skipf("%v: sending from invalid source succeeds", runtime.GOOS)
+ }
+ runUDPTest(t, func(t *testing.T, test udpTest) {
+ var localAddr netip.AddrPort
+ if test.src.LocalAddr().Addr().Is4() {
+ localAddr = netip.MustParseAddrPort("127.0.0.2:1234")
+ } else {
+ localAddr = netip.MustParseAddrPort("[::2]:1234")
+ }
+ data := []byte("source invalid")
+ if err := test.src.Write(datagram{
+ b: data,
+ peerAddr: test.dstAddr,
+ localAddr: localAddr,
+ }); err == nil {
+ t.Errorf("Write with invalid localAddr succeeded; want error")
+ }
+ })
+}
+
+func TestUDPECN(t *testing.T) {
+ if !udpECNSupport {
+ t.Skipf("%v: no ECN support", runtime.GOOS)
+ }
+ // Send datagrams with ECN bits set, verify the ECN bits are received.
+ runUDPTest(t, func(t *testing.T, test udpTest) {
+ for _, ecn := range []ecnBits{ecnNotECT, ecnECT1, ecnECT0, ecnCE} {
+ if err := test.src.Write(datagram{
+ b: []byte{1, 2, 3, 4},
+ peerAddr: test.dstAddr,
+ ecn: ecn,
+ }); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ got := <-test.dgramc
+ if got.ecn != ecn {
+ t.Errorf("sending ECN bits %x, got %x", ecn, got.ecn)
+ }
+ }
+ })
+}
+
+type udpTest struct {
+ src *netUDPConn
+ dst *netUDPConn
+ dstAddr netip.AddrPort
+ dgramc chan *datagram
+}
+
+// runUDPTest calls f with a pair of UDPConns in a matrix of network variations:
+// udp, udp4, and udp6, and variations on binding to an unspecified address (0.0.0.0)
+// or a specified one.
+func runUDPTest(t *testing.T, f func(t *testing.T, u udpTest)) {
+ for _, test := range []struct {
+ srcNet, srcAddr, dstNet, dstAddr string
+ }{
+ {"udp4", "127.0.0.1", "udp", ""},
+ {"udp4", "127.0.0.1", "udp4", ""},
+ {"udp4", "127.0.0.1", "udp4", "127.0.0.1"},
+ {"udp6", "::1", "udp", ""},
+ {"udp6", "::1", "udp6", ""},
+ {"udp6", "::1", "udp6", "::1"},
+ } {
+ spec := "spec"
+ if test.dstAddr == "" {
+ spec = "unspec"
+ }
+ t.Run(fmt.Sprintf("%v/%v/%v", test.srcNet, test.dstNet, spec), func(t *testing.T) {
+ // See: https://go.googlesource.com/go/+/refs/tags/go1.22.0/src/net/ipsock.go#47
+ // On these platforms, conns with network="udp" cannot accept IPv6.
+ switch runtime.GOOS {
+ case "dragonfly", "openbsd":
+ if test.srcNet == "udp6" && test.dstNet == "udp" {
+ t.Skipf("%v: no support for mapping IPv4 address to IPv6", runtime.GOOS)
+ }
+ case "plan9":
+ t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS)
+ }
+ if runtime.GOARCH == "wasm" && test.srcNet == "udp6" {
+ t.Skipf("%v: IPv6 tests fail when using wasm fake net", runtime.GOARCH)
+ }
+
+ srcAddr := netip.AddrPortFrom(netip.MustParseAddr(test.srcAddr), 0)
+ srcConn, err := net.ListenUDP(test.srcNet, net.UDPAddrFromAddrPort(srcAddr))
+ if err != nil {
+ // If ListenUDP fails here, we presumably don't have
+ // IPv4/IPv6 configured.
+ t.Skipf("ListenUDP(%q, %v) = %v", test.srcNet, srcAddr, err)
+ }
+ t.Cleanup(func() { srcConn.Close() })
+ src, err := newNetUDPConn(srcConn)
+ if err != nil {
+ t.Fatalf("newNetUDPConn: %v", err)
+ }
+
+ var dstAddr netip.AddrPort
+ if test.dstAddr != "" {
+ dstAddr = netip.AddrPortFrom(netip.MustParseAddr(test.dstAddr), 0)
+ }
+ dstConn, err := net.ListenUDP(test.dstNet, net.UDPAddrFromAddrPort(dstAddr))
+ if err != nil {
+ t.Skipf("ListenUDP(%q, nil) = %v", test.dstNet, err)
+ }
+ dst, err := newNetUDPConn(dstConn)
+ if err != nil {
+ dstConn.Close()
+ t.Fatalf("newNetUDPConn: %v", err)
+ }
+
+ dgramc := make(chan *datagram)
+ go func() {
+ defer close(dgramc)
+ dst.Read(func(dgram *datagram) {
+ dgramc <- dgram
+ })
+ }()
+ t.Cleanup(func() {
+ dstConn.Close()
+ for range dgramc {
+ t.Errorf("test read unexpected datagram")
+ }
+ })
+
+ f(t, udpTest{
+ src: src,
+ dst: dst,
+ dstAddr: netip.AddrPortFrom(
+ srcAddr.Addr(),
+ dst.LocalAddr().Port(),
+ ),
+ dgramc: dgramc,
+ })
+ })
+ }
+}
diff --git a/internal/quic/version_test.go b/quic/version_test.go
similarity index 90%
rename from internal/quic/version_test.go
rename to quic/version_test.go
index cfb7ce4be7..0bd8bac14b 100644
--- a/internal/quic/version_test.go
+++ b/quic/version_test.go
@@ -17,7 +17,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) {
config := &Config{
TLSConfig: newTestTLSConfig(serverSide),
}
- tl := newTestListener(t, config, nil)
+ te := newTestEndpoint(t, config)
// Packet of unknown contents for some unrecognized QUIC version.
dstConnID := []byte{1, 2, 3, 4}
@@ -30,19 +30,19 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) {
pkt = append(pkt, dstConnID...)
pkt = append(pkt, byte(len(srcConnID)))
pkt = append(pkt, srcConnID...)
- for len(pkt) < minimumClientInitialDatagramSize {
+ for len(pkt) < paddedInitialDatagramSize {
pkt = append(pkt, 0)
}
- tl.write(&datagram{
+ te.write(&datagram{
b: pkt,
})
- gotPkt := tl.read()
+ gotPkt := te.read()
if gotPkt == nil {
- t.Fatalf("got no response; want Version Negotiaion")
+ t.Fatalf("got no response; want Version Negotiation")
}
if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation {
- t.Fatalf("got packet type %v; want Version Negotiaion", got)
+ t.Fatalf("got packet type %v; want Version Negotiation", got)
}
gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt)
if got, want := gotDst, srcConnID; !bytes.Equal(got, want) {
@@ -59,7 +59,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) {
func TestVersionNegotiationClientAborts(t *testing.T) {
tc := newTestConn(t, clientSide)
p := tc.readPacket() // client Initial packet
- tc.listener.write(&datagram{
+ tc.endpoint.write(&datagram{
b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10),
})
tc.wantIdle("connection does not send a CONNECTION_CLOSE")
@@ -76,7 +76,7 @@ func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) {
debugFrameCrypto{
data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
})
- tc.listener.write(&datagram{
+ tc.endpoint.write(&datagram{
b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10),
})
if err := tc.conn.waitReady(canceledContext()); err != context.Canceled {
@@ -94,7 +94,7 @@ func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) {
tc := newTestConn(t, clientSide)
tc.ignoreFrame(frameTypeAck)
p := tc.readPacket() // client Initial packet
- tc.listener.write(&datagram{
+ tc.endpoint.write(&datagram{
b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10),
})
tc.writeFrames(packetTypeInitial,
diff --git a/route/address.go b/route/address.go
index 5a3cc06549..492838a7fe 100644
--- a/route/address.go
+++ b/route/address.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
@@ -171,29 +170,56 @@ func (a *Inet6Addr) marshal(b []byte) (int, error) {
// parseInetAddr parses b as an internet address for IPv4 or IPv6.
func parseInetAddr(af int, b []byte) (Addr, error) {
+ const (
+ off4 = 4 // offset of in_addr
+ off6 = 8 // offset of in6_addr
+ ipv4Len = 4 // length of IPv4 address in bytes
+ ipv6Len = 16 // length of IPv6 address in bytes
+ )
switch af {
case syscall.AF_INET:
- if len(b) < sizeofSockaddrInet {
+ if len(b) < int(b[0]) {
return nil, errInvalidAddr
}
+ sockAddrLen := int(b[0])
a := &Inet4Addr{}
- copy(a.IP[:], b[4:8])
+ // sockAddrLen of 0 is valid and represents 0.0.0.0
+ if sockAddrLen > off4 {
+ // Calculate how many bytes of the address to copy:
+ // either full IPv4 length or the available length.
+ n := off4 + ipv4Len
+ if sockAddrLen < n {
+ n = sockAddrLen
+ }
+ copy(a.IP[:], b[off4:n])
+ }
return a, nil
case syscall.AF_INET6:
- if len(b) < sizeofSockaddrInet6 {
+ if len(b) < int(b[0]) {
return nil, errInvalidAddr
}
- a := &Inet6Addr{ZoneID: int(nativeEndian.Uint32(b[24:28]))}
- copy(a.IP[:], b[8:24])
- if a.IP[0] == 0xfe && a.IP[1]&0xc0 == 0x80 || a.IP[0] == 0xff && (a.IP[1]&0x0f == 0x01 || a.IP[1]&0x0f == 0x02) {
- // KAME based IPv6 protocol stack usually
- // embeds the interface index in the
- // interface-local or link-local address as
- // the kernel-internal form.
- id := int(bigEndian.Uint16(a.IP[2:4]))
- if id != 0 {
- a.ZoneID = id
- a.IP[2], a.IP[3] = 0, 0
+ sockAddrLen := int(b[0])
+ a := &Inet6Addr{}
+ // sockAddrLen of 0 is valid and represents ::
+ if sockAddrLen > off6 {
+ n := off6 + ipv6Len
+ if sockAddrLen < n {
+ n = sockAddrLen
+ }
+ if sockAddrLen == sizeofSockaddrInet6 {
+ a.ZoneID = int(nativeEndian.Uint32(b[24:28]))
+ }
+ copy(a.IP[:], b[off6:n])
+ if a.IP[0] == 0xfe && a.IP[1]&0xc0 == 0x80 || a.IP[0] == 0xff && (a.IP[1]&0x0f == 0x01 || a.IP[1]&0x0f == 0x02) {
+ // KAME based IPv6 protocol stack usually
+ // embeds the interface index in the
+ // interface-local or link-local address as
+ // the kernel-internal form.
+ id := int(bigEndian.Uint16(a.IP[2:4]))
+ if id != 0 {
+ a.ZoneID = id
+ a.IP[2], a.IP[3] = 0, 0
+ }
}
}
return a, nil
@@ -370,13 +396,19 @@ func marshalAddrs(b []byte, as []Addr) (uint, error) {
func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ([]Addr, error) {
var as [syscall.RTAX_MAX]Addr
af := int(syscall.AF_UNSPEC)
+ isInet := func(fam int) bool {
+ return fam == syscall.AF_INET || fam == syscall.AF_INET6
+ }
+ isMask := func(addrType uint) bool {
+ return addrType == syscall.RTAX_NETMASK || addrType == syscall.RTAX_GENMASK
+ }
for i := uint(0); i < syscall.RTAX_MAX && len(b) >= roundup(0); i++ {
if attrs&(1<
+ // locks: inits:
+ // sockaddrs:
+ // :: fe80::2d0:4cff:fe10:15d2 ::
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x02, 0xd0, 0x4c, 0xff, 0xfe, 0x10, 0x15, 0xd2,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x02, 0x1e, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet6Addr{},
+ &Inet6Addr{IP: [16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xd0, 0x4c, 0xff, 0xfe, 0x10, 0x15, 0xd2}},
+ &Inet6Addr{},
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ },
+ },
+ // golang/go#70528, the kernel can produce addresses of length 0
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x00, 0x1e, 0x00, 0x00,
+
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x21, 0x00, 0x00, 0x00, 0x00,
+ 0xf2, 0x2f, 0x4b, 0xff, 0xfe, 0x09, 0x3b, 0xff,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x0e, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet6Addr{IP: [16]byte{}},
+ &Inet6Addr{IP: [16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf2, 0x2f, 0x4b, 0xff, 0xfe, 0x09, 0x3b, 0xff}, ZoneID: 33},
+ &Inet6Addr{IP: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ },
+ },
+ // Additional case: golang/go/issues/70528#issuecomment-2498692877
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x84, 0x00, 0x05, 0x04, 0x01, 0x00, 0x00, 0x00, 0x03, 0x08, 0x00, 0x01, 0x15, 0x00, 0x00, 0x00,
+ 0x1B, 0x01, 0x00, 0x00, 0xF5, 0x5A, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x02, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00,
+ 0x14, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet4Addr{IP: [4]byte{0x0, 0x0, 0x0, 0x0}},
+ nil,
+ nil,
nil,
nil,
nil,
diff --git a/route/address_test.go b/route/address_test.go
index bd7db4a1f7..31087576ed 100644
--- a/route/address_test.go
+++ b/route/address_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/binary.go b/route/binary.go
index a5e28f1e9c..db3f7e0c2a 100644
--- a/route/binary.go
+++ b/route/binary.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/defs_darwin.go b/route/defs_darwin.go
index 8da5845712..46a4ed6694 100644
--- a/route/defs_darwin.go
+++ b/route/defs_darwin.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package route
@@ -25,14 +24,10 @@ const (
sizeofIfmaMsghdrDarwin15 = C.sizeof_struct_ifma_msghdr
sizeofIfMsghdr2Darwin15 = C.sizeof_struct_if_msghdr2
sizeofIfmaMsghdr2Darwin15 = C.sizeof_struct_ifma_msghdr2
- sizeofIfDataDarwin15 = C.sizeof_struct_if_data
- sizeofIfData64Darwin15 = C.sizeof_struct_if_data64
sizeofRtMsghdrDarwin15 = C.sizeof_struct_rt_msghdr
sizeofRtMsghdr2Darwin15 = C.sizeof_struct_rt_msghdr2
- sizeofRtMetricsDarwin15 = C.sizeof_struct_rt_metrics
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_dragonfly.go b/route/defs_dragonfly.go
index acf3d1c55f..52aa700a6d 100644
--- a/route/defs_dragonfly.go
+++ b/route/defs_dragonfly.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package route
@@ -48,10 +47,8 @@ const (
sizeofIfaMsghdrDragonFlyBSD58 = C.sizeof_struct_ifa_msghdr_dfly58
- sizeofRtMsghdrDragonFlyBSD4 = C.sizeof_struct_rt_msghdr
- sizeofRtMetricsDragonFlyBSD4 = C.sizeof_struct_rt_metrics
+ sizeofRtMsghdrDragonFlyBSD4 = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_freebsd.go b/route/defs_freebsd.go
index 3f115121bc..68778f2d16 100644
--- a/route/defs_freebsd.go
+++ b/route/defs_freebsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package route
@@ -221,7 +220,6 @@ import "C"
const (
sizeofIfMsghdrlFreeBSD10 = C.sizeof_struct_if_msghdrl
sizeofIfaMsghdrFreeBSD10 = C.sizeof_struct_ifa_msghdr
- sizeofIfaMsghdrlFreeBSD10 = C.sizeof_struct_ifa_msghdrl
sizeofIfmaMsghdrFreeBSD10 = C.sizeof_struct_ifma_msghdr
sizeofIfAnnouncemsghdrFreeBSD10 = C.sizeof_struct_if_announcemsghdr
@@ -234,15 +232,7 @@ const (
sizeofIfMsghdrFreeBSD10 = C.sizeof_struct_if_msghdr_freebsd10
sizeofIfMsghdrFreeBSD11 = C.sizeof_struct_if_msghdr_freebsd11
- sizeofIfDataFreeBSD7 = C.sizeof_struct_if_data_freebsd7
- sizeofIfDataFreeBSD8 = C.sizeof_struct_if_data_freebsd8
- sizeofIfDataFreeBSD9 = C.sizeof_struct_if_data_freebsd9
- sizeofIfDataFreeBSD10 = C.sizeof_struct_if_data_freebsd10
- sizeofIfDataFreeBSD11 = C.sizeof_struct_if_data_freebsd11
-
- sizeofIfMsghdrlFreeBSD10Emu = C.sizeof_struct_if_msghdrl
sizeofIfaMsghdrFreeBSD10Emu = C.sizeof_struct_ifa_msghdr
- sizeofIfaMsghdrlFreeBSD10Emu = C.sizeof_struct_ifa_msghdrl
sizeofIfmaMsghdrFreeBSD10Emu = C.sizeof_struct_ifma_msghdr
sizeofIfAnnouncemsghdrFreeBSD10Emu = C.sizeof_struct_if_announcemsghdr
@@ -255,13 +245,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = C.sizeof_struct_if_msghdr_freebsd10
sizeofIfMsghdrFreeBSD11Emu = C.sizeof_struct_if_msghdr_freebsd11
- sizeofIfDataFreeBSD7Emu = C.sizeof_struct_if_data_freebsd7
- sizeofIfDataFreeBSD8Emu = C.sizeof_struct_if_data_freebsd8
- sizeofIfDataFreeBSD9Emu = C.sizeof_struct_if_data_freebsd9
- sizeofIfDataFreeBSD10Emu = C.sizeof_struct_if_data_freebsd10
- sizeofIfDataFreeBSD11Emu = C.sizeof_struct_if_data_freebsd11
-
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_netbsd.go b/route/defs_netbsd.go
index c4304df84f..fb60f43c83 100644
--- a/route/defs_netbsd.go
+++ b/route/defs_netbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package route
@@ -24,10 +23,8 @@ const (
sizeofIfaMsghdrNetBSD7 = C.sizeof_struct_ifa_msghdr
sizeofIfAnnouncemsghdrNetBSD7 = C.sizeof_struct_if_announcemsghdr
- sizeofRtMsghdrNetBSD7 = C.sizeof_struct_rt_msghdr
- sizeofRtMetricsNetBSD7 = C.sizeof_struct_rt_metrics
+ sizeofRtMsghdrNetBSD7 = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_openbsd.go b/route/defs_openbsd.go
index 9af0e1af57..471558d9ef 100644
--- a/route/defs_openbsd.go
+++ b/route/defs_openbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package route
@@ -22,7 +21,6 @@ import "C"
const (
sizeofRtMsghdr = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/empty.s b/route/empty.s
index 90ab4ca3d8..49d79791e0 100644
--- a/route/empty.s
+++ b/route/empty.s
@@ -3,6 +3,5 @@
// license that can be found in the LICENSE file.
//go:build darwin && go1.12
-// +build darwin,go1.12
// This exists solely so we can linkname in symbols from syscall.
diff --git a/route/example_darwin_test.go b/route/example_darwin_test.go
new file mode 100644
index 0000000000..e442c3ecf7
--- /dev/null
+++ b/route/example_darwin_test.go
@@ -0,0 +1,70 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package route_test
+
+import (
+ "fmt"
+ "net/netip"
+ "os"
+ "syscall"
+
+ "golang.org/x/net/route"
+ "golang.org/x/sys/unix"
+)
+
+// This example demonstrates how to parse a response to RTM_GET request.
+func ExampleParseRIB() {
+ fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ if err != nil {
+ return
+ }
+ defer unix.Close(fd)
+
+ // Create a RouteMessage with RTM_GET type
+ rtm := &route.RouteMessage{
+ Version: syscall.RTM_VERSION,
+ Type: unix.RTM_GET,
+ ID: uintptr(os.Getpid()),
+ Seq: 0,
+ Addrs: []route.Addr{
+ &route.Inet4Addr{IP: [4]byte{127, 0, 0, 0}},
+ },
+ }
+
+ // Marshal the message into bytes
+ msgBytes, err := rtm.Marshal()
+ if err != nil {
+ return
+ }
+
+ // Send the message over the routing socket
+ _, err = unix.Write(fd, msgBytes)
+ if err != nil {
+ return
+ }
+
+ // Read the response from the routing socket
+ var buf [2 << 10]byte
+ n, err := unix.Read(fd, buf[:])
+ if err != nil {
+ return
+ }
+
+ // Parse the response messages
+ msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n])
+ if err != nil {
+ return
+ }
+ routeMsg, ok := msgs[0].(*route.RouteMessage)
+ if !ok {
+ return
+ }
+ netmask, ok := routeMsg.Addrs[2].(*route.Inet4Addr)
+ if !ok {
+ return
+ }
+ fmt.Println(netip.AddrFrom4(netmask.IP))
+ // Output: 255.0.0.0
+}
diff --git a/route/interface.go b/route/interface.go
index 9e9407830c..0aa70555ca 100644
--- a/route/interface.go
+++ b/route/interface.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/interface_announce.go b/route/interface_announce.go
index 8282bfe9e2..70614c1b1a 100644
--- a/route/interface_announce.go
+++ b/route/interface_announce.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || netbsd
-// +build dragonfly freebsd netbsd
package route
diff --git a/route/interface_classic.go b/route/interface_classic.go
index 903a196346..be1bf2652e 100644
--- a/route/interface_classic.go
+++ b/route/interface_classic.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || netbsd
-// +build darwin dragonfly netbsd
package route
diff --git a/route/interface_multicast.go b/route/interface_multicast.go
index dd0b214baa..2ee37b9c74 100644
--- a/route/interface_multicast.go
+++ b/route/interface_multicast.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd
-// +build darwin dragonfly freebsd
package route
diff --git a/route/message.go b/route/message.go
index 456a8363fe..dc8bfc5b3a 100644
--- a/route/message.go
+++ b/route/message.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/message_test.go b/route/message_test.go
index 61927d62c0..9381f1b2df 100644
--- a/route/message_test.go
+++ b/route/message_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/route.go b/route/route.go
index 3ab5bcdc01..ca2ce2b887 100644
--- a/route/route.go
+++ b/route/route.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
// Package route provides basic functions for the manipulation of
// packet routing facilities on BSD variants.
diff --git a/route/route_classic.go b/route/route_classic.go
index d6ee42f1b1..e273fe39ab 100644
--- a/route/route_classic.go
+++ b/route/route_classic.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd
-// +build darwin dragonfly freebsd netbsd
package route
diff --git a/route/route_test.go b/route/route_test.go
index 55c8f23727..ba57702178 100644
--- a/route/route_test.go
+++ b/route/route_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/sys.go b/route/sys.go
index 7c75574f18..fcebee58ec 100644
--- a/route/sys.go
+++ b/route/sys.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/sys_netbsd.go b/route/sys_netbsd.go
index be4460e13f..c6bb6bc8a2 100644
--- a/route/sys_netbsd.go
+++ b/route/sys_netbsd.go
@@ -25,7 +25,7 @@ func (m *RouteMessage) Sys() []Sys {
}
}
-// RouteMetrics represents route metrics.
+// InterfaceMetrics represents route metrics.
type InterfaceMetrics struct {
Type int // interface type
MTU int // maximum transmission unit
diff --git a/route/syscall.go b/route/syscall.go
index 68d37c9621..0ed53750a2 100644
--- a/route/syscall.go
+++ b/route/syscall.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package route
diff --git a/route/zsys_darwin.go b/route/zsys_darwin.go
index 56a0c66f44..adaa460026 100644
--- a/route/zsys_darwin.go
+++ b/route/zsys_darwin.go
@@ -9,14 +9,10 @@ const (
sizeofIfmaMsghdrDarwin15 = 0x10
sizeofIfMsghdr2Darwin15 = 0xa0
sizeofIfmaMsghdr2Darwin15 = 0x14
- sizeofIfDataDarwin15 = 0x60
- sizeofIfData64Darwin15 = 0x80
sizeofRtMsghdrDarwin15 = 0x5c
sizeofRtMsghdr2Darwin15 = 0x5c
- sizeofRtMetricsDarwin15 = 0x38
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_dragonfly.go b/route/zsys_dragonfly.go
index f7c7a60cd6..209cb20af8 100644
--- a/route/zsys_dragonfly.go
+++ b/route/zsys_dragonfly.go
@@ -11,10 +11,8 @@ const (
sizeofIfaMsghdrDragonFlyBSD58 = 0x18
- sizeofRtMsghdrDragonFlyBSD4 = 0x98
- sizeofRtMetricsDragonFlyBSD4 = 0x70
+ sizeofRtMsghdrDragonFlyBSD4 = 0x98
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_386.go b/route/zsys_freebsd_386.go
index 3f985c7ee9..ec617772b2 100644
--- a/route/zsys_freebsd_386.go
+++ b/route/zsys_freebsd_386.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0x68
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0x6c
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,18 +18,10 @@ const (
sizeofIfMsghdrFreeBSD10 = 0x64
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x50
- sizeofIfDataFreeBSD8 = 0x50
- sizeofIfDataFreeBSD9 = 0x50
- sizeofIfDataFreeBSD10 = 0x54
- sizeofIfDataFreeBSD11 = 0x98
-
// MODIFIED BY HAND FOR 386 EMULATION ON AMD64
// 386 EMULATION USES THE UNDERLYING RAW DATA LAYOUT
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -43,13 +34,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_amd64.go b/route/zsys_freebsd_amd64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_amd64.go
+++ b/route/zsys_freebsd_amd64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_arm.go b/route/zsys_freebsd_arm.go
index a2bdb4ad3b..931afa3931 100644
--- a/route/zsys_freebsd_arm.go
+++ b/route/zsys_freebsd_arm.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0x68
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0x6c
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0x70
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x60
- sizeofIfDataFreeBSD8 = 0x60
- sizeofIfDataFreeBSD9 = 0x60
- sizeofIfDataFreeBSD10 = 0x60
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0x68
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0x6c
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0x70
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x60
- sizeofIfDataFreeBSD8Emu = 0x60
- sizeofIfDataFreeBSD9Emu = 0x60
- sizeofIfDataFreeBSD10Emu = 0x60
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_arm64.go b/route/zsys_freebsd_arm64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_arm64.go
+++ b/route/zsys_freebsd_arm64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_riscv64.go b/route/zsys_freebsd_riscv64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_riscv64.go
+++ b/route/zsys_freebsd_riscv64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_netbsd.go b/route/zsys_netbsd.go
index eaffe8c408..90ce707d47 100644
--- a/route/zsys_netbsd.go
+++ b/route/zsys_netbsd.go
@@ -8,10 +8,8 @@ const (
sizeofIfaMsghdrNetBSD7 = 0x18
sizeofIfAnnouncemsghdrNetBSD7 = 0x18
- sizeofRtMsghdrNetBSD7 = 0x78
- sizeofRtMetricsNetBSD7 = 0x50
+ sizeofRtMsghdrNetBSD7 = 0x78
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_openbsd.go b/route/zsys_openbsd.go
index b11b812680..64fbdd98fb 100644
--- a/route/zsys_openbsd.go
+++ b/route/zsys_openbsd.go
@@ -6,7 +6,6 @@ package route
const (
sizeofRtMsghdr = 0x60
- sizeofSockaddrStorage = 0x100
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/webdav/file_test.go b/webdav/file_test.go
index e875c136ca..c9313dc5bb 100644
--- a/webdav/file_test.go
+++ b/webdav/file_test.go
@@ -9,7 +9,6 @@ import (
"encoding/xml"
"fmt"
"io"
- "io/ioutil"
"os"
"path"
"path/filepath"
@@ -518,12 +517,7 @@ func TestDir(t *testing.T) {
t.Skip("see golang.org/issue/11453")
}
- td, err := ioutil.TempDir("", "webdav-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(td)
- testFS(t, Dir(td))
+ testFS(t, Dir(t.TempDir()))
}
func TestMemFS(t *testing.T) {
@@ -758,7 +752,7 @@ func TestMemFile(t *testing.T) {
if err != nil {
t.Fatalf("test case #%d %q: OpenFile: %v", i, tc, err)
}
- gotBytes, err := ioutil.ReadAll(g)
+ gotBytes, err := io.ReadAll(g)
if err != nil {
t.Fatalf("test case #%d %q: ReadAll: %v", i, tc, err)
}
diff --git a/webdav/litmus_test_server.go b/webdav/litmus_test_server.go
index 6334d7e233..4d49072c4d 100644
--- a/webdav/litmus_test_server.go
+++ b/webdav/litmus_test_server.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
/*
This program is a server for the WebDAV 'litmus' compliance test at
diff --git a/webdav/webdav.go b/webdav/webdav.go
index add2bcd67c..8ff3d100f9 100644
--- a/webdav/webdav.go
+++ b/webdav/webdav.go
@@ -267,6 +267,9 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int,
f, err := h.FileSystem.OpenFile(ctx, reqPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
+ if os.IsNotExist(err) {
+ return http.StatusConflict, err
+ }
return http.StatusNotFound, err
}
_, copyErr := io.Copy(f, r.Body)
diff --git a/webdav/webdav_test.go b/webdav/webdav_test.go
index 2baebe3c97..deb60fb885 100644
--- a/webdav/webdav_test.go
+++ b/webdav/webdav_test.go
@@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
- "io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
@@ -256,7 +255,7 @@ func TestFilenameEscape(t *testing.T) {
}
defer res.Body.Close()
- b, err := ioutil.ReadAll(res.Body)
+ b, err := io.ReadAll(res.Body)
if err != nil {
return "", "", err
}
@@ -347,3 +346,63 @@ func TestFilenameEscape(t *testing.T) {
}
}
}
+
+func TestPutRequest(t *testing.T) {
+ h := &Handler{
+ FileSystem: NewMemFS(),
+ LockSystem: NewMemLS(),
+ }
+ srv := httptest.NewServer(h)
+ defer srv.Close()
+
+ do := func(method, urlStr string, body string) (*http.Response, error) {
+ bodyReader := strings.NewReader(body)
+ req, err := http.NewRequest(method, urlStr, bodyReader)
+ if err != nil {
+ return nil, err
+ }
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ return res, nil
+ }
+
+ testCases := []struct {
+ name string
+ urlPrefix string
+ want int
+ }{{
+ name: "put",
+ urlPrefix: "/res",
+ want: http.StatusCreated,
+ }, {
+ name: "put_utf8_segment",
+ urlPrefix: "/res-%e2%82%ac",
+ want: http.StatusCreated,
+ }, {
+ name: "put_empty_segment",
+ urlPrefix: "",
+ want: http.StatusNotFound,
+ }, {
+ name: "put_root_segment",
+ urlPrefix: "/",
+ want: http.StatusNotFound,
+ }, {
+ name: "put_no_parent [RFC4918:S9.7.1]",
+ urlPrefix: "/409me/noparent.txt",
+ want: http.StatusConflict,
+ }}
+
+ for _, tc := range testCases {
+ urlStr := srv.URL + tc.urlPrefix
+ res, err := do("PUT", urlStr, "ABC\n")
+ if err != nil {
+ t.Errorf("name=%q: PUT: %v", tc.name, err)
+ continue
+ }
+ if res.StatusCode != tc.want {
+ t.Errorf("name=%q: got status code %d, want %d", tc.name, res.StatusCode, tc.want)
+ }
+ }
+}
diff --git a/websocket/client.go b/websocket/client.go
index 69a4ac7eef..1e64157f3e 100644
--- a/websocket/client.go
+++ b/websocket/client.go
@@ -6,10 +6,12 @@ package websocket
import (
"bufio"
+ "context"
"io"
"net"
"net/http"
"net/url"
+ "time"
)
// DialError is an error that occurs while dialling a websocket server.
@@ -79,28 +81,59 @@ func parseAuthority(location *url.URL) string {
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
- var client net.Conn
+ return config.DialContext(context.Background())
+}
+
+// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
+func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
+
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
- client, err = dialWithDialer(dialer, config)
- if err != nil {
- goto Error
- }
- ws, err = NewClient(config, client)
+
+ client, err := dialWithDialer(ctx, dialer, config)
if err != nil {
- client.Close()
- goto Error
+ return nil, &DialError{config, err}
}
- return
-Error:
- return nil, &DialError{config, err}
+ // Cleanup the connection if we fail to create the websocket successfully
+ success := false
+ defer func() {
+ if !success {
+ _ = client.Close()
+ }
+ }()
+
+ var ws *Conn
+ var wsErr error
+ doneConnecting := make(chan struct{})
+ go func() {
+ defer close(doneConnecting)
+ ws, err = NewClient(config, client)
+ if err != nil {
+ wsErr = &DialError{config, err}
+ }
+ }()
+
+ // The websocket.NewClient() function can block indefinitely, make sure that we
+ // respect the deadlines specified by the context.
+ select {
+ case <-ctx.Done():
+ // Force the pending operations to fail, terminating the pending connection attempt
+ _ = client.SetDeadline(time.Now())
+ <-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
+ return nil, &DialError{config, ctx.Err()}
+ case <-doneConnecting:
+ if wsErr == nil {
+ success = true // Disarm the deferred connection cleanup
+ }
+ return ws, wsErr
+ }
}
diff --git a/websocket/dial.go b/websocket/dial.go
index 2dab943a48..8a2d83c473 100644
--- a/websocket/dial.go
+++ b/websocket/dial.go
@@ -5,18 +5,23 @@
package websocket
import (
+ "context"
"crypto/tls"
"net"
)
-func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
+func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
- conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
+ conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
case "wss":
- conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
+ tlsDialer := &tls.Dialer{
+ NetDialer: dialer,
+ Config: config.TlsConfig,
+ }
+ conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
default:
err = ErrBadScheme
}
diff --git a/websocket/dial_test.go b/websocket/dial_test.go
index aa03e30dd1..dd844872c9 100644
--- a/websocket/dial_test.go
+++ b/websocket/dial_test.go
@@ -5,10 +5,13 @@
package websocket
import (
+ "context"
"crypto/tls"
+ "errors"
"fmt"
"log"
"net"
+ "net/http"
"net/http/httptest"
"testing"
"time"
@@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
t.Fatalf("expected timeout error, got %#v", neterr)
}
}
+
+func TestDialConfigTLSWithTimeouts(t *testing.T) {
+ t.Parallel()
+
+ finishedRequest := make(chan bool)
+
+ // Context for cancellation
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // This is a TLS server that blocks each request indefinitely (and cancels the context)
+ tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ cancel()
+ <-finishedRequest
+ }))
+
+ tlsServerAddr := tlsServer.Listener.Addr().String()
+ log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
+ defer tlsServer.Close()
+ defer close(finishedRequest)
+
+ config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
+ config.TlsConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+
+ _, err := config.DialContext(ctx)
+ dialerr, ok := err.(*DialError)
+ if !ok {
+ t.Fatalf("DialError expected, got %#v", err)
+ }
+ if !errors.Is(dialerr.Err, context.Canceled) {
+ t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
+ }
+}
diff --git a/websocket/hybi.go b/websocket/hybi.go
index 48a069e190..dda7434666 100644
--- a/websocket/hybi.go
+++ b/websocket/hybi.go
@@ -16,7 +16,6 @@ import (
"encoding/binary"
"fmt"
"io"
- "io/ioutil"
"net/http"
"net/url"
"strings"
@@ -279,7 +278,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
}
}
if header := frame.HeaderReader(); header != nil {
- io.Copy(ioutil.Discard, header)
+ io.Copy(io.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
@@ -294,7 +293,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
- io.Copy(ioutil.Discard, frame)
+ io.Copy(io.Discard, frame)
if frame.PayloadType() == PingFrame {
if _, err := handler.WritePong(b[:n]); err != nil {
return nil, err
diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go
index 9504aa2d30..f0715d3f6f 100644
--- a/websocket/hybi_test.go
+++ b/websocket/hybi_test.go
@@ -163,7 +163,7 @@ Sec-WebSocket-Protocol: chat
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
- t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
+ t.Errorf("%s expected %q but got %q", k, v, req.Header.Get(k))
}
}
}
diff --git a/websocket/websocket.go b/websocket/websocket.go
index 90a2257cd5..ac76165ceb 100644
--- a/websocket/websocket.go
+++ b/websocket/websocket.go
@@ -8,7 +8,7 @@
// This package currently lacks some features found in an alternative
// and more actively maintained WebSocket package:
//
-// https://pkg.go.dev/nhooyr.io/websocket
+// https://pkg.go.dev/github.com/coder/websocket
package websocket // import "golang.org/x/net/websocket"
import (
@@ -17,7 +17,6 @@ import (
"encoding/json"
"errors"
"io"
- "io/ioutil"
"net"
"net/http"
"net/url"
@@ -208,7 +207,7 @@ again:
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
- io.Copy(ioutil.Discard, trailer)
+ io.Copy(io.Discard, trailer)
}
ws.frameReader = nil
goto again
@@ -330,7 +329,7 @@ func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
- _, err = io.Copy(ioutil.Discard, ws.frameReader)
+ _, err = io.Copy(io.Discard, ws.frameReader)
if err != nil {
return err
}
@@ -362,7 +361,7 @@ again:
return ErrFrameTooLarge
}
payloadType := frame.PayloadType()
- data, err := ioutil.ReadAll(frame)
+ data, err := io.ReadAll(frame)
if err != nil {
return err
}
diff --git a/xsrftoken/xsrf.go b/xsrftoken/xsrf.go
index 3ca5d5b9f5..e808e6dd80 100644
--- a/xsrftoken/xsrf.go
+++ b/xsrftoken/xsrf.go
@@ -45,10 +45,9 @@ func generateTokenAtTime(key, userID, actionID string, now time.Time) string {
h := hmac.New(sha1.New, []byte(key))
fmt.Fprintf(h, "%s:%s:%d", clean(userID), clean(actionID), milliTime)
- // Get the padded base64 string then removing the padding.
+ // Get the no padding base64 string.
tok := string(h.Sum(nil))
- tok = base64.URLEncoding.EncodeToString([]byte(tok))
- tok = strings.TrimRight(tok, "=")
+ tok = base64.RawURLEncoding.EncodeToString([]byte(tok))
return fmt.Sprintf("%s:%d", tok, milliTime)
}