From ab271c317248ea0f18481852f96d12d5eca05cf8 Mon Sep 17 00:00:00 2001 From: David Bell Date: Tue, 29 Aug 2023 20:50:33 +0000 Subject: [PATCH 01/19] http2: add IdleConnTimeout to http2.Transport Exposes an IdleConnTimeout on http2.Transport directly, rather than rely on configuring it through the underlying http1 transport. For golang/go#57893 Change-Id: Ibe506da39e314aebec1cd6df64937982182a37ca GitHub-Last-Rev: cc8f1710ed543da8e937aa2446b0a3982dec6ce3 GitHub-Pull-Request: golang/net#173 Reviewed-on: https://go-review.googlesource.com/c/net/+/497195 Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Matthew Dempsky --- http2/transport.go | 14 ++++++++++ http2/transport_test.go | 62 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/http2/transport.go b/http2/transport.go index c2a5b44b3..b599197e7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -147,6 +147,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -3150,9 +3156,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err } func (t *Transport) idleConnTimeout() time.Duration { + // to keep things backwards compatible, we use non-zero values of + // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying + // http1 transport, followed by 0 + if t.IdleConnTimeout != 0 { + return t.IdleConnTimeout + } + if t.t1 != nil { return t.t1.IdleConnTimeout } + return 0 } diff --git a/http2/transport_test.go b/http2/transport_test.go index a81131f29..6ac8e978b 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -95,6 +95,68 @@ func startH2cServer(t *testing.T) net.Listener { return l } +func TestIdleConnTimeout(t *testing.T) { + for _, test := range []struct { + idleConnTimeout time.Duration + wait time.Duration + baseTransport *http.Transport + wantConns int32 + }{{ + idleConnTimeout: 2 * time.Second, + wait: 1 * time.Second, + baseTransport: nil, + wantConns: 1, + }, { + idleConnTimeout: 1 * time.Second, + wait: 2 * time.Second, + baseTransport: nil, + wantConns: 5, + }, { + idleConnTimeout: 0 * time.Second, + wait: 1 * time.Second, + baseTransport: &http.Transport{ + IdleConnTimeout: 2 * time.Second, + }, + wantConns: 1, + }} { + var gotConns int32 + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + IdleConnTimeout: test.idleConnTimeout, + TLSClientConfig: tlsConfigInsecure, + } + defer tr.CloseIdleConnections() + + for i := 0; i < 5; i++ { + req, _ := http.NewRequest("GET", st.ts.URL, http.NoBody) + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConns, 1) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + _, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("%v", err) + } + + <-time.After(test.wait) + } + + if gotConns != test.wantConns { + t.Errorf("incorrect gotConns: %d != %d", gotConns, test.wantConns) + } + } +} + func TestTransportH2c(t *testing.T) { l := startH2cServer(t) defer l.Close() From 8c07e20f924fb9dec8d39d2793f72a42c3261a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E7=8E=AE=E6=96=87?= Date: Sun, 3 Sep 2023 14:21:31 +0800 Subject: [PATCH 02/19] httpproxy: allow any scheme currently only http/https/socks5 scheme are allowed. However, any scheme could be possible if user provides their own implementation. Specifically, the widely used "socks5h://localhost" is parsed as Scheme="http" Host="socks5h:", which does not make sense because host name cannot contain ":". This patch allows any scheme to appear in the proxy config. And only fallback to http scheme if parsed scheme or host is empty. url.Parse() result of fallback cases: localhost => Scheme="localhost" localhost:1234 => Scheme="localhost" Opaque="1234" example.com => Path="example.com" Updates golang/go#24135 Change-Id: Ia2c041e37e2ac61be16220fd41d6cb6fabeeca3d Reviewed-on: https://go-review.googlesource.com/c/net/+/525257 LUCI-TryBot-Result: Go LUCI Run-TryBot: Damien Neil Reviewed-by: Michael Knyszek Reviewed-by: Damien Neil TryBot-Result: Gopher Robot Auto-Submit: Damien Neil --- http/httpproxy/proxy.go | 5 +---- http/httpproxy/proxy_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index c3bd9a1ee..6404aaf15 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -149,10 +149,7 @@ func parseProxy(proxy string) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) - if err != nil || - (proxyURL.Scheme != "http" && - proxyURL.Scheme != "https" && - proxyURL.Scheme != "socks5") { + if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" { // proxy was bogus. Try prepending "http://" to it and // see if that parses correctly. If not, we fall // through and complain about the original one. diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index d76373295..790afdab7 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "cache.corp.example.com", }, want: "http://cache.corp.example.com", +}, { + // single label domain is recognized as scheme by url.Parse + cfg: httpproxy.Config{ + HTTPProxy: "localhost", + }, + want: "http://localhost", }, { cfg: httpproxy.Config{ HTTPProxy: "https://cache.corp.example.com", @@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "socks5://127.0.0.1", }, want: "socks5://127.0.0.1", +}, { + // Preserve unknown schemes. + cfg: httpproxy.Config{ + HTTPProxy: "foo://host", + }, + want: "foo://host", }, { // Don't use secure for http cfg: httpproxy.Config{ From ea095bc79b94b4bdc6939ce2dbd0300520089a1f Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Mon, 19 Feb 2024 19:53:31 +0800 Subject: [PATCH 03/19] http2: only set up positive deadlines Fixes golang/go#65785 Change-Id: Icd95d7cae5ed26b8a2fe656daf8365e27a7785d8 Reviewed-on: https://go-review.googlesource.com/c/net/+/565195 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Reviewed-by: Carlos Amedee --- http2/server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/http2/server.go b/http2/server.go index ae94c6408..905206f3e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -434,7 +434,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -2017,7 +2017,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } @@ -2038,7 +2038,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { // Disable any read deadline set by the net/http package // prior to the upgrade. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -2116,7 +2116,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } From 57a6a7a86bc0e47508781ed988adcecbe8ff2580 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sun, 25 Feb 2024 15:09:48 +0800 Subject: [PATCH 04/19] http2: prevent uninitialized pipe from being written For golang/go#65927 Change-Id: I6f48706156384e026968cf9a6d9e0ec76b46fabf Reviewed-on: https://go-review.googlesource.com/c/net/+/566675 Reviewed-by: Damien Neil Reviewed-by: Carlos Amedee LUCI-TryBot-Result: Go LUCI --- http2/pipe.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/http2/pipe.go b/http2/pipe.go index 684d984fd..3b9f06b96 100644 --- a/http2/pipe.go +++ b/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } From d600ae05799943851536e26ab37ee23294912c3d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 9 Feb 2024 18:11:52 -0800 Subject: [PATCH 05/19] http2: add testClientConn for testing client RoundTrips Many RoundTrip tests involve testing against a test-defined server with specific behaviors. For example: Testing RoundTrip's behavior when the server violates flow control limits. Existing tests mostly use the clientTester type, which starts separate goroutines for the Transport and a fake server. This results in tests where the control flow bounces around the test function, and requires each test to manage its own synchronization. Introduce a new framework for writing RoundTrip tests. testClientConn allows client tests to be written linearly, with synchronization provided by the test framework. For example, a testClientConn test can, as a linear sequence of actions: - start RoundTrip; - check the request headers sent; - provide data to the request body; - check that a DATA frame is sent; - send response headers from the server to the client; - check that RoundTrip returns. See TestTestClientConn at the top of clientconn_test.go for a full example. To enable synchronization with tests, this CL instruments the RoundTrip path to record when goroutines start, exit, and block waiting for events. This adds a certain amount of noise and bookkeeping to the client implementation, but (in my opinion) this is more than repaid in improved testability. The testClientConn also permits use of synthetic time in tests. At the moment, this is limited to the response header timeout, but extending it to other timeouts (read, 100-continue) should be straightforward. This CL converts a number of existing clientTester tests to use the new framework, but not all. Change-Id: Ief963889969363ec8469cd3c3de0becb2fc548f9 Reviewed-on: https://go-review.googlesource.com/c/net/+/563540 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 634 +++++++++++++++++++++ http2/testsync.go | 246 ++++++++ http2/transport.go | 145 ++++- http2/transport_test.go | 1154 ++++++++++++++------------------------ 4 files changed, 1414 insertions(+), 765 deletions(-) create mode 100644 http2/clientconn_test.go create mode 100644 http2/testsync.go diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go new file mode 100644 index 000000000..6d94762e5 --- /dev/null +++ b/http2/clientconn_test.go @@ -0,0 +1,634 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Infrastructure for testing ClientConn.RoundTrip. +// Put actual tests in transport_test.go. + +package http2 + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "reflect" + "testing" + "time" + + "golang.org/x/net/http2/hpack" +) + +// TestTestClientConn demonstrates usage of testClientConn. +func TestTestClientConn(t *testing.T) { + // newTestClientConn creates a *ClientConn and surrounding test infrastructure. + tc := newTestClientConn(t) + + // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames, + // and sends a SETTINGS frame to the client. + // + // Additional settings may be provided as optional parameters to greet. + tc.greet() + + // Request bodies must either be constant (bytes.Buffer, strings.Reader) + // or created with newRequestBody. + body := tc.newRequestBody() + body.writeBytes(10) // 10 arbitrary bytes... + body.closeWithError(io.EOF) // ...followed by EOF. + + // tc.roundTrip calls RoundTrip, but does not wait for it to return. + // It returns a testRoundTrip. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + // tc has a number of methods to check for expected frames sent. + // Here, we look for headers and the request body. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + // Expect 10 bytes of request body in DATA frames. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: 10, + }) + + // tc.writeHeaders sends a HEADERS frame back to the client. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // Now that we've received headers, RoundTrip has finished. + // testRoundTrip has various methods to examine the response, + // or to fetch the response and/or error returned by RoundTrip + rt.wantStatus(200) + rt.wantBody(nil) +} + +// A testClientConn allows testing ClientConn.RoundTrip against a fake server. +// +// A test using testClientConn consists of: +// - actions on the client (calling RoundTrip, making data available to Request.Body); +// - validation of frames sent by the client to the server; and +// - providing frames from the server to the client. +// +// testClientConn manages synchronization, so tests can generally be written as +// a linear sequence of actions and validations without additional synchronization. +type testClientConn struct { + t *testing.T + + tr *Transport + fr *Framer + cc *ClientConn + hooks *testSyncHooks + + encbuf bytes.Buffer + enc *hpack.Encoder + + roundtrips []*testRoundTrip + + rerr error // returned by Read + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn +} + +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tr := &Transport{} + for _, o := range opts { + o(tr) + } + + tc := &testClientConn{ + t: t, + tr: tr, + hooks: newTestSyncHooks(), + } + tc.enc = hpack.NewEncoder(&tc.encbuf) + tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) + tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + tc.fr.SetMaxReadFrameSize(10 << 20) + + t.Cleanup(func() { + if tc.rerr == nil { + tc.rerr = io.EOF + } + tc.sync() + if tc.hooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tc.hooks.total) + } + + }) + + tc.hooks.newclientconn = func(cc *ClientConn) { + tc.cc = cc + } + const singleUse = false + _, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks) + if err != nil { + t.Fatal(err) + } + tc.sync() + tc.hooks.newclientconn = nil + + // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. + buf := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { + t.Fatalf("reading preface: %v", err) + } + if !bytes.Equal(buf, clientPreface) { + t.Fatalf("client preface: %q, want %q", buf, clientPreface) + } + + return tc +} + +// sync waits for the ClientConn under test to reach a stable state, +// with all goroutines blocked on some input. +func (tc *testClientConn) sync() { + tc.hooks.waitInactive() +} + +// advance advances synthetic time by a duration. +func (tc *testClientConn) advance(d time.Duration) { + tc.hooks.advance(d) + tc.sync() +} + +// hasFrame reports whether a frame is available to be read. +func (tc *testClientConn) hasFrame() bool { + return tc.wbuf.Len() > 0 +} + +// readFrame reads the next frame from the conn. +func (tc *testClientConn) readFrame() Frame { + if tc.wbuf.Len() == 0 { + return nil + } + fr, err := tc.fr.ReadFrame() + if err != nil { + return nil + } + return fr +} + +// testClientConnReadFrame reads a frame of a specific type from the conn. +func testClientConnReadFrame[T any](tc *testClientConn) T { + tc.t.Helper() + var v T + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", v) + } + v, ok := fr.(T) + if !ok { + tc.t.Fatalf("got frame %T, want %T", fr, v) + } + return v +} + +// wantFrameType reads the next frame from the conn. +// It produces an error if the frame type is not the expected value. +func (tc *testClientConn) wantFrameType(want FrameType) { + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", want) + } + if got := fr.Header().Type; got != want { + tc.t.Fatalf("got frame %v, want %v", got, want) + } +} + +type wantHeader struct { + streamID uint32 + endStream bool + header http.Header +} + +// wantHeaders reads a HEADERS frame and potential CONTINUATION frames, +// and asserts that they contain the expected headers. +func (tc *testClientConn) wantHeaders(want wantHeader) { + fr := tc.readFrame() + got, ok := fr.(*MetaHeadersFrame) + if !ok { + tc.t.Fatalf("got %v, want HEADERS frame", want) + } + if got, want := got.StreamID, want.streamID; got != want { + tc.t.Fatalf("got stream ID %v, want %v", got, want) + } + if got, want := got.StreamEnded(), want.endStream; got != want { + tc.t.Fatalf("got stream ended %v, want %v", got, want) + } + gotHeader := make(http.Header) + for _, f := range got.Fields { + gotHeader[f.Name] = append(gotHeader[f.Name], f.Value) + } + for k, v := range want.header { + if !reflect.DeepEqual(v, gotHeader[k]) { + tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) + } + } +} + +type wantData struct { + streamID uint32 + endStream bool + size int +} + +// wantData reads zero or more DATA frames, and asserts that they match the expectation. +func (tc *testClientConn) wantData(want wantData) { + tc.t.Helper() + gotSize := 0 + gotEndStream := false + for tc.hasFrame() && !gotEndStream { + data := testClientConnReadFrame[*DataFrame](tc) + gotSize += len(data.Data()) + if data.StreamEnded() { + gotEndStream = true + } + } + if gotSize != want.size { + tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size) + } + if gotEndStream != want.endStream { + tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream) + } +} + +// testRequestBody is a Request.Body for use in tests. +type testRequestBody struct { + tc *testClientConn + + // At most one of buf or bytes can be set at any given time: + buf bytes.Buffer // specific bytes to read from the body + bytes int // body contains this many arbitrary bytes + + err error // read error (comes after any available bytes) +} + +func (tc *testClientConn) newRequestBody() *testRequestBody { + b := &testRequestBody{ + tc: tc, + } + return b +} + +// Read is called by the ClientConn to read from a request body. +func (b *testRequestBody) Read(p []byte) (n int, _ error) { + b.tc.cc.syncHooks.blockUntil(func() bool { + return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil + }) + switch { + case b.buf.Len() > 0: + return b.buf.Read(p) + case b.bytes > 0: + if len(p) > b.bytes { + p = p[:b.bytes] + } + b.bytes -= len(p) + for i := range p { + p[i] = 'A' + } + return len(p), nil + default: + return 0, b.err + } +} + +// Close is called by the ClientConn when it is done reading from a request body. +func (b *testRequestBody) Close() error { + return nil +} + +// writeBytes adds n arbitrary bytes to the body. +func (b *testRequestBody) writeBytes(n int) { + b.bytes += n + b.checkWrite() + b.tc.sync() +} + +// Write adds bytes to the body. +func (b *testRequestBody) Write(p []byte) (int, error) { + n, err := b.buf.Write(p) + b.checkWrite() + b.tc.sync() + return n, err +} + +func (b *testRequestBody) checkWrite() { + if b.bytes > 0 && b.buf.Len() > 0 { + b.tc.t.Fatalf("can't interleave Write and writeBytes on request body") + } + if b.err != nil { + b.tc.t.Fatalf("can't write to request body after closeWithError") + } +} + +// closeWithError sets an error which will be returned by Read. +func (b *testRequestBody) closeWithError(err error) { + b.err = err + b.tc.sync() +} + +// roundTrip starts a RoundTrip call. +// +// (Note that the RoundTrip won't complete until response headers are received, +// the request times out, or some other terminal condition is reached.) +func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + tc: tc, + donec: make(chan struct{}), + } + tc.roundtrips = append(tc.roundtrips, rt) + tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs } + tc.cc.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tc.cc.RoundTrip(req) + }) + tc.sync() + tc.hooks.newstream = nil + + tc.t.Cleanup(func() { + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} + +func (tc *testClientConn) greet(settings ...Setting) { + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings(settings...) + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement +} + +func (tc *testClientConn) writeSettings(settings ...Setting) { + tc.t.Helper() + if err := tc.fr.WriteSettings(settings...); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeSettingsAck() { + tc.t.Helper() + if err := tc.fr.WriteSettingsAck(); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) { + tc.t.Helper() + if err := tc.fr.WriteData(streamID, endStream, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// makeHeaderBlockFragment encodes headers in a form suitable for inclusion +// in a HEADERS or CONTINUATION frame. +// +// It takes a list of alernating names and values. +func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { + if len(s)%2 != 0 { + tc.t.Fatalf("uneven list of header name/value pairs") + } + tc.encbuf.Reset() + for i := 0; i < len(s); i += 2 { + tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]}) + } + return tc.encbuf.Bytes() +} + +func (tc *testClientConn) writeHeaders(p HeadersFrameParam) { + tc.t.Helper() + if err := tc.fr.WriteHeaders(p); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// writeHeadersMode writes header frames, as modified by mode: +// +// - noHeader: Don't write the header. +// - oneHeader: Write a single HEADERS frame. +// - splitHeader: Write a HEADERS frame and CONTINUATION frame. +func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) { + tc.t.Helper() + switch mode { + case noHeader: + case oneHeader: + tc.writeHeaders(p) + case splitHeader: + if len(p.BlockFragment) < 2 { + panic("too small") + } + contData := p.BlockFragment[1:] + contEnd := p.EndHeaders + p.BlockFragment = p.BlockFragment[:1] + p.EndHeaders = false + tc.writeHeaders(p) + tc.writeContinuation(p.StreamID, contEnd, contData) + default: + panic("bogus mode") + } +} + +func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) { + tc.t.Helper() + if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { + tc.t.Helper() + if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) { + tc.t.Helper() + if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Read calls. +func (tc *testClientConn) closeWrite(err error) { + tc.rerr = err + tc.sync() +} + +// testRoundTrip manages a RoundTrip in progress. +type testRoundTrip struct { + tc *testClientConn + resp *http.Response + respErr error + donec chan struct{} + cs *clientStream +} + +// streamID returns the HTTP/2 stream ID of the request. +func (rt *testRoundTrip) streamID() uint32 { + return rt.cs.ID +} + +// done reports whether RoundTrip has returned. +func (rt *testRoundTrip) done() bool { + select { + case <-rt.donec: + return true + default: + return false + } +} + +// result returns the result of the RoundTrip. +func (rt *testRoundTrip) result() (*http.Response, error) { + t := rt.tc.t + t.Helper() + select { + case <-rt.donec: + default: + t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID()) + } + return rt.resp, rt.respErr +} + +// response returns the response of a successful RoundTrip. +// If the RoundTrip unexpectedly failed, it calls t.Fatal. +func (rt *testRoundTrip) response() *http.Response { + t := rt.tc.t + t.Helper() + resp, err := rt.result() + if err != nil { + t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) + } + if resp == nil { + t.Fatalf("RoundTrip returned nil *Response and nil error") + } + return resp +} + +// err returns the (possibly nil) error result of RoundTrip. +func (rt *testRoundTrip) err() error { + t := rt.tc.t + t.Helper() + _, err := rt.result() + return err +} + +// wantStatus indicates the expected response StatusCode. +func (rt *testRoundTrip) wantStatus(want int) { + t := rt.tc.t + t.Helper() + if got := rt.response().StatusCode; got != want { + t.Fatalf("got response status %v, want %v", got, want) + } +} + +// body reads the contents of the response body. +func (rt *testRoundTrip) readBody() ([]byte, error) { + t := rt.tc.t + t.Helper() + return io.ReadAll(rt.response().Body) +} + +// wantBody indicates the expected response body. +// (Note that this consumes the body.) +func (rt *testRoundTrip) wantBody(want []byte) { + t := rt.tc.t + t.Helper() + got, err := rt.readBody() + if err != nil { + t.Fatalf("unexpected error reading response body: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) + } +} + +// wantHeaders indicates the expected response headers. +func (rt *testRoundTrip) wantHeaders(want http.Header) { + t := rt.tc.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Header, want); diff != "" { + t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +// wantTrailers indicates the expected response trailers. +func (rt *testRoundTrip) wantTrailers(want http.Header) { + t := rt.tc.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Trailer, want); diff != "" { + t.Fatalf("unexpected response trailers:\n%v", diff) + } +} + +func diffHeaders(got, want http.Header) string { + // nil and 0-length non-nil are equal. + if len(got) == 0 && len(want) == 0 { + return "" + } + // We could do a more sophisticated diff here. + // DeepEqual is good enough for now. + if reflect.DeepEqual(got, want) { + return "" + } + return fmt.Sprintf("got: %v\nwant: %v", got, want) +} + +// testClientConnNetConn implements net.Conn. +type testClientConnNetConn testClientConn + +func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) { + nc.cc.syncHooks.blockUntil(func() bool { + return nc.rerr != nil || nc.rbuf.Len() > 0 + }) + if nc.rbuf.Len() > 0 { + return nc.rbuf.Read(b) + } + return 0, nc.rerr +} + +func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { + return nc.wbuf.Write(b) +} + +func (*testClientConnNetConn) Close() error { + return nil +} + +func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/http2/testsync.go b/http2/testsync.go new file mode 100644 index 000000000..b8335c0fb --- /dev/null +++ b/http2/testsync.go @@ -0,0 +1,246 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import ( + "sync" + "time" +) + +// testSyncHooks coordinates goroutines in tests. +// +// For example, a call to ClientConn.RoundTrip involves several goroutines, including: +// - the goroutine running RoundTrip; +// - the clientStream.doRequest goroutine, which writes the request; and +// - the clientStream.readLoop goroutine, which reads the response. +// +// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines +// are blocked waiting for some condition such as reading the Request.Body or waiting for +// flow control to become available. +// +// The testSyncHooks also manage timers and synthetic time in tests. +// This permits us to, for example, start a request and cause it to time out waiting for +// response headers without resorting to time.Sleep calls. +type testSyncHooks struct { + // active/inactive act as a mutex and condition variable. + // + // - neither chan contains a value: testSyncHooks is locked. + // - active contains a value: unlocked, and at least one goroutine is not blocked + // - inactive contains a value: unlocked, and all goroutines are blocked + active chan struct{} + inactive chan struct{} + + // goroutine counts + total int // total goroutines + condwait map[*sync.Cond]int // blocked in sync.Cond.Wait + blocked []*testBlockedGoroutine // otherwise blocked + + // fake time + now time.Time + timers []*fakeTimer + + // Transport testing: Report various events. + newclientconn func(*ClientConn) + newstream func(*clientStream) +} + +// testBlockedGoroutine is a blocked goroutine. +type testBlockedGoroutine struct { + f func() bool // blocked until f returns true + ch chan struct{} // closed when unblocked +} + +func newTestSyncHooks() *testSyncHooks { + h := &testSyncHooks{ + active: make(chan struct{}, 1), + inactive: make(chan struct{}, 1), + condwait: map[*sync.Cond]int{}, + } + h.inactive <- struct{}{} + h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return h +} + +// lock acquires the testSyncHooks mutex. +func (h *testSyncHooks) lock() { + select { + case <-h.active: + case <-h.inactive: + } +} + +// waitInactive waits for all goroutines to become inactive. +func (h *testSyncHooks) waitInactive() { + for { + <-h.inactive + if !h.unlock() { + break + } + } +} + +// unlock releases the testSyncHooks mutex. +// It reports whether any goroutines are active. +func (h *testSyncHooks) unlock() (active bool) { + // Look for a blocked goroutine which can be unblocked. + blocked := h.blocked[:0] + unblocked := false + for _, b := range h.blocked { + if !unblocked && b.f() { + unblocked = true + close(b.ch) + } else { + blocked = append(blocked, b) + } + } + h.blocked = blocked + + // Count goroutines blocked on condition variables. + condwait := 0 + for _, count := range h.condwait { + condwait += count + } + + if h.total > condwait+len(blocked) { + h.active <- struct{}{} + return true + } else { + h.inactive <- struct{}{} + return false + } +} + +// goRun starts a new goroutine. +func (h *testSyncHooks) goRun(f func()) { + h.lock() + h.total++ + h.unlock() + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + f() + }() +} + +// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. +// It waits until f returns true before proceeding. +// +// Example usage: +// +// h.blockUntil(func() bool { +// // Is the context done yet? +// select { +// case <-ctx.Done(): +// default: +// return false +// } +// return true +// }) +// // Wait for the context to become done. +// <-ctx.Done() +// +// The function f passed to blockUntil must be non-blocking and idempotent. +func (h *testSyncHooks) blockUntil(f func() bool) { + if f() { + return + } + ch := make(chan struct{}) + h.lock() + h.blocked = append(h.blocked, &testBlockedGoroutine{ + f: f, + ch: ch, + }) + h.unlock() + <-ch +} + +// broadcast is sync.Cond.Broadcast. +func (h *testSyncHooks) condBroadcast(cond *sync.Cond) { + h.lock() + delete(h.condwait, cond) + h.unlock() + cond.Broadcast() +} + +// broadcast is sync.Cond.Wait. +func (h *testSyncHooks) condWait(cond *sync.Cond) { + h.lock() + h.condwait[cond]++ + h.unlock() +} + +// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests. +func (h *testSyncHooks) newTimer(d time.Duration) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + when: h.now.Add(d), + c: make(chan time.Time), + } + h.timers = append(h.timers, t) + return t +} + +// advance advances time and causes synthetic timers to fire. +func (h *testSyncHooks) advance(d time.Duration) { + h.lock() + defer h.unlock() + h.now = h.now.Add(d) + timers := h.timers[:0] + for _, t := range h.timers { + t.mu.Lock() + switch { + case t.when.After(h.now): + timers = append(timers, t) + case t.when.IsZero(): + // stopped timer + default: + t.when = time.Time{} + close(t.c) + } + t.mu.Unlock() + } + h.timers = timers +} + +// A timer wraps a time.Timer, or a synthetic equivalent in tests. +// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. +type timer interface { + C() <-chan time.Time + Stop() bool +} + +type timeTimer struct { + t *time.Timer + c chan time.Time +} + +func newTimeTimer(d time.Duration) timer { + ch := make(chan time.Time) + t := time.AfterFunc(d, func() { + close(ch) + }) + return &timeTimer{t, ch} +} + +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } + +type fakeTimer struct { + mu sync.Mutex + when time.Time + c chan time.Time +} + +func (t *fakeTimer) C() <-chan time.Time { return t.c } +func (t *fakeTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + stopped := t.when.IsZero() + t.when = time.Time{} + return stopped +} diff --git a/http2/transport.go b/http2/transport.go index b599197e7..04db29275 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -350,6 +350,45 @@ type ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + + syncHooks *testSyncHooks // can be nil +} + +// Hook points used for testing. +// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +// goRun starts a new goroutine. +func (cc *ClientConn) goRun(f func()) { + if cc.syncHooks != nil { + cc.syncHooks.goRun(f) + return + } + go f() +} + +// condBroadcast is cc.cond.Broadcast. +func (cc *ClientConn) condBroadcast() { + if cc.syncHooks != nil { + cc.syncHooks.condBroadcast(cc.cond) + } + cc.cond.Broadcast() +} + +// condWait is cc.cond.Wait. +func (cc *ClientConn) condWait() { + if cc.syncHooks != nil { + cc.syncHooks.condWait(cc.cond) + } + cc.cond.Wait() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (cc *ClientConn) newTimer(d time.Duration) timer { + if cc.syncHooks != nil { + return cc.syncHooks.newTimer(d) + } + return newTimeTimer(d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -431,7 +470,7 @@ func (cs *clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() + cs.cc.condBroadcast() } } @@ -441,7 +480,7 @@ func (cs *clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.cond.Broadcast() + cc.condBroadcast() } } @@ -451,10 +490,10 @@ func (cs *clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - go func() { + cs.cc.goRun(func() { cs.reqBody.Close() close(reqBodyClosed) - }() + }) } type stickyErrWriter struct { @@ -672,7 +711,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -738,10 +777,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) { cc := &ClientConn{ t: t, tconn: c, @@ -756,6 +795,10 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), + syncHooks: hooks, + } + if hooks != nil { + hooks.newclientconn(cc) } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -824,7 +867,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - go cc.readLoop() + cc.goRun(cc.readLoop) return cc, nil } @@ -1062,7 +1105,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - go func() { + cc.goRun(func() { cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1074,9 +1117,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.cond.Wait() + cc.condWait() } - }() + }) shutdownEnterWaitStateHook() select { case <-done: @@ -1086,7 +1129,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() return ctx.Err() } @@ -1124,7 +1167,7 @@ func (cc *ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() cc.closeConn() } @@ -1221,6 +1264,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1235,9 +1282,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + cc.goRun(func() { + cs.doRequest(req) + }) waitDone := func() error { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.donec: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.donec: return nil @@ -1298,7 +1359,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return err } + if streamf != nil { + streamf(cs) + } + for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -1378,6 +1456,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() + if cc.syncHooks != nil { + cc.syncHooks.newstream(cs) + } + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -1458,15 +1540,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.peerClosed: + case <-respHeaderTimer: + case <-respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.peerClosed: return nil @@ -1615,7 +1712,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { return nil } cc.pendingRequests++ - cc.cond.Wait() + cc.condWait() cc.pendingRequests-- select { case <-cs.abort: @@ -1877,7 +1974,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) cs.flow.take(take) return take, nil } - cc.cond.Wait() + cc.condWait() } } @@ -2149,7 +2246,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + cc.condBroadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -2237,7 +2334,7 @@ func (rl *clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() } @@ -2873,7 +2970,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { for _, cs := range cc.streams { cs.flow.add(delta) } - cc.cond.Broadcast() + cc.condBroadcast() cc.initialWindowSize = s.Val case SettingHeaderTableSize: @@ -2928,7 +3025,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { return ConnectionError(ErrCodeFlowControl) } - cc.cond.Broadcast() + cc.condBroadcast() return nil } @@ -2971,7 +3068,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error { cc.mu.Unlock() } errc := make(chan error, 1) - go func() { + cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() if err := cc.fr.WritePing(false, p); err != nil { @@ -2982,7 +3079,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error { errc <- err return } - }() + }) select { case <-c: return nil diff --git a/http2/transport_test.go b/http2/transport_test.go index 6ac8e978b..f889cd12b 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -1014,131 +1014,65 @@ func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyA func testTransportReqBodyAfterResponse(t *testing.T, status int) { const bodySize = 10 << 20 - clientDone := make(chan struct{}) - ct := newClientTester(t) - recvLen := make(chan int64, 1) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - body := &pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - if res.StatusCode != status { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) - } - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - body.CloseWithError(io.EOF) - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("unexpected body: %q", slurp) - } - res.Body.Close() - if status == 200 { - if got := <-recvLen; got != bodySize { - return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) - } - } else { - if got := <-recvLen; got == 0 || got >= bodySize { - return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + body := tc.newRequestBody() + body.writeBytes(bodySize / 2) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + + // Provide enough congestion window for the full request body. + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: bodySize / 2, + }) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", strconv.Itoa(status), + ), + }) + + res := rt.response() + if res.StatusCode != status { + t.Fatalf("status code = %v; want %v", res.StatusCode, status) } - ct.server = func() error { - ct.greet() - defer close(recvLen) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var dataRecv int64 - var closed bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - //println(fmt.Sprintf("server got frame: %v", f)) - ended := false - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - if f.StreamEnded() { - return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) - } - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if dataRecv == 0 { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - dataRecv += int64(dataLen) - if !closed && ((status != 200 && dataRecv > 0) || - (status == 200 && f.StreamEnded())) { - closed = true - if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { - return err - } - } + body.writeBytes(bodySize / 2) + body.closeWithError(io.EOF) - if f.StreamEnded() { - ended = true - } - case *RSTStreamFrame: - if status == 200 { - return fmt.Errorf("Unexpected client frame %v", f) - } - ended = true - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - if ended { - select { - case recvLen <- dataRecv: - default: - } - } - } + if status == 200 { + // After a 200 response, client sends the remaining request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: bodySize / 2, + }) + } else { + // After a 403 response, client gives up and resets the stream. + tc.wantFrameType(FrameRSTStream) } - ct.run() + + rt.wantBody(nil) } // See golang.org/issue/13444 @@ -1319,121 +1253,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy panic("invalid combination") } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) - if expect100Continue != noHeader { - req.Header.Set("Expect", "100-continue") - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - wantBody := resBody - if !withData { - wantBody = "" - } - if string(slurp) != wantBody { - return fmt.Errorf("body = %q; want %q", slurp, wantBody) - } - if trailers == noHeader { - if len(res.Trailer) > 0 { - t.Errorf("Trailer = %v; want none", res.Trailer) - } - } else { - want := http.Header{"Some-Trailer": {"some-value"}} - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Trailer = %v; want %v", res.Trailer, want) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - endStream := false - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *DataFrame: - if !f.StreamEnded() { - // No need to send flow control tokens. The test request body is tiny. - continue - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) - enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) - if trailers != noHeader { - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) - } - endStream = withData == false && trailers == noHeader - send(resHeader) - } - if withData { - endStream = trailers == noHeader - ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) - } - if trailers != noHeader { - endStream = true - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) - send(trailers) - } - if endStream { - return nil - } - case *HeadersFrame: - if expect100Continue != noHeader { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) - send(expect100Continue) - } - } - } + tc.wantFrameType(FrameHeaders) + + // Possibly 100-continue, or skip when noHeader. + tc.writeHeadersMode(expect100Continue, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + + // Client sends request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: len(reqBody), + }) + + hdr := []string{ + ":status", "200", + "x-foo", "blah", + "x-bar", "more", + } + if trailers != noHeader { + hdr = append(hdr, "trailer", "some-trailer") + } + tc.writeHeadersMode(resHeader, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: withData == false && trailers == noHeader, + BlockFragment: tc.makeHeaderBlockFragment(hdr...), + }) + if withData { + endStream := trailers == noHeader + tc.writeData(rt.streamID(), endStream, []byte(resBody)) + } + tc.writeHeadersMode(trailers, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "some-value", + ), + }) + + rt.wantStatus(200) + if !withData { + rt.wantBody(nil) + } else { + rt.wantBody([]byte(resBody)) + } + if trailers == noHeader { + rt.wantTrailers(nil) + } else { + rt.wantTrailers(http.Header{ + "Some-Trailer": {"some-value"}, + }) } - ct.run() } // Issue 26189, Issue 17739: ignore unknown 1xx responses @@ -1445,130 +1332,76 @@ func TestTransportUnknown1xx(t *testing.T) { return nil } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 204 { - return fmt.Errorf("status code = %v; want 204", res.StatusCode) - } - want := `code=110 header=map[Foo-Bar:[110]] + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + for i := 110; i <= 114; i++ { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", fmt.Sprint(i), + "foo-bar", fmt.Sprint(i), + ), + }) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) + + res := rt.response() + if res.StatusCode != 204 { + t.Fatalf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] code=111 header=map[Foo-Bar:[111]] code=112 header=map[Foo-Bar:[112]] code=113 header=map[Foo-Bar:[113]] code=114 header=map[Foo-Bar:[114]] ` - if got := buf.String(); got != want { - t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - for i := 110; i <= 114; i++ { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) - enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) } - ct.run() - } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - if _, ok := res.Trailer["Some-Trailer"]; !ok { - return fmt.Errorf("expected Some-Trailer") - } - return nil - } - ct.server = func() error { - ct.greet() - - var n int - var hf *HeadersFrame - for hf == nil && n < 10 { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - hf, _ = f.(*HeadersFrame) - n++ - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - // send headers without Trailer header - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "I'm an undeclared Trailer!", + ), + }) - // send trailers - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - ct.run() + rt.wantStatus(200) + rt.wantBody(nil) + rt.wantTrailers(http.Header{ + "Some-Trailer": []string{"I'm an undeclared Trailer!"}, + }) } func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { @@ -1578,10 +1411,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - }) + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), + ":colon", "foo", + "foo", "bar", + ) } func TestTransportInvalidTrailer_Capital1(t *testing.T) { @@ -1591,102 +1424,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) - }) + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), + "foo", "bar", + "Capital", "bad", + ) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), + "", "bad", + ) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), + "x", "has\nnewline", + ) } -func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - se, ok := err.(StreamError) - if !ok || se.Cause != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) +func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "trailer", "declared", + ), + }) + tc.writeHeadersMode(mode, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment(trailers...), + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var endStream bool - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) - endStream = false - send(oneHeader) - } - // Trailers: - { - endStream = true - buf.Reset() - writeTrailer(enc) - send(trailers) - } - return nil - } - } + rt.wantStatus(200) + body, err := rt.readBody() + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr) + } + if len(body) > 0 { + t.Fatalf("body = %q; want nothing", body) } - ct.run() } // headerListSize returns the HTTP2 header list size of h. @@ -1962,115 +1747,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if e, ok := err.(StreamError); ok { - err = e.Cause - } - if err != errResponseHeaderListSize { - size := int64(0) - if res != nil { - res.Body.Close() - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 - } - } - } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + hdr := []string{":status", "200"} + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + hdr = append(hdr, large, large) + } + hbf := tc.makeHeaderBlockFragment(hdr...) + // Note: this number might change if our hpack implementation changes. + // That's fine. This is just a sanity check that our response can fit in a single + // header block fragment frame. + if size, want := len(hbf), 6329; size != want { + t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: hbf, + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - large := strings.Repeat("a", 1<<10) - for i := 0; i < 5042; i++ { - enc.WriteField(hpack.HeaderField{Name: large, Value: large}) - } - if size, want := buf.Len(), 6329; size != want { - // Note: this number might change if - // our hpack implementation - // changes. That's fine. This is - // just a sanity check that our - // response can fit in a single - // header block fragment frame. - return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + res, err := rt.result() + if e, ok := err.(StreamError); ok { + err = e.Cause + } + if err != errResponseHeaderListSize { + size := int64(0) + if res != nil { + res.Body.Close() + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil } } + t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } - ct.run() } func TestTransportCookieHeaderSplit(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - req.Header.Add("Cookie", "a=b;c=d; e=f;") - req.Header.Add("Cookie", "e=f;g=h; ") - req.Header.Add("Cookie", "i=j") - _, err := ct.tr.RoundTrip(req) - return err - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - dec := hpack.NewDecoder(initialHeaderTableSize, nil) - hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) - if err != nil { - return err - } - got := []string{} - want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} - for _, hf := range hfs { - if hf.Name == "cookie" { - got = append(got, hf.Value) - } - } - if !reflect.DeepEqual(got, want) { - t.Errorf("Cookies = %#v, want %#v", got, want) - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + req.Header.Add("Cookie", "a=b;c=d; e=f;") + req.Header.Add("Cookie", "e=f;g=h; ") + req.Header.Add("Cookie", "i=j") + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}, + }, + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if err := rt.err(); err != nil { + t.Fatalf("RoundTrip = %v, want success", err) } - ct.run() } // Test that the Transport returns a typed error from Response.Body.Read calls @@ -2286,55 +2036,49 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { } func testTransportResponseHeaderTimeout(t *testing.T, body bool) { - ct := newClientTester(t) - ct.tr.t1 = &http.Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, - } - ct.client = func() error { - c := &http.Client{Transport: ct.tr} - var err error - var n int64 - const bodySize = 4 << 20 - if body { - _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) - } else { - _, err = c.Get("https://dummy.tld/") - } - if !isTimeout(err) { - t.Errorf("client expected timeout error; got %#v", err) - } - if body && n != bodySize { - t.Errorf("only read %d bytes of body; want %d", n, bodySize) + const bodySize = 4 << 20 + tc := newTestClientConn(t, func(tr *Transport) { + tr.t1 = &http.Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, } - return nil + }) + tc.greet() + + var req *http.Request + var reqBody *testRequestBody + if body { + reqBody = tc.newRequestBody() + reqBody.writeBytes(bodySize) + reqBody.closeWithError(io.EOF) + req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody) + req.Header.Set("Content-Type", "text/foo") + } else { + req, _ = http.NewRequest("GET", "https://dummy.tld/", nil) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - switch f := f.(type) { - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - case *RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { - return nil - } - } - } + + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + if body { + tc.wantData(wantData{ + endStream: true, + size: bodySize, + }) + } + + tc.advance(4 * time.Millisecond) + if rt.done() { + t.Fatalf("RoundTrip is done after 4ms; want still waiting") + } + tc.advance(1 * time.Millisecond) + + if err := rt.err(); !isTimeout(err) { + t.Fatalf("RoundTrip error: %v; want timeout error", err) } - ct.run() } func TestTransportDisableCompression(t *testing.T) { @@ -2720,115 +2464,61 @@ func TestTransportNewTLSConfig(t *testing.T) { // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) func TestTransportReadHeadResponse(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != 123 { - return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, // as the GFE does - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, nil) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) + tc.writeData(rt.streamID(), true, nil) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) } - ct.run() + rt.wantBody(nil) } func TestTransportReadHeadResponseWithBody(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + // This test uses an invalid response format. + // Discard logger output to not spam tests output. + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != int64(len(response)) { - return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, []byte(response)) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", strconv.Itoa(len(response)), + ), + }) + tc.writeData(rt.streamID(), true, []byte(response)) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != int64(len(response)) { + t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response)) } - ct.run() + rt.wantBody(nil) } type neverEnding byte @@ -2953,71 +2643,53 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { } func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { - ct := newClientTester(t) - clientDone := make(chan struct{}) + tc := newTestClientConn(t) + tc.greet() const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if failMidBody { - if err != nil { - return fmt.Errorf("unexpected client RoundTrip error: %v", err) - } - _, err = io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - } - want := GoAwayError{ - LastStreamID: 5, - ErrCode: goAwayErrCode, - DebugData: goAwayDebugData, - } - if !reflect.DeepEqual(err, want) { - t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) - } - return nil + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + if failMidBody { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - if failMidBody { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - // Write two GOAWAY frames, to test that the Transport takes - // the interesting parts of both. - ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) - ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.sc.(*net.TCPConn).Close() - } - <-clientDone - return nil + + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) + tc.writeGoAway(5, goAwayErrCode, nil) + tc.closeWrite(io.EOF) + + res, err := rt.result() + whence := "RoundTrip" + if failMidBody { + whence = "Body.Read" + if err != nil { + t.Fatalf("RoundTrip error = %v, want success", err) } + _, err = res.Body.Read(make([]byte, 1)) + } + + want := GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want) } - ct.run() } func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { @@ -6049,7 +5721,7 @@ func TestClientConnReservations(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.newClientConn(st.cc, false) + cc, err := tr.newClientConn(st.cc, false, nil) if err != nil { t.Fatal(err) } From 12ddef72728707026a3d9adbbc28affa76faf688 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 26 Feb 2024 12:25:37 -0800 Subject: [PATCH 06/19] http2: reject DATA frames after 1xx and before final headers When checking to see if a DATA frame can be accepted, check to see if we have received a non-1xx header, not whether we have received any header. Fixes golang/go#65927 Change-Id: Id4fae1862de6179f8fc95e02dec7d4c47a7640e1 Reviewed-on: https://go-review.googlesource.com/c/net/+/567175 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/transport.go | 2 +- http2/transport_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/http2/transport.go b/http2/transport.go index 04db29275..44845bafd 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2787,7 +2787,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, diff --git a/http2/transport_test.go b/http2/transport_test.go index f889cd12b..836d45593 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -6254,3 +6254,32 @@ func TestDialRaceResumesDial(t *testing.T) { case <-successCh: } } + +func TestTransportDataAfter1xxHeader(t *testing.T) { + // Discard logger output to avoid spamming stderr. + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame. + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + tc.writeData(rt.streamID(), true, []byte{0}) + err := rt.err() + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err) + } + tc.wantFrameType(FrameRSTStream) +} From 31d9683ed011ab20a0aa6ab62de563611851a2b8 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 12:58:12 -0700 Subject: [PATCH 07/19] http2: mark several testing functions as helpers Change-Id: Ib5519fd882b3692efadd6191fbebbf042c9aa77d Reviewed-on: https://go-review.googlesource.com/c/net/+/572376 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 6d94762e5..9a5b2b013 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -191,7 +191,7 @@ func testClientConnReadFrame[T any](tc *testClientConn) T { var v T fr := tc.readFrame() if fr == nil { - tc.t.Fatalf("got no frame, want frame %v", v) + tc.t.Fatalf("got no frame, want frame %T", v) } v, ok := fr.(T) if !ok { @@ -203,6 +203,7 @@ func testClientConnReadFrame[T any](tc *testClientConn) T { // wantFrameType reads the next frame from the conn. // It produces an error if the frame type is not the expected value. func (tc *testClientConn) wantFrameType(want FrameType) { + tc.t.Helper() fr := tc.readFrame() if fr == nil { tc.t.Fatalf("got no frame, want frame %v", want) @@ -221,11 +222,8 @@ type wantHeader struct { // wantHeaders reads a HEADERS frame and potential CONTINUATION frames, // and asserts that they contain the expected headers. func (tc *testClientConn) wantHeaders(want wantHeader) { - fr := tc.readFrame() - got, ok := fr.(*MetaHeadersFrame) - if !ok { - tc.t.Fatalf("got %v, want HEADERS frame", want) - } + tc.t.Helper() + got := testClientConnReadFrame[*MetaHeadersFrame](tc) if got, want := got.StreamID, want.streamID; got != want { tc.t.Fatalf("got stream ID %v, want %v", got, want) } From 9e0498de4d22259990fc8eb8440eafd7c353c19c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:06:45 -0700 Subject: [PATCH 08/19] http2: use synthetic timers for ping timeouts in tests Change-Id: I642890519b066937ade3c13e8387c31d29e912f4 Reviewed-on: https://go-review.googlesource.com/c/net/+/572377 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 9 +++ http2/testsync.go | 101 +++++++++++++++++++++++++++--- http2/transport.go | 69 +++++++++++++++++---- http2/transport_test.go | 131 ++++++++++++++++++++++++--------------- 4 files changed, 240 insertions(+), 70 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 9a5b2b013..97f884c66 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { tc.fr.SetMaxReadFrameSize(10 << 20) t.Cleanup(func() { + tc.sync() if tc.rerr == nil { tc.rerr = io.EOF } @@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he tc.sync() } +func (tc *testClientConn) writePing(ack bool, data [8]byte) { + tc.t.Helper() + if err := tc.fr.WritePing(ack, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { tc.t.Helper() if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { diff --git a/http2/testsync.go b/http2/testsync.go index b8335c0fb..61075bd16 100644 --- a/http2/testsync.go +++ b/http2/testsync.go @@ -4,6 +4,7 @@ package http2 import ( + "context" "sync" "time" ) @@ -173,18 +174,56 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) { h.unlock() } -// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests. +// newTimer creates a new fake timer. func (h *testSyncHooks) newTimer(d time.Duration) timer { h.lock() defer h.unlock() t := &fakeTimer{ - when: h.now.Add(d), - c: make(chan time.Time), + hooks: h, + when: h.now.Add(d), + c: make(chan time.Time), } h.timers = append(h.timers, t) return t } +// afterFunc creates a new fake AfterFunc timer. +func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + f: f, + } + h.timers = append(h.timers, t) + return t +} + +func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + t := h.afterFunc(d, cancel) + return ctx, func() { + t.Stop() + cancel() + } +} + +func (h *testSyncHooks) timeUntilEvent() time.Duration { + h.lock() + defer h.unlock() + var next time.Time + for _, t := range h.timers { + if next.IsZero() || t.when.Before(next) { + next = t.when + } + } + if d := next.Sub(h.now); d > 0 { + return d + } + return 0 +} + // advance advances time and causes synthetic timers to fire. func (h *testSyncHooks) advance(d time.Duration) { h.lock() @@ -192,6 +231,7 @@ func (h *testSyncHooks) advance(d time.Duration) { h.now = h.now.Add(d) timers := h.timers[:0] for _, t := range h.timers { + t := t // remove after go.mod depends on go1.22 t.mu.Lock() switch { case t.when.After(h.now): @@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) { // stopped timer default: t.when = time.Time{} - close(t.c) + if t.c != nil { + close(t.c) + } + if t.f != nil { + h.total++ + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + t.f() + }() + } } t.mu.Unlock() } @@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) { type timer interface { C() <-chan time.Time Stop() bool + Reset(d time.Duration) bool } +// timeTimer implements timer using real time. type timeTimer struct { t *time.Timer c chan time.Time } +// newTimeTimer creates a new timer using real time. func newTimeTimer(d time.Duration) timer { ch := make(chan time.Time) t := time.AfterFunc(d, func() { @@ -227,16 +283,29 @@ func newTimeTimer(d time.Duration) timer { return &timeTimer{t, ch} } -func (t timeTimer) C() <-chan time.Time { return t.c } -func (t timeTimer) Stop() bool { return t.t.Stop() } +// newTimeAfterFunc creates an AfterFunc timer using real time. +func newTimeAfterFunc(d time.Duration, f func()) timer { + return &timeTimer{ + t: time.AfterFunc(d, f), + } +} +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } +func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } + +// fakeTimer implements timer using fake time. type fakeTimer struct { + hooks *testSyncHooks + mu sync.Mutex - when time.Time - c chan time.Time + when time.Time // when the timer will fire + c chan time.Time // closed when the timer fires; mutually exclusive with f + f func() // called when the timer fires; mutually exclusive with c } func (t *fakeTimer) C() <-chan time.Time { return t.c } + func (t *fakeTimer) Stop() bool { t.mu.Lock() defer t.mu.Unlock() @@ -244,3 +313,19 @@ func (t *fakeTimer) Stop() bool { t.when = time.Time{} return stopped } + +func (t *fakeTimer) Reset(d time.Duration) bool { + if t.c != nil || t.f == nil { + panic("fakeTimer only supports Reset on AfterFunc timers") + } + t.mu.Lock() + defer t.mu.Unlock() + t.hooks.lock() + defer t.hooks.unlock() + active := !t.when.IsZero() + t.when = t.hooks.now.Add(d) + if !active { + t.hooks.timers = append(t.hooks.timers, t) + } + return active +} diff --git a/http2/transport.go b/http2/transport.go index 44845bafd..1ce5f125c 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer { return newTimeTimer(d) } +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer { + if cc.syncHooks != nil { + return cc.syncHooks.afterFunc(d, f) + } + return newTimeAfterFunc(d, f) +} + +func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if cc.syncHooks != nil { + return cc.syncHooks.contextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) +} + // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type clientStream struct { @@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } + var newStreamHook func(*clientStream) + if cc.syncHooks != nil { + newStreamHook = cc.syncHooks.newstream + cc.syncHooks.blockUntil(func() bool { + select { + case cc.reqHeaderMu <- struct{}{}: + <-cc.reqHeaderMu + case <-cs.reqCancel: + case <-ctx.Done(): + default: + return false + } + return true + }) + } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() - if cc.syncHooks != nil { - cc.syncHooks.newstream(cs) + if newStreamHook != nil { + newStreamHook(cs) } // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? @@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) + var pingError error + errc := make(chan struct{}) cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } }) + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-c: + case <-errc: + case <-ctx.Done(): + case <-cc.readerDone: + default: + return false + } + return true + }) + } select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: diff --git a/http2/transport_test.go b/http2/transport_test.go index 836d45593..bab2472f3 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3310,26 +3310,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { } func TestTransportCloseAfterLostPing(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.tr.PingTimeout = 1 * time.Second - ct.tr.ReadIdleTimeout = 1 * time.Second - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - if err == nil || !strings.Contains(err.Error(), "client connection lost") { - return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - <-clientDone - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.PingTimeout = 1 * time.Second + tr.ReadIdleTimeout = 1 * time.Second + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.advance(1 * time.Second) + tc.wantFrameType(FramePing) + + tc.advance(1 * time.Second) + err := rt.err() + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + t.Fatalf("expected to get error about \"connection lost\", got %v", err) } - ct.run() } func TestTransportPingWriteBlocks(t *testing.T) { @@ -3362,38 +3360,73 @@ func TestTransportPingWriteBlocks(t *testing.T) { } } -func TestTransportPingWhenReading(t *testing.T) { - testCases := []struct { - name string - readIdleTimeout time.Duration - deadline time.Duration - expectedPingCount int - }{ - { - name: "two pings", - readIdleTimeout: 100 * time.Millisecond, - deadline: time.Second, - expectedPingCount: 2, - }, - { - name: "zero ping", - readIdleTimeout: time.Second, - deadline: 200 * time.Millisecond, - expectedPingCount: 0, - }, - { - name: "0 readIdleTimeout means no ping", - readIdleTimeout: 0 * time.Millisecond, - deadline: 500 * time.Millisecond, - expectedPingCount: 0, - }, +func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 1000 * time.Millisecond + }) + tc.greet() + + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + for i := 0; i < 5; i++ { + // No ping yet... + tc.advance(999 * time.Millisecond) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) + } + + // ...ping now. + tc.advance(1 * time.Millisecond) + f := testClientConnReadFrame[*PingFrame](tc) + tc.writePing(true, f.Data) } - for _, tc := range testCases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) - }) + // Cancel the request, Transport resets it and returns an error from body reads. + cancel() + tc.sync() + + tc.wantFrameType(FrameRSTStream) + _, err := rt.readBody() + if err == nil { + t.Fatalf("Response.Body.Read() = %v, want error", err) + } +} + +func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 0 // PINGs disabled + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // No PING is sent, even after a long delay. + tc.advance(1 * time.Minute) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } } From 6e2c99c943496e33025da68db088edff5dc7d07b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:01:26 -0700 Subject: [PATCH 09/19] http2: allow testing Transports with testSyncHooks Change-Id: Icafc4860ef0691e5133221a0b53bb1d2158346cc Reviewed-on: https://go-review.googlesource.com/c/net/+/572378 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 202 ++++++++++++++++++++------ http2/transport.go | 35 +++-- http2/transport_test.go | 301 ++++++++++++++------------------------- 3 files changed, 288 insertions(+), 250 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 97f884c66..73ceefd7b 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -99,62 +99,57 @@ type testClientConn struct { roundtrips []*testRoundTrip - rerr error // returned by Read - rbuf bytes.Buffer // sent to the test conn - wbuf bytes.Buffer // sent by the test conn + rerr error // returned by Read + netConnClosed bool // set when the ClientConn closes the net.Conn + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn } -func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { - t.Helper() - - tr := &Transport{} - for _, o := range opts { - o(tr) - } - +func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { tc := &testClientConn{ t: t, - tr: tr, - hooks: newTestSyncHooks(), + tr: cc.t, + cc: cc, + hooks: cc.t.syncHooks, } + cc.tconn = (*testClientConnNetConn)(tc) tc.enc = hpack.NewEncoder(&tc.encbuf) tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) tc.fr.SetMaxReadFrameSize(10 << 20) - t.Cleanup(func() { tc.sync() if tc.rerr == nil { tc.rerr = io.EOF } tc.sync() - if tc.hooks.total != 0 { - t.Errorf("%v goroutines still running after test completed", tc.hooks.total) - } - }) + return tc +} - tc.hooks.newclientconn = func(cc *ClientConn) { - tc.cc = cc - } - const singleUse = false - _, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks) - if err != nil { - t.Fatal(err) - } - tc.sync() - tc.hooks.newclientconn = nil - +func (tc *testClientConn) readClientPreface() { + tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. buf := make([]byte, len(clientPreface)) if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { - t.Fatalf("reading preface: %v", err) + tc.t.Fatalf("reading preface: %v", err) } if !bytes.Equal(buf, clientPreface) { - t.Fatalf("client preface: %q, want %q", buf, clientPreface) + tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) } +} - return tc +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + const singleUse = false + _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() } // sync waits for the ClientConn under test to reach a stable state, @@ -349,7 +344,7 @@ func (b *testRequestBody) closeWithError(err error) { // the request times out, or some other terminal condition is reached.) func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { rt := &testRoundTrip{ - tc: tc, + t: tc.t, donec: make(chan struct{}), } tc.roundtrips = append(tc.roundtrips, rt) @@ -362,6 +357,9 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { tc.hooks.newstream = nil tc.t.Cleanup(func() { + if !rt.done() { + return + } res, _ := rt.result() if res != nil { res.Body.Close() @@ -460,6 +458,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he tc.sync() } +func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) { + tc.t.Helper() + if err := tc.fr.WriteRSTStream(streamID, code); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + func (tc *testClientConn) writePing(ack bool, data [8]byte) { tc.t.Helper() if err := tc.fr.WritePing(ack, data); err != nil { @@ -491,9 +497,25 @@ func (tc *testClientConn) closeWrite(err error) { tc.sync() } +// inflowWindow returns the amount of inbound flow control available for a stream, +// or for the connection if streamID is 0. +func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + tc.cc.mu.Lock() + defer tc.cc.mu.Unlock() + if streamID == 0 { + return tc.cc.inflow.avail + tc.cc.inflow.unsent + } + cs := tc.cc.streams[streamID] + if cs == nil { + tc.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent +} + // testRoundTrip manages a RoundTrip in progress. type testRoundTrip struct { - tc *testClientConn + t *testing.T resp *http.Response respErr error donec chan struct{} @@ -502,6 +524,9 @@ type testRoundTrip struct { // streamID returns the HTTP/2 stream ID of the request. func (rt *testRoundTrip) streamID() uint32 { + if rt.cs == nil { + panic("stream ID unknown") + } return rt.cs.ID } @@ -517,12 +542,12 @@ func (rt *testRoundTrip) done() bool { // result returns the result of the RoundTrip. func (rt *testRoundTrip) result() (*http.Response, error) { - t := rt.tc.t + t := rt.t t.Helper() select { case <-rt.donec: default: - t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID()) + t.Fatalf("RoundTrip is not done; want it to be") } return rt.resp, rt.respErr } @@ -530,7 +555,7 @@ func (rt *testRoundTrip) result() (*http.Response, error) { // response returns the response of a successful RoundTrip. // If the RoundTrip unexpectedly failed, it calls t.Fatal. func (rt *testRoundTrip) response() *http.Response { - t := rt.tc.t + t := rt.t t.Helper() resp, err := rt.result() if err != nil { @@ -544,7 +569,7 @@ func (rt *testRoundTrip) response() *http.Response { // err returns the (possibly nil) error result of RoundTrip. func (rt *testRoundTrip) err() error { - t := rt.tc.t + t := rt.t t.Helper() _, err := rt.result() return err @@ -552,7 +577,7 @@ func (rt *testRoundTrip) err() error { // wantStatus indicates the expected response StatusCode. func (rt *testRoundTrip) wantStatus(want int) { - t := rt.tc.t + t := rt.t t.Helper() if got := rt.response().StatusCode; got != want { t.Fatalf("got response status %v, want %v", got, want) @@ -561,7 +586,7 @@ func (rt *testRoundTrip) wantStatus(want int) { // body reads the contents of the response body. func (rt *testRoundTrip) readBody() ([]byte, error) { - t := rt.tc.t + t := rt.t t.Helper() return io.ReadAll(rt.response().Body) } @@ -569,7 +594,7 @@ func (rt *testRoundTrip) readBody() ([]byte, error) { // wantBody indicates the expected response body. // (Note that this consumes the body.) func (rt *testRoundTrip) wantBody(want []byte) { - t := rt.tc.t + t := rt.t t.Helper() got, err := rt.readBody() if err != nil { @@ -582,7 +607,7 @@ func (rt *testRoundTrip) wantBody(want []byte) { // wantHeaders indicates the expected response headers. func (rt *testRoundTrip) wantHeaders(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Header, want); diff != "" { @@ -592,7 +617,7 @@ func (rt *testRoundTrip) wantHeaders(want http.Header) { // wantTrailers indicates the expected response trailers. func (rt *testRoundTrip) wantTrailers(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Trailer, want); diff != "" { @@ -630,7 +655,8 @@ func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { return nc.wbuf.Write(b) } -func (*testClientConnNetConn) Close() error { +func (nc *testClientConnNetConn) Close() error { + nc.netConnClosed = true return nil } @@ -639,3 +665,91 @@ func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } + +// A testTransport allows testing Transport.RoundTrip against fake servers. +// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling +// should use testClientConn instead. +type testTransport struct { + t *testing.T + tr *Transport + + ccs []*testClientConn +} + +func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { + tr := &Transport{ + syncHooks: newTestSyncHooks(), + } + for _, o := range opts { + o(tr) + } + + tt := &testTransport{ + t: t, + tr: tr, + } + tr.syncHooks.newclientconn = func(cc *ClientConn) { + tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc)) + } + + t.Cleanup(func() { + tt.sync() + if len(tt.ccs) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + } + if tt.tr.syncHooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total) + } + }) + + return tt +} + +func (tt *testTransport) sync() { + tt.tr.syncHooks.waitInactive() +} + +func (tt *testTransport) advance(d time.Duration) { + tt.tr.syncHooks.advance(d) + tt.sync() +} + +func (tt *testTransport) hasConn() bool { + return len(tt.ccs) > 0 +} + +func (tt *testTransport) getConn() *testClientConn { + tt.t.Helper() + if len(tt.ccs) == 0 { + tt.t.Fatalf("no new ClientConns created; wanted one") + } + tc := tt.ccs[0] + tt.ccs = tt.ccs[1:] + tc.sync() + tc.readClientPreface() + return tc +} + +func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tt.t, + donec: make(chan struct{}), + } + tt.tr.syncHooks.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + }) + tt.sync() + + tt.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} diff --git a/http2/transport.go b/http2/transport.go index 1ce5f125c..bf1dacd35 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -184,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -597,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -633,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -718,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -814,6 +824,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo } if hooks != nil { hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d diff --git a/http2/transport_test.go b/http2/transport_test.go index bab2472f3..5de0ad8c4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3688,61 +3688,49 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { } func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + var totalDelay time.Duration + count := 0 + for streamID := uint32(1); ; streamID += 2 { + count++ + tc.wantHeaders(wantHeader{ + streamID: streamID, + endStream: true, + }) + if streamID == 1 { + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) + tc.writeRSTStream(streamID, ErrCodeRefusedStream) + + d := tt.tr.syncHooks.timeUntilEvent() + if d == 0 { + if streamID == 1 { + continue } + break + } + totalDelay += d + if totalDelay > 5*time.Minute { + t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } + tt.advance(d) + } + if got, want := count, 5; got < count { + t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) + } + if rt.err() == nil { + t.Errorf("RoundTrip succeeded, want error") } - ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { @@ -5593,155 +5581,80 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - lower ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool + // This test verifies that + // - receiving a protocol error on a connection does not interfere with + // other requests in flight on that connection; + // - the connection is not reused for further requests; and + // - the failed request is retried on a new connecection. + tt := newTestTransport(t) + + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. + + // Request #1: The long request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tt.roundTrip(req1) + tc1 := tt.getConn() + tc1.wantFrameType(FrameSettings) + tc1.wantFrameType(FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc1.writeSettings() + tc1.wantFrameType(FrameSettings) // settings ACK + + // Request #2(a): The short request. + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := tt.roundTrip(req2) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } + // Request #2(a) fails with ErrCodeProtocol. + tc1.writeRSTStream(3, ErrCodeProtocol) + if rt1.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress") } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil + if rt2.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress") } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) + // Request #2(b): The short request is retried on a new connection. + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc2.writeSettings() + tc2.wantFrameType(FrameSettings) // settings ACK - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - } - ct.run() + // Request #2(b) succeeds. + tc2.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "201", + ), + }) + rt2.wantStatus(201) + + // Request #1 succeeds. + tc1.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) } func TestClientConnReservations(t *testing.T) { From 89f602b7bbf237abe0467031a18b42fc742ced08 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 19 Mar 2024 00:18:26 -0700 Subject: [PATCH 10/19] http2: validate client/outgoing trailers This change is a counterpart to the HTTP/1.1 trailers validation CL 572615. This change will ensure that we validate trailers that were set on the HTTP client before they are transformed to HTTP/2 equivalents. Updates golang/go#64766 Change-Id: Id1afd7f7e9af820ea969ef226bbb16e4de6d57a5 Reviewed-on: https://go-review.googlesource.com/c/net/+/572655 Auto-Submit: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Damien Neil Run-TryBot: Emmanuel Odeke LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- http2/transport.go | 33 ++++++++++++++++++++++----------- http2/transport_test.go | 13 ++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index bf1dacd35..ba0956e22 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2019,6 +2019,22 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } + } + return "" +} + var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -2056,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { diff --git a/http2/transport_test.go b/http2/transport_test.go index 5de0ad8c4..5226a61f7 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2290,7 +2290,8 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { } // golang.org/issue/14048 -func TestTransportFailsOnInvalidHeaders(t *testing.T) { +// golang.org/issue/64766 +func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { @@ -2303,6 +2304,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { tests := [...]struct { h http.Header + t http.Header wantErr string }{ 0: { @@ -2321,6 +2323,14 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { h: http.Header{"foo": {"foo\x01bar"}}, wantErr: `invalid HTTP header value for header "foo"`, }, + 4: { + t: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer value for header "foo"`, + }, + 5: { + t: http.Header{"x-\r\nda": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer name "x-\r\nda"`, + }, } tr := &Transport{TLSClientConfig: tlsConfigInsecure} @@ -2329,6 +2339,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { for i, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h + req.Trailer = tt.t res, err := tr.RoundTrip(req) var bad bool if tt.wantErr == "" { From d73acffdc9493532acb85777105bb4a351eea702 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sun, 10 Mar 2024 23:40:56 +0800 Subject: [PATCH 11/19] http2: only set up deadline when Server.IdleTimeout is positive Check out https://go-review.googlesource.com/c/go/+/570396 Change-Id: I8bda17acebc27308c2a1723191ea1e4a9e64d585 Reviewed-on: https://go-review.googlesource.com/c/net/+/570455 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Reviewed-by: Damien Neil Auto-Submit: Damien Neil --- http2/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/http2/server.go b/http2/server.go index 905206f3e..ce2e8b40e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -124,6 +124,7 @@ type Server struct { // IdleTimeout specifies how long until idle clients should be // closed with a GOAWAY frame. PING frames are not considered // activity for the purposes of IdleTimeout. + // If zero or negative, there is no timeout. IdleTimeout time.Duration // MaxUploadBufferPerConnection is the size of the initial flow @@ -924,7 +925,7 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } @@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { From d8870b0bf2f2426fc8d19a9332f652da5c25418f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:51:41 -0700 Subject: [PATCH 12/19] http2: use synthetic time in TestIdleConnTimeout Rewrite TestIdleConnTimeout to use the new synthetic time and synchronization test facilities, rather than using real time and sleeps. Reduces the test time from 20 seconds to 0. Reduces all package tests on my laptop from 32 seconds to 12. Change-Id: I33838488168450a7acd6a462777b5a4caf7f5307 Reviewed-on: https://go-review.googlesource.com/c/net/+/572379 Reviewed-by: Jonathan Amsterdam Reviewed-by: Emmanuel Odeke LUCI-TryBot-Result: Go LUCI --- http2/transport.go | 4 +- http2/transport_test.go | 90 +++++++++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index ba0956e22..ce375c8c7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -310,7 +310,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -828,7 +828,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) diff --git a/http2/transport_test.go b/http2/transport_test.go index 5226a61f7..18d4db3ed 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -97,63 +97,83 @@ func startH2cServer(t *testing.T) net.Listener { func TestIdleConnTimeout(t *testing.T) { for _, test := range []struct { + name string idleConnTimeout time.Duration wait time.Duration baseTransport *http.Transport - wantConns int32 + wantNewConn bool }{{ + name: "NoExpiry", idleConnTimeout: 2 * time.Second, wait: 1 * time.Second, baseTransport: nil, - wantConns: 1, + wantNewConn: false, }, { + name: "H2TransportTimeoutExpires", idleConnTimeout: 1 * time.Second, wait: 2 * time.Second, baseTransport: nil, - wantConns: 5, + wantNewConn: true, }, { + name: "H1TransportTimeoutExpires", idleConnTimeout: 0 * time.Second, wait: 1 * time.Second, baseTransport: &http.Transport{ IdleConnTimeout: 2 * time.Second, }, - wantConns: 1, + wantNewConn: false, }} { - var gotConns int32 - - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - }, optOnlyServer) - defer st.Close() + t.Run(test.name, func(t *testing.T) { + tt := newTestTransport(t, func(tr *Transport) { + tr.IdleConnTimeout = test.idleConnTimeout + }) + var tc *testClientConn + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // This request happens on a new conn if it's the first request + // (and there is no cached conn), or if the test timeout is long + // enough that old conns are being closed. + wantConn := i == 0 || test.wantNewConn + if has := tt.hasConn(); has != wantConn { + t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn) + } + if wantConn { + tc = tt.getConn() + // Read client's SETTINGS and first WINDOW_UPDATE, + // send our SETTINGS. + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings() + } + if tt.hasConn() { + t.Fatalf("request %v: Transport has more than one conn", i) + } - tr := &Transport{ - IdleConnTimeout: test.idleConnTimeout, - TLSClientConfig: tlsConfigInsecure, - } - defer tr.CloseIdleConnections() + // Respond to the client's request. + hf := testClientConnReadFrame[*MetaHeadersFrame](tc) + tc.writeHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) - for i := 0; i < 5; i++ { - req, _ := http.NewRequest("GET", st.ts.URL, http.NoBody) - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConns, 1) - } - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + // If this was a newly-accepted conn, read the SETTINGS ACK. + if wantConn { + tc.wantFrameType(FrameSettings) // ACK to our settings + } - _, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("%v", err) + tt.advance(test.wait) + if got, want := tc.netConnClosed, test.wantNewConn; got != want { + t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) + } } - - <-time.After(test.wait) - } - - if gotConns != test.wantConns { - t.Errorf("incorrect gotConns: %d != %d", gotConns, test.wantConns) - } + }) } } From c7877ac4213b2f859831366f5a35b353e0dc9f66 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:24:51 -0700 Subject: [PATCH 13/19] http2: convert the remaining clientTester tests to testClientConn Change-Id: Ia7f213346baff48504fef6dfdc112575a5459f35 Reviewed-on: https://go-review.googlesource.com/c/net/+/572380 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 74 ++ http2/transport_test.go | 1651 +++++++++++--------------------------- 2 files changed, 535 insertions(+), 1190 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 73ceefd7b..4237b1436 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "reflect" + "slices" "testing" "time" @@ -209,6 +210,71 @@ func (tc *testClientConn) wantFrameType(want FrameType) { } } +// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied. +// +// want is a list of func(*SomeFrame) bool. +// wantUnorderedFrames will call each func with frames of the appropriate type +// until the func returns true. +// It calls t.Fatal if an unexpected frame is received (no func has that frame type, +// or all funcs with that type have returned true), or if the conn runs out of frames +// with unsatisfied funcs. +// +// Example: +// +// // Read a SETTINGS frame, and any number of DATA frames for a stream. +// // The SETTINGS frame may appear anywhere in the sequence. +// // The last DATA frame must indicate the end of the stream. +// tc.wantUnorderedFrames( +// func(f *SettingsFrame) bool { +// return true +// }, +// func(f *DataFrame) bool { +// return f.StreamEnded() +// }, +// ) +func (tc *testClientConn) wantUnorderedFrames(want ...any) { + tc.t.Helper() + want = slices.Clone(want) + seen := 0 +frame: + for seen < len(want) && !tc.t.Failed() { + fr := tc.readFrame() + if fr == nil { + break + } + for i, f := range want { + if f == nil { + continue + } + typ := reflect.TypeOf(f) + if typ.Kind() != reflect.Func || + typ.NumIn() != 1 || + typ.NumOut() != 1 || + typ.Out(0) != reflect.TypeOf(true) { + tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) + } + if typ.In(0) == reflect.TypeOf(fr) { + out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)}) + if out[0].Bool() { + want[i] = nil + seen++ + } + continue frame + } + } + tc.t.Errorf("got unexpected frame type %T", fr) + } + if seen < len(want) { + for _, f := range want { + if f == nil { + continue + } + tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0)) + } + tc.t.Fatalf("did not see %v expected frame types", len(want)-seen) + } +} + type wantHeader struct { streamID uint32 endStream bool @@ -401,6 +467,14 @@ func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte tc.sync() } +func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + tc.t.Helper() + if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + // makeHeaderBlockFragment encodes headers in a form suitable for inclusion // in a HEADERS or CONTINUATION frame. // diff --git a/http2/transport_test.go b/http2/transport_test.go index 18d4db3ed..855c107ef 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2724,122 +2724,75 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { } func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { - ct := newClientTester(t) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } + tc := newTestClientConn(t) + tc.greet() - if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { - return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) - } - res.Body.Close() // leaving 4999 bytes unread + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - return nil + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) + initialInflow := tc.inflowWindow(0) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + const streamNotEnded = false + if oneDataFrame { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000)) + } else { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1)) } - ct.server = func() error { - ct.greet() - - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialInflow := ct.inflowWindow(0) - - // Two cases: - // - Send one DATA frame with 5000 bytes. - // - Send two DATA frames with 1 and 4999 bytes each. - // - // In both cases, the client should consume one byte of data, - // refund that byte, then refund the following 4999 bytes. - // - // In the second case, the server waits for the client to reset the - // stream before sending the second DATA frame. This tests the case - // where the client receives a DATA frame after it has reset the stream. - if oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - } else { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - } + res := rt.response() + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + t.Fatalf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + tc.sync() - wantRST := true - wantWUF := true - if !oneDataFrame { - wantWUF = false // flow control update is small, and will not be sent - } - for wantRST || wantWUF { - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + sentAdditionalData := false + tc.wantUnorderedFrames( + func(f *RSTStreamFrame) bool { + if f.ErrCode != ErrCodeCancel { + t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - switch f := f.(type) { - case *RSTStreamFrame: - if !wantRST { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) - } - wantRST = false - case *WindowUpdateFrame: - if !wantWUF { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.Increment != 5000 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) - } - wantWUF = false - default: - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) + if !oneDataFrame { + // Send the remaining data now. + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999)) + sentAdditionalData = true } - } - if !oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + return true + }, + func(f *WindowUpdateFrame) bool { + if !oneDataFrame && !sentAdditionalData { + t.Fatalf("Got WindowUpdateFrame, don't expect one yet") } - wuf, ok := f.(*WindowUpdateFrame) - if !ok || wuf.Increment != 5000 { - return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + if f.Increment != 5000 { + t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - } - if err := ct.writeReadPing(); err != nil { - return err - } - if got, want := ct.inflowWindow(0), initialInflow; got != want { - return fmt.Errorf("connection flow tokens = %v, want %v", got, want) - } - return nil + return true + }, + ) + + if got, want := tc.inflowWindow(0), initialInflow; got != want { + t.Fatalf("connection flow tokens = %v, want %v", got, want) } - ct.run() } // See golang.org/issue/16481 @@ -2855,199 +2808,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. func TestTransportAdjustsFlowControl(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - const bodySize = 1 << 20 - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + // Don't write our SETTINGS yet. - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err + body := tc.newRequestBody() + body.writeBytes(bodySize) + body.closeWithError(io.EOF) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + gotBytes := int64(0) + for { + f := testClientConnReadFrame[*DataFrame](tc) + gotBytes += int64(len(f.Data())) + // After we've got half the client's initial flow control window's worth + // of request body data, give it just enough flow control to finish. + if gotBytes >= initialWindowSize/2 { + break } - res.Body.Close() - return nil } - ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - var gotBytes int64 - var sentSettings bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - return nil - default: - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - } - switch f := f.(type) { - case *DataFrame: - gotBytes += int64(len(f.Data())) - // After we've got half the client's - // initial flow control window's worth - // of request body data, give it just - // enough flow control to finish. - if gotBytes >= initialWindowSize/2 && !sentSettings { - sentSettings = true - - ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) - ct.fr.WriteWindowUpdate(0, bodySize) - ct.fr.WriteSettingsAck() - } + tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) + tc.writeWindowUpdate(0, bodySize) + tc.writeSettingsAck() - if f.StreamEnded() { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - } - } + tc.wantUnorderedFrames( + func(f *SettingsFrame) bool { return true }, + func(f *DataFrame) bool { + gotBytes += int64(len(f.Data())) + return f.StreamEnded() + }, + ) + + if gotBytes != bodySize { + t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize) } - ct.run() + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) } // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() - unblockClient := make(chan bool, 1) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - defer res.Body.Close() - <-unblockClient - return nil - } - ct.server = func() error { - ct.greet() + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } + initialConnWindow := tc.inflowWindow(0) + initialStreamWindow := tc.inflowWindow(rt.streamID()) - initialConnWindow := ct.inflowWindow(0) + pad := make([]byte, 5) + tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialStreamWindow := ct.inflowWindow(hf.StreamID) - pad := make([]byte, 5) - ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - if err := ct.writeReadPing(); err != nil { - return err - } - // Padding flow control should have been returned. - if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { - t.Errorf("conn inflow window = %v, want %v", got, want) - } - if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { - t.Errorf("stream inflow window = %v, want %v", got, want) - } - unblockClient <- true - return nil + // Padding flow control should have been returned. + if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) + } + if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } - ct.run() } // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - return errors.New("unexpected successful GET") - } - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} - if !reflect.DeepEqual(want, err) { - t.Errorf("RoundTrip error = %#v; want %#v", err, want) - } - return nil - } - ct.server = func() error { - ct.greet() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - hf, err := ct.firstHeaders() - if err != nil { - return err - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + " content-type", "bogus", + ), + }) - for { - fr, err := ct.readFrame() - if err != nil { - return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) - } - if _, ok := fr.(*SettingsFrame); ok { - continue - } - if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) - } - break - } + err := rt.err() + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(err, want) { + t.Fatalf("RoundTrip error = %#v; want %#v", err, want) + } - return nil + fr := testClientConnReadFrame[*RSTStreamFrame](tc) + if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } - ct.run() } // byteAndEOFReader returns is in an io.Reader which reads one byte @@ -3461,261 +3339,84 @@ func TestTransportPingWhenReadingPingDisabled(t *testing.T) { } } -func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { - var pingCount int - ct := newClientTester(t) - ct.tr.ReadIdleTimeout = readIdleTimeout - - ctx, cancel := context.WithTimeout(context.Background(), deadline) - defer cancel() - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) - } - _, err = ioutil.ReadAll(res.Body) - if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil - } +func TestTransportRetryAfterGOAWAY(t *testing.T) { + tt := newTestTransport(t) - cancel() - return err - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var streamID uint32 - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-ctx.Done(): - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - streamID = f.StreamID - case *PingFrame: - pingCount++ - if pingCount == expectedPingCount { - if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { - return err - } - } - if err := ct.fr.WritePing(true, f.Data); err != nil { - return err - } - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil) + if rt.done() { + t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying") } - ct.run() -} - -func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { - ln := newLocalListener(t) - defer ln.Close() - var ( - mu sync.Mutex - count int - conns []net.Conn - ) - var wg sync.WaitGroup - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - } - tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - defer mu.Unlock() - count++ - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - return nil, fmt.Errorf("dial error: %v", err) - } - conns = append(conns, cc) - sc, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("accept error: %v", err) - } - conns = append(conns, sc) - ct := &clientTester{ - t: t, - tr: tr, - cc: cc, - sc: sc, - fr: NewFramer(sc, sc), - } - wg.Add(1) - go func(count int) { - defer wg.Done() - server(count, ct) - }(count) - return cc, nil - } + // Second attempt succeeds on a new connection. + tc = tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - client(tr) - tr.CloseIdleConnections() - ln.Close() - for _, c := range conns { - c.Close() - } - wg.Wait() + rt.wantStatus(200) } -func TestTransportRetryAfterGOAWAY(t *testing.T) { - client := func(tr *Transport) { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := tr.RoundTrip(req) - if res != nil { - res.Body.Close() - if got := res.Header.Get("Foo"); got != "bar" { - err = fmt.Errorf("foo header = %q; want bar", got) - } - } - if err != nil { - t.Errorf("RoundTrip: %v", err) - } - } - - server := func(count int, ct *clientTester) { - switch count { - case 1: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - t.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - case 2: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server2 failed reading HEADERS: %v", err) - return - } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - t.Errorf("server2 failed writing response HEADERS: %v", err) - } - default: - t.Errorf("unexpected number of dials") - return - } - } +func TestTransportRetryAfterRefusedStream(t *testing.T) { + tt := newTestTransport(t) - testClientMultipleDials(t, client, server) -} + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) -func TestTransportRetryAfterRefusedStream(t *testing.T) { - clientDone := make(chan struct{}) - client := func(tr *Transport) { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - return - } - resp.Body.Close() - if resp.StatusCode != 204 { - t.Errorf("Status = %v; want 204", resp.StatusCode) - return - } + // First attempt: Server sends a RST_STREAM. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK + tc.writeRSTStream(1, ErrCodeRefusedStream) + if rt.done() { + t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying") } - server := func(_ int, ct *clientTester) { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var count int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - default: - t.Error(err) - } - return - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - t.Errorf("headers should have END_HEADERS be ended: %v", f) - return - } - count++ - if count == 1 { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - } else { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - default: - t.Errorf("Unexpected client frame %v", f) - return - } - } - } + // Second attempt succeeds on the same connection. + tc.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 3, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(204) } func TestTransportRetryHasLimit(t *testing.T) { @@ -3765,67 +3466,34 @@ func TestTransportRetryHasLimit(t *testing.T) { } func TestTransportResponseDataBeforeHeaders(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) + // Discard log output complaining about protocol error. + log.SetOutput(io.Discard) + t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req := httptest.NewRequest("GET", "https://dummy.tld/", nil) - // First request is normal to ensure the check is per stream and not per connection. - _, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip expected no error, got: %v", err) - } - // Second request returns a DATA frame with no HEADERS. - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { - return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: - case *HeadersFrame: - switch f.StreamID { - case 1: - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - case 3: - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - } - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc := newTestClientConn(t) + tc.greet() + + // First request is normal to ensure the check is per stream and not per connection. + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt1.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) + + // Second request returns a DATA frame with no HEADERS. + rt2 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeData(rt2.streamID(), true, []byte("payload")) + if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err) } - ct.run() } func TestTransportMaxFrameReadSize(t *testing.T) { @@ -3839,30 +3507,17 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - ct := newClientTester(t) - ct.tr.MaxReadFrameSize = test.maxReadFrameSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - return nil - } - ct.server = func() error { - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - var got uint32 - ct.settings.ForeachSetting(func(s Setting) error { - switch s.ID { - case SettingMaxFrameSize: - got = s.Val - } - return nil - }) - if got != test.want { - t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxReadFrameSize = test.maxReadFrameSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + got, ok := fr.Value(SettingMaxFrameSize) + if !ok { + t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want) + } else if got != test.want { + t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) } - ct.run() } } @@ -3902,337 +3557,126 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { - res.Body.Close() - } - } - - if connCount != 1 { - t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) - } -} - -// tests Transport.StrictMaxConcurrentStreams -func TestTransportRequestsStallAtServerLimit(t *testing.T) { - const maxConcurrent = 2 - - greet := make(chan struct{}) // server sends initial SETTINGS frame - gotRequest := make(chan struct{}) // server received a request - clientDone := make(chan struct{}) - cancelClientRequest := make(chan struct{}) - - // Collect errors from goroutines. - var wg sync.WaitGroup - errs := make(chan error, 100) - defer func() { - wg.Wait() - close(errs) - for err := range errs { - t.Error(err) - } - }() - - // We will send maxConcurrent+2 requests. This checker goroutine waits for the - // following stages: - // 1. The first maxConcurrent requests are received by the server. - // 2. The client will cancel the next request - // 3. The server is unblocked so it can service the first maxConcurrent requests - // 4. The client will send the final request - wg.Add(1) - unblockClient := make(chan struct{}) - clientRequestCancelled := make(chan struct{}) - unblockServer := make(chan struct{}) - go func() { - defer wg.Done() - // Stage 1. - for k := 0; k < maxConcurrent; k++ { - <-gotRequest - } - // Stage 2. - close(unblockClient) - <-clientRequestCancelled - // Stage 3: give some time for the final RoundTrip call to be scheduled and - // verify that the final request is not sent. - time.Sleep(50 * time.Millisecond) - select { - case <-gotRequest: - errs <- errors.New("last request did not stall") - close(unblockServer) - return - default: - } - close(unblockServer) - // Stage 4. - <-gotRequest - }() - - ct := newClientTester(t) - ct.tr.StrictMaxConcurrentStreams = true - ct.client = func() error { - var wg sync.WaitGroup - defer func() { - wg.Wait() - close(clientDone) - ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.cc.(*net.TCPConn).Close() - } - }() - for k := 0; k < maxConcurrent+2; k++ { - wg.Add(1) - go func(k int) { - defer wg.Done() - // Don't send the second request until after receiving SETTINGS from the server - // to avoid a race where we use the default SettingMaxConcurrentStreams, which - // is much larger than maxConcurrent. We have to send the first request before - // waiting because the first request triggers the dial and greet. - if k > 0 { - <-greet - } - // Block until maxConcurrent requests are sent before sending any more. - if k >= maxConcurrent { - <-unblockClient - } - body := newStaticCloseChecker("") - req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) - if k == maxConcurrent { - // This request will be canceled. - req.Cancel = cancelClientRequest - close(cancelClientRequest) - _, err := ct.tr.RoundTrip(req) - close(clientRequestCancelled) - if err == nil { - errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) - return - } - } else { - resp, err := ct.tr.RoundTrip(req) - if err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - return - } - ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 204 { - errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - if err := body.isClosed(); err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - } - }(k) + res.Body.Close() } - return nil } - ct.server = func() error { - var wg sync.WaitGroup - defer wg.Wait() + if connCount != 1 { + t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) + } +} - ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) +// tests Transport.StrictMaxConcurrentStreams +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 - // Server write loop. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - writeResp := make(chan uint32, maxConcurrent+1) + tc := newTestClientConn(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) + tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) - wg.Add(1) - go func() { - defer wg.Done() - <-unblockServer - for id := range writeResp { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: id, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - }() + cancelClientRequest := make(chan struct{}) - // Server read loop. - var nreq int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it will have reported any errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame: - case *SettingsFrame: - // Wait for the client SETTINGS ack until ending the greet. - close(greet) - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - gotRequest <- struct{}{} - nreq++ - writeResp <- f.StreamID - if nreq == maxConcurrent+1 { - close(writeResp) - } - case *DataFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) + // Start maxConcurrent+2 requests. + // The server does not respond to any of them yet. + var rts []*testRoundTrip + for k := 0; k < maxConcurrent+2; k++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + req.Cancel = cancelClientRequest + } + rt := tc.roundTrip(req) + rts = append(rts, rt) + + if k < maxConcurrent { + // We are under the stream limit, so the client sends the request. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", k)}, + }, + }) + } else { + // We have reached the stream limit, + // so the client cannot send the request. + if fr := tc.readFrame(); fr != nil { + t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr) } } + + if rt.done() { + t.Fatalf("rt %v done", k) + } + } + + // Cancel the maxConcurrent'th request. + // The request should fail. + close(cancelClientRequest) + tc.sync() + if err := rts[maxConcurrent].err(); err == nil { + t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) + } + + // No requests should be complete, except for the canceled one. + for i, rt := range rts { + if i != maxConcurrent && rt.done() { + t.Fatalf("RoundTrip(%d) is done, but should not be", i) + } } - ct.run() + // Server responds to a request, unblocking the last one. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rts[0].streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.wantHeaders(wantHeader{ + streamID: rts[maxConcurrent+1].streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)}, + }, + }) + rts[0].wantStatus(200) } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var reqSize, resSize uint32 = 8192, 16384 - ct.tr.MaxDecoderHeaderTableSize = reqSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.peerMaxHeaderTableSize, resSize; got != want { - return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxDecoderHeaderTableSize = reqSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + if v, ok := fr.Value(SettingHeaderTableSize); !ok { + t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting") + } else if v != reqSize { + t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize) } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - var found bool - err = sf.ForeachSetting(func(s Setting) error { - if s.ID == SettingHeaderTableSize { - found = true - if got, want := s.Val, reqSize; got != want { - return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want) - } - } - return nil - }) - if err != nil { - return err - } - if !found { - return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting") - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) + if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want { + t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } - ct.run() } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var peerAdvertisedMaxHeaderTableSize uint32 = 16384 - ct.tr.MaxEncoderHeaderTableSize = 8192 - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want { - return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want) - } - return nil - } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxEncoderHeaderTableSize = 8192 + }) + tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { + t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } - ct.run() } func TestAuthorityAddr(t *testing.T) { @@ -4316,40 +3760,24 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. func TestTransportNoBodyMeansNoDATA(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool) + tc := newTestClientConn(t) + tc.greet() - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - <-unblockClient - return nil - } - ct.server = func() error { - defer close(unblockClient) - defer ct.cc.(*net.TCPConn).Close() - ct.greet() + req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f := f.(type) { - default: - return fmt.Errorf("Got %T; want HeadersFrame", f) - case *WindowUpdateFrame, *SettingsFrame: - continue - case *HeadersFrame: - if !f.StreamEnded() { - return fmt.Errorf("got headers frame without END_STREAM") - } - return nil - } - } + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, // END_STREAM should be set when body is http.NoBody + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + }, + }) + if fr := tc.readFrame(); fr != nil { + t.Fatalf("unexpected frame after headers: %v", fr) } - ct.run() } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { @@ -4428,41 +3856,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - const substr = "malformed response from server: missing status pseudo header" - if !strings.Contains(fmt.Sprint(err), substr) { - return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, // we'll send some DATA to try to crash the transport - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - return nil - } - } - } - ct.run() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: tc.makeHeaderBlockFragment( + "content-type", "text/html", // no :status header + ), + }) + tc.writeData(rt.streamID(), true, []byte("payload")) } func BenchmarkClientRequestHeaders(b *testing.B) { @@ -4810,95 +4219,42 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - // So far we've only seen this be flaky on Windows and Plan 9, - // perhaps due to TCP behavior on shutdowns while - // unread data is in flight. This test should be - // fixed, but a skip is better than annoying people - // for now. - t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) - } - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*clientConnPool) - if !ok { - return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) - } - cp.mu.Lock() - defer cp.mu.Unlock() - conns, ok := cp.conns["dummy.tld:443"] - if !ok { - return fmt.Errorf("missing connection") - } - if len(conns) != 1 { - return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) - } - if activeStreams(conns[0]) != 0 { - return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) - } - return nil - } - bodyReadError := errors.New("body read error") - body := &errReader{body, bodyReadError} - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != bodyReadError { - return fmt.Errorf("err = %v; want %v", err, bodyReadError) - } - if err = checkNoStreams(); err != nil { - return err + tc := newTestClientConn(t) + tc.greet() + + bodyReadError := errors.New("body read error") + b := tc.newRequestBody() + b.Write(body) + b.closeWithError(bodyReadError) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", b) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + var receivedBody []byte +readFrames: + for { + switch f := tc.readFrame().(type) { + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + break readFrames + default: + t.Fatalf("unexpected frame: %v", f) + case nil: + t.Fatalf("transport is idle, want RST_STREAM") } - return nil } - ct.server = func() error { - ct.greet() - var receivedBody []byte - var resetCount int - for { - f, err := ct.fr.ReadFrame() - t.Logf("server: ReadFrame = %v, %v", f, err) - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - if bytes.Compare(receivedBody, body) != 0 { - return fmt.Errorf("body: %q; expected %q", receivedBody, body) - } - if resetCount != 1 { - return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) - } - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - receivedBody = append(receivedBody, f.Data()...) - case *RSTStreamFrame: - resetCount++ - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + if !bytes.Equal(receivedBody, body) { + t.Fatalf("body: %q; expected %q", receivedBody, body) + } + + if err := rt.err(); err != bodyReadError { + t.Fatalf("err = %v; want %v", err, bodyReadError) + } + + if got := activeStreams(tc.cc); got != 0 { + t.Fatalf("active streams count: %v; want 0", got) } - ct.run() } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -4911,59 +4267,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) { const reqBody = "some request body" const resBody = "some response body" - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - body := strings.NewReader(reqBody) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } + body := strings.NewReader(reqBody) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + tc.roundTrip(req) - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - return fmt.Errorf("data frame without END_STREAM %v", f) - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte(resBody)) - return nil - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc.wantFrameType(FrameHeaders) + f := testClientConnReadFrame[*DataFrame](tc) + if !f.StreamEnded() { + t.Fatalf("data frame without END_STREAM %v", f) } - ct.run() } type chunkReader struct { @@ -5737,39 +5052,27 @@ func TestClientConnReservations(t *testing.T) { } func TestTransportTimeoutServerHangs(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) + tc := newTestClientConn(t) + tc.greet() - req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) - if err != nil { - return err - } + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - req.Header.Add("Big", strings.Repeat("a", 1<<20)) - _, err = ct.tr.RoundTrip(req) - if err == nil { - return errors.New("error should not be nil") - } - if ne, ok := err.(net.Error); !ok || !ne.Timeout() { - return fmt.Errorf("error should be a net error timeout: %v", err) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.advance(5 * time.Second) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - ct.server = func() error { - ct.greet() - select { - case <-time.After(5 * time.Second): - case <-clientDone: - } - return nil + if rt.done() { + t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned") + } + + cancel() + tc.sync() + if rt.err() != context.Canceled { + t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } - ct.run() } func TestTransportContentLengthWithoutBody(t *testing.T) { @@ -5962,20 +5265,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } -type closeOnceConn struct { - net.Conn - closed uint32 -} - -var errClosed = errors.New("Close of closed connection") - -func (c *closeOnceConn) Close() error { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return c.Conn.Close() - } - return errClosed -} - // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // @@ -5983,53 +5272,35 @@ func (c *closeOnceConn) Close() error { // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { - ct := newClientTester(t) - ct.cc = &closeOnceConn{Conn: ct.cc} + tc := newTestClientConn(t) + tc.greet() - var wg sync.WaitGroup - wg.Add(1) - ct.client = func() error { - defer wg.Done() - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - } - if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { - t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) - } - if err = ct.cc.Close(); err != errClosed { - return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) - } - return nil - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ct.server = func() error { - defer wg.Wait() - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - return fmt.Errorf("server failed reading HEADERS: %v", err) - } - if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { - return fmt.Errorf("server failed writing GOAWAY: %v", err) - } - if lastStream > 0 { - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.writeGoAway(lastStream, ErrCodeNo, nil) + + if lastStream > 0 { + // Send a valid response to first request. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) } - ct.run() + tc.closeWrite(io.EOF) + err := rt.err() + if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { + t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) + } + if !tc.netConnClosed { + t.Errorf("ClientConn did not close its net.Conn, expected it to") + } } type slowCloser struct { From 448c44f9287b6745f958d74aa2a17ec7761c2f13 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Mar 2024 10:37:19 -0700 Subject: [PATCH 14/19] http2: remove clientTester All tests which use clientTester have been converted to use testClientConn, so delete clientTester. Change-Id: Id9a88bf7ee6760fada8442d383d5e68455c6dc3e Reviewed-on: https://go-review.googlesource.com/c/net/+/572815 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/transport_test.go | 195 ---------------------------------------- 1 file changed, 195 deletions(-) diff --git a/http2/transport_test.go b/http2/transport_test.go index 855c107ef..11ff67b4c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -822,53 +822,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { return } -type clientTester struct { - t *testing.T - tr *Transport - sc, cc net.Conn // server and client conn - fr *Framer // server's framer - settings *SettingsFrame - client func() error - server func() error -} - -func newClientTester(t *testing.T) *clientTester { - var dialOnce struct { - sync.Mutex - dialed bool - } - ct := &clientTester{ - t: t, - } - ct.tr = &Transport{ - TLSClientConfig: tlsConfigInsecure, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialOnce.Lock() - defer dialOnce.Unlock() - if dialOnce.dialed { - return nil, errors.New("only one dial allowed in test mode") - } - dialOnce.dialed = true - return ct.cc, nil - }, - } - - ln := newLocalListener(t) - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - sc, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - ln.Close() - ct.cc = cc - ct.sc = sc - ct.fr = NewFramer(sc, sc) - return ct -} - func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { @@ -881,154 +834,6 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func (ct *clientTester) greet(settings ...Setting) { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - ct.t.Fatalf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - ct.t.Fatalf("Reading client settings frame: %v", err) - } - var ok bool - if ct.settings, ok = f.(*SettingsFrame); !ok { - ct.t.Fatalf("Wanted client settings frame; got %v", f) - } - if err := ct.fr.WriteSettings(settings...); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } -} - -func (ct *clientTester) readNonSettingsFrame() (Frame, error) { - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return nil, err - } - if _, ok := f.(*SettingsFrame); ok { - continue - } - return f, nil - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection, -// aside from SETTINGS frames. -func (ct *clientTester) writeReadPing() error { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := ct.fr.WritePing(false, data); err != nil { - return fmt.Errorf("Error writing PING: %v", err) - } - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - p, ok := f.(*PingFrame) - if !ok { - return fmt.Errorf("got a %v, want a PING ACK", f) - } - if p.Flags&FlagPingAck == 0 { - return fmt.Errorf("got a PING, want a PING ACK") - } - if p.Data != data { - return fmt.Errorf("got PING data = %x, want %x", p.Data, data) - } - return nil -} - -func (ct *clientTester) inflowWindow(streamID uint32) int32 { - pool := ct.tr.connPoolOrDef.(*clientConnPool) - pool.mu.Lock() - defer pool.mu.Unlock() - if n := len(pool.keys); n != 1 { - ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) - return -1 - } - for cc := range pool.keys { - cc.mu.Lock() - defer cc.mu.Unlock() - if streamID == 0 { - return cc.inflow.avail + cc.inflow.unsent - } - cs := cc.streams[streamID] - if cs == nil { - ct.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent - } - return -1 -} - -func (ct *clientTester) cleanup() { - ct.tr.CloseIdleConnections() - - // close both connections, ignore the error if its already closed - ct.sc.Close() - ct.cc.Close() -} - -func (ct *clientTester) run() { - var errOnce sync.Once - var wg sync.WaitGroup - - run := func(which string, fn func() error) { - defer wg.Done() - if err := fn(); err != nil { - errOnce.Do(func() { - ct.t.Errorf("%s: %v", which, err) - ct.cleanup() - }) - } - } - - wg.Add(2) - go run("client", ct.client) - go run("server", ct.server) - wg.Wait() - - errOnce.Do(ct.cleanup) // clean up if no error -} - -func (ct *clientTester) readFrame() (Frame, error) { - return ct.fr.ReadFrame() -} - -func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { - for { - f, err := ct.readFrame() - if err != nil { - return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - hf, ok := f.(*HeadersFrame) - if !ok { - return nil, fmt.Errorf("Got %T; want HeadersFrame", f) - } - return hf, nil - } -} - -type countingReader struct { - n *int64 -} - -func (r countingReader) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(i) - } - atomic.AddInt64(r.n, int64(len(p))) - return len(p), err -} - func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } From 3678185f8a652e52864c44049a9ea96b7bcc066a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Mar 2024 13:42:48 -0700 Subject: [PATCH 15/19] http2: make TestCanonicalHeaderCacheGrowth faster Lower the number of iterations that this test runs for. Reduces runtime with -race from 27s on my M1 Mac to 0.06s. Change-Id: Ibd4b225277c79d9030c0a21b3077173a787cc4c1 Reviewed-on: https://go-review.googlesource.com/c/net/+/572656 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/server_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/http2/server_test.go b/http2/server_test.go index 1fdd191ef..afccd9ecd 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4578,13 +4578,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { sc := &serverConn{ serveG: newGoroutineLock(), } - const count = 1000 - for i := 0; i < count; i++ { - h := fmt.Sprintf("%v-%v", base, i) + count := 0 + added := 0 + for added < 10*maxCachedCanonicalHeadersKeysSize { + h := fmt.Sprintf("%v-%v", base, count) c := sc.canonicalHeader(h) if len(h) != len(c) { t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) } + count++ + added += len(h) } total := 0 for k, v := range sc.canonHeader { From ebc8168ac8ac742194df729305175940790c55a2 Mon Sep 17 00:00:00 2001 From: vitalmotif Date: Wed, 20 Mar 2024 09:32:28 +0000 Subject: [PATCH 16/19] all: fix some typos Change-Id: I7e2c867efcc960553da77e395b0069ab6776cd9f GitHub-Last-Rev: eaa122d1b6086c22f329227053411b0a73a5215b GitHub-Pull-Request: golang/net#205 Reviewed-on: https://go-review.googlesource.com/c/net/+/572995 Reviewed-by: Emmanuel Odeke Reviewed-by: David Chase Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- dns/dnsmessage/message_test.go | 2 +- quic/rangeset.go | 2 +- quic/retry_test.go | 2 +- quic/stream_limits_test.go | 2 +- quic/version_test.go | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index e60ec42d9..255530598 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -1635,7 +1635,7 @@ func FuzzUnpackPack(f *testing.F) { msgPacked, err := m.Pack() if err != nil { - t.Fatalf("failed to pack message that was succesfully unpacked: %v", err) + t.Fatalf("failed to pack message that was successfully unpacked: %v", err) } var m2 Message diff --git a/quic/rangeset.go b/quic/rangeset.go index 4966a99d2..b8b2e9367 100644 --- a/quic/rangeset.go +++ b/quic/rangeset.go @@ -50,7 +50,7 @@ func (s *rangeset[T]) add(start, end T) { if end <= r.end { return } - // Possibly coalesce subsquent ranges into range i. + // Possibly coalesce subsequent ranges into range i. r.end = end j := i + 1 for ; j < len(*s) && r.end >= (*s)[j].start; j++ { diff --git a/quic/retry_test.go b/quic/retry_test.go index 42f2bdd4a..c898ad331 100644 --- a/quic/retry_test.go +++ b/quic/retry_test.go @@ -521,7 +521,7 @@ func TestParseInvalidRetryPackets(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { if _, ok := parseRetryPacket(test.pkt, originalDstConnID); ok { - t.Errorf("parseRetryPacket succeded, want failure") + t.Errorf("parseRetryPacket succeeded, want failure") } }) } diff --git a/quic/stream_limits_test.go b/quic/stream_limits_test.go index 9c2f71ec1..8fed825d7 100644 --- a/quic/stream_limits_test.go +++ b/quic/stream_limits_test.go @@ -249,7 +249,7 @@ func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStopSending{ id: s.id, }) - tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM", + tc.wantFrame("received STOP_SENDING, send RESET_STREAM", packetType1RTT, debugFrameResetStream{ id: s.id, }) diff --git a/quic/version_test.go b/quic/version_test.go index 92fabd7b3..0bd8bac14 100644 --- a/quic/version_test.go +++ b/quic/version_test.go @@ -39,10 +39,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { }) gotPkt := te.read() if gotPkt == nil { - t.Fatalf("got no response; want Version Negotiaion") + t.Fatalf("got no response; want Version Negotiation") } if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { - t.Fatalf("got packet type %v; want Version Negotiaion", got) + t.Fatalf("got packet type %v; want Version Negotiation", got) } gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) if got, want := gotDst, srcConnID; !bytes.Equal(got, want) { From ba872109ef2dc8f1da778651bd1fd3792d0e4587 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 10 Jan 2024 13:41:39 -0800 Subject: [PATCH 17/19] http2: close connections when receiving too many headers Maintaining HPACK state requires that we parse and process all HEADERS and CONTINUATION frames on a connection. When a request's headers exceed MaxHeaderBytes, we don't allocate memory to store the excess headers but we do parse them. This permits an attacker to cause an HTTP/2 endpoint to read arbitrary amounts of data, all associated with a request which is going to be rejected. Set a limit on the amount of excess header frames we will process before closing a connection. Thanks to Bartek Nowotarski for reporting this issue. Fixes CVE-2023-45288 Fixes golang/go#65051 Change-Id: I15df097268df13bb5a9e9d3a5c04a8a141d850f6 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/2130527 Reviewed-by: Roland Shoemaker Reviewed-by: Tatiana Bradley Reviewed-on: https://go-review.googlesource.com/c/net/+/576155 Reviewed-by: Dmitri Shuralyov Auto-Submit: Dmitri Shuralyov Reviewed-by: Than McIntosh LUCI-TryBot-Result: Go LUCI --- http2/frame.go | 31 ++++++++++++++++ http2/server_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/http2/frame.go b/http2/frame.go index e2b298d85..a5a94411d 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the struture of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the struture of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { return nil, ConnectionError(ErrCodeCompression) } diff --git a/http2/server_test.go b/http2/server_test.go index afccd9ecd..d400990d2 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4786,3 +4786,87 @@ Frames: close(s) } } + +func TestServerContinuationFlood(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }, func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 4096 + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + for i := 0; i < 1000; i++ { + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + fmt.Sprintf("x-%v", i), "1234567890", + )) + } + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-last-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} + +func TestServerContinuationAfterInvalidHeader(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + "x-invalid-header", "\x00", + )) + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-valid-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} From 762b58d1cf6e0779780decad89c6c1523386638d Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Wed, 3 Apr 2024 09:32:37 -0700 Subject: [PATCH 18/19] http2: fix tipos in comment Change-Id: I20cd0f8db534fe2a849306eb7e0c8ee5b434e88f Reviewed-on: https://go-review.googlesource.com/c/net/+/576175 Auto-Submit: Ian Lance Taylor Reviewed-by: Ian Lance Taylor LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- http2/frame.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/http2/frame.go b/http2/frame.go index a5a94411d..43557ab7e 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1591,7 +1591,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { log.Printf("http2: header list too large") } // It would be nice to send a RST_STREAM before sending the GOAWAY, - // but the struture of the server's frame writer makes this difficult. + // but the structure of the server's frame writer makes this difficult. return nil, ConnectionError(ErrCodeProtocol) } @@ -1603,7 +1603,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { log.Printf("http2: invalid header: %v", invalid) } // It would be nice to send a RST_STREAM before sending the GOAWAY, - // but the struture of the server's frame writer makes this difficult. + // but the structure of the server's frame writer makes this difficult. return nil, ConnectionError(ErrCodeProtocol) } From c48da131589f122489348be5dfbcb6457640046f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 3 Apr 2024 10:17:18 -0700 Subject: [PATCH 19/19] http2: fix TestServerContinuationFlood flakes This test causes the server to send a GOAWAY and close a connection. The server GOAWAY path writes a GOAWAY frame asynchronously, and closes the connection if the write doesn't complete within 1s. This is causing failures on some builders, when the frame write doesn't complete in time. The important aspect of this test is that the connection be closed. Drop the check for the GOAWAY frame. Change-Id: I099413be9c4dfe71d8fe83d2c6242e82e282293e Reviewed-on: https://go-review.googlesource.com/c/net/+/576235 Reviewed-by: Dmitri Shuralyov Reviewed-by: Dmitri Shuralyov Reviewed-by: Than McIntosh LUCI-TryBot-Result: Go LUCI --- http2/server_test.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/http2/server_test.go b/http2/server_test.go index d400990d2..a931a06e5 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4813,22 +4813,24 @@ func TestServerContinuationFlood(t *testing.T) { "x-last-header", "1", )) - var sawGoAway bool for { f, err := st.readFrame() if err != nil { break } switch f.(type) { - case *GoAwayFrame: - sawGoAway = true case *HeadersFrame: - t.Fatalf("received HEADERS frame; want GOAWAY") + t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection") } } - if !sawGoAway { - t.Errorf("connection closed with no GOAWAY frame; want one") - } + // We expect to have seen a GOAWAY before the connection closes, + // but the server will close the connection after one second + // whether or not it has finished sending the GOAWAY. On windows-amd64-race + // builders, this fairly consistently results in the connection closing without + // the GOAWAY being sent. + // + // Since the server's behavior is inherently racy here and the important thing + // is that the connection is closed, don't check for the GOAWAY having been sent. } func TestServerContinuationAfterInvalidHeader(t *testing.T) {