diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index 42987ab7c..a656efc12 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -273,7 +273,6 @@ var ( errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)") errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)") errStringTooLong = errors.New("character string exceeds maximum length (255)") - errCompressedSRV = errors.New("compressed name in SRV resource data") ) // Internal constants. @@ -2028,10 +2027,6 @@ func (n *Name) pack(msg []byte, compression map[string]uint16, compressionOff in // unpack unpacks a domain name. func (n *Name) unpack(msg []byte, off int) (int, error) { - return n.unpackCompressed(msg, off, true /* allowCompression */) -} - -func (n *Name) unpackCompressed(msg []byte, off int, allowCompression bool) (int, error) { // currOff is the current working offset. currOff := off @@ -2076,9 +2071,6 @@ Loop: name = append(name, '.') currOff = endOff case 0xC0: // Pointer - if !allowCompression { - return off, errCompressedSRV - } if currOff >= len(msg) { return off, errInvalidPtr } @@ -2549,7 +2541,7 @@ func unpackSRVResource(msg []byte, off int) (SRVResource, error) { return SRVResource{}, &nestedError{"Port", err} } var target Name - if _, err := target.unpackCompressed(msg, off, false /* allowCompression */); err != nil { + if _, err := target.unpack(msg, off); err != nil { return SRVResource{}, &nestedError{"Target", err} } return SRVResource{priority, weight, port, target}, nil diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index c84d5a3aa..255530598 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -303,28 +303,6 @@ func TestNameUnpackTooLongName(t *testing.T) { } } -func TestIncompressibleName(t *testing.T) { - name := MustNewName("example.com.") - compression := map[string]uint16{} - buf, err := name.pack(make([]byte, 0, 100), compression, 0) - if err != nil { - t.Fatal("first Name.pack() =", err) - } - buf, err = name.pack(buf, compression, 0) - if err != nil { - t.Fatal("second Name.pack() =", err) - } - var n1 Name - off, err := n1.unpackCompressed(buf, 0, false /* allowCompression */) - if err != nil { - t.Fatal("unpacking incompressible name without pointers failed:", err) - } - var n2 Name - if _, err := n2.unpackCompressed(buf, off, false /* allowCompression */); err != errCompressedSRV { - t.Errorf("unpacking compressed incompressible name with pointers: got %v, want = %v", err, errCompressedSRV) - } -} - func checkErrorPrefix(err error, prefix string) bool { e, ok := err.(*nestedError) return ok && e.s == prefix @@ -1657,7 +1635,7 @@ func FuzzUnpackPack(f *testing.F) { msgPacked, err := m.Pack() if err != nil { - t.Fatalf("failed to pack message that was succesfully unpacked: %v", err) + t.Fatalf("failed to pack message that was successfully unpacked: %v", err) } var m2 Message diff --git a/go.mod b/go.mod index 21deffd4b..36207106d 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.15.0 - golang.org/x/sys v0.14.0 - golang.org/x/term v0.14.0 + golang.org/x/crypto v0.21.0 + golang.org/x/sys v0.18.0 + golang.org/x/term v0.18.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 54759e489..69fb10498 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/html/token.go b/html/token.go index de67f938a..3c57880d6 100644 --- a/html/token.go +++ b/html/token.go @@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() { return } switch c { - case ' ', '\n', '\r', '\t', '\f', '/': - z.pendingAttr[0].end = z.raw.end - 1 - return case '=': if z.pendingAttr[0].start+1 == z.raw.end { // WHATWG 13.2.5.32, if we see an equals sign before the attribute name @@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() { continue } fallthrough - case '>': + case ' ', '\n', '\r', '\t', '\f', '/', '>': + // WHATWG 13.2.5.33 Attribute name state + // We need to reconsume the char in the after attribute name state to support the / character z.raw.end-- z.pendingAttr[0].end = z.raw.end return @@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() { if z.err != nil { return } + if c == '/' { + // WHATWG 13.2.5.34 After attribute name state + // U+002F SOLIDUS (/) - Switch to the self-closing start tag state. + return + } if c != '=' { z.raw.end-- return diff --git a/html/token_test.go b/html/token_test.go index b2383a951..8b0d5aab6 100644 --- a/html/token_test.go +++ b/html/token_test.go @@ -601,6 +601,21 @@ var tokenTests = []tokenTest{ `
`, `
`, }, + { + "forward slash before attribute name", + `
`, + `
`, + }, + { + "forward slash before attribute name with spaces around", + `
`, + `
`, + }, + { + "forward slash after attribute name followed by a character", + `
`, + `
`,
+ },
}
func TestTokenizer(t *testing.T) {
diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go
index c3bd9a1ee..6404aaf15 100644
--- a/http/httpproxy/proxy.go
+++ b/http/httpproxy/proxy.go
@@ -149,10 +149,7 @@ func parseProxy(proxy string) (*url.URL, error) {
}
proxyURL, err := url.Parse(proxy)
- if err != nil ||
- (proxyURL.Scheme != "http" &&
- proxyURL.Scheme != "https" &&
- proxyURL.Scheme != "socks5") {
+ if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" {
// proxy was bogus. Try prepending "http://" to it and
// see if that parses correctly. If not, we fall
// through and complain about the original one.
diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go
index d76373295..790afdab7 100644
--- a/http/httpproxy/proxy_test.go
+++ b/http/httpproxy/proxy_test.go
@@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{
HTTPProxy: "cache.corp.example.com",
},
want: "http://cache.corp.example.com",
+}, {
+ // single label domain is recognized as scheme by url.Parse
+ cfg: httpproxy.Config{
+ HTTPProxy: "localhost",
+ },
+ want: "http://localhost",
}, {
cfg: httpproxy.Config{
HTTPProxy: "https://cache.corp.example.com",
@@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{
HTTPProxy: "socks5://127.0.0.1",
},
want: "socks5://127.0.0.1",
+}, {
+ // Preserve unknown schemes.
+ cfg: httpproxy.Config{
+ HTTPProxy: "foo://host",
+ },
+ want: "foo://host",
}, {
// Don't use secure for http
cfg: httpproxy.Config{
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
new file mode 100644
index 000000000..4237b1436
--- /dev/null
+++ b/http2/clientconn_test.go
@@ -0,0 +1,829 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Infrastructure for testing ClientConn.RoundTrip.
+// Put actual tests in transport_test.go.
+
+package http2
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "reflect"
+ "slices"
+ "testing"
+ "time"
+
+ "golang.org/x/net/http2/hpack"
+)
+
+// TestTestClientConn demonstrates usage of testClientConn.
+func TestTestClientConn(t *testing.T) {
+ // newTestClientConn creates a *ClientConn and surrounding test infrastructure.
+ tc := newTestClientConn(t)
+
+ // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
+ // and sends a SETTINGS frame to the client.
+ //
+ // Additional settings may be provided as optional parameters to greet.
+ tc.greet()
+
+ // Request bodies must either be constant (bytes.Buffer, strings.Reader)
+ // or created with newRequestBody.
+ body := tc.newRequestBody()
+ body.writeBytes(10) // 10 arbitrary bytes...
+ body.closeWithError(io.EOF) // ...followed by EOF.
+
+ // tc.roundTrip calls RoundTrip, but does not wait for it to return.
+ // It returns a testRoundTrip.
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ // tc has a number of methods to check for expected frames sent.
+ // Here, we look for headers and the request body.
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: false,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"PUT"},
+ ":path": []string{"/"},
+ },
+ })
+ // Expect 10 bytes of request body in DATA frames.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: 10,
+ })
+
+ // tc.writeHeaders sends a HEADERS frame back to the client.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // Now that we've received headers, RoundTrip has finished.
+ // testRoundTrip has various methods to examine the response,
+ // or to fetch the response and/or error returned by RoundTrip
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+}
+
+// A testClientConn allows testing ClientConn.RoundTrip against a fake server.
+//
+// A test using testClientConn consists of:
+// - actions on the client (calling RoundTrip, making data available to Request.Body);
+// - validation of frames sent by the client to the server; and
+// - providing frames from the server to the client.
+//
+// testClientConn manages synchronization, so tests can generally be written as
+// a linear sequence of actions and validations without additional synchronization.
+type testClientConn struct {
+ t *testing.T
+
+ tr *Transport
+ fr *Framer
+ cc *ClientConn
+ hooks *testSyncHooks
+
+ encbuf bytes.Buffer
+ enc *hpack.Encoder
+
+ roundtrips []*testRoundTrip
+
+ rerr error // returned by Read
+ netConnClosed bool // set when the ClientConn closes the net.Conn
+ rbuf bytes.Buffer // sent to the test conn
+ wbuf bytes.Buffer // sent by the test conn
+}
+
+func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
+ tc := &testClientConn{
+ t: t,
+ tr: cc.t,
+ cc: cc,
+ hooks: cc.t.syncHooks,
+ }
+ cc.tconn = (*testClientConnNetConn)(tc)
+ tc.enc = hpack.NewEncoder(&tc.encbuf)
+ tc.fr = NewFramer(&tc.rbuf, &tc.wbuf)
+ tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
+ tc.fr.SetMaxReadFrameSize(10 << 20)
+ t.Cleanup(func() {
+ tc.sync()
+ if tc.rerr == nil {
+ tc.rerr = io.EOF
+ }
+ tc.sync()
+ })
+ return tc
+}
+
+func (tc *testClientConn) readClientPreface() {
+ tc.t.Helper()
+ // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
+ buf := make([]byte, len(clientPreface))
+ if _, err := io.ReadFull(&tc.wbuf, buf); err != nil {
+ tc.t.Fatalf("reading preface: %v", err)
+ }
+ if !bytes.Equal(buf, clientPreface) {
+ tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
+ }
+}
+
+func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
+ t.Helper()
+
+ tt := newTestTransport(t, opts...)
+ const singleUse = false
+ _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks)
+ if err != nil {
+ t.Fatalf("newClientConn: %v", err)
+ }
+
+ return tt.getConn()
+}
+
+// sync waits for the ClientConn under test to reach a stable state,
+// with all goroutines blocked on some input.
+func (tc *testClientConn) sync() {
+ tc.hooks.waitInactive()
+}
+
+// advance advances synthetic time by a duration.
+func (tc *testClientConn) advance(d time.Duration) {
+ tc.hooks.advance(d)
+ tc.sync()
+}
+
+// hasFrame reports whether a frame is available to be read.
+func (tc *testClientConn) hasFrame() bool {
+ return tc.wbuf.Len() > 0
+}
+
+// readFrame reads the next frame from the conn.
+func (tc *testClientConn) readFrame() Frame {
+ if tc.wbuf.Len() == 0 {
+ return nil
+ }
+ fr, err := tc.fr.ReadFrame()
+ if err != nil {
+ return nil
+ }
+ return fr
+}
+
+// testClientConnReadFrame reads a frame of a specific type from the conn.
+func testClientConnReadFrame[T any](tc *testClientConn) T {
+ tc.t.Helper()
+ var v T
+ fr := tc.readFrame()
+ if fr == nil {
+ tc.t.Fatalf("got no frame, want frame %T", v)
+ }
+ v, ok := fr.(T)
+ if !ok {
+ tc.t.Fatalf("got frame %T, want %T", fr, v)
+ }
+ return v
+}
+
+// wantFrameType reads the next frame from the conn.
+// It produces an error if the frame type is not the expected value.
+func (tc *testClientConn) wantFrameType(want FrameType) {
+ tc.t.Helper()
+ fr := tc.readFrame()
+ if fr == nil {
+ tc.t.Fatalf("got no frame, want frame %v", want)
+ }
+ if got := fr.Header().Type; got != want {
+ tc.t.Fatalf("got frame %v, want %v", got, want)
+ }
+}
+
+// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied.
+//
+// want is a list of func(*SomeFrame) bool.
+// wantUnorderedFrames will call each func with frames of the appropriate type
+// until the func returns true.
+// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
+// or all funcs with that type have returned true), or if the conn runs out of frames
+// with unsatisfied funcs.
+//
+// Example:
+//
+// // Read a SETTINGS frame, and any number of DATA frames for a stream.
+// // The SETTINGS frame may appear anywhere in the sequence.
+// // The last DATA frame must indicate the end of the stream.
+// tc.wantUnorderedFrames(
+// func(f *SettingsFrame) bool {
+// return true
+// },
+// func(f *DataFrame) bool {
+// return f.StreamEnded()
+// },
+// )
+func (tc *testClientConn) wantUnorderedFrames(want ...any) {
+ tc.t.Helper()
+ want = slices.Clone(want)
+ seen := 0
+frame:
+ for seen < len(want) && !tc.t.Failed() {
+ fr := tc.readFrame()
+ if fr == nil {
+ break
+ }
+ for i, f := range want {
+ if f == nil {
+ continue
+ }
+ typ := reflect.TypeOf(f)
+ if typ.Kind() != reflect.Func ||
+ typ.NumIn() != 1 ||
+ typ.NumOut() != 1 ||
+ typ.Out(0) != reflect.TypeOf(true) {
+ tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
+ }
+ if typ.In(0) == reflect.TypeOf(fr) {
+ out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
+ if out[0].Bool() {
+ want[i] = nil
+ seen++
+ }
+ continue frame
+ }
+ }
+ tc.t.Errorf("got unexpected frame type %T", fr)
+ }
+ if seen < len(want) {
+ for _, f := range want {
+ if f == nil {
+ continue
+ }
+ tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
+ }
+ tc.t.Fatalf("did not see %v expected frame types", len(want)-seen)
+ }
+}
+
+type wantHeader struct {
+ streamID uint32
+ endStream bool
+ header http.Header
+}
+
+// wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
+// and asserts that they contain the expected headers.
+func (tc *testClientConn) wantHeaders(want wantHeader) {
+ tc.t.Helper()
+ got := testClientConnReadFrame[*MetaHeadersFrame](tc)
+ if got, want := got.StreamID, want.streamID; got != want {
+ tc.t.Fatalf("got stream ID %v, want %v", got, want)
+ }
+ if got, want := got.StreamEnded(), want.endStream; got != want {
+ tc.t.Fatalf("got stream ended %v, want %v", got, want)
+ }
+ gotHeader := make(http.Header)
+ for _, f := range got.Fields {
+ gotHeader[f.Name] = append(gotHeader[f.Name], f.Value)
+ }
+ for k, v := range want.header {
+ if !reflect.DeepEqual(v, gotHeader[k]) {
+ tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
+ }
+ }
+}
+
+type wantData struct {
+ streamID uint32
+ endStream bool
+ size int
+}
+
+// wantData reads zero or more DATA frames, and asserts that they match the expectation.
+func (tc *testClientConn) wantData(want wantData) {
+ tc.t.Helper()
+ gotSize := 0
+ gotEndStream := false
+ for tc.hasFrame() && !gotEndStream {
+ data := testClientConnReadFrame[*DataFrame](tc)
+ gotSize += len(data.Data())
+ if data.StreamEnded() {
+ gotEndStream = true
+ }
+ }
+ if gotSize != want.size {
+ tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
+ }
+ if gotEndStream != want.endStream {
+ tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
+ }
+}
+
+// testRequestBody is a Request.Body for use in tests.
+type testRequestBody struct {
+ tc *testClientConn
+
+ // At most one of buf or bytes can be set at any given time:
+ buf bytes.Buffer // specific bytes to read from the body
+ bytes int // body contains this many arbitrary bytes
+
+ err error // read error (comes after any available bytes)
+}
+
+func (tc *testClientConn) newRequestBody() *testRequestBody {
+ b := &testRequestBody{
+ tc: tc,
+ }
+ return b
+}
+
+// Read is called by the ClientConn to read from a request body.
+func (b *testRequestBody) Read(p []byte) (n int, _ error) {
+ b.tc.cc.syncHooks.blockUntil(func() bool {
+ return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil
+ })
+ switch {
+ case b.buf.Len() > 0:
+ return b.buf.Read(p)
+ case b.bytes > 0:
+ if len(p) > b.bytes {
+ p = p[:b.bytes]
+ }
+ b.bytes -= len(p)
+ for i := range p {
+ p[i] = 'A'
+ }
+ return len(p), nil
+ default:
+ return 0, b.err
+ }
+}
+
+// Close is called by the ClientConn when it is done reading from a request body.
+func (b *testRequestBody) Close() error {
+ return nil
+}
+
+// writeBytes adds n arbitrary bytes to the body.
+func (b *testRequestBody) writeBytes(n int) {
+ b.bytes += n
+ b.checkWrite()
+ b.tc.sync()
+}
+
+// Write adds bytes to the body.
+func (b *testRequestBody) Write(p []byte) (int, error) {
+ n, err := b.buf.Write(p)
+ b.checkWrite()
+ b.tc.sync()
+ return n, err
+}
+
+func (b *testRequestBody) checkWrite() {
+ if b.bytes > 0 && b.buf.Len() > 0 {
+ b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
+ }
+ if b.err != nil {
+ b.tc.t.Fatalf("can't write to request body after closeWithError")
+ }
+}
+
+// closeWithError sets an error which will be returned by Read.
+func (b *testRequestBody) closeWithError(err error) {
+ b.err = err
+ b.tc.sync()
+}
+
+// roundTrip starts a RoundTrip call.
+//
+// (Note that the RoundTrip won't complete until response headers are received,
+// the request times out, or some other terminal condition is reached.)
+func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{
+ t: tc.t,
+ donec: make(chan struct{}),
+ }
+ tc.roundtrips = append(tc.roundtrips, rt)
+ tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs }
+ tc.cc.goRun(func() {
+ defer close(rt.donec)
+ rt.resp, rt.respErr = tc.cc.RoundTrip(req)
+ })
+ tc.sync()
+ tc.hooks.newstream = nil
+
+ tc.t.Cleanup(func() {
+ if !rt.done() {
+ return
+ }
+ res, _ := rt.result()
+ if res != nil {
+ res.Body.Close()
+ }
+ })
+
+ return rt
+}
+
+func (tc *testClientConn) greet(settings ...Setting) {
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.writeSettings(settings...)
+ tc.writeSettingsAck()
+ tc.wantFrameType(FrameSettings) // acknowledgement
+}
+
+func (tc *testClientConn) writeSettings(settings ...Setting) {
+ tc.t.Helper()
+ if err := tc.fr.WriteSettings(settings...); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeSettingsAck() {
+ tc.t.Helper()
+ if err := tc.fr.WriteSettingsAck(); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) {
+ tc.t.Helper()
+ if err := tc.fr.WriteData(streamID, endStream, data); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
+ tc.t.Helper()
+ if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+// makeHeaderBlockFragment encodes headers in a form suitable for inclusion
+// in a HEADERS or CONTINUATION frame.
+//
+// It takes a list of alernating names and values.
+func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
+ if len(s)%2 != 0 {
+ tc.t.Fatalf("uneven list of header name/value pairs")
+ }
+ tc.encbuf.Reset()
+ for i := 0; i < len(s); i += 2 {
+ tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
+ }
+ return tc.encbuf.Bytes()
+}
+
+func (tc *testClientConn) writeHeaders(p HeadersFrameParam) {
+ tc.t.Helper()
+ if err := tc.fr.WriteHeaders(p); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+// writeHeadersMode writes header frames, as modified by mode:
+//
+// - noHeader: Don't write the header.
+// - oneHeader: Write a single HEADERS frame.
+// - splitHeader: Write a HEADERS frame and CONTINUATION frame.
+func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) {
+ tc.t.Helper()
+ switch mode {
+ case noHeader:
+ case oneHeader:
+ tc.writeHeaders(p)
+ case splitHeader:
+ if len(p.BlockFragment) < 2 {
+ panic("too small")
+ }
+ contData := p.BlockFragment[1:]
+ contEnd := p.EndHeaders
+ p.BlockFragment = p.BlockFragment[:1]
+ p.EndHeaders = false
+ tc.writeHeaders(p)
+ tc.writeContinuation(p.StreamID, contEnd, contData)
+ default:
+ panic("bogus mode")
+ }
+}
+
+func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
+ tc.t.Helper()
+ if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) {
+ tc.t.Helper()
+ if err := tc.fr.WriteRSTStream(streamID, code); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writePing(ack bool, data [8]byte) {
+ tc.t.Helper()
+ if err := tc.fr.WritePing(ack, data); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
+ tc.t.Helper()
+ if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) {
+ tc.t.Helper()
+ if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
+// closeWrite causes the net.Conn used by the ClientConn to return a error
+// from Read calls.
+func (tc *testClientConn) closeWrite(err error) {
+ tc.rerr = err
+ tc.sync()
+}
+
+// inflowWindow returns the amount of inbound flow control available for a stream,
+// or for the connection if streamID is 0.
+func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
+ tc.cc.mu.Lock()
+ defer tc.cc.mu.Unlock()
+ if streamID == 0 {
+ return tc.cc.inflow.avail + tc.cc.inflow.unsent
+ }
+ cs := tc.cc.streams[streamID]
+ if cs == nil {
+ tc.t.Errorf("no stream with id %v", streamID)
+ return -1
+ }
+ return cs.inflow.avail + cs.inflow.unsent
+}
+
+// testRoundTrip manages a RoundTrip in progress.
+type testRoundTrip struct {
+ t *testing.T
+ resp *http.Response
+ respErr error
+ donec chan struct{}
+ cs *clientStream
+}
+
+// streamID returns the HTTP/2 stream ID of the request.
+func (rt *testRoundTrip) streamID() uint32 {
+ if rt.cs == nil {
+ panic("stream ID unknown")
+ }
+ return rt.cs.ID
+}
+
+// done reports whether RoundTrip has returned.
+func (rt *testRoundTrip) done() bool {
+ select {
+ case <-rt.donec:
+ return true
+ default:
+ return false
+ }
+}
+
+// result returns the result of the RoundTrip.
+func (rt *testRoundTrip) result() (*http.Response, error) {
+ t := rt.t
+ t.Helper()
+ select {
+ case <-rt.donec:
+ default:
+ t.Fatalf("RoundTrip is not done; want it to be")
+ }
+ return rt.resp, rt.respErr
+}
+
+// response returns the response of a successful RoundTrip.
+// If the RoundTrip unexpectedly failed, it calls t.Fatal.
+func (rt *testRoundTrip) response() *http.Response {
+ t := rt.t
+ t.Helper()
+ resp, err := rt.result()
+ if err != nil {
+ t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
+ }
+ if resp == nil {
+ t.Fatalf("RoundTrip returned nil *Response and nil error")
+ }
+ return resp
+}
+
+// err returns the (possibly nil) error result of RoundTrip.
+func (rt *testRoundTrip) err() error {
+ t := rt.t
+ t.Helper()
+ _, err := rt.result()
+ return err
+}
+
+// wantStatus indicates the expected response StatusCode.
+func (rt *testRoundTrip) wantStatus(want int) {
+ t := rt.t
+ t.Helper()
+ if got := rt.response().StatusCode; got != want {
+ t.Fatalf("got response status %v, want %v", got, want)
+ }
+}
+
+// body reads the contents of the response body.
+func (rt *testRoundTrip) readBody() ([]byte, error) {
+ t := rt.t
+ t.Helper()
+ return io.ReadAll(rt.response().Body)
+}
+
+// wantBody indicates the expected response body.
+// (Note that this consumes the body.)
+func (rt *testRoundTrip) wantBody(want []byte) {
+ t := rt.t
+ t.Helper()
+ got, err := rt.readBody()
+ if err != nil {
+ t.Fatalf("unexpected error reading response body: %v", err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want)
+ }
+}
+
+// wantHeaders indicates the expected response headers.
+func (rt *testRoundTrip) wantHeaders(want http.Header) {
+ t := rt.t
+ t.Helper()
+ res := rt.response()
+ if diff := diffHeaders(res.Header, want); diff != "" {
+ t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+}
+
+// wantTrailers indicates the expected response trailers.
+func (rt *testRoundTrip) wantTrailers(want http.Header) {
+ t := rt.t
+ t.Helper()
+ res := rt.response()
+ if diff := diffHeaders(res.Trailer, want); diff != "" {
+ t.Fatalf("unexpected response trailers:\n%v", diff)
+ }
+}
+
+func diffHeaders(got, want http.Header) string {
+ // nil and 0-length non-nil are equal.
+ if len(got) == 0 && len(want) == 0 {
+ return ""
+ }
+ // We could do a more sophisticated diff here.
+ // DeepEqual is good enough for now.
+ if reflect.DeepEqual(got, want) {
+ return ""
+ }
+ return fmt.Sprintf("got: %v\nwant: %v", got, want)
+}
+
+// testClientConnNetConn implements net.Conn.
+type testClientConnNetConn testClientConn
+
+func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) {
+ nc.cc.syncHooks.blockUntil(func() bool {
+ return nc.rerr != nil || nc.rbuf.Len() > 0
+ })
+ if nc.rbuf.Len() > 0 {
+ return nc.rbuf.Read(b)
+ }
+ return 0, nc.rerr
+}
+
+func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) {
+ return nc.wbuf.Write(b)
+}
+
+func (nc *testClientConnNetConn) Close() error {
+ nc.netConnClosed = true
+ return nil
+}
+
+func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return }
+func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return }
+func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil }
+func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil }
+func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }
+
+// A testTransport allows testing Transport.RoundTrip against fake servers.
+// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
+// should use testClientConn instead.
+type testTransport struct {
+ t *testing.T
+ tr *Transport
+
+ ccs []*testClientConn
+}
+
+func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
+ tr := &Transport{
+ syncHooks: newTestSyncHooks(),
+ }
+ for _, o := range opts {
+ o(tr)
+ }
+
+ tt := &testTransport{
+ t: t,
+ tr: tr,
+ }
+ tr.syncHooks.newclientconn = func(cc *ClientConn) {
+ tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc))
+ }
+
+ t.Cleanup(func() {
+ tt.sync()
+ if len(tt.ccs) > 0 {
+ t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
+ }
+ if tt.tr.syncHooks.total != 0 {
+ t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total)
+ }
+ })
+
+ return tt
+}
+
+func (tt *testTransport) sync() {
+ tt.tr.syncHooks.waitInactive()
+}
+
+func (tt *testTransport) advance(d time.Duration) {
+ tt.tr.syncHooks.advance(d)
+ tt.sync()
+}
+
+func (tt *testTransport) hasConn() bool {
+ return len(tt.ccs) > 0
+}
+
+func (tt *testTransport) getConn() *testClientConn {
+ tt.t.Helper()
+ if len(tt.ccs) == 0 {
+ tt.t.Fatalf("no new ClientConns created; wanted one")
+ }
+ tc := tt.ccs[0]
+ tt.ccs = tt.ccs[1:]
+ tc.sync()
+ tc.readClientPreface()
+ return tc
+}
+
+func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{
+ t: tt.t,
+ donec: make(chan struct{}),
+ }
+ tt.tr.syncHooks.goRun(func() {
+ defer close(rt.donec)
+ rt.resp, rt.respErr = tt.tr.RoundTrip(req)
+ })
+ tt.sync()
+
+ tt.t.Cleanup(func() {
+ if !rt.done() {
+ return
+ }
+ res, _ := rt.result()
+ if res != nil {
+ res.Body.Close()
+ }
+ })
+
+ return rt
+}
diff --git a/http2/frame.go b/http2/frame.go
index c1f6b90dc..43557ab7e 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -1510,13 +1510,12 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
}
func (fr *Framer) maxHeaderStringLen() int {
- v := fr.maxHeaderListSize()
- if uint32(int(v)) == v {
- return int(v)
+ v := int(fr.maxHeaderListSize())
+ if v < 0 {
+ // If maxHeaderListSize overflows an int, use no limit (0).
+ return 0
}
- // They had a crazy big number for MaxHeaderBytes anyway,
- // so give them unlimited header lengths:
- return 0
+ return v
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
@@ -1565,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
+ remainSize = 0
return
}
remainSize -= size
@@ -1577,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
+
+ // Avoid parsing large amounts of headers that we will then discard.
+ // If the sender exceeds the max header list size by too much,
+ // skip parsing the fragment and close the connection.
+ //
+ // "Too much" is either any CONTINUATION frame after we've already
+ // exceeded the max header list size (in which case remainSize is 0),
+ // or a frame whose encoded size is more than twice the remaining
+ // header list bytes we're willing to accept.
+ if int64(len(frag)) > int64(2*remainSize) {
+ if VerboseLogs {
+ log.Printf("http2: header list too large")
+ }
+ // It would be nice to send a RST_STREAM before sending the GOAWAY,
+ // but the structure of the server's frame writer makes this difficult.
+ return nil, ConnectionError(ErrCodeProtocol)
+ }
+
+ // Also close the connection after any CONTINUATION frame following an
+ // invalid header, since we stop tracking the size of the headers after
+ // an invalid one.
+ if invalid != nil {
+ if VerboseLogs {
+ log.Printf("http2: invalid header: %v", invalid)
+ }
+ // It would be nice to send a RST_STREAM before sending the GOAWAY,
+ // but the structure of the server's frame writer makes this difficult.
+ return nil, ConnectionError(ErrCodeProtocol)
+ }
+
if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}
diff --git a/http2/pipe.go b/http2/pipe.go
index 684d984fd..3b9f06b96 100644
--- a/http2/pipe.go
+++ b/http2/pipe.go
@@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
}
}
-var errClosedPipeWrite = errors.New("write on closed buffer")
+var (
+ errClosedPipeWrite = errors.New("write on closed buffer")
+ errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
+)
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
@@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite
}
+ // pipe.setBuffer is never invoked, leaving the buffer uninitialized.
+ // We shouldn't try to write to an uninitialized pipe,
+ // but returning an error is better than panicking.
+ if p.b == nil {
+ return 0, errUninitializedPipeWrite
+ }
return p.b.Write(d)
}
diff --git a/http2/server.go b/http2/server.go
index ae94c6408..ce2e8b40e 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -124,6 +124,7 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout.
+ // If zero or negative, there is no timeout.
IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow
@@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here.
- if sc.hs.WriteTimeout != 0 {
+ if sc.hs.WriteTimeout > 0 {
sc.conn.SetWriteDeadline(time.Time{})
}
@@ -924,7 +925,7 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle)
- if sc.srv.IdleTimeout != 0 {
+ if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
@@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
- if sc.srv.IdleTimeout != 0 {
+ if sc.srv.IdleTimeout > 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway.
- if sc.hs.ReadTimeout != 0 {
+ if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
@@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package
// prior to the upgrade.
- if sc.hs.ReadTimeout != 0 {
+ if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
}
@@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
- if sc.hs.WriteTimeout != 0 {
+ if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
diff --git a/http2/server_push_test.go b/http2/server_push_test.go
index 9882d9ef7..cda8f4336 100644
--- a/http2/server_push_test.go
+++ b/http2/server_push_test.go
@@ -11,6 +11,7 @@ import (
"io/ioutil"
"net/http"
"reflect"
+ "runtime"
"strconv"
"sync"
"testing"
@@ -483,11 +484,7 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) {
ready := make(chan struct{})
errc := make(chan error, 2)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- select {
- case <-ready:
- case <-time.After(5 * time.Second):
- errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
- }
+ <-ready
if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
errc <- fmt.Errorf("Push()=%v, want %v", got, want)
}
@@ -505,6 +502,10 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) {
case <-ready:
return
default:
+ if runtime.GOARCH == "wasm" {
+ // Work around https://go.dev/issue/65178 to avoid goroutine starvation.
+ runtime.Gosched()
+ }
}
st.sc.serveMsgCh <- func(loopNum int) {
if !st.sc.pushEnabled {
diff --git a/http2/server_test.go b/http2/server_test.go
index 22657cbfe..a931a06e5 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -145,6 +145,12 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
ConfigureServer(ts.Config, h2server)
+ // Go 1.22 changes the default minimum TLS version to TLS 1.2,
+ // in order to properly test cases where we want to reject low
+ // TLS versions, we need to explicitly configure the minimum
+ // version here.
+ ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
+
st := &serverTester{
t: t,
ts: ts,
@@ -4572,13 +4578,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) {
sc := &serverConn{
serveG: newGoroutineLock(),
}
- const count = 1000
- for i := 0; i < count; i++ {
- h := fmt.Sprintf("%v-%v", base, i)
+ count := 0
+ added := 0
+ for added < 10*maxCachedCanonicalHeadersKeysSize {
+ h := fmt.Sprintf("%v-%v", base, count)
c := sc.canonicalHeader(h)
if len(h) != len(c) {
t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c)
}
+ count++
+ added += len(h)
}
total := 0
for k, v := range sc.canonHeader {
@@ -4777,3 +4786,89 @@ Frames:
close(s)
}
}
+
+func TestServerContinuationFlood(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Println(r.Header)
+ }, func(ts *httptest.Server) {
+ ts.Config.MaxHeaderBytes = 4096
+ })
+ defer st.Close()
+
+ st.writePreface()
+ st.writeInitialSettings()
+ st.writeSettingsAck()
+
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ })
+ for i := 0; i < 1000; i++ {
+ st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
+ fmt.Sprintf("x-%v", i), "1234567890",
+ ))
+ }
+ st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
+ "x-last-header", "1",
+ ))
+
+ for {
+ f, err := st.readFrame()
+ if err != nil {
+ break
+ }
+ switch f.(type) {
+ case *HeadersFrame:
+ t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection")
+ }
+ }
+ // We expect to have seen a GOAWAY before the connection closes,
+ // but the server will close the connection after one second
+ // whether or not it has finished sending the GOAWAY. On windows-amd64-race
+ // builders, this fairly consistently results in the connection closing without
+ // the GOAWAY being sent.
+ //
+ // Since the server's behavior is inherently racy here and the important thing
+ // is that the connection is closed, don't check for the GOAWAY having been sent.
+}
+
+func TestServerContinuationAfterInvalidHeader(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Println(r.Header)
+ })
+ defer st.Close()
+
+ st.writePreface()
+ st.writeInitialSettings()
+ st.writeSettingsAck()
+
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ })
+ st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
+ "x-invalid-header", "\x00",
+ ))
+ st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
+ "x-valid-header", "1",
+ ))
+
+ var sawGoAway bool
+ for {
+ f, err := st.readFrame()
+ if err != nil {
+ break
+ }
+ switch f.(type) {
+ case *GoAwayFrame:
+ sawGoAway = true
+ case *HeadersFrame:
+ t.Fatalf("received HEADERS frame; want GOAWAY")
+ }
+ }
+ if !sawGoAway {
+ t.Errorf("connection closed with no GOAWAY frame; want one")
+ }
+}
diff --git a/http2/testsync.go b/http2/testsync.go
new file mode 100644
index 000000000..61075bd16
--- /dev/null
+++ b/http2/testsync.go
@@ -0,0 +1,331 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+package http2
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// testSyncHooks coordinates goroutines in tests.
+//
+// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
+// - the goroutine running RoundTrip;
+// - the clientStream.doRequest goroutine, which writes the request; and
+// - the clientStream.readLoop goroutine, which reads the response.
+//
+// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
+// are blocked waiting for some condition such as reading the Request.Body or waiting for
+// flow control to become available.
+//
+// The testSyncHooks also manage timers and synthetic time in tests.
+// This permits us to, for example, start a request and cause it to time out waiting for
+// response headers without resorting to time.Sleep calls.
+type testSyncHooks struct {
+ // active/inactive act as a mutex and condition variable.
+ //
+ // - neither chan contains a value: testSyncHooks is locked.
+ // - active contains a value: unlocked, and at least one goroutine is not blocked
+ // - inactive contains a value: unlocked, and all goroutines are blocked
+ active chan struct{}
+ inactive chan struct{}
+
+ // goroutine counts
+ total int // total goroutines
+ condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
+ blocked []*testBlockedGoroutine // otherwise blocked
+
+ // fake time
+ now time.Time
+ timers []*fakeTimer
+
+ // Transport testing: Report various events.
+ newclientconn func(*ClientConn)
+ newstream func(*clientStream)
+}
+
+// testBlockedGoroutine is a blocked goroutine.
+type testBlockedGoroutine struct {
+ f func() bool // blocked until f returns true
+ ch chan struct{} // closed when unblocked
+}
+
+func newTestSyncHooks() *testSyncHooks {
+ h := &testSyncHooks{
+ active: make(chan struct{}, 1),
+ inactive: make(chan struct{}, 1),
+ condwait: map[*sync.Cond]int{},
+ }
+ h.inactive <- struct{}{}
+ h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
+ return h
+}
+
+// lock acquires the testSyncHooks mutex.
+func (h *testSyncHooks) lock() {
+ select {
+ case <-h.active:
+ case <-h.inactive:
+ }
+}
+
+// waitInactive waits for all goroutines to become inactive.
+func (h *testSyncHooks) waitInactive() {
+ for {
+ <-h.inactive
+ if !h.unlock() {
+ break
+ }
+ }
+}
+
+// unlock releases the testSyncHooks mutex.
+// It reports whether any goroutines are active.
+func (h *testSyncHooks) unlock() (active bool) {
+ // Look for a blocked goroutine which can be unblocked.
+ blocked := h.blocked[:0]
+ unblocked := false
+ for _, b := range h.blocked {
+ if !unblocked && b.f() {
+ unblocked = true
+ close(b.ch)
+ } else {
+ blocked = append(blocked, b)
+ }
+ }
+ h.blocked = blocked
+
+ // Count goroutines blocked on condition variables.
+ condwait := 0
+ for _, count := range h.condwait {
+ condwait += count
+ }
+
+ if h.total > condwait+len(blocked) {
+ h.active <- struct{}{}
+ return true
+ } else {
+ h.inactive <- struct{}{}
+ return false
+ }
+}
+
+// goRun starts a new goroutine.
+func (h *testSyncHooks) goRun(f func()) {
+ h.lock()
+ h.total++
+ h.unlock()
+ go func() {
+ defer func() {
+ h.lock()
+ h.total--
+ h.unlock()
+ }()
+ f()
+ }()
+}
+
+// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
+// It waits until f returns true before proceeding.
+//
+// Example usage:
+//
+// h.blockUntil(func() bool {
+// // Is the context done yet?
+// select {
+// case <-ctx.Done():
+// default:
+// return false
+// }
+// return true
+// })
+// // Wait for the context to become done.
+// <-ctx.Done()
+//
+// The function f passed to blockUntil must be non-blocking and idempotent.
+func (h *testSyncHooks) blockUntil(f func() bool) {
+ if f() {
+ return
+ }
+ ch := make(chan struct{})
+ h.lock()
+ h.blocked = append(h.blocked, &testBlockedGoroutine{
+ f: f,
+ ch: ch,
+ })
+ h.unlock()
+ <-ch
+}
+
+// broadcast is sync.Cond.Broadcast.
+func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
+ h.lock()
+ delete(h.condwait, cond)
+ h.unlock()
+ cond.Broadcast()
+}
+
+// broadcast is sync.Cond.Wait.
+func (h *testSyncHooks) condWait(cond *sync.Cond) {
+ h.lock()
+ h.condwait[cond]++
+ h.unlock()
+}
+
+// newTimer creates a new fake timer.
+func (h *testSyncHooks) newTimer(d time.Duration) timer {
+ h.lock()
+ defer h.unlock()
+ t := &fakeTimer{
+ hooks: h,
+ when: h.now.Add(d),
+ c: make(chan time.Time),
+ }
+ h.timers = append(h.timers, t)
+ return t
+}
+
+// afterFunc creates a new fake AfterFunc timer.
+func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
+ h.lock()
+ defer h.unlock()
+ t := &fakeTimer{
+ hooks: h,
+ when: h.now.Add(d),
+ f: f,
+ }
+ h.timers = append(h.timers, t)
+ return t
+}
+
+func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(ctx)
+ t := h.afterFunc(d, cancel)
+ return ctx, func() {
+ t.Stop()
+ cancel()
+ }
+}
+
+func (h *testSyncHooks) timeUntilEvent() time.Duration {
+ h.lock()
+ defer h.unlock()
+ var next time.Time
+ for _, t := range h.timers {
+ if next.IsZero() || t.when.Before(next) {
+ next = t.when
+ }
+ }
+ if d := next.Sub(h.now); d > 0 {
+ return d
+ }
+ return 0
+}
+
+// advance advances time and causes synthetic timers to fire.
+func (h *testSyncHooks) advance(d time.Duration) {
+ h.lock()
+ defer h.unlock()
+ h.now = h.now.Add(d)
+ timers := h.timers[:0]
+ for _, t := range h.timers {
+ t := t // remove after go.mod depends on go1.22
+ t.mu.Lock()
+ switch {
+ case t.when.After(h.now):
+ timers = append(timers, t)
+ case t.when.IsZero():
+ // stopped timer
+ default:
+ t.when = time.Time{}
+ if t.c != nil {
+ close(t.c)
+ }
+ if t.f != nil {
+ h.total++
+ go func() {
+ defer func() {
+ h.lock()
+ h.total--
+ h.unlock()
+ }()
+ t.f()
+ }()
+ }
+ }
+ t.mu.Unlock()
+ }
+ h.timers = timers
+}
+
+// A timer wraps a time.Timer, or a synthetic equivalent in tests.
+// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
+type timer interface {
+ C() <-chan time.Time
+ Stop() bool
+ Reset(d time.Duration) bool
+}
+
+// timeTimer implements timer using real time.
+type timeTimer struct {
+ t *time.Timer
+ c chan time.Time
+}
+
+// newTimeTimer creates a new timer using real time.
+func newTimeTimer(d time.Duration) timer {
+ ch := make(chan time.Time)
+ t := time.AfterFunc(d, func() {
+ close(ch)
+ })
+ return &timeTimer{t, ch}
+}
+
+// newTimeAfterFunc creates an AfterFunc timer using real time.
+func newTimeAfterFunc(d time.Duration, f func()) timer {
+ return &timeTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+func (t timeTimer) C() <-chan time.Time { return t.c }
+func (t timeTimer) Stop() bool { return t.t.Stop() }
+func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
+
+// fakeTimer implements timer using fake time.
+type fakeTimer struct {
+ hooks *testSyncHooks
+
+ mu sync.Mutex
+ when time.Time // when the timer will fire
+ c chan time.Time // closed when the timer fires; mutually exclusive with f
+ f func() // called when the timer fires; mutually exclusive with c
+}
+
+func (t *fakeTimer) C() <-chan time.Time { return t.c }
+
+func (t *fakeTimer) Stop() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ stopped := t.when.IsZero()
+ t.when = time.Time{}
+ return stopped
+}
+
+func (t *fakeTimer) Reset(d time.Duration) bool {
+ if t.c != nil || t.f == nil {
+ panic("fakeTimer only supports Reset on AfterFunc timers")
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.hooks.lock()
+ defer t.hooks.unlock()
+ active := !t.when.IsZero()
+ t.when = t.hooks.now.Add(d)
+ if !active {
+ t.hooks.timers = append(t.hooks.timers, t)
+ }
+ return active
+}
diff --git a/http2/transport.go b/http2/transport.go
index df578b86c..ce375c8c7 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -147,6 +147,12 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool
+ // IdleConnTimeout is the maximum amount of time an idle
+ // (keep-alive) connection will remain idle before closing
+ // itself.
+ // Zero means no limit.
+ IdleConnTimeout time.Duration
+
// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
@@ -178,6 +184,8 @@ type Transport struct {
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
+
+ syncHooks *testSyncHooks
}
func (t *Transport) maxHeaderListSize() uint32 {
@@ -302,7 +310,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
- idleTimer *time.Timer
+ idleTimer timer
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
@@ -344,6 +352,60 @@ type ClientConn struct {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
+
+ syncHooks *testSyncHooks // can be nil
+}
+
+// Hook points used for testing.
+// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
+// Inside tests, see the testSyncHooks function docs.
+
+// goRun starts a new goroutine.
+func (cc *ClientConn) goRun(f func()) {
+ if cc.syncHooks != nil {
+ cc.syncHooks.goRun(f)
+ return
+ }
+ go f()
+}
+
+// condBroadcast is cc.cond.Broadcast.
+func (cc *ClientConn) condBroadcast() {
+ if cc.syncHooks != nil {
+ cc.syncHooks.condBroadcast(cc.cond)
+ }
+ cc.cond.Broadcast()
+}
+
+// condWait is cc.cond.Wait.
+func (cc *ClientConn) condWait() {
+ if cc.syncHooks != nil {
+ cc.syncHooks.condWait(cc.cond)
+ }
+ cc.cond.Wait()
+}
+
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (cc *ClientConn) newTimer(d time.Duration) timer {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.newTimer(d)
+ }
+ return newTimeTimer(d)
+}
+
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.afterFunc(d, f)
+ }
+ return newTimeAfterFunc(d, f)
+}
+
+func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.contextWithTimeout(ctx, d)
+ }
+ return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
@@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
- cs.cc.cond.Broadcast()
+ cs.cc.condBroadcast()
}
}
@@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
- cc.cond.Broadcast()
+ cc.condBroadcast()
}
}
@@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() {
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
- go func() {
+ cs.cc.goRun(func() {
cs.reqBody.Close()
close(reqBodyClosed)
- }()
+ })
}
type stickyErrWriter struct {
@@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port)
}
-var retryBackoffHook func(time.Duration) *time.Timer
-
-func backoffNewTimer(d time.Duration) *time.Timer {
- if retryBackoffHook != nil {
- return retryBackoffHook(d)
- }
- return time.NewTimer(d)
-}
-
// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
- timer := backoffNewTimer(d)
+ var tm timer
+ if t.syncHooks != nil {
+ tm = t.syncHooks.newTimer(d)
+ t.syncHooks.blockUntil(func() bool {
+ select {
+ case <-tm.C():
+ case <-req.Context().Done():
+ default:
+ return false
+ }
+ return true
+ })
+ } else {
+ tm = newTimeTimer(d)
+ }
select {
- case <-timer.C:
+ case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
- timer.Stop()
+ tm.Stop()
err = req.Context().Err()
}
}
@@ -658,6 +725,9 @@ func canRetryError(err error) bool {
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
+ if t.syncHooks != nil {
+ return t.newClientConn(nil, singleUse, t.syncHooks)
+ }
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil {
return nil, err
}
- return t.newClientConn(tconn, singleUse)
+ return t.newClientConn(tconn, singleUse, nil)
}
func (t *Transport) newTLSConfig(host string) *tls.Config {
@@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
- return t.newClientConn(c, t.disableKeepAlives())
+ return t.newClientConn(c, t.disableKeepAlives(), nil)
}
-func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
+func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
cc := &ClientConn{
t: t,
tconn: c,
@@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
+ syncHooks: hooks,
+ }
+ if hooks != nil {
+ hooks.newclientconn(cc)
+ c = cc.tconn
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
- cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
+ cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return nil, cc.werr
}
- go cc.readLoop()
+ cc.goRun(cc.readLoop)
return cc, nil
}
@@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
- ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
- go func() {
+ cc.goRun(func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled {
break
}
- cc.cond.Wait()
+ cc.condWait()
}
- }()
+ })
shutdownEnterWaitStateHook()
select {
case <-done:
@@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock()
// Free the goroutine above
cancelled = true
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
return ctx.Err()
}
@@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
cc.closeConn()
}
@@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
+ return cc.roundTrip(req, nil)
+}
+
+func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context()
cs := &clientStream{
cc: cc,
@@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
- go cs.doRequest(req)
+ cc.goRun(func() {
+ cs.doRequest(req)
+ })
waitDone := func() error {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.donec:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.donec:
return nil
@@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return err
}
+ if streamf != nil {
+ streamf(cs)
+ }
+
for {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.respHeaderRecv:
+ case <-cs.abort:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
@@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
+ var newStreamHook func(*clientStream)
+ if cc.syncHooks != nil {
+ newStreamHook = cc.syncHooks.newstream
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case cc.reqHeaderMu <- struct{}{}:
+ <-cc.reqHeaderMu
+ case <-cs.reqCancel:
+ case <-ctx.Done():
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
+ if newStreamHook != nil {
+ newStreamHook(cs)
+ }
+
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
@@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
- timer := time.NewTimer(d)
+ timer := cc.newTimer(d)
defer timer.Stop()
- respHeaderTimer = timer.C
+ respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.peerClosed:
+ case <-respHeaderTimer:
+ case <-respHeaderRecv:
+ case <-cs.abort:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.peerClosed:
return nil
@@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil
}
cc.pendingRequests++
- cc.cond.Wait()
+ cc.condWait()
cc.pendingRequests--
select {
case <-cs.abort:
@@ -1871,8 +2015,24 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take)
return take, nil
}
- cc.cond.Wait()
+ cc.condWait()
+ }
+}
+
+func validateHeaders(hdrs http.Header) string {
+ for k, vv := range hdrs {
+ if !httpguts.ValidHeaderFieldName(k) {
+ return fmt.Sprintf("name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error,
+ // because it may be sensitive.
+ return fmt.Sprintf("value for header %q", k)
+ }
+ }
}
+ return ""
}
var errNilRequestURL = errors.New("http2: Request.URI is nil")
@@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
}
}
- // Check for any invalid headers and return an error before we
+ // Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
- for k, vv := range req.Header {
- if !httpguts.ValidHeaderFieldName(k) {
- return nil, fmt.Errorf("invalid HTTP header name %q", k)
- }
- for _, v := range vv {
- if !httpguts.ValidHeaderFieldValue(v) {
- // Don't include the value in the error, because it may be sensitive.
- return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
- }
- }
+ if err := validateHeaders(req.Header); err != "" {
+ return nil, fmt.Errorf("invalid HTTP header %s", err)
+ }
+ if err := validateHeaders(req.Trailer); err != "" {
+ return nil, fmt.Errorf("invalid HTTP trailer %s", err)
}
enumerateHeaders := func(f func(name, value string)) {
@@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
- cc.cond.Broadcast()
+ cc.condBroadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err)
}
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
}
@@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
- var t *time.Timer
+ var t timer
if readIdleTimeout != 0 {
- t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
- defer t.Stop()
+ t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
})
return nil
}
- if !cs.firstByte {
+ if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
@@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams {
cs.flow.add(delta)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.initialWindowSize = s.Val
case SettingHeaderTableSize:
@@ -2911,9 +3065,18 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
fl = &cs.flow
}
if !fl.add(int32(f.Increment)) {
+ // For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR
+ if cs != nil {
+ rl.endStreamError(cs, StreamError{
+ StreamID: f.StreamID,
+ Code: ErrCodeFlowControl,
+ })
+ return nil
+ }
+
return ConnectionError(ErrCodeFlowControl)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
return nil
}
@@ -2955,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
- errc := make(chan error, 1)
- go func() {
+ var pingError error
+ errc := make(chan struct{})
+ cc.goRun(func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
- if err := cc.fr.WritePing(false, p); err != nil {
- errc <- err
+ if pingError = cc.fr.WritePing(false, p); pingError != nil {
+ close(errc)
return
}
- if err := cc.bw.Flush(); err != nil {
- errc <- err
+ if pingError = cc.bw.Flush(); pingError != nil {
+ close(errc)
return
}
- }()
+ })
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-c:
+ case <-errc:
+ case <-ctx.Done():
+ case <-cc.readerDone:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-c:
return nil
- case err := <-errc:
- return err
+ case <-errc:
+ return pingError
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
@@ -3141,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}
func (t *Transport) idleConnTimeout() time.Duration {
+ // to keep things backwards compatible, we use non-zero values of
+ // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
+ // http1 transport, followed by 0
+ if t.IdleConnTimeout != 0 {
+ return t.IdleConnTimeout
+ }
+
if t.t1 != nil {
return t.t1.IdleConnTimeout
}
+
return 0
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index a81131f29..11ff67b4c 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -95,6 +95,88 @@ func startH2cServer(t *testing.T) net.Listener {
return l
}
+func TestIdleConnTimeout(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ idleConnTimeout time.Duration
+ wait time.Duration
+ baseTransport *http.Transport
+ wantNewConn bool
+ }{{
+ name: "NoExpiry",
+ idleConnTimeout: 2 * time.Second,
+ wait: 1 * time.Second,
+ baseTransport: nil,
+ wantNewConn: false,
+ }, {
+ name: "H2TransportTimeoutExpires",
+ idleConnTimeout: 1 * time.Second,
+ wait: 2 * time.Second,
+ baseTransport: nil,
+ wantNewConn: true,
+ }, {
+ name: "H1TransportTimeoutExpires",
+ idleConnTimeout: 0 * time.Second,
+ wait: 1 * time.Second,
+ baseTransport: &http.Transport{
+ IdleConnTimeout: 2 * time.Second,
+ },
+ wantNewConn: false,
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ tt := newTestTransport(t, func(tr *Transport) {
+ tr.IdleConnTimeout = test.idleConnTimeout
+ })
+ var tc *testClientConn
+ for i := 0; i < 3; i++ {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // This request happens on a new conn if it's the first request
+ // (and there is no cached conn), or if the test timeout is long
+ // enough that old conns are being closed.
+ wantConn := i == 0 || test.wantNewConn
+ if has := tt.hasConn(); has != wantConn {
+ t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn)
+ }
+ if wantConn {
+ tc = tt.getConn()
+ // Read client's SETTINGS and first WINDOW_UPDATE,
+ // send our SETTINGS.
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.writeSettings()
+ }
+ if tt.hasConn() {
+ t.Fatalf("request %v: Transport has more than one conn", i)
+ }
+
+ // Respond to the client's request.
+ hf := testClientConnReadFrame[*MetaHeadersFrame](tc)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+
+ // If this was a newly-accepted conn, read the SETTINGS ACK.
+ if wantConn {
+ tc.wantFrameType(FrameSettings) // ACK to our settings
+ }
+
+ tt.advance(test.wait)
+ if got, want := tc.netConnClosed, test.wantNewConn; got != want {
+ t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestTransportH2c(t *testing.T) {
l := startH2cServer(t)
defer l.Close()
@@ -740,53 +822,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) {
return
}
-type clientTester struct {
- t *testing.T
- tr *Transport
- sc, cc net.Conn // server and client conn
- fr *Framer // server's framer
- settings *SettingsFrame
- client func() error
- server func() error
-}
-
-func newClientTester(t *testing.T) *clientTester {
- var dialOnce struct {
- sync.Mutex
- dialed bool
- }
- ct := &clientTester{
- t: t,
- }
- ct.tr = &Transport{
- TLSClientConfig: tlsConfigInsecure,
- DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- dialOnce.Lock()
- defer dialOnce.Unlock()
- if dialOnce.dialed {
- return nil, errors.New("only one dial allowed in test mode")
- }
- dialOnce.dialed = true
- return ct.cc, nil
- },
- }
-
- ln := newLocalListener(t)
- cc, err := net.Dial("tcp", ln.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- sc, err := ln.Accept()
- if err != nil {
- t.Fatal(err)
- }
- ln.Close()
- ct.cc = cc
- ct.sc = sc
- ct.fr = NewFramer(sc, sc)
- return ct
-}
-
func newLocalListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err == nil {
@@ -799,284 +834,70 @@ func newLocalListener(t *testing.T) net.Listener {
return ln
}
-func (ct *clientTester) greet(settings ...Setting) {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- ct.t.Fatalf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- ct.t.Fatalf("Reading client settings frame: %v", err)
- }
- var ok bool
- if ct.settings, ok = f.(*SettingsFrame); !ok {
- ct.t.Fatalf("Wanted client settings frame; got %v", f)
- }
- if err := ct.fr.WriteSettings(settings...); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
-}
-
-func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return nil, err
- }
- if _, ok := f.(*SettingsFrame); ok {
- continue
- }
- return f, nil
- }
-}
-
-// writeReadPing sends a PING and immediately reads the PING ACK.
-// It will fail if any other unread data was pending on the connection,
-// aside from SETTINGS frames.
-func (ct *clientTester) writeReadPing() error {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- if err := ct.fr.WritePing(false, data); err != nil {
- return fmt.Errorf("Error writing PING: %v", err)
- }
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
- }
- p, ok := f.(*PingFrame)
- if !ok {
- return fmt.Errorf("got a %v, want a PING ACK", f)
- }
- if p.Flags&FlagPingAck == 0 {
- return fmt.Errorf("got a PING, want a PING ACK")
- }
- if p.Data != data {
- return fmt.Errorf("got PING data = %x, want %x", p.Data, data)
- }
- return nil
-}
-
-func (ct *clientTester) inflowWindow(streamID uint32) int32 {
- pool := ct.tr.connPoolOrDef.(*clientConnPool)
- pool.mu.Lock()
- defer pool.mu.Unlock()
- if n := len(pool.keys); n != 1 {
- ct.t.Errorf("clientConnPool contains %v keys, expected 1", n)
- return -1
- }
- for cc := range pool.keys {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- if streamID == 0 {
- return cc.inflow.avail + cc.inflow.unsent
- }
- cs := cc.streams[streamID]
- if cs == nil {
- ct.t.Errorf("no stream with id %v", streamID)
- return -1
- }
- return cs.inflow.avail + cs.inflow.unsent
- }
- return -1
-}
-
-func (ct *clientTester) cleanup() {
- ct.tr.CloseIdleConnections()
-
- // close both connections, ignore the error if its already closed
- ct.sc.Close()
- ct.cc.Close()
-}
-
-func (ct *clientTester) run() {
- var errOnce sync.Once
- var wg sync.WaitGroup
-
- run := func(which string, fn func() error) {
- defer wg.Done()
- if err := fn(); err != nil {
- errOnce.Do(func() {
- ct.t.Errorf("%s: %v", which, err)
- ct.cleanup()
- })
- }
- }
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
- wg.Add(2)
- go run("client", ct.client)
- go run("server", ct.server)
- wg.Wait()
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
+ const bodySize = 10 << 20
- errOnce.Do(ct.cleanup) // clean up if no error
-}
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ body := tc.newRequestBody()
+ body.writeBytes(bodySize / 2)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: false,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"PUT"},
+ ":path": []string{"/"},
+ },
+ })
-func (ct *clientTester) readFrame() (Frame, error) {
- return ct.fr.ReadFrame()
-}
+ // Provide enough congestion window for the full request body.
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeWindowUpdate(rt.streamID(), bodySize)
-func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
- for {
- f, err := ct.readFrame()
- if err != nil {
- return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- return hf, nil
- }
-}
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: false,
+ size: bodySize / 2,
+ })
-type countingReader struct {
- n *int64
-}
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", strconv.Itoa(status),
+ ),
+ })
-func (r countingReader) Read(p []byte) (n int, err error) {
- for i := range p {
- p[i] = byte(i)
+ res := rt.response()
+ if res.StatusCode != status {
+ t.Fatalf("status code = %v; want %v", res.StatusCode, status)
}
- atomic.AddInt64(r.n, int64(len(p)))
- return len(p), err
-}
-func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
-func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
+ body.writeBytes(bodySize / 2)
+ body.closeWithError(io.EOF)
-func testTransportReqBodyAfterResponse(t *testing.T, status int) {
- const bodySize = 10 << 20
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- recvLen := make(chan int64, 1)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
-
- body := &pipe{b: new(bytes.Buffer)}
- io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- if res.StatusCode != status {
- return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
- }
- io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
- body.CloseWithError(io.EOF)
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("Slurp: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("unexpected body: %q", slurp)
- }
- res.Body.Close()
- if status == 200 {
- if got := <-recvLen; got != bodySize {
- return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
- }
- } else {
- if got := <-recvLen; got == 0 || got >= bodySize {
- return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
- }
- }
- return nil
+ if status == 200 {
+ // After a 200 response, client sends the remaining request body.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: bodySize / 2,
+ })
+ } else {
+ // After a 403 response, client gives up and resets the stream.
+ tc.wantFrameType(FrameRSTStream)
}
- ct.server = func() error {
- ct.greet()
- defer close(recvLen)
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var dataRecv int64
- var closed bool
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- //println(fmt.Sprintf("server got frame: %v", f))
- ended := false
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- if f.StreamEnded() {
- return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
- }
- case *DataFrame:
- dataLen := len(f.Data())
- if dataLen > 0 {
- if dataRecv == 0 {
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
- return err
- }
- if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
- return err
- }
- }
- dataRecv += int64(dataLen)
-
- if !closed && ((status != 200 && dataRecv > 0) ||
- (status == 200 && f.StreamEnded())) {
- closed = true
- if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
- return err
- }
- }
- if f.StreamEnded() {
- ended = true
- }
- case *RSTStreamFrame:
- if status == 200 {
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- ended = true
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- if ended {
- select {
- case recvLen <- dataRecv:
- default:
- }
- }
- }
- }
- ct.run()
+ rt.wantBody(nil)
}
// See golang.org/issue/13444
@@ -1257,121 +1078,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy
panic("invalid combination")
}
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
- if expect100Continue != noHeader {
- req.Header.Set("Expect", "100-continue")
- }
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("Slurp: %v", err)
- }
- wantBody := resBody
- if !withData {
- wantBody = ""
- }
- if string(slurp) != wantBody {
- return fmt.Errorf("body = %q; want %q", slurp, wantBody)
- }
- if trailers == noHeader {
- if len(res.Trailer) > 0 {
- t.Errorf("Trailer = %v; want none", res.Trailer)
- }
- } else {
- want := http.Header{"Some-Trailer": {"some-value"}}
- if !reflect.DeepEqual(res.Trailer, want) {
- t.Errorf("Trailer = %v; want %v", res.Trailer, want)
- }
- }
- return nil
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
+ if expect100Continue != noHeader {
+ req.Header.Set("Expect", "100-continue")
}
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+ rt := tc.roundTrip(req)
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- endStream := false
- send := func(mode headerType) {
- hbf := buf.Bytes()
- switch mode {
- case oneHeader:
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: true,
- EndStream: endStream,
- BlockFragment: hbf,
- })
- case splitHeader:
- if len(hbf) < 2 {
- panic("too small")
- }
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: false,
- EndStream: endStream,
- BlockFragment: hbf[:1],
- })
- ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
- default:
- panic("bogus mode")
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *DataFrame:
- if !f.StreamEnded() {
- // No need to send flow control tokens. The test request body is tiny.
- continue
- }
- // Response headers (1+ frames; 1 or 2 in this test, but never 0)
- {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
- enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
- if trailers != noHeader {
- enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
- }
- endStream = withData == false && trailers == noHeader
- send(resHeader)
- }
- if withData {
- endStream = trailers == noHeader
- ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
- }
- if trailers != noHeader {
- endStream = true
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
- send(trailers)
- }
- if endStream {
- return nil
- }
- case *HeadersFrame:
- if expect100Continue != noHeader {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
- send(expect100Continue)
- }
- }
- }
+ tc.wantFrameType(FrameHeaders)
+
+ // Possibly 100-continue, or skip when noHeader.
+ tc.writeHeadersMode(expect100Continue, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "100",
+ ),
+ })
+
+ // Client sends request body.
+ tc.wantData(wantData{
+ streamID: rt.streamID(),
+ endStream: true,
+ size: len(reqBody),
+ })
+
+ hdr := []string{
+ ":status", "200",
+ "x-foo", "blah",
+ "x-bar", "more",
+ }
+ if trailers != noHeader {
+ hdr = append(hdr, "trailer", "some-trailer")
+ }
+ tc.writeHeadersMode(resHeader, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: withData == false && trailers == noHeader,
+ BlockFragment: tc.makeHeaderBlockFragment(hdr...),
+ })
+ if withData {
+ endStream := trailers == noHeader
+ tc.writeData(rt.streamID(), endStream, []byte(resBody))
+ }
+ tc.writeHeadersMode(trailers, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "some-trailer", "some-value",
+ ),
+ })
+
+ rt.wantStatus(200)
+ if !withData {
+ rt.wantBody(nil)
+ } else {
+ rt.wantBody([]byte(resBody))
+ }
+ if trailers == noHeader {
+ rt.wantTrailers(nil)
+ } else {
+ rt.wantTrailers(http.Header{
+ "Some-Trailer": {"some-value"},
+ })
}
- ct.run()
}
// Issue 26189, Issue 17739: ignore unknown 1xx responses
@@ -1383,130 +1157,76 @@ func TestTransportUnknown1xx(t *testing.T) {
return nil
}
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 204 {
- return fmt.Errorf("status code = %v; want 204", res.StatusCode)
- }
- want := `code=110 header=map[Foo-Bar:[110]]
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ for i := 110; i <= 114; i++ {
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", fmt.Sprint(i),
+ "foo-bar", fmt.Sprint(i),
+ ),
+ })
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
+
+ res := rt.response()
+ if res.StatusCode != 204 {
+ t.Fatalf("status code = %v; want 204", res.StatusCode)
+ }
+ want := `code=110 header=map[Foo-Bar:[110]]
code=111 header=map[Foo-Bar:[111]]
code=112 header=map[Foo-Bar:[112]]
code=113 header=map[Foo-Bar:[113]]
code=114 header=map[Foo-Bar:[114]]
`
- if got := buf.String(); got != want {
- t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- for i := 110; i <= 114; i++ {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
- enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if got := buf.String(); got != want {
+ t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
}
- ct.run()
-
}
func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("body = %q; want nothing", slurp)
- }
- if _, ok := res.Trailer["Some-Trailer"]; !ok {
- return fmt.Errorf("expected Some-Trailer")
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
-
- var n int
- var hf *HeadersFrame
- for hf == nil && n < 10 {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- hf, _ = f.(*HeadersFrame)
- n++
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- // send headers without Trailer header
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "some-trailer", "I'm an undeclared Trailer!",
+ ),
+ })
- // send trailers
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- ct.run()
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+ rt.wantTrailers(http.Header{
+ "Some-Trailer": []string{"I'm an undeclared Trailer!"},
+ })
}
func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
@@ -1516,10 +1236,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
testTransportInvalidTrailer_Pseudo(t, splitHeader)
}
func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
- testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- })
+ testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"),
+ ":colon", "foo",
+ "foo", "bar",
+ )
}
func TestTransportInvalidTrailer_Capital1(t *testing.T) {
@@ -1529,102 +1249,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) {
testTransportInvalidTrailer_Capital(t, splitHeader)
}
func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
- testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
- })
+ testInvalidTrailer(t, trailers, headerFieldNameError("Capital"),
+ "foo", "bar",
+ "Capital", "bad",
+ )
}
func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
- testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
- })
+ testInvalidTrailer(t, oneHeader, headerFieldNameError(""),
+ "", "bad",
+ )
}
func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
- testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) {
- enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
- })
+ testInvalidTrailer(t, oneHeader, headerFieldValueError("x"),
+ "x", "has\nnewline",
+ )
}
-func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want 200", res.StatusCode)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- se, ok := err.(StreamError)
- if !ok || se.Cause != wantErr {
- return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("body = %q; want nothing", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) {
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var endStream bool
- send := func(mode headerType) {
- hbf := buf.Bytes()
- switch mode {
- case oneHeader:
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: endStream,
- BlockFragment: hbf,
- })
- case splitHeader:
- if len(hbf) < 2 {
- panic("too small")
- }
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: false,
- EndStream: endStream,
- BlockFragment: hbf[:1],
- })
- ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
- default:
- panic("bogus mode")
- }
- }
- // Response headers (1+ frames; 1 or 2 in this test, but never 0)
- {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
- endStream = false
- send(oneHeader)
- }
- // Trailers:
- {
- endStream = true
- buf.Reset()
- writeTrailer(enc)
- send(trailers)
- }
- return nil
- }
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "trailer", "declared",
+ ),
+ })
+ tc.writeHeadersMode(mode, HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(trailers...),
+ })
+
+ rt.wantStatus(200)
+ body, err := rt.readBody()
+ se, ok := err.(StreamError)
+ if !ok || se.Cause != wantErr {
+ t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr)
+ }
+ if len(body) > 0 {
+ t.Fatalf("body = %q; want nothing", body)
}
- ct.run()
}
// headerListSize returns the HTTP2 header list size of h.
@@ -1900,115 +1572,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
}
func TestTransportChecksResponseHeaderListSize(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if e, ok := err.(StreamError); ok {
- err = e.Cause
- }
- if err != errResponseHeaderListSize {
- size := int64(0)
- if res != nil {
- res.Body.Close()
- for k, vv := range res.Header {
- for _, v := range vv {
- size += int64(len(k)) + int64(len(v)) + 32
- }
- }
- }
- return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ hdr := []string{":status", "200"}
+ large := strings.Repeat("a", 1<<10)
+ for i := 0; i < 5042; i++ {
+ hdr = append(hdr, large, large)
+ }
+ hbf := tc.makeHeaderBlockFragment(hdr...)
+ // Note: this number might change if our hpack implementation changes.
+ // That's fine. This is just a sanity check that our response can fit in a single
+ // header block fragment frame.
+ if size, want := len(hbf), 6329; size != want {
+ t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: hbf,
+ })
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- large := strings.Repeat("a", 1<<10)
- for i := 0; i < 5042; i++ {
- enc.WriteField(hpack.HeaderField{Name: large, Value: large})
- }
- if size, want := buf.Len(), 6329; size != want {
- // Note: this number might change if
- // our hpack implementation
- // changes. That's fine. This is
- // just a sanity check that our
- // response can fit in a single
- // header block fragment frame.
- return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
+ res, err := rt.result()
+ if e, ok := err.(StreamError); ok {
+ err = e.Cause
+ }
+ if err != errResponseHeaderListSize {
+ size := int64(0)
+ if res != nil {
+ res.Body.Close()
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ size += int64(len(k)) + int64(len(v)) + 32
}
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
}
}
+ t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
}
- ct.run()
}
func TestTransportCookieHeaderSplit(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- req.Header.Add("Cookie", "a=b;c=d; e=f;")
- req.Header.Add("Cookie", "e=f;g=h; ")
- req.Header.Add("Cookie", "i=j")
- _, err := ct.tr.RoundTrip(req)
- return err
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- dec := hpack.NewDecoder(initialHeaderTableSize, nil)
- hfs, err := dec.DecodeFull(f.HeaderBlockFragment())
- if err != nil {
- return err
- }
- got := []string{}
- want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}
- for _, hf := range hfs {
- if hf.Name == "cookie" {
- got = append(got, hf.Value)
- }
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Cookies = %#v, want %#v", got, want)
- }
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ req.Header.Add("Cookie", "a=b;c=d; e=f;")
+ req.Header.Add("Cookie", "e=f;g=h; ")
+ req.Header.Add("Cookie", "i=j")
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true,
+ header: http.Header{
+ "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"},
+ },
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if err := rt.err(); err != nil {
+ t.Fatalf("RoundTrip = %v, want success", err)
}
- ct.run()
}
// Test that the Transport returns a typed error from Response.Body.Read calls
@@ -2224,55 +1861,49 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
}
func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
- ct := newClientTester(t)
- ct.tr.t1 = &http.Transport{
- ResponseHeaderTimeout: 5 * time.Millisecond,
- }
- ct.client = func() error {
- c := &http.Client{Transport: ct.tr}
- var err error
- var n int64
- const bodySize = 4 << 20
- if body {
- _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
- } else {
- _, err = c.Get("https://dummy.tld/")
- }
- if !isTimeout(err) {
- t.Errorf("client expected timeout error; got %#v", err)
- }
- if body && n != bodySize {
- t.Errorf("only read %d bytes of body; want %d", n, bodySize)
+ const bodySize = 4 << 20
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.t1 = &http.Transport{
+ ResponseHeaderTimeout: 5 * time.Millisecond,
}
- return nil
+ })
+ tc.greet()
+
+ var req *http.Request
+ var reqBody *testRequestBody
+ if body {
+ reqBody = tc.newRequestBody()
+ reqBody.writeBytes(bodySize)
+ reqBody.closeWithError(io.EOF)
+ req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody)
+ req.Header.Set("Content-Type", "text/foo")
+ } else {
+ req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
}
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- switch f := f.(type) {
- case *DataFrame:
- dataLen := len(f.Data())
- if dataLen > 0 {
- if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
- return err
- }
- if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
- return err
- }
- }
- case *RSTStreamFrame:
- if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
- return nil
- }
- }
- }
+
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeWindowUpdate(rt.streamID(), bodySize)
+
+ if body {
+ tc.wantData(wantData{
+ endStream: true,
+ size: bodySize,
+ })
+ }
+
+ tc.advance(4 * time.Millisecond)
+ if rt.done() {
+ t.Fatalf("RoundTrip is done after 4ms; want still waiting")
+ }
+ tc.advance(1 * time.Millisecond)
+
+ if err := rt.err(); !isTimeout(err) {
+ t.Fatalf("RoundTrip error: %v; want timeout error", err)
}
- ct.run()
}
func TestTransportDisableCompression(t *testing.T) {
@@ -2484,7 +2115,8 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) {
}
// golang.org/issue/14048
-func TestTransportFailsOnInvalidHeaders(t *testing.T) {
+// golang.org/issue/64766
+func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
var got []string
for k := range r.Header {
@@ -2497,6 +2129,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) {
tests := [...]struct {
h http.Header
+ t http.Header
wantErr string
}{
0: {
@@ -2515,6 +2148,14 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) {
h: http.Header{"foo": {"foo\x01bar"}},
wantErr: `invalid HTTP header value for header "foo"`,
},
+ 4: {
+ t: http.Header{"foo": {"foo\x01bar"}},
+ wantErr: `invalid HTTP trailer value for header "foo"`,
+ },
+ 5: {
+ t: http.Header{"x-\r\nda": {"foo\x01bar"}},
+ wantErr: `invalid HTTP trailer name "x-\r\nda"`,
+ },
}
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
@@ -2523,6 +2164,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) {
for i, tt := range tests {
req, _ := http.NewRequest("GET", st.ts.URL, nil)
req.Header = tt.h
+ req.Trailer = tt.t
res, err := tr.RoundTrip(req)
var bad bool
if tt.wantErr == "" {
@@ -2658,115 +2300,61 @@ func TestTransportNewTLSConfig(t *testing.T) {
// without END_STREAM, followed by a 0-length DATA frame with
// END_STREAM. Make sure we don't get confused by that. (We did.)
func TestTransportReadHeadResponse(t *testing.T) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- if res.ContentLength != 123 {
- return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("ReadAll: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false, // as the GFE does
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(hf.StreamID, true, nil)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false, // as the GFE does
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "123",
+ ),
+ })
+ tc.writeData(rt.streamID(), true, nil)
- <-clientDone
- return nil
- }
+ res := rt.response()
+ if res.ContentLength != 123 {
+ t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
}
- ct.run()
+ rt.wantBody(nil)
}
func TestTransportReadHeadResponseWithBody(t *testing.T) {
- // This test use not valid response format.
- // Discarding logger output to not spam tests output.
- log.SetOutput(ioutil.Discard)
+ // This test uses an invalid response format.
+ // Discard logger output to not spam tests output.
+ log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
response := "redirecting to /elsewhere"
- ct := newClientTester(t)
- clientDone := make(chan struct{})
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- if res.ContentLength != int64(len(response)) {
- return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
- }
- slurp, err := ioutil.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("ReadAll: %v", err)
- }
- if len(slurp) > 0 {
- return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(hf.StreamID, true, []byte(response))
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", strconv.Itoa(len(response)),
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte(response))
- <-clientDone
- return nil
- }
+ res := rt.response()
+ if res.ContentLength != int64(len(response)) {
+ t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response))
}
- ct.run()
+ rt.wantBody(nil)
}
type neverEnding byte
@@ -2891,190 +2479,125 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
}
func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
+ tc := newTestClientConn(t)
+ tc.greet()
const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
const goAwayDebugData = "some debug data"
- ct.client = func() error {
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if failMidBody {
- if err != nil {
- return fmt.Errorf("unexpected client RoundTrip error: %v", err)
- }
- _, err = io.Copy(ioutil.Discard, res.Body)
- res.Body.Close()
- }
- want := GoAwayError{
- LastStreamID: 5,
- ErrCode: goAwayErrCode,
- DebugData: goAwayDebugData,
- }
- if !reflect.DeepEqual(err, want) {
- t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- t.Logf("ReadFrame: %v", err)
- return nil
- }
- hf, ok := f.(*HeadersFrame)
- if !ok {
- continue
- }
- if failMidBody {
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- }
- // Write two GOAWAY frames, to test that the Transport takes
- // the interesting parts of both.
- ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
- ct.fr.WriteGoAway(5, goAwayErrCode, nil)
- ct.sc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- ct.sc.(*net.TCPConn).Close()
- }
- <-clientDone
- return nil
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ if failMidBody {
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "123",
+ ),
+ })
}
- ct.run()
-}
-func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
- ct := newClientTester(t)
+ // Write two GOAWAY frames, to test that the Transport takes
+ // the interesting parts of both.
+ tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
+ tc.writeGoAway(5, goAwayErrCode, nil)
+ tc.closeWrite(io.EOF)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
+ res, err := rt.result()
+ whence := "RoundTrip"
+ if failMidBody {
+ whence = "Body.Read"
if err != nil {
- return err
+ t.Fatalf("RoundTrip error = %v, want success", err)
}
+ _, err = res.Body.Read(make([]byte, 1))
+ }
- if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
- return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
- }
- res.Body.Close() // leaving 4999 bytes unread
-
- return nil
+ want := GoAwayError{
+ LastStreamID: 5,
+ ErrCode: goAwayErrCode,
+ DebugData: goAwayDebugData,
+ }
+ if !reflect.DeepEqual(err, want) {
+ t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want)
}
- ct.server = func() error {
- ct.greet()
+}
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
+func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
+ initialInflow := tc.inflowWindow(0)
+
+ // Two cases:
+ // - Send one DATA frame with 5000 bytes.
+ // - Send two DATA frames with 1 and 4999 bytes each.
+ //
+ // In both cases, the client should consume one byte of data,
+ // refund that byte, then refund the following 4999 bytes.
+ //
+ // In the second case, the server waits for the client to reset the
+ // stream before sending the second DATA frame. This tests the case
+ // where the client receives a DATA frame after it has reset the stream.
+ const streamNotEnded = false
+ if oneDataFrame {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
+ } else {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
+ }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialInflow := ct.inflowWindow(0)
-
- // Two cases:
- // - Send one DATA frame with 5000 bytes.
- // - Send two DATA frames with 1 and 4999 bytes each.
- //
- // In both cases, the client should consume one byte of data,
- // refund that byte, then refund the following 4999 bytes.
- //
- // In the second case, the server waits for the client to reset the
- // stream before sending the second DATA frame. This tests the case
- // where the client receives a DATA frame after it has reset the stream.
- if oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
- } else {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
- }
+ res := rt.response()
+ if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
+ t.Fatalf("body read = %v, %v; want 1, nil", n, err)
+ }
+ res.Body.Close() // leaving 4999 bytes unread
+ tc.sync()
- wantRST := true
- wantWUF := true
- if !oneDataFrame {
- wantWUF = false // flow control update is small, and will not be sent
- }
- for wantRST || wantWUF {
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
+ sentAdditionalData := false
+ tc.wantUnorderedFrames(
+ func(f *RSTStreamFrame) bool {
+ if f.ErrCode != ErrCodeCancel {
+ t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
}
- switch f := f.(type) {
- case *RSTStreamFrame:
- if !wantRST {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.ErrCode != ErrCodeCancel {
- return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
- }
- wantRST = false
- case *WindowUpdateFrame:
- if !wantWUF {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.Increment != 5000 {
- return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
- }
- wantWUF = false
- default:
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
+ if !oneDataFrame {
+ // Send the remaining data now.
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
+ sentAdditionalData = true
}
- }
- if !oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
+ return true
+ },
+ func(f *WindowUpdateFrame) bool {
+ if !oneDataFrame && !sentAdditionalData {
+ t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
}
- wuf, ok := f.(*WindowUpdateFrame)
- if !ok || wuf.Increment != 5000 {
- return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
+ if f.Increment != 5000 {
+ t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
}
- }
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- if got, want := ct.inflowWindow(0), initialInflow; got != want {
- return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
- }
- return nil
+ return true
+ },
+ )
+
+ if got, want := tc.inflowWindow(0), initialInflow; got != want {
+ t.Fatalf("connection flow tokens = %v, want %v", got, want)
}
- ct.run()
}
// See golang.org/issue/16481
@@ -3090,199 +2613,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
// Issue 16612: adjust flow control on open streams when transport
// receives SETTINGS with INITIAL_WINDOW_SIZE from server.
func TestTransportAdjustsFlowControl(t *testing.T) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
-
const bodySize = 1 << 20
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ // Don't write our SETTINGS yet.
- req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
+ body := tc.newRequestBody()
+ body.writeBytes(bodySize)
+ body.closeWithError(io.EOF)
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+
+ gotBytes := int64(0)
+ for {
+ f := testClientConnReadFrame[*DataFrame](tc)
+ gotBytes += int64(len(f.Data()))
+ // After we've got half the client's initial flow control window's worth
+ // of request body data, give it just enough flow control to finish.
+ if gotBytes >= initialWindowSize/2 {
+ break
}
- res.Body.Close()
- return nil
}
- ct.server = func() error {
- _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- var gotBytes int64
- var sentSettings bool
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- return nil
- default:
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- }
- switch f := f.(type) {
- case *DataFrame:
- gotBytes += int64(len(f.Data()))
- // After we've got half the client's
- // initial flow control window's worth
- // of request body data, give it just
- // enough flow control to finish.
- if gotBytes >= initialWindowSize/2 && !sentSettings {
- sentSettings = true
-
- ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
- ct.fr.WriteWindowUpdate(0, bodySize)
- ct.fr.WriteSettingsAck()
- }
+ tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeSettingsAck()
- if f.StreamEnded() {
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }
- }
+ tc.wantUnorderedFrames(
+ func(f *SettingsFrame) bool { return true },
+ func(f *DataFrame) bool {
+ gotBytes += int64(len(f.Data()))
+ return f.StreamEnded()
+ },
+ )
+
+ if gotBytes != bodySize {
+ t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
}
- ct.run()
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
}
// See golang.org/issue/16556
func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
- ct := newClientTester(t)
-
- unblockClient := make(chan bool, 1)
-
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- <-unblockClient
- return nil
- }
- ct.server = func() error {
- ct.greet()
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
+ initialConnWindow := tc.inflowWindow(0)
+ initialStreamWindow := tc.inflowWindow(rt.streamID())
- initialConnWindow := ct.inflowWindow(0)
+ pad := make([]byte, 5)
+ tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialStreamWindow := ct.inflowWindow(hf.StreamID)
- pad := make([]byte, 5)
- ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- // Padding flow control should have been returned.
- if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
- t.Errorf("conn inflow window = %v, want %v", got, want)
- }
- if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
- t.Errorf("stream inflow window = %v, want %v", got, want)
- }
- unblockClient <- true
- return nil
+ // Padding flow control should have been returned.
+ if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
+ t.Errorf("conn inflow window = %v, want %v", got, want)
+ }
+ if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
+ t.Errorf("stream inflow window = %v, want %v", got, want)
}
- ct.run()
}
// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
// StreamError as a result of the response HEADERS
func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ " content-type", "bogus",
+ ),
+ })
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- return errors.New("unexpected successful GET")
- }
- want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
- if !reflect.DeepEqual(want, err) {
- t.Errorf("RoundTrip error = %#v; want %#v", err, want)
- }
- return nil
+ err := rt.err()
+ want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
+ if !reflect.DeepEqual(err, want) {
+ t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
}
- ct.server = func() error {
- ct.greet()
-
- hf, err := ct.firstHeaders()
- if err != nil {
- return err
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- for {
- fr, err := ct.readFrame()
- if err != nil {
- return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
- }
- if _, ok := fr.(*SettingsFrame); ok {
- continue
- }
- if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
- t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
- }
- break
- }
-
- return nil
+ fr := testClientConnReadFrame[*RSTStreamFrame](tc)
+ if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
+ t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
}
- ct.run()
}
// byteAndEOFReader returns is in an io.Reader which reads one byte
@@ -3576,26 +3024,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
}
func TestTransportCloseAfterLostPing(t *testing.T) {
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.tr.PingTimeout = 1 * time.Second
- ct.tr.ReadIdleTimeout = 1 * time.Second
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- _, err := ct.tr.RoundTrip(req)
- if err == nil || !strings.Contains(err.Error(), "client connection lost") {
- return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- <-clientDone
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.PingTimeout = 1 * time.Second
+ tr.ReadIdleTimeout = 1 * time.Second
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ tc.advance(1 * time.Second)
+ tc.wantFrameType(FramePing)
+
+ tc.advance(1 * time.Second)
+ err := rt.err()
+ if err == nil || !strings.Contains(err.Error(), "client connection lost") {
+ t.Fatalf("expected to get error about \"connection lost\", got %v", err)
}
- ct.run()
}
func TestTransportPingWriteBlocks(t *testing.T) {
@@ -3628,418 +3074,231 @@ func TestTransportPingWriteBlocks(t *testing.T) {
}
}
-func TestTransportPingWhenReading(t *testing.T) {
- testCases := []struct {
- name string
- readIdleTimeout time.Duration
- deadline time.Duration
- expectedPingCount int
- }{
- {
- name: "two pings",
- readIdleTimeout: 100 * time.Millisecond,
- deadline: time.Second,
- expectedPingCount: 2,
- },
- {
- name: "zero ping",
- readIdleTimeout: time.Second,
- deadline: 200 * time.Millisecond,
- expectedPingCount: 0,
- },
- {
- name: "0 readIdleTimeout means no ping",
- readIdleTimeout: 0 * time.Millisecond,
- deadline: 500 * time.Millisecond,
- expectedPingCount: 0,
- },
- }
-
- for _, tc := range testCases {
- tc := tc // capture range variable
- t.Run(tc.name, func(t *testing.T) {
- testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
- })
- }
-}
+func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 1000 * time.Millisecond
+ })
+ tc.greet()
-func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
- var pingCount int
- ct := newClientTester(t)
- ct.tr.ReadIdleTimeout = readIdleTimeout
+ ctx, cancel := context.WithCancel(context.Background())
+ req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- ctx, cancel := context.WithTimeout(context.Background(), deadline)
- defer cancel()
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
- }
- _, err = ioutil.ReadAll(res.Body)
- if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
- return nil
+ for i := 0; i < 5; i++ {
+ // No ping yet...
+ tc.advance(999 * time.Millisecond)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- cancel()
- return err
+ // ...ping now.
+ tc.advance(1 * time.Millisecond)
+ f := testClientConnReadFrame[*PingFrame](tc)
+ tc.writePing(true, f.Data)
}
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var streamID uint32
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-ctx.Done():
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- streamID = f.StreamID
- case *PingFrame:
- pingCount++
- if pingCount == expectedPingCount {
- if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
- return err
- }
- }
- if err := ct.fr.WritePing(true, f.Data); err != nil {
- return err
- }
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ // Cancel the request, Transport resets it and returns an error from body reads.
+ cancel()
+ tc.sync()
+
+ tc.wantFrameType(FrameRSTStream)
+ _, err := rt.readBody()
+ if err == nil {
+ t.Fatalf("Response.Body.Read() = %v, want error", err)
}
- ct.run()
}
-func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
- ln := newLocalListener(t)
- defer ln.Close()
-
- var (
- mu sync.Mutex
- count int
- conns []net.Conn
- )
- var wg sync.WaitGroup
- tr := &Transport{
- TLSClientConfig: tlsConfigInsecure,
- }
- tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- mu.Lock()
- defer mu.Unlock()
- count++
- cc, err := net.Dial("tcp", ln.Addr().String())
- if err != nil {
- return nil, fmt.Errorf("dial error: %v", err)
- }
- conns = append(conns, cc)
- sc, err := ln.Accept()
- if err != nil {
- return nil, fmt.Errorf("accept error: %v", err)
- }
- conns = append(conns, sc)
- ct := &clientTester{
- t: t,
- tr: tr,
- cc: cc,
- sc: sc,
- fr: NewFramer(sc, sc),
- }
- wg.Add(1)
- go func(count int) {
- defer wg.Done()
- server(count, ct)
- }(count)
- return cc, nil
- }
+func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 0 // PINGs disabled
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- client(tr)
- tr.CloseIdleConnections()
- ln.Close()
- for _, c := range conns {
- c.Close()
+ // No PING is sent, even after a long delay.
+ tc.advance(1 * time.Minute)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- wg.Wait()
}
func TestTransportRetryAfterGOAWAY(t *testing.T) {
- client := func(tr *Transport) {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := tr.RoundTrip(req)
- if res != nil {
- res.Body.Close()
- if got := res.Header.Get("Foo"); got != "bar" {
- err = fmt.Errorf("foo header = %q; want bar", got)
- }
- }
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- }
- }
-
- server := func(count int, ct *clientTester) {
- switch count {
- case 1:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server1 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server1 got %v", hf)
- if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
- t.Errorf("server1 failed writing GOAWAY: %v", err)
- return
- }
- case 2:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server2 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server2 got %v", hf)
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- err = ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- if err != nil {
- t.Errorf("server2 failed writing response HEADERS: %v", err)
- }
- default:
- t.Errorf("unexpected number of dials")
- return
- }
- }
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil)
+ if rt.done() {
+ t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
+ }
+
+ // Second attempt succeeds on a new connection.
+ tc = tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- testClientMultipleDials(t, client, server)
+ rt.wantStatus(200)
}
func TestTransportRetryAfterRefusedStream(t *testing.T) {
- clientDone := make(chan struct{})
- client := func(tr *Transport) {
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := tr.RoundTrip(req)
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- return
- }
- resp.Body.Close()
- if resp.StatusCode != 204 {
- t.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a RST_STREAM.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // settings ACK
+ tc.writeRSTStream(1, ErrCodeRefusedStream)
+ if rt.done() {
+ t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
}
- server := func(_ int, ct *clientTester) {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var count int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- default:
- t.Error(err)
- }
- return
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- t.Errorf("headers should have END_HEADERS be ended: %v", f)
- return
- }
- count++
- if count == 1 {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- } else {
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- default:
- t.Errorf("Unexpected client frame %v", f)
- return
- }
- }
- }
+ // Second attempt succeeds on the same connection.
+ tc.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 3,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
- testClientMultipleDials(t, client, server)
+ rt.wantStatus(204)
}
func TestTransportRetryHasLimit(t *testing.T) {
- // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
- if testing.Short() {
- t.Skip("skipping long test in short mode")
- }
- retryBackoffHook = func(d time.Duration) *time.Timer {
- return time.NewTimer(0) // fires immediately
- }
- defer func() {
- retryBackoffHook = nil
- }()
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := ct.tr.RoundTrip(req)
- if err == nil {
- return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+
+ var totalDelay time.Duration
+ count := 0
+ for streamID := uint32(1); ; streamID += 2 {
+ count++
+ tc.wantHeaders(wantHeader{
+ streamID: streamID,
+ endStream: true,
+ })
+ if streamID == 1 {
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // settings ACK
}
- t.Logf("expected error, got: %v", err)
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
+ tc.writeRSTStream(streamID, ErrCodeRefusedStream)
+
+ d := tt.tr.syncHooks.timeUntilEvent()
+ if d == 0 {
+ if streamID == 1 {
+ continue
}
+ break
+ }
+ totalDelay += d
+ if totalDelay > 5*time.Minute {
+ t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay)
}
+ tt.advance(d)
+ }
+ if got, want := count, 5; got < count {
+ t.Errorf("RoundTrip made %v attempts, want at least %v", got, want)
+ }
+ if rt.err() == nil {
+ t.Errorf("RoundTrip succeeded, want error")
}
- ct.run()
}
func TestTransportResponseDataBeforeHeaders(t *testing.T) {
- // This test use not valid response format.
- // Discarding logger output to not spam tests output.
- log.SetOutput(ioutil.Discard)
- defer log.SetOutput(os.Stderr)
+ // Discard log output complaining about protocol error.
+ log.SetOutput(io.Discard)
+ t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done
+
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ // First request is normal to ensure the check is per stream and not per connection.
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt1.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
- // First request is normal to ensure the check is per stream and not per connection.
- _, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip expected no error, got: %v", err)
- }
- // Second request returns a DATA frame with no HEADERS.
- resp, err := ct.tr.RoundTrip(req)
- if err == nil {
- return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
- }
- if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
- return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err == io.EOF {
- return nil
- } else if err != nil {
- return err
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame:
- case *HeadersFrame:
- switch f.StreamID {
- case 1:
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- case 3:
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- }
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ // Second request returns a DATA frame with no HEADERS.
+ rt2 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeData(rt2.streamID(), true, []byte("payload"))
+ if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
+ t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
}
- ct.run()
}
func TestTransportMaxFrameReadSize(t *testing.T) {
@@ -4053,30 +3312,17 @@ func TestTransportMaxFrameReadSize(t *testing.T) {
maxReadFrameSize: 1024,
want: minMaxFrameSize,
}} {
- ct := newClientTester(t)
- ct.tr.MaxReadFrameSize = test.maxReadFrameSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- return nil
- }
- ct.server = func() error {
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
- var got uint32
- ct.settings.ForeachSetting(func(s Setting) error {
- switch s.ID {
- case SettingMaxFrameSize:
- got = s.Val
- }
- return nil
- })
- if got != test.want {
- t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
- }
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxReadFrameSize = test.maxReadFrameSize
+ })
+
+ fr := testClientConnReadFrame[*SettingsFrame](tc)
+ got, ok := fr.Value(SettingMaxFrameSize)
+ if !ok {
+ t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
+ } else if got != test.want {
+ t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
}
- ct.run()
}
}
@@ -4108,345 +3354,134 @@ func TestTransportRequestsLowServerLimit(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- res, err := tr.RoundTrip(req)
- if err != nil {
- t.Fatal(err)
- }
- if got, want := res.StatusCode, 200; got != want {
- t.Errorf("StatusCode = %v; want %v", got, want)
- }
- if res != nil && res.Body != nil {
- res.Body.Close()
- }
- }
-
- if connCount != 1 {
- t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
- }
-}
-
-// tests Transport.StrictMaxConcurrentStreams
-func TestTransportRequestsStallAtServerLimit(t *testing.T) {
- const maxConcurrent = 2
-
- greet := make(chan struct{}) // server sends initial SETTINGS frame
- gotRequest := make(chan struct{}) // server received a request
- clientDone := make(chan struct{})
- cancelClientRequest := make(chan struct{})
-
- // Collect errors from goroutines.
- var wg sync.WaitGroup
- errs := make(chan error, 100)
- defer func() {
- wg.Wait()
- close(errs)
- for err := range errs {
- t.Error(err)
- }
- }()
-
- // We will send maxConcurrent+2 requests. This checker goroutine waits for the
- // following stages:
- // 1. The first maxConcurrent requests are received by the server.
- // 2. The client will cancel the next request
- // 3. The server is unblocked so it can service the first maxConcurrent requests
- // 4. The client will send the final request
- wg.Add(1)
- unblockClient := make(chan struct{})
- clientRequestCancelled := make(chan struct{})
- unblockServer := make(chan struct{})
- go func() {
- defer wg.Done()
- // Stage 1.
- for k := 0; k < maxConcurrent; k++ {
- <-gotRequest
- }
- // Stage 2.
- close(unblockClient)
- <-clientRequestCancelled
- // Stage 3: give some time for the final RoundTrip call to be scheduled and
- // verify that the final request is not sent.
- time.Sleep(50 * time.Millisecond)
- select {
- case <-gotRequest:
- errs <- errors.New("last request did not stall")
- close(unblockServer)
- return
- default:
- }
- close(unblockServer)
- // Stage 4.
- <-gotRequest
- }()
-
- ct := newClientTester(t)
- ct.tr.StrictMaxConcurrentStreams = true
- ct.client = func() error {
- var wg sync.WaitGroup
- defer func() {
- wg.Wait()
- close(clientDone)
- ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- ct.cc.(*net.TCPConn).Close()
- }
- }()
- for k := 0; k < maxConcurrent+2; k++ {
- wg.Add(1)
- go func(k int) {
- defer wg.Done()
- // Don't send the second request until after receiving SETTINGS from the server
- // to avoid a race where we use the default SettingMaxConcurrentStreams, which
- // is much larger than maxConcurrent. We have to send the first request before
- // waiting because the first request triggers the dial and greet.
- if k > 0 {
- <-greet
- }
- // Block until maxConcurrent requests are sent before sending any more.
- if k >= maxConcurrent {
- <-unblockClient
- }
- body := newStaticCloseChecker("")
- req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
- if k == maxConcurrent {
- // This request will be canceled.
- req.Cancel = cancelClientRequest
- close(cancelClientRequest)
- _, err := ct.tr.RoundTrip(req)
- close(clientRequestCancelled)
- if err == nil {
- errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
- return
- }
- } else {
- resp, err := ct.tr.RoundTrip(req)
- if err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- return
- }
- ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- if resp.StatusCode != 204 {
- errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
- }
- if err := body.isClosed(); err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- }
- }(k)
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := res.StatusCode, 200; got != want {
+ t.Errorf("StatusCode = %v; want %v", got, want)
+ }
+ if res != nil && res.Body != nil {
+ res.Body.Close()
}
- return nil
}
- ct.server = func() error {
- var wg sync.WaitGroup
- defer wg.Wait()
+ if connCount != 1 {
+ t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
+ }
+}
- ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+// tests Transport.StrictMaxConcurrentStreams
+func TestTransportRequestsStallAtServerLimit(t *testing.T) {
+ const maxConcurrent = 2
- // Server write loop.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- writeResp := make(chan uint32, maxConcurrent+1)
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.StrictMaxConcurrentStreams = true
+ })
+ tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-unblockServer
- for id := range writeResp {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: id,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }()
+ cancelClientRequest := make(chan struct{})
- // Server read loop.
- var nreq int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it will have reported any errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame:
- case *SettingsFrame:
- // Wait for the client SETTINGS ack until ending the greet.
- close(greet)
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- gotRequest <- struct{}{}
- nreq++
- writeResp <- f.StreamID
- if nreq == maxConcurrent+1 {
- close(writeResp)
- }
- case *DataFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
+ // Start maxConcurrent+2 requests.
+ // The server does not respond to any of them yet.
+ var rts []*testRoundTrip
+ for k := 0; k < maxConcurrent+2; k++ {
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ req.Cancel = cancelClientRequest
+ }
+ rt := tc.roundTrip(req)
+ rts = append(rts, rt)
+
+ if k < maxConcurrent {
+ // We are under the stream limit, so the client sends the request.
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", k)},
+ },
+ })
+ } else {
+ // We have reached the stream limit,
+ // so the client cannot send the request.
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
}
}
+
+ if rt.done() {
+ t.Fatalf("rt %v done", k)
+ }
+ }
+
+ // Cancel the maxConcurrent'th request.
+ // The request should fail.
+ close(cancelClientRequest)
+ tc.sync()
+ if err := rts[maxConcurrent].err(); err == nil {
+ t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
+ }
+
+ // No requests should be complete, except for the canceled one.
+ for i, rt := range rts {
+ if i != maxConcurrent && rt.done() {
+ t.Fatalf("RoundTrip(%d) is done, but should not be", i)
+ }
}
- ct.run()
+ // Server responds to a request, unblocking the last one.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rts[0].streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ tc.wantHeaders(wantHeader{
+ streamID: rts[maxConcurrent+1].streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
+ },
+ })
+ rts[0].wantStatus(200)
}
func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var reqSize, resSize uint32 = 8192, 16384
- ct.tr.MaxDecoderHeaderTableSize = reqSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.peerMaxHeaderTableSize, resSize; got != want {
- return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want)
- }
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxDecoderHeaderTableSize = reqSize
+ })
+
+ fr := testClientConnReadFrame[*SettingsFrame](tc)
+ if v, ok := fr.Value(SettingHeaderTableSize); !ok {
+ t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
+ } else if v != reqSize {
+ t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
}
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- var found bool
- err = sf.ForeachSetting(func(s Setting) error {
- if s.ID == SettingHeaderTableSize {
- found = true
- if got, want := s.Val, reqSize; got != want {
- return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want)
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- if !found {
- return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting")
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
+ if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want {
+ t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
}
- ct.run()
}
func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var peerAdvertisedMaxHeaderTableSize uint32 = 16384
- ct.tr.MaxEncoderHeaderTableSize = 8192
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want {
- return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
- }
- return nil
- }
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxEncoderHeaderTableSize = 8192
+ })
+ tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want {
+ t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
}
- ct.run()
}
func TestAuthorityAddr(t *testing.T) {
@@ -4530,40 +3565,24 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
// Issue 18891: make sure Request.Body == NoBody means no DATA frame
// is ever sent, even if empty.
func TestTransportNoBodyMeansNoDATA(t *testing.T) {
- ct := newClientTester(t)
-
- unblockClient := make(chan bool)
-
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- <-unblockClient
- return nil
- }
- ct.server = func() error {
- defer close(unblockClient)
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f := f.(type) {
- default:
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- case *HeadersFrame:
- if !f.StreamEnded() {
- return fmt.Errorf("got headers frame without END_STREAM")
- }
- return nil
- }
- }
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
+ rt := tc.roundTrip(req)
+
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true, // END_STREAM should be set when body is http.NoBody
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ },
+ })
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("unexpected frame after headers: %v", fr)
}
- ct.run()
}
func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
@@ -4642,41 +3661,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
// Verify transport doesn't crash when receiving bogus response lacking a :status header.
// Issue 22880.
func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- _, err := ct.tr.RoundTrip(req)
- const substr = "malformed response from server: missing status pseudo header"
- if !strings.Contains(fmt.Sprint(err), substr) {
- return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false, // we'll send some DATA to try to crash the transport
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- return nil
- }
- }
- }
- ct.run()
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false, // we'll send some DATA to try to crash the transport
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "content-type", "text/html", // no :status header
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte("payload"))
}
func BenchmarkClientRequestHeaders(b *testing.B) {
@@ -5024,95 +4024,42 @@ func (r *errReader) Read(p []byte) (int, error) {
}
func testTransportBodyReadError(t *testing.T, body []byte) {
- if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
- // So far we've only seen this be flaky on Windows and Plan 9,
- // perhaps due to TCP behavior on shutdowns while
- // unread data is in flight. This test should be
- // fixed, but a skip is better than annoying people
- // for now.
- t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
- }
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
-
- checkNoStreams := func() error {
- cp, ok := ct.tr.connPool().(*clientConnPool)
- if !ok {
- return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
- }
- cp.mu.Lock()
- defer cp.mu.Unlock()
- conns, ok := cp.conns["dummy.tld:443"]
- if !ok {
- return fmt.Errorf("missing connection")
- }
- if len(conns) != 1 {
- return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
- }
- if activeStreams(conns[0]) != 0 {
- return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
- }
- return nil
- }
- bodyReadError := errors.New("body read error")
- body := &errReader{body, bodyReadError}
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != bodyReadError {
- return fmt.Errorf("err = %v; want %v", err, bodyReadError)
- }
- if err = checkNoStreams(); err != nil {
- return err
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyReadError := errors.New("body read error")
+ b := tc.newRequestBody()
+ b.Write(body)
+ b.closeWithError(bodyReadError)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ var receivedBody []byte
+readFrames:
+ for {
+ switch f := tc.readFrame().(type) {
+ case *DataFrame:
+ receivedBody = append(receivedBody, f.Data()...)
+ case *RSTStreamFrame:
+ break readFrames
+ default:
+ t.Fatalf("unexpected frame: %v", f)
+ case nil:
+ t.Fatalf("transport is idle, want RST_STREAM")
}
- return nil
}
- ct.server = func() error {
- ct.greet()
- var receivedBody []byte
- var resetCount int
- for {
- f, err := ct.fr.ReadFrame()
- t.Logf("server: ReadFrame = %v, %v", f, err)
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- if bytes.Compare(receivedBody, body) != 0 {
- return fmt.Errorf("body: %q; expected %q", receivedBody, body)
- }
- if resetCount != 1 {
- return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
- }
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- receivedBody = append(receivedBody, f.Data()...)
- case *RSTStreamFrame:
- resetCount++
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ if !bytes.Equal(receivedBody, body) {
+ t.Fatalf("body: %q; expected %q", receivedBody, body)
+ }
+
+ if err := rt.err(); err != bodyReadError {
+ t.Fatalf("err = %v; want %v", err, bodyReadError)
+ }
+
+ if got := activeStreams(tc.cc); got != 0 {
+ t.Fatalf("active streams count: %v; want 0", got)
}
- ct.run()
}
func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
@@ -5125,59 +4072,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) {
const reqBody = "some request body"
const resBody = "some response body"
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- body := strings.NewReader(reqBody)
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
+ body := strings.NewReader(reqBody)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ tc.roundTrip(req)
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- if !f.StreamEnded() {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- return fmt.Errorf("data frame without END_STREAM %v", f)
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte(resBody))
- return nil
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ tc.wantFrameType(FrameHeaders)
+ f := testClientConnReadFrame[*DataFrame](tc)
+ if !f.StreamEnded() {
+ t.Fatalf("data frame without END_STREAM %v", f)
}
- ct.run()
}
type chunkReader struct {
@@ -5826,155 +4732,80 @@ func TestTransportCloseRequestBody(t *testing.T) {
}
}
-// collectClientsConnPool is a ClientConnPool that wraps lower and
-// collects what calls were made on it.
-type collectClientsConnPool struct {
- lower ClientConnPool
-
- mu sync.Mutex
- getErrs int
- got []*ClientConn
-}
-
-func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
- cc, err := p.lower.GetClientConn(req, addr)
- p.mu.Lock()
- defer p.mu.Unlock()
- if err != nil {
- p.getErrs++
- return nil, err
- }
- p.got = append(p.got, cc)
- return cc, nil
-}
-
-func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
- p.lower.MarkDead(cc)
-}
-
func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
- ct := newClientTester(t)
- pool := &collectClientsConnPool{
- lower: &clientConnPool{t: ct.tr},
- }
- ct.tr.ConnPool = pool
+ // This test verifies that
+ // - receiving a protocol error on a connection does not interfere with
+ // other requests in flight on that connection;
+ // - the connection is not reused for further requests; and
+ // - the failed request is retried on a new connecection.
+ tt := newTestTransport(t)
+
+ // Start two requests. The first is a long request
+ // that will finish after the second. The second one
+ // will result in the protocol error.
+
+ // Request #1: The long request.
+ req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tt.roundTrip(req1)
+ tc1 := tt.getConn()
+ tc1.wantFrameType(FrameSettings)
+ tc1.wantFrameType(FrameWindowUpdate)
+ tc1.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc1.writeSettings()
+ tc1.wantFrameType(FrameSettings) // settings ACK
+
+ // Request #2(a): The short request.
+ req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt2 := tt.roundTrip(req2)
+ tc1.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
- gotProtoError := make(chan bool, 1)
- ct.tr.CountError = func(errType string) {
- if errType == "recv_rststream_PROTOCOL_ERROR" {
- select {
- case gotProtoError <- true:
- default:
- }
- }
+ // Request #2(a) fails with ErrCodeProtocol.
+ tc1.writeRSTStream(3, ErrCodeProtocol)
+ if rt1.done() {
+ t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress")
}
- ct.client = func() error {
- // Start two requests. The first is a long request
- // that will finish after the second. The second one
- // will result in the protocol error. We check that
- // after the first one closes, the connection then
- // shuts down.
-
- // The long, outer request.
- req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil)
- res1, err := ct.tr.RoundTrip(req1)
- if err != nil {
- return err
- }
- if got, want := res1.Header.Get("Is-Long"), "1"; got != want {
- return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want)
- }
-
- req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil)
- res, err := ct.tr.RoundTrip(req)
- const want = "only one dial allowed in test mode"
- if got := fmt.Sprint(err); got != want {
- t.Errorf("didn't dial again: got %#q; want %#q", got, want)
- }
- if res != nil {
- res.Body.Close()
- }
- select {
- case <-gotProtoError:
- default:
- t.Errorf("didn't get stream protocol error")
- }
-
- if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 {
- t.Errorf("unexpected body read %v, %v", n, err)
- }
-
- pool.mu.Lock()
- defer pool.mu.Unlock()
- if pool.getErrs != 1 {
- t.Errorf("pool get errors = %v; want 1", pool.getErrs)
- }
- if len(pool.got) == 2 {
- if pool.got[0] != pool.got[1] {
- t.Errorf("requests went on different connections")
- }
- cc := pool.got[0]
- cc.mu.Lock()
- if !cc.doNotReuse {
- t.Error("ClientConn not marked doNotReuse")
- }
- cc.mu.Unlock()
-
- select {
- case <-cc.readerDone:
- case <-time.After(5 * time.Second):
- t.Errorf("timeout waiting for reader to be done")
- }
- } else {
- t.Errorf("pool get success = %v; want 2", len(pool.got))
- }
- return nil
+ if rt2.done() {
+ t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress")
}
- ct.server = func() error {
- ct.greet()
- var sentErr bool
- var numHeaders int
- var firstStreamID uint32
-
- var hbuf bytes.Buffer
- enc := hpack.NewEncoder(&hbuf)
- for {
- f, err := ct.fr.ReadFrame()
- if err == io.EOF {
- // Client hung up on us, as it should at the end.
- return nil
- }
- if err != nil {
- return nil
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- numHeaders++
- if numHeaders == 1 {
- firstStreamID = f.StreamID
- hbuf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: hbuf.Bytes(),
- })
- continue
- }
- if !sentErr {
- sentErr = true
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
- ct.fr.WriteData(firstStreamID, true, nil)
- continue
- }
- }
- }
- }
- ct.run()
+ // Request #2(b): The short request is retried on a new connection.
+ tc2 := tt.getConn()
+ tc2.wantFrameType(FrameSettings)
+ tc2.wantFrameType(FrameWindowUpdate)
+ tc2.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc2.writeSettings()
+ tc2.wantFrameType(FrameSettings) // settings ACK
+
+ // Request #2(b) succeeds.
+ tc2.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "201",
+ ),
+ })
+ rt2.wantStatus(201)
+
+ // Request #1 succeeds.
+ tc1.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
}
func TestClientConnReservations(t *testing.T) {
@@ -5987,7 +4818,7 @@ func TestClientConnReservations(t *testing.T) {
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- cc, err := tr.newClientConn(st.cc, false)
+ cc, err := tr.newClientConn(st.cc, false, nil)
if err != nil {
t.Fatal(err)
}
@@ -6026,39 +4857,27 @@ func TestClientConnReservations(t *testing.T) {
}
func TestTransportTimeoutServerHangs(t *testing.T) {
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.greet()
- req, err := http.NewRequest("PUT", "https://dummy.tld/", nil)
- if err != nil {
- return err
- }
+ ctx, cancel := context.WithCancel(context.Background())
+ req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancel()
- req = req.WithContext(ctx)
- req.Header.Add("Big", strings.Repeat("a", 1<<20))
- _, err = ct.tr.RoundTrip(req)
- if err == nil {
- return errors.New("error should not be nil")
- }
- if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
- return fmt.Errorf("error should be a net error timeout: %v", err)
- }
- return nil
+ tc.wantFrameType(FrameHeaders)
+ tc.advance(5 * time.Second)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- ct.server = func() error {
- ct.greet()
- select {
- case <-time.After(5 * time.Second):
- case <-clientDone:
- }
- return nil
+ if rt.done() {
+ t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
+ }
+
+ cancel()
+ tc.sync()
+ if rt.err() != context.Canceled {
+ t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
}
- ct.run()
}
func TestTransportContentLengthWithoutBody(t *testing.T) {
@@ -6251,20 +5070,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
testTransportClosesConnAfterGoAway(t, 1)
}
-type closeOnceConn struct {
- net.Conn
- closed uint32
-}
-
-var errClosed = errors.New("Close of closed connection")
-
-func (c *closeOnceConn) Close() error {
- if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
- return c.Conn.Close()
- }
- return errClosed
-}
-
// testTransportClosesConnAfterGoAway verifies that the transport
// closes a connection after reading a GOAWAY from it.
//
@@ -6272,53 +5077,35 @@ func (c *closeOnceConn) Close() error {
// When 0, the transport (unsuccessfully) retries the request (stream 1);
// when 1, the transport reads the response after receiving the GOAWAY.
func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
- ct := newClientTester(t)
- ct.cc = &closeOnceConn{Conn: ct.cc}
-
- var wg sync.WaitGroup
- wg.Add(1)
- ct.client = func() error {
- defer wg.Done()
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- }
- if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
- t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
- }
- if err = ct.cc.Close(); err != errClosed {
- return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err)
- }
- return nil
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeGoAway(lastStream, ErrCodeNo, nil)
+
+ if lastStream > 0 {
+ // Send a valid response to first request.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
}
- ct.server = func() error {
- defer wg.Wait()
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- return fmt.Errorf("server failed reading HEADERS: %v", err)
- }
- if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil {
- return fmt.Errorf("server failed writing GOAWAY: %v", err)
- }
- if lastStream > 0 {
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- return nil
+ tc.closeWrite(io.EOF)
+ err := rt.err()
+ if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
+ t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
+ }
+ if !tc.netConnClosed {
+ t.Errorf("ClientConn did not close its net.Conn, expected it to")
}
-
- ct.run()
}
type slowCloser struct {
@@ -6520,3 +5307,32 @@ func TestDialRaceResumesDial(t *testing.T) {
case <-successCh:
}
}
+
+func TestTransportDataAfter1xxHeader(t *testing.T) {
+ // Discard logger output to avoid spamming stderr.
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame.
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "100",
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte{0})
+ err := rt.err()
+ if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
+ t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err)
+ }
+ tc.wantFrameType(FrameRSTStream)
+}
diff --git a/internal/quic/cmd/interop/Dockerfile b/internal/quic/cmd/interop/Dockerfile
index 4b52e5356..b60999a86 100644
--- a/internal/quic/cmd/interop/Dockerfile
+++ b/internal/quic/cmd/interop/Dockerfile
@@ -9,7 +9,7 @@ ENV GOVERSION=1.21.1
RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \
filename="go${GOVERSION}.${platform}.tar.gz" && \
- wget https://dl.google.com/go/${filename} && \
+ wget --no-verbose https://dl.google.com/go/${filename} && \
tar xfz ${filename} && \
rm ${filename}
diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go
index cc5292e9e..5b652a2b1 100644
--- a/internal/quic/cmd/interop/main.go
+++ b/internal/quic/cmd/interop/main.go
@@ -18,21 +18,24 @@ import (
"fmt"
"io"
"log"
+ "log/slog"
"net"
"net/url"
"os"
"path/filepath"
"sync"
- "golang.org/x/net/internal/quic"
+ "golang.org/x/net/quic"
+ "golang.org/x/net/quic/qlog"
)
var (
- listen = flag.String("listen", "", "listen address")
- cert = flag.String("cert", "", "certificate")
- pkey = flag.String("key", "", "private key")
- root = flag.String("root", "", "serve files from this root")
- output = flag.String("output", "", "directory to write files to")
+ listen = flag.String("listen", "", "listen address")
+ cert = flag.String("cert", "", "certificate")
+ pkey = flag.String("key", "", "private key")
+ root = flag.String("root", "", "serve files from this root")
+ output = flag.String("output", "", "directory to write files to")
+ qlogdir = flag.String("qlog", "", "directory to write qlog output to")
)
func main() {
@@ -48,6 +51,10 @@ func main() {
},
MaxBidiRemoteStreams: -1,
MaxUniRemoteStreams: -1,
+ QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: quic.QLogLevelFrame,
+ Dir: *qlogdir,
+ })),
}
if *cert != "" {
c, err := tls.LoadX509KeyPair(*cert, *pkey)
@@ -141,7 +148,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) {
g.Add(1)
go func() {
defer g.Done()
- fetchFrom(ctx, l, addr, u)
+ fetchFrom(ctx, config, l, addr, u)
}()
}
@@ -150,7 +157,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) {
}
}
-func serve(ctx context.Context, l *quic.Listener) error {
+func serve(ctx context.Context, l *quic.Endpoint) error {
for {
c, err := l.Accept(ctx)
if err != nil {
@@ -214,8 +221,8 @@ func parseURL(s string) (u *url.URL, authority string, err error) {
return u, authority, nil
}
-func fetchFrom(ctx context.Context, l *quic.Listener, addr string, urls []*url.URL) {
- conn, err := l.Dial(ctx, "udp", addr)
+func fetchFrom(ctx context.Context, config *quic.Config, l *quic.Endpoint, addr string, urls []*url.URL) {
+ conn, err := l.Dial(ctx, "udp", addr, config)
if err != nil {
log.Printf("%v: %v", addr, err)
return
diff --git a/internal/quic/cmd/interop/run_endpoint.sh b/internal/quic/cmd/interop/run_endpoint.sh
index d72335d8e..442039bc0 100644
--- a/internal/quic/cmd/interop/run_endpoint.sh
+++ b/internal/quic/cmd/interop/run_endpoint.sh
@@ -11,7 +11,7 @@
if [ "$ROLE" == "client" ]; then
# Wait for the simulator to start up.
/wait-for-it.sh sim:57832 -s -t 30
- ./interop -output=/downloads $CLIENT_PARAMS $REQUESTS
+ ./interop -output=/downloads -qlog=$QLOGDIR $CLIENT_PARAMS $REQUESTS
elif [ "$ROLE" == "server" ]; then
- ./interop -cert=/certs/cert.pem -key=/certs/priv.key -listen=:443 -root=/www "$@" $SERVER_PARAMS
+ ./interop -cert=/certs/cert.pem -key=/certs/priv.key -qlog=$QLOGDIR -listen=:443 -root=/www "$@" $SERVER_PARAMS
fi
diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go
deleted file mode 100644
index a9ef0db5e..000000000
--- a/internal/quic/conn_close.go
+++ /dev/null
@@ -1,258 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "context"
- "errors"
- "time"
-)
-
-// lifetimeState tracks the state of a connection.
-//
-// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps
-// reason about operations that cause state transitions.
-type lifetimeState struct {
- readyc chan struct{} // closed when TLS handshake completes
- drainingc chan struct{} // closed when entering the draining state
-
- // Possible states for the connection:
- //
- // Alive: localErr and finalErr are both nil.
- //
- // Closing: localErr is non-nil and finalErr is nil.
- // We have sent a CONNECTION_CLOSE to the peer or are about to
- // (if connCloseSentTime is zero) and are waiting for the peer to respond.
- // drainEndTime is set to the time the closing state ends.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.1
- //
- // Draining: finalErr is non-nil.
- // If localErr is nil, we're waiting for the user to provide us with a final status
- // to send to the peer.
- // Otherwise, we've either sent a CONNECTION_CLOSE to the peer or are about to
- // (if connCloseSentTime is zero).
- // drainEndTime is set to the time the draining state ends.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
- localErr error // error sent to the peer
- finalErr error // error sent by the peer, or transport error; always set before draining
-
- connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame
- connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent
- drainEndTime time.Time // time the connection exits the draining state
-}
-
-func (c *Conn) lifetimeInit() {
- c.lifetime.readyc = make(chan struct{})
- c.lifetime.drainingc = make(chan struct{})
-}
-
-var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE")
-
-// advance is called when time passes.
-func (c *Conn) lifetimeAdvance(now time.Time) (done bool) {
- if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) {
- return false
- }
- // The connection drain period has ended, and we can shut down.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7
- c.lifetime.drainEndTime = time.Time{}
- if c.lifetime.finalErr == nil {
- // The peer never responded to our CONNECTION_CLOSE.
- c.enterDraining(now, errNoPeerResponse)
- }
- return true
-}
-
-// confirmHandshake is called when the TLS handshake completes.
-func (c *Conn) handshakeDone() {
- close(c.lifetime.readyc)
-}
-
-// isDraining reports whether the conn is in the draining state.
-//
-// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame.
-// The endpoint will no longer send any packets, but we retain knowledge of the connection
-// until the end of the drain period to ensure we discard packets for the connection
-// rather than treating them as starting a new connection.
-//
-// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
-func (c *Conn) isDraining() bool {
- return c.lifetime.finalErr != nil
-}
-
-// isClosingOrDraining reports whether the conn is in the closing or draining states.
-func (c *Conn) isClosingOrDraining() bool {
- return c.lifetime.localErr != nil || c.lifetime.finalErr != nil
-}
-
-// sendOK reports whether the conn can send frames at this time.
-func (c *Conn) sendOK(now time.Time) bool {
- if !c.isClosingOrDraining() {
- return true
- }
- // We are closing or draining.
- if c.lifetime.localErr == nil {
- // We're waiting for the user to close the connection, providing us with
- // a final status to send to the peer.
- return false
- }
- // Past this point, returning true will result in the conn sending a CONNECTION_CLOSE
- // due to localErr being set.
- if c.lifetime.drainEndTime.IsZero() {
- // The closing and draining states should last for at least three times
- // the current PTO interval. We currently use exactly that minimum.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-5
- //
- // The drain period begins when we send or receive a CONNECTION_CLOSE,
- // whichever comes first.
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2-3
- c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
- }
- if c.lifetime.connCloseSentTime.IsZero() {
- // We haven't sent a CONNECTION_CLOSE yet. Do so.
- // Either we're initiating an immediate close
- // (and will enter the closing state as soon as we send CONNECTION_CLOSE),
- // or we've read a CONNECTION_CLOSE from our peer
- // (and may send one CONNECTION_CLOSE before entering the draining state).
- //
- // Set the initial delay before we will send another CONNECTION_CLOSE.
- //
- // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames,
- // but leaves the implementation of the limit up to us. Here, we start
- // with the same delay as the PTO timer (RFC 9002, Section 6.2.1),
- // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent.
- c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity)
- c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
- return true
- }
- if c.isDraining() {
- // We are in the draining state, and will send no more packets.
- return false
- }
- maxRecvTime := c.acks[initialSpace].maxRecvTime
- if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) {
- maxRecvTime = t
- }
- if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) {
- maxRecvTime = t
- }
- if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) {
- // After sending CONNECTION_CLOSE, ignore packets from the peer for
- // a delay. On the next packet received after the delay, send another
- // CONNECTION_CLOSE.
- return false
- }
- c.lifetime.connCloseSentTime = now
- c.lifetime.connCloseDelay *= 2
- return true
-}
-
-// enterDraining enters the draining state.
-func (c *Conn) enterDraining(now time.Time, err error) {
- if c.isDraining() {
- return
- }
- if err == errStatelessReset {
- // If we've received a stateless reset, then we must not send a CONNECTION_CLOSE.
- // Setting connCloseSentTime here prevents us from doing so.
- c.lifetime.finalErr = errStatelessReset
- c.lifetime.localErr = errStatelessReset
- c.lifetime.connCloseSentTime = now
- } else if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo {
- // If we've terminated the connection due to a peer protocol violation,
- // record the final error on the connection as our reason for termination.
- c.lifetime.finalErr = c.lifetime.localErr
- } else {
- c.lifetime.finalErr = err
- }
- close(c.lifetime.drainingc)
- c.streams.queue.close(c.lifetime.finalErr)
-}
-
-func (c *Conn) waitReady(ctx context.Context) error {
- select {
- case <-c.lifetime.readyc:
- return nil
- case <-c.lifetime.drainingc:
- return c.lifetime.finalErr
- default:
- }
- select {
- case <-c.lifetime.readyc:
- return nil
- case <-c.lifetime.drainingc:
- return c.lifetime.finalErr
- case <-ctx.Done():
- return ctx.Err()
- }
-}
-
-// Close closes the connection.
-//
-// Close is equivalent to:
-//
-// conn.Abort(nil)
-// err := conn.Wait(context.Background())
-func (c *Conn) Close() error {
- c.Abort(nil)
- <-c.lifetime.drainingc
- return c.lifetime.finalErr
-}
-
-// Wait waits for the peer to close the connection.
-//
-// If the connection is closed locally and the peer does not close its end of the connection,
-// Wait will return with a non-nil error after the drain period expires.
-//
-// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil.
-// If the peer closes the connection with an application error, Wait returns an ApplicationError
-// containing the peer's error code and reason.
-// If the peer closes the connection with any other status, Wait returns a non-nil error.
-func (c *Conn) Wait(ctx context.Context) error {
- if err := c.waitOnDone(ctx, c.lifetime.drainingc); err != nil {
- return err
- }
- return c.lifetime.finalErr
-}
-
-// Abort closes the connection and returns immediately.
-//
-// If err is nil, Abort sends a transport error of NO_ERROR to the peer.
-// If err is an ApplicationError, Abort sends its error code and text.
-// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text.
-func (c *Conn) Abort(err error) {
- if err == nil {
- err = localTransportError{code: errNo}
- }
- c.sendMsg(func(now time.Time, c *Conn) {
- c.abort(now, err)
- })
-}
-
-// abort terminates a connection with an error.
-func (c *Conn) abort(now time.Time, err error) {
- if c.lifetime.localErr != nil {
- return // already closing
- }
- c.lifetime.localErr = err
-}
-
-// abortImmediately terminates a connection.
-// The connection does not send a CONNECTION_CLOSE, and skips the draining period.
-func (c *Conn) abortImmediately(now time.Time, err error) {
- c.abort(now, err)
- c.enterDraining(now, err)
- c.exited = true
-}
-
-// exit fully terminates a connection immediately.
-func (c *Conn) exit() {
- c.sendMsg(func(now time.Time, c *Conn) {
- c.enterDraining(now, errors.New("connection closed"))
- c.exited = true
- })
-}
diff --git a/internal/quic/doc.go b/internal/quic/doc.go
deleted file mode 100644
index 2fe17fe22..000000000
--- a/internal/quic/doc.go
+++ /dev/null
@@ -1,9 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package quic is an experimental, incomplete implementation of the QUIC protocol.
-// This package is a work in progress, and is not ready for use at this time.
-//
-// This package implements (or will implement) RFC 9000, RFC 9001, and RFC 9002.
-package quic
diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go
deleted file mode 100644
index 037fb21b4..000000000
--- a/internal/quic/listener_test.go
+++ /dev/null
@@ -1,319 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.21
-
-package quic
-
-import (
- "bytes"
- "context"
- "crypto/tls"
- "io"
- "net"
- "net/netip"
- "reflect"
- "testing"
- "time"
-)
-
-func TestConnect(t *testing.T) {
- newLocalConnPair(t, &Config{}, &Config{})
-}
-
-func TestStreamTransfer(t *testing.T) {
- ctx := context.Background()
- cli, srv := newLocalConnPair(t, &Config{}, &Config{})
- data := makeTestData(1 << 20)
-
- srvdone := make(chan struct{})
- go func() {
- defer close(srvdone)
- s, err := srv.AcceptStream(ctx)
- if err != nil {
- t.Errorf("AcceptStream: %v", err)
- return
- }
- b, err := io.ReadAll(s)
- if err != nil {
- t.Errorf("io.ReadAll(s): %v", err)
- return
- }
- if !bytes.Equal(b, data) {
- t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
- }
- if err := s.Close(); err != nil {
- t.Errorf("s.Close() = %v", err)
- }
- }()
-
- s, err := cli.NewStream(ctx)
- if err != nil {
- t.Fatalf("NewStream: %v", err)
- }
- n, err := io.Copy(s, bytes.NewBuffer(data))
- if n != int64(len(data)) || err != nil {
- t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
- }
- if err := s.Close(); err != nil {
- t.Fatalf("s.Close() = %v", err)
- }
-}
-
-func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
- t.Helper()
- ctx := context.Background()
- l1 := newLocalListener(t, serverSide, conf1)
- l2 := newLocalListener(t, clientSide, conf2)
- c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
- if err != nil {
- t.Fatal(err)
- }
- c1, err := l1.Accept(ctx)
- if err != nil {
- t.Fatal(err)
- }
- return c2, c1
-}
-
-func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
- t.Helper()
- if conf.TLSConfig == nil {
- newConf := *conf
- conf = &newConf
- conf.TLSConfig = newTestTLSConfig(side)
- }
- l, err := Listen("udp", "127.0.0.1:0", conf)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() {
- l.Close(context.Background())
- })
- return l
-}
-
-type testListener struct {
- t *testing.T
- l *Listener
- now time.Time
- recvc chan *datagram
- idlec chan struct{}
- conns map[*Conn]*testConn
- acceptQueue []*testConn
- configTransportParams []func(*transportParameters)
- configTestConn []func(*testConn)
- sentDatagrams [][]byte
- peerTLSConn *tls.QUICConn
- lastInitialDstConnID []byte // for parsing Retry packets
-}
-
-func newTestListener(t *testing.T, config *Config) *testListener {
- tl := &testListener{
- t: t,
- now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
- recvc: make(chan *datagram),
- idlec: make(chan struct{}),
- conns: make(map[*Conn]*testConn),
- }
- var err error
- tl.l, err = newListener((*testListenerUDPConn)(tl), config, (*testListenerHooks)(tl))
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(tl.cleanup)
- return tl
-}
-
-func (tl *testListener) cleanup() {
- tl.l.Close(canceledContext())
-}
-
-func (tl *testListener) wait() {
- select {
- case tl.idlec <- struct{}{}:
- case <-tl.l.closec:
- }
- for _, tc := range tl.conns {
- tc.wait()
- }
-}
-
-// accept returns a server connection from the listener.
-// Unlike Listener.Accept, connections are available as soon as they are created.
-func (tl *testListener) accept() *testConn {
- if len(tl.acceptQueue) == 0 {
- tl.t.Fatalf("accept: expected available conn, but found none")
- }
- tc := tl.acceptQueue[0]
- tl.acceptQueue = tl.acceptQueue[1:]
- return tc
-}
-
-func (tl *testListener) write(d *datagram) {
- tl.recvc <- d
- tl.wait()
-}
-
-var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
-
-func (tl *testListener) writeDatagram(d *testDatagram) {
- tl.t.Helper()
- logDatagram(tl.t, "<- listener under test receives", d)
- var buf []byte
- for _, p := range d.packets {
- tc := tl.connForDestination(p.dstConnID)
- if p.ptype != packetTypeRetry && tc != nil {
- space := spaceForPacketType(p.ptype)
- if p.num >= tc.peerNextPacketNum[space] {
- tc.peerNextPacketNum[space] = p.num + 1
- }
- }
- if p.ptype == packetTypeInitial {
- tl.lastInitialDstConnID = p.dstConnID
- }
- pad := 0
- if p.ptype == packetType1RTT {
- pad = d.paddedSize - len(buf)
- }
- buf = append(buf, encodeTestPacket(tl.t, tc, p, pad)...)
- }
- for len(buf) < d.paddedSize {
- buf = append(buf, 0)
- }
- addr := d.addr
- if !addr.IsValid() {
- addr = testClientAddr
- }
- tl.write(&datagram{
- b: buf,
- addr: addr,
- })
-}
-
-func (tl *testListener) connForDestination(dstConnID []byte) *testConn {
- for _, tc := range tl.conns {
- for _, loc := range tc.conn.connIDState.local {
- if bytes.Equal(loc.cid, dstConnID) {
- return tc
- }
- }
- }
- return nil
-}
-
-func (tl *testListener) connForSource(srcConnID []byte) *testConn {
- for _, tc := range tl.conns {
- for _, loc := range tc.conn.connIDState.remote {
- if bytes.Equal(loc.cid, srcConnID) {
- return tc
- }
- }
- }
- return nil
-}
-
-func (tl *testListener) read() []byte {
- tl.t.Helper()
- tl.wait()
- if len(tl.sentDatagrams) == 0 {
- return nil
- }
- d := tl.sentDatagrams[0]
- tl.sentDatagrams = tl.sentDatagrams[1:]
- return d
-}
-
-func (tl *testListener) readDatagram() *testDatagram {
- tl.t.Helper()
- buf := tl.read()
- if buf == nil {
- return nil
- }
- p, _ := parseGenericLongHeaderPacket(buf)
- tc := tl.connForSource(p.dstConnID)
- d := parseTestDatagram(tl.t, tl, tc, buf)
- logDatagram(tl.t, "-> listener under test sends", d)
- return d
-}
-
-// wantDatagram indicates that we expect the Listener to send a datagram.
-func (tl *testListener) wantDatagram(expectation string, want *testDatagram) {
- tl.t.Helper()
- got := tl.readDatagram()
- if !reflect.DeepEqual(got, want) {
- tl.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
- }
-}
-
-// wantIdle indicates that we expect the Listener to not send any more datagrams.
-func (tl *testListener) wantIdle(expectation string) {
- if got := tl.readDatagram(); got != nil {
- tl.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
- }
-}
-
-// advance causes time to pass.
-func (tl *testListener) advance(d time.Duration) {
- tl.t.Helper()
- tl.advanceTo(tl.now.Add(d))
-}
-
-// advanceTo sets the current time.
-func (tl *testListener) advanceTo(now time.Time) {
- tl.t.Helper()
- if tl.now.After(now) {
- tl.t.Fatalf("time moved backwards: %v -> %v", tl.now, now)
- }
- tl.now = now
- for _, tc := range tl.conns {
- if !tc.timer.After(tl.now) {
- tc.conn.sendMsg(timerEvent{})
- tc.wait()
- }
- }
-}
-
-// testListenerHooks implements listenerTestHooks.
-type testListenerHooks testListener
-
-func (tl *testListenerHooks) timeNow() time.Time {
- return tl.now
-}
-
-func (tl *testListenerHooks) newConn(c *Conn) {
- tc := newTestConnForConn(tl.t, (*testListener)(tl), c)
- tl.conns[c] = tc
-}
-
-// testListenerUDPConn implements UDPConn.
-type testListenerUDPConn testListener
-
-func (tl *testListenerUDPConn) Close() error {
- close(tl.recvc)
- return nil
-}
-
-func (tl *testListenerUDPConn) LocalAddr() net.Addr {
- return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
-}
-
-func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
- for {
- select {
- case d, ok := <-tl.recvc:
- if !ok {
- return 0, 0, 0, netip.AddrPort{}, io.EOF
- }
- n = copy(b, d.b)
- return n, 0, 0, d.addr, nil
- case <-tl.idlec:
- }
- }
-}
-
-func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
- tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...))
- return len(b), nil
-}
diff --git a/internal/quic/ack_delay.go b/quic/ack_delay.go
similarity index 100%
rename from internal/quic/ack_delay.go
rename to quic/ack_delay.go
diff --git a/internal/quic/ack_delay_test.go b/quic/ack_delay_test.go
similarity index 100%
rename from internal/quic/ack_delay_test.go
rename to quic/ack_delay_test.go
diff --git a/internal/quic/acks.go b/quic/acks.go
similarity index 91%
rename from internal/quic/acks.go
rename to quic/acks.go
index ba860efb2..039b7b46e 100644
--- a/internal/quic/acks.go
+++ b/quic/acks.go
@@ -130,12 +130,19 @@ func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bo
// there are no gaps. If it does not, there must be a gap.
return true
}
- if acks.unackedAckEliciting >= 2 {
- // "[...] after receiving at least two ack-eliciting packets."
- // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2
- return true
+ // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2
+ //
+ // This ack frequency takes a substantial toll on performance, however.
+ // Follow the behavior of Google QUICHE:
+ // Ack every other packet for the first 100 packets, and then ack every 10th packet.
+ // This keeps ack frequency high during the beginning of slow start when CWND is
+ // increasing rapidly.
+ packetsBeforeAck := 2
+ if acks.seen.max() > 100 {
+ packetsBeforeAck = 10
}
- return false
+ return acks.unackedAckEliciting >= packetsBeforeAck
}
// shouldSendAck reports whether the connection should send an ACK frame at this time,
diff --git a/internal/quic/acks_test.go b/quic/acks_test.go
similarity index 94%
rename from internal/quic/acks_test.go
rename to quic/acks_test.go
index 4f1032910..d10f917ad 100644
--- a/internal/quic/acks_test.go
+++ b/quic/acks_test.go
@@ -7,6 +7,7 @@
package quic
import (
+ "slices"
"testing"
"time"
)
@@ -198,7 +199,7 @@ func TestAcksSent(t *testing.T) {
if len(gotNums) == 0 {
wantDelay = 0
}
- if !slicesEqual(gotNums, test.wantAcks) || gotDelay != wantDelay {
+ if !slices.Equal(gotNums, test.wantAcks) || gotDelay != wantDelay {
t.Errorf("acks.acksToSend(T+%v) = %v, %v; want %v, %v", delay, gotNums, gotDelay, test.wantAcks, wantDelay)
}
}
@@ -206,20 +207,6 @@ func TestAcksSent(t *testing.T) {
}
}
-// slicesEqual reports whether two slices are equal.
-// Replace this with slices.Equal once the module go.mod is go1.17 or newer.
-func slicesEqual[E comparable](s1, s2 []E) bool {
- if len(s1) != len(s2) {
- return false
- }
- for i := range s1 {
- if s1[i] != s2[i] {
- return false
- }
- }
- return true
-}
-
func TestAcksDiscardAfterAck(t *testing.T) {
acks := ackState{}
now := time.Now()
diff --git a/internal/quic/atomic_bits.go b/quic/atomic_bits.go
similarity index 100%
rename from internal/quic/atomic_bits.go
rename to quic/atomic_bits.go
diff --git a/quic/bench_test.go b/quic/bench_test.go
new file mode 100644
index 000000000..636b71327
--- /dev/null
+++ b/quic/bench_test.go
@@ -0,0 +1,170 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "math"
+ "sync"
+ "testing"
+)
+
+// BenchmarkThroughput is based on the crypto/tls benchmark of the same name.
+func BenchmarkThroughput(b *testing.B) {
+ for size := 1; size <= 64; size <<= 1 {
+ name := fmt.Sprintf("%dMiB", size)
+ b.Run(name, func(b *testing.B) {
+ throughput(b, int64(size<<20))
+ })
+ }
+}
+
+func throughput(b *testing.B, totalBytes int64) {
+ // Same buffer size as crypto/tls's BenchmarkThroughput, for consistency.
+ const bufsize = 32 << 10
+
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ go func() {
+ buf := make([]byte, bufsize)
+ for i := 0; i < b.N; i++ {
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ if _, err := io.CopyBuffer(sconn, sconn, buf); err != nil {
+ panic(fmt.Errorf("CopyBuffer: %v", err))
+ }
+ sconn.Close()
+ }
+ }()
+
+ b.SetBytes(totalBytes)
+ buf := make([]byte, bufsize)
+ chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
+ for i := 0; i < b.N; i++ {
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ closec := make(chan struct{})
+ go func() {
+ defer close(closec)
+ buf := make([]byte, bufsize)
+ if _, err := io.CopyBuffer(io.Discard, cconn, buf); err != nil {
+ panic(fmt.Errorf("Discard: %v", err))
+ }
+ }()
+ for j := 0; j < chunks; j++ {
+ _, err := cconn.Write(buf)
+ if err != nil {
+ b.Fatalf("Write: %v", err)
+ }
+ }
+ cconn.CloseWrite()
+ <-closec
+ cconn.Close()
+ }
+}
+
+func BenchmarkReadByte(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ buf := make([]byte, 1<<20)
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ for {
+ if _, err := sconn.Write(buf); err != nil {
+ break
+ }
+ sconn.Flush()
+ }
+ }()
+
+ b.SetBytes(1)
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Flush()
+ for i := 0; i < b.N; i++ {
+ _, err := cconn.ReadByte()
+ if err != nil {
+ b.Fatalf("ReadByte: %v", err)
+ }
+ }
+ cconn.Close()
+}
+
+func BenchmarkWriteByte(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ n, err := io.Copy(io.Discard, sconn)
+ if n != int64(b.N) || err != nil {
+ b.Errorf("server io.Copy() = %v, %v; want %v, nil", n, err, b.N)
+ }
+ }()
+
+ b.SetBytes(1)
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Flush()
+ for i := 0; i < b.N; i++ {
+ if err := cconn.WriteByte(0); err != nil {
+ b.Fatalf("WriteByte: %v", err)
+ }
+ }
+ cconn.Close()
+}
+
+func BenchmarkStreamCreation(b *testing.B) {
+ cli, srv := newLocalConnPair(b, &Config{}, &Config{})
+
+ go func() {
+ for i := 0; i < b.N; i++ {
+ sconn, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ panic(fmt.Errorf("AcceptStream: %v", err))
+ }
+ sconn.Close()
+ }
+ }()
+
+ buf := make([]byte, 1)
+ for i := 0; i < b.N; i++ {
+ cconn, err := cli.NewStream(context.Background())
+ if err != nil {
+ b.Fatalf("NewStream: %v", err)
+ }
+ cconn.Write(buf)
+ cconn.Flush()
+ cconn.Read(buf)
+ cconn.Close()
+ }
+}
diff --git a/internal/quic/config.go b/quic/config.go
similarity index 64%
rename from internal/quic/config.go
rename to quic/config.go
index 6278bf89c..5d420312b 100644
--- a/internal/quic/config.go
+++ b/quic/config.go
@@ -8,6 +8,9 @@ package quic
import (
"crypto/tls"
+ "log/slog"
+ "math"
+ "time"
)
// A Config structure configures a QUIC endpoint.
@@ -72,9 +75,46 @@ type Config struct {
//
// If this field is left as zero, stateless reset is disabled.
StatelessResetKey [32]byte
+
+ // HandshakeTimeout is the maximum time in which a connection handshake must complete.
+ // If zero, the default of 10 seconds is used.
+ // If negative, there is no handshake timeout.
+ HandshakeTimeout time.Duration
+
+ // MaxIdleTimeout is the maximum time after which an idle connection will be closed.
+ // If zero, the default of 30 seconds is used.
+ // If negative, idle connections are never closed.
+ //
+ // The idle timeout for a connection is the minimum of the maximum idle timeouts
+ // of the endpoints.
+ MaxIdleTimeout time.Duration
+
+ // KeepAlivePeriod is the time after which a packet will be sent to keep
+ // an idle connection alive.
+ // If zero, keep alive packets are not sent.
+ // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and
+ // half the connection idle timeout.
+ KeepAlivePeriod time.Duration
+
+ // QLogLogger receives qlog events.
+ //
+ // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03.
+ // This is not the latest version of the draft, but is the latest version supported
+ // by common event log viewers as of the time this paragraph was written.
+ //
+ // The qlog package contains a slog.Handler which serializes qlog events
+ // to a standard JSON representation.
+ QLogLogger *slog.Logger
+}
+
+// Clone returns a shallow clone of c, or nil if c is nil.
+// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint.
+func (c *Config) Clone() *Config {
+ n := *c
+ return &n
}
-func configDefault(v, def, limit int64) int64 {
+func configDefault[T ~int64](v, def, limit T) T {
switch {
case v == 0:
return def
@@ -104,3 +144,15 @@ func (c *Config) maxStreamWriteBufferSize() int64 {
func (c *Config) maxConnReadBufferSize() int64 {
return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint)
}
+
+func (c *Config) handshakeTimeout() time.Duration {
+ return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64)
+}
+
+func (c *Config) maxIdleTimeout() time.Duration {
+ return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64)
+}
+
+func (c *Config) keepAlivePeriod() time.Duration {
+ return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64)
+}
diff --git a/internal/quic/config_test.go b/quic/config_test.go
similarity index 100%
rename from internal/quic/config_test.go
rename to quic/config_test.go
diff --git a/internal/quic/congestion_reno.go b/quic/congestion_reno.go
similarity index 83%
rename from internal/quic/congestion_reno.go
rename to quic/congestion_reno.go
index 982cbf4bb..a53983524 100644
--- a/internal/quic/congestion_reno.go
+++ b/quic/congestion_reno.go
@@ -7,6 +7,8 @@
package quic
import (
+ "context"
+ "log/slog"
"math"
"time"
)
@@ -40,6 +42,9 @@ type ccReno struct {
// true if we haven't sent that packet yet.
sendOnePacketInRecovery bool
+ // inRecovery is set when we are in the recovery state.
+ inRecovery bool
+
// underutilized is set if the congestion window is underutilized
// due to insufficient application data, flow control limits, or
// anti-amplification limits.
@@ -100,12 +105,19 @@ func (c *ccReno) canSend() bool {
// congestion controller permits sending data, but no data is sent.
//
// https://www.rfc-editor.org/rfc/rfc9002#section-7.8
-func (c *ccReno) setUnderutilized(v bool) {
+func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) {
+ if c.underutilized == v {
+ return
+ }
+ oldState := c.state()
c.underutilized = v
+ if logEnabled(log, QLogLevelPacket) {
+ logCongestionStateUpdated(log, oldState, c.state())
+ }
}
// packetSent indicates that a packet has been sent.
-func (c *ccReno) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
if !sent.inFlight {
return
}
@@ -185,7 +197,11 @@ func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket,
}
// packetBatchEnd is called at the end of processing a batch of acked or lost packets.
-func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, maxAckDelay time.Duration) {
+func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) {
+ if logEnabled(log, QLogLevelPacket) {
+ oldState := c.state()
+ defer func() { logCongestionStateUpdated(log, oldState, c.state()) }()
+ }
if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) {
// Enter the recovery state.
// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2
@@ -196,8 +212,10 @@ func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState,
// Clear congestionPendingAcks to avoid increasing the congestion
// window based on acks in a frame that sends us into recovery.
c.congestionPendingAcks = 0
+ c.inRecovery = true
} else if c.congestionPendingAcks > 0 {
// We are in slow start or congestion avoidance.
+ c.inRecovery = false
if c.congestionWindow < c.slowStartThreshold {
// When the congestion window is less than the slow start threshold,
// we are in slow start and increase the window by the number of
@@ -253,3 +271,38 @@ func (c *ccReno) minimumCongestionWindow() int {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4
return 2 * c.maxDatagramSize
}
+
+func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) {
+ if oldState == newState {
+ return
+ }
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:congestion_state_updated",
+ slog.String("old", oldState.String()),
+ slog.String("new", newState.String()),
+ )
+}
+
+type congestionState string
+
+func (s congestionState) String() string { return string(s) }
+
+const (
+ congestionSlowStart = congestionState("slow_start")
+ congestionCongestionAvoidance = congestionState("congestion_avoidance")
+ congestionApplicationLimited = congestionState("application_limited")
+ congestionRecovery = congestionState("recovery")
+)
+
+func (c *ccReno) state() congestionState {
+ switch {
+ case c.inRecovery:
+ return congestionRecovery
+ case c.underutilized:
+ return congestionApplicationLimited
+ case c.congestionWindow < c.slowStartThreshold:
+ return congestionSlowStart
+ default:
+ return congestionCongestionAvoidance
+ }
+}
diff --git a/internal/quic/congestion_reno_test.go b/quic/congestion_reno_test.go
similarity index 99%
rename from internal/quic/congestion_reno_test.go
rename to quic/congestion_reno_test.go
index e9af6452c..cda7a90a8 100644
--- a/internal/quic/congestion_reno_test.go
+++ b/quic/congestion_reno_test.go
@@ -470,7 +470,7 @@ func (c *ccTest) setRTT(smoothedRTT, rttvar time.Duration) {
func (c *ccTest) setUnderutilized(v bool) {
c.t.Helper()
c.t.Logf("set underutilized = %v", v)
- c.cc.setUnderutilized(v)
+ c.cc.setUnderutilized(nil, v)
}
func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket)) *sentPacket {
@@ -488,7 +488,7 @@ func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket
f(sent)
}
c.t.Logf("packet sent: num=%v.%v, size=%v", space, sent.num, sent.size)
- c.cc.packetSent(c.now, space, sent)
+ c.cc.packetSent(c.now, nil, space, sent)
return sent
}
@@ -519,7 +519,7 @@ func (c *ccTest) packetDiscarded(space numberSpace, sent *sentPacket) {
func (c *ccTest) packetBatchEnd(space numberSpace) {
c.t.Helper()
c.t.Logf("(end of batch)")
- c.cc.packetBatchEnd(c.now, space, &c.rtt, c.maxAckDelay)
+ c.cc.packetBatchEnd(c.now, nil, space, &c.rtt, c.maxAckDelay)
}
func (c *ccTest) wantCanSend(want bool) {
diff --git a/internal/quic/conn.go b/quic/conn.go
similarity index 82%
rename from internal/quic/conn.go
rename to quic/conn.go
index 1292f2b20..38e8fe8f4 100644
--- a/internal/quic/conn.go
+++ b/quic/conn.go
@@ -11,6 +11,7 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "log/slog"
"net/netip"
"time"
)
@@ -20,26 +21,23 @@ import (
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
side connSide
- listener *Listener
+ endpoint *Endpoint
config *Config
testHooks connTestHooks
peerAddr netip.AddrPort
+ localAddr netip.AddrPort
- msgc chan any
- donec chan struct{} // closed when conn loop exits
- exited bool // set to make the conn loop exit immediately
+ msgc chan any
+ donec chan struct{} // closed when conn loop exits
w packetWriter
acks [numberSpaceCount]ackState // indexed by number space
lifetime lifetimeState
+ idle idleState
connIDState connIDState
loss lossState
streams streamsState
-
- // idleTimeout is the time at which the connection will be closed due to inactivity.
- // https://www.rfc-editor.org/rfc/rfc9000#section-10.1
- maxIdleTimeout time.Duration
- idleTimeout time.Time
+ path pathState
// Packet protection keys, CRYPTO streams, and TLS state.
keysInitial fixedKeyPair
@@ -60,6 +58,8 @@ type Conn struct {
// Tests only: Send a PING in a specific number space.
testSendPingSpace numberSpace
testSendPing sentVal
+
+ log *slog.Logger
}
// connTestHooks override conn behavior in tests.
@@ -94,25 +94,31 @@ type newServerConnIDs struct {
retrySrcConnID []byte // source from server's Retry
}
-func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) {
+func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) {
c := &Conn{
side: side,
- listener: l,
+ endpoint: e,
config: config,
- peerAddr: peerAddr,
+ peerAddr: unmapAddrPort(peerAddr),
msgc: make(chan any, 1),
donec: make(chan struct{}),
- maxIdleTimeout: defaultMaxIdleTimeout,
- idleTimeout: now.Add(defaultMaxIdleTimeout),
peerAckDelayExponent: -1,
}
+ defer func() {
+ // If we hit an error in newConn, close donec so tests don't get stuck waiting for it.
+ // This is only relevant if we've got a bug, but it makes tracking that bug down
+ // much easier.
+ if conn == nil {
+ close(c.donec)
+ }
+ }()
// A one-element buffer allows us to wake a Conn's event loop as a
// non-blocking operation.
c.msgc = make(chan any, 1)
- if l.testHooks != nil {
- l.testHooks.newConn(c)
+ if e.testHooks != nil {
+ e.testHooks.newConn(c)
}
// initialConnID is the connection ID used to generate Initial packet protection keys.
@@ -132,15 +138,15 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip
}
}
- // The smallest allowed maximum QUIC datagram size is 1200 bytes.
// TODO: PMTU discovery.
- const maxDatagramSize = 1200
+ c.logConnectionStarted(cids.originalDstConnID, peerAddr)
c.keysAppData.init()
- c.loss.init(c.side, maxDatagramSize, now)
+ c.loss.init(c.side, smallestMaxDatagramSize, now)
c.streamsInit()
c.lifetimeInit()
+ c.restartIdleTimer(now)
- if err := c.startTLS(now, initialConnID, transportParameters{
+ if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
originalDstConnID: cids.originalDstConnID,
retrySrcConnID: cids.retrySrcConnID,
@@ -183,13 +189,14 @@ func (c *Conn) confirmHandshake(now time.Time) {
if c.side == serverSide {
// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
c.handshakeConfirmed.setUnsent()
- c.listener.serverConnEstablished(c)
+ c.endpoint.serverConnEstablished(c)
} else {
// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
// to the received state, indicating that the handshake is confirmed and we
// don't need to send anything.
c.handshakeConfirmed.setReceived()
}
+ c.restartIdleTimer(now)
c.loss.confirmHandshake()
// "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
@@ -205,7 +212,7 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) {
case handshakeSpace:
c.keysHandshake.discard()
}
- c.loss.discardKeys(now, space)
+ c.loss.discardKeys(now, c.log, space)
}
// receiveTransportParameters applies transport parameters sent by the peer.
@@ -220,6 +227,7 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error {
c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal
c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote
c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
+ c.receivePeerMaxIdleTimeout(p.maxIdleTimeout)
c.peerAckDelayExponent = p.ackDelayExponent
c.loss.setMaxAckDelay(p.maxAckDelay)
if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
@@ -236,7 +244,6 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error {
return err
}
}
- // TODO: max_idle_timeout
// TODO: stateless_reset_token
// TODO: max_udp_payload_size
// TODO: disable_active_migration
@@ -249,6 +256,8 @@ type (
wakeEvent struct{}
)
+var errIdleTimeout = errors.New("idle timeout")
+
// loop is the connection main loop.
//
// Except where otherwise noted, all connection state is owned by the loop goroutine.
@@ -256,9 +265,7 @@ type (
// The loop processes messages from c.msgc and timer events.
// Other goroutines may examine or modify conn state by sending the loop funcs to execute.
func (c *Conn) loop(now time.Time) {
- defer close(c.donec)
- defer c.tls.Close()
- defer c.listener.connDrained(c)
+ defer c.cleanup()
// The connection timer sends a message to the connection loop on expiry.
// We need to give it an expiry when creating it, so set the initial timeout to
@@ -275,14 +282,14 @@ func (c *Conn) loop(now time.Time) {
defer timer.Stop()
}
- for !c.exited {
+ for c.lifetime.state != connStateDone {
sendTimeout := c.maybeSend(now) // try sending
// Note that we only need to consider the ack timer for the App Data space,
// since the Initial and Handshake spaces always ack immediately.
nextTimeout := sendTimeout
- nextTimeout = firstTime(nextTimeout, c.idleTimeout)
- if !c.isClosingOrDraining() {
+ nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout)
+ if c.isAlive() {
nextTimeout = firstTime(nextTimeout, c.loss.timer)
nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck)
} else {
@@ -312,15 +319,17 @@ func (c *Conn) loop(now time.Time) {
}
switch m := m.(type) {
case *datagram:
- c.handleDatagram(now, m)
+ if !c.handleDatagram(now, m) {
+ if c.logEnabled(QLogLevelPacket) {
+ c.logPacketDropped(m)
+ }
+ }
m.recycle()
case timerEvent:
// A connection timer has expired.
- if !now.Before(c.idleTimeout) {
- // "[...] the connection is silently closed and
- // its state is discarded [...]"
- // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1
- c.exited = true
+ if c.idleAdvance(now) {
+ // The connection idle timer has expired.
+ c.abortImmediately(now, errIdleTimeout)
return
}
c.loss.advance(now, c.handleAckOrLoss)
@@ -340,6 +349,13 @@ func (c *Conn) loop(now time.Time) {
}
}
+func (c *Conn) cleanup() {
+ c.logConnectionClosed()
+ c.endpoint.connDrained(c)
+ c.tls.Close()
+ close(c.donec)
+}
+
// sendMsg sends a message to the conn's loop.
// It does not wait for the message to be processed.
// The conn may close before processing the message, in which case it is lost.
@@ -359,12 +375,37 @@ func (c *Conn) wake() {
}
// runOnLoop executes a function within the conn's loop goroutine.
-func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error {
+func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
donec := make(chan struct{})
- c.sendMsg(func(now time.Time, c *Conn) {
+ msg := func(now time.Time, c *Conn) {
defer close(donec)
f(now, c)
- })
+ }
+ if c.testHooks != nil {
+ // In tests, we can't rely on being able to send a message immediately:
+ // c.msgc might be full, and testConnHooks.nextMessage might be waiting
+ // for us to block before it processes the next message.
+ // To avoid a deadlock, we send the message in waitUntil.
+ // If msgc is empty, the message is buffered.
+ // If msgc is full, we block and let nextMessage process the queue.
+ msgc := c.msgc
+ c.testHooks.waitUntil(ctx, func() bool {
+ for {
+ select {
+ case msgc <- msg:
+ msgc = nil // send msg only once
+ case <-donec:
+ return true
+ case <-c.donec:
+ return true
+ default:
+ return false
+ }
+ }
+ })
+ } else {
+ c.sendMsg(msg)
+ }
select {
case <-donec:
case <-c.donec:
diff --git a/internal/quic/conn_async_test.go b/quic/conn_async_test.go
similarity index 94%
rename from internal/quic/conn_async_test.go
rename to quic/conn_async_test.go
index dc2a57f9d..4671f8340 100644
--- a/internal/quic/conn_async_test.go
+++ b/quic/conn_async_test.go
@@ -41,7 +41,7 @@ type asyncOp[T any] struct {
err error
caller string
- state *asyncTestState
+ tc *testConn
donec chan struct{}
cancelFunc context.CancelFunc
}
@@ -55,7 +55,7 @@ func (a *asyncOp[T]) cancel() {
default:
}
a.cancelFunc()
- <-a.state.notify
+ <-a.tc.asyncTestState.notify
select {
case <-a.donec:
default:
@@ -73,6 +73,7 @@ var errNotDone = errors.New("async op is not done")
// control over the progress of operations, an asyncOp can only
// become done in reaction to the test taking some action.
func (a *asyncOp[T]) result() (v T, err error) {
+ a.tc.wait()
select {
case <-a.donec:
return a.v, a.err
@@ -94,8 +95,8 @@ type asyncContextKey struct{}
// The function f should call a blocking function such as
// Stream.Write or Conn.AcceptStream and return its result.
// It must use the provided context.
-func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[T] {
- as := &ts.asyncTestState
+func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[T] {
+ as := &tc.asyncTestState
if as.notify == nil {
as.notify = make(chan struct{})
as.mu.Lock()
@@ -106,7 +107,7 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[
ctx := context.WithValue(context.Background(), asyncContextKey{}, true)
ctx, cancel := context.WithCancel(ctx)
a := &asyncOp[T]{
- state: as,
+ tc: tc,
caller: fmt.Sprintf("%v:%v", filepath.Base(file), line),
donec: make(chan struct{}),
cancelFunc: cancel,
@@ -116,14 +117,15 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[
close(a.donec)
as.notify <- struct{}{}
}()
- ts.t.Cleanup(func() {
+ tc.t.Cleanup(func() {
if _, err := a.result(); err == errNotDone {
- ts.t.Errorf("%v: async operation is still executing at end of test", a.caller)
+ tc.t.Errorf("%v: async operation is still executing at end of test", a.caller)
a.cancel()
}
})
// Wait for the operation to either finish or block.
<-as.notify
+ tc.wait()
return a
}
diff --git a/quic/conn_close.go b/quic/conn_close.go
new file mode 100644
index 000000000..1798d0536
--- /dev/null
+++ b/quic/conn_close.go
@@ -0,0 +1,331 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "errors"
+ "time"
+)
+
+// connState is the state of a connection.
+type connState int
+
+const (
+ // A connection is alive when it is first created.
+ connStateAlive = connState(iota)
+
+ // The connection has received a CONNECTION_CLOSE frame from the peer,
+ // and has not yet sent a CONNECTION_CLOSE in response.
+ //
+ // We will send a CONNECTION_CLOSE, and then enter the draining state.
+ connStatePeerClosed
+
+ // The connection is in the closing state.
+ //
+ // We will send CONNECTION_CLOSE frames to the peer
+ // (once upon entering the closing state, and possibly again in response to peer packets).
+ //
+ // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state.
+ // Otherwise, we will eventually time out and move to the done state.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1
+ connStateClosing
+
+ // The connection is in the draining state.
+ //
+ // We will neither send packets nor process received packets.
+ // When the drain timer expires, we move to the done state.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2
+ connStateDraining
+
+ // The connection is done, and the conn loop will exit.
+ connStateDone
+)
+
+// lifetimeState tracks the state of a connection.
+//
+// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps
+// reason about operations that cause state transitions.
+type lifetimeState struct {
+ state connState
+
+ readyc chan struct{} // closed when TLS handshake completes
+ donec chan struct{} // closed when finalErr is set
+
+ localErr error // error sent to the peer
+ finalErr error // error sent by the peer, or transport error; set before closing donec
+
+ connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame
+ connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent
+ drainEndTime time.Time // time the connection exits the draining state
+}
+
+func (c *Conn) lifetimeInit() {
+ c.lifetime.readyc = make(chan struct{})
+ c.lifetime.donec = make(chan struct{})
+}
+
+var (
+ errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE")
+ errConnClosed = errors.New("connection closed")
+)
+
+// advance is called when time passes.
+func (c *Conn) lifetimeAdvance(now time.Time) (done bool) {
+ if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) {
+ return false
+ }
+ // The connection drain period has ended, and we can shut down.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7
+ c.lifetime.drainEndTime = time.Time{}
+ if c.lifetime.state != connStateDraining {
+ // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer.
+ c.setFinalError(errNoPeerResponse)
+ }
+ c.setState(now, connStateDone)
+ return true
+}
+
+// setState sets the conn state.
+func (c *Conn) setState(now time.Time, state connState) {
+ if c.lifetime.state == state {
+ return
+ }
+ c.lifetime.state = state
+ switch state {
+ case connStateClosing, connStateDraining:
+ if c.lifetime.drainEndTime.IsZero() {
+ c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod())
+ }
+ case connStateDone:
+ c.setFinalError(nil)
+ }
+ if state != connStateAlive {
+ c.streamsCleanup()
+ }
+}
+
+// confirmHandshake is called when the TLS handshake completes.
+func (c *Conn) handshakeDone() {
+ close(c.lifetime.readyc)
+}
+
+// isDraining reports whether the conn is in the draining state.
+//
+// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame.
+// The endpoint will no longer send any packets, but we retain knowledge of the connection
+// until the end of the drain period to ensure we discard packets for the connection
+// rather than treating them as starting a new connection.
+//
+// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
+func (c *Conn) isDraining() bool {
+ switch c.lifetime.state {
+ case connStateDraining, connStateDone:
+ return true
+ }
+ return false
+}
+
+// isAlive reports whether the conn is handling packets.
+func (c *Conn) isAlive() bool {
+ return c.lifetime.state == connStateAlive
+}
+
+// sendOK reports whether the conn can send frames at this time.
+func (c *Conn) sendOK(now time.Time) bool {
+ switch c.lifetime.state {
+ case connStateAlive:
+ return true
+ case connStatePeerClosed:
+ if c.lifetime.localErr == nil {
+ // We're waiting for the user to close the connection, providing us with
+ // a final status to send to the peer.
+ return false
+ }
+ // We should send a CONNECTION_CLOSE.
+ return true
+ case connStateClosing:
+ if c.lifetime.connCloseSentTime.IsZero() {
+ return true
+ }
+ maxRecvTime := c.acks[initialSpace].maxRecvTime
+ if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) {
+ maxRecvTime = t
+ }
+ if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) {
+ maxRecvTime = t
+ }
+ if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) {
+ // After sending CONNECTION_CLOSE, ignore packets from the peer for
+ // a delay. On the next packet received after the delay, send another
+ // CONNECTION_CLOSE.
+ return false
+ }
+ return true
+ case connStateDraining:
+ // We are in the draining state, and will send no more packets.
+ return false
+ case connStateDone:
+ return false
+ default:
+ panic("BUG: unhandled connection state")
+ }
+}
+
+// sendConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer.
+func (c *Conn) sentConnectionClose(now time.Time) {
+ switch c.lifetime.state {
+ case connStatePeerClosed:
+ c.enterDraining(now)
+ }
+ if c.lifetime.connCloseSentTime.IsZero() {
+ // Set the initial delay before we will send another CONNECTION_CLOSE.
+ //
+ // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames,
+ // but leaves the implementation of the limit up to us. Here, we start
+ // with the same delay as the PTO timer (RFC 9002, Section 6.2.1),
+ // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent.
+ c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity)
+ } else if !c.lifetime.connCloseSentTime.Equal(now) {
+ // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames
+ // coalesced into the same datagram. We only want to increase the delay once.
+ c.lifetime.connCloseDelay *= 2
+ }
+ c.lifetime.connCloseSentTime = now
+}
+
+// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer.
+func (c *Conn) handlePeerConnectionClose(now time.Time, err error) {
+ c.setFinalError(err)
+ switch c.lifetime.state {
+ case connStateAlive:
+ c.setState(now, connStatePeerClosed)
+ case connStatePeerClosed:
+ // Duplicate CONNECTION_CLOSE, ignore.
+ case connStateClosing:
+ if c.lifetime.connCloseSentTime.IsZero() {
+ c.setState(now, connStatePeerClosed)
+ } else {
+ c.setState(now, connStateDraining)
+ }
+ case connStateDraining:
+ case connStateDone:
+ }
+}
+
+// setFinalError records the final connection status we report to the user.
+func (c *Conn) setFinalError(err error) {
+ select {
+ case <-c.lifetime.donec:
+ return // already set
+ default:
+ }
+ c.lifetime.finalErr = err
+ close(c.lifetime.donec)
+}
+
+func (c *Conn) waitReady(ctx context.Context) error {
+ select {
+ case <-c.lifetime.readyc:
+ return nil
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ default:
+ }
+ select {
+ case <-c.lifetime.readyc:
+ return nil
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// Close closes the connection.
+//
+// Close is equivalent to:
+//
+// conn.Abort(nil)
+// err := conn.Wait(context.Background())
+func (c *Conn) Close() error {
+ c.Abort(nil)
+ <-c.lifetime.donec
+ return c.lifetime.finalErr
+}
+
+// Wait waits for the peer to close the connection.
+//
+// If the connection is closed locally and the peer does not close its end of the connection,
+// Wait will return with a non-nil error after the drain period expires.
+//
+// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil.
+// If the peer closes the connection with an application error, Wait returns an ApplicationError
+// containing the peer's error code and reason.
+// If the peer closes the connection with any other status, Wait returns a non-nil error.
+func (c *Conn) Wait(ctx context.Context) error {
+ if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil {
+ return err
+ }
+ return c.lifetime.finalErr
+}
+
+// Abort closes the connection and returns immediately.
+//
+// If err is nil, Abort sends a transport error of NO_ERROR to the peer.
+// If err is an ApplicationError, Abort sends its error code and text.
+// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text.
+func (c *Conn) Abort(err error) {
+ if err == nil {
+ err = localTransportError{code: errNo}
+ }
+ c.sendMsg(func(now time.Time, c *Conn) {
+ c.enterClosing(now, err)
+ })
+}
+
+// abort terminates a connection with an error.
+func (c *Conn) abort(now time.Time, err error) {
+ c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE
+ c.enterClosing(now, err)
+}
+
+// abortImmediately terminates a connection.
+// The connection does not send a CONNECTION_CLOSE, and skips the draining period.
+func (c *Conn) abortImmediately(now time.Time, err error) {
+ c.setFinalError(err)
+ c.setState(now, connStateDone)
+}
+
+// enterClosing starts an immediate close.
+// We will send a CONNECTION_CLOSE to the peer and wait for their response.
+func (c *Conn) enterClosing(now time.Time, err error) {
+ switch c.lifetime.state {
+ case connStateAlive:
+ c.lifetime.localErr = err
+ c.setState(now, connStateClosing)
+ case connStatePeerClosed:
+ c.lifetime.localErr = err
+ }
+}
+
+// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE.
+func (c *Conn) enterDraining(now time.Time) {
+ switch c.lifetime.state {
+ case connStateAlive, connStatePeerClosed, connStateClosing:
+ c.setState(now, connStateDraining)
+ }
+}
+
+// exit fully terminates a connection immediately.
+func (c *Conn) exit() {
+ c.sendMsg(func(now time.Time, c *Conn) {
+ c.abortImmediately(now, errors.New("connection closed"))
+ })
+}
diff --git a/internal/quic/conn_close_test.go b/quic/conn_close_test.go
similarity index 69%
rename from internal/quic/conn_close_test.go
rename to quic/conn_close_test.go
index d583ae92a..213975011 100644
--- a/internal/quic/conn_close_test.go
+++ b/quic/conn_close_test.go
@@ -70,7 +70,8 @@ func TestConnCloseResponseBackoff(t *testing.T) {
}
func TestConnCloseWithPeerResponse(t *testing.T) {
- tc := newTestConn(t, clientSide)
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
tc.handshake()
tc.conn.Abort(nil)
@@ -99,10 +100,19 @@ func TestConnCloseWithPeerResponse(t *testing.T) {
if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) {
t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr)
}
+
+ tc.advance(1 * time.Second) // long enough to exit the draining state
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "application",
+ },
+ })
}
func TestConnClosePeerCloses(t *testing.T) {
- tc := newTestConn(t, clientSide)
+ qr := &qlogRecord{}
+ tc := newTestConn(t, clientSide, qr.config)
tc.handshake()
wantErr := &ApplicationError{
@@ -128,6 +138,14 @@ func TestConnClosePeerCloses(t *testing.T) {
code: 9,
reason: "because",
})
+
+ tc.advance(1 * time.Second) // long enough to exit the draining state
+ qr.wantEvents(t, jsonEvent{
+ "name": "connectivity:connection_closed",
+ "data": map[string]any{
+ "trigger": "application",
+ },
+ })
}
func TestConnCloseReceiveInInitial(t *testing.T) {
@@ -187,14 +205,78 @@ func TestConnCloseReceiveInHandshake(t *testing.T) {
tc.wantIdle("no more frames to send")
}
-func TestConnCloseClosedByListener(t *testing.T) {
+func TestConnCloseClosedByEndpoint(t *testing.T) {
ctx := canceledContext()
tc := newTestConn(t, clientSide)
tc.handshake()
- tc.listener.l.Close(ctx)
- tc.wantFrame("listener closes connection before exiting",
+ tc.endpoint.e.Close(ctx)
+ tc.wantFrame("endpoint closes connection before exiting",
packetType1RTT, debugFrameConnectionCloseTransport{
code: errNo,
})
}
+
+func testConnCloseUnblocks(t *testing.T, f func(context.Context, *testConn) error, opts ...any) {
+ tc := newTestConn(t, clientSide, opts...)
+ tc.handshake()
+ op := runAsync(tc, func(ctx context.Context) (struct{}, error) {
+ return struct{}{}, f(ctx, tc)
+ })
+ if _, err := op.result(); err != errNotDone {
+ t.Fatalf("before abort, op = %v, want errNotDone", err)
+ }
+ tc.conn.Abort(nil)
+ if _, err := op.result(); err == nil || err == errNotDone {
+ t.Fatalf("after abort, op = %v, want error", err)
+ }
+}
+
+func TestConnCloseUnblocksAcceptStream(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ _, err := tc.conn.AcceptStream(ctx)
+ return err
+ }, permissiveTransportParameters)
+}
+
+func TestConnCloseUnblocksNewStream(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ _, err := tc.conn.NewStream(ctx)
+ return err
+ })
+}
+
+func TestConnCloseUnblocksStreamRead(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetReadContext(ctx)
+ buf := make([]byte, 16)
+ _, err := s.Read(buf)
+ return err
+ }, permissiveTransportParameters)
+}
+
+func TestConnCloseUnblocksStreamWrite(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetWriteContext(ctx)
+ buf := make([]byte, 32)
+ _, err := s.Write(buf)
+ return err
+ }, permissiveTransportParameters, func(c *Config) {
+ c.MaxStreamWriteBufferSize = 16
+ })
+}
+
+func TestConnCloseUnblocksStreamClose(t *testing.T) {
+ testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error {
+ s := newLocalStream(t, tc, bidiStream)
+ s.SetWriteContext(ctx)
+ buf := make([]byte, 16)
+ _, err := s.Write(buf)
+ if err != nil {
+ return err
+ }
+ return s.Close()
+ }, permissiveTransportParameters)
+}
diff --git a/internal/quic/conn_flow.go b/quic/conn_flow.go
similarity index 100%
rename from internal/quic/conn_flow.go
rename to quic/conn_flow.go
diff --git a/internal/quic/conn_flow_test.go b/quic/conn_flow_test.go
similarity index 90%
rename from internal/quic/conn_flow_test.go
rename to quic/conn_flow_test.go
index 03e0757a6..260684bdb 100644
--- a/internal/quic/conn_flow_test.go
+++ b/quic/conn_flow_test.go
@@ -12,39 +12,34 @@ import (
)
func TestConnInflowReturnOnRead(t *testing.T) {
- ctx := canceledContext()
tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) {
c.MaxConnReadBufferSize = 64
})
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
- data: make([]byte, 64),
+ data: make([]byte, 8),
})
- const readSize = 8
- if n, err := s.ReadContext(ctx, make([]byte, readSize)); n != readSize || err != nil {
- t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize)
- }
- tc.wantFrame("available window increases, send a MAX_DATA",
- packetType1RTT, debugFrameMaxData{
- max: 64 + readSize,
- })
- if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64-readSize || err != nil {
- t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize)
+ if n, err := s.Read(make([]byte, 8)); n != 8 || err != nil {
+ t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 8)
}
tc.wantFrame("available window increases, send a MAX_DATA",
packetType1RTT, debugFrameMaxData{
- max: 128,
+ max: 64 + 8,
})
// Peer can write up to the new limit.
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
- off: 64,
+ off: 8,
data: make([]byte, 64),
})
- tc.wantIdle("connection is idle")
- if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64 || err != nil {
- t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64)
+ if n, err := s.Read(make([]byte, 64+1)); n != 64 {
+ t.Fatalf("s.Read() = %v, %v; want %v, anything", n, err, 64)
}
+ tc.wantFrame("available window increases, send a MAX_DATA",
+ packetType1RTT, debugFrameMaxData{
+ max: 64 + 8 + 64,
+ })
+ tc.wantIdle("connection is idle")
}
func TestConnInflowReturnOnRacingReads(t *testing.T) {
@@ -64,11 +59,11 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
tc.ignoreFrame(frameTypeAck)
tc.writeFrames(packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, uniStream, 0),
- data: make([]byte, 32),
+ data: make([]byte, 16),
})
tc.writeFrames(packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, uniStream, 1),
- data: make([]byte, 32),
+ data: make([]byte, 1),
})
s1, err := tc.conn.AcceptStream(ctx)
if err != nil {
@@ -79,10 +74,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
t.Fatalf("conn.AcceptStream() = %v", err)
}
read1 := runAsync(tc, func(ctx context.Context) (int, error) {
- return s1.ReadContext(ctx, make([]byte, 16))
+ return s1.Read(make([]byte, 16))
})
read2 := runAsync(tc, func(ctx context.Context) (int, error) {
- return s2.ReadContext(ctx, make([]byte, 1))
+ return s2.Read(make([]byte, 1))
})
// This MAX_DATA might extend the window by 16 or 17, depending on
// whether the second write occurs before the update happens.
@@ -90,10 +85,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) {
packetType1RTT, debugFrameMaxData{})
tc.wantIdle("redundant MAX_DATA is not sent")
if _, err := read1.result(); err != nil {
- t.Errorf("ReadContext #1 = %v", err)
+ t.Errorf("Read #1 = %v", err)
}
if _, err := read2.result(); err != nil {
- t.Errorf("ReadContext #2 = %v", err)
+ t.Errorf("Read #2 = %v", err)
}
}
@@ -204,7 +199,6 @@ func TestConnInflowResetViolation(t *testing.T) {
}
func TestConnInflowMultipleStreams(t *testing.T) {
- ctx := canceledContext()
tc := newTestConn(t, serverSide, func(c *Config) {
c.MaxConnReadBufferSize = 128
})
@@ -220,21 +214,26 @@ func TestConnInflowMultipleStreams(t *testing.T) {
} {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: id,
- data: make([]byte, 32),
+ data: make([]byte, 1),
})
- s, err := tc.conn.AcceptStream(ctx)
- if err != nil {
- t.Fatalf("AcceptStream() = %v", err)
- }
+ s := tc.acceptStream()
streams = append(streams, s)
- if n, err := s.ReadContext(ctx, make([]byte, 1)); err != nil || n != 1 {
+ if n, err := s.Read(make([]byte, 1)); err != nil || n != 1 {
t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err)
}
}
tc.wantIdle("streams have read data, but not enough to update MAX_DATA")
- if n, err := streams[0].ReadContext(ctx, make([]byte, 32)); err != nil || n != 31 {
- t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err)
+ for _, s := range streams {
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: 1,
+ data: make([]byte, 31),
+ })
+ }
+
+ if n, err := streams[0].Read(make([]byte, 32)); n != 31 {
+ t.Fatalf("s.Read() = %v, %v; want 31, anything", n, err)
}
tc.wantFrame("read enough data to trigger a MAX_DATA update",
packetType1RTT, debugFrameMaxData{
@@ -262,6 +261,7 @@ func TestConnOutflowBlocked(t *testing.T) {
if n != len(data) || err != nil {
t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
+ s.Flush()
tc.wantFrame("stream writes data up to MAX_DATA limit",
packetType1RTT, debugFrameStream{
@@ -310,6 +310,7 @@ func TestConnOutflowMaxDataDecreases(t *testing.T) {
if n != len(data) || err != nil {
t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
+ s.Flush()
tc.wantFrame("stream writes data up to MAX_DATA limit",
packetType1RTT, debugFrameStream{
@@ -337,7 +338,9 @@ func TestConnOutflowMaxDataRoundRobin(t *testing.T) {
}
s1.Write(make([]byte, 10))
+ s1.Flush()
s2.Write(make([]byte, 10))
+ s2.Flush()
tc.writeFrames(packetType1RTT, debugFrameMaxData{
max: 1,
@@ -378,6 +381,7 @@ func TestConnOutflowMetaAndData(t *testing.T) {
data := makeTestData(32)
s.Write(data)
+ s.Flush()
s.CloseRead()
tc.wantFrame("CloseRead sends a STOP_SENDING, not flow controlled",
@@ -405,6 +409,7 @@ func TestConnOutflowResentData(t *testing.T) {
data := makeTestData(15)
s.Write(data[:8])
+ s.Flush()
tc.wantFrame("data is under MAX_DATA limit, all sent",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -421,6 +426,7 @@ func TestConnOutflowResentData(t *testing.T) {
})
s.Write(data[8:])
+ s.Flush()
tc.wantFrame("new data is sent up to the MAX_DATA limit",
packetType1RTT, debugFrameStream{
id: s.id,
diff --git a/internal/quic/conn_id.go b/quic/conn_id.go
similarity index 96%
rename from internal/quic/conn_id.go
rename to quic/conn_id.go
index 439c22123..2efe8d6b5 100644
--- a/internal/quic/conn_id.go
+++ b/quic/conn_id.go
@@ -76,7 +76,7 @@ func (s *connIDState) initClient(c *Conn) error {
cid: locid,
})
s.nextLocalSeq = 1
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.addConnID(c, locid)
})
@@ -117,7 +117,7 @@ func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
cid: locid,
})
s.nextLocalSeq = 1
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.addConnID(c, dstConnID)
conns.addConnID(c, locid)
})
@@ -194,7 +194,7 @@ func (s *connIDState) issueLocalIDs(c *Conn) error {
s.needSend = true
toIssue--
}
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
for _, cid := range newIDs {
conns.addConnID(c, cid)
}
@@ -247,7 +247,7 @@ func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p trans
}
token := statelessResetToken(p.statelessResetToken)
s.remote[0].resetToken = token
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.addResetToken(c, token)
})
}
@@ -276,7 +276,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
cid := s.local[0].cid
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.retireConnID(c, cid)
})
s.local = append(s.local[:0], s.local[1:]...)
@@ -314,7 +314,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
rcid := &s.remote[i]
if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
s.retireRemote(rcid)
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.retireResetToken(c, rcid.resetToken)
})
}
@@ -350,7 +350,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
s.retireRemote(&s.remote[len(s.remote)-1])
} else {
active++
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.addResetToken(c, resetToken)
})
}
@@ -399,7 +399,7 @@ func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
for i := range s.local {
if s.local[i].seq == seq {
cid := s.local[i].cid
- c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
conns.retireConnID(c, cid)
})
s.local = append(s.local[:i], s.local[i+1:]...)
@@ -463,7 +463,7 @@ func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
s.local[i].seq,
retireBefore,
s.local[i].cid,
- c.listener.resetGen.tokenForConnID(s.local[i].cid),
+ c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
) {
return false
}
diff --git a/internal/quic/conn_id_test.go b/quic/conn_id_test.go
similarity index 98%
rename from internal/quic/conn_id_test.go
rename to quic/conn_id_test.go
index 314a6b384..d44472e81 100644
--- a/internal/quic/conn_id_test.go
+++ b/quic/conn_id_test.go
@@ -651,16 +651,16 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) {
// Wait for the conn to drain.
// Then wait for the conn loop to exit,
// and force an immediate sync of the connsMap updates
- // (normally only done by the listener read loop).
+ // (normally only done by the endpoint read loop).
tc.advanceToTimer()
<-tc.conn.donec
- tc.listener.l.connsMap.applyUpdates()
+ tc.endpoint.e.connsMap.applyUpdates()
- if got := len(tc.listener.l.connsMap.byConnID); got != 0 {
- t.Errorf("%v conn ids in listener map after closing, want 0", got)
+ if got := len(tc.endpoint.e.connsMap.byConnID); got != 0 {
+ t.Errorf("%v conn ids in endpoint map after closing, want 0", got)
}
- if got := len(tc.listener.l.connsMap.byResetToken); got != 0 {
- t.Errorf("%v reset tokens in listener map after closing, want 0", got)
+ if got := len(tc.endpoint.e.connsMap.byResetToken); got != 0 {
+ t.Errorf("%v reset tokens in endpoint map after closing, want 0", got)
}
})
}
diff --git a/internal/quic/conn_loss.go b/quic/conn_loss.go
similarity index 96%
rename from internal/quic/conn_loss.go
rename to quic/conn_loss.go
index 85bda314e..623ebdd7c 100644
--- a/internal/quic/conn_loss.go
+++ b/quic/conn_loss.go
@@ -20,6 +20,10 @@ import "fmt"
// See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss.
// https://www.rfc-editor.org/rfc/rfc9000#section-13.3
func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) {
+ if fate == packetLost && c.logEnabled(QLogLevelPacket) {
+ c.logPacketLost(space, sent)
+ }
+
// The list of frames in a sent packet is marshaled into a buffer in the sentPacket
// by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with
// packetWriter.append*.
diff --git a/internal/quic/conn_loss_test.go b/quic/conn_loss_test.go
similarity index 93%
rename from internal/quic/conn_loss_test.go
rename to quic/conn_loss_test.go
index 5144be6ac..81d537803 100644
--- a/internal/quic/conn_loss_test.go
+++ b/quic/conn_loss_test.go
@@ -183,7 +183,7 @@ func TestLostStreamFrameEmpty(t *testing.T) {
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ c.Flush() // open the stream
tc.wantFrame("created bidirectional stream 0",
packetType1RTT, debugFrameStream{
id: newStreamID(clientSide, bidiStream, 0),
@@ -213,6 +213,7 @@ func TestLostStreamWithData(t *testing.T) {
p.initialMaxStreamDataUni = 1 << 20
})
s.Write(data[:4])
+ s.Flush()
tc.wantFrame("send [0,4)",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -220,6 +221,7 @@ func TestLostStreamWithData(t *testing.T) {
data: data[:4],
})
s.Write(data[4:8])
+ s.Flush()
tc.wantFrame("send [4,8)",
packetType1RTT, debugFrameStream{
id: s.id,
@@ -263,6 +265,7 @@ func TestLostStreamPartialLoss(t *testing.T) {
})
for i := range data {
s.Write(data[i : i+1])
+ s.Flush()
tc.wantFrame(fmt.Sprintf("send STREAM frame with byte %v", i),
packetType1RTT, debugFrameStream{
id: s.id,
@@ -305,9 +308,9 @@ func TestLostMaxDataFrame(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
off: 0,
- data: make([]byte, maxWindowSize),
+ data: make([]byte, maxWindowSize-1),
})
- if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 {
+ if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1)
}
tc.wantFrame("conn window is extended after reading data",
@@ -316,7 +319,12 @@ func TestLostMaxDataFrame(t *testing.T) {
})
// MAX_DATA = 64, which is only one more byte, so we don't send the frame.
- if n, err := s.Read(buf); err != nil || n != 1 {
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: maxWindowSize - 1,
+ data: make([]byte, 1),
+ })
+ if n, err := s.Read(buf[:1]); err != nil || n != 1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1)
}
tc.wantIdle("read doesn't extend window enough to send another MAX_DATA")
@@ -345,9 +353,9 @@ func TestLostMaxStreamDataFrame(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
off: 0,
- data: make([]byte, maxWindowSize),
+ data: make([]byte, maxWindowSize-1),
})
- if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 {
+ if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1)
}
tc.wantFrame("stream window is extended after reading data",
@@ -357,6 +365,11 @@ func TestLostMaxStreamDataFrame(t *testing.T) {
})
// MAX_STREAM_DATA = 64, which is only one more byte, so we don't send the frame.
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ off: maxWindowSize - 1,
+ data: make([]byte, 1),
+ })
if n, err := s.Read(buf); err != nil || n != 1 {
t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1)
}
@@ -430,7 +443,8 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) {
if err != nil {
t.Fatalf("AcceptStream() = %v", err)
}
- s.CloseContext(ctx)
+ s.SetWriteContext(ctx)
+ s.Close()
if styp == bidiStream {
tc.wantFrame("stream is closed",
packetType1RTT, debugFrameStream{
@@ -477,7 +491,7 @@ func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) {
if err != nil {
t.Fatalf("AcceptStream() = %v", err)
}
- if err := s.CloseContext(ctx); err != nil {
+ if err := s.Close(); err != nil {
t.Fatalf("stream.Close() = %v", err)
}
tc.wantFrame("closing stream updates peer's MAX_STREAMS",
@@ -509,7 +523,7 @@ func TestLostStreamDataBlockedFrame(t *testing.T) {
})
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, []byte{0, 1, 2, 3})
+ return s.Write([]byte{0, 1, 2, 3})
})
defer w.cancel()
tc.wantFrame("write is blocked by flow control",
@@ -561,7 +575,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) {
data := []byte{0, 1, 2, 3}
w := runAsync(tc, func(ctx context.Context) (int, error) {
- return s.WriteContext(ctx, data)
+ return s.Write(data)
})
defer w.cancel()
tc.wantFrame("write is blocked by flow control",
@@ -649,6 +663,29 @@ func TestLostRetireConnectionIDFrame(t *testing.T) {
})
}
+func TestLostPathResponseFrame(t *testing.T) {
+ // "Responses to path validation using PATH_RESPONSE frames are sent just once."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.12
+ lostFrameTest(t, func(t *testing.T, pto bool) {
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypePing)
+
+ data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ tc.writeFrames(packetType1RTT, debugFramePathChallenge{
+ data: data,
+ })
+ tc.wantFrame("response to PATH_CHALLENGE",
+ packetType1RTT, debugFramePathResponse{
+ data: data,
+ })
+
+ tc.triggerLossOrPTO(packetType1RTT, pto)
+ tc.wantIdle("lost PATH_RESPONSE frame is not retransmitted")
+ })
+}
+
func TestLostHandshakeDoneFrame(t *testing.T) {
// "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged."
// https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16
diff --git a/internal/quic/conn_recv.go b/quic/conn_recv.go
similarity index 84%
rename from internal/quic/conn_recv.go
rename to quic/conn_recv.go
index 896c6d74e..b1354cd3a 100644
--- a/internal/quic/conn_recv.go
+++ b/quic/conn_recv.go
@@ -13,11 +13,28 @@ import (
"time"
)
-func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
+func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) {
+ if !c.localAddr.IsValid() {
+ // We don't have any way to tell in the general case what address we're
+ // sending packets from. Set our address from the destination address of
+ // the first packet received from the peer.
+ c.localAddr = dgram.localAddr
+ }
+ if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr {
+ if c.side == clientSide {
+ // "If a client receives packets from an unknown server address,
+ // the client MUST discard these packets."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-9-6
+ return false
+ }
+ // We currently don't support connection migration,
+ // so for now the server also drops packets from an unknown address.
+ return false
+ }
buf := dgram.b
c.loss.datagramReceived(now, len(buf))
if c.isDraining() {
- return
+ return false
}
for len(buf) > 0 {
var n int
@@ -27,19 +44,19 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize {
// Discard client-sent Initial packets in too-short datagrams.
// https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4
- return
+ return false
}
- n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf)
+ n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf)
case packetTypeHandshake:
- n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
+ n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf)
case packetType1RTT:
- n = c.handle1RTT(now, buf)
+ n = c.handle1RTT(now, dgram, buf)
case packetTypeRetry:
c.handleRetry(now, buf)
- return
+ return true
case packetTypeVersionNegotiation:
c.handleVersionNegotiation(now, buf)
- return
+ return true
default:
n = -1
}
@@ -56,17 +73,20 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen {
var token statelessResetToken
copy(token[:], buf[len(buf)-len(token):])
- c.handleStatelessReset(now, token)
+ if c.handleStatelessReset(now, token) {
+ return true
+ }
}
// Invalid data at the end of a datagram is ignored.
- break
+ return false
}
- c.idleTimeout = now.Add(c.maxIdleTimeout)
+ c.idleHandlePacketReceived(now)
buf = buf[n:]
}
+ return true
}
-func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
+func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
if !k.isSet() {
return skipLongHeaderPacket(buf)
}
@@ -101,8 +121,11 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
if logPackets {
logInboundLongPacket(c, p)
}
+ if c.logEnabled(QLogLevelPacket) {
+ c.logLongPacketReceived(p, buf[:n])
+ }
c.connIDState.handlePacket(c, p.ptype, p.srcConnID)
- ackEliciting := c.handleFrames(now, ptype, space, p.payload)
+ ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload)
c.acks[space].receive(now, space, p.num, ackEliciting)
if p.ptype == packetTypeHandshake && c.side == serverSide {
c.loss.validateClientAddress()
@@ -115,7 +138,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
return n
}
-func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
+func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int {
if !c.keysAppData.canRead() {
// 1-RTT packets extend to the end of the datagram,
// so skip the remainder of the datagram if we can't parse this.
@@ -149,7 +172,10 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
if logPackets {
logInboundShortPacket(c, p)
}
- ackEliciting := c.handleFrames(now, packetType1RTT, appDataSpace, p.payload)
+ if c.logEnabled(QLogLevelPacket) {
+ c.log1RTTPacketReceived(p, buf)
+ }
+ ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload)
c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting)
return len(buf)
}
@@ -186,7 +212,7 @@ func (c *Conn) handleRetry(now time.Time, pkt []byte) {
c.connIDState.handleRetryPacket(p.srcConnID)
// We need to resend any data we've already sent in Initial packets.
// We must not reuse already sent packet numbers.
- c.loss.discardPackets(initialSpace, c.handleAckOrLoss)
+ c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss)
// TODO: Discard 0-RTT packets as well, once we support 0-RTT.
}
@@ -226,7 +252,7 @@ func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) {
c.abortImmediately(now, errVersionNegotiation)
}
-func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) {
+func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) {
if len(payload) == 0 {
// "An endpoint MUST treat receipt of a packet containing no frames
// as a connection error of type PROTOCOL_VIOLATION."
@@ -347,6 +373,16 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
return
}
n = c.handleRetireConnectionIDFrame(now, space, payload)
+ case frameTypePathChallenge:
+ if !frameOK(c, ptype, __01) {
+ return
+ }
+ n = c.handlePathChallengeFrame(now, dgram, space, payload)
+ case frameTypePathResponse:
+ if !frameOK(c, ptype, ___1) {
+ return
+ }
+ n = c.handlePathResponseFrame(now, space, payload)
case frameTypeConnectionCloseTransport:
// Transport CONNECTION_CLOSE is OK in all spaces.
n = c.handleConnectionCloseTransportFrame(now, payload)
@@ -410,7 +446,7 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte)
if c.peerAckDelayExponent >= 0 {
delay = ackDelay.Duration(uint8(c.peerAckDelayExponent))
}
- c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
+ c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss)
if space == appDataSpace {
c.keysAppData.handleAckFor(largest)
}
@@ -520,12 +556,30 @@ func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, p
return n
}
+func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int {
+ data, n := consumePathChallengeFrame(payload)
+ if n < 0 {
+ return -1
+ }
+ c.handlePathChallenge(now, dgram, data)
+ return n
+}
+
+func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int {
+ data, n := consumePathResponseFrame(payload)
+ if n < 0 {
+ return -1
+ }
+ c.handlePathResponse(now, data)
+ return n
+}
+
func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int {
code, _, reason, n := consumeConnectionCloseTransportFrame(payload)
if n < 0 {
return -1
}
- c.enterDraining(now, peerTransportError{code: code, reason: reason})
+ c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason})
return n
}
@@ -534,7 +588,7 @@ func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []by
if n < 0 {
return -1
}
- c.enterDraining(now, &ApplicationError{Code: code, Reason: reason})
+ c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason})
return n
}
@@ -548,7 +602,7 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa
})
return -1
}
- if !c.isClosingOrDraining() {
+ if c.isAlive() {
c.confirmHandshake(now)
}
return 1
@@ -556,9 +610,11 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa
var errStatelessReset = errors.New("received stateless reset")
-func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) {
+func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) {
if !c.connIDState.isValidStatelessResetToken(resetToken) {
- return
+ return false
}
- c.enterDraining(now, errStatelessReset)
+ c.setFinalError(errStatelessReset)
+ c.enterDraining(now)
+ return true
}
diff --git a/internal/quic/conn_send.go b/quic/conn_send.go
similarity index 82%
rename from internal/quic/conn_send.go
rename to quic/conn_send.go
index 22e780479..a87cac232 100644
--- a/internal/quic/conn_send.go
+++ b/quic/conn_send.go
@@ -22,7 +22,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// Assumption: The congestion window is not underutilized.
// If congestion control, pacing, and anti-amplification all permit sending,
// but we have no packet to send, then we will declare the window underutilized.
- c.loss.cc.setUnderutilized(false)
+ underutilized := false
+ defer func() {
+ c.loss.cc.setUnderutilized(c.log, underutilized)
+ }()
// Send one datagram on each iteration of this loop,
// until we hit a limit or run out of data to send.
@@ -60,7 +63,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
pad := false
var sentInitial *sentPacket
if c.keysInitial.canWrite() {
- pnumMaxAcked := c.acks[initialSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked
pnum := c.loss.nextNumber(initialSpace)
p := longPacket{
ptype: packetTypeInitial,
@@ -75,6 +78,9 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if logPackets {
logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload())
+ }
sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p)
if sentInitial != nil {
// Client initial packets and ack-eliciting server initial packaets
@@ -89,7 +95,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// Handshake packet.
if c.keysHandshake.canWrite() {
- pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked
pnum := c.loss.nextNumber(handshakeSpace)
p := longPacket{
ptype: packetTypeHandshake,
@@ -103,8 +109,11 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if logPackets {
logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload())
+ }
if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil {
- c.loss.packetSent(now, handshakeSpace, sent)
+ c.packetSent(now, handshakeSpace, sent)
if c.side == clientSide {
// "[...] a client MUST discard Initial keys when it first
// sends a Handshake packet [...]"
@@ -116,7 +125,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// 1-RTT packet.
if c.keysAppData.canWrite() {
- pnumMaxAcked := c.acks[appDataSpace].largestSeen()
+ pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked
pnum := c.loss.nextNumber(appDataSpace)
c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
c.appendFrames(now, appDataSpace, pnum, limit)
@@ -130,8 +139,11 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if logPackets {
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
}
+ if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 {
+ c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload())
+ }
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
- c.loss.packetSent(now, appDataSpace, sent)
+ c.packetSent(now, appDataSpace, sent)
}
}
@@ -140,7 +152,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
if limit == ccOK {
// We have nothing to send, and congestion control does not
// block sending. The congestion window is underutilized.
- c.loss.cc.setUnderutilized(true)
+ underutilized = true
}
return next
}
@@ -163,14 +175,22 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
// with a Handshake packet, then we've discarded Initial keys
// since constructing the packet and shouldn't record it as in-flight.
if c.keysInitial.canWrite() {
- c.loss.packetSent(now, initialSpace, sentInitial)
+ c.packetSent(now, initialSpace, sentInitial)
}
}
- c.listener.sendDatagram(buf, c.peerAddr)
+ c.endpoint.sendDatagram(datagram{
+ b: buf,
+ peerAddr: c.peerAddr,
+ })
}
}
+func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+ c.idleHandlePacketSent(now, sent)
+ c.loss.packetSent(now, c.log, space, sent)
+}
+
func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) {
if c.lifetime.localErr != nil {
c.appendConnectionCloseFrame(now, space, c.lifetime.localErr)
@@ -210,11 +230,7 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
// Either we are willing to send an ACK-only packet,
// or we've added additional frames.
c.acks[space].sentAck()
- if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() {
- // The peer has initiated a key update.
- // We haven't sent them any packets yet in the new phase.
- // Make this an ack-eliciting packet.
- // Their ack of this packet will complete the key update.
+ if !c.w.sent.ackEliciting && c.shouldMakePacketAckEliciting() {
c.w.appendPingFrame()
}
}()
@@ -255,12 +271,23 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
return
}
+ // PATH_RESPONSE
+ if pad, ok := c.appendPathFrames(); !ok {
+ return
+ } else if pad {
+ defer c.w.appendPaddingTo(smallestMaxDatagramSize)
+ }
+
// All stream-related frames. This should come last in the packet,
// so large amounts of STREAM data don't crowd out other frames
// we may need to send.
if !c.appendStreamFrames(&c.w, pnum, pto) {
return
}
+
+ if !c.appendKeepAlive(now) {
+ return
+ }
}
// If this is a PTO probe and we haven't added an ack-eliciting frame yet,
@@ -315,6 +342,30 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
}
}
+// shouldMakePacketAckEliciting is called when sending a packet containing nothing but an ACK frame.
+// It reports whether we should add a PING frame to the packet to make it ack-eliciting.
+func (c *Conn) shouldMakePacketAckEliciting() bool {
+ if c.keysAppData.needAckEliciting() {
+ // The peer has initiated a key update.
+ // We haven't sent them any packets yet in the new phase.
+ // Make this an ack-eliciting packet.
+ // Their ack of this packet will complete the key update.
+ return true
+ }
+ if c.loss.consecutiveNonAckElicitingPackets >= 19 {
+ // We've sent a run of non-ack-eliciting packets.
+ // Add in an ack-eliciting one every once in a while so the peer
+ // lets us know which ones have arrived.
+ //
+ // Google QUICHE injects a PING after sending 19 packets. We do the same.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2
+ return true
+ }
+ // TODO: Consider making every packet sent when in PTO ack-eliciting to speed up recovery.
+ return false
+}
+
func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool {
seen, delay := c.acks[space].acksToSend(now)
if len(seen) == 0 {
@@ -325,7 +376,7 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool {
}
func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) {
- c.lifetime.connCloseSentTime = now
+ c.sentConnectionClose(now)
switch e := err.(type) {
case localTransportError:
c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason)
@@ -342,11 +393,12 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err
// TLS alerts are sent using error codes [0x0100,0x01ff).
// https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1
var alert tls.AlertError
- if errors.As(err, &alert) {
+ switch {
+ case errors.As(err, &alert):
// tls.AlertError is a uint8, so this can't exceed 0x01ff.
code := errTLSBase + transportError(alert)
c.w.appendConnectionCloseTransportFrame(code, 0, "")
- } else {
+ default:
c.w.appendConnectionCloseTransportFrame(errInternal, 0, "")
}
}
diff --git a/quic/conn_send_test.go b/quic/conn_send_test.go
new file mode 100644
index 000000000..2205ff2f7
--- /dev/null
+++ b/quic/conn_send_test.go
@@ -0,0 +1,83 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "testing"
+ "time"
+)
+
+func TestAckElicitingAck(t *testing.T) {
+ // "A receiver that sends only non-ack-eliciting packets [...] might not receive
+ // an acknowledgment for a long period of time.
+ // [...] a receiver could send a [...] ack-eliciting frame occasionally [...]
+ // to elicit an ACK from the peer."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2
+ //
+ // Send a bunch of ack-eliciting packets, verify that the conn doesn't just
+ // send ACKs in response.
+ tc := newTestConn(t, clientSide, permissiveTransportParameters)
+ tc.handshake()
+ const count = 100
+ for i := 0; i < count; i++ {
+ tc.advance(1 * time.Millisecond)
+ tc.writeFrames(packetType1RTT,
+ debugFramePing{},
+ )
+ got, _ := tc.readFrame()
+ switch got.(type) {
+ case debugFrameAck:
+ continue
+ case debugFramePing:
+ return
+ }
+ }
+ t.Errorf("after sending %v PINGs, got no ack-eliciting response", count)
+}
+
+func TestSendPacketNumberSize(t *testing.T) {
+ tc := newTestConn(t, clientSide, permissiveTransportParameters)
+ tc.handshake()
+
+ recvPing := func() *testPacket {
+ t.Helper()
+ tc.conn.ping(appDataSpace)
+ p := tc.readPacket()
+ if p == nil {
+ t.Fatalf("want packet containing PING, got none")
+ }
+ return p
+ }
+
+ // Desynchronize the packet numbers the conn is sending and the ones it is receiving,
+ // by having the conn send a number of unacked packets.
+ for i := 0; i < 16; i++ {
+ recvPing()
+ }
+
+ // Establish the maximum packet number the conn has received an ACK for.
+ maxAcked := recvPing().num
+ tc.writeAckForAll()
+
+ // Make the conn send a sequence of packets.
+ // Check that the packet number is encoded with two bytes once the difference between the
+ // current packet and the max acked one is sufficiently large.
+ for want := maxAcked + 1; want < maxAcked+0x100; want++ {
+ p := recvPing()
+ if p.num != want {
+ t.Fatalf("received packet number %v, want %v", p.num, want)
+ }
+ gotPnumLen := int(p.header&0x03) + 1
+ wantPnumLen := 1
+ if p.num-maxAcked >= 0x80 {
+ wantPnumLen = 2
+ }
+ if gotPnumLen != wantPnumLen {
+ t.Fatalf("packet number 0x%x encoded with %v bytes, want %v (max acked = %v)", p.num, gotPnumLen, wantPnumLen, maxAcked)
+ }
+ }
+}
diff --git a/internal/quic/conn_streams.go b/quic/conn_streams.go
similarity index 91%
rename from internal/quic/conn_streams.go
rename to quic/conn_streams.go
index 83ab5554c..87cfd297e 100644
--- a/internal/quic/conn_streams.go
+++ b/quic/conn_streams.go
@@ -16,8 +16,14 @@ import (
type streamsState struct {
queue queue[*Stream] // new, peer-created streams
- streamsMu sync.Mutex
- streams map[streamID]*Stream
+ // All peer-created streams.
+ //
+ // Implicitly created streams are included as an empty entry in the map.
+ // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and
+ // insert an empty entry for it to the map.)
+ //
+ // The map value is maybeStream rather than *Stream as a reminder that values can be nil.
+ streams map[streamID]maybeStream
// Limits on the number of streams, indexed by streamType.
localLimit [streamTypeCount]localStreamLimits
@@ -39,8 +45,13 @@ type streamsState struct {
queueData streamRing // streams with only flow-controlled frames
}
+// maybeStream is a possibly nil *Stream. See streamsState.streams.
+type maybeStream struct {
+ s *Stream
+}
+
func (c *Conn) streamsInit() {
- c.streams.streams = make(map[streamID]*Stream)
+ c.streams.streams = make(map[streamID]maybeStream)
c.streams.queue = newQueue[*Stream]()
c.streams.localLimit[bidiStream].init()
c.streams.localLimit[uniStream].init()
@@ -49,6 +60,17 @@ func (c *Conn) streamsInit() {
c.inflowInit()
}
+func (c *Conn) streamsCleanup() {
+ c.streams.queue.close(errConnClosed)
+ c.streams.localLimit[bidiStream].connHasClosed()
+ c.streams.localLimit[uniStream].connHasClosed()
+ for _, s := range c.streams.streams {
+ if s.s != nil {
+ s.s.connHasClosed()
+ }
+ }
+}
+
// AcceptStream waits for and returns the next stream created by the peer.
func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) {
return c.streams.queue.get(ctx, c.testHooks)
@@ -71,9 +93,6 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) {
}
func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) {
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
-
num, err := c.streams.localLimit[styp].open(ctx, c)
if err != nil {
return nil, err
@@ -89,7 +108,12 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er
s.inUnlock()
s.outUnlock()
- c.streams.streams[s.id] = s
+ // Modify c.streams on the conn's loop.
+ if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ c.streams.streams[s.id] = maybeStream{s}
+ }); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -108,9 +132,7 @@ const (
// streamForID returns the stream with the given id.
// If the stream does not exist, it returns nil.
func (c *Conn) streamForID(id streamID) *Stream {
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
- return c.streams.streams[id]
+ return c.streams.streams[id].s
}
// streamForFrame returns the stream with the given id.
@@ -135,11 +157,9 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
}
}
- c.streams.streamsMu.Lock()
- defer c.streams.streamsMu.Unlock()
- s, isOpen := c.streams.streams[id]
- if s != nil {
- return s
+ ms, isOpen := c.streams.streams[id]
+ if ms.s != nil {
+ return ms.s
}
num := id.num()
@@ -176,10 +196,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
// with the same initiator and type and a lower number.
// Add a nil entry to the streams map for each implicitly created stream.
for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 {
- c.streams.streams[n] = nil
+ c.streams.streams[n] = maybeStream{}
}
- s = newStream(c, id)
+ s := newStream(c, id)
s.inmaxbuf = c.config.maxStreamReadBufferSize()
s.inwin = c.config.maxStreamReadBufferSize()
if id.streamType() == bidiStream {
@@ -189,7 +209,7 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
s.inUnlock()
s.outUnlock()
- c.streams.streams[id] = s
+ c.streams.streams[id] = maybeStream{s}
c.streams.queue.put(s)
return s
}
@@ -393,7 +413,11 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool {
c.streams.sendMu.Lock()
defer c.streams.sendMu.Unlock()
const pto = true
- for _, s := range c.streams.streams {
+ for _, ms := range c.streams.streams {
+ s := ms.s
+ if s == nil {
+ continue
+ }
const pto = true
s.ingate.lock()
inOK := s.appendInFramesLocked(w, pnum, pto)
diff --git a/internal/quic/conn_streams_test.go b/quic/conn_streams_test.go
similarity index 84%
rename from internal/quic/conn_streams_test.go
rename to quic/conn_streams_test.go
index 69f982c3a..dc81ad991 100644
--- a/internal/quic/conn_streams_test.go
+++ b/quic/conn_streams_test.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"math"
+ "sync"
"testing"
)
@@ -19,33 +20,33 @@ func TestStreamsCreate(t *testing.T) {
tc := newTestConn(t, clientSide, permissiveTransportParameters)
tc.handshake()
- c, err := tc.conn.NewStream(ctx)
+ s, err := tc.conn.NewStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created bidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 0, // client-initiated, bidi, number 0
data: []byte{},
})
- c, err = tc.conn.NewSendOnlyStream(ctx)
+ s, err = tc.conn.NewSendOnlyStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created unidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 2, // client-initiated, uni, number 0
data: []byte{},
})
- c, err = tc.conn.NewStream(ctx)
+ s, err = tc.conn.NewStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created bidirectional stream 1",
packetType1RTT, debugFrameStream{
id: 4, // client-initiated, uni, number 4
@@ -177,11 +178,11 @@ func TestStreamsStreamSendOnly(t *testing.T) {
tc := newTestConn(t, serverSide, permissiveTransportParameters)
tc.handshake()
- c, err := tc.conn.NewSendOnlyStream(ctx)
+ s, err := tc.conn.NewSendOnlyStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
- c.Write(nil) // open the stream
+ s.Flush() // open the stream
tc.wantFrame("created unidirectional stream 0",
packetType1RTT, debugFrameStream{
id: 3, // server-initiated, uni, number 0
@@ -229,8 +230,8 @@ func TestStreamsWriteQueueFairness(t *testing.T) {
t.Fatal(err)
}
streams = append(streams, s)
- if n, err := s.WriteContext(ctx, data); n != len(data) || err != nil {
- t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(data))
+ if n, err := s.Write(data); n != len(data) || err != nil {
+ t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data))
}
// Wait for the stream to finish writing whatever frames it can before
// congestion control blocks it.
@@ -297,7 +298,7 @@ func TestStreamsShutdown(t *testing.T) {
side: localStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- s.CloseContext(canceledContext())
+ s.Close()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
@@ -310,7 +311,7 @@ func TestStreamsShutdown(t *testing.T) {
tc.writeFrames(packetType1RTT, debugFrameResetStream{
id: s.id,
})
- s.CloseContext(canceledContext())
+ s.Close()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
tc.writeAckForAll()
@@ -320,8 +321,8 @@ func TestStreamsShutdown(t *testing.T) {
side: localStream,
styp: bidiStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- s.CloseContext(canceledContext())
- tc.wantIdle("all frames after CloseContext are ignored")
+ s.Close()
+ tc.wantIdle("all frames after Close are ignored")
tc.writeAckForAll()
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
@@ -334,13 +335,12 @@ func TestStreamsShutdown(t *testing.T) {
side: remoteStream,
styp: uniStream,
setup: func(t *testing.T, tc *testConn, s *Stream) {
- ctx := canceledContext()
tc.writeFrames(packetType1RTT, debugFrameStream{
id: s.id,
fin: true,
})
- if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF {
- t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err)
+ if n, err := s.Read(make([]byte, 16)); n != 0 || err != io.EOF {
+ t.Errorf("Read() = %v, %v; want 0, io.EOF", n, err)
}
},
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
@@ -450,17 +450,14 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
id: op.id,
})
case acceptOp:
- s, err := tc.conn.AcceptStream(ctx)
- if err != nil {
- t.Fatalf("AcceptStream() = %q; want stream %v", err, stringID(op.id))
- }
+ s := tc.acceptStream()
if s.id != op.id {
- t.Fatalf("accepted stram %v; want stream %v", err, stringID(op.id))
+ t.Fatalf("accepted stream %v; want stream %v", stringID(s.id), stringID(op.id))
}
t.Logf("accepted stream %v", stringID(op.id))
// Immediately close the stream, so the stream becomes done when the
// peer closes its end.
- s.CloseContext(ctx)
+ s.Close()
}
p := tc.readPacket()
if p != nil {
@@ -478,3 +475,85 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
t.Fatalf("after test, stream send queue is not empty; should be")
}
}
+
+func TestStreamsCreateConcurrency(t *testing.T) {
+ cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+
+ srvdone := make(chan int)
+ go func() {
+ defer close(srvdone)
+ for streams := 0; ; streams++ {
+ s, err := srv.AcceptStream(context.Background())
+ if err != nil {
+ srvdone <- streams
+ return
+ }
+ s.Close()
+ }
+ }()
+
+ var wg sync.WaitGroup
+ const concurrency = 10
+ const streams = 10
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < streams; j++ {
+ s, err := cli.NewStream(context.Background())
+ if err != nil {
+ t.Errorf("NewStream: %v", err)
+ return
+ }
+ s.Flush()
+ _, err = io.ReadAll(s)
+ if err != nil {
+ t.Errorf("ReadFull: %v", err)
+ }
+ s.Close()
+ }
+ }()
+ }
+ wg.Wait()
+
+ cli.Abort(nil)
+ srv.Abort(nil)
+ if got, want := <-srvdone, concurrency*streams; got != want {
+ t.Errorf("accepted %v streams, want %v", got, want)
+ }
+}
+
+func TestStreamsPTOWithImplicitStream(t *testing.T) {
+ ctx := canceledContext()
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+
+ // Peer creates stream 1, and implicitly creates stream 0.
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ })
+
+ // We accept stream 1 and write data to it.
+ data := []byte("data")
+ s, err := tc.conn.AcceptStream(ctx)
+ if err != nil {
+ t.Fatalf("conn.AcceptStream() = %v, want stream", err)
+ }
+ s.Write(data)
+ s.Flush()
+ tc.wantFrame("data written to stream",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ data: data,
+ })
+
+ // PTO expires, and the data is resent.
+ const pto = true
+ tc.triggerLossOrPTO(packetType1RTT, true)
+ tc.wantFrame("data resent after PTO expires",
+ packetType1RTT, debugFrameStream{
+ id: newStreamID(clientSide, bidiStream, 1),
+ data: data,
+ })
+}
diff --git a/internal/quic/conn_test.go b/quic/conn_test.go
similarity index 90%
rename from internal/quic/conn_test.go
rename to quic/conn_test.go
index c70c58ef0..f4f1818a6 100644
--- a/internal/quic/conn_test.go
+++ b/quic/conn_test.go
@@ -13,44 +13,56 @@ import (
"errors"
"flag"
"fmt"
+ "log/slog"
"math"
"net/netip"
"reflect"
"strings"
"testing"
"time"
+
+ "golang.org/x/net/quic/qlog"
)
-var testVV = flag.Bool("vv", false, "even more verbose test output")
+var (
+ testVV = flag.Bool("vv", false, "even more verbose test output")
+ qlogdir = flag.String("qlog", "", "write qlog logs to directory")
+)
func TestConnTestConn(t *testing.T) {
tc := newTestConn(t, serverSide)
+ tc.handshake()
if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
}
- var ranAt time.Time
- tc.conn.runOnLoop(func(now time.Time, c *Conn) {
- ranAt = now
- })
- if !ranAt.Equal(tc.listener.now) {
- t.Errorf("func ran on loop at %v, want %v", ranAt, tc.listener.now)
+ ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
+ tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ when = now
+ })
+ return
+ }).result()
+ if !ranAt.Equal(tc.endpoint.now) {
+ t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
}
tc.wait()
- nextTime := tc.listener.now.Add(defaultMaxIdleTimeout / 2)
+ nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
tc.advanceTo(nextTime)
- tc.conn.runOnLoop(func(now time.Time, c *Conn) {
- ranAt = now
- })
+ ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
+ tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
+ when = now
+ })
+ return
+ }).result()
if !ranAt.Equal(nextTime) {
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
}
tc.wait()
tc.advanceToTimer()
- if !tc.conn.exited {
- t.Errorf("after advancing to idle timeout, exited = false, want true")
+ if got := tc.conn.lifetime.state; got != connStateDone {
+ t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
}
}
@@ -76,6 +88,7 @@ func (d testDatagram) String() string {
type testPacket struct {
ptype packetType
+ header byte
version uint32
num packetNumber
keyPhaseBit bool
@@ -116,7 +129,7 @@ const maxTestKeyPhases = 3
type testConn struct {
t *testing.T
conn *Conn
- listener *testListener
+ endpoint *testEndpoint
timer time.Time
timerLastFired time.Time
idlec chan struct{} // only accessed on the conn's loop
@@ -155,6 +168,7 @@ type testConn struct {
sentDatagrams [][]byte
sentPackets []*testPacket
sentFrames []debugFrame
+ lastDatagram *testDatagram
lastPacket *testPacket
recvDatagram chan *datagram
@@ -192,12 +206,17 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
config := &Config{
TLSConfig: newTestTLSConfig(side),
StatelessResetKey: testStatelessResetKey,
+ QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ Dir: *qlogdir,
+ })),
}
var cids newServerConnIDs
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
cids.srcConnID = testPeerConnID(0)
cids.dstConnID = testPeerConnID(-1)
+ cids.originalDstConnID = cids.dstConnID
}
var configTransportParams []func(*transportParameters)
var configTestConn []func(*testConn)
@@ -218,27 +237,29 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
}
}
- listener := newTestListener(t, config)
- listener.configTransportParams = configTransportParams
- listener.configTestConn = configTestConn
- conn, err := listener.l.newConn(
- listener.now,
+ endpoint := newTestEndpoint(t, config)
+ endpoint.configTransportParams = configTransportParams
+ endpoint.configTestConn = configTestConn
+ conn, err := endpoint.e.newConn(
+ endpoint.now,
+ config,
side,
cids,
+ "",
netip.MustParseAddrPort("127.0.0.1:443"))
if err != nil {
t.Fatal(err)
}
- tc := listener.conns[conn]
+ tc := endpoint.conns[conn]
tc.wait()
return tc
}
-func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testConn {
+func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
t.Helper()
tc := &testConn{
t: t,
- listener: listener,
+ endpoint: endpoint,
conn: conn,
peerConnID: testPeerConnID(0),
ignoreFrames: map[byte]bool{
@@ -249,14 +270,14 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC
recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
- for _, f := range listener.configTestConn {
+ for _, f := range endpoint.configTestConn {
f(tc)
}
conn.testHooks = (*testConnHooks)(tc)
- if listener.peerTLSConn != nil {
- tc.peerTLSConn = listener.peerTLSConn
- listener.peerTLSConn = nil
+ if endpoint.peerTLSConn != nil {
+ tc.peerTLSConn = endpoint.peerTLSConn
+ endpoint.peerTLSConn = nil
return tc
}
@@ -265,7 +286,7 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC
if conn.side == clientSide {
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
}
- for _, f := range listener.configTransportParams {
+ for _, f := range endpoint.configTransportParams {
f(&peerProvidedParams)
}
@@ -277,6 +298,9 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC
}
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tc.peerTLSConn.Start(context.Background())
+ t.Cleanup(func() {
+ tc.peerTLSConn.Close()
+ })
return tc
}
@@ -284,13 +308,13 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC
// advance causes time to pass.
func (tc *testConn) advance(d time.Duration) {
tc.t.Helper()
- tc.listener.advance(d)
+ tc.endpoint.advance(d)
}
// advanceTo sets the current time.
func (tc *testConn) advanceTo(now time.Time) {
tc.t.Helper()
- tc.listener.advanceTo(now)
+ tc.endpoint.advanceTo(now)
}
// advanceToTimer sets the current time to the time of the Conn's next timer event.
@@ -305,10 +329,10 @@ func (tc *testConn) timerDelay() time.Duration {
if tc.timer.IsZero() {
return math.MaxInt64 // infinite
}
- if tc.timer.Before(tc.listener.now) {
+ if tc.timer.Before(tc.endpoint.now) {
return 0
}
- return tc.timer.Sub(tc.listener.now)
+ return tc.timer.Sub(tc.endpoint.now)
}
const infiniteDuration = time.Duration(math.MaxInt64)
@@ -318,10 +342,10 @@ func (tc *testConn) timeUntilEvent() time.Duration {
if tc.timer.IsZero() {
return infiniteDuration
}
- if tc.timer.Before(tc.listener.now) {
+ if tc.timer.Before(tc.endpoint.now) {
return 0
}
- return tc.timer.Sub(tc.listener.now)
+ return tc.timer.Sub(tc.endpoint.now)
}
// wait blocks until the conn becomes idle.
@@ -361,6 +385,17 @@ func (tc *testConn) cleanup() {
<-tc.conn.donec
}
+func (tc *testConn) acceptStream() *Stream {
+ tc.t.Helper()
+ s, err := tc.conn.AcceptStream(canceledContext())
+ if err != nil {
+ tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
+ }
+ s.SetReadContext(canceledContext())
+ s.SetWriteContext(canceledContext())
+ return s
+}
+
func logDatagram(t *testing.T, text string, d *testDatagram) {
t.Helper()
if !*testVV {
@@ -398,7 +433,7 @@ func logDatagram(t *testing.T, text string, d *testDatagram) {
// write sends the Conn a datagram.
func (tc *testConn) write(d *testDatagram) {
tc.t.Helper()
- tc.listener.writeDatagram(d)
+ tc.endpoint.writeDatagram(d)
}
// writeFrame sends the Conn a datagram containing the given frames.
@@ -421,6 +456,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
dstConnID: dstConnID,
srcConnID: tc.peerConnID,
}},
+ addr: tc.conn.peerAddr,
}
if ptype == packetTypeInitial && tc.conn.side == serverSide {
d.paddedSize = 1200
@@ -464,11 +500,11 @@ func (tc *testConn) readDatagram() *testDatagram {
tc.wait()
tc.sentPackets = nil
tc.sentFrames = nil
- buf := tc.listener.read()
+ buf := tc.endpoint.read()
if buf == nil {
return nil
}
- d := parseTestDatagram(tc.t, tc.listener, tc, buf)
+ d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
// Log the datagram before removing ignored frames.
// When things go wrong, it's useful to see all the frames.
logDatagram(tc.t, "-> conn under test sends", d)
@@ -543,6 +579,7 @@ func (tc *testConn) readDatagram() *testDatagram {
}
p.frames = frames
}
+ tc.lastDatagram = d
return d
}
@@ -589,12 +626,18 @@ func (tc *testConn) readFrame() (debugFrame, packetType) {
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
tc.t.Helper()
got := tc.readDatagram()
- if !reflect.DeepEqual(got, want) {
+ if !datagramEqual(got, want) {
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
func datagramEqual(a, b *testDatagram) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
if a.paddedSize != b.paddedSize ||
a.addr != b.addr ||
len(a.packets) != len(b.packets) {
@@ -612,16 +655,24 @@ func datagramEqual(a, b *testDatagram) bool {
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
got := tc.readPacket()
- if !reflect.DeepEqual(got, want) {
+ if !packetEqual(got, want) {
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
}
}
func packetEqual(a, b *testPacket) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
ac := *a
ac.frames = nil
+ ac.header = 0
bc := *b
bc.frames = nil
+ bc.header = 0
if !reflect.DeepEqual(ac, bc) {
return false
}
@@ -769,7 +820,7 @@ func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte
return w.datagram()
}
-func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte) *testDatagram {
+func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
t.Helper()
bufSize := len(buf)
d := &testDatagram{}
@@ -782,7 +833,7 @@ func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte)
ptype := getPacketType(buf)
switch ptype {
case packetTypeRetry:
- retry, ok := parseRetryPacket(buf, tl.lastInitialDstConnID)
+ retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
if !ok {
t.Fatalf("could not parse %v packet", ptype)
}
@@ -829,6 +880,7 @@ func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte)
}
d.packets = append(d.packets, &testPacket{
ptype: p.ptype,
+ header: buf[0],
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
@@ -870,6 +922,7 @@ func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte)
}
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
+ header: hdr[0],
num: pnum,
dstConnID: hdr[1:][:len(tc.peerConnID)],
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
@@ -936,7 +989,7 @@ func (tc *testConnHooks) init() {
tc.keysInitial.r = tc.conn.keysInitial.w
tc.keysInitial.w = tc.conn.keysInitial.r
if tc.conn.side == serverSide {
- tc.listener.acceptQueue = append(tc.listener.acceptQueue, (*testConn)(tc))
+ tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
}
}
@@ -1037,20 +1090,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
tc.timer = timer
for {
- if !timer.IsZero() && !timer.After(tc.listener.now) {
+ if !timer.IsZero() && !timer.After(tc.endpoint.now) {
if timer.Equal(tc.timerLastFired) {
// If the connection timer fires at time T, the Conn should take some
// action to advance the timer into the future. If the Conn reschedules
// the timer for the same time, it isn't making progress and we have a bug.
- tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.listener.now, timer)
+ tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
} else {
tc.timerLastFired = timer
- return tc.listener.now, timerEvent{}
+ return tc.endpoint.now, timerEvent{}
}
}
select {
case m := <-msgc:
- return tc.listener.now, m
+ return tc.endpoint.now, m
default:
}
if !tc.wakeAsync() {
@@ -1064,7 +1117,7 @@ func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.T
close(idlec)
}
m = <-msgc
- return tc.listener.now, m
+ return tc.endpoint.now, m
}
func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
@@ -1072,7 +1125,7 @@ func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
}
func (tc *testConnHooks) timeNow() time.Time {
- return tc.listener.now
+ return tc.endpoint.now
}
// testLocalConnID returns the connection ID with a given sequence number
diff --git a/internal/quic/crypto_stream.go b/quic/crypto_stream.go
similarity index 100%
rename from internal/quic/crypto_stream.go
rename to quic/crypto_stream.go
diff --git a/internal/quic/crypto_stream_test.go b/quic/crypto_stream_test.go
similarity index 100%
rename from internal/quic/crypto_stream_test.go
rename to quic/crypto_stream_test.go
diff --git a/internal/quic/dgram.go b/quic/dgram.go
similarity index 58%
rename from internal/quic/dgram.go
rename to quic/dgram.go
index 79e6650fa..615589373 100644
--- a/internal/quic/dgram.go
+++ b/quic/dgram.go
@@ -12,10 +12,25 @@ import (
)
type datagram struct {
- b []byte
- addr netip.AddrPort
+ b []byte
+ localAddr netip.AddrPort
+ peerAddr netip.AddrPort
+ ecn ecnBits
}
+// Explicit Congestion Notification bits.
+//
+// https://www.rfc-editor.org/rfc/rfc3168.html#section-5
+type ecnBits byte
+
+const (
+ ecnMask = 0b000000_11
+ ecnNotECT = 0b000000_00
+ ecnECT1 = 0b000000_01
+ ecnECT0 = 0b000000_10
+ ecnCE = 0b000000_11
+)
+
var datagramPool = sync.Pool{
New: func() any {
return &datagram{
@@ -26,7 +41,9 @@ var datagramPool = sync.Pool{
func newDatagram() *datagram {
m := datagramPool.Get().(*datagram)
- m.b = m.b[:cap(m.b)]
+ *m = datagram{
+ b: m.b[:cap(m.b)],
+ }
return m
}
diff --git a/quic/doc.go b/quic/doc.go
new file mode 100644
index 000000000..2fd10f087
--- /dev/null
+++ b/quic/doc.go
@@ -0,0 +1,45 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package quic implements the QUIC protocol.
+//
+// This package is a work in progress.
+// It is not ready for production usage.
+// Its API is subject to change without notice.
+//
+// This package is low-level.
+// Most users will use it indirectly through an HTTP/3 implementation.
+//
+// # Usage
+//
+// An [Endpoint] sends and receives traffic on a network address.
+// Create an Endpoint to either accept inbound QUIC connections
+// or create outbound ones.
+//
+// A [Conn] is a QUIC connection.
+//
+// A [Stream] is a QUIC stream, an ordered, reliable byte stream.
+//
+// # Cancelation
+//
+// All blocking operations may be canceled using a context.Context.
+// When performing an operation with a canceled context, the operation
+// will succeed if doing so does not require blocking. For example,
+// reading from a stream will return data when buffered data is available,
+// even if the stream context is canceled.
+//
+// # Limitations
+//
+// This package is a work in progress.
+// Known limitations include:
+//
+// - Performance is untuned.
+// - 0-RTT is not supported.
+// - Address migration is not supported.
+// - Server preferred addresses are not supported.
+// - The latency spin bit is not supported.
+// - Stream send/receive windows are configurable,
+// but are fixed and do not adapt to available throughput.
+// - Path MTU discovery is not implemented.
+package quic
diff --git a/internal/quic/listener.go b/quic/endpoint.go
similarity index 60%
rename from internal/quic/listener.go
rename to quic/endpoint.go
index ca8f9b25a..a55336b24 100644
--- a/internal/quic/listener.go
+++ b/quic/endpoint.go
@@ -17,16 +17,16 @@ import (
"time"
)
-// A Listener listens for QUIC traffic on a network address.
+// An Endpoint handles QUIC traffic on a network address.
// It can accept inbound connections or create outbound ones.
//
-// Multiple goroutines may invoke methods on a Listener simultaneously.
-type Listener struct {
- config *Config
- udpConn udpConn
- testHooks listenerTestHooks
- resetGen statelessResetTokenGenerator
- retry retryState
+// Multiple goroutines may invoke methods on an Endpoint simultaneously.
+type Endpoint struct {
+ listenConfig *Config
+ packetConn packetConn
+ testHooks endpointTestHooks
+ resetGen statelessResetTokenGenerator
+ retry retryState
acceptQueue queue[*Conn] // new inbound connections
connsMap connsMap // only accessed by the listen loop
@@ -37,24 +37,25 @@ type Listener struct {
closec chan struct{} // closed when the listen loop exits
}
-type listenerTestHooks interface {
+type endpointTestHooks interface {
timeNow() time.Time
newConn(c *Conn)
}
-// A udpConn is a UDP connection.
-// It is implemented by net.UDPConn.
-type udpConn interface {
+// A packetConn is the interface to sending and receiving UDP packets.
+type packetConn interface {
Close() error
- LocalAddr() net.Addr
- ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
- WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
+ LocalAddr() netip.AddrPort
+ Read(f func(*datagram))
+ Write(datagram) error
}
// Listen listens on a local network address.
-// The configuration config must be non-nil.
-func Listen(network, address string, config *Config) (*Listener, error) {
- if config.TLSConfig == nil {
+//
+// The config is used to for connections accepted by the endpoint.
+// If the config is nil, the endpoint will not accept connections.
+func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
+ if listenConfig != nil && listenConfig.TLSConfig == nil {
return nil, errors.New("TLSConfig is not set")
}
a, err := net.ResolveUDPAddr(network, address)
@@ -65,82 +66,96 @@ func Listen(network, address string, config *Config) (*Listener, error) {
if err != nil {
return nil, err
}
- return newListener(udpConn, config, nil)
+ pc, err := newNetUDPConn(udpConn)
+ if err != nil {
+ return nil, err
+ }
+ return newEndpoint(pc, listenConfig, nil)
}
-func newListener(udpConn udpConn, config *Config, hooks listenerTestHooks) (*Listener, error) {
- l := &Listener{
- config: config,
- udpConn: udpConn,
- testHooks: hooks,
- conns: make(map[*Conn]struct{}),
- acceptQueue: newQueue[*Conn](),
- closec: make(chan struct{}),
- }
- l.resetGen.init(config.StatelessResetKey)
- l.connsMap.init()
- if config.RequireAddressValidation {
- if err := l.retry.init(); err != nil {
+func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
+ e := &Endpoint{
+ listenConfig: config,
+ packetConn: pc,
+ testHooks: hooks,
+ conns: make(map[*Conn]struct{}),
+ acceptQueue: newQueue[*Conn](),
+ closec: make(chan struct{}),
+ }
+ var statelessResetKey [32]byte
+ if config != nil {
+ statelessResetKey = config.StatelessResetKey
+ }
+ e.resetGen.init(statelessResetKey)
+ e.connsMap.init()
+ if config != nil && config.RequireAddressValidation {
+ if err := e.retry.init(); err != nil {
return nil, err
}
}
- go l.listen()
- return l, nil
+ go e.listen()
+ return e, nil
}
// LocalAddr returns the local network address.
-func (l *Listener) LocalAddr() netip.AddrPort {
- a, _ := l.udpConn.LocalAddr().(*net.UDPAddr)
- return a.AddrPort()
+func (e *Endpoint) LocalAddr() netip.AddrPort {
+ return e.packetConn.LocalAddr()
}
-// Close closes the listener.
-// Any blocked operations on the Listener or associated Conns and Stream will be unblocked
+// Close closes the Endpoint.
+// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked
// and return errors.
//
// Close aborts every open connection.
// Data in stream read and write buffers is discarded.
// It waits for the peers of any open connection to acknowledge the connection has been closed.
-func (l *Listener) Close(ctx context.Context) error {
- l.acceptQueue.close(errors.New("listener closed"))
- l.connsMu.Lock()
- if !l.closing {
- l.closing = true
- for c := range l.conns {
- c.Abort(localTransportError{code: errNo})
+func (e *Endpoint) Close(ctx context.Context) error {
+ e.acceptQueue.close(errors.New("endpoint closed"))
+
+ // It isn't safe to call Conn.Abort or conn.exit with connsMu held,
+ // so copy the list of conns.
+ var conns []*Conn
+ e.connsMu.Lock()
+ if !e.closing {
+ e.closing = true // setting e.closing prevents new conns from being created
+ for c := range e.conns {
+ conns = append(conns, c)
}
- if len(l.conns) == 0 {
- l.udpConn.Close()
+ if len(e.conns) == 0 {
+ e.packetConn.Close()
}
}
- l.connsMu.Unlock()
+ e.connsMu.Unlock()
+
+ for _, c := range conns {
+ c.Abort(localTransportError{code: errNo})
+ }
select {
- case <-l.closec:
+ case <-e.closec:
case <-ctx.Done():
- l.connsMu.Lock()
- for c := range l.conns {
+ for _, c := range conns {
c.exit()
}
- l.connsMu.Unlock()
return ctx.Err()
}
return nil
}
-// Accept waits for and returns the next connection to the listener.
-func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
- return l.acceptQueue.get(ctx, nil)
+// Accept waits for and returns the next connection.
+func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
+ return e.acceptQueue.get(ctx, nil)
}
// Dial creates and returns a connection to a network address.
-func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) {
+// The config cannot be nil.
+func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
u, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
addr := u.AddrPort()
addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
- c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr)
+ c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
if err != nil {
return nil, err
}
@@ -151,29 +166,29 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er
return c, nil
}
-func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) {
- l.connsMu.Lock()
- defer l.connsMu.Unlock()
- if l.closing {
- return nil, errors.New("listener closed")
+func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
+ e.connsMu.Lock()
+ defer e.connsMu.Unlock()
+ if e.closing {
+ return nil, errors.New("endpoint closed")
}
- c, err := newConn(now, side, cids, peerAddr, l.config, l)
+ c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
if err != nil {
return nil, err
}
- l.conns[c] = struct{}{}
+ e.conns[c] = struct{}{}
return c, nil
}
// serverConnEstablished is called by a conn when the handshake completes
// for an inbound (serverSide) connection.
-func (l *Listener) serverConnEstablished(c *Conn) {
- l.acceptQueue.put(c)
+func (e *Endpoint) serverConnEstablished(c *Conn) {
+ e.acceptQueue.put(c)
}
// connDrained is called by a conn when it leaves the draining state,
// either when the peer acknowledges connection closure or the drain timeout expires.
-func (l *Listener) connDrained(c *Conn) {
+func (e *Endpoint) connDrained(c *Conn) {
var cids [][]byte
for i := range c.connIDState.local {
cids = append(cids, c.connIDState.local[i].cid)
@@ -182,7 +197,7 @@ func (l *Listener) connDrained(c *Conn) {
for i := range c.connIDState.remote {
tokens = append(tokens, c.connIDState.remote[i].resetToken)
}
- l.connsMap.updateConnIDs(func(conns *connsMap) {
+ e.connsMap.updateConnIDs(func(conns *connsMap) {
for _, cid := range cids {
conns.retireConnID(c, cid)
}
@@ -190,60 +205,44 @@ func (l *Listener) connDrained(c *Conn) {
conns.retireResetToken(c, token)
}
})
- l.connsMu.Lock()
- defer l.connsMu.Unlock()
- delete(l.conns, c)
- if l.closing && len(l.conns) == 0 {
- l.udpConn.Close()
+ e.connsMu.Lock()
+ defer e.connsMu.Unlock()
+ delete(e.conns, c)
+ if e.closing && len(e.conns) == 0 {
+ e.packetConn.Close()
}
}
-func (l *Listener) listen() {
- defer close(l.closec)
- for {
- m := newDatagram()
- // TODO: Read and process the ECN (explicit congestion notification) field.
- // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4
- n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil)
- if err != nil {
- // The user has probably closed the listener.
- // We currently don't surface errors from other causes;
- // we could check to see if the listener has been closed and
- // record the unexpected error if it has not.
- return
- }
- if n == 0 {
- continue
+func (e *Endpoint) listen() {
+ defer close(e.closec)
+ e.packetConn.Read(func(m *datagram) {
+ if e.connsMap.updateNeeded.Load() {
+ e.connsMap.applyUpdates()
}
- if l.connsMap.updateNeeded.Load() {
- l.connsMap.applyUpdates()
- }
- m.addr = addr
- m.b = m.b[:n]
- l.handleDatagram(m)
- }
+ e.handleDatagram(m)
+ })
}
-func (l *Listener) handleDatagram(m *datagram) {
+func (e *Endpoint) handleDatagram(m *datagram) {
dstConnID, ok := dstConnIDForDatagram(m.b)
if !ok {
m.recycle()
return
}
- c := l.connsMap.byConnID[string(dstConnID)]
+ c := e.connsMap.byConnID[string(dstConnID)]
if c == nil {
// TODO: Move this branch into a separate goroutine to avoid blocking
- // the listener while processing packets.
- l.handleUnknownDestinationDatagram(m)
+ // the endpoint while processing packets.
+ e.handleUnknownDestinationDatagram(m)
return
}
- // TODO: This can block the listener while waiting for the conn to accept the dgram.
+ // TODO: This can block the endpoint while waiting for the conn to accept the dgram.
// Think about buffering between the receive loop and the conn.
c.sendMsg(m)
}
-func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
+func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
defer func() {
if m != nil {
m.recycle()
@@ -254,15 +253,15 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
return
}
var now time.Time
- if l.testHooks != nil {
- now = l.testHooks.timeNow()
+ if e.testHooks != nil {
+ now = e.testHooks.timeNow()
} else {
now = time.Now()
}
// Check to see if this is a stateless reset.
var token statelessResetToken
copy(token[:], m.b[len(m.b)-len(token):])
- if c := l.connsMap.byResetToken[token]; c != nil {
+ if c := e.connsMap.byResetToken[token]; c != nil {
c.sendMsg(func(now time.Time, c *Conn) {
c.handleStatelessReset(now, token)
})
@@ -271,7 +270,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
// If this is a 1-RTT packet, there's nothing productive we can do with it.
// Send a stateless reset if possible.
if !isLongHeader(m.b[0]) {
- l.maybeSendStatelessReset(m.b, m.addr)
+ e.maybeSendStatelessReset(m.b, m.peerAddr)
return
}
p, ok := parseGenericLongHeaderPacket(m.b)
@@ -285,7 +284,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
return
default:
// Unknown version.
- l.sendVersionNegotiation(p, m.addr)
+ e.sendVersionNegotiation(p, m.peerAddr)
return
}
if getPacketType(m.b) != packetTypeInitial {
@@ -296,14 +295,18 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
return
}
+ if e.listenConfig == nil {
+ // We are not configured to accept connections.
+ return
+ }
cids := newServerConnIDs{
srcConnID: p.srcConnID,
dstConnID: p.dstConnID,
}
- if l.config.RequireAddressValidation {
+ if e.listenConfig.RequireAddressValidation {
var ok bool
cids.retrySrcConnID = p.dstConnID
- cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr)
+ cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
if !ok {
return
}
@@ -311,7 +314,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
cids.originalDstConnID = p.dstConnID
}
var err error
- c, err := l.newConn(now, serverSide, cids, m.addr)
+ c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
if err != nil {
// The accept queue is probably full.
// We could send a CONNECTION_CLOSE to the peer to reject the connection.
@@ -323,8 +326,8 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
m = nil // don't recycle, sendMsg takes ownership
}
-func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
- if !l.resetGen.canReset {
+func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
+ if !e.resetGen.canReset {
// Config.StatelessResetKey isn't set, so we don't send stateless resets.
return
}
@@ -339,7 +342,7 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
}
// TODO: Rate limit stateless resets.
cid := b[1:][:connIDLen]
- token := l.resetGen.tokenForConnID(cid)
+ token := e.resetGen.tokenForConnID(cid)
// We want to generate a stateless reset that is as short as possible,
// but long enough to be difficult to distinguish from a 1-RTT packet.
//
@@ -364,17 +367,21 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
b[0] &^= headerFormLong // clear long header bit
b[0] |= fixedBit // set fixed bit
copy(b[len(b)-statelessResetTokenLen:], token[:])
- l.sendDatagram(b, addr)
+ e.sendDatagram(datagram{
+ b: b,
+ peerAddr: peerAddr,
+ })
}
-func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
+func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
m := newDatagram()
m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
- l.sendDatagram(m.b, addr)
+ m.peerAddr = peerAddr
+ e.sendDatagram(*m)
m.recycle()
}
-func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) {
+func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
keys := initialKeys(in.dstConnID, serverSide)
var w packetWriter
p := longPacket{
@@ -393,15 +400,17 @@ func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort
if len(buf) == 0 {
return
}
- l.sendDatagram(buf, addr)
+ e.sendDatagram(datagram{
+ b: buf,
+ peerAddr: peerAddr,
+ })
}
-func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error {
- _, err := l.udpConn.WriteToUDPAddrPort(p, addr)
- return err
+func (e *Endpoint) sendDatagram(dgram datagram) error {
+ return e.packetConn.Write(dgram)
}
-// A connsMap is a listener's mapping of conn ids and reset tokens to conns.
+// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns.
type connsMap struct {
byConnID map[string]*Conn
byResetToken map[statelessResetToken]*Conn
diff --git a/quic/endpoint_test.go b/quic/endpoint_test.go
new file mode 100644
index 000000000..d5f436e6d
--- /dev/null
+++ b/quic/endpoint_test.go
@@ -0,0 +1,330 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "io"
+ "log/slog"
+ "net/netip"
+ "testing"
+ "time"
+
+ "golang.org/x/net/quic/qlog"
+)
+
+func TestConnect(t *testing.T) {
+ newLocalConnPair(t, &Config{}, &Config{})
+}
+
+func TestStreamTransfer(t *testing.T) {
+ ctx := context.Background()
+ cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+ data := makeTestData(1 << 20)
+
+ srvdone := make(chan struct{})
+ go func() {
+ defer close(srvdone)
+ s, err := srv.AcceptStream(ctx)
+ if err != nil {
+ t.Errorf("AcceptStream: %v", err)
+ return
+ }
+ b, err := io.ReadAll(s)
+ if err != nil {
+ t.Errorf("io.ReadAll(s): %v", err)
+ return
+ }
+ if !bytes.Equal(b, data) {
+ t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Errorf("s.Close() = %v", err)
+ }
+ }()
+
+ s, err := cli.NewSendOnlyStream(ctx)
+ if err != nil {
+ t.Fatalf("NewStream: %v", err)
+ }
+ n, err := io.Copy(s, bytes.NewBuffer(data))
+ if n != int64(len(data)) || err != nil {
+ t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Fatalf("s.Close() = %v", err)
+ }
+}
+
+func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
+ t.Helper()
+ ctx := context.Background()
+ e1 := newLocalEndpoint(t, serverSide, conf1)
+ e2 := newLocalEndpoint(t, clientSide, conf2)
+ conf2 = makeTestConfig(conf2, clientSide)
+ c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String(), conf2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1, err := e1.Accept(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c2, c1
+}
+
+func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint {
+ t.Helper()
+ conf = makeTestConfig(conf, side)
+ e, err := Listen("udp", "127.0.0.1:0", conf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(canceledContext())
+ })
+ return e
+}
+
+func makeTestConfig(conf *Config, side connSide) *Config {
+ if conf == nil {
+ return nil
+ }
+ newConf := *conf
+ conf = &newConf
+ if conf.TLSConfig == nil {
+ conf.TLSConfig = newTestTLSConfig(side)
+ }
+ if conf.QLogLogger == nil {
+ conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ Dir: *qlogdir,
+ }))
+ }
+ return conf
+}
+
+type testEndpoint struct {
+ t *testing.T
+ e *Endpoint
+ now time.Time
+ recvc chan *datagram
+ idlec chan struct{}
+ conns map[*Conn]*testConn
+ acceptQueue []*testConn
+ configTransportParams []func(*transportParameters)
+ configTestConn []func(*testConn)
+ sentDatagrams [][]byte
+ peerTLSConn *tls.QUICConn
+ lastInitialDstConnID []byte // for parsing Retry packets
+}
+
+func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
+ te := &testEndpoint{
+ t: t,
+ now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
+ recvc: make(chan *datagram),
+ idlec: make(chan struct{}),
+ conns: make(map[*Conn]*testConn),
+ }
+ var err error
+ te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(te.cleanup)
+ return te
+}
+
+func (te *testEndpoint) cleanup() {
+ te.e.Close(canceledContext())
+}
+
+func (te *testEndpoint) wait() {
+ select {
+ case te.idlec <- struct{}{}:
+ case <-te.e.closec:
+ }
+ for _, tc := range te.conns {
+ tc.wait()
+ }
+}
+
+// accept returns a server connection from the endpoint.
+// Unlike Endpoint.Accept, connections are available as soon as they are created.
+func (te *testEndpoint) accept() *testConn {
+ if len(te.acceptQueue) == 0 {
+ te.t.Fatalf("accept: expected available conn, but found none")
+ }
+ tc := te.acceptQueue[0]
+ te.acceptQueue = te.acceptQueue[1:]
+ return tc
+}
+
+func (te *testEndpoint) write(d *datagram) {
+ te.recvc <- d
+ te.wait()
+}
+
+var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
+
+func (te *testEndpoint) writeDatagram(d *testDatagram) {
+ te.t.Helper()
+ logDatagram(te.t, "<- endpoint under test receives", d)
+ var buf []byte
+ for _, p := range d.packets {
+ tc := te.connForDestination(p.dstConnID)
+ if p.ptype != packetTypeRetry && tc != nil {
+ space := spaceForPacketType(p.ptype)
+ if p.num >= tc.peerNextPacketNum[space] {
+ tc.peerNextPacketNum[space] = p.num + 1
+ }
+ }
+ if p.ptype == packetTypeInitial {
+ te.lastInitialDstConnID = p.dstConnID
+ }
+ pad := 0
+ if p.ptype == packetType1RTT {
+ pad = d.paddedSize - len(buf)
+ }
+ buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
+ }
+ for len(buf) < d.paddedSize {
+ buf = append(buf, 0)
+ }
+ te.write(&datagram{
+ b: buf,
+ peerAddr: d.addr,
+ })
+}
+
+func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
+ for _, tc := range te.conns {
+ for _, loc := range tc.conn.connIDState.local {
+ if bytes.Equal(loc.cid, dstConnID) {
+ return tc
+ }
+ }
+ }
+ return nil
+}
+
+func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
+ for _, tc := range te.conns {
+ for _, loc := range tc.conn.connIDState.remote {
+ if bytes.Equal(loc.cid, srcConnID) {
+ return tc
+ }
+ }
+ }
+ return nil
+}
+
+func (te *testEndpoint) read() []byte {
+ te.t.Helper()
+ te.wait()
+ if len(te.sentDatagrams) == 0 {
+ return nil
+ }
+ d := te.sentDatagrams[0]
+ te.sentDatagrams = te.sentDatagrams[1:]
+ return d
+}
+
+func (te *testEndpoint) readDatagram() *testDatagram {
+ te.t.Helper()
+ buf := te.read()
+ if buf == nil {
+ return nil
+ }
+ p, _ := parseGenericLongHeaderPacket(buf)
+ tc := te.connForSource(p.dstConnID)
+ d := parseTestDatagram(te.t, te, tc, buf)
+ logDatagram(te.t, "-> endpoint under test sends", d)
+ return d
+}
+
+// wantDatagram indicates that we expect the Endpoint to send a datagram.
+func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
+ te.t.Helper()
+ got := te.readDatagram()
+ if !datagramEqual(got, want) {
+ te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
+ }
+}
+
+// wantIdle indicates that we expect the Endpoint to not send any more datagrams.
+func (te *testEndpoint) wantIdle(expectation string) {
+ if got := te.readDatagram(); got != nil {
+ te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
+ }
+}
+
+// advance causes time to pass.
+func (te *testEndpoint) advance(d time.Duration) {
+ te.t.Helper()
+ te.advanceTo(te.now.Add(d))
+}
+
+// advanceTo sets the current time.
+func (te *testEndpoint) advanceTo(now time.Time) {
+ te.t.Helper()
+ if te.now.After(now) {
+ te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
+ }
+ te.now = now
+ for _, tc := range te.conns {
+ if !tc.timer.After(te.now) {
+ tc.conn.sendMsg(timerEvent{})
+ tc.wait()
+ }
+ }
+}
+
+// testEndpointHooks implements endpointTestHooks.
+type testEndpointHooks testEndpoint
+
+func (te *testEndpointHooks) timeNow() time.Time {
+ return te.now
+}
+
+func (te *testEndpointHooks) newConn(c *Conn) {
+ tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
+ te.conns[c] = tc
+}
+
+// testEndpointUDPConn implements UDPConn.
+type testEndpointUDPConn testEndpoint
+
+func (te *testEndpointUDPConn) Close() error {
+ close(te.recvc)
+ return nil
+}
+
+func (te *testEndpointUDPConn) LocalAddr() netip.AddrPort {
+ return netip.MustParseAddrPort("127.0.0.1:443")
+}
+
+func (te *testEndpointUDPConn) Read(f func(*datagram)) {
+ for {
+ select {
+ case d, ok := <-te.recvc:
+ if !ok {
+ return
+ }
+ f(d)
+ case <-te.idlec:
+ }
+ }
+}
+
+func (te *testEndpointUDPConn) Write(dgram datagram) error {
+ te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...))
+ return nil
+}
diff --git a/internal/quic/errors.go b/quic/errors.go
similarity index 100%
rename from internal/quic/errors.go
rename to quic/errors.go
diff --git a/internal/quic/files_test.go b/quic/files_test.go
similarity index 100%
rename from internal/quic/files_test.go
rename to quic/files_test.go
diff --git a/internal/quic/frame_debug.go b/quic/frame_debug.go
similarity index 68%
rename from internal/quic/frame_debug.go
rename to quic/frame_debug.go
index dc8009037..17234dd7c 100644
--- a/internal/quic/frame_debug.go
+++ b/quic/frame_debug.go
@@ -8,6 +8,9 @@ package quic
import (
"fmt"
+ "log/slog"
+ "strconv"
+ "time"
)
// A debugFrame is a representation of the contents of a QUIC frame,
@@ -15,6 +18,7 @@ import (
type debugFrame interface {
String() string
write(w *packetWriter) bool
+ LogValue() slog.Value
}
func parseDebugFrame(b []byte) (f debugFrame, n int) {
@@ -73,6 +77,7 @@ func parseDebugFrame(b []byte) (f debugFrame, n int) {
// debugFramePadding is a sequence of PADDING frames.
type debugFramePadding struct {
size int
+ to int // alternate for writing packets: pad to
}
func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) {
@@ -91,12 +96,23 @@ func (f debugFramePadding) write(w *packetWriter) bool {
if w.avail() == 0 {
return false
}
+ if f.to > 0 {
+ w.appendPaddingTo(f.to)
+ return true
+ }
for i := 0; i < f.size && w.avail() > 0; i++ {
w.b = append(w.b, frameTypePadding)
}
return true
}
+func (f debugFramePadding) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "padding"),
+ slog.Int("length", f.size),
+ )
+}
+
// debugFramePing is a PING frame.
type debugFramePing struct{}
@@ -112,6 +128,12 @@ func (f debugFramePing) write(w *packetWriter) bool {
return w.appendPingFrame()
}
+func (f debugFramePing) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "ping"),
+ )
+}
+
// debugFrameAck is an ACK frame.
type debugFrameAck struct {
ackDelay unscaledAckDelay
@@ -126,7 +148,7 @@ func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) {
end: end,
})
})
- // Ranges are parsed smallest to highest; reverse ranges slice to order them high to low.
+ // Ranges are parsed high to low; reverse ranges slice to order them low to high.
for i := 0; i < len(f.ranges)/2; i++ {
j := len(f.ranges) - 1
f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i]
@@ -146,6 +168,61 @@ func (f debugFrameAck) write(w *packetWriter) bool {
return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay)
}
+func (f debugFrameAck) LogValue() slog.Value {
+ return slog.StringValue("error: debugFrameAck should not appear as a slog Value")
+}
+
+// debugFrameScaledAck is an ACK frame with scaled ACK Delay.
+//
+// This type is used in qlog events, which need access to the delay as a duration.
+type debugFrameScaledAck struct {
+ ackDelay time.Duration
+ ranges []i64range[packetNumber]
+}
+
+func (f debugFrameScaledAck) LogValue() slog.Value {
+ var ackDelay slog.Attr
+ if f.ackDelay >= 0 {
+ ackDelay = slog.Duration("ack_delay", f.ackDelay)
+ }
+ return slog.GroupValue(
+ slog.String("frame_type", "ack"),
+ // Rather than trying to convert the ack ranges into the slog data model,
+ // pass a value that can JSON-encode itself.
+ slog.Any("acked_ranges", debugAckRanges(f.ranges)),
+ ackDelay,
+ )
+}
+
+type debugAckRanges []i64range[packetNumber]
+
+// AppendJSON appends a JSON encoding of the ack ranges to b, and returns it.
+// This is different than the standard json.Marshaler, but more efficient.
+// Since we only use this in cooperation with the qlog package,
+// encoding/json compatibility is irrelevant.
+func (r debugAckRanges) AppendJSON(b []byte) []byte {
+ b = append(b, '[')
+ for i, ar := range r {
+ start, end := ar.start, ar.end-1 // qlog ranges are closed-closed
+ if i != 0 {
+ b = append(b, ',')
+ }
+ b = append(b, '[')
+ b = strconv.AppendInt(b, int64(start), 10)
+ if start != end {
+ b = append(b, ',')
+ b = strconv.AppendInt(b, int64(end), 10)
+ }
+ b = append(b, ']')
+ }
+ b = append(b, ']')
+ return b
+}
+
+func (r debugAckRanges) String() string {
+ return string(r.AppendJSON(nil))
+}
+
// debugFrameResetStream is a RESET_STREAM frame.
type debugFrameResetStream struct {
id streamID
@@ -166,6 +243,14 @@ func (f debugFrameResetStream) write(w *packetWriter) bool {
return w.appendResetStreamFrame(f.id, f.code, f.finalSize)
}
+func (f debugFrameResetStream) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "reset_stream"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Uint64("final_size", uint64(f.finalSize)),
+ )
+}
+
// debugFrameStopSending is a STOP_SENDING frame.
type debugFrameStopSending struct {
id streamID
@@ -185,6 +270,14 @@ func (f debugFrameStopSending) write(w *packetWriter) bool {
return w.appendStopSendingFrame(f.id, f.code)
}
+func (f debugFrameStopSending) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "stop_sending"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Uint64("error_code", uint64(f.code)),
+ )
+}
+
// debugFrameCrypto is a CRYPTO frame.
type debugFrameCrypto struct {
off int64
@@ -206,6 +299,14 @@ func (f debugFrameCrypto) write(w *packetWriter) bool {
return added
}
+func (f debugFrameCrypto) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "crypto"),
+ slog.Int64("offset", f.off),
+ slog.Int("length", len(f.data)),
+ )
+}
+
// debugFrameNewToken is a NEW_TOKEN frame.
type debugFrameNewToken struct {
token []byte
@@ -224,6 +325,13 @@ func (f debugFrameNewToken) write(w *packetWriter) bool {
return w.appendNewTokenFrame(f.token)
}
+func (f debugFrameNewToken) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "new_token"),
+ slogHexstring("token", f.token),
+ )
+}
+
// debugFrameStream is a STREAM frame.
type debugFrameStream struct {
id streamID
@@ -251,6 +359,20 @@ func (f debugFrameStream) write(w *packetWriter) bool {
return added
}
+func (f debugFrameStream) LogValue() slog.Value {
+ var fin slog.Attr
+ if f.fin {
+ fin = slog.Bool("fin", true)
+ }
+ return slog.GroupValue(
+ slog.String("frame_type", "stream"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("offset", f.off),
+ slog.Int("length", len(f.data)),
+ fin,
+ )
+}
+
// debugFrameMaxData is a MAX_DATA frame.
type debugFrameMaxData struct {
max int64
@@ -269,6 +391,13 @@ func (f debugFrameMaxData) write(w *packetWriter) bool {
return w.appendMaxDataFrame(f.max)
}
+func (f debugFrameMaxData) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_data"),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameMaxStreamData is a MAX_STREAM_DATA frame.
type debugFrameMaxStreamData struct {
id streamID
@@ -288,6 +417,14 @@ func (f debugFrameMaxStreamData) write(w *packetWriter) bool {
return w.appendMaxStreamDataFrame(f.id, f.max)
}
+func (f debugFrameMaxStreamData) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_stream_data"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameMaxStreams is a MAX_STREAMS frame.
type debugFrameMaxStreams struct {
streamType streamType
@@ -307,6 +444,14 @@ func (f debugFrameMaxStreams) write(w *packetWriter) bool {
return w.appendMaxStreamsFrame(f.streamType, f.max)
}
+func (f debugFrameMaxStreams) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "max_streams"),
+ slog.String("stream_type", f.streamType.qlogString()),
+ slog.Int64("maximum", f.max),
+ )
+}
+
// debugFrameDataBlocked is a DATA_BLOCKED frame.
type debugFrameDataBlocked struct {
max int64
@@ -325,6 +470,13 @@ func (f debugFrameDataBlocked) write(w *packetWriter) bool {
return w.appendDataBlockedFrame(f.max)
}
+func (f debugFrameDataBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "data_blocked"),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame.
type debugFrameStreamDataBlocked struct {
id streamID
@@ -344,6 +496,14 @@ func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool {
return w.appendStreamDataBlockedFrame(f.id, f.max)
}
+func (f debugFrameStreamDataBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "stream_data_blocked"),
+ slog.Uint64("stream_id", uint64(f.id)),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameStreamsBlocked is a STREAMS_BLOCKED frame.
type debugFrameStreamsBlocked struct {
streamType streamType
@@ -363,6 +523,14 @@ func (f debugFrameStreamsBlocked) write(w *packetWriter) bool {
return w.appendStreamsBlockedFrame(f.streamType, f.max)
}
+func (f debugFrameStreamsBlocked) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "streams_blocked"),
+ slog.String("stream_type", f.streamType.qlogString()),
+ slog.Int64("limit", f.max),
+ )
+}
+
// debugFrameNewConnectionID is a NEW_CONNECTION_ID frame.
type debugFrameNewConnectionID struct {
seq int64
@@ -384,6 +552,16 @@ func (f debugFrameNewConnectionID) write(w *packetWriter) bool {
return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token)
}
+func (f debugFrameNewConnectionID) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "new_connection_id"),
+ slog.Int64("sequence_number", f.seq),
+ slog.Int64("retire_prior_to", f.retirePriorTo),
+ slogHexstring("connection_id", f.connID),
+ slogHexstring("stateless_reset_token", f.token[:]),
+ )
+}
+
// debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame.
type debugFrameRetireConnectionID struct {
seq int64
@@ -402,9 +580,16 @@ func (f debugFrameRetireConnectionID) write(w *packetWriter) bool {
return w.appendRetireConnectionIDFrame(f.seq)
}
+func (f debugFrameRetireConnectionID) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "retire_connection_id"),
+ slog.Int64("sequence_number", f.seq),
+ )
+}
+
// debugFramePathChallenge is a PATH_CHALLENGE frame.
type debugFramePathChallenge struct {
- data uint64
+ data pathChallengeData
}
func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) {
@@ -413,16 +598,23 @@ func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) {
}
func (f debugFramePathChallenge) String() string {
- return fmt.Sprintf("PATH_CHALLENGE Data=%016x", f.data)
+ return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data)
}
func (f debugFramePathChallenge) write(w *packetWriter) bool {
return w.appendPathChallengeFrame(f.data)
}
+func (f debugFramePathChallenge) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "path_challenge"),
+ slog.String("data", fmt.Sprintf("%x", f.data)),
+ )
+}
+
// debugFramePathResponse is a PATH_RESPONSE frame.
type debugFramePathResponse struct {
- data uint64
+ data pathChallengeData
}
func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) {
@@ -431,13 +623,20 @@ func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) {
}
func (f debugFramePathResponse) String() string {
- return fmt.Sprintf("PATH_RESPONSE Data=%016x", f.data)
+ return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data)
}
func (f debugFramePathResponse) write(w *packetWriter) bool {
return w.appendPathResponseFrame(f.data)
}
+func (f debugFramePathResponse) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "path_response"),
+ slog.String("data", fmt.Sprintf("%x", f.data)),
+ )
+}
+
// debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error.
type debugFrameConnectionCloseTransport struct {
code transportError
@@ -465,6 +664,15 @@ func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool {
return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason)
}
+func (f debugFrameConnectionCloseTransport) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "connection_close"),
+ slog.String("error_space", "transport"),
+ slog.Uint64("error_code_value", uint64(f.code)),
+ slog.String("reason", f.reason),
+ )
+}
+
// debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error.
type debugFrameConnectionCloseApplication struct {
code uint64
@@ -488,6 +696,15 @@ func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool {
return w.appendConnectionCloseApplicationFrame(f.code, f.reason)
}
+func (f debugFrameConnectionCloseApplication) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "connection_close"),
+ slog.String("error_space", "application"),
+ slog.Uint64("error_code_value", uint64(f.code)),
+ slog.String("reason", f.reason),
+ )
+}
+
// debugFrameHandshakeDone is a HANDSHAKE_DONE frame.
type debugFrameHandshakeDone struct{}
@@ -502,3 +719,9 @@ func (f debugFrameHandshakeDone) String() string {
func (f debugFrameHandshakeDone) write(w *packetWriter) bool {
return w.appendHandshakeDoneFrame()
}
+
+func (f debugFrameHandshakeDone) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("frame_type", "handshake_done"),
+ )
+}
diff --git a/internal/quic/gate.go b/quic/gate.go
similarity index 100%
rename from internal/quic/gate.go
rename to quic/gate.go
diff --git a/internal/quic/gate_test.go b/quic/gate_test.go
similarity index 100%
rename from internal/quic/gate_test.go
rename to quic/gate_test.go
diff --git a/internal/quic/gotraceback_test.go b/quic/gotraceback_test.go
similarity index 100%
rename from internal/quic/gotraceback_test.go
rename to quic/gotraceback_test.go
diff --git a/quic/idle.go b/quic/idle.go
new file mode 100644
index 000000000..f5b2422ad
--- /dev/null
+++ b/quic/idle.go
@@ -0,0 +1,170 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "time"
+)
+
+// idleState tracks connection idle events.
+//
+// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout.
+//
+// After the handshake is confirmed, the idle timeout is
+// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter.
+//
+// If KeepAlivePeriod is set, keep-alive pings are sent.
+// Keep-alives are only sent after the handshake is confirmed.
+//
+// https://www.rfc-editor.org/rfc/rfc9000#section-10.1
+type idleState struct {
+ // idleDuration is the negotiated idle timeout for the connection.
+ idleDuration time.Duration
+
+ // idleTimeout is the time at which the connection will be closed due to inactivity.
+ idleTimeout time.Time
+
+ // nextTimeout is the time of the next idle event.
+ // If nextTimeout == idleTimeout, this is the idle timeout.
+ // Otherwise, this is the keep-alive timeout.
+ nextTimeout time.Time
+
+ // sentSinceLastReceive is set if we have sent an ack-eliciting packet
+ // since the last time we received and processed a packet from the peer.
+ sentSinceLastReceive bool
+}
+
+// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter.
+func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) {
+ localMaxIdleTimeout := c.config.maxIdleTimeout()
+ switch {
+ case localMaxIdleTimeout == 0:
+ c.idle.idleDuration = peerMaxIdleTimeout
+ case peerMaxIdleTimeout == 0:
+ c.idle.idleDuration = localMaxIdleTimeout
+ default:
+ c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout)
+ }
+}
+
+func (c *Conn) idleHandlePacketReceived(now time.Time) {
+ if !c.handshakeConfirmed.isSet() {
+ return
+ }
+ // "An endpoint restarts its idle timer when a packet from its peer is
+ // received and processed successfully."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3
+ c.idle.sentSinceLastReceive = false
+ c.restartIdleTimer(now)
+}
+
+func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) {
+ // "An endpoint also restarts its idle timer when sending an ack-eliciting packet
+ // if no other ack-eliciting packets have been sent since
+ // last receiving and processing a packet."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3
+ if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() {
+ return
+ }
+ c.idle.sentSinceLastReceive = true
+ c.restartIdleTimer(now)
+}
+
+func (c *Conn) restartIdleTimer(now time.Time) {
+ if !c.isAlive() {
+ // Connection is closing, disable timeouts.
+ c.idle.idleTimeout = time.Time{}
+ c.idle.nextTimeout = time.Time{}
+ return
+ }
+ var idleDuration time.Duration
+ if c.handshakeConfirmed.isSet() {
+ idleDuration = c.idle.idleDuration
+ } else {
+ idleDuration = c.config.handshakeTimeout()
+ }
+ if idleDuration == 0 {
+ c.idle.idleTimeout = time.Time{}
+ } else {
+ // "[...] endpoints MUST increase the idle timeout period to be
+ // at least three times the current Probe Timeout (PTO)."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4
+ idleDuration = max(idleDuration, 3*c.loss.ptoPeriod())
+ c.idle.idleTimeout = now.Add(idleDuration)
+ }
+ // Set the time of our next event:
+ // The idle timer if no keep-alive is set, or the keep-alive timer if one is.
+ c.idle.nextTimeout = c.idle.idleTimeout
+ keepAlive := c.config.keepAlivePeriod()
+ switch {
+ case !c.handshakeConfirmed.isSet():
+ // We do not send keep-alives before the handshake is complete.
+ case keepAlive <= 0:
+ // Keep-alives are not enabled.
+ case c.idle.sentSinceLastReceive:
+ // We have sent an ack-eliciting packet to the peer.
+ // If they don't acknowledge it, loss detection will follow up with PTO probes,
+ // which will function as keep-alives.
+ // We don't need to send further pings.
+ case idleDuration == 0:
+ // The connection does not have a negotiated idle timeout.
+ // Send keep-alives anyway, since they may be required to keep middleboxes
+ // from losing state.
+ c.idle.nextTimeout = now.Add(keepAlive)
+ default:
+ // Schedule our next keep-alive.
+ // If our configured keep-alive period is greater than half the negotiated
+ // connection idle timeout, we reduce the keep-alive period to half
+ // the idle timeout to ensure we have time for the ping to arrive.
+ c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2))
+ }
+}
+
+func (c *Conn) appendKeepAlive(now time.Time) bool {
+ if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) {
+ return true // timer has not expired
+ }
+ if c.idle.nextTimeout.Equal(c.idle.idleTimeout) {
+ return true // no keepalive timer set, only idle
+ }
+ if c.idle.sentSinceLastReceive {
+ return true // already sent an ack-eliciting packet
+ }
+ if c.w.sent.ackEliciting {
+ return true // this packet is already ack-eliciting
+ }
+ // Send an ack-eliciting PING frame to the peer to keep the connection alive.
+ return c.w.appendPingFrame()
+}
+
+var errHandshakeTimeout error = localTransportError{
+ code: errConnectionRefused,
+ reason: "handshake timeout",
+}
+
+func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) {
+ if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) {
+ return false
+ }
+ c.idle.idleTimeout = time.Time{}
+ c.idle.nextTimeout = time.Time{}
+ if !c.handshakeConfirmed.isSet() {
+ // Handshake timeout has expired.
+ // If we're a server, we're refusing the too-slow client.
+ // If we're a client, we're giving up.
+ // In either case, we're going to send a CONNECTION_CLOSE frame and
+ // enter the closing state rather than unceremoniously dropping the connection,
+ // since the peer might still be trying to complete the handshake.
+ c.abort(now, errHandshakeTimeout)
+ return false
+ }
+ // Idle timeout has expired.
+ //
+ // "[...] the connection is silently closed and its state is discarded [...]"
+ // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1
+ return true
+}
diff --git a/quic/idle_test.go b/quic/idle_test.go
new file mode 100644
index 000000000..18f6a690a
--- /dev/null
+++ b/quic/idle_test.go
@@ -0,0 +1,225 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "testing"
+ "time"
+)
+
+func TestHandshakeTimeoutExpiresServer(t *testing.T) {
+ const timeout = 5 * time.Second
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.HandshakeTimeout = timeout
+ })
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ // Server starts its end of the handshake.
+ // Client acks these packets to avoid starting the PTO timer.
+ tc.wantFrameType("server sends Initial CRYPTO flight",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.wantFrameType("server sends Handshake CRYPTO flight",
+ packetTypeHandshake, debugFrameCrypto{})
+ tc.writeAckForAll()
+
+ if got, want := tc.timerDelay(), timeout; got != want {
+ t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want)
+ }
+
+ // Client sends a packet, but this does not extend the handshake timer.
+ tc.advance(1 * time.Second)
+ tc.writeFrames(packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], // partial data
+ })
+ tc.wantIdle("handshake is not complete")
+
+ tc.advance(timeout - 1*time.Second)
+ tc.wantFrame("server closes connection after handshake timeout",
+ packetTypeHandshake, debugFrameConnectionCloseTransport{
+ code: errConnectionRefused,
+ })
+}
+
+func TestHandshakeTimeoutExpiresClient(t *testing.T) {
+ const timeout = 5 * time.Second
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ c.HandshakeTimeout = timeout
+ })
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ // Start the handshake.
+ // The client always sets a PTO timer until it gets an ack for a handshake packet
+ // or confirms the handshake, so proceed far enough through the handshake to
+ // let us not worry about PTO.
+ tc.wantFrameType("client sends Initial CRYPTO flight",
+ packetTypeInitial, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrameType("client sends Handshake CRYPTO flight",
+ packetTypeHandshake, debugFrameCrypto{})
+ tc.writeAckForAll()
+ tc.wantIdle("client is waiting for end of handshake")
+
+ if got, want := tc.timerDelay(), timeout; got != want {
+ t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want)
+ }
+ tc.advance(timeout)
+ tc.wantFrame("client closes connection after handshake timeout",
+ packetTypeHandshake, debugFrameConnectionCloseTransport{
+ code: errConnectionRefused,
+ })
+}
+
+func TestIdleTimeoutExpires(t *testing.T) {
+ for _, test := range []struct {
+ localMaxIdleTimeout time.Duration
+ peerMaxIdleTimeout time.Duration
+ wantTimeout time.Duration
+ }{{
+ localMaxIdleTimeout: 10 * time.Second,
+ peerMaxIdleTimeout: 20 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 20 * time.Second,
+ peerMaxIdleTimeout: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 0,
+ peerMaxIdleTimeout: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ localMaxIdleTimeout: 10 * time.Second,
+ peerMaxIdleTimeout: 0,
+ wantTimeout: 10 * time.Second,
+ }} {
+ name := fmt.Sprintf("local=%v/peer=%v", test.localMaxIdleTimeout, test.peerMaxIdleTimeout)
+ t.Run(name, func(t *testing.T) {
+ tc := newTestConn(t, serverSide, func(p *transportParameters) {
+ p.maxIdleTimeout = test.peerMaxIdleTimeout
+ }, func(c *Config) {
+ c.MaxIdleTimeout = test.localMaxIdleTimeout
+ })
+ tc.handshake()
+ if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want {
+ t.Errorf("new conn timeout=%v, want %v (idle timeout)", got, want)
+ }
+ tc.advance(test.wantTimeout - 1)
+ tc.wantIdle("connection is idle and alive prior to timeout")
+ ctx := canceledContext()
+ if err := tc.conn.Wait(ctx); err != context.Canceled {
+ t.Fatalf("conn.Wait() = %v, want Canceled", err)
+ }
+ tc.advance(1)
+ tc.wantIdle("connection exits after timeout")
+ if err := tc.conn.Wait(ctx); err != errIdleTimeout {
+ t.Fatalf("conn.Wait() = %v, want errIdleTimeout", err)
+ }
+ })
+ }
+}
+
+func TestIdleTimeoutKeepAlive(t *testing.T) {
+ for _, test := range []struct {
+ idleTimeout time.Duration
+ keepAlive time.Duration
+ wantTimeout time.Duration
+ }{{
+ idleTimeout: 30 * time.Second,
+ keepAlive: 10 * time.Second,
+ wantTimeout: 10 * time.Second,
+ }, {
+ idleTimeout: 10 * time.Second,
+ keepAlive: 30 * time.Second,
+ wantTimeout: 5 * time.Second,
+ }, {
+ idleTimeout: -1, // disabled
+ keepAlive: 30 * time.Second,
+ wantTimeout: 30 * time.Second,
+ }} {
+ name := fmt.Sprintf("idle_timeout=%v/keepalive=%v", test.idleTimeout, test.keepAlive)
+ t.Run(name, func(t *testing.T) {
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.MaxIdleTimeout = test.idleTimeout
+ c.KeepAlivePeriod = test.keepAlive
+ })
+ tc.handshake()
+ if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want {
+ t.Errorf("new conn timeout=%v, want %v (keepalive timeout)", got, want)
+ }
+ tc.advance(test.wantTimeout - 1)
+ tc.wantIdle("connection is idle prior to timeout")
+ tc.advance(1)
+ tc.wantFrameType("keep-alive ping is sent", packetType1RTT,
+ debugFramePing{})
+ })
+ }
+}
+
+func TestIdleLongTermKeepAliveSent(t *testing.T) {
+ // This test examines a connection sitting idle and sending periodic keep-alive pings.
+ const keepAlivePeriod = 30 * time.Second
+ tc := newTestConn(t, clientSide, func(c *Config) {
+ c.KeepAlivePeriod = keepAlivePeriod
+ c.MaxIdleTimeout = -1
+ })
+ tc.handshake()
+ // The handshake will have completed a little bit after the point at which the
+ // keepalive timer was set. Send two PING frames to the conn, triggering an immediate ack
+ // and resetting the timer.
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ tc.wantFrameType("conn acks received pings", packetType1RTT, debugFrameAck{})
+ for i := 0; i < 10; i++ {
+ tc.wantIdle("conn has nothing more to send")
+ if got, want := tc.timeUntilEvent(), keepAlivePeriod; got != want {
+ t.Errorf("i=%v conn timeout=%v, want %v (keepalive timeout)", i, got, want)
+ }
+ tc.advance(keepAlivePeriod)
+ tc.wantFrameType("keep-alive ping is sent", packetType1RTT,
+ debugFramePing{})
+ tc.writeAckForAll()
+ }
+}
+
+func TestIdleLongTermKeepAliveReceived(t *testing.T) {
+ // This test examines a connection sitting idle, but receiving periodic peer
+ // traffic to keep the connection alive.
+ const idleTimeout = 30 * time.Second
+ tc := newTestConn(t, serverSide, func(c *Config) {
+ c.MaxIdleTimeout = idleTimeout
+ })
+ tc.handshake()
+ for i := 0; i < 10; i++ {
+ tc.advance(idleTimeout - 1*time.Second)
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+ if got, want := tc.timeUntilEvent(), maxAckDelay-timerGranularity; got != want {
+ t.Errorf("i=%v conn timeout=%v, want %v (max_ack_delay)", i, got, want)
+ }
+ tc.advanceToTimer()
+ tc.wantFrameType("conn acks received ping", packetType1RTT, debugFrameAck{})
+ }
+ // Connection is still alive.
+ ctx := canceledContext()
+ if err := tc.conn.Wait(ctx); err != context.Canceled {
+ t.Fatalf("conn.Wait() = %v, want Canceled", err)
+ }
+}
diff --git a/internal/quic/key_update_test.go b/quic/key_update_test.go
similarity index 100%
rename from internal/quic/key_update_test.go
rename to quic/key_update_test.go
diff --git a/internal/quic/log.go b/quic/log.go
similarity index 100%
rename from internal/quic/log.go
rename to quic/log.go
diff --git a/internal/quic/loss.go b/quic/loss.go
similarity index 87%
rename from internal/quic/loss.go
rename to quic/loss.go
index c0f915b42..796b5f7a3 100644
--- a/internal/quic/loss.go
+++ b/quic/loss.go
@@ -7,6 +7,8 @@
package quic
import (
+ "context"
+ "log/slog"
"math"
"time"
)
@@ -50,6 +52,9 @@ type lossState struct {
// https://www.rfc-editor.org/rfc/rfc9000#section-8-2
antiAmplificationLimit int
+ // Count of non-ack-eliciting packets (ACKs) sent since the last ack-eliciting one.
+ consecutiveNonAckElicitingPackets int
+
rtt rttState
pacer pacerState
cc *ccReno
@@ -176,7 +181,7 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber {
}
// packetSent records a sent packet.
-func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacket) {
+func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
sent.time = now
c.spaces[space].add(sent)
size := sent.size
@@ -184,13 +189,21 @@ func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacke
c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size)
}
if sent.inFlight {
- c.cc.packetSent(now, space, sent)
+ c.cc.packetSent(now, log, space, sent)
c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT)
if sent.ackEliciting {
c.spaces[space].lastAckEliciting = sent.num
c.ptoExpired = false // reset expired PTO timer after sending probe
}
c.scheduleTimer(now)
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
+ }
+ if sent.ackEliciting {
+ c.consecutiveNonAckElicitingPackets = 0
+ } else {
+ c.consecutiveNonAckElicitingPackets++
}
}
@@ -259,7 +272,7 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex
// receiveAckEnd finishes processing an ack frame.
// The lossf function is called for each packet newly detected as lost.
-func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) {
+func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) {
c.spaces[space].sentPacketList.clean()
// Update the RTT sample when the largest acknowledged packet in the ACK frame
// is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting.
@@ -278,13 +291,30 @@ func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay tim
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3
c.timer = time.Time{}
c.detectLoss(now, lossf)
- c.cc.packetBatchEnd(now, space, &c.rtt, c.maxAckDelay)
+ c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay)
+
+ if logEnabled(log, QLogLevelPacket) {
+ var ssthresh slog.Attr
+ if c.cc.slowStartThreshold != math.MaxInt {
+ ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold)
+ }
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:metrics_updated",
+ slog.Duration("min_rtt", c.rtt.minRTT),
+ slog.Duration("smoothed_rtt", c.rtt.smoothedRTT),
+ slog.Duration("latest_rtt", c.rtt.latestRTT),
+ slog.Duration("rtt_variance", c.rtt.rttvar),
+ slog.Int("congestion_window", c.cc.congestionWindow),
+ slog.Int("bytes_in_flight", c.cc.bytesInFlight),
+ ssthresh,
+ )
+ }
}
// discardPackets declares that packets within a number space will not be delivered
// and that data contained in them should be resent.
// For example, after receiving a Retry packet we discard already-sent Initial packets.
-func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *sentPacket, packetFate)) {
+func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) {
for i := 0; i < c.spaces[space].size; i++ {
sent := c.spaces[space].nth(i)
sent.lost = true
@@ -292,10 +322,13 @@ func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *s
lossf(numberSpace(space), sent, packetLost)
}
c.spaces[space].clean()
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
}
// discardKeys is called when dropping packet protection keys for a number space.
-func (c *lossState) discardKeys(now time.Time, space numberSpace) {
+func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4
for i := 0; i < c.spaces[space].size; i++ {
sent := c.spaces[space].nth(i)
@@ -305,6 +338,9 @@ func (c *lossState) discardKeys(now time.Time, space numberSpace) {
c.spaces[space].maxAcked = -1
c.spaces[space].lastAckEliciting = -1
c.scheduleTimer(now)
+ if logEnabled(log, QLogLevelPacket) {
+ logBytesInFlight(log, c.cc.bytesInFlight)
+ }
}
func (c *lossState) lossDuration() time.Duration {
@@ -431,12 +467,15 @@ func (c *lossState) scheduleTimer(now time.Time) {
c.timer = time.Time{}
return
}
- // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
- pto := c.ptoBasePeriod() << c.ptoBackoffCount
- c.timer = last.Add(pto)
+ c.timer = last.Add(c.ptoPeriod())
c.ptoTimerArmed = true
}
+func (c *lossState) ptoPeriod() time.Duration {
+ // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
+ return c.ptoBasePeriod() << c.ptoBackoffCount
+}
+
func (c *lossState) ptoBasePeriod() time.Duration {
// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1
pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity)
@@ -448,3 +487,10 @@ func (c *lossState) ptoBasePeriod() time.Duration {
}
return pto
}
+
+func logBytesInFlight(log *slog.Logger, bytesInFlight int) {
+ log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:metrics_updated",
+ slog.Int("bytes_in_flight", bytesInFlight),
+ )
+}
diff --git a/internal/quic/loss_test.go b/quic/loss_test.go
similarity index 99%
rename from internal/quic/loss_test.go
rename to quic/loss_test.go
index efbf1649e..1fb9662e4 100644
--- a/internal/quic/loss_test.go
+++ b/quic/loss_test.go
@@ -1060,7 +1060,7 @@ func TestLossPersistentCongestion(t *testing.T) {
maxDatagramSize: 1200,
})
test.send(initialSpace, 0, testSentPacketSize(1200))
- test.c.cc.setUnderutilized(true)
+ test.c.cc.setUnderutilized(nil, true)
test.advance(10 * time.Millisecond)
test.ack(initialSpace, 0*time.Millisecond, i64range[packetNumber]{0, 1})
@@ -1377,7 +1377,7 @@ func (c *lossTest) setRTTVar(d time.Duration) {
func (c *lossTest) setUnderutilized(v bool) {
c.t.Logf("set congestion window underutilized: %v", v)
- c.c.cc.setUnderutilized(v)
+ c.c.cc.setUnderutilized(nil, v)
}
func (c *lossTest) advance(d time.Duration) {
@@ -1438,7 +1438,7 @@ func (c *lossTest) send(spaceID numberSpace, opts ...any) {
sent := &sentPacket{}
*sent = prototype
sent.num = num
- c.c.packetSent(c.now, spaceID, sent)
+ c.c.packetSent(c.now, nil, spaceID, sent)
}
}
@@ -1462,7 +1462,7 @@ func (c *lossTest) ack(spaceID numberSpace, ackDelay time.Duration, rs ...i64ran
c.t.Logf("ack %v delay=%v [%v,%v)", spaceID, ackDelay, r.start, r.end)
c.c.receiveAckRange(c.now, spaceID, i, r.start, r.end, c.onAckOrLoss)
}
- c.c.receiveAckEnd(c.now, spaceID, ackDelay, c.onAckOrLoss)
+ c.c.receiveAckEnd(c.now, nil, spaceID, ackDelay, c.onAckOrLoss)
}
func (c *lossTest) onAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) {
@@ -1491,7 +1491,7 @@ func (c *lossTest) discardKeys(spaceID numberSpace) {
c.t.Helper()
c.checkUnexpectedEvents()
c.t.Logf("discard %s keys", spaceID)
- c.c.discardKeys(c.now, spaceID)
+ c.c.discardKeys(c.now, nil, spaceID)
}
func (c *lossTest) setMaxAckDelay(d time.Duration) {
diff --git a/quic/main_test.go b/quic/main_test.go
new file mode 100644
index 000000000..ecd0b1e9f
--- /dev/null
+++ b/quic/main_test.go
@@ -0,0 +1,57 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func TestMain(m *testing.M) {
+ defer os.Exit(m.Run())
+
+ // Look for leaked goroutines.
+ //
+ // Checking after every test makes it easier to tell which test is the culprit,
+ // but checking once at the end is faster and less likely to miss something.
+ if runtime.GOOS == "js" {
+ // The js-wasm runtime creates an additional background goroutine.
+ // Just skip the leak check there.
+ return
+ }
+ start := time.Now()
+ warned := false
+ for {
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ leaked := false
+ for _, g := range bytes.Split(buf, []byte("\n\n")) {
+ if bytes.Contains(g, []byte("quic.TestMain")) ||
+ bytes.Contains(g, []byte("created by os/signal.Notify")) ||
+ bytes.Contains(g, []byte("gotraceback_test.go")) {
+ continue
+ }
+ leaked = true
+ }
+ if !leaked {
+ break
+ }
+ if !warned && time.Since(start) > 1*time.Second {
+ // Print a warning quickly, in case this is an interactive session.
+ // Keep waiting until the test times out, in case this is a slow trybot.
+ fmt.Printf("Tests seem to have leaked some goroutines, still waiting.\n\n")
+ fmt.Print(string(buf))
+ warned = true
+ }
+ // Goroutines might still be shutting down.
+ time.Sleep(1 * time.Millisecond)
+ }
+}
diff --git a/internal/quic/math.go b/quic/math.go
similarity index 100%
rename from internal/quic/math.go
rename to quic/math.go
diff --git a/internal/quic/pacer.go b/quic/pacer.go
similarity index 100%
rename from internal/quic/pacer.go
rename to quic/pacer.go
diff --git a/internal/quic/pacer_test.go b/quic/pacer_test.go
similarity index 100%
rename from internal/quic/pacer_test.go
rename to quic/pacer_test.go
diff --git a/internal/quic/packet.go b/quic/packet.go
similarity index 96%
rename from internal/quic/packet.go
rename to quic/packet.go
index df589ccca..7a874319d 100644
--- a/internal/quic/packet.go
+++ b/quic/packet.go
@@ -41,6 +41,22 @@ func (p packetType) String() string {
return fmt.Sprintf("unknown packet type %v", byte(p))
}
+func (p packetType) qlogString() string {
+ switch p {
+ case packetTypeInitial:
+ return "initial"
+ case packetType0RTT:
+ return "0RTT"
+ case packetTypeHandshake:
+ return "handshake"
+ case packetTypeRetry:
+ return "retry"
+ case packetType1RTT:
+ return "1RTT"
+ }
+ return "unknown"
+}
+
// Bits set in the first byte of a packet.
const (
headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1
diff --git a/internal/quic/packet_codec_test.go b/quic/packet_codec_test.go
similarity index 85%
rename from internal/quic/packet_codec_test.go
rename to quic/packet_codec_test.go
index 7b01bb00d..3b39795ef 100644
--- a/internal/quic/packet_codec_test.go
+++ b/quic/packet_codec_test.go
@@ -9,8 +9,13 @@ package quic
import (
"bytes"
"crypto/tls"
+ "io"
+ "log/slog"
"reflect"
"testing"
+ "time"
+
+ "golang.org/x/net/quic/qlog"
)
func TestParseLongHeaderPacket(t *testing.T) {
@@ -207,11 +212,13 @@ func TestRoundtripEncodeShortPacket(t *testing.T) {
func TestFrameEncodeDecode(t *testing.T) {
for _, test := range []struct {
s string
+ j string
f debugFrame
b []byte
truncated []byte
}{{
s: "PADDING*1",
+ j: `{"frame_type":"padding","length":1}`,
f: debugFramePadding{
size: 1,
},
@@ -221,12 +228,14 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PING",
+ j: `{"frame_type":"ping"}`,
f: debugFramePing{},
b: []byte{
0x01, // TYPE(i) = 0x01
},
}, {
s: "ACK Delay=10 [0,16) [17,32) [48,64)",
+ j: `"error: debugFrameAck should not appear as a slog Value"`,
f: debugFrameAck{
ackDelay: 10,
ranges: []i64range[packetNumber]{
@@ -257,6 +266,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "RESET_STREAM ID=1 Code=2 FinalSize=3",
+ j: `{"frame_type":"reset_stream","stream_id":1,"final_size":3}`,
f: debugFrameResetStream{
id: 1,
code: 2,
@@ -270,6 +280,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STOP_SENDING ID=1 Code=2",
+ j: `{"frame_type":"stop_sending","stream_id":1,"error_code":2}`,
f: debugFrameStopSending{
id: 1,
code: 2,
@@ -281,6 +292,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "CRYPTO Offset=1 Length=2",
+ j: `{"frame_type":"crypto","offset":1,"length":2}`,
f: debugFrameCrypto{
off: 1,
data: []byte{3, 4},
@@ -299,6 +311,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "NEW_TOKEN Token=0304",
+ j: `{"frame_type":"new_token","token":"0304"}`,
f: debugFrameNewToken{
token: []byte{3, 4},
},
@@ -309,6 +322,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=1 Offset=0 Length=0",
+ j: `{"frame_type":"stream","stream_id":1,"offset":0,"length":0}`,
f: debugFrameStream{
id: 1,
fin: false,
@@ -324,6 +338,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=100 Offset=4 Length=3",
+ j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3}`,
f: debugFrameStream{
id: 100,
fin: false,
@@ -346,6 +361,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=100 FIN Offset=4 Length=3",
+ j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3,"fin":true}`,
f: debugFrameStream{
id: 100,
fin: true,
@@ -368,6 +384,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM ID=1 FIN Offset=100 Length=0",
+ j: `{"frame_type":"stream","stream_id":1,"offset":100,"length":0,"fin":true}`,
f: debugFrameStream{
id: 1,
fin: true,
@@ -383,6 +400,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_DATA Max=10",
+ j: `{"frame_type":"max_data","maximum":10}`,
f: debugFrameMaxData{
max: 10,
},
@@ -392,6 +410,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAM_DATA ID=1 Max=10",
+ j: `{"frame_type":"max_stream_data","stream_id":1,"maximum":10}`,
f: debugFrameMaxStreamData{
id: 1,
max: 10,
@@ -403,6 +422,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAMS Type=bidi Max=1",
+ j: `{"frame_type":"max_streams","stream_type":"bidirectional","maximum":1}`,
f: debugFrameMaxStreams{
streamType: bidiStream,
max: 1,
@@ -413,6 +433,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "MAX_STREAMS Type=uni Max=1",
+ j: `{"frame_type":"max_streams","stream_type":"unidirectional","maximum":1}`,
f: debugFrameMaxStreams{
streamType: uniStream,
max: 1,
@@ -423,6 +444,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "DATA_BLOCKED Max=1",
+ j: `{"frame_type":"data_blocked","limit":1}`,
f: debugFrameDataBlocked{
max: 1,
},
@@ -432,6 +454,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAM_DATA_BLOCKED ID=1 Max=2",
+ j: `{"frame_type":"stream_data_blocked","stream_id":1,"limit":2}`,
f: debugFrameStreamDataBlocked{
id: 1,
max: 2,
@@ -443,6 +466,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAMS_BLOCKED Type=bidi Max=1",
+ j: `{"frame_type":"streams_blocked","stream_type":"bidirectional","limit":1}`,
f: debugFrameStreamsBlocked{
streamType: bidiStream,
max: 1,
@@ -453,6 +477,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "STREAMS_BLOCKED Type=uni Max=1",
+ j: `{"frame_type":"streams_blocked","stream_type":"unidirectional","limit":1}`,
f: debugFrameStreamsBlocked{
streamType: uniStream,
max: 1,
@@ -463,6 +488,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "NEW_CONNECTION_ID Seq=3 Retire=2 ID=a0a1a2a3 Token=0102030405060708090a0b0c0d0e0f10",
+ j: `{"frame_type":"new_connection_id","sequence_number":3,"retire_prior_to":2,"connection_id":"a0a1a2a3","stateless_reset_token":"0102030405060708090a0b0c0d0e0f10"}`,
f: debugFrameNewConnectionID{
seq: 3,
retirePriorTo: 2,
@@ -479,6 +505,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "RETIRE_CONNECTION_ID Seq=1",
+ j: `{"frame_type":"retire_connection_id","sequence_number":1}`,
f: debugFrameRetireConnectionID{
seq: 1,
},
@@ -488,8 +515,9 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PATH_CHALLENGE Data=0123456789abcdef",
+ j: `{"frame_type":"path_challenge","data":"0123456789abcdef"}`,
f: debugFramePathChallenge{
- data: 0x0123456789abcdef,
+ data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
},
b: []byte{
0x1a, // Type (i) = 0x1a,
@@ -497,8 +525,9 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "PATH_RESPONSE Data=0123456789abcdef",
+ j: `{"frame_type":"path_response","data":"0123456789abcdef"}`,
f: debugFramePathResponse{
- data: 0x0123456789abcdef,
+ data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
},
b: []byte{
0x1b, // Type (i) = 0x1b,
@@ -506,6 +535,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: `CONNECTION_CLOSE Code=INTERNAL_ERROR FrameType=2 Reason="oops"`,
+ j: `{"frame_type":"connection_close","error_space":"transport","error_code_value":1,"reason":"oops"}`,
f: debugFrameConnectionCloseTransport{
code: 1,
frameType: 2,
@@ -520,6 +550,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: `CONNECTION_CLOSE AppCode=1 Reason="oops"`,
+ j: `{"frame_type":"connection_close","error_space":"application","error_code_value":1,"reason":"oops"}`,
f: debugFrameConnectionCloseApplication{
code: 1,
reason: "oops",
@@ -532,6 +563,7 @@ func TestFrameEncodeDecode(t *testing.T) {
},
}, {
s: "HANDSHAKE_DONE",
+ j: `{"frame_type":"handshake_done"}`,
f: debugFrameHandshakeDone{},
b: []byte{
0x1e, // Type (i) = 0x1e,
@@ -554,6 +586,9 @@ func TestFrameEncodeDecode(t *testing.T) {
if got, want := test.f.String(), test.s; got != want {
t.Errorf("frame.String():\ngot %q\nwant %q", got, want)
}
+ if got, want := frameJSON(test.f), test.j; got != want {
+ t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want)
+ }
// Try encoding the frame into too little space.
// Most frames will result in an error; some (like STREAM frames) will truncate
@@ -579,6 +614,42 @@ func TestFrameEncodeDecode(t *testing.T) {
}
}
+func TestFrameScaledAck(t *testing.T) {
+ for _, test := range []struct {
+ j string
+ f debugFrameScaledAck
+ }{{
+ j: `{"frame_type":"ack","acked_ranges":[[0,15],[17],[48,63]],"ack_delay":10.000000}`,
+ f: debugFrameScaledAck{
+ ackDelay: 10 * time.Millisecond,
+ ranges: []i64range[packetNumber]{
+ {0x00, 0x10},
+ {0x11, 0x12},
+ {0x30, 0x40},
+ },
+ },
+ }} {
+ if got, want := frameJSON(test.f), test.j; got != want {
+ t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want)
+ }
+ }
+}
+
+func frameJSON(f slog.LogValuer) string {
+ var buf bytes.Buffer
+ h := qlog.NewJSONHandler(qlog.HandlerOptions{
+ Level: QLogLevelFrame,
+ NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) {
+ return nopCloseWriter{&buf}, nil
+ },
+ })
+ // Log the frame, and then trim out everything but the frame from the log.
+ slog.New(h).Info("message", slog.Any("frame", f))
+ _, b, _ := bytes.Cut(buf.Bytes(), []byte(`"frame":`))
+ b = bytes.TrimSuffix(b, []byte("}}\n"))
+ return string(b)
+}
+
func TestFrameDecode(t *testing.T) {
for _, test := range []struct {
desc string
diff --git a/internal/quic/packet_number.go b/quic/packet_number.go
similarity index 100%
rename from internal/quic/packet_number.go
rename to quic/packet_number.go
diff --git a/internal/quic/packet_number_test.go b/quic/packet_number_test.go
similarity index 100%
rename from internal/quic/packet_number_test.go
rename to quic/packet_number_test.go
diff --git a/internal/quic/packet_parser.go b/quic/packet_parser.go
similarity index 98%
rename from internal/quic/packet_parser.go
rename to quic/packet_parser.go
index 02ef9fb14..feef9eac7 100644
--- a/internal/quic/packet_parser.go
+++ b/quic/packet_parser.go
@@ -463,18 +463,17 @@ func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) {
return seq, n
}
-func consumePathChallengeFrame(b []byte) (data uint64, n int) {
+func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) {
n = 1
- var nn int
- data, nn = consumeUint64(b[n:])
- if nn < 0 {
- return 0, -1
+ nn := copy(data[:], b[n:])
+ if nn != len(data) {
+ return data, -1
}
n += nn
return data, n
}
-func consumePathResponseFrame(b []byte) (data uint64, n int) {
+func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) {
return consumePathChallengeFrame(b) // identical frame format
}
diff --git a/internal/quic/packet_protection.go b/quic/packet_protection.go
similarity index 100%
rename from internal/quic/packet_protection.go
rename to quic/packet_protection.go
diff --git a/internal/quic/packet_protection_test.go b/quic/packet_protection_test.go
similarity index 100%
rename from internal/quic/packet_protection_test.go
rename to quic/packet_protection_test.go
diff --git a/internal/quic/packet_test.go b/quic/packet_test.go
similarity index 100%
rename from internal/quic/packet_test.go
rename to quic/packet_test.go
diff --git a/internal/quic/packet_writer.go b/quic/packet_writer.go
similarity index 95%
rename from internal/quic/packet_writer.go
rename to quic/packet_writer.go
index 0c2b2ee41..e4d71e622 100644
--- a/internal/quic/packet_writer.go
+++ b/quic/packet_writer.go
@@ -47,6 +47,11 @@ func (w *packetWriter) datagram() []byte {
return w.b
}
+// packet returns the size of the current packet.
+func (w *packetWriter) packetLen() int {
+ return len(w.b[w.pktOff:]) + aeadOverhead
+}
+
// payload returns the payload of the current packet.
func (w *packetWriter) payload() []byte {
return w.b[w.payOff:]
@@ -136,7 +141,7 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber
hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked)
k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num)
- return w.finish(p.num)
+ return w.finish(p.ptype, p.num)
}
// start1RTTPacket starts writing a 1-RTT (short header) packet.
@@ -178,7 +183,7 @@ func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConn
hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked)
w.padPacketLength(pnumLen)
k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum)
- return w.finish(pnum)
+ return w.finish(packetType1RTT, pnum)
}
// padPacketLength pads out the payload of the current packet to the minimum size,
@@ -199,9 +204,10 @@ func (w *packetWriter) padPacketLength(pnumLen int) int {
}
// finish finishes the current packet after protection is applied.
-func (w *packetWriter) finish(pnum packetNumber) *sentPacket {
+func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket {
w.b = w.b[:len(w.b)+aeadOverhead]
w.sent.size = len(w.b) - w.pktOff
+ w.sent.ptype = ptype
w.sent.num = pnum
sent := w.sent
w.sent = nil
@@ -237,10 +243,7 @@ func (w *packetWriter) appendPingFrame() (added bool) {
return false
}
w.b = append(w.b, frameTypePing)
- // Mark this packet as ack-eliciting and in-flight,
- // but there's no need to record the presence of a PING frame in it.
- w.sent.ackEliciting = true
- w.sent.inFlight = true
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
@@ -382,11 +385,7 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
w.b = appendVarint(w.b, uint64(size))
start := len(w.b)
w.b = w.b[:start+size]
- if fin {
- w.sent.appendAckElicitingFrame(frameTypeStreamBase | streamFinBit)
- } else {
- w.sent.appendAckElicitingFrame(frameTypeStreamBase)
- }
+ w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit))
w.sent.appendInt(uint64(id))
w.sent.appendOffAndSize(off, size)
return w.b[start:][:size], true
@@ -493,23 +492,23 @@ func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) {
return true
}
-func (w *packetWriter) appendPathChallengeFrame(data uint64) (added bool) {
+func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) {
if w.avail() < 1+8 {
return false
}
w.b = append(w.b, frameTypePathChallenge)
- w.b = binary.BigEndian.AppendUint64(w.b, data)
- w.sent.appendAckElicitingFrame(frameTypePathChallenge)
+ w.b = append(w.b, data[:]...)
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
-func (w *packetWriter) appendPathResponseFrame(data uint64) (added bool) {
+func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) {
if w.avail() < 1+8 {
return false
}
w.b = append(w.b, frameTypePathResponse)
- w.b = binary.BigEndian.AppendUint64(w.b, data)
- w.sent.appendAckElicitingFrame(frameTypePathResponse)
+ w.b = append(w.b, data[:]...)
+ w.sent.markAckEliciting() // no need to record the frame itself
return true
}
diff --git a/quic/path.go b/quic/path.go
new file mode 100644
index 000000000..8c237dd45
--- /dev/null
+++ b/quic/path.go
@@ -0,0 +1,89 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import "time"
+
+type pathState struct {
+ // Response to a peer's PATH_CHALLENGE.
+ // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames.
+ // We only track the most recent PATH_CHALLENGE.
+ // If the peer sends a second PATH_CHALLENGE before we respond to the first,
+ // we'll drop the first response.
+ sendPathResponse pathResponseType
+ data pathChallengeData
+}
+
+// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame.
+type pathChallengeData [64 / 8]byte
+
+type pathResponseType uint8
+
+const (
+ pathResponseNotNeeded = pathResponseType(iota)
+ pathResponseSmall // send PATH_RESPONSE, do not expand datagram
+ pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes
+)
+
+func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) {
+ // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes,
+ // except when this would exceed the anti-amplification limit.
+ //
+ // Rather than maintaining anti-amplification state for each path
+ // we may be sending a PATH_RESPONSE on, follow the following heuristic:
+ //
+ // If we receive a PATH_CHALLENGE in an expanded datagram,
+ // respond with an expanded datagram.
+ //
+ // If we receive a PATH_CHALLENGE in a non-expanded datagram,
+ // then the peer is presumably blocked by its own anti-amplification limit.
+ // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE
+ // will validate the path to the peer, remove its anti-amplification limit,
+ // and permit it to send a followup PATH_CHALLENGE in an expanded datagram.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1
+ if len(dgram.b) >= smallestMaxDatagramSize {
+ c.path.sendPathResponse = pathResponseExpanded
+ } else {
+ c.path.sendPathResponse = pathResponseSmall
+ }
+ c.path.data = data
+}
+
+func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) {
+ // "If the content of a PATH_RESPONSE frame does not match the content of
+ // a PATH_CHALLENGE frame previously sent by the endpoint,
+ // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4
+ //
+ // We never send PATH_CHALLENGE frames.
+ c.abort(now, localTransportError{
+ code: errProtocolViolation,
+ reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent",
+ })
+}
+
+// appendPathFrames appends path validation related frames to the current packet.
+// If the return value pad is true, then the packet should be padded to 1200 bytes.
+func (c *Conn) appendPathFrames() (pad, ok bool) {
+ if c.path.sendPathResponse == pathResponseNotNeeded {
+ return pad, true
+ }
+ // We're required to send the PATH_RESPONSE on the path where the
+ // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2).
+ //
+ // At the moment, we don't support path migration and reject packets if
+ // the peer changes its source address, so just sending the PATH_RESPONSE
+ // in a regular datagram is fine.
+ if !c.w.appendPathResponseFrame(c.path.data) {
+ return pad, false
+ }
+ if c.path.sendPathResponse == pathResponseExpanded {
+ pad = true
+ }
+ c.path.sendPathResponse = pathResponseNotNeeded
+ return pad, true
+}
diff --git a/quic/path_test.go b/quic/path_test.go
new file mode 100644
index 000000000..a309ed14b
--- /dev/null
+++ b/quic/path_test.go
@@ -0,0 +1,66 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "testing"
+)
+
+func TestPathChallengeReceived(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ padTo int
+ wantPadding int
+ }{{
+ name: "unexpanded",
+ padTo: 0,
+ wantPadding: 0,
+ }, {
+ name: "expanded",
+ padTo: 1200,
+ wantPadding: 1200,
+ }} {
+ // "The recipient of [a PATH_CHALLENGE] frame MUST generate
+ // a PATH_RESPONSE frame [...] containing the same Data value."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ tc.writeFrames(packetType1RTT, debugFramePathChallenge{
+ data: data,
+ }, debugFramePadding{
+ to: test.padTo,
+ })
+ tc.wantFrame("response to PATH_CHALLENGE",
+ packetType1RTT, debugFramePathResponse{
+ data: data,
+ })
+ if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want {
+ t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want)
+ }
+ tc.wantIdle("connection is idle")
+ }
+}
+
+func TestPathResponseMismatchReceived(t *testing.T) {
+ // "If the content of a PATH_RESPONSE frame does not match the content of
+ // a PATH_CHALLENGE frame previously sent by the endpoint,
+ // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.writeFrames(packetType1RTT, debugFramePathResponse{
+ data: pathChallengeData{},
+ })
+ tc.wantFrame("invalid PATH_RESPONSE causes the connection to close",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ },
+ )
+}
diff --git a/internal/quic/ping.go b/quic/ping.go
similarity index 100%
rename from internal/quic/ping.go
rename to quic/ping.go
diff --git a/internal/quic/ping_test.go b/quic/ping_test.go
similarity index 100%
rename from internal/quic/ping_test.go
rename to quic/ping_test.go
diff --git a/internal/quic/pipe.go b/quic/pipe.go
similarity index 71%
rename from internal/quic/pipe.go
rename to quic/pipe.go
index d3a448df3..75cf76db2 100644
--- a/internal/quic/pipe.go
+++ b/quic/pipe.go
@@ -17,14 +17,14 @@ import (
// Writing past the end of the window extends it.
// Data may be discarded from the start of the pipe, advancing the window.
type pipe struct {
- start int64
- end int64
- head *pipebuf
- tail *pipebuf
+ start int64 // stream position of first stored byte
+ end int64 // stream position just past the last stored byte
+ head *pipebuf // if non-nil, then head.off + len(head.b) > start
+ tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end
}
type pipebuf struct {
- off int64
+ off int64 // stream position of b[0]
b []byte
next *pipebuf
}
@@ -111,6 +111,7 @@ func (p *pipe) copy(off int64, b []byte) {
// read calls f with the data in [off, off+n)
// The data may be provided sequentially across multiple calls to f.
+// Note that read (unlike an io.Reader) does not consume the read data.
func (p *pipe) read(off int64, n int, f func([]byte) error) error {
if off < p.start {
panic("invalid read range")
@@ -135,6 +136,30 @@ func (p *pipe) read(off int64, n int, f func([]byte) error) error {
return nil
}
+// peek returns a reference to up to n bytes of internal data buffer, starting at p.start.
+// The returned slice is valid until the next call to discardBefore.
+// The length of the returned slice will be in the range [0,n].
+func (p *pipe) peek(n int64) []byte {
+ pb := p.head
+ if pb == nil {
+ return nil
+ }
+ b := pb.b[p.start-pb.off:]
+ return b[:min(int64(len(b)), n)]
+}
+
+// availableBuffer returns the available contiguous, allocated buffer space
+// following the pipe window.
+//
+// This is used by the stream write fast path, which makes multiple writes into the pipe buffer
+// without a lock, and then adjusts p.end at a later time with a lock held.
+func (p *pipe) availableBuffer() []byte {
+ if p.tail == nil {
+ return nil
+ }
+ return p.tail.b[p.end-p.tail.off:]
+}
+
// discardBefore discards all data prior to off.
func (p *pipe) discardBefore(off int64) {
for p.head != nil && p.head.end() < off {
diff --git a/internal/quic/pipe_test.go b/quic/pipe_test.go
similarity index 100%
rename from internal/quic/pipe_test.go
rename to quic/pipe_test.go
diff --git a/quic/qlog.go b/quic/qlog.go
new file mode 100644
index 000000000..36831252c
--- /dev/null
+++ b/quic/qlog.go
@@ -0,0 +1,274 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "encoding/hex"
+ "log/slog"
+ "net/netip"
+ "time"
+)
+
+// Log levels for qlog events.
+const (
+ // QLogLevelFrame includes per-frame information.
+ // When this level is enabled, packet_sent and packet_received events will
+ // contain information on individual frames sent/received.
+ QLogLevelFrame = slog.Level(-6)
+
+ // QLogLevelPacket events occur at most once per packet sent or received.
+ //
+ // For example: packet_sent, packet_received.
+ QLogLevelPacket = slog.Level(-4)
+
+ // QLogLevelConn events occur multiple times over a connection's lifetime,
+ // but less often than the frequency of individual packets.
+ //
+ // For example: connection_state_updated.
+ QLogLevelConn = slog.Level(-2)
+
+ // QLogLevelEndpoint events occur at most once per connection.
+ //
+ // For example: connection_started, connection_closed.
+ QLogLevelEndpoint = slog.Level(0)
+)
+
+func (c *Conn) logEnabled(level slog.Level) bool {
+ return logEnabled(c.log, level)
+}
+
+func logEnabled(log *slog.Logger, level slog.Level) bool {
+ return log != nil && log.Enabled(context.Background(), level)
+}
+
+// slogHexstring returns a slog.Attr for a value of the hexstring type.
+//
+// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1
+func slogHexstring(key string, value []byte) slog.Attr {
+ return slog.String(key, hex.EncodeToString(value))
+}
+
+func slogAddr(key string, value netip.Addr) slog.Attr {
+ return slog.String(key, value.String())
+}
+
+func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) {
+ if c.config.QLogLogger == nil ||
+ !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) {
+ return
+ }
+ var vantage string
+ if c.side == clientSide {
+ vantage = "client"
+ originalDstConnID = c.connIDState.originalDstConnID
+ } else {
+ vantage = "server"
+ }
+ // A qlog Trace container includes some metadata (title, description, vantage_point)
+ // and a list of Events. The Trace also includes a common_fields field setting field
+ // values common to all events in the trace.
+ //
+ // Trace = {
+ // ? title: text
+ // ? description: text
+ // ? configuration: Configuration
+ // ? common_fields: CommonFields
+ // ? vantage_point: VantagePoint
+ // events: [* Event]
+ // }
+ //
+ // To map this into slog's data model, we start each per-connection trace with a With
+ // call that includes both the trace metadata and the common fields.
+ //
+ // This means that in slog's model, each trace event will also include
+ // the Trace metadata fields (vantage_point), which is a divergence from the qlog model.
+ c.log = c.config.QLogLogger.With(
+ // The group_id permits associating traces taken from different vantage points
+ // for the same connection.
+ //
+ // We use the original destination connection ID as the group ID.
+ //
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6
+ slogHexstring("group_id", originalDstConnID),
+ slog.Group("vantage_point",
+ slog.String("name", "go quic"),
+ slog.String("type", vantage),
+ ),
+ )
+ localAddr := c.endpoint.LocalAddr()
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2
+ c.log.LogAttrs(context.Background(), QLogLevelEndpoint,
+ "connectivity:connection_started",
+ slogAddr("src_ip", localAddr.Addr()),
+ slog.Int("src_port", int(localAddr.Port())),
+ slogHexstring("src_cid", c.connIDState.local[0].cid),
+ slogAddr("dst_ip", peerAddr.Addr()),
+ slog.Int("dst_port", int(peerAddr.Port())),
+ slogHexstring("dst_cid", c.connIDState.remote[0].cid),
+ )
+}
+
+func (c *Conn) logConnectionClosed() {
+ if !c.logEnabled(QLogLevelEndpoint) {
+ return
+ }
+ err := c.lifetime.finalErr
+ trigger := "error"
+ switch e := err.(type) {
+ case *ApplicationError:
+ // TODO: Distinguish between peer and locally-initiated close.
+ trigger = "application"
+ case localTransportError:
+ switch err {
+ case errHandshakeTimeout:
+ trigger = "handshake_timeout"
+ default:
+ if e.code == errNo {
+ trigger = "clean"
+ }
+ }
+ case peerTransportError:
+ if e.code == errNo {
+ trigger = "clean"
+ }
+ default:
+ switch err {
+ case errIdleTimeout:
+ trigger = "idle_timeout"
+ case errStatelessReset:
+ trigger = "stateless_reset"
+ }
+ }
+ // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3
+ c.log.LogAttrs(context.Background(), QLogLevelEndpoint,
+ "connectivity:connection_closed",
+ slog.String("trigger", trigger),
+ )
+}
+
+func (c *Conn) logPacketDropped(dgram *datagram) {
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "connectivity:packet_dropped",
+ )
+}
+
+func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(p.payload)
+ }
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_received",
+ slog.Group("header",
+ slog.String("packet_type", p.ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(p.num)),
+ slog.Uint64("flags", uint64(pkt[0])),
+ slogHexstring("scid", p.srcConnID),
+ slogHexstring("dcid", p.dstConnID),
+ ),
+ slog.Group("raw",
+ slog.Int("length", len(pkt)),
+ ),
+ frames,
+ )
+}
+
+func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(p.payload)
+ }
+ dstConnID, _ := dstConnIDForDatagram(pkt)
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_received",
+ slog.Group("header",
+ slog.String("packet_type", packetType1RTT.qlogString()),
+ slog.Uint64("packet_number", uint64(p.num)),
+ slog.Uint64("flags", uint64(pkt[0])),
+ slogHexstring("dcid", dstConnID),
+ ),
+ slog.Group("raw",
+ slog.Int("length", len(pkt)),
+ ),
+ frames,
+ )
+}
+
+func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst []byte, pktLen int, payload []byte) {
+ var frames slog.Attr
+ if c.logEnabled(QLogLevelFrame) {
+ frames = c.packetFramesAttr(payload)
+ }
+ var scid slog.Attr
+ if len(src) > 0 {
+ scid = slogHexstring("scid", src)
+ }
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "transport:packet_sent",
+ slog.Group("header",
+ slog.String("packet_type", ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(pnum)),
+ scid,
+ slogHexstring("dcid", dst),
+ ),
+ slog.Group("raw",
+ slog.Int("length", pktLen),
+ ),
+ frames,
+ )
+}
+
+// packetFramesAttr returns the "frames" attribute containing the frames in a packet.
+// We currently pass this as a slog Any containing a []slog.Value,
+// where each Value is a debugFrame that implements slog.LogValuer.
+//
+// This isn't tremendously efficient, but avoids the need to put a JSON encoder
+// in the quic package or a frame parser in the qlog package.
+func (c *Conn) packetFramesAttr(payload []byte) slog.Attr {
+ var frames []slog.Value
+ for len(payload) > 0 {
+ f, n := parseDebugFrame(payload)
+ if n < 0 {
+ break
+ }
+ payload = payload[n:]
+ switch f := f.(type) {
+ case debugFrameAck:
+ // The qlog ACK frame contains the ACK Delay field as a duration.
+ // Interpreting the contents of this field as a duration requires
+ // knowing the peer's ack_delay_exponent transport parameter,
+ // and it's possible for us to parse an ACK frame before we've
+ // received that parameter.
+ //
+ // We could plumb connection state down into the frame parser,
+ // but for now let's minimize the amount of code that needs to
+ // deal with this and convert the unscaled value into a scaled one here.
+ ackDelay := time.Duration(-1)
+ if c.peerAckDelayExponent >= 0 {
+ ackDelay = f.ackDelay.Duration(uint8(c.peerAckDelayExponent))
+ }
+ frames = append(frames, slog.AnyValue(debugFrameScaledAck{
+ ranges: f.ranges,
+ ackDelay: ackDelay,
+ }))
+ default:
+ frames = append(frames, slog.AnyValue(f))
+ }
+ }
+ return slog.Any("frames", frames)
+}
+
+func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) {
+ c.log.LogAttrs(context.Background(), QLogLevelPacket,
+ "recovery:packet_lost",
+ slog.Group("header",
+ slog.String("packet_type", sent.ptype.qlogString()),
+ slog.Uint64("packet_number", uint64(sent.num)),
+ ),
+ )
+}
diff --git a/quic/qlog/handler.go b/quic/qlog/handler.go
new file mode 100644
index 000000000..35a66cf8b
--- /dev/null
+++ b/quic/qlog/handler.go
@@ -0,0 +1,76 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "context"
+ "log/slog"
+)
+
+type withAttrsHandler struct {
+ attrs []slog.Attr
+ h slog.Handler
+}
+
+func withAttrs(h slog.Handler, attrs []slog.Attr) slog.Handler {
+ if len(attrs) == 0 {
+ return h
+ }
+ return &withAttrsHandler{attrs: attrs, h: h}
+}
+
+func (h *withAttrsHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return h.h.Enabled(ctx, level)
+}
+
+func (h *withAttrsHandler) Handle(ctx context.Context, r slog.Record) error {
+ r.AddAttrs(h.attrs...)
+ return h.h.Handle(ctx, r)
+}
+
+func (h *withAttrsHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return withAttrs(h, attrs)
+}
+
+func (h *withAttrsHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
+
+type withGroupHandler struct {
+ name string
+ h slog.Handler
+}
+
+func withGroup(h slog.Handler, name string) slog.Handler {
+ if name == "" {
+ return h
+ }
+ return &withGroupHandler{name: name, h: h}
+}
+
+func (h *withGroupHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return h.h.Enabled(ctx, level)
+}
+
+func (h *withGroupHandler) Handle(ctx context.Context, r slog.Record) error {
+ var attrs []slog.Attr
+ r.Attrs(func(a slog.Attr) bool {
+ attrs = append(attrs, a)
+ return true
+ })
+ nr := slog.NewRecord(r.Time, r.Level, r.Message, r.PC)
+ nr.Add(slog.Any(h.name, slog.GroupValue(attrs...)))
+ return h.h.Handle(ctx, nr)
+}
+
+func (h *withGroupHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return withAttrs(h, attrs)
+}
+
+func (h *withGroupHandler) WithGroup(name string) slog.Handler {
+ return withGroup(h, name)
+}
diff --git a/quic/qlog/json_writer.go b/quic/qlog/json_writer.go
new file mode 100644
index 000000000..6fb8d33b2
--- /dev/null
+++ b/quic/qlog/json_writer.go
@@ -0,0 +1,261 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package qlog
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log/slog"
+ "strconv"
+ "sync"
+ "time"
+)
+
+// A jsonWriter writes JSON-SEQ (RFC 7464).
+//
+// A JSON-SEQ file consists of a series of JSON text records,
+// each beginning with an RS (0x1e) character and ending with LF (0x0a).
+type jsonWriter struct {
+ mu sync.Mutex
+ w io.WriteCloser
+ buf bytes.Buffer
+}
+
+// writeRecordStart writes the start of a JSON-SEQ record.
+func (w *jsonWriter) writeRecordStart() {
+ w.mu.Lock()
+ w.buf.WriteByte(0x1e)
+ w.buf.WriteByte('{')
+}
+
+// writeRecordEnd finishes writing a JSON-SEQ record.
+func (w *jsonWriter) writeRecordEnd() {
+ w.buf.WriteByte('}')
+ w.buf.WriteByte('\n')
+ w.w.Write(w.buf.Bytes())
+ w.buf.Reset()
+ w.mu.Unlock()
+}
+
+func (w *jsonWriter) writeAttrs(attrs []slog.Attr) {
+ w.buf.WriteByte('{')
+ for _, a := range attrs {
+ w.writeAttr(a)
+ }
+ w.buf.WriteByte('}')
+}
+
+func (w *jsonWriter) writeAttr(a slog.Attr) {
+ if a.Key == "" {
+ return
+ }
+ w.writeName(a.Key)
+ w.writeValue(a.Value)
+}
+
+// writeAttr writes a []slog.Attr as an object field.
+func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) {
+ w.writeName(name)
+ w.writeAttrs(attrs)
+}
+
+func (w *jsonWriter) writeValue(v slog.Value) {
+ v = v.Resolve()
+ switch v.Kind() {
+ case slog.KindAny:
+ switch v := v.Any().(type) {
+ case []slog.Value:
+ w.writeArray(v)
+ case interface{ AppendJSON([]byte) []byte }:
+ w.buf.Write(v.AppendJSON(w.buf.AvailableBuffer()))
+ default:
+ w.writeString(fmt.Sprint(v))
+ }
+ case slog.KindBool:
+ w.writeBool(v.Bool())
+ case slog.KindDuration:
+ w.writeDuration(v.Duration())
+ case slog.KindFloat64:
+ w.writeFloat64(v.Float64())
+ case slog.KindInt64:
+ w.writeInt64(v.Int64())
+ case slog.KindString:
+ w.writeString(v.String())
+ case slog.KindTime:
+ w.writeTime(v.Time())
+ case slog.KindUint64:
+ w.writeUint64(v.Uint64())
+ case slog.KindGroup:
+ w.writeAttrs(v.Group())
+ default:
+ w.writeString("unhandled kind")
+ }
+}
+
+// writeName writes an object field name followed by a colon.
+func (w *jsonWriter) writeName(name string) {
+ if b := w.buf.Bytes(); len(b) > 0 && b[len(b)-1] != '{' {
+ // Add the comma separating this from the previous field.
+ w.buf.WriteByte(',')
+ }
+ w.writeString(name)
+ w.buf.WriteByte(':')
+}
+
+func (w *jsonWriter) writeObject(f func()) {
+ w.buf.WriteByte('{')
+ f()
+ w.buf.WriteByte('}')
+}
+
+// writeObject writes an object-valued object field.
+// The function f is called to write the contents.
+func (w *jsonWriter) writeObjectField(name string, f func()) {
+ w.writeName(name)
+ w.writeObject(f)
+}
+
+func (w *jsonWriter) writeArray(vals []slog.Value) {
+ w.buf.WriteByte('[')
+ for i, v := range vals {
+ if i != 0 {
+ w.buf.WriteByte(',')
+ }
+ w.writeValue(v)
+ }
+ w.buf.WriteByte(']')
+}
+
+func (w *jsonWriter) writeRaw(v string) {
+ w.buf.WriteString(v)
+}
+
+// writeRawField writes a field with a raw JSON value.
+func (w *jsonWriter) writeRawField(name, v string) {
+ w.writeName(name)
+ w.writeRaw(v)
+}
+
+func (w *jsonWriter) writeBool(v bool) {
+ if v {
+ w.buf.WriteString("true")
+ } else {
+ w.buf.WriteString("false")
+ }
+}
+
+// writeBoolField writes a bool-valued object field.
+func (w *jsonWriter) writeBoolField(name string, v bool) {
+ w.writeName(name)
+ w.writeBool(v)
+}
+
+// writeDuration writes a duration as milliseconds.
+func (w *jsonWriter) writeDuration(v time.Duration) {
+ if v < 0 {
+ w.buf.WriteByte('-')
+ v = -v
+ }
+ fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond)
+}
+
+// writeDurationField writes a millisecond duration-valued object field.
+func (w *jsonWriter) writeDurationField(name string, v time.Duration) {
+ w.writeName(name)
+ w.writeDuration(v)
+}
+
+func (w *jsonWriter) writeFloat64(v float64) {
+ w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64))
+}
+
+// writeFloat64Field writes an float64-valued object field.
+func (w *jsonWriter) writeFloat64Field(name string, v float64) {
+ w.writeName(name)
+ w.writeFloat64(v)
+}
+
+func (w *jsonWriter) writeInt64(v int64) {
+ w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10))
+}
+
+// writeInt64Field writes an int64-valued object field.
+func (w *jsonWriter) writeInt64Field(name string, v int64) {
+ w.writeName(name)
+ w.writeInt64(v)
+}
+
+func (w *jsonWriter) writeUint64(v uint64) {
+ w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10))
+}
+
+// writeUint64Field writes a uint64-valued object field.
+func (w *jsonWriter) writeUint64Field(name string, v uint64) {
+ w.writeName(name)
+ w.writeUint64(v)
+}
+
+// writeTime writes a time as seconds since the Unix epoch.
+func (w *jsonWriter) writeTime(v time.Time) {
+ fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond))
+}
+
+// writeTimeField writes a time-valued object field.
+func (w *jsonWriter) writeTimeField(name string, v time.Time) {
+ w.writeName(name)
+ w.writeTime(v)
+}
+
+func jsonSafeSet(c byte) bool {
+ // mask is a 128-bit bitmap with 1s for allowed bytes,
+ // so that the byte c can be tested with a shift and an and.
+ // If c > 128, then 1<