diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index e60ec42d9..255530598 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -1635,7 +1635,7 @@ func FuzzUnpackPack(f *testing.F) { msgPacked, err := m.Pack() if err != nil { - t.Fatalf("failed to pack message that was succesfully unpacked: %v", err) + t.Fatalf("failed to pack message that was successfully unpacked: %v", err) } var m2 Message diff --git a/go.mod b/go.mod index 36207106d..1446f39a4 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.21.0 - golang.org/x/sys v0.18.0 - golang.org/x/term v0.18.0 + golang.org/x/crypto v0.22.0 + golang.org/x/sys v0.19.0 + golang.org/x/term v0.19.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 69fb10498..a4e7f116d 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q= +golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index c3bd9a1ee..6404aaf15 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -149,10 +149,7 @@ func parseProxy(proxy string) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) - if err != nil || - (proxyURL.Scheme != "http" && - proxyURL.Scheme != "https" && - proxyURL.Scheme != "socks5") { + if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" { // proxy was bogus. Try prepending "http://" to it and // see if that parses correctly. If not, we fall // through and complain about the original one. diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index d76373295..790afdab7 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "cache.corp.example.com", }, want: "http://cache.corp.example.com", +}, { + // single label domain is recognized as scheme by url.Parse + cfg: httpproxy.Config{ + HTTPProxy: "localhost", + }, + want: "http://localhost", }, { cfg: httpproxy.Config{ HTTPProxy: "https://cache.corp.example.com", @@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "socks5://127.0.0.1", }, want: "socks5://127.0.0.1", +}, { + // Preserve unknown schemes. + cfg: httpproxy.Config{ + HTTPProxy: "foo://host", + }, + want: "foo://host", }, { // Don't use secure for http cfg: httpproxy.Config{ diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go new file mode 100644 index 000000000..4237b1436 --- /dev/null +++ b/http2/clientconn_test.go @@ -0,0 +1,829 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Infrastructure for testing ClientConn.RoundTrip. +// Put actual tests in transport_test.go. + +package http2 + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "reflect" + "slices" + "testing" + "time" + + "golang.org/x/net/http2/hpack" +) + +// TestTestClientConn demonstrates usage of testClientConn. +func TestTestClientConn(t *testing.T) { + // newTestClientConn creates a *ClientConn and surrounding test infrastructure. + tc := newTestClientConn(t) + + // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames, + // and sends a SETTINGS frame to the client. + // + // Additional settings may be provided as optional parameters to greet. + tc.greet() + + // Request bodies must either be constant (bytes.Buffer, strings.Reader) + // or created with newRequestBody. + body := tc.newRequestBody() + body.writeBytes(10) // 10 arbitrary bytes... + body.closeWithError(io.EOF) // ...followed by EOF. + + // tc.roundTrip calls RoundTrip, but does not wait for it to return. + // It returns a testRoundTrip. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + // tc has a number of methods to check for expected frames sent. + // Here, we look for headers and the request body. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + // Expect 10 bytes of request body in DATA frames. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: 10, + }) + + // tc.writeHeaders sends a HEADERS frame back to the client. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // Now that we've received headers, RoundTrip has finished. + // testRoundTrip has various methods to examine the response, + // or to fetch the response and/or error returned by RoundTrip + rt.wantStatus(200) + rt.wantBody(nil) +} + +// A testClientConn allows testing ClientConn.RoundTrip against a fake server. +// +// A test using testClientConn consists of: +// - actions on the client (calling RoundTrip, making data available to Request.Body); +// - validation of frames sent by the client to the server; and +// - providing frames from the server to the client. +// +// testClientConn manages synchronization, so tests can generally be written as +// a linear sequence of actions and validations without additional synchronization. +type testClientConn struct { + t *testing.T + + tr *Transport + fr *Framer + cc *ClientConn + hooks *testSyncHooks + + encbuf bytes.Buffer + enc *hpack.Encoder + + roundtrips []*testRoundTrip + + rerr error // returned by Read + netConnClosed bool // set when the ClientConn closes the net.Conn + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn +} + +func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { + tc := &testClientConn{ + t: t, + tr: cc.t, + cc: cc, + hooks: cc.t.syncHooks, + } + cc.tconn = (*testClientConnNetConn)(tc) + tc.enc = hpack.NewEncoder(&tc.encbuf) + tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) + tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + tc.fr.SetMaxReadFrameSize(10 << 20) + t.Cleanup(func() { + tc.sync() + if tc.rerr == nil { + tc.rerr = io.EOF + } + tc.sync() + }) + return tc +} + +func (tc *testClientConn) readClientPreface() { + tc.t.Helper() + // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. + buf := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { + tc.t.Fatalf("reading preface: %v", err) + } + if !bytes.Equal(buf, clientPreface) { + tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) + } +} + +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + const singleUse = false + _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() +} + +// sync waits for the ClientConn under test to reach a stable state, +// with all goroutines blocked on some input. +func (tc *testClientConn) sync() { + tc.hooks.waitInactive() +} + +// advance advances synthetic time by a duration. +func (tc *testClientConn) advance(d time.Duration) { + tc.hooks.advance(d) + tc.sync() +} + +// hasFrame reports whether a frame is available to be read. +func (tc *testClientConn) hasFrame() bool { + return tc.wbuf.Len() > 0 +} + +// readFrame reads the next frame from the conn. +func (tc *testClientConn) readFrame() Frame { + if tc.wbuf.Len() == 0 { + return nil + } + fr, err := tc.fr.ReadFrame() + if err != nil { + return nil + } + return fr +} + +// testClientConnReadFrame reads a frame of a specific type from the conn. +func testClientConnReadFrame[T any](tc *testClientConn) T { + tc.t.Helper() + var v T + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %T", v) + } + v, ok := fr.(T) + if !ok { + tc.t.Fatalf("got frame %T, want %T", fr, v) + } + return v +} + +// wantFrameType reads the next frame from the conn. +// It produces an error if the frame type is not the expected value. +func (tc *testClientConn) wantFrameType(want FrameType) { + tc.t.Helper() + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", want) + } + if got := fr.Header().Type; got != want { + tc.t.Fatalf("got frame %v, want %v", got, want) + } +} + +// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied. +// +// want is a list of func(*SomeFrame) bool. +// wantUnorderedFrames will call each func with frames of the appropriate type +// until the func returns true. +// It calls t.Fatal if an unexpected frame is received (no func has that frame type, +// or all funcs with that type have returned true), or if the conn runs out of frames +// with unsatisfied funcs. +// +// Example: +// +// // Read a SETTINGS frame, and any number of DATA frames for a stream. +// // The SETTINGS frame may appear anywhere in the sequence. +// // The last DATA frame must indicate the end of the stream. +// tc.wantUnorderedFrames( +// func(f *SettingsFrame) bool { +// return true +// }, +// func(f *DataFrame) bool { +// return f.StreamEnded() +// }, +// ) +func (tc *testClientConn) wantUnorderedFrames(want ...any) { + tc.t.Helper() + want = slices.Clone(want) + seen := 0 +frame: + for seen < len(want) && !tc.t.Failed() { + fr := tc.readFrame() + if fr == nil { + break + } + for i, f := range want { + if f == nil { + continue + } + typ := reflect.TypeOf(f) + if typ.Kind() != reflect.Func || + typ.NumIn() != 1 || + typ.NumOut() != 1 || + typ.Out(0) != reflect.TypeOf(true) { + tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) + } + if typ.In(0) == reflect.TypeOf(fr) { + out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)}) + if out[0].Bool() { + want[i] = nil + seen++ + } + continue frame + } + } + tc.t.Errorf("got unexpected frame type %T", fr) + } + if seen < len(want) { + for _, f := range want { + if f == nil { + continue + } + tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0)) + } + tc.t.Fatalf("did not see %v expected frame types", len(want)-seen) + } +} + +type wantHeader struct { + streamID uint32 + endStream bool + header http.Header +} + +// wantHeaders reads a HEADERS frame and potential CONTINUATION frames, +// and asserts that they contain the expected headers. +func (tc *testClientConn) wantHeaders(want wantHeader) { + tc.t.Helper() + got := testClientConnReadFrame[*MetaHeadersFrame](tc) + if got, want := got.StreamID, want.streamID; got != want { + tc.t.Fatalf("got stream ID %v, want %v", got, want) + } + if got, want := got.StreamEnded(), want.endStream; got != want { + tc.t.Fatalf("got stream ended %v, want %v", got, want) + } + gotHeader := make(http.Header) + for _, f := range got.Fields { + gotHeader[f.Name] = append(gotHeader[f.Name], f.Value) + } + for k, v := range want.header { + if !reflect.DeepEqual(v, gotHeader[k]) { + tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) + } + } +} + +type wantData struct { + streamID uint32 + endStream bool + size int +} + +// wantData reads zero or more DATA frames, and asserts that they match the expectation. +func (tc *testClientConn) wantData(want wantData) { + tc.t.Helper() + gotSize := 0 + gotEndStream := false + for tc.hasFrame() && !gotEndStream { + data := testClientConnReadFrame[*DataFrame](tc) + gotSize += len(data.Data()) + if data.StreamEnded() { + gotEndStream = true + } + } + if gotSize != want.size { + tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size) + } + if gotEndStream != want.endStream { + tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream) + } +} + +// testRequestBody is a Request.Body for use in tests. +type testRequestBody struct { + tc *testClientConn + + // At most one of buf or bytes can be set at any given time: + buf bytes.Buffer // specific bytes to read from the body + bytes int // body contains this many arbitrary bytes + + err error // read error (comes after any available bytes) +} + +func (tc *testClientConn) newRequestBody() *testRequestBody { + b := &testRequestBody{ + tc: tc, + } + return b +} + +// Read is called by the ClientConn to read from a request body. +func (b *testRequestBody) Read(p []byte) (n int, _ error) { + b.tc.cc.syncHooks.blockUntil(func() bool { + return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil + }) + switch { + case b.buf.Len() > 0: + return b.buf.Read(p) + case b.bytes > 0: + if len(p) > b.bytes { + p = p[:b.bytes] + } + b.bytes -= len(p) + for i := range p { + p[i] = 'A' + } + return len(p), nil + default: + return 0, b.err + } +} + +// Close is called by the ClientConn when it is done reading from a request body. +func (b *testRequestBody) Close() error { + return nil +} + +// writeBytes adds n arbitrary bytes to the body. +func (b *testRequestBody) writeBytes(n int) { + b.bytes += n + b.checkWrite() + b.tc.sync() +} + +// Write adds bytes to the body. +func (b *testRequestBody) Write(p []byte) (int, error) { + n, err := b.buf.Write(p) + b.checkWrite() + b.tc.sync() + return n, err +} + +func (b *testRequestBody) checkWrite() { + if b.bytes > 0 && b.buf.Len() > 0 { + b.tc.t.Fatalf("can't interleave Write and writeBytes on request body") + } + if b.err != nil { + b.tc.t.Fatalf("can't write to request body after closeWithError") + } +} + +// closeWithError sets an error which will be returned by Read. +func (b *testRequestBody) closeWithError(err error) { + b.err = err + b.tc.sync() +} + +// roundTrip starts a RoundTrip call. +// +// (Note that the RoundTrip won't complete until response headers are received, +// the request times out, or some other terminal condition is reached.) +func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tc.t, + donec: make(chan struct{}), + } + tc.roundtrips = append(tc.roundtrips, rt) + tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs } + tc.cc.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tc.cc.RoundTrip(req) + }) + tc.sync() + tc.hooks.newstream = nil + + tc.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} + +func (tc *testClientConn) greet(settings ...Setting) { + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings(settings...) + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement +} + +func (tc *testClientConn) writeSettings(settings ...Setting) { + tc.t.Helper() + if err := tc.fr.WriteSettings(settings...); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeSettingsAck() { + tc.t.Helper() + if err := tc.fr.WriteSettingsAck(); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) { + tc.t.Helper() + if err := tc.fr.WriteData(streamID, endStream, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + tc.t.Helper() + if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// makeHeaderBlockFragment encodes headers in a form suitable for inclusion +// in a HEADERS or CONTINUATION frame. +// +// It takes a list of alernating names and values. +func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { + if len(s)%2 != 0 { + tc.t.Fatalf("uneven list of header name/value pairs") + } + tc.encbuf.Reset() + for i := 0; i < len(s); i += 2 { + tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]}) + } + return tc.encbuf.Bytes() +} + +func (tc *testClientConn) writeHeaders(p HeadersFrameParam) { + tc.t.Helper() + if err := tc.fr.WriteHeaders(p); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// writeHeadersMode writes header frames, as modified by mode: +// +// - noHeader: Don't write the header. +// - oneHeader: Write a single HEADERS frame. +// - splitHeader: Write a HEADERS frame and CONTINUATION frame. +func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) { + tc.t.Helper() + switch mode { + case noHeader: + case oneHeader: + tc.writeHeaders(p) + case splitHeader: + if len(p.BlockFragment) < 2 { + panic("too small") + } + contData := p.BlockFragment[1:] + contEnd := p.EndHeaders + p.BlockFragment = p.BlockFragment[:1] + p.EndHeaders = false + tc.writeHeaders(p) + tc.writeContinuation(p.StreamID, contEnd, contData) + default: + panic("bogus mode") + } +} + +func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) { + tc.t.Helper() + if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) { + tc.t.Helper() + if err := tc.fr.WriteRSTStream(streamID, code); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writePing(ack bool, data [8]byte) { + tc.t.Helper() + if err := tc.fr.WritePing(ack, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { + tc.t.Helper() + if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) { + tc.t.Helper() + if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Read calls. +func (tc *testClientConn) closeWrite(err error) { + tc.rerr = err + tc.sync() +} + +// inflowWindow returns the amount of inbound flow control available for a stream, +// or for the connection if streamID is 0. +func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + tc.cc.mu.Lock() + defer tc.cc.mu.Unlock() + if streamID == 0 { + return tc.cc.inflow.avail + tc.cc.inflow.unsent + } + cs := tc.cc.streams[streamID] + if cs == nil { + tc.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent +} + +// testRoundTrip manages a RoundTrip in progress. +type testRoundTrip struct { + t *testing.T + resp *http.Response + respErr error + donec chan struct{} + cs *clientStream +} + +// streamID returns the HTTP/2 stream ID of the request. +func (rt *testRoundTrip) streamID() uint32 { + if rt.cs == nil { + panic("stream ID unknown") + } + return rt.cs.ID +} + +// done reports whether RoundTrip has returned. +func (rt *testRoundTrip) done() bool { + select { + case <-rt.donec: + return true + default: + return false + } +} + +// result returns the result of the RoundTrip. +func (rt *testRoundTrip) result() (*http.Response, error) { + t := rt.t + t.Helper() + select { + case <-rt.donec: + default: + t.Fatalf("RoundTrip is not done; want it to be") + } + return rt.resp, rt.respErr +} + +// response returns the response of a successful RoundTrip. +// If the RoundTrip unexpectedly failed, it calls t.Fatal. +func (rt *testRoundTrip) response() *http.Response { + t := rt.t + t.Helper() + resp, err := rt.result() + if err != nil { + t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) + } + if resp == nil { + t.Fatalf("RoundTrip returned nil *Response and nil error") + } + return resp +} + +// err returns the (possibly nil) error result of RoundTrip. +func (rt *testRoundTrip) err() error { + t := rt.t + t.Helper() + _, err := rt.result() + return err +} + +// wantStatus indicates the expected response StatusCode. +func (rt *testRoundTrip) wantStatus(want int) { + t := rt.t + t.Helper() + if got := rt.response().StatusCode; got != want { + t.Fatalf("got response status %v, want %v", got, want) + } +} + +// body reads the contents of the response body. +func (rt *testRoundTrip) readBody() ([]byte, error) { + t := rt.t + t.Helper() + return io.ReadAll(rt.response().Body) +} + +// wantBody indicates the expected response body. +// (Note that this consumes the body.) +func (rt *testRoundTrip) wantBody(want []byte) { + t := rt.t + t.Helper() + got, err := rt.readBody() + if err != nil { + t.Fatalf("unexpected error reading response body: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) + } +} + +// wantHeaders indicates the expected response headers. +func (rt *testRoundTrip) wantHeaders(want http.Header) { + t := rt.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Header, want); diff != "" { + t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +// wantTrailers indicates the expected response trailers. +func (rt *testRoundTrip) wantTrailers(want http.Header) { + t := rt.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Trailer, want); diff != "" { + t.Fatalf("unexpected response trailers:\n%v", diff) + } +} + +func diffHeaders(got, want http.Header) string { + // nil and 0-length non-nil are equal. + if len(got) == 0 && len(want) == 0 { + return "" + } + // We could do a more sophisticated diff here. + // DeepEqual is good enough for now. + if reflect.DeepEqual(got, want) { + return "" + } + return fmt.Sprintf("got: %v\nwant: %v", got, want) +} + +// testClientConnNetConn implements net.Conn. +type testClientConnNetConn testClientConn + +func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) { + nc.cc.syncHooks.blockUntil(func() bool { + return nc.rerr != nil || nc.rbuf.Len() > 0 + }) + if nc.rbuf.Len() > 0 { + return nc.rbuf.Read(b) + } + return 0, nc.rerr +} + +func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { + return nc.wbuf.Write(b) +} + +func (nc *testClientConnNetConn) Close() error { + nc.netConnClosed = true + return nil +} + +func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } + +// A testTransport allows testing Transport.RoundTrip against fake servers. +// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling +// should use testClientConn instead. +type testTransport struct { + t *testing.T + tr *Transport + + ccs []*testClientConn +} + +func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { + tr := &Transport{ + syncHooks: newTestSyncHooks(), + } + for _, o := range opts { + o(tr) + } + + tt := &testTransport{ + t: t, + tr: tr, + } + tr.syncHooks.newclientconn = func(cc *ClientConn) { + tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc)) + } + + t.Cleanup(func() { + tt.sync() + if len(tt.ccs) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + } + if tt.tr.syncHooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total) + } + }) + + return tt +} + +func (tt *testTransport) sync() { + tt.tr.syncHooks.waitInactive() +} + +func (tt *testTransport) advance(d time.Duration) { + tt.tr.syncHooks.advance(d) + tt.sync() +} + +func (tt *testTransport) hasConn() bool { + return len(tt.ccs) > 0 +} + +func (tt *testTransport) getConn() *testClientConn { + tt.t.Helper() + if len(tt.ccs) == 0 { + tt.t.Fatalf("no new ClientConns created; wanted one") + } + tc := tt.ccs[0] + tt.ccs = tt.ccs[1:] + tc.sync() + tc.readClientPreface() + return tc +} + +func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tt.t, + donec: make(chan struct{}), + } + tt.tr.syncHooks.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + }) + tt.sync() + + tt.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} diff --git a/http2/frame.go b/http2/frame.go index e2b298d85..43557ab7e 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { return nil, ConnectionError(ErrCodeCompression) } diff --git a/http2/pipe.go b/http2/pipe.go index 684d984fd..3b9f06b96 100644 --- a/http2/pipe.go +++ b/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } diff --git a/http2/server.go b/http2/server.go index ae94c6408..ce2e8b40e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -124,6 +124,7 @@ type Server struct { // IdleTimeout specifies how long until idle clients should be // closed with a GOAWAY frame. PING frames are not considered // activity for the purposes of IdleTimeout. + // If zero or negative, there is no timeout. IdleTimeout time.Duration // MaxUploadBufferPerConnection is the size of the initial flow @@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -924,7 +925,7 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } @@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { @@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } @@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { // Disable any read deadline set by the net/http package // prior to the upgrade. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } diff --git a/http2/server_test.go b/http2/server_test.go index 1fdd191ef..a931a06e5 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4578,13 +4578,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { sc := &serverConn{ serveG: newGoroutineLock(), } - const count = 1000 - for i := 0; i < count; i++ { - h := fmt.Sprintf("%v-%v", base, i) + count := 0 + added := 0 + for added < 10*maxCachedCanonicalHeadersKeysSize { + h := fmt.Sprintf("%v-%v", base, count) c := sc.canonicalHeader(h) if len(h) != len(c) { t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) } + count++ + added += len(h) } total := 0 for k, v := range sc.canonHeader { @@ -4783,3 +4786,89 @@ Frames: close(s) } } + +func TestServerContinuationFlood(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }, func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 4096 + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + for i := 0; i < 1000; i++ { + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + fmt.Sprintf("x-%v", i), "1234567890", + )) + } + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-last-header", "1", + )) + + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection") + } + } + // We expect to have seen a GOAWAY before the connection closes, + // but the server will close the connection after one second + // whether or not it has finished sending the GOAWAY. On windows-amd64-race + // builders, this fairly consistently results in the connection closing without + // the GOAWAY being sent. + // + // Since the server's behavior is inherently racy here and the important thing + // is that the connection is closed, don't check for the GOAWAY having been sent. +} + +func TestServerContinuationAfterInvalidHeader(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + "x-invalid-header", "\x00", + )) + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-valid-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} diff --git a/http2/testsync.go b/http2/testsync.go new file mode 100644 index 000000000..61075bd16 --- /dev/null +++ b/http2/testsync.go @@ -0,0 +1,331 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import ( + "context" + "sync" + "time" +) + +// testSyncHooks coordinates goroutines in tests. +// +// For example, a call to ClientConn.RoundTrip involves several goroutines, including: +// - the goroutine running RoundTrip; +// - the clientStream.doRequest goroutine, which writes the request; and +// - the clientStream.readLoop goroutine, which reads the response. +// +// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines +// are blocked waiting for some condition such as reading the Request.Body or waiting for +// flow control to become available. +// +// The testSyncHooks also manage timers and synthetic time in tests. +// This permits us to, for example, start a request and cause it to time out waiting for +// response headers without resorting to time.Sleep calls. +type testSyncHooks struct { + // active/inactive act as a mutex and condition variable. + // + // - neither chan contains a value: testSyncHooks is locked. + // - active contains a value: unlocked, and at least one goroutine is not blocked + // - inactive contains a value: unlocked, and all goroutines are blocked + active chan struct{} + inactive chan struct{} + + // goroutine counts + total int // total goroutines + condwait map[*sync.Cond]int // blocked in sync.Cond.Wait + blocked []*testBlockedGoroutine // otherwise blocked + + // fake time + now time.Time + timers []*fakeTimer + + // Transport testing: Report various events. + newclientconn func(*ClientConn) + newstream func(*clientStream) +} + +// testBlockedGoroutine is a blocked goroutine. +type testBlockedGoroutine struct { + f func() bool // blocked until f returns true + ch chan struct{} // closed when unblocked +} + +func newTestSyncHooks() *testSyncHooks { + h := &testSyncHooks{ + active: make(chan struct{}, 1), + inactive: make(chan struct{}, 1), + condwait: map[*sync.Cond]int{}, + } + h.inactive <- struct{}{} + h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return h +} + +// lock acquires the testSyncHooks mutex. +func (h *testSyncHooks) lock() { + select { + case <-h.active: + case <-h.inactive: + } +} + +// waitInactive waits for all goroutines to become inactive. +func (h *testSyncHooks) waitInactive() { + for { + <-h.inactive + if !h.unlock() { + break + } + } +} + +// unlock releases the testSyncHooks mutex. +// It reports whether any goroutines are active. +func (h *testSyncHooks) unlock() (active bool) { + // Look for a blocked goroutine which can be unblocked. + blocked := h.blocked[:0] + unblocked := false + for _, b := range h.blocked { + if !unblocked && b.f() { + unblocked = true + close(b.ch) + } else { + blocked = append(blocked, b) + } + } + h.blocked = blocked + + // Count goroutines blocked on condition variables. + condwait := 0 + for _, count := range h.condwait { + condwait += count + } + + if h.total > condwait+len(blocked) { + h.active <- struct{}{} + return true + } else { + h.inactive <- struct{}{} + return false + } +} + +// goRun starts a new goroutine. +func (h *testSyncHooks) goRun(f func()) { + h.lock() + h.total++ + h.unlock() + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + f() + }() +} + +// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. +// It waits until f returns true before proceeding. +// +// Example usage: +// +// h.blockUntil(func() bool { +// // Is the context done yet? +// select { +// case <-ctx.Done(): +// default: +// return false +// } +// return true +// }) +// // Wait for the context to become done. +// <-ctx.Done() +// +// The function f passed to blockUntil must be non-blocking and idempotent. +func (h *testSyncHooks) blockUntil(f func() bool) { + if f() { + return + } + ch := make(chan struct{}) + h.lock() + h.blocked = append(h.blocked, &testBlockedGoroutine{ + f: f, + ch: ch, + }) + h.unlock() + <-ch +} + +// broadcast is sync.Cond.Broadcast. +func (h *testSyncHooks) condBroadcast(cond *sync.Cond) { + h.lock() + delete(h.condwait, cond) + h.unlock() + cond.Broadcast() +} + +// broadcast is sync.Cond.Wait. +func (h *testSyncHooks) condWait(cond *sync.Cond) { + h.lock() + h.condwait[cond]++ + h.unlock() +} + +// newTimer creates a new fake timer. +func (h *testSyncHooks) newTimer(d time.Duration) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + c: make(chan time.Time), + } + h.timers = append(h.timers, t) + return t +} + +// afterFunc creates a new fake AfterFunc timer. +func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + f: f, + } + h.timers = append(h.timers, t) + return t +} + +func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + t := h.afterFunc(d, cancel) + return ctx, func() { + t.Stop() + cancel() + } +} + +func (h *testSyncHooks) timeUntilEvent() time.Duration { + h.lock() + defer h.unlock() + var next time.Time + for _, t := range h.timers { + if next.IsZero() || t.when.Before(next) { + next = t.when + } + } + if d := next.Sub(h.now); d > 0 { + return d + } + return 0 +} + +// advance advances time and causes synthetic timers to fire. +func (h *testSyncHooks) advance(d time.Duration) { + h.lock() + defer h.unlock() + h.now = h.now.Add(d) + timers := h.timers[:0] + for _, t := range h.timers { + t := t // remove after go.mod depends on go1.22 + t.mu.Lock() + switch { + case t.when.After(h.now): + timers = append(timers, t) + case t.when.IsZero(): + // stopped timer + default: + t.when = time.Time{} + if t.c != nil { + close(t.c) + } + if t.f != nil { + h.total++ + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + t.f() + }() + } + } + t.mu.Unlock() + } + h.timers = timers +} + +// A timer wraps a time.Timer, or a synthetic equivalent in tests. +// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. +type timer interface { + C() <-chan time.Time + Stop() bool + Reset(d time.Duration) bool +} + +// timeTimer implements timer using real time. +type timeTimer struct { + t *time.Timer + c chan time.Time +} + +// newTimeTimer creates a new timer using real time. +func newTimeTimer(d time.Duration) timer { + ch := make(chan time.Time) + t := time.AfterFunc(d, func() { + close(ch) + }) + return &timeTimer{t, ch} +} + +// newTimeAfterFunc creates an AfterFunc timer using real time. +func newTimeAfterFunc(d time.Duration, f func()) timer { + return &timeTimer{ + t: time.AfterFunc(d, f), + } +} + +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } +func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } + +// fakeTimer implements timer using fake time. +type fakeTimer struct { + hooks *testSyncHooks + + mu sync.Mutex + when time.Time // when the timer will fire + c chan time.Time // closed when the timer fires; mutually exclusive with f + f func() // called when the timer fires; mutually exclusive with c +} + +func (t *fakeTimer) C() <-chan time.Time { return t.c } + +func (t *fakeTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + stopped := t.when.IsZero() + t.when = time.Time{} + return stopped +} + +func (t *fakeTimer) Reset(d time.Duration) bool { + if t.c != nil || t.f == nil { + panic("fakeTimer only supports Reset on AfterFunc timers") + } + t.mu.Lock() + defer t.mu.Unlock() + t.hooks.lock() + defer t.hooks.unlock() + active := !t.when.IsZero() + t.when = t.hooks.now.Add(d) + if !active { + t.hooks.timers = append(t.hooks.timers, t) + } + return active +} diff --git a/http2/transport.go b/http2/transport.go index c2a5b44b3..ce375c8c7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -147,6 +147,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -178,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -302,7 +310,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -344,6 +352,60 @@ type ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + + syncHooks *testSyncHooks // can be nil +} + +// Hook points used for testing. +// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +// goRun starts a new goroutine. +func (cc *ClientConn) goRun(f func()) { + if cc.syncHooks != nil { + cc.syncHooks.goRun(f) + return + } + go f() +} + +// condBroadcast is cc.cond.Broadcast. +func (cc *ClientConn) condBroadcast() { + if cc.syncHooks != nil { + cc.syncHooks.condBroadcast(cc.cond) + } + cc.cond.Broadcast() +} + +// condWait is cc.cond.Wait. +func (cc *ClientConn) condWait() { + if cc.syncHooks != nil { + cc.syncHooks.condWait(cc.cond) + } + cc.cond.Wait() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (cc *ClientConn) newTimer(d time.Duration) timer { + if cc.syncHooks != nil { + return cc.syncHooks.newTimer(d) + } + return newTimeTimer(d) +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer { + if cc.syncHooks != nil { + return cc.syncHooks.afterFunc(d, f) + } + return newTimeAfterFunc(d, f) +} + +func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if cc.syncHooks != nil { + return cc.syncHooks.contextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() + cs.cc.condBroadcast() } } @@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.cond.Broadcast() + cc.condBroadcast() } } @@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - go func() { + cs.cc.goRun(func() { cs.reqBody.Close() close(reqBodyClosed) - }() + }) } type stickyErrWriter struct { @@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -658,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) { cc := &ClientConn{ t: t, tconn: c, @@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), + syncHooks: hooks, + } + if hooks != nil { + hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - go cc.readLoop() + cc.goRun(cc.readLoop) return cc, nil } @@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - go func() { + cc.goRun(func() { cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.cond.Wait() + cc.condWait() } - }() + }) shutdownEnterWaitStateHook() select { case <-done: @@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() return ctx.Err() } @@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() cc.closeConn() } @@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + cc.goRun(func() { + cs.doRequest(req) + }) waitDone := func() error { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.donec: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.donec: return nil @@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return err } + if streamf != nil { + streamf(cs) + } + for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } + var newStreamHook func(*clientStream) + if cc.syncHooks != nil { + newStreamHook = cc.syncHooks.newstream + cc.syncHooks.blockUntil(func() bool { + select { + case cc.reqHeaderMu <- struct{}{}: + <-cc.reqHeaderMu + case <-cs.reqCancel: + case <-ctx.Done(): + default: + return false + } + return true + }) + } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() + if newStreamHook != nil { + newStreamHook(cs) + } + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.peerClosed: + case <-respHeaderTimer: + case <-respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.peerClosed: return nil @@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { return nil } cc.pendingRequests++ - cc.cond.Wait() + cc.condWait() cc.pendingRequests-- select { case <-cs.abort: @@ -1871,8 +2015,24 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) cs.flow.take(take) return take, nil } - cc.cond.Wait() + cc.condWait() + } +} + +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } } + return "" } var errNilRequestURL = errors.New("http2: Request.URI is nil") @@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { @@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + cc.condBroadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() } @@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, @@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { for _, cs := range cc.streams { cs.flow.add(delta) } - cc.cond.Broadcast() + cc.condBroadcast() cc.initialWindowSize = s.Val case SettingHeaderTableSize: @@ -2922,7 +3076,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { return ConnectionError(ErrCodeFlowControl) } - cc.cond.Broadcast() + cc.condBroadcast() return nil } @@ -2964,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) - go func() { + var pingError error + errc := make(chan struct{}) + cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } - }() + }) + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-c: + case <-errc: + case <-ctx.Done(): + case <-cc.readerDone: + default: + return false + } + return true + }) + } select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: @@ -3150,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err } func (t *Transport) idleConnTimeout() time.Duration { + // to keep things backwards compatible, we use non-zero values of + // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying + // http1 transport, followed by 0 + if t.IdleConnTimeout != 0 { + return t.IdleConnTimeout + } + if t.t1 != nil { return t.t1.IdleConnTimeout } + return 0 } diff --git a/http2/transport_test.go b/http2/transport_test.go index a81131f29..11ff67b4c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -95,6 +95,88 @@ func startH2cServer(t *testing.T) net.Listener { return l } +func TestIdleConnTimeout(t *testing.T) { + for _, test := range []struct { + name string + idleConnTimeout time.Duration + wait time.Duration + baseTransport *http.Transport + wantNewConn bool + }{{ + name: "NoExpiry", + idleConnTimeout: 2 * time.Second, + wait: 1 * time.Second, + baseTransport: nil, + wantNewConn: false, + }, { + name: "H2TransportTimeoutExpires", + idleConnTimeout: 1 * time.Second, + wait: 2 * time.Second, + baseTransport: nil, + wantNewConn: true, + }, { + name: "H1TransportTimeoutExpires", + idleConnTimeout: 0 * time.Second, + wait: 1 * time.Second, + baseTransport: &http.Transport{ + IdleConnTimeout: 2 * time.Second, + }, + wantNewConn: false, + }} { + t.Run(test.name, func(t *testing.T) { + tt := newTestTransport(t, func(tr *Transport) { + tr.IdleConnTimeout = test.idleConnTimeout + }) + var tc *testClientConn + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // This request happens on a new conn if it's the first request + // (and there is no cached conn), or if the test timeout is long + // enough that old conns are being closed. + wantConn := i == 0 || test.wantNewConn + if has := tt.hasConn(); has != wantConn { + t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn) + } + if wantConn { + tc = tt.getConn() + // Read client's SETTINGS and first WINDOW_UPDATE, + // send our SETTINGS. + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings() + } + if tt.hasConn() { + t.Fatalf("request %v: Transport has more than one conn", i) + } + + // Respond to the client's request. + hf := testClientConnReadFrame[*MetaHeadersFrame](tc) + tc.writeHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) + + // If this was a newly-accepted conn, read the SETTINGS ACK. + if wantConn { + tc.wantFrameType(FrameSettings) // ACK to our settings + } + + tt.advance(test.wait) + if got, want := tc.netConnClosed, test.wantNewConn; got != want { + t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) + } + } + }) + } +} + func TestTransportH2c(t *testing.T) { l := startH2cServer(t) defer l.Close() @@ -740,53 +822,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { return } -type clientTester struct { - t *testing.T - tr *Transport - sc, cc net.Conn // server and client conn - fr *Framer // server's framer - settings *SettingsFrame - client func() error - server func() error -} - -func newClientTester(t *testing.T) *clientTester { - var dialOnce struct { - sync.Mutex - dialed bool - } - ct := &clientTester{ - t: t, - } - ct.tr = &Transport{ - TLSClientConfig: tlsConfigInsecure, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialOnce.Lock() - defer dialOnce.Unlock() - if dialOnce.dialed { - return nil, errors.New("only one dial allowed in test mode") - } - dialOnce.dialed = true - return ct.cc, nil - }, - } - - ln := newLocalListener(t) - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - sc, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - ln.Close() - ct.cc = cc - ct.sc = sc - ct.fr = NewFramer(sc, sc) - return ct -} - func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { @@ -799,284 +834,70 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func (ct *clientTester) greet(settings ...Setting) { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - ct.t.Fatalf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - ct.t.Fatalf("Reading client settings frame: %v", err) - } - var ok bool - if ct.settings, ok = f.(*SettingsFrame); !ok { - ct.t.Fatalf("Wanted client settings frame; got %v", f) - } - if err := ct.fr.WriteSettings(settings...); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } -} - -func (ct *clientTester) readNonSettingsFrame() (Frame, error) { - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return nil, err - } - if _, ok := f.(*SettingsFrame); ok { - continue - } - return f, nil - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection, -// aside from SETTINGS frames. -func (ct *clientTester) writeReadPing() error { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := ct.fr.WritePing(false, data); err != nil { - return fmt.Errorf("Error writing PING: %v", err) - } - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - p, ok := f.(*PingFrame) - if !ok { - return fmt.Errorf("got a %v, want a PING ACK", f) - } - if p.Flags&FlagPingAck == 0 { - return fmt.Errorf("got a PING, want a PING ACK") - } - if p.Data != data { - return fmt.Errorf("got PING data = %x, want %x", p.Data, data) - } - return nil -} - -func (ct *clientTester) inflowWindow(streamID uint32) int32 { - pool := ct.tr.connPoolOrDef.(*clientConnPool) - pool.mu.Lock() - defer pool.mu.Unlock() - if n := len(pool.keys); n != 1 { - ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) - return -1 - } - for cc := range pool.keys { - cc.mu.Lock() - defer cc.mu.Unlock() - if streamID == 0 { - return cc.inflow.avail + cc.inflow.unsent - } - cs := cc.streams[streamID] - if cs == nil { - ct.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent - } - return -1 -} - -func (ct *clientTester) cleanup() { - ct.tr.CloseIdleConnections() - - // close both connections, ignore the error if its already closed - ct.sc.Close() - ct.cc.Close() -} - -func (ct *clientTester) run() { - var errOnce sync.Once - var wg sync.WaitGroup - - run := func(which string, fn func() error) { - defer wg.Done() - if err := fn(); err != nil { - errOnce.Do(func() { - ct.t.Errorf("%s: %v", which, err) - ct.cleanup() - }) - } - } +func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } +func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } - wg.Add(2) - go run("client", ct.client) - go run("server", ct.server) - wg.Wait() +func testTransportReqBodyAfterResponse(t *testing.T, status int) { + const bodySize = 10 << 20 - errOnce.Do(ct.cleanup) // clean up if no error -} + tc := newTestClientConn(t) + tc.greet() + + body := tc.newRequestBody() + body.writeBytes(bodySize / 2) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) -func (ct *clientTester) readFrame() (Frame, error) { - return ct.fr.ReadFrame() -} + // Provide enough congestion window for the full request body. + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) -func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { - for { - f, err := ct.readFrame() - if err != nil { - return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - hf, ok := f.(*HeadersFrame) - if !ok { - return nil, fmt.Errorf("Got %T; want HeadersFrame", f) - } - return hf, nil - } -} + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: bodySize / 2, + }) -type countingReader struct { - n *int64 -} + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", strconv.Itoa(status), + ), + }) -func (r countingReader) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(i) + res := rt.response() + if res.StatusCode != status { + t.Fatalf("status code = %v; want %v", res.StatusCode, status) } - atomic.AddInt64(r.n, int64(len(p))) - return len(p), err -} -func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } -func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } + body.writeBytes(bodySize / 2) + body.closeWithError(io.EOF) -func testTransportReqBodyAfterResponse(t *testing.T, status int) { - const bodySize = 10 << 20 - clientDone := make(chan struct{}) - ct := newClientTester(t) - recvLen := make(chan int64, 1) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - body := &pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - if res.StatusCode != status { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) - } - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - body.CloseWithError(io.EOF) - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("unexpected body: %q", slurp) - } - res.Body.Close() - if status == 200 { - if got := <-recvLen; got != bodySize { - return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) - } - } else { - if got := <-recvLen; got == 0 || got >= bodySize { - return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) - } - } - return nil + if status == 200 { + // After a 200 response, client sends the remaining request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: bodySize / 2, + }) + } else { + // After a 403 response, client gives up and resets the stream. + tc.wantFrameType(FrameRSTStream) } - ct.server = func() error { - ct.greet() - defer close(recvLen) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var dataRecv int64 - var closed bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - //println(fmt.Sprintf("server got frame: %v", f)) - ended := false - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - if f.StreamEnded() { - return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) - } - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if dataRecv == 0 { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - dataRecv += int64(dataLen) - - if !closed && ((status != 200 && dataRecv > 0) || - (status == 200 && f.StreamEnded())) { - closed = true - if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { - return err - } - } - if f.StreamEnded() { - ended = true - } - case *RSTStreamFrame: - if status == 200 { - return fmt.Errorf("Unexpected client frame %v", f) - } - ended = true - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - if ended { - select { - case recvLen <- dataRecv: - default: - } - } - } - } - ct.run() + rt.wantBody(nil) } // See golang.org/issue/13444 @@ -1257,121 +1078,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy panic("invalid combination") } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) - if expect100Continue != noHeader { - req.Header.Set("Expect", "100-continue") - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - wantBody := resBody - if !withData { - wantBody = "" - } - if string(slurp) != wantBody { - return fmt.Errorf("body = %q; want %q", slurp, wantBody) - } - if trailers == noHeader { - if len(res.Trailer) > 0 { - t.Errorf("Trailer = %v; want none", res.Trailer) - } - } else { - want := http.Header{"Some-Trailer": {"some-value"}} - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Trailer = %v; want %v", res.Trailer, want) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - endStream := false - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *DataFrame: - if !f.StreamEnded() { - // No need to send flow control tokens. The test request body is tiny. - continue - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) - enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) - if trailers != noHeader { - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) - } - endStream = withData == false && trailers == noHeader - send(resHeader) - } - if withData { - endStream = trailers == noHeader - ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) - } - if trailers != noHeader { - endStream = true - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) - send(trailers) - } - if endStream { - return nil - } - case *HeadersFrame: - if expect100Continue != noHeader { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) - send(expect100Continue) - } - } - } + tc.wantFrameType(FrameHeaders) + + // Possibly 100-continue, or skip when noHeader. + tc.writeHeadersMode(expect100Continue, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + + // Client sends request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: len(reqBody), + }) + + hdr := []string{ + ":status", "200", + "x-foo", "blah", + "x-bar", "more", + } + if trailers != noHeader { + hdr = append(hdr, "trailer", "some-trailer") + } + tc.writeHeadersMode(resHeader, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: withData == false && trailers == noHeader, + BlockFragment: tc.makeHeaderBlockFragment(hdr...), + }) + if withData { + endStream := trailers == noHeader + tc.writeData(rt.streamID(), endStream, []byte(resBody)) + } + tc.writeHeadersMode(trailers, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "some-value", + ), + }) + + rt.wantStatus(200) + if !withData { + rt.wantBody(nil) + } else { + rt.wantBody([]byte(resBody)) + } + if trailers == noHeader { + rt.wantTrailers(nil) + } else { + rt.wantTrailers(http.Header{ + "Some-Trailer": {"some-value"}, + }) } - ct.run() } // Issue 26189, Issue 17739: ignore unknown 1xx responses @@ -1383,130 +1157,76 @@ func TestTransportUnknown1xx(t *testing.T) { return nil } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 204 { - return fmt.Errorf("status code = %v; want 204", res.StatusCode) - } - want := `code=110 header=map[Foo-Bar:[110]] + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + for i := 110; i <= 114; i++ { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", fmt.Sprint(i), + "foo-bar", fmt.Sprint(i), + ), + }) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) + + res := rt.response() + if res.StatusCode != 204 { + t.Fatalf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] code=111 header=map[Foo-Bar:[111]] code=112 header=map[Foo-Bar:[112]] code=113 header=map[Foo-Bar:[113]] code=114 header=map[Foo-Bar:[114]] ` - if got := buf.String(); got != want { - t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - for i := 110; i <= 114; i++ { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) - enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) } - ct.run() - } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - if _, ok := res.Trailer["Some-Trailer"]; !ok { - return fmt.Errorf("expected Some-Trailer") - } - return nil - } - ct.server = func() error { - ct.greet() - - var n int - var hf *HeadersFrame - for hf == nil && n < 10 { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - hf, _ = f.(*HeadersFrame) - n++ - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - // send headers without Trailer header - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "I'm an undeclared Trailer!", + ), + }) - // send trailers - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - ct.run() + rt.wantStatus(200) + rt.wantBody(nil) + rt.wantTrailers(http.Header{ + "Some-Trailer": []string{"I'm an undeclared Trailer!"}, + }) } func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { @@ -1516,10 +1236,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - }) + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), + ":colon", "foo", + "foo", "bar", + ) } func TestTransportInvalidTrailer_Capital1(t *testing.T) { @@ -1529,102 +1249,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) - }) + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), + "foo", "bar", + "Capital", "bad", + ) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), + "", "bad", + ) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), + "x", "has\nnewline", + ) } -func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - se, ok := err.(StreamError) - if !ok || se.Cause != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) +func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var endStream bool - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) - endStream = false - send(oneHeader) - } - // Trailers: - { - endStream = true - buf.Reset() - writeTrailer(enc) - send(trailers) - } - return nil - } - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "trailer", "declared", + ), + }) + tc.writeHeadersMode(mode, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment(trailers...), + }) + + rt.wantStatus(200) + body, err := rt.readBody() + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr) + } + if len(body) > 0 { + t.Fatalf("body = %q; want nothing", body) } - ct.run() } // headerListSize returns the HTTP2 header list size of h. @@ -1900,115 +1572,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if e, ok := err.(StreamError); ok { - err = e.Cause - } - if err != errResponseHeaderListSize { - size := int64(0) - if res != nil { - res.Body.Close() - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 - } - } - } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + hdr := []string{":status", "200"} + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + hdr = append(hdr, large, large) + } + hbf := tc.makeHeaderBlockFragment(hdr...) + // Note: this number might change if our hpack implementation changes. + // That's fine. This is just a sanity check that our response can fit in a single + // header block fragment frame. + if size, want := len(hbf), 6329; size != want { + t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: hbf, + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - large := strings.Repeat("a", 1<<10) - for i := 0; i < 5042; i++ { - enc.WriteField(hpack.HeaderField{Name: large, Value: large}) - } - if size, want := buf.Len(), 6329; size != want { - // Note: this number might change if - // our hpack implementation - // changes. That's fine. This is - // just a sanity check that our - // response can fit in a single - // header block fragment frame. - return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + res, err := rt.result() + if e, ok := err.(StreamError); ok { + err = e.Cause + } + if err != errResponseHeaderListSize { + size := int64(0) + if res != nil { + res.Body.Close() + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil } } + t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } - ct.run() } func TestTransportCookieHeaderSplit(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - req.Header.Add("Cookie", "a=b;c=d; e=f;") - req.Header.Add("Cookie", "e=f;g=h; ") - req.Header.Add("Cookie", "i=j") - _, err := ct.tr.RoundTrip(req) - return err - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - dec := hpack.NewDecoder(initialHeaderTableSize, nil) - hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) - if err != nil { - return err - } - got := []string{} - want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} - for _, hf := range hfs { - if hf.Name == "cookie" { - got = append(got, hf.Value) - } - } - if !reflect.DeepEqual(got, want) { - t.Errorf("Cookies = %#v, want %#v", got, want) - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + req.Header.Add("Cookie", "a=b;c=d; e=f;") + req.Header.Add("Cookie", "e=f;g=h; ") + req.Header.Add("Cookie", "i=j") + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}, + }, + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if err := rt.err(); err != nil { + t.Fatalf("RoundTrip = %v, want success", err) } - ct.run() } // Test that the Transport returns a typed error from Response.Body.Read calls @@ -2224,55 +1861,49 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { } func testTransportResponseHeaderTimeout(t *testing.T, body bool) { - ct := newClientTester(t) - ct.tr.t1 = &http.Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, - } - ct.client = func() error { - c := &http.Client{Transport: ct.tr} - var err error - var n int64 - const bodySize = 4 << 20 - if body { - _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) - } else { - _, err = c.Get("https://dummy.tld/") - } - if !isTimeout(err) { - t.Errorf("client expected timeout error; got %#v", err) - } - if body && n != bodySize { - t.Errorf("only read %d bytes of body; want %d", n, bodySize) + const bodySize = 4 << 20 + tc := newTestClientConn(t, func(tr *Transport) { + tr.t1 = &http.Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, } - return nil + }) + tc.greet() + + var req *http.Request + var reqBody *testRequestBody + if body { + reqBody = tc.newRequestBody() + reqBody.writeBytes(bodySize) + reqBody.closeWithError(io.EOF) + req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody) + req.Header.Set("Content-Type", "text/foo") + } else { + req, _ = http.NewRequest("GET", "https://dummy.tld/", nil) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - switch f := f.(type) { - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - case *RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { - return nil - } - } - } + + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + if body { + tc.wantData(wantData{ + endStream: true, + size: bodySize, + }) + } + + tc.advance(4 * time.Millisecond) + if rt.done() { + t.Fatalf("RoundTrip is done after 4ms; want still waiting") + } + tc.advance(1 * time.Millisecond) + + if err := rt.err(); !isTimeout(err) { + t.Fatalf("RoundTrip error: %v; want timeout error", err) } - ct.run() } func TestTransportDisableCompression(t *testing.T) { @@ -2484,7 +2115,8 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { } // golang.org/issue/14048 -func TestTransportFailsOnInvalidHeaders(t *testing.T) { +// golang.org/issue/64766 +func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { @@ -2497,6 +2129,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { tests := [...]struct { h http.Header + t http.Header wantErr string }{ 0: { @@ -2515,6 +2148,14 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { h: http.Header{"foo": {"foo\x01bar"}}, wantErr: `invalid HTTP header value for header "foo"`, }, + 4: { + t: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer value for header "foo"`, + }, + 5: { + t: http.Header{"x-\r\nda": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer name "x-\r\nda"`, + }, } tr := &Transport{TLSClientConfig: tlsConfigInsecure} @@ -2523,6 +2164,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { for i, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h + req.Trailer = tt.t res, err := tr.RoundTrip(req) var bad bool if tt.wantErr == "" { @@ -2658,115 +2300,61 @@ func TestTransportNewTLSConfig(t *testing.T) { // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) func TestTransportReadHeadResponse(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != 123 { - return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, // as the GFE does - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, nil) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) + tc.writeData(rt.streamID(), true, nil) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) } - ct.run() + rt.wantBody(nil) } func TestTransportReadHeadResponseWithBody(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + // This test uses an invalid response format. + // Discard logger output to not spam tests output. + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != int64(len(response)) { - return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, []byte(response)) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", strconv.Itoa(len(response)), + ), + }) + tc.writeData(rt.streamID(), true, []byte(response)) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != int64(len(response)) { + t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response)) } - ct.run() + rt.wantBody(nil) } type neverEnding byte @@ -2891,190 +2479,125 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { } func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { - ct := newClientTester(t) - clientDone := make(chan struct{}) + tc := newTestClientConn(t) + tc.greet() const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if failMidBody { - if err != nil { - return fmt.Errorf("unexpected client RoundTrip error: %v", err) - } - _, err = io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - } - want := GoAwayError{ - LastStreamID: 5, - ErrCode: goAwayErrCode, - DebugData: goAwayDebugData, - } - if !reflect.DeepEqual(err, want) { - t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - if failMidBody { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - // Write two GOAWAY frames, to test that the Transport takes - // the interesting parts of both. - ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) - ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.sc.(*net.TCPConn).Close() - } - <-clientDone - return nil - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + if failMidBody { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) } - ct.run() -} -func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { - ct := newClientTester(t) + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) + tc.writeGoAway(5, goAwayErrCode, nil) + tc.closeWrite(io.EOF) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) + res, err := rt.result() + whence := "RoundTrip" + if failMidBody { + whence = "Body.Read" if err != nil { - return err + t.Fatalf("RoundTrip error = %v, want success", err) } + _, err = res.Body.Read(make([]byte, 1)) + } - if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { - return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) - } - res.Body.Close() // leaving 4999 bytes unread - - return nil + want := GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want) } - ct.server = func() error { - ct.greet() +} - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } +func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) + initialInflow := tc.inflowWindow(0) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + const streamNotEnded = false + if oneDataFrame { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000)) + } else { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1)) + } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialInflow := ct.inflowWindow(0) - - // Two cases: - // - Send one DATA frame with 5000 bytes. - // - Send two DATA frames with 1 and 4999 bytes each. - // - // In both cases, the client should consume one byte of data, - // refund that byte, then refund the following 4999 bytes. - // - // In the second case, the server waits for the client to reset the - // stream before sending the second DATA frame. This tests the case - // where the client receives a DATA frame after it has reset the stream. - if oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - } else { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - } + res := rt.response() + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + t.Fatalf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + tc.sync() - wantRST := true - wantWUF := true - if !oneDataFrame { - wantWUF = false // flow control update is small, and will not be sent - } - for wantRST || wantWUF { - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + sentAdditionalData := false + tc.wantUnorderedFrames( + func(f *RSTStreamFrame) bool { + if f.ErrCode != ErrCodeCancel { + t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - switch f := f.(type) { - case *RSTStreamFrame: - if !wantRST { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) - } - wantRST = false - case *WindowUpdateFrame: - if !wantWUF { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.Increment != 5000 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) - } - wantWUF = false - default: - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) + if !oneDataFrame { + // Send the remaining data now. + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999)) + sentAdditionalData = true } - } - if !oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + return true + }, + func(f *WindowUpdateFrame) bool { + if !oneDataFrame && !sentAdditionalData { + t.Fatalf("Got WindowUpdateFrame, don't expect one yet") } - wuf, ok := f.(*WindowUpdateFrame) - if !ok || wuf.Increment != 5000 { - return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + if f.Increment != 5000 { + t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - } - if err := ct.writeReadPing(); err != nil { - return err - } - if got, want := ct.inflowWindow(0), initialInflow; got != want { - return fmt.Errorf("connection flow tokens = %v, want %v", got, want) - } - return nil + return true + }, + ) + + if got, want := tc.inflowWindow(0), initialInflow; got != want { + t.Fatalf("connection flow tokens = %v, want %v", got, want) } - ct.run() } // See golang.org/issue/16481 @@ -3090,199 +2613,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. func TestTransportAdjustsFlowControl(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - const bodySize = 1 << 20 - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + // Don't write our SETTINGS yet. - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err + body := tc.newRequestBody() + body.writeBytes(bodySize) + body.closeWithError(io.EOF) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + gotBytes := int64(0) + for { + f := testClientConnReadFrame[*DataFrame](tc) + gotBytes += int64(len(f.Data())) + // After we've got half the client's initial flow control window's worth + // of request body data, give it just enough flow control to finish. + if gotBytes >= initialWindowSize/2 { + break } - res.Body.Close() - return nil } - ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - var gotBytes int64 - var sentSettings bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - return nil - default: - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - } - switch f := f.(type) { - case *DataFrame: - gotBytes += int64(len(f.Data())) - // After we've got half the client's - // initial flow control window's worth - // of request body data, give it just - // enough flow control to finish. - if gotBytes >= initialWindowSize/2 && !sentSettings { - sentSettings = true - - ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) - ct.fr.WriteWindowUpdate(0, bodySize) - ct.fr.WriteSettingsAck() - } + tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) + tc.writeWindowUpdate(0, bodySize) + tc.writeSettingsAck() - if f.StreamEnded() { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - } - } + tc.wantUnorderedFrames( + func(f *SettingsFrame) bool { return true }, + func(f *DataFrame) bool { + gotBytes += int64(len(f.Data())) + return f.StreamEnded() + }, + ) + + if gotBytes != bodySize { + t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize) } - ct.run() + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) } // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool, 1) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - defer res.Body.Close() - <-unblockClient - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } + initialConnWindow := tc.inflowWindow(0) + initialStreamWindow := tc.inflowWindow(rt.streamID()) - initialConnWindow := ct.inflowWindow(0) + pad := make([]byte, 5) + tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialStreamWindow := ct.inflowWindow(hf.StreamID) - pad := make([]byte, 5) - ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - if err := ct.writeReadPing(); err != nil { - return err - } - // Padding flow control should have been returned. - if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { - t.Errorf("conn inflow window = %v, want %v", got, want) - } - if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { - t.Errorf("stream inflow window = %v, want %v", got, want) - } - unblockClient <- true - return nil + // Padding flow control should have been returned. + if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) + } + if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } - ct.run() } // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + " content-type", "bogus", + ), + }) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - return errors.New("unexpected successful GET") - } - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} - if !reflect.DeepEqual(want, err) { - t.Errorf("RoundTrip error = %#v; want %#v", err, want) - } - return nil + err := rt.err() + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(err, want) { + t.Fatalf("RoundTrip error = %#v; want %#v", err, want) } - ct.server = func() error { - ct.greet() - - hf, err := ct.firstHeaders() - if err != nil { - return err - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - for { - fr, err := ct.readFrame() - if err != nil { - return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) - } - if _, ok := fr.(*SettingsFrame); ok { - continue - } - if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) - } - break - } - - return nil + fr := testClientConnReadFrame[*RSTStreamFrame](tc) + if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } - ct.run() } // byteAndEOFReader returns is in an io.Reader which reads one byte @@ -3576,26 +3024,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { } func TestTransportCloseAfterLostPing(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.tr.PingTimeout = 1 * time.Second - ct.tr.ReadIdleTimeout = 1 * time.Second - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - if err == nil || !strings.Contains(err.Error(), "client connection lost") { - return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - <-clientDone - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.PingTimeout = 1 * time.Second + tr.ReadIdleTimeout = 1 * time.Second + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.advance(1 * time.Second) + tc.wantFrameType(FramePing) + + tc.advance(1 * time.Second) + err := rt.err() + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + t.Fatalf("expected to get error about \"connection lost\", got %v", err) } - ct.run() } func TestTransportPingWriteBlocks(t *testing.T) { @@ -3628,418 +3074,231 @@ func TestTransportPingWriteBlocks(t *testing.T) { } } -func TestTransportPingWhenReading(t *testing.T) { - testCases := []struct { - name string - readIdleTimeout time.Duration - deadline time.Duration - expectedPingCount int - }{ - { - name: "two pings", - readIdleTimeout: 100 * time.Millisecond, - deadline: time.Second, - expectedPingCount: 2, - }, - { - name: "zero ping", - readIdleTimeout: time.Second, - deadline: 200 * time.Millisecond, - expectedPingCount: 0, - }, - { - name: "0 readIdleTimeout means no ping", - readIdleTimeout: 0 * time.Millisecond, - deadline: 500 * time.Millisecond, - expectedPingCount: 0, - }, - } - - for _, tc := range testCases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) - }) - } -} +func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 1000 * time.Millisecond + }) + tc.greet() -func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { - var pingCount int - ct := newClientTester(t) - ct.tr.ReadIdleTimeout = readIdleTimeout + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - ctx, cancel := context.WithTimeout(context.Background(), deadline) - defer cancel() - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) - } - _, err = ioutil.ReadAll(res.Body) - if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil + for i := 0; i < 5; i++ { + // No ping yet... + tc.advance(999 * time.Millisecond) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - cancel() - return err + // ...ping now. + tc.advance(1 * time.Millisecond) + f := testClientConnReadFrame[*PingFrame](tc) + tc.writePing(true, f.Data) } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var streamID uint32 - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-ctx.Done(): - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - streamID = f.StreamID - case *PingFrame: - pingCount++ - if pingCount == expectedPingCount { - if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { - return err - } - } - if err := ct.fr.WritePing(true, f.Data); err != nil { - return err - } - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // Cancel the request, Transport resets it and returns an error from body reads. + cancel() + tc.sync() + + tc.wantFrameType(FrameRSTStream) + _, err := rt.readBody() + if err == nil { + t.Fatalf("Response.Body.Read() = %v, want error", err) } - ct.run() } -func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { - ln := newLocalListener(t) - defer ln.Close() - - var ( - mu sync.Mutex - count int - conns []net.Conn - ) - var wg sync.WaitGroup - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - } - tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - defer mu.Unlock() - count++ - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - return nil, fmt.Errorf("dial error: %v", err) - } - conns = append(conns, cc) - sc, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("accept error: %v", err) - } - conns = append(conns, sc) - ct := &clientTester{ - t: t, - tr: tr, - cc: cc, - sc: sc, - fr: NewFramer(sc, sc), - } - wg.Add(1) - go func(count int) { - defer wg.Done() - server(count, ct) - }(count) - return cc, nil - } +func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 0 // PINGs disabled + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - client(tr) - tr.CloseIdleConnections() - ln.Close() - for _, c := range conns { - c.Close() + // No PING is sent, even after a long delay. + tc.advance(1 * time.Minute) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - wg.Wait() } func TestTransportRetryAfterGOAWAY(t *testing.T) { - client := func(tr *Transport) { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := tr.RoundTrip(req) - if res != nil { - res.Body.Close() - if got := res.Header.Get("Foo"); got != "bar" { - err = fmt.Errorf("foo header = %q; want bar", got) - } - } - if err != nil { - t.Errorf("RoundTrip: %v", err) - } - } - - server := func(count int, ct *clientTester) { - switch count { - case 1: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - t.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - case 2: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server2 failed reading HEADERS: %v", err) - return - } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - t.Errorf("server2 failed writing response HEADERS: %v", err) - } - default: - t.Errorf("unexpected number of dials") - return - } - } + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil) + if rt.done() { + t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying") + } + + // Second attempt succeeds on a new connection. + tc = tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(200) } func TestTransportRetryAfterRefusedStream(t *testing.T) { - clientDone := make(chan struct{}) - client := func(tr *Transport) { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - return - } - resp.Body.Close() - if resp.StatusCode != 204 { - t.Errorf("Status = %v; want 204", resp.StatusCode) - return - } + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a RST_STREAM. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK + tc.writeRSTStream(1, ErrCodeRefusedStream) + if rt.done() { + t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying") } - server := func(_ int, ct *clientTester) { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var count int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - default: - t.Error(err) - } - return - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - t.Errorf("headers should have END_HEADERS be ended: %v", f) - return - } - count++ - if count == 1 { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - } else { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - default: - t.Errorf("Unexpected client frame %v", f) - return - } - } - } + // Second attempt succeeds on the same connection. + tc.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 3, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(204) } func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + var totalDelay time.Duration + count := 0 + for streamID := uint32(1); ; streamID += 2 { + count++ + tc.wantHeaders(wantHeader{ + streamID: streamID, + endStream: true, + }) + if streamID == 1 { + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) + tc.writeRSTStream(streamID, ErrCodeRefusedStream) + + d := tt.tr.syncHooks.timeUntilEvent() + if d == 0 { + if streamID == 1 { + continue } + break + } + totalDelay += d + if totalDelay > 5*time.Minute { + t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } + tt.advance(d) + } + if got, want := count, 5; got < count { + t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) + } + if rt.err() == nil { + t.Errorf("RoundTrip succeeded, want error") } - ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) + // Discard log output complaining about protocol error. + log.SetOutput(io.Discard) + t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done + + tc := newTestClientConn(t) + tc.greet() + + // First request is normal to ensure the check is per stream and not per connection. + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt1.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req := httptest.NewRequest("GET", "https://dummy.tld/", nil) - // First request is normal to ensure the check is per stream and not per connection. - _, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip expected no error, got: %v", err) - } - // Second request returns a DATA frame with no HEADERS. - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { - return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: - case *HeadersFrame: - switch f.StreamID { - case 1: - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - case 3: - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - } - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // Second request returns a DATA frame with no HEADERS. + rt2 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeData(rt2.streamID(), true, []byte("payload")) + if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err) } - ct.run() } func TestTransportMaxFrameReadSize(t *testing.T) { @@ -4053,30 +3312,17 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - ct := newClientTester(t) - ct.tr.MaxReadFrameSize = test.maxReadFrameSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - return nil - } - ct.server = func() error { - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - var got uint32 - ct.settings.ForeachSetting(func(s Setting) error { - switch s.ID { - case SettingMaxFrameSize: - got = s.Val - } - return nil - }) - if got != test.want { - t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxReadFrameSize = test.maxReadFrameSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + got, ok := fr.Value(SettingMaxFrameSize) + if !ok { + t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want) + } else if got != test.want { + t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) } - ct.run() } } @@ -4108,345 +3354,134 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { if err != nil { t.Fatal(err) } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - } - - if connCount != 1 { - t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) - } -} - -// tests Transport.StrictMaxConcurrentStreams -func TestTransportRequestsStallAtServerLimit(t *testing.T) { - const maxConcurrent = 2 - - greet := make(chan struct{}) // server sends initial SETTINGS frame - gotRequest := make(chan struct{}) // server received a request - clientDone := make(chan struct{}) - cancelClientRequest := make(chan struct{}) - - // Collect errors from goroutines. - var wg sync.WaitGroup - errs := make(chan error, 100) - defer func() { - wg.Wait() - close(errs) - for err := range errs { - t.Error(err) - } - }() - - // We will send maxConcurrent+2 requests. This checker goroutine waits for the - // following stages: - // 1. The first maxConcurrent requests are received by the server. - // 2. The client will cancel the next request - // 3. The server is unblocked so it can service the first maxConcurrent requests - // 4. The client will send the final request - wg.Add(1) - unblockClient := make(chan struct{}) - clientRequestCancelled := make(chan struct{}) - unblockServer := make(chan struct{}) - go func() { - defer wg.Done() - // Stage 1. - for k := 0; k < maxConcurrent; k++ { - <-gotRequest - } - // Stage 2. - close(unblockClient) - <-clientRequestCancelled - // Stage 3: give some time for the final RoundTrip call to be scheduled and - // verify that the final request is not sent. - time.Sleep(50 * time.Millisecond) - select { - case <-gotRequest: - errs <- errors.New("last request did not stall") - close(unblockServer) - return - default: - } - close(unblockServer) - // Stage 4. - <-gotRequest - }() - - ct := newClientTester(t) - ct.tr.StrictMaxConcurrentStreams = true - ct.client = func() error { - var wg sync.WaitGroup - defer func() { - wg.Wait() - close(clientDone) - ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.cc.(*net.TCPConn).Close() - } - }() - for k := 0; k < maxConcurrent+2; k++ { - wg.Add(1) - go func(k int) { - defer wg.Done() - // Don't send the second request until after receiving SETTINGS from the server - // to avoid a race where we use the default SettingMaxConcurrentStreams, which - // is much larger than maxConcurrent. We have to send the first request before - // waiting because the first request triggers the dial and greet. - if k > 0 { - <-greet - } - // Block until maxConcurrent requests are sent before sending any more. - if k >= maxConcurrent { - <-unblockClient - } - body := newStaticCloseChecker("") - req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) - if k == maxConcurrent { - // This request will be canceled. - req.Cancel = cancelClientRequest - close(cancelClientRequest) - _, err := ct.tr.RoundTrip(req) - close(clientRequestCancelled) - if err == nil { - errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) - return - } - } else { - resp, err := ct.tr.RoundTrip(req) - if err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - return - } - ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 204 { - errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - if err := body.isClosed(); err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - } - }(k) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() } - return nil } - ct.server = func() error { - var wg sync.WaitGroup - defer wg.Wait() + if connCount != 1 { + t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) + } +} - ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) +// tests Transport.StrictMaxConcurrentStreams +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 - // Server write loop. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - writeResp := make(chan uint32, maxConcurrent+1) + tc := newTestClientConn(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) + tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) - wg.Add(1) - go func() { - defer wg.Done() - <-unblockServer - for id := range writeResp { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: id, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - }() + cancelClientRequest := make(chan struct{}) - // Server read loop. - var nreq int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it will have reported any errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame: - case *SettingsFrame: - // Wait for the client SETTINGS ack until ending the greet. - close(greet) - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - gotRequest <- struct{}{} - nreq++ - writeResp <- f.StreamID - if nreq == maxConcurrent+1 { - close(writeResp) - } - case *DataFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) + // Start maxConcurrent+2 requests. + // The server does not respond to any of them yet. + var rts []*testRoundTrip + for k := 0; k < maxConcurrent+2; k++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + req.Cancel = cancelClientRequest + } + rt := tc.roundTrip(req) + rts = append(rts, rt) + + if k < maxConcurrent { + // We are under the stream limit, so the client sends the request. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", k)}, + }, + }) + } else { + // We have reached the stream limit, + // so the client cannot send the request. + if fr := tc.readFrame(); fr != nil { + t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr) } } + + if rt.done() { + t.Fatalf("rt %v done", k) + } + } + + // Cancel the maxConcurrent'th request. + // The request should fail. + close(cancelClientRequest) + tc.sync() + if err := rts[maxConcurrent].err(); err == nil { + t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) + } + + // No requests should be complete, except for the canceled one. + for i, rt := range rts { + if i != maxConcurrent && rt.done() { + t.Fatalf("RoundTrip(%d) is done, but should not be", i) + } } - ct.run() + // Server responds to a request, unblocking the last one. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rts[0].streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.wantHeaders(wantHeader{ + streamID: rts[maxConcurrent+1].streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)}, + }, + }) + rts[0].wantStatus(200) } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var reqSize, resSize uint32 = 8192, 16384 - ct.tr.MaxDecoderHeaderTableSize = reqSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.peerMaxHeaderTableSize, resSize; got != want { - return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxDecoderHeaderTableSize = reqSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + if v, ok := fr.Value(SettingHeaderTableSize); !ok { + t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting") + } else if v != reqSize { + t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize) } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - var found bool - err = sf.ForeachSetting(func(s Setting) error { - if s.ID == SettingHeaderTableSize { - found = true - if got, want := s.Val, reqSize; got != want { - return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want) - } - } - return nil - }) - if err != nil { - return err - } - if !found { - return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting") - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) + if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want { + t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } - ct.run() } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var peerAdvertisedMaxHeaderTableSize uint32 = 16384 - ct.tr.MaxEncoderHeaderTableSize = 8192 - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want { - return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want) - } - return nil - } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxEncoderHeaderTableSize = 8192 + }) + tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { + t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } - ct.run() } func TestAuthorityAddr(t *testing.T) { @@ -4530,40 +3565,24 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. func TestTransportNoBodyMeansNoDATA(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - <-unblockClient - return nil - } - ct.server = func() error { - defer close(unblockClient) - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f := f.(type) { - default: - return fmt.Errorf("Got %T; want HeadersFrame", f) - case *WindowUpdateFrame, *SettingsFrame: - continue - case *HeadersFrame: - if !f.StreamEnded() { - return fmt.Errorf("got headers frame without END_STREAM") - } - return nil - } - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, // END_STREAM should be set when body is http.NoBody + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + }, + }) + if fr := tc.readFrame(); fr != nil { + t.Fatalf("unexpected frame after headers: %v", fr) } - ct.run() } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { @@ -4642,41 +3661,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - const substr = "malformed response from server: missing status pseudo header" - if !strings.Contains(fmt.Sprint(err), substr) { - return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, // we'll send some DATA to try to crash the transport - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - return nil - } - } - } - ct.run() + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: tc.makeHeaderBlockFragment( + "content-type", "text/html", // no :status header + ), + }) + tc.writeData(rt.streamID(), true, []byte("payload")) } func BenchmarkClientRequestHeaders(b *testing.B) { @@ -5024,95 +4024,42 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - // So far we've only seen this be flaky on Windows and Plan 9, - // perhaps due to TCP behavior on shutdowns while - // unread data is in flight. This test should be - // fixed, but a skip is better than annoying people - // for now. - t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) - } - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*clientConnPool) - if !ok { - return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) - } - cp.mu.Lock() - defer cp.mu.Unlock() - conns, ok := cp.conns["dummy.tld:443"] - if !ok { - return fmt.Errorf("missing connection") - } - if len(conns) != 1 { - return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) - } - if activeStreams(conns[0]) != 0 { - return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) - } - return nil - } - bodyReadError := errors.New("body read error") - body := &errReader{body, bodyReadError} - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != bodyReadError { - return fmt.Errorf("err = %v; want %v", err, bodyReadError) - } - if err = checkNoStreams(); err != nil { - return err + tc := newTestClientConn(t) + tc.greet() + + bodyReadError := errors.New("body read error") + b := tc.newRequestBody() + b.Write(body) + b.closeWithError(bodyReadError) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", b) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + var receivedBody []byte +readFrames: + for { + switch f := tc.readFrame().(type) { + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + break readFrames + default: + t.Fatalf("unexpected frame: %v", f) + case nil: + t.Fatalf("transport is idle, want RST_STREAM") } - return nil } - ct.server = func() error { - ct.greet() - var receivedBody []byte - var resetCount int - for { - f, err := ct.fr.ReadFrame() - t.Logf("server: ReadFrame = %v, %v", f, err) - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - if bytes.Compare(receivedBody, body) != 0 { - return fmt.Errorf("body: %q; expected %q", receivedBody, body) - } - if resetCount != 1 { - return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) - } - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - receivedBody = append(receivedBody, f.Data()...) - case *RSTStreamFrame: - resetCount++ - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + if !bytes.Equal(receivedBody, body) { + t.Fatalf("body: %q; expected %q", receivedBody, body) + } + + if err := rt.err(); err != bodyReadError { + t.Fatalf("err = %v; want %v", err, bodyReadError) + } + + if got := activeStreams(tc.cc); got != 0 { + t.Fatalf("active streams count: %v; want 0", got) } - ct.run() } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -5125,59 +4072,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) { const reqBody = "some request body" const resBody = "some response body" - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - body := strings.NewReader(reqBody) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } + body := strings.NewReader(reqBody) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + tc.roundTrip(req) - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - return fmt.Errorf("data frame without END_STREAM %v", f) - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte(resBody)) - return nil - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc.wantFrameType(FrameHeaders) + f := testClientConnReadFrame[*DataFrame](tc) + if !f.StreamEnded() { + t.Fatalf("data frame without END_STREAM %v", f) } - ct.run() } type chunkReader struct { @@ -5826,155 +4732,80 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - lower ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool + // This test verifies that + // - receiving a protocol error on a connection does not interfere with + // other requests in flight on that connection; + // - the connection is not reused for further requests; and + // - the failed request is retried on a new connecection. + tt := newTestTransport(t) + + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. + + // Request #1: The long request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tt.roundTrip(req1) + tc1 := tt.getConn() + tc1.wantFrameType(FrameSettings) + tc1.wantFrameType(FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc1.writeSettings() + tc1.wantFrameType(FrameSettings) // settings ACK + + // Request #2(a): The short request. + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := tt.roundTrip(req2) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } + // Request #2(a) fails with ErrCodeProtocol. + tc1.writeRSTStream(3, ErrCodeProtocol) + if rt1.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress") } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil + if rt2.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress") } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - } - ct.run() + // Request #2(b): The short request is retried on a new connection. + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc2.writeSettings() + tc2.wantFrameType(FrameSettings) // settings ACK + + // Request #2(b) succeeds. + tc2.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "201", + ), + }) + rt2.wantStatus(201) + + // Request #1 succeeds. + tc1.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) } func TestClientConnReservations(t *testing.T) { @@ -5987,7 +4818,7 @@ func TestClientConnReservations(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.newClientConn(st.cc, false) + cc, err := tr.newClientConn(st.cc, false, nil) if err != nil { t.Fatal(err) } @@ -6026,39 +4857,27 @@ func TestClientConnReservations(t *testing.T) { } func TestTransportTimeoutServerHangs(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) + tc := newTestClientConn(t) + tc.greet() - req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) - if err != nil { - return err - } + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - req.Header.Add("Big", strings.Repeat("a", 1<<20)) - _, err = ct.tr.RoundTrip(req) - if err == nil { - return errors.New("error should not be nil") - } - if ne, ok := err.(net.Error); !ok || !ne.Timeout() { - return fmt.Errorf("error should be a net error timeout: %v", err) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.advance(5 * time.Second) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - ct.server = func() error { - ct.greet() - select { - case <-time.After(5 * time.Second): - case <-clientDone: - } - return nil + if rt.done() { + t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned") + } + + cancel() + tc.sync() + if rt.err() != context.Canceled { + t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } - ct.run() } func TestTransportContentLengthWithoutBody(t *testing.T) { @@ -6251,20 +5070,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } -type closeOnceConn struct { - net.Conn - closed uint32 -} - -var errClosed = errors.New("Close of closed connection") - -func (c *closeOnceConn) Close() error { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return c.Conn.Close() - } - return errClosed -} - // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // @@ -6272,53 +5077,35 @@ func (c *closeOnceConn) Close() error { // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { - ct := newClientTester(t) - ct.cc = &closeOnceConn{Conn: ct.cc} - - var wg sync.WaitGroup - wg.Add(1) - ct.client = func() error { - defer wg.Done() - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - } - if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { - t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) - } - if err = ct.cc.Close(); err != errClosed { - return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeGoAway(lastStream, ErrCodeNo, nil) + + if lastStream > 0 { + // Send a valid response to first request. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) } - ct.server = func() error { - defer wg.Wait() - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - return fmt.Errorf("server failed reading HEADERS: %v", err) - } - if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { - return fmt.Errorf("server failed writing GOAWAY: %v", err) - } - if lastStream > 0 { - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - return nil + tc.closeWrite(io.EOF) + err := rt.err() + if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { + t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) + } + if !tc.netConnClosed { + t.Errorf("ClientConn did not close its net.Conn, expected it to") } - - ct.run() } type slowCloser struct { @@ -6520,3 +5307,32 @@ func TestDialRaceResumesDial(t *testing.T) { case <-successCh: } } + +func TestTransportDataAfter1xxHeader(t *testing.T) { + // Discard logger output to avoid spamming stderr. + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame. + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + tc.writeData(rt.streamID(), true, []byte{0}) + err := rt.err() + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err) + } + tc.wantFrameType(FrameRSTStream) +} diff --git a/quic/rangeset.go b/quic/rangeset.go index 4966a99d2..b8b2e9367 100644 --- a/quic/rangeset.go +++ b/quic/rangeset.go @@ -50,7 +50,7 @@ func (s *rangeset[T]) add(start, end T) { if end <= r.end { return } - // Possibly coalesce subsquent ranges into range i. + // Possibly coalesce subsequent ranges into range i. r.end = end j := i + 1 for ; j < len(*s) && r.end >= (*s)[j].start; j++ { diff --git a/quic/retry_test.go b/quic/retry_test.go index 42f2bdd4a..c898ad331 100644 --- a/quic/retry_test.go +++ b/quic/retry_test.go @@ -521,7 +521,7 @@ func TestParseInvalidRetryPackets(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { if _, ok := parseRetryPacket(test.pkt, originalDstConnID); ok { - t.Errorf("parseRetryPacket succeded, want failure") + t.Errorf("parseRetryPacket succeeded, want failure") } }) } diff --git a/quic/stream_limits_test.go b/quic/stream_limits_test.go index 9c2f71ec1..8fed825d7 100644 --- a/quic/stream_limits_test.go +++ b/quic/stream_limits_test.go @@ -249,7 +249,7 @@ func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStopSending{ id: s.id, }) - tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM", + tc.wantFrame("received STOP_SENDING, send RESET_STREAM", packetType1RTT, debugFrameResetStream{ id: s.id, }) diff --git a/quic/version_test.go b/quic/version_test.go index 92fabd7b3..0bd8bac14 100644 --- a/quic/version_test.go +++ b/quic/version_test.go @@ -39,10 +39,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { }) gotPkt := te.read() if gotPkt == nil { - t.Fatalf("got no response; want Version Negotiaion") + t.Fatalf("got no response; want Version Negotiation") } if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { - t.Fatalf("got packet type %v; want Version Negotiaion", got) + t.Fatalf("got packet type %v; want Version Negotiation", got) } gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) if got, want := gotDst, srcConnID; !bytes.Equal(got, want) {