From 7b5abfaf7fdcca5801931b9726eb8ce9ae586ba1 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 26 Oct 2023 11:57:09 -0700 Subject: [PATCH 01/70] quic: basic qlog support Add the structure for generating and writing qlog events. Events are generated as slog events using the structure of the qlog events (draft-ietf-quic-qlog-quic-events-03). The qlog package contains a slog Handler implementation that converts the quic package events to qlog JSON. This CL generates events for connection creation and closure. Future CLs will add additional events. Events follow draft-ietf-quic-qlog-quic-events-03, which is the most recent draft supported by the qvis visualization tool. https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html For golang/go#58547 Change-Id: I5fb1b7653d0257cb86726bd5bc9e8775da74686a Reviewed-on: https://go-review.googlesource.com/c/net/+/537936 Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/cmd/interop/main.go | 17 +- internal/quic/cmd/interop/run_endpoint.sh | 4 +- internal/quic/config.go | 11 + internal/quic/conn.go | 15 +- internal/quic/conn_close_test.go | 22 +- internal/quic/conn_test.go | 1 + internal/quic/qlog.go | 141 ++++++++++++ internal/quic/qlog/handler.go | 76 ++++++ internal/quic/qlog/json_writer.go | 194 ++++++++++++++++ internal/quic/qlog/json_writer_test.go | 186 +++++++++++++++ internal/quic/qlog/qlog.go | 267 ++++++++++++++++++++++ internal/quic/qlog/qlog_test.go | 151 ++++++++++++ internal/quic/qlog_test.go | 132 +++++++++++ internal/quic/stateless_reset_test.go | 10 +- 14 files changed, 1216 insertions(+), 11 deletions(-) create mode 100644 internal/quic/qlog.go create mode 100644 internal/quic/qlog/handler.go create mode 100644 internal/quic/qlog/json_writer.go create mode 100644 internal/quic/qlog/json_writer_test.go create mode 100644 internal/quic/qlog/qlog.go create mode 100644 internal/quic/qlog/qlog_test.go create mode 100644 internal/quic/qlog_test.go diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index cc5292e9e..2ca5d652a 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/url" "os" @@ -25,14 +26,16 @@ import ( "sync" "golang.org/x/net/internal/quic" + "golang.org/x/net/internal/quic/qlog" ) var ( - listen = flag.String("listen", "", "listen address") - cert = flag.String("cert", "", "certificate") - pkey = flag.String("key", "", "private key") - root = flag.String("root", "", "serve files from this root") - output = flag.String("output", "", "directory to write files to") + listen = flag.String("listen", "", "listen address") + cert = flag.String("cert", "", "certificate") + pkey = flag.String("key", "", "private key") + root = flag.String("root", "", "serve files from this root") + output = flag.String("output", "", "directory to write files to") + qlogdir = flag.String("qlog", "", "directory to write qlog output to") ) func main() { @@ -48,6 +51,10 @@ func main() { }, MaxBidiRemoteStreams: -1, MaxUniRemoteStreams: -1, + QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: quic.QLogLevelFrame, + Dir: *qlogdir, + })), } if *cert != "" { c, err := tls.LoadX509KeyPair(*cert, *pkey) diff --git a/internal/quic/cmd/interop/run_endpoint.sh b/internal/quic/cmd/interop/run_endpoint.sh index d72335d8e..442039bc0 100644 --- a/internal/quic/cmd/interop/run_endpoint.sh +++ b/internal/quic/cmd/interop/run_endpoint.sh @@ -11,7 +11,7 @@ if [ "$ROLE" == "client" ]; then # Wait for the simulator to start up. /wait-for-it.sh sim:57832 -s -t 30 - ./interop -output=/downloads $CLIENT_PARAMS $REQUESTS + ./interop -output=/downloads -qlog=$QLOGDIR $CLIENT_PARAMS $REQUESTS elif [ "$ROLE" == "server" ]; then - ./interop -cert=/certs/cert.pem -key=/certs/priv.key -listen=:443 -root=/www "$@" $SERVER_PARAMS + ./interop -cert=/certs/cert.pem -key=/certs/priv.key -qlog=$QLOGDIR -listen=:443 -root=/www "$@" $SERVER_PARAMS fi diff --git a/internal/quic/config.go b/internal/quic/config.go index 6278bf89c..b10ecc79e 100644 --- a/internal/quic/config.go +++ b/internal/quic/config.go @@ -8,6 +8,7 @@ package quic import ( "crypto/tls" + "log/slog" ) // A Config structure configures a QUIC endpoint. @@ -72,6 +73,16 @@ type Config struct { // // If this field is left as zero, stateless reset is disabled. StatelessResetKey [32]byte + + // QLogLogger receives qlog events. + // + // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03. + // This is not the latest version of the draft, but is the latest version supported + // by common event log viewers as of the time this paragraph was written. + // + // The qlog package contains a slog.Handler which serializes qlog events + // to a standard JSON representation. + QLogLogger *slog.Logger } func configDefault(v, def, limit int64) int64 { diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 1292f2b20..cca11166c 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "errors" "fmt" + "log/slog" "net/netip" "time" ) @@ -60,6 +61,8 @@ type Conn struct { // Tests only: Send a PING in a specific number space. testSendPingSpace numberSpace testSendPing sentVal + + log *slog.Logger } // connTestHooks override conn behavior in tests. @@ -94,7 +97,7 @@ type newServerConnIDs struct { retrySrcConnID []byte // source from server's Retry } -func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (conn *Conn, _ error) { c := &Conn{ side: side, listener: l, @@ -106,6 +109,14 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip idleTimeout: now.Add(defaultMaxIdleTimeout), peerAckDelayExponent: -1, } + defer func() { + // If we hit an error in newConn, close donec so tests don't get stuck waiting for it. + // This is only relevant if we've got a bug, but it makes tracking that bug down + // much easier. + if conn == nil { + close(c.donec) + } + }() // A one-element buffer allows us to wake a Conn's event loop as a // non-blocking operation. @@ -135,6 +146,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip // The smallest allowed maximum QUIC datagram size is 1200 bytes. // TODO: PMTU discovery. const maxDatagramSize = 1200 + c.logConnectionStarted(cids.originalDstConnID, peerAddr) c.keysAppData.init() c.loss.init(c.side, maxDatagramSize, now) c.streamsInit() @@ -259,6 +271,7 @@ func (c *Conn) loop(now time.Time) { defer close(c.donec) defer c.tls.Close() defer c.listener.connDrained(c) + defer c.logConnectionClosed() // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go index d583ae92a..0dd46dd20 100644 --- a/internal/quic/conn_close_test.go +++ b/internal/quic/conn_close_test.go @@ -70,7 +70,8 @@ func TestConnCloseResponseBackoff(t *testing.T) { } func TestConnCloseWithPeerResponse(t *testing.T) { - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() tc.conn.Abort(nil) @@ -99,10 +100,19 @@ func TestConnCloseWithPeerResponse(t *testing.T) { if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr) } + + tc.advance(1 * time.Second) // long enough to exit the draining state + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "application", + }, + }) } func TestConnClosePeerCloses(t *testing.T) { - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() wantErr := &ApplicationError{ @@ -128,6 +138,14 @@ func TestConnClosePeerCloses(t *testing.T) { code: 9, reason: "because", }) + + tc.advance(1 * time.Second) // long enough to exit the draining state + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "application", + }, + }) } func TestConnCloseReceiveInInitial(t *testing.T) { diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index c70c58ef0..514a8775e 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -198,6 +198,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { // The initial connection ID for the server is chosen by the client. cids.srcConnID = testPeerConnID(0) cids.dstConnID = testPeerConnID(-1) + cids.originalDstConnID = cids.dstConnID } var configTransportParams []func(*transportParameters) var configTestConn []func(*testConn) diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go new file mode 100644 index 000000000..29875693e --- /dev/null +++ b/internal/quic/qlog.go @@ -0,0 +1,141 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "encoding/hex" + "log/slog" + "net/netip" +) + +// Log levels for qlog events. +const ( + // QLogLevelFrame includes per-frame information. + // When this level is enabled, packet_sent and packet_received events will + // contain information on individual frames sent/received. + QLogLevelFrame = slog.Level(-6) + + // QLogLevelPacket events occur at most once per packet sent or received. + // + // For example: packet_sent, packet_received. + QLogLevelPacket = slog.Level(-4) + + // QLogLevelConn events occur multiple times over a connection's lifetime, + // but less often than the frequency of individual packets. + // + // For example: connection_state_updated. + QLogLevelConn = slog.Level(-2) + + // QLogLevelEndpoint events occur at most once per connection. + // + // For example: connection_started, connection_closed. + QLogLevelEndpoint = slog.Level(0) +) + +func (c *Conn) logEnabled(level slog.Level) bool { + return c.log != nil && c.log.Enabled(context.Background(), level) +} + +// slogHexstring returns a slog.Attr for a value of the hexstring type. +// +// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1 +func slogHexstring(key string, value []byte) slog.Attr { + return slog.String(key, hex.EncodeToString(value)) +} + +func slogAddr(key string, value netip.Addr) slog.Attr { + return slog.String(key, value.String()) +} + +func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) { + if c.config.QLogLogger == nil || + !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) { + return + } + var vantage string + if c.side == clientSide { + vantage = "client" + originalDstConnID = c.connIDState.originalDstConnID + } else { + vantage = "server" + } + // A qlog Trace container includes some metadata (title, description, vantage_point) + // and a list of Events. The Trace also includes a common_fields field setting field + // values common to all events in the trace. + // + // Trace = { + // ? title: text + // ? description: text + // ? configuration: Configuration + // ? common_fields: CommonFields + // ? vantage_point: VantagePoint + // events: [* Event] + // } + // + // To map this into slog's data model, we start each per-connection trace with a With + // call that includes both the trace metadata and the common fields. + // + // This means that in slog's model, each trace event will also include + // the Trace metadata fields (vantage_point), which is a divergence from the qlog model. + c.log = c.config.QLogLogger.With( + // The group_id permits associating traces taken from different vantage points + // for the same connection. + // + // We use the original destination connection ID as the group ID. + // + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6 + slogHexstring("group_id", originalDstConnID), + slog.Group("vantage_point", + slog.String("name", "go quic"), + slog.String("type", vantage), + ), + ) + localAddr := c.listener.LocalAddr() + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_started", + slogAddr("src_ip", localAddr.Addr()), + slog.Int("src_port", int(localAddr.Port())), + slogHexstring("src_cid", c.connIDState.local[0].cid), + slogAddr("dst_ip", peerAddr.Addr()), + slog.Int("dst_port", int(peerAddr.Port())), + slogHexstring("dst_cid", c.connIDState.remote[0].cid), + ) +} + +func (c *Conn) logConnectionClosed() { + if !c.logEnabled(QLogLevelEndpoint) { + return + } + err := c.lifetime.finalErr + trigger := "error" + switch e := err.(type) { + case *ApplicationError: + // TODO: Distinguish between peer and locally-initiated close. + trigger = "application" + case localTransportError: + if e.code == errNo { + trigger = "clean" + } + case peerTransportError: + if e.code == errNo { + trigger = "clean" + } + default: + switch err { + case errStatelessReset: + trigger = "stateless_reset" + } + // TODO: idle_timeout, handshake_timeout + } + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_closed", + slog.String("trigger", trigger), + ) +} diff --git a/internal/quic/qlog/handler.go b/internal/quic/qlog/handler.go new file mode 100644 index 000000000..35a66cf8b --- /dev/null +++ b/internal/quic/qlog/handler.go @@ -0,0 +1,76 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "context" + "log/slog" +) + +type withAttrsHandler struct { + attrs []slog.Attr + h slog.Handler +} + +func withAttrs(h slog.Handler, attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &withAttrsHandler{attrs: attrs, h: h} +} + +func (h *withAttrsHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.h.Enabled(ctx, level) +} + +func (h *withAttrsHandler) Handle(ctx context.Context, r slog.Record) error { + r.AddAttrs(h.attrs...) + return h.h.Handle(ctx, r) +} + +func (h *withAttrsHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *withAttrsHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +type withGroupHandler struct { + name string + h slog.Handler +} + +func withGroup(h slog.Handler, name string) slog.Handler { + if name == "" { + return h + } + return &withGroupHandler{name: name, h: h} +} + +func (h *withGroupHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.h.Enabled(ctx, level) +} + +func (h *withGroupHandler) Handle(ctx context.Context, r slog.Record) error { + var attrs []slog.Attr + r.Attrs(func(a slog.Attr) bool { + attrs = append(attrs, a) + return true + }) + nr := slog.NewRecord(r.Time, r.Level, r.Message, r.PC) + nr.Add(slog.Any(h.name, slog.GroupValue(attrs...))) + return h.h.Handle(ctx, nr) +} + +func (h *withGroupHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *withGroupHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} diff --git a/internal/quic/qlog/json_writer.go b/internal/quic/qlog/json_writer.go new file mode 100644 index 000000000..50cf33bc5 --- /dev/null +++ b/internal/quic/qlog/json_writer.go @@ -0,0 +1,194 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "strconv" + "sync" + "time" +) + +// A jsonWriter writes JSON-SEQ (RFC 7464). +// +// A JSON-SEQ file consists of a series of JSON text records, +// each beginning with an RS (0x1e) character and ending with LF (0x0a). +type jsonWriter struct { + mu sync.Mutex + w io.WriteCloser + buf bytes.Buffer +} + +// writeRecordStart writes the start of a JSON-SEQ record. +func (w *jsonWriter) writeRecordStart() { + w.mu.Lock() + w.buf.WriteByte(0x1e) + w.buf.WriteByte('{') +} + +// writeRecordEnd finishes writing a JSON-SEQ record. +func (w *jsonWriter) writeRecordEnd() { + w.buf.WriteByte('}') + w.buf.WriteByte('\n') + w.w.Write(w.buf.Bytes()) + w.buf.Reset() + w.mu.Unlock() +} + +// writeAttrsField writes a []slog.Attr as an object field. +func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) { + w.writeName(name) + w.buf.WriteByte('{') + for _, a := range attrs { + w.writeAttr(a) + } + w.buf.WriteByte('}') +} + +// writeAttr writes a slog.Attr as an object field. +func (w *jsonWriter) writeAttr(a slog.Attr) { + v := a.Value.Resolve() + switch v.Kind() { + case slog.KindAny: + w.writeStringField(a.Key, fmt.Sprint(v.Any())) + case slog.KindBool: + w.writeBoolField(a.Key, v.Bool()) + case slog.KindDuration: + w.writeDurationField(a.Key, v.Duration()) + case slog.KindFloat64: + w.writeFloat64Field(a.Key, v.Float64()) + case slog.KindInt64: + w.writeInt64Field(a.Key, v.Int64()) + case slog.KindString: + w.writeStringField(a.Key, v.String()) + case slog.KindTime: + w.writeTimeField(a.Key, v.Time()) + case slog.KindUint64: + w.writeUint64Field(a.Key, v.Uint64()) + case slog.KindGroup: + w.writeAttrsField(a.Key, v.Group()) + default: + w.writeString("unhandled kind") + } +} + +// writeName writes an object field name followed by a colon. +func (w *jsonWriter) writeName(name string) { + if b := w.buf.Bytes(); len(b) > 0 && b[len(b)-1] != '{' { + // Add the comma separating this from the previous field. + w.buf.WriteByte(',') + } + w.writeString(name) + w.buf.WriteByte(':') +} + +// writeObject writes an object-valued object field. +// The function f is called to write the contents. +func (w *jsonWriter) writeObjectField(name string, f func()) { + w.writeName(name) + w.buf.WriteByte('{') + f() + w.buf.WriteByte('}') +} + +// writeRawField writes an field with a raw JSON value. +func (w *jsonWriter) writeRawField(name, v string) { + w.writeName(name) + w.buf.WriteString(v) +} + +// writeBoolField writes a bool-valued object field. +func (w *jsonWriter) writeBoolField(name string, v bool) { + w.writeName(name) + if v { + w.buf.WriteString("true") + } else { + w.buf.WriteString("false") + } +} + +// writeDurationField writes a millisecond duration-valued object field. +func (w *jsonWriter) writeDurationField(name string, v time.Duration) { + w.writeName(name) + fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond) +} + +// writeFloat64Field writes an float64-valued object field. +func (w *jsonWriter) writeFloat64Field(name string, v float64) { + w.writeName(name) + w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64)) +} + +// writeInt64Field writes an int64-valued object field. +func (w *jsonWriter) writeInt64Field(name string, v int64) { + w.writeName(name) + w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10)) +} + +// writeUint64Field writes a uint64-valued object field. +func (w *jsonWriter) writeUint64Field(name string, v uint64) { + w.writeName(name) + w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10)) +} + +// writeStringField writes a string-valued object field. +func (w *jsonWriter) writeStringField(name, v string) { + w.writeName(name) + w.writeString(v) +} + +// writeTimeField writes a time-valued object field. +func (w *jsonWriter) writeTimeField(name string, v time.Time) { + w.writeName(name) + fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond)) +} + +func jsonSafeSet(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c > 128, then 1<>64)) != 0 +} + +func jsonNeedsEscape(s string) bool { + for i := range s { + if !jsonSafeSet(s[i]) { + return true + } + } + return false +} + +// writeString writes an ASCII string. +// +// qlog fields should never contain anything that isn't ASCII, +// so we do the bare minimum to avoid producing invalid output if we +// do write something unexpected. +func (w *jsonWriter) writeString(v string) { + w.buf.WriteByte('"') + if !jsonNeedsEscape(v) { + w.buf.WriteString(v) + } else { + for i := range v { + if jsonSafeSet(v[i]) { + w.buf.WriteByte(v[i]) + } else { + fmt.Fprintf(&w.buf, `\u%04x`, v[i]) + } + } + } + w.buf.WriteByte('"') +} diff --git a/internal/quic/qlog/json_writer_test.go b/internal/quic/qlog/json_writer_test.go new file mode 100644 index 000000000..7ba5e1737 --- /dev/null +++ b/internal/quic/qlog/json_writer_test.go @@ -0,0 +1,186 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "testing" + "time" +) + +type testJSONOut struct { + bytes.Buffer +} + +func (o *testJSONOut) Close() error { return nil } + +func newTestJSONWriter() *jsonWriter { + return &jsonWriter{w: &testJSONOut{}} +} + +func wantJSONRecord(t *testing.T, w *jsonWriter, want string) { + t.Helper() + want = "\x1e" + want + "\n" + got := w.w.(*testJSONOut).String() + if got != want { + t.Errorf("jsonWriter contains unexpected output\ngot: %q\nwant: %q", got, want) + } +} + +func TestJSONWriterWriteConcurrentRecords(t *testing.T) { + w := newTestJSONWriter() + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + w.writeRecordStart() + w.writeInt64Field("field", 0) + w.writeRecordEnd() + }() + } + wg.Wait() + wantJSONRecord(t, w, strings.Join([]string{ + `{"field":0}`, + `{"field":0}`, + `{"field":0}`, + }, "\n\x1e")) +} + +func TestJSONWriterAttrs(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeAttrsField("field", []slog.Attr{ + slog.Any("any", errors.New("value")), + slog.Bool("bool", true), + slog.Duration("duration", 1*time.Second), + slog.Float64("float64", 1), + slog.Int64("int64", 1), + slog.String("string", "value"), + slog.Time("time", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + slog.Uint64("uint64", 1), + slog.Group("group", "a", 1), + }) + w.writeRecordEnd() + wantJSONRecord(t, w, + `{"field":{`+ + `"any":"value",`+ + `"bool":true,`+ + `"duration":1000.000000,`+ + `"float64":1,`+ + `"int64":1,`+ + `"string":"value",`+ + `"time":946684800000.000000,`+ + `"uint64":1,`+ + `"group":{"a":1}`+ + `}}`) +} + +func TestJSONWriterObjectEmpty(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeObjectField("field", func() {}) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":{}}`) +} + +func TestJSONWriterObjectFields(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeObjectField("field", func() { + w.writeStringField("a", "value") + w.writeInt64Field("b", 10) + }) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":{"a":"value","b":10}}`) +} + +func TestJSONWriterRawField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeRawField("field", `[1]`) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":[1]}`) +} + +func TestJSONWriterBoolField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeBoolField("true", true) + w.writeBoolField("false", false) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"true":true,"false":false}`) +} + +func TestJSONWriterDurationField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeDurationField("field", (10*time.Millisecond)+(2*time.Nanosecond)) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":10.000002}`) +} + +func TestJSONWriterFloat64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeFloat64Field("field", 1.1) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1.1}`) +} + +func TestJSONWriterInt64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeInt64Field("field", 1234) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1234}`) +} + +func TestJSONWriterUint64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeUint64Field("field", 1234) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1234}`) +} + +func TestJSONWriterStringField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", "value") + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":"value"}`) +} + +func TestJSONWriterStringFieldEscaped(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", "va\x00ue") + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":"va\u0000ue"}`) +} + +func TestJSONWriterStringEscaping(t *testing.T) { + for c := 0; c <= 0xff; c++ { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", string([]byte{byte(c)})) + w.writeRecordEnd() + var want string + if (c >= 0x20 && c <= 0x21) || (c >= 0x23 && c <= 0x5b) || (c >= 0x5d && c <= 0x7e) { + want = fmt.Sprintf(`%c`, c) + } else { + want = fmt.Sprintf(`\u%04x`, c) + } + wantJSONRecord(t, w, `{"field":"`+want+`"}`) + } +} diff --git a/internal/quic/qlog/qlog.go b/internal/quic/qlog/qlog.go new file mode 100644 index 000000000..0e71d71aa --- /dev/null +++ b/internal/quic/qlog/qlog.go @@ -0,0 +1,267 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +// Package qlog serializes qlog events. +package qlog + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "os" + "path/filepath" + "sync" + "time" +) + +// Vantage is the vantage point of a trace. +type Vantage string + +const ( + // VantageEndpoint traces contain events not specific to a single connection. + VantageEndpoint = Vantage("endpoint") + + // VantageClient traces follow a connection from the client's perspective. + VantageClient = Vantage("client") + + // VantageClient traces follow a connection from the server's perspective. + VantageServer = Vantage("server") +) + +// TraceInfo contains information about a trace. +type TraceInfo struct { + // Vantage is the vantage point of the trace. + Vantage Vantage + + // GroupID identifies the logical group the trace belongs to. + // For a connection trace, the group will be the same for + // both the client and server vantage points. + GroupID string +} + +// HandlerOptions are options for a JSONHandler. +type HandlerOptions struct { + // Level reports the minimum record level that will be logged. + // If Level is nil, the handler assumes QLogLevelEndpoint. + Level slog.Leveler + + // Dir is the directory in which to create trace files. + // The handler will create one file per connection. + // If NewTrace is non-nil or Dir is "", the handler will not create files. + Dir string + + // NewTrace is called to create a new trace. + // If NewTrace is nil and Dir is set, + // the handler will create a new file in Dir for each trace. + NewTrace func(TraceInfo) (io.WriteCloser, error) +} + +type endpointHandler struct { + opts HandlerOptions + + traceOnce sync.Once + trace *jsonTraceHandler +} + +// NewJSONHandler returns a handler which serializes qlog events to JSON. +// +// The handler will write an endpoint-wide trace, +// and a separate trace for each connection. +// The HandlerOptions control the location traces are written. +// +// It uses the streamable JSON Text Sequences mapping (JSON-SEQ) +// defined in draft-ietf-quic-qlog-main-schema-04, Section 6.2. +// +// A JSONHandler may be used as the handler for a quic.Config.QLogLogger. +// It is not a general-purpose slog handler, +// and may not properly handle events from other sources. +func NewJSONHandler(opts HandlerOptions) slog.Handler { + if opts.Dir == "" && opts.NewTrace == nil { + return slogDiscard{} + } + return &endpointHandler{ + opts: opts, + } +} + +func (h *endpointHandler) Enabled(ctx context.Context, level slog.Level) bool { + return enabled(h.opts.Level, level) +} + +func (h *endpointHandler) Handle(ctx context.Context, r slog.Record) error { + h.traceOnce.Do(func() { + h.trace, _ = newJSONTraceHandler(h.opts, nil) + }) + if h.trace != nil { + h.trace.Handle(ctx, r) + } + return nil +} + +func (h *endpointHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + // Create a new trace output file for each top-level WithAttrs. + tr, err := newJSONTraceHandler(h.opts, attrs) + if err != nil { + return withAttrs(h, attrs) + } + return tr +} + +func (h *endpointHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +type jsonTraceHandler struct { + level slog.Leveler + w jsonWriter + start time.Time + buf bytes.Buffer +} + +func newJSONTraceHandler(opts HandlerOptions, attrs []slog.Attr) (*jsonTraceHandler, error) { + w, err := newTraceWriter(opts, traceInfoFromAttrs(attrs)) + if err != nil { + return nil, err + } + + // For testing, it might be nice to set the start time used for relative timestamps + // to the time of the first event. + // + // At the expense of some additional complexity here, we could defer writing + // the reference_time header field until the first event is processed. + // + // Just use the current time for now. + start := time.Now() + + h := &jsonTraceHandler{ + w: jsonWriter{w: w}, + level: opts.Level, + start: start, + } + h.writeHeader(attrs) + return h, nil +} + +func traceInfoFromAttrs(attrs []slog.Attr) TraceInfo { + info := TraceInfo{ + Vantage: VantageEndpoint, // default if not specified + } + for _, a := range attrs { + if a.Key == "group_id" && a.Value.Kind() == slog.KindString { + info.GroupID = a.Value.String() + } + if a.Key == "vantage_point" && a.Value.Kind() == slog.KindGroup { + for _, aa := range a.Value.Group() { + if aa.Key == "type" && aa.Value.Kind() == slog.KindString { + info.Vantage = Vantage(aa.Value.String()) + } + } + } + } + return info +} + +func newTraceWriter(opts HandlerOptions, info TraceInfo) (io.WriteCloser, error) { + var w io.WriteCloser + var err error + if opts.NewTrace != nil { + w, err = opts.NewTrace(info) + } else if opts.Dir != "" { + var filename string + if info.GroupID != "" { + filename = info.GroupID + "_" + } + filename += string(info.Vantage) + ".sqlog" + if !filepath.IsLocal(filename) { + return nil, errors.New("invalid trace filename") + } + w, err = os.Create(filepath.Join(opts.Dir, filename)) + } else { + err = errors.New("no log destination") + } + return w, err +} + +func (h *jsonTraceHandler) writeHeader(attrs []slog.Attr) { + h.w.writeRecordStart() + defer h.w.writeRecordEnd() + + // At the time of writing this comment the most recent version is 0.4, + // but qvis only supports up to 0.3. + h.w.writeStringField("qlog_version", "0.3") + h.w.writeStringField("qlog_format", "JSON-SEQ") + + // The attrs flatten both common trace event fields and Trace fields. + // This identifies the fields that belong to the Trace. + isTraceSeqField := func(s string) bool { + switch s { + case "title", "description", "configuration", "vantage_point": + return true + } + return false + } + + h.w.writeObjectField("trace", func() { + h.w.writeObjectField("common_fields", func() { + h.w.writeRawField("protocol_type", `["QUIC"]`) + h.w.writeStringField("time_format", "relative") + h.w.writeTimeField("reference_time", h.start) + for _, a := range attrs { + if !isTraceSeqField(a.Key) { + h.w.writeAttr(a) + } + } + }) + for _, a := range attrs { + if isTraceSeqField(a.Key) { + h.w.writeAttr(a) + } + } + }) +} + +func (h *jsonTraceHandler) Enabled(ctx context.Context, level slog.Level) bool { + return enabled(h.level, level) +} + +func (h *jsonTraceHandler) Handle(ctx context.Context, r slog.Record) error { + h.w.writeRecordStart() + defer h.w.writeRecordEnd() + h.w.writeDurationField("time", r.Time.Sub(h.start)) + h.w.writeStringField("name", r.Message) + h.w.writeObjectField("data", func() { + r.Attrs(func(a slog.Attr) bool { + h.w.writeAttr(a) + return true + }) + }) + return nil +} + +func (h *jsonTraceHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *jsonTraceHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +func enabled(leveler slog.Leveler, level slog.Level) bool { + var minLevel slog.Level + if leveler != nil { + minLevel = leveler.Level() + } + return level >= minLevel +} + +type slogDiscard struct{} + +func (slogDiscard) Enabled(context.Context, slog.Level) bool { return false } +func (slogDiscard) Handle(ctx context.Context, r slog.Record) error { return nil } +func (slogDiscard) WithAttrs(attrs []slog.Attr) slog.Handler { return slogDiscard{} } +func (slogDiscard) WithGroup(name string) slog.Handler { return slogDiscard{} } diff --git a/internal/quic/qlog/qlog_test.go b/internal/quic/qlog/qlog_test.go new file mode 100644 index 000000000..7575cd890 --- /dev/null +++ b/internal/quic/qlog/qlog_test.go @@ -0,0 +1,151 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "reflect" + "testing" + "time" +) + +// QLog tests are mostly in the quic package, where we can test event generation +// and serialization together. + +func TestQLogHandlerEvents(t *testing.T) { + for _, test := range []struct { + name string + f func(*slog.Logger) + want []map[string]any // events, not counting the trace header + }{{ + name: "various types", + f: func(log *slog.Logger) { + log.Info("message", + "bool", true, + "duration", time.Duration(1*time.Second), + "float", 0.0, + "int", 0, + "string", "value", + "uint", uint64(0), + slog.Group("group", + "a", 0, + ), + ) + }, + want: []map[string]any{{ + "name": "message", + "data": map[string]any{ + "bool": true, + "duration": float64(1000), + "float": float64(0.0), + "int": float64(0), + "string": "value", + "uint": float64(0), + "group": map[string]any{ + "a": float64(0), + }, + }, + }}, + }, { + name: "WithAttrs", + f: func(log *slog.Logger) { + log = log.With( + "with_a", "a", + "with_b", "b", + ) + log.Info("m1", "field", "1") + log.Info("m2", "field", "2") + }, + want: []map[string]any{{ + "name": "m1", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "1", + }, + }, { + "name": "m2", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "2", + }, + }}, + }, { + name: "WithGroup", + f: func(log *slog.Logger) { + log = log.With( + "with_a", "a", + "with_b", "b", + ) + log.Info("m1", "field", "1") + log.Info("m2", "field", "2") + }, + want: []map[string]any{{ + "name": "m1", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "1", + }, + }, { + "name": "m2", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "2", + }, + }}, + }} { + var out bytes.Buffer + opts := HandlerOptions{ + Level: slog.LevelDebug, + NewTrace: func(TraceInfo) (io.WriteCloser, error) { + return nopCloseWriter{&out}, nil + }, + } + h, err := newJSONTraceHandler(opts, []slog.Attr{ + slog.String("group_id", "group"), + slog.Group("vantage_point", + slog.String("type", "client"), + ), + }) + if err != nil { + t.Fatal(err) + } + log := slog.New(h) + test.f(log) + got := []map[string]any{} + for i, e := range bytes.Split(out.Bytes(), []byte{0x1e}) { + // i==0: empty string before the initial record separator + // i==1: trace header; not part of this test + if i < 2 { + continue + } + var val map[string]any + if err := json.Unmarshal(e, &val); err != nil { + panic(fmt.Errorf("log unmarshal failure: %v\n%q", err, string(e))) + } + delete(val, "time") + got = append(got, val) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("event mismatch\ngot: %v\nwant: %v", got, test.want) + } + } + +} + +type nopCloseWriter struct { + io.Writer +} + +func (nopCloseWriter) Close() error { return nil } diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go new file mode 100644 index 000000000..5a2858b8b --- /dev/null +++ b/internal/quic/qlog_test.go @@ -0,0 +1,132 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "reflect" + "testing" + + "golang.org/x/net/internal/quic/qlog" +) + +func TestQLogHandshake(t *testing.T) { + testSides(t, "", func(t *testing.T, side connSide) { + qr := &qlogRecord{} + tc := newTestConn(t, side, qr.config) + tc.handshake() + tc.conn.Abort(nil) + tc.wantFrame("aborting connection generates CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{}) + tc.advanceToTimer() // let the conn finish draining + + var src, dst []byte + if side == clientSide { + src = testLocalConnID(0) + dst = testLocalConnID(-1) + } else { + src = testPeerConnID(-1) + dst = testPeerConnID(0) + } + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_started", + "data": map[string]any{ + "src_cid": hex.EncodeToString(src), + "dst_cid": hex.EncodeToString(dst), + }, + }, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "clean", + }, + }) + }) +} + +type nopCloseWriter struct { + io.Writer +} + +func (nopCloseWriter) Close() error { return nil } + +type jsonEvent map[string]any + +func (j jsonEvent) String() string { + b, _ := json.MarshalIndent(j, "", " ") + return string(b) +} + +// eventPartialEqual verifies that every field set in want matches the corresponding field in got. +// It ignores additional fields in got. +func eventPartialEqual(got, want jsonEvent) bool { + for k := range want { + ge, gok := got[k].(map[string]any) + we, wok := want[k].(map[string]any) + if gok && wok { + if !eventPartialEqual(ge, we) { + return false + } + } else { + if !reflect.DeepEqual(got[k], want[k]) { + return false + } + } + } + return true +} + +// A qlogRecord records events. +type qlogRecord struct { + ev []jsonEvent +} + +func (q *qlogRecord) Write(b []byte) (int, error) { + // This relies on the property that the Handler always makes one Write call per event. + if len(b) < 1 || b[0] != 0x1e { + panic(fmt.Errorf("trace Write should start with record separator, got %q", string(b))) + } + var val map[string]any + if err := json.Unmarshal(b[1:], &val); err != nil { + panic(fmt.Errorf("log unmarshal failure: %v\n%v", err, string(b))) + } + q.ev = append(q.ev, val) + return len(b), nil +} + +func (q *qlogRecord) Close() error { return nil } + +// config may be passed to newTestConn to configure the conn to use this logger. +func (q *qlogRecord) config(c *Config) { + c.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) { + return q, nil + }, + })) +} + +// wantEvents checks that every event in want occurs in the order specified. +func (q *qlogRecord) wantEvents(t *testing.T, want ...jsonEvent) { + t.Helper() + got := q.ev + unseen := want + for _, g := range got { + if eventPartialEqual(g, unseen[0]) { + unseen = unseen[1:] + if len(unseen) == 0 { + return + } + } + } + t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want) +} diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go index 8a16597c4..c01375fbd 100644 --- a/internal/quic/stateless_reset_test.go +++ b/internal/quic/stateless_reset_test.go @@ -130,7 +130,8 @@ func TestStatelessResetSentSizes(t *testing.T) { func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { // "[...] Stateless Reset Token field values from [...] NEW_CONNECTION_ID frames [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1 - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -158,6 +159,13 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { tc.wantIdle("closed connection is idle in draining") tc.advance(1 * time.Second) // long enough to exit the draining state tc.wantIdle("closed connection is idle after draining") + + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "stateless_reset", + }, + }) } func TestStatelessResetSuccessfulTransportParameter(t *testing.T) { From d87f99be5d1813013851ce74ed7d22743fa33f21 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 30 Oct 2023 10:51:45 -0700 Subject: [PATCH 02/70] quic: idle timeouts, handshake timeouts, and keepalive Negotiate the connection idle timeout based on the sent and received max_idle_timeout transport parameter values. Set a configurable limit on how long a handshake can take to complete. Add a configuration option to send keep-alive PING frames to avoid connection closure due to the idle timeout. RFC 9000, Section 10.1. For golang/go#58547 Change-Id: If6a611090ab836cd6937fcfbb1360a0f07425102 Reviewed-on: https://go-review.googlesource.com/c/net/+/540895 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/config.go | 36 ++++- internal/quic/conn.go | 33 ++--- internal/quic/conn_close.go | 270 ++++++++++++++++++++++-------------- internal/quic/conn_recv.go | 11 +- internal/quic/conn_send.go | 14 +- internal/quic/conn_test.go | 5 +- internal/quic/idle.go | 170 +++++++++++++++++++++++ internal/quic/idle_test.go | 225 ++++++++++++++++++++++++++++++ internal/quic/loss.go | 9 +- internal/quic/qlog.go | 12 +- internal/quic/qlog_test.go | 70 ++++++++++ internal/quic/quic.go | 6 + 12 files changed, 721 insertions(+), 140 deletions(-) create mode 100644 internal/quic/idle.go create mode 100644 internal/quic/idle_test.go diff --git a/internal/quic/config.go b/internal/quic/config.go index b10ecc79e..b045b7b92 100644 --- a/internal/quic/config.go +++ b/internal/quic/config.go @@ -9,6 +9,8 @@ package quic import ( "crypto/tls" "log/slog" + "math" + "time" ) // A Config structure configures a QUIC endpoint. @@ -74,6 +76,26 @@ type Config struct { // If this field is left as zero, stateless reset is disabled. StatelessResetKey [32]byte + // HandshakeTimeout is the maximum time in which a connection handshake must complete. + // If zero, the default of 10 seconds is used. + // If negative, there is no handshake timeout. + HandshakeTimeout time.Duration + + // MaxIdleTimeout is the maximum time after which an idle connection will be closed. + // If zero, the default of 30 seconds is used. + // If negative, idle connections are never closed. + // + // The idle timeout for a connection is the minimum of the maximum idle timeouts + // of the endpoints. + MaxIdleTimeout time.Duration + + // KeepAlivePeriod is the time after which a packet will be sent to keep + // an idle connection alive. + // If zero, keep alive packets are not sent. + // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and + // half the connection idle timeout. + KeepAlivePeriod time.Duration + // QLogLogger receives qlog events. // // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03. @@ -85,7 +107,7 @@ type Config struct { QLogLogger *slog.Logger } -func configDefault(v, def, limit int64) int64 { +func configDefault[T ~int64](v, def, limit T) T { switch { case v == 0: return def @@ -115,3 +137,15 @@ func (c *Config) maxStreamWriteBufferSize() int64 { func (c *Config) maxConnReadBufferSize() int64 { return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint) } + +func (c *Config) handshakeTimeout() time.Duration { + return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64) +} + +func (c *Config) maxIdleTimeout() time.Duration { + return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64) +} + +func (c *Config) keepAlivePeriod() time.Duration { + return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64) +} diff --git a/internal/quic/conn.go b/internal/quic/conn.go index cca11166c..b2b6a0877 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -26,22 +26,17 @@ type Conn struct { testHooks connTestHooks peerAddr netip.AddrPort - msgc chan any - donec chan struct{} // closed when conn loop exits - exited bool // set to make the conn loop exit immediately + msgc chan any + donec chan struct{} // closed when conn loop exits w packetWriter acks [numberSpaceCount]ackState // indexed by number space lifetime lifetimeState + idle idleState connIDState connIDState loss lossState streams streamsState - // idleTimeout is the time at which the connection will be closed due to inactivity. - // https://www.rfc-editor.org/rfc/rfc9000#section-10.1 - maxIdleTimeout time.Duration - idleTimeout time.Time - // Packet protection keys, CRYPTO streams, and TLS state. keysInitial fixedKeyPair keysHandshake fixedKeyPair @@ -105,8 +100,6 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip peerAddr: peerAddr, msgc: make(chan any, 1), donec: make(chan struct{}), - maxIdleTimeout: defaultMaxIdleTimeout, - idleTimeout: now.Add(defaultMaxIdleTimeout), peerAckDelayExponent: -1, } defer func() { @@ -151,6 +144,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip c.loss.init(c.side, maxDatagramSize, now) c.streamsInit() c.lifetimeInit() + c.restartIdleTimer(now) if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), @@ -202,6 +196,7 @@ func (c *Conn) confirmHandshake(now time.Time) { // don't need to send anything. c.handshakeConfirmed.setReceived() } + c.restartIdleTimer(now) c.loss.confirmHandshake() // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed" // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1 @@ -232,6 +227,7 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error { c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni + c.receivePeerMaxIdleTimeout(p.maxIdleTimeout) c.peerAckDelayExponent = p.ackDelayExponent c.loss.setMaxAckDelay(p.maxAckDelay) if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil { @@ -248,7 +244,6 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error { return err } } - // TODO: max_idle_timeout // TODO: stateless_reset_token // TODO: max_udp_payload_size // TODO: disable_active_migration @@ -261,6 +256,8 @@ type ( wakeEvent struct{} ) +var errIdleTimeout = errors.New("idle timeout") + // loop is the connection main loop. // // Except where otherwise noted, all connection state is owned by the loop goroutine. @@ -288,14 +285,14 @@ func (c *Conn) loop(now time.Time) { defer timer.Stop() } - for !c.exited { + for c.lifetime.state != connStateDone { sendTimeout := c.maybeSend(now) // try sending // Note that we only need to consider the ack timer for the App Data space, // since the Initial and Handshake spaces always ack immediately. nextTimeout := sendTimeout - nextTimeout = firstTime(nextTimeout, c.idleTimeout) - if !c.isClosingOrDraining() { + nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout) + if c.isAlive() { nextTimeout = firstTime(nextTimeout, c.loss.timer) nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) } else { @@ -329,11 +326,9 @@ func (c *Conn) loop(now time.Time) { m.recycle() case timerEvent: // A connection timer has expired. - if !now.Before(c.idleTimeout) { - // "[...] the connection is silently closed and - // its state is discarded [...]" - // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 - c.exited = true + if c.idleAdvance(now) { + // The connection idle timer has expired. + c.abortImmediately(now, errIdleTimeout) return } c.loss.advance(now, c.handleAckOrLoss) diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go index a9ef0db5e..246a12638 100644 --- a/internal/quic/conn_close.go +++ b/internal/quic/conn_close.go @@ -12,33 +12,54 @@ import ( "time" ) +// connState is the state of a connection. +type connState int + +const ( + // A connection is alive when it is first created. + connStateAlive = connState(iota) + + // The connection has received a CONNECTION_CLOSE frame from the peer, + // and has not yet sent a CONNECTION_CLOSE in response. + // + // We will send a CONNECTION_CLOSE, and then enter the draining state. + connStatePeerClosed + + // The connection is in the closing state. + // + // We will send CONNECTION_CLOSE frames to the peer + // (once upon entering the closing state, and possibly again in response to peer packets). + // + // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state. + // Otherwise, we will eventually time out and move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1 + connStateClosing + + // The connection is in the draining state. + // + // We will neither send packets nor process received packets. + // When the drain timer expires, we move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2 + connStateDraining + + // The connection is done, and the conn loop will exit. + connStateDone +) + // lifetimeState tracks the state of a connection. // // This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps // reason about operations that cause state transitions. type lifetimeState struct { - readyc chan struct{} // closed when TLS handshake completes - drainingc chan struct{} // closed when entering the draining state + state connState + + readyc chan struct{} // closed when TLS handshake completes + donec chan struct{} // closed when finalErr is set - // Possible states for the connection: - // - // Alive: localErr and finalErr are both nil. - // - // Closing: localErr is non-nil and finalErr is nil. - // We have sent a CONNECTION_CLOSE to the peer or are about to - // (if connCloseSentTime is zero) and are waiting for the peer to respond. - // drainEndTime is set to the time the closing state ends. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.1 - // - // Draining: finalErr is non-nil. - // If localErr is nil, we're waiting for the user to provide us with a final status - // to send to the peer. - // Otherwise, we've either sent a CONNECTION_CLOSE to the peer or are about to - // (if connCloseSentTime is zero). - // drainEndTime is set to the time the draining state ends. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 localErr error // error sent to the peer - finalErr error // error sent by the peer, or transport error; always set before draining + finalErr error // error sent by the peer, or transport error; set before closing donec connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent @@ -47,7 +68,7 @@ type lifetimeState struct { func (c *Conn) lifetimeInit() { c.lifetime.readyc = make(chan struct{}) - c.lifetime.drainingc = make(chan struct{}) + c.lifetime.donec = make(chan struct{}) } var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") @@ -60,13 +81,25 @@ func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { // The connection drain period has ended, and we can shut down. // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7 c.lifetime.drainEndTime = time.Time{} - if c.lifetime.finalErr == nil { - // The peer never responded to our CONNECTION_CLOSE. - c.enterDraining(now, errNoPeerResponse) + if c.lifetime.state != connStateDraining { + // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer. + c.setFinalError(errNoPeerResponse) } + c.setState(now, connStateDone) return true } +// setState sets the conn state. +func (c *Conn) setState(now time.Time, state connState) { + switch state { + case connStateClosing, connStateDraining: + if c.lifetime.drainEndTime.IsZero() { + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + } + } + c.lifetime.state = state +} + // confirmHandshake is called when the TLS handshake completes. func (c *Conn) handshakeDone() { close(c.lifetime.readyc) @@ -81,44 +114,66 @@ func (c *Conn) handshakeDone() { // // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 func (c *Conn) isDraining() bool { - return c.lifetime.finalErr != nil + switch c.lifetime.state { + case connStateDraining, connStateDone: + return true + } + return false } -// isClosingOrDraining reports whether the conn is in the closing or draining states. -func (c *Conn) isClosingOrDraining() bool { - return c.lifetime.localErr != nil || c.lifetime.finalErr != nil +// isAlive reports whether the conn is handling packets. +func (c *Conn) isAlive() bool { + return c.lifetime.state == connStateAlive } // sendOK reports whether the conn can send frames at this time. func (c *Conn) sendOK(now time.Time) bool { - if !c.isClosingOrDraining() { + switch c.lifetime.state { + case connStateAlive: return true - } - // We are closing or draining. - if c.lifetime.localErr == nil { - // We're waiting for the user to close the connection, providing us with - // a final status to send to the peer. + case connStatePeerClosed: + if c.lifetime.localErr == nil { + // We're waiting for the user to close the connection, providing us with + // a final status to send to the peer. + return false + } + // We should send a CONNECTION_CLOSE. + return true + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + return true + } + maxRecvTime := c.acks[initialSpace].maxRecvTime + if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { + // After sending CONNECTION_CLOSE, ignore packets from the peer for + // a delay. On the next packet received after the delay, send another + // CONNECTION_CLOSE. + return false + } + return true + case connStateDraining: + // We are in the draining state, and will send no more packets. return false + case connStateDone: + return false + default: + panic("BUG: unhandled connection state") } - // Past this point, returning true will result in the conn sending a CONNECTION_CLOSE - // due to localErr being set. - if c.lifetime.drainEndTime.IsZero() { - // The closing and draining states should last for at least three times - // the current PTO interval. We currently use exactly that minimum. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-5 - // - // The drain period begins when we send or receive a CONNECTION_CLOSE, - // whichever comes first. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2-3 - c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) +} + +// sendConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer. +func (c *Conn) sentConnectionClose(now time.Time) { + switch c.lifetime.state { + case connStatePeerClosed: + c.enterDraining(now) } if c.lifetime.connCloseSentTime.IsZero() { - // We haven't sent a CONNECTION_CLOSE yet. Do so. - // Either we're initiating an immediate close - // (and will enter the closing state as soon as we send CONNECTION_CLOSE), - // or we've read a CONNECTION_CLOSE from our peer - // (and may send one CONNECTION_CLOSE before entering the draining state). - // // Set the initial delay before we will send another CONNECTION_CLOSE. // // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames, @@ -126,65 +181,56 @@ func (c *Conn) sendOK(now time.Time) bool { // with the same delay as the PTO timer (RFC 9002, Section 6.2.1), // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent. c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity) - c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) - return true - } - if c.isDraining() { - // We are in the draining state, and will send no more packets. - return false - } - maxRecvTime := c.acks[initialSpace].maxRecvTime - if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { - maxRecvTime = t - } - if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { - maxRecvTime = t - } - if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { - // After sending CONNECTION_CLOSE, ignore packets from the peer for - // a delay. On the next packet received after the delay, send another - // CONNECTION_CLOSE. - return false + } else if !c.lifetime.connCloseSentTime.Equal(now) { + // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames + // coalesced into the same datagram. We only want to increase the delay once. + c.lifetime.connCloseDelay *= 2 } c.lifetime.connCloseSentTime = now - c.lifetime.connCloseDelay *= 2 - return true } -// enterDraining enters the draining state. -func (c *Conn) enterDraining(now time.Time, err error) { - if c.isDraining() { - return +// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer. +func (c *Conn) handlePeerConnectionClose(now time.Time, err error) { + c.setFinalError(err) + switch c.lifetime.state { + case connStateAlive: + c.setState(now, connStatePeerClosed) + case connStatePeerClosed: + // Duplicate CONNECTION_CLOSE, ignore. + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + c.setState(now, connStatePeerClosed) + } else { + c.setState(now, connStateDraining) + } + case connStateDraining: + case connStateDone: } - if err == errStatelessReset { - // If we've received a stateless reset, then we must not send a CONNECTION_CLOSE. - // Setting connCloseSentTime here prevents us from doing so. - c.lifetime.finalErr = errStatelessReset - c.lifetime.localErr = errStatelessReset - c.lifetime.connCloseSentTime = now - } else if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo { - // If we've terminated the connection due to a peer protocol violation, - // record the final error on the connection as our reason for termination. - c.lifetime.finalErr = c.lifetime.localErr - } else { - c.lifetime.finalErr = err +} + +// setFinalError records the final connection status we report to the user. +func (c *Conn) setFinalError(err error) { + select { + case <-c.lifetime.donec: + return // already set + default: } - close(c.lifetime.drainingc) - c.streams.queue.close(c.lifetime.finalErr) + c.lifetime.finalErr = err + close(c.lifetime.donec) } func (c *Conn) waitReady(ctx context.Context) error { select { case <-c.lifetime.readyc: return nil - case <-c.lifetime.drainingc: + case <-c.lifetime.donec: return c.lifetime.finalErr default: } select { case <-c.lifetime.readyc: return nil - case <-c.lifetime.drainingc: + case <-c.lifetime.donec: return c.lifetime.finalErr case <-ctx.Done(): return ctx.Err() @@ -199,7 +245,7 @@ func (c *Conn) waitReady(ctx context.Context) error { // err := conn.Wait(context.Background()) func (c *Conn) Close() error { c.Abort(nil) - <-c.lifetime.drainingc + <-c.lifetime.donec return c.lifetime.finalErr } @@ -213,7 +259,7 @@ func (c *Conn) Close() error { // containing the peer's error code and reason. // If the peer closes the connection with any other status, Wait returns a non-nil error. func (c *Conn) Wait(ctx context.Context) error { - if err := c.waitOnDone(ctx, c.lifetime.drainingc); err != nil { + if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil { return err } return c.lifetime.finalErr @@ -229,30 +275,46 @@ func (c *Conn) Abort(err error) { err = localTransportError{code: errNo} } c.sendMsg(func(now time.Time, c *Conn) { - c.abort(now, err) + c.enterClosing(now, err) }) } // abort terminates a connection with an error. func (c *Conn) abort(now time.Time, err error) { - if c.lifetime.localErr != nil { - return // already closing - } - c.lifetime.localErr = err + c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE + c.enterClosing(now, err) } // abortImmediately terminates a connection. // The connection does not send a CONNECTION_CLOSE, and skips the draining period. func (c *Conn) abortImmediately(now time.Time, err error) { - c.abort(now, err) - c.enterDraining(now, err) - c.exited = true + c.setFinalError(err) + c.setState(now, connStateDone) +} + +// enterClosing starts an immediate close. +// We will send a CONNECTION_CLOSE to the peer and wait for their response. +func (c *Conn) enterClosing(now time.Time, err error) { + switch c.lifetime.state { + case connStateAlive: + c.lifetime.localErr = err + c.setState(now, connStateClosing) + case connStatePeerClosed: + c.lifetime.localErr = err + } +} + +// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE. +func (c *Conn) enterDraining(now time.Time) { + switch c.lifetime.state { + case connStateAlive, connStatePeerClosed, connStateClosing: + c.setState(now, connStateDraining) + } } // exit fully terminates a connection immediately. func (c *Conn) exit() { c.sendMsg(func(now time.Time, c *Conn) { - c.enterDraining(now, errors.New("connection closed")) - c.exited = true + c.abortImmediately(now, errors.New("connection closed")) }) } diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 896c6d74e..156ef5dd5 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -61,7 +61,7 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { // Invalid data at the end of a datagram is ignored. break } - c.idleTimeout = now.Add(c.maxIdleTimeout) + c.idleHandlePacketReceived(now) buf = buf[n:] } } @@ -525,7 +525,7 @@ func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte if n < 0 { return -1 } - c.enterDraining(now, peerTransportError{code: code, reason: reason}) + c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason}) return n } @@ -534,7 +534,7 @@ func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []by if n < 0 { return -1 } - c.enterDraining(now, &ApplicationError{Code: code, Reason: reason}) + c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason}) return n } @@ -548,7 +548,7 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa }) return -1 } - if !c.isClosingOrDraining() { + if c.isAlive() { c.confirmHandshake(now) } return 1 @@ -560,5 +560,6 @@ func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToke if !c.connIDState.isValidStatelessResetToken(resetToken) { return } - c.enterDraining(now, errStatelessReset) + c.setFinalError(errStatelessReset) + c.enterDraining(now) } diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 22e780479..e45dc8af3 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -77,6 +77,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { + c.idleHandlePacketSent(now, sentInitial) // Client initial packets and ack-eliciting server initial packaets // need to be sent in a datagram padded to at least 1200 bytes. // We can't add the padding yet, however, since we may want to @@ -104,6 +105,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { + c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, handshakeSpace, sent) if c.side == clientSide { // "[...] a client MUST discard Initial keys when it first @@ -131,6 +133,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { + c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, appDataSpace, sent) } } @@ -261,6 +264,10 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, if !c.appendStreamFrames(&c.w, pnum, pto) { return } + + if !c.appendKeepAlive(now) { + return + } } // If this is a PTO probe and we haven't added an ack-eliciting frame yet, @@ -325,7 +332,7 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { } func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { - c.lifetime.connCloseSentTime = now + c.sentConnectionClose(now) switch e := err.(type) { case localTransportError: c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason) @@ -342,11 +349,12 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err // TLS alerts are sent using error codes [0x0100,0x01ff). // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 var alert tls.AlertError - if errors.As(err, &alert) { + switch { + case errors.As(err, &alert): // tls.AlertError is a uint8, so this can't exceed 0x01ff. code := errTLSBase + transportError(alert) c.w.appendConnectionCloseTransportFrame(code, 0, "") - } else { + default: c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") } } diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 514a8775e..70ba7b392 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -25,6 +25,7 @@ var testVV = flag.Bool("vv", false, "even more verbose test output") func TestConnTestConn(t *testing.T) { tc := newTestConn(t, serverSide) + tc.handshake() if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) } @@ -49,8 +50,8 @@ func TestConnTestConn(t *testing.T) { tc.wait() tc.advanceToTimer() - if !tc.conn.exited { - t.Errorf("after advancing to idle timeout, exited = false, want true") + if got := tc.conn.lifetime.state; got != connStateDone { + t.Errorf("after advancing to idle timeout, conn state = %v, want done", got) } } diff --git a/internal/quic/idle.go b/internal/quic/idle.go new file mode 100644 index 000000000..f5b2422ad --- /dev/null +++ b/internal/quic/idle.go @@ -0,0 +1,170 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "time" +) + +// idleState tracks connection idle events. +// +// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout. +// +// After the handshake is confirmed, the idle timeout is +// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter. +// +// If KeepAlivePeriod is set, keep-alive pings are sent. +// Keep-alives are only sent after the handshake is confirmed. +// +// https://www.rfc-editor.org/rfc/rfc9000#section-10.1 +type idleState struct { + // idleDuration is the negotiated idle timeout for the connection. + idleDuration time.Duration + + // idleTimeout is the time at which the connection will be closed due to inactivity. + idleTimeout time.Time + + // nextTimeout is the time of the next idle event. + // If nextTimeout == idleTimeout, this is the idle timeout. + // Otherwise, this is the keep-alive timeout. + nextTimeout time.Time + + // sentSinceLastReceive is set if we have sent an ack-eliciting packet + // since the last time we received and processed a packet from the peer. + sentSinceLastReceive bool +} + +// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter. +func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) { + localMaxIdleTimeout := c.config.maxIdleTimeout() + switch { + case localMaxIdleTimeout == 0: + c.idle.idleDuration = peerMaxIdleTimeout + case peerMaxIdleTimeout == 0: + c.idle.idleDuration = localMaxIdleTimeout + default: + c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout) + } +} + +func (c *Conn) idleHandlePacketReceived(now time.Time) { + if !c.handshakeConfirmed.isSet() { + return + } + // "An endpoint restarts its idle timer when a packet from its peer is + // received and processed successfully." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + c.idle.sentSinceLastReceive = false + c.restartIdleTimer(now) +} + +func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) { + // "An endpoint also restarts its idle timer when sending an ack-eliciting packet + // if no other ack-eliciting packets have been sent since + // last receiving and processing a packet." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() { + return + } + c.idle.sentSinceLastReceive = true + c.restartIdleTimer(now) +} + +func (c *Conn) restartIdleTimer(now time.Time) { + if !c.isAlive() { + // Connection is closing, disable timeouts. + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + return + } + var idleDuration time.Duration + if c.handshakeConfirmed.isSet() { + idleDuration = c.idle.idleDuration + } else { + idleDuration = c.config.handshakeTimeout() + } + if idleDuration == 0 { + c.idle.idleTimeout = time.Time{} + } else { + // "[...] endpoints MUST increase the idle timeout period to be + // at least three times the current Probe Timeout (PTO)." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4 + idleDuration = max(idleDuration, 3*c.loss.ptoPeriod()) + c.idle.idleTimeout = now.Add(idleDuration) + } + // Set the time of our next event: + // The idle timer if no keep-alive is set, or the keep-alive timer if one is. + c.idle.nextTimeout = c.idle.idleTimeout + keepAlive := c.config.keepAlivePeriod() + switch { + case !c.handshakeConfirmed.isSet(): + // We do not send keep-alives before the handshake is complete. + case keepAlive <= 0: + // Keep-alives are not enabled. + case c.idle.sentSinceLastReceive: + // We have sent an ack-eliciting packet to the peer. + // If they don't acknowledge it, loss detection will follow up with PTO probes, + // which will function as keep-alives. + // We don't need to send further pings. + case idleDuration == 0: + // The connection does not have a negotiated idle timeout. + // Send keep-alives anyway, since they may be required to keep middleboxes + // from losing state. + c.idle.nextTimeout = now.Add(keepAlive) + default: + // Schedule our next keep-alive. + // If our configured keep-alive period is greater than half the negotiated + // connection idle timeout, we reduce the keep-alive period to half + // the idle timeout to ensure we have time for the ping to arrive. + c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2)) + } +} + +func (c *Conn) appendKeepAlive(now time.Time) bool { + if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) { + return true // timer has not expired + } + if c.idle.nextTimeout.Equal(c.idle.idleTimeout) { + return true // no keepalive timer set, only idle + } + if c.idle.sentSinceLastReceive { + return true // already sent an ack-eliciting packet + } + if c.w.sent.ackEliciting { + return true // this packet is already ack-eliciting + } + // Send an ack-eliciting PING frame to the peer to keep the connection alive. + return c.w.appendPingFrame() +} + +var errHandshakeTimeout error = localTransportError{ + code: errConnectionRefused, + reason: "handshake timeout", +} + +func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) { + if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) { + return false + } + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + if !c.handshakeConfirmed.isSet() { + // Handshake timeout has expired. + // If we're a server, we're refusing the too-slow client. + // If we're a client, we're giving up. + // In either case, we're going to send a CONNECTION_CLOSE frame and + // enter the closing state rather than unceremoniously dropping the connection, + // since the peer might still be trying to complete the handshake. + c.abort(now, errHandshakeTimeout) + return false + } + // Idle timeout has expired. + // + // "[...] the connection is silently closed and its state is discarded [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 + return true +} diff --git a/internal/quic/idle_test.go b/internal/quic/idle_test.go new file mode 100644 index 000000000..18f6a690a --- /dev/null +++ b/internal/quic/idle_test.go @@ -0,0 +1,225 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "crypto/tls" + "fmt" + "testing" + "time" +) + +func TestHandshakeTimeoutExpiresServer(t *testing.T) { + const timeout = 5 * time.Second + tc := newTestConn(t, serverSide, func(c *Config) { + c.HandshakeTimeout = timeout + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeNewConnectionID) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + // Server starts its end of the handshake. + // Client acks these packets to avoid starting the PTO timer. + tc.wantFrameType("server sends Initial CRYPTO flight", + packetTypeInitial, debugFrameCrypto{}) + tc.writeAckForAll() + tc.wantFrameType("server sends Handshake CRYPTO flight", + packetTypeHandshake, debugFrameCrypto{}) + tc.writeAckForAll() + + if got, want := tc.timerDelay(), timeout; got != want { + t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) + } + + // Client sends a packet, but this does not extend the handshake timer. + tc.advance(1 * time.Second) + tc.writeFrames(packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], // partial data + }) + tc.wantIdle("handshake is not complete") + + tc.advance(timeout - 1*time.Second) + tc.wantFrame("server closes connection after handshake timeout", + packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) +} + +func TestHandshakeTimeoutExpiresClient(t *testing.T) { + const timeout = 5 * time.Second + tc := newTestConn(t, clientSide, func(c *Config) { + c.HandshakeTimeout = timeout + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeNewConnectionID) + // Start the handshake. + // The client always sets a PTO timer until it gets an ack for a handshake packet + // or confirms the handshake, so proceed far enough through the handshake to + // let us not worry about PTO. + tc.wantFrameType("client sends Initial CRYPTO flight", + packetTypeInitial, debugFrameCrypto{}) + tc.writeAckForAll() + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("client sends Handshake CRYPTO flight", + packetTypeHandshake, debugFrameCrypto{}) + tc.writeAckForAll() + tc.wantIdle("client is waiting for end of handshake") + + if got, want := tc.timerDelay(), timeout; got != want { + t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) + } + tc.advance(timeout) + tc.wantFrame("client closes connection after handshake timeout", + packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) +} + +func TestIdleTimeoutExpires(t *testing.T) { + for _, test := range []struct { + localMaxIdleTimeout time.Duration + peerMaxIdleTimeout time.Duration + wantTimeout time.Duration + }{{ + localMaxIdleTimeout: 10 * time.Second, + peerMaxIdleTimeout: 20 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 20 * time.Second, + peerMaxIdleTimeout: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 0, + peerMaxIdleTimeout: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 10 * time.Second, + peerMaxIdleTimeout: 0, + wantTimeout: 10 * time.Second, + }} { + name := fmt.Sprintf("local=%v/peer=%v", test.localMaxIdleTimeout, test.peerMaxIdleTimeout) + t.Run(name, func(t *testing.T) { + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.maxIdleTimeout = test.peerMaxIdleTimeout + }, func(c *Config) { + c.MaxIdleTimeout = test.localMaxIdleTimeout + }) + tc.handshake() + if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { + t.Errorf("new conn timeout=%v, want %v (idle timeout)", got, want) + } + tc.advance(test.wantTimeout - 1) + tc.wantIdle("connection is idle and alive prior to timeout") + ctx := canceledContext() + if err := tc.conn.Wait(ctx); err != context.Canceled { + t.Fatalf("conn.Wait() = %v, want Canceled", err) + } + tc.advance(1) + tc.wantIdle("connection exits after timeout") + if err := tc.conn.Wait(ctx); err != errIdleTimeout { + t.Fatalf("conn.Wait() = %v, want errIdleTimeout", err) + } + }) + } +} + +func TestIdleTimeoutKeepAlive(t *testing.T) { + for _, test := range []struct { + idleTimeout time.Duration + keepAlive time.Duration + wantTimeout time.Duration + }{{ + idleTimeout: 30 * time.Second, + keepAlive: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + idleTimeout: 10 * time.Second, + keepAlive: 30 * time.Second, + wantTimeout: 5 * time.Second, + }, { + idleTimeout: -1, // disabled + keepAlive: 30 * time.Second, + wantTimeout: 30 * time.Second, + }} { + name := fmt.Sprintf("idle_timeout=%v/keepalive=%v", test.idleTimeout, test.keepAlive) + t.Run(name, func(t *testing.T) { + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxIdleTimeout = test.idleTimeout + c.KeepAlivePeriod = test.keepAlive + }) + tc.handshake() + if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { + t.Errorf("new conn timeout=%v, want %v (keepalive timeout)", got, want) + } + tc.advance(test.wantTimeout - 1) + tc.wantIdle("connection is idle prior to timeout") + tc.advance(1) + tc.wantFrameType("keep-alive ping is sent", packetType1RTT, + debugFramePing{}) + }) + } +} + +func TestIdleLongTermKeepAliveSent(t *testing.T) { + // This test examines a connection sitting idle and sending periodic keep-alive pings. + const keepAlivePeriod = 30 * time.Second + tc := newTestConn(t, clientSide, func(c *Config) { + c.KeepAlivePeriod = keepAlivePeriod + c.MaxIdleTimeout = -1 + }) + tc.handshake() + // The handshake will have completed a little bit after the point at which the + // keepalive timer was set. Send two PING frames to the conn, triggering an immediate ack + // and resetting the timer. + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantFrameType("conn acks received pings", packetType1RTT, debugFrameAck{}) + for i := 0; i < 10; i++ { + tc.wantIdle("conn has nothing more to send") + if got, want := tc.timeUntilEvent(), keepAlivePeriod; got != want { + t.Errorf("i=%v conn timeout=%v, want %v (keepalive timeout)", i, got, want) + } + tc.advance(keepAlivePeriod) + tc.wantFrameType("keep-alive ping is sent", packetType1RTT, + debugFramePing{}) + tc.writeAckForAll() + } +} + +func TestIdleLongTermKeepAliveReceived(t *testing.T) { + // This test examines a connection sitting idle, but receiving periodic peer + // traffic to keep the connection alive. + const idleTimeout = 30 * time.Second + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxIdleTimeout = idleTimeout + }) + tc.handshake() + for i := 0; i < 10; i++ { + tc.advance(idleTimeout - 1*time.Second) + tc.writeFrames(packetType1RTT, debugFramePing{}) + if got, want := tc.timeUntilEvent(), maxAckDelay-timerGranularity; got != want { + t.Errorf("i=%v conn timeout=%v, want %v (max_ack_delay)", i, got, want) + } + tc.advanceToTimer() + tc.wantFrameType("conn acks received ping", packetType1RTT, debugFrameAck{}) + } + // Connection is still alive. + ctx := canceledContext() + if err := tc.conn.Wait(ctx); err != context.Canceled { + t.Fatalf("conn.Wait() = %v, want Canceled", err) + } +} diff --git a/internal/quic/loss.go b/internal/quic/loss.go index c0f915b42..4a0767bd0 100644 --- a/internal/quic/loss.go +++ b/internal/quic/loss.go @@ -431,12 +431,15 @@ func (c *lossState) scheduleTimer(now time.Time) { c.timer = time.Time{} return } - // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 - pto := c.ptoBasePeriod() << c.ptoBackoffCount - c.timer = last.Add(pto) + c.timer = last.Add(c.ptoPeriod()) c.ptoTimerArmed = true } +func (c *lossState) ptoPeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + return c.ptoBasePeriod() << c.ptoBackoffCount +} + func (c *lossState) ptoBasePeriod() time.Duration { // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity) diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index 29875693e..c8ee429fe 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -119,8 +119,13 @@ func (c *Conn) logConnectionClosed() { // TODO: Distinguish between peer and locally-initiated close. trigger = "application" case localTransportError: - if e.code == errNo { - trigger = "clean" + switch err { + case errHandshakeTimeout: + trigger = "handshake_timeout" + default: + if e.code == errNo { + trigger = "clean" + } } case peerTransportError: if e.code == errNo { @@ -128,10 +133,11 @@ func (c *Conn) logConnectionClosed() { } default: switch err { + case errIdleTimeout: + trigger = "idle_timeout" case errStatelessReset: trigger = "stateless_reset" } - // TODO: idle_timeout, handshake_timeout } // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3 c.log.LogAttrs(context.Background(), QLogLevelEndpoint, diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go index 5a2858b8b..119f5d16a 100644 --- a/internal/quic/qlog_test.go +++ b/internal/quic/qlog_test.go @@ -14,6 +14,7 @@ import ( "log/slog" "reflect" "testing" + "time" "golang.org/x/net/internal/quic/qlog" ) @@ -54,6 +55,75 @@ func TestQLogHandshake(t *testing.T) { }) } +func TestQLogConnectionClosedTrigger(t *testing.T) { + for _, test := range []struct { + trigger string + connOpts []any + f func(*testConn) + }{{ + trigger: "clean", + f: func(tc *testConn) { + tc.handshake() + tc.conn.Abort(nil) + }, + }, { + trigger: "handshake_timeout", + connOpts: []any{ + func(c *Config) { + c.HandshakeTimeout = 5 * time.Second + }, + }, + f: func(tc *testConn) { + tc.ignoreFrame(frameTypeCrypto) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypePing) + tc.advance(5 * time.Second) + }, + }, { + trigger: "idle_timeout", + connOpts: []any{ + func(c *Config) { + c.MaxIdleTimeout = 5 * time.Second + }, + }, + f: func(tc *testConn) { + tc.handshake() + tc.advance(5 * time.Second) + }, + }, { + trigger: "error", + f: func(tc *testConn) { + tc.handshake() + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) + tc.conn.Abort(nil) + }, + }} { + t.Run(test.trigger, func(t *testing.T) { + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, append(test.connOpts, qr.config)...) + test.f(tc) + fr, ptype := tc.readFrame() + switch fr := fr.(type) { + case debugFrameConnectionCloseTransport: + tc.writeFrames(ptype, fr) + case nil: + default: + t.Fatalf("unexpected frame: %v", fr) + } + tc.wantIdle("connection should be idle while closing") + tc.advance(5 * time.Second) // long enough for the drain timer to expire + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": test.trigger, + }, + }) + }) + } +} + type nopCloseWriter struct { io.Writer } diff --git a/internal/quic/quic.go b/internal/quic/quic.go index 084887be6..6b60db869 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -54,6 +54,12 @@ const ( maxPeerActiveConnIDLimit = 4 ) +// Time limit for completing the handshake. +const defaultHandshakeTimeout = 10 * time.Second + +// Keep-alive ping frequency. +const defaultKeepAlivePeriod = 0 + // Local timer granularity. // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6 const timerGranularity = 1 * time.Millisecond From 399218d6bcdde008df7f43cf82a92b69e842c049 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 30 Oct 2023 15:07:39 -0700 Subject: [PATCH 03/70] quic: implement stream flush Do not commit data written to a stream to the network until the user explicitly flushes the stream, the stream output buffer fills, or the output buffer contains enough data to fill a packet. We could write data immediately (as net.TCPConn does), but this can require the user to put their own buffer in front of the stream. Since we necessarily need to maintain a retransmit buffer in the stream, this is redundant. We could do something like Nagle's algorithm, but nobody wants that. So make flushes explicit. For golang/go#58547 Change-Id: I29dc9d79556c7a358a360ef79beb38b45040b6bc Reviewed-on: https://go-review.googlesource.com/c/net/+/543083 Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn.go | 4 +- internal/quic/conn_flow_test.go | 7 ++ internal/quic/conn_loss_test.go | 5 +- internal/quic/conn_streams_test.go | 16 ++--- internal/quic/quic.go | 6 +- internal/quic/stream.go | 55 +++++++++++---- internal/quic/stream_test.go | 108 ++++++++++++++++++++++++++++- 7 files changed, 171 insertions(+), 30 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index b2b6a0877..ff96ff760 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -136,12 +136,10 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip } } - // The smallest allowed maximum QUIC datagram size is 1200 bytes. // TODO: PMTU discovery. - const maxDatagramSize = 1200 c.logConnectionStarted(cids.originalDstConnID, peerAddr) c.keysAppData.init() - c.loss.init(c.side, maxDatagramSize, now) + c.loss.init(c.side, smallestMaxDatagramSize, now) c.streamsInit() c.lifetimeInit() c.restartIdleTimer(now) diff --git a/internal/quic/conn_flow_test.go b/internal/quic/conn_flow_test.go index 03e0757a6..39c879346 100644 --- a/internal/quic/conn_flow_test.go +++ b/internal/quic/conn_flow_test.go @@ -262,6 +262,7 @@ func TestConnOutflowBlocked(t *testing.T) { if n != len(data) || err != nil { t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } + s.Flush() tc.wantFrame("stream writes data up to MAX_DATA limit", packetType1RTT, debugFrameStream{ @@ -310,6 +311,7 @@ func TestConnOutflowMaxDataDecreases(t *testing.T) { if n != len(data) || err != nil { t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } + s.Flush() tc.wantFrame("stream writes data up to MAX_DATA limit", packetType1RTT, debugFrameStream{ @@ -337,7 +339,9 @@ func TestConnOutflowMaxDataRoundRobin(t *testing.T) { } s1.Write(make([]byte, 10)) + s1.Flush() s2.Write(make([]byte, 10)) + s2.Flush() tc.writeFrames(packetType1RTT, debugFrameMaxData{ max: 1, @@ -378,6 +382,7 @@ func TestConnOutflowMetaAndData(t *testing.T) { data := makeTestData(32) s.Write(data) + s.Flush() s.CloseRead() tc.wantFrame("CloseRead sends a STOP_SENDING, not flow controlled", @@ -405,6 +410,7 @@ func TestConnOutflowResentData(t *testing.T) { data := makeTestData(15) s.Write(data[:8]) + s.Flush() tc.wantFrame("data is under MAX_DATA limit, all sent", packetType1RTT, debugFrameStream{ id: s.id, @@ -421,6 +427,7 @@ func TestConnOutflowResentData(t *testing.T) { }) s.Write(data[8:]) + s.Flush() tc.wantFrame("new data is sent up to the MAX_DATA limit", packetType1RTT, debugFrameStream{ id: s.id, diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index 5144be6ac..818816335 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -183,7 +183,7 @@ func TestLostStreamFrameEmpty(t *testing.T) { if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + c.Flush() // open the stream tc.wantFrame("created bidirectional stream 0", packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, bidiStream, 0), @@ -213,6 +213,7 @@ func TestLostStreamWithData(t *testing.T) { p.initialMaxStreamDataUni = 1 << 20 }) s.Write(data[:4]) + s.Flush() tc.wantFrame("send [0,4)", packetType1RTT, debugFrameStream{ id: s.id, @@ -220,6 +221,7 @@ func TestLostStreamWithData(t *testing.T) { data: data[:4], }) s.Write(data[4:8]) + s.Flush() tc.wantFrame("send [4,8)", packetType1RTT, debugFrameStream{ id: s.id, @@ -263,6 +265,7 @@ func TestLostStreamPartialLoss(t *testing.T) { }) for i := range data { s.Write(data[i : i+1]) + s.Flush() tc.wantFrame(fmt.Sprintf("send STREAM frame with byte %v", i), packetType1RTT, debugFrameStream{ id: s.id, diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 69f982c3a..c90354db8 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -19,33 +19,33 @@ func TestStreamsCreate(t *testing.T) { tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() - c, err := tc.conn.NewStream(ctx) + s, err := tc.conn.NewStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created bidirectional stream 0", packetType1RTT, debugFrameStream{ id: 0, // client-initiated, bidi, number 0 data: []byte{}, }) - c, err = tc.conn.NewSendOnlyStream(ctx) + s, err = tc.conn.NewSendOnlyStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created unidirectional stream 0", packetType1RTT, debugFrameStream{ id: 2, // client-initiated, uni, number 0 data: []byte{}, }) - c, err = tc.conn.NewStream(ctx) + s, err = tc.conn.NewStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created bidirectional stream 1", packetType1RTT, debugFrameStream{ id: 4, // client-initiated, uni, number 4 @@ -177,11 +177,11 @@ func TestStreamsStreamSendOnly(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() - c, err := tc.conn.NewSendOnlyStream(ctx) + s, err := tc.conn.NewSendOnlyStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created unidirectional stream 0", packetType1RTT, debugFrameStream{ id: 3, // server-initiated, uni, number 0 diff --git a/internal/quic/quic.go b/internal/quic/quic.go index 6b60db869..e4d0d77c7 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -64,10 +64,14 @@ const defaultKeepAlivePeriod = 0 // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6 const timerGranularity = 1 * time.Millisecond +// The smallest allowed maximum datagram size. +// https://www.rfc-editor.org/rfc/rfc9000#section-14 +const smallestMaxDatagramSize = 1200 + // Minimum size of a UDP datagram sent by a client carrying an Initial packet, // or a server containing an ack-eliciting Initial packet. // https://www.rfc-editor.org/rfc/rfc9000#section-14.1 -const paddedInitialDatagramSize = 1200 +const paddedInitialDatagramSize = smallestMaxDatagramSize // Maximum number of streams of a given type which may be created. // https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 58d84ed1b..36c80f6af 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -38,10 +38,11 @@ type Stream struct { // the write will fail. outgate gate out pipe // buffered data to send + outflushed int64 // offset of last flush call outwin int64 // maximum MAX_STREAM_DATA received from the peer outmaxsent int64 // maximum data offset we've sent to the peer outmaxbuf int64 // maximum amount of data we will buffer - outunsent rangeset[int64] // ranges buffered but not yet sent + outunsent rangeset[int64] // ranges buffered but not yet sent (only flushed data) outacked rangeset[int64] // ranges sent and acknowledged outopened sentVal // set if we should open the stream outclosed sentVal // set by CloseWrite @@ -240,8 +241,6 @@ func (s *Stream) Write(b []byte) (n int, err error) { // WriteContext writes data to the stream write buffer. // Buffered data is only sent when the buffer is sufficiently full. // Call the Flush method to ensure buffered data is sent. -// -// TODO: Implement Flush. func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) { if s.IsReadOnly() { return 0, errors.New("write to read-only stream") @@ -269,10 +268,6 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) s.outUnlock() return n, errors.New("write to closed stream") } - // We set outopened here rather than below, - // so if this is a zero-length write we still - // open the stream despite not writing any data to it. - s.outopened.set() if len(b) == 0 { break } @@ -282,13 +277,26 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) // Amount to write is min(the full buffer, data up to the write limit). // This is a number of bytes. nn := min(int64(len(b)), lim-s.out.end) - // Copy the data into the output buffer and mark it as unsent. - if s.out.end <= s.outwin { - s.outunsent.add(s.out.end, min(s.out.end+nn, s.outwin)) - } + // Copy the data into the output buffer. s.out.writeAt(b[:nn], s.out.end) b = b[nn:] n += int(nn) + // Possibly flush the output buffer. + // We automatically flush if: + // - We have enough data to consume the send window. + // Sending this data may cause the peer to extend the window. + // - We have buffered as much data as we're willing do. + // We need to send data to clear out buffer space. + // - We have enough data to fill a 1-RTT packet using the smallest + // possible maximum datagram size (1200 bytes, less header byte, + // connection ID, packet number, and AEAD overhead). + const autoFlushSize = smallestMaxDatagramSize - 1 - connIDLen - 1 - aeadOverhead + shouldFlush := s.out.end >= s.outwin || // peer send window is full + s.out.end >= lim || // local send buffer is full + (s.out.end-s.outflushed) >= autoFlushSize // enough data buffered + if shouldFlush { + s.flushLocked() + } if s.out.end > s.outwin { // We're blocked by flow control. // Send a STREAM_DATA_BLOCKED frame to let the peer know. @@ -301,6 +309,23 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) return n, nil } +// Flush flushes data written to the stream. +// It does not wait for the peer to acknowledge receipt of the data. +// Use CloseContext to wait for the peer's acknowledgement. +func (s *Stream) Flush() { + s.outgate.lock() + defer s.outUnlock() + s.flushLocked() +} + +func (s *Stream) flushLocked() { + s.outopened.set() + if s.outflushed < s.outwin { + s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) + } + s.outflushed = s.out.end +} + // Close closes the stream. // See CloseContext for more details. func (s *Stream) Close() error { @@ -363,6 +388,7 @@ func (s *Stream) CloseWrite() { s.outgate.lock() defer s.outUnlock() s.outclosed.set() + s.flushLocked() } // Reset aborts writes on the stream and notifies the peer @@ -612,8 +638,8 @@ func (s *Stream) handleMaxStreamData(maxStreamData int64) error { if maxStreamData <= s.outwin { return nil } - if s.out.end > s.outwin { - s.outunsent.add(s.outwin, min(maxStreamData, s.out.end)) + if s.outflushed > s.outwin { + s.outunsent.add(s.outwin, min(maxStreamData, s.outflushed)) } s.outwin = maxStreamData if s.out.end > s.outwin { @@ -741,10 +767,11 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b } for { // STREAM - off, size := dataToSend(min(s.out.start, s.outwin), min(s.out.end, s.outwin), s.outunsent, s.outacked, pto) + off, size := dataToSend(min(s.out.start, s.outwin), min(s.outflushed, s.outwin), s.outunsent, s.outacked, pto) if end := off + size; end > s.outmaxsent { // This will require connection-level flow control to send. end = min(end, s.outmaxsent+s.conn.streams.outflow.avail()) + end = max(end, off) size = end - off } fin := s.outclosed.isSet() && off+size == s.out.end diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 9bf2b5871..93c8839ff 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -38,6 +38,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { if n != writeBufferSize || err != context.Canceled { t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) } + s.Flush() tc.wantFrame("first write buffer of data sent", packetType1RTT, debugFrameStream{ id: s.id, @@ -47,7 +48,9 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { // Blocking write, which must wait for buffer space. w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want[writeBufferSize:]) + n, err := s.WriteContext(ctx, want[writeBufferSize:]) + s.Flush() + return n, err }) tc.wantIdle("write buffer is full, no more data can be sent") @@ -170,6 +173,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { t.Fatal(err) } s.WriteContext(ctx, want[:1]) + s.Flush() tc.wantFrame("sent data (1 byte) fits within flow control limit", packetType1RTT, debugFrameStream{ id: s.id, @@ -723,7 +727,7 @@ func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFra if err != nil { t.Fatal(err) } - s.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("new stream is opened", packetType1RTT, debugFrameStream{ id: sid, @@ -968,7 +972,9 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { want := make([]byte, 4096) rand.Read(want) // doesn't need to be crypto/rand, but non-deprecated and harmless w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want) + n, err := s.WriteContext(ctx, want) + s.Flush() + return n, err }) got := make([]byte, 0, len(want)) for { @@ -998,6 +1004,7 @@ func TestStreamCloseWaitsForAcks(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.WriteContext(ctx, data) + s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, @@ -1064,6 +1071,7 @@ func TestStreamCloseUnblocked(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.WriteContext(ctx, data) + s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, @@ -1228,6 +1236,7 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, styp, permissiveTransportParameters) for i := 0; i < 4; i++ { s.Write([]byte{byte(i)}) + s.Flush() tc.wantFrame("write sends a STREAM frame to peer", packetType1RTT, debugFrameStream{ id: s.id, @@ -1271,6 +1280,99 @@ func TestStreamReceiveDataBlocked(t *testing.T) { tc.wantIdle("no response to STREAM_DATA_BLOCKED and DATA_BLOCKED") } +func TestStreamFlushExplicit(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters) + want := []byte{0, 1, 2, 3} + n, err := s.Write(want) + if n != len(want) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantIdle("unflushed data is not sent") + s.Flush() + tc.wantFrame("data is sent after flush", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want, + }) + }) +} + +func TestStreamFlushImplicitExact(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + const writeBufferSize = 4 + tc, s := newTestConnAndLocalStream(t, clientSide, styp, + permissiveTransportParameters, + func(c *Config) { + c.MaxStreamWriteBufferSize = writeBufferSize + }) + want := []byte{0, 1, 2, 3, 4, 5, 6} + + // This write doesn't quite fill the output buffer. + n, err := s.Write(want[:3]) + if n != 3 || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantIdle("unflushed data is not sent") + + // This write fills the output buffer exactly. + n, err = s.Write(want[3:4]) + if n != 1 || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantFrame("data is sent after write buffer fills", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want[0:4], + }) + + }) +} + +func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + const writeBufferSize = 4 + tc, s := newTestConnAndLocalStream(t, clientSide, styp, + permissiveTransportParameters, + func(c *Config) { + c.MaxStreamWriteBufferSize = writeBufferSize + }) + want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + w := runAsync(tc, func(ctx context.Context) (int, error) { + n, err := s.WriteContext(ctx, want) + return n, err + }) + + tc.wantFrame("data is sent after write buffer fills", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want[0:4], + }) + tc.writeAckForAll() + tc.wantFrame("ack permits sending more data", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 4, + data: want[4:8], + }) + tc.writeAckForAll() + + tc.wantIdle("write buffer is not full") + if n, err := w.result(); n != len(want) || err != nil { + t.Fatalf("Write() = %v, %v; want %v, nil", n, err, len(want)) + } + + s.Flush() + tc.wantFrame("flush sends last buffer of data", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 8, + data: want[8:], + }) + }) +} + type streamSide string const ( From e26b9a44574ff997838ad359431007e3a9ee6766 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 30 Oct 2023 15:18:17 -0700 Subject: [PATCH 04/70] quic: rename Listener to Endpoint The name Listener is confusing, because unlike a net.Listener a quic.Listener manages outgoing connections as well as inbound ones. Rename to "endpoint" which doesn't map to any existing net package name and matches the terminology of the QUIC RFCs. For golang/go#58547 Change-Id: If87f8c67ac7dd15d89d2d082a8ba2c63ea7f6e26 Reviewed-on: https://go-review.googlesource.com/c/net/+/543298 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/cmd/interop/main.go | 4 +- internal/quic/conn.go | 14 +- internal/quic/conn_close_test.go | 6 +- internal/quic/conn_id.go | 18 +-- internal/quic/conn_id_test.go | 12 +- internal/quic/conn_send.go | 2 +- internal/quic/conn_test.go | 70 +++++----- internal/quic/listener.go | 170 ++++++++++++------------ internal/quic/listener_test.go | 180 +++++++++++++------------- internal/quic/qlog.go | 2 +- internal/quic/retry.go | 16 +-- internal/quic/retry_test.go | 46 +++---- internal/quic/stateless_reset_test.go | 14 +- internal/quic/tls_test.go | 4 +- internal/quic/version_test.go | 12 +- 15 files changed, 285 insertions(+), 285 deletions(-) diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index 2ca5d652a..20f737b52 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -157,7 +157,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) { } } -func serve(ctx context.Context, l *quic.Listener) error { +func serve(ctx context.Context, l *quic.Endpoint) error { for { c, err := l.Accept(ctx) if err != nil { @@ -221,7 +221,7 @@ func parseURL(s string) (u *url.URL, authority string, err error) { return u, authority, nil } -func fetchFrom(ctx context.Context, l *quic.Listener, addr string, urls []*url.URL) { +func fetchFrom(ctx context.Context, l *quic.Endpoint, addr string, urls []*url.URL) { conn, err := l.Dial(ctx, "udp", addr) if err != nil { log.Printf("%v: %v", addr, err) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index ff96ff760..31e789b1d 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -21,7 +21,7 @@ import ( // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn struct { side connSide - listener *Listener + endpoint *Endpoint config *Config testHooks connTestHooks peerAddr netip.AddrPort @@ -92,10 +92,10 @@ type newServerConnIDs struct { retrySrcConnID []byte // source from server's Retry } -func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (conn *Conn, _ error) { +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { c := &Conn{ side: side, - listener: l, + endpoint: e, config: config, peerAddr: peerAddr, msgc: make(chan any, 1), @@ -115,8 +115,8 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip // non-blocking operation. c.msgc = make(chan any, 1) - if l.testHooks != nil { - l.testHooks.newConn(c) + if e.testHooks != nil { + e.testHooks.newConn(c) } // initialConnID is the connection ID used to generate Initial packet protection keys. @@ -187,7 +187,7 @@ func (c *Conn) confirmHandshake(now time.Time) { if c.side == serverSide { // When the server confirms the handshake, it sends a HANDSHAKE_DONE. c.handshakeConfirmed.setUnsent() - c.listener.serverConnEstablished(c) + c.endpoint.serverConnEstablished(c) } else { // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed // to the received state, indicating that the handshake is confirmed and we @@ -265,7 +265,7 @@ var errIdleTimeout = errors.New("idle timeout") func (c *Conn) loop(now time.Time) { defer close(c.donec) defer c.tls.Close() - defer c.listener.connDrained(c) + defer c.endpoint.connDrained(c) defer c.logConnectionClosed() // The connection timer sends a message to the connection loop on expiry. diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go index 0dd46dd20..49881e62f 100644 --- a/internal/quic/conn_close_test.go +++ b/internal/quic/conn_close_test.go @@ -205,13 +205,13 @@ func TestConnCloseReceiveInHandshake(t *testing.T) { tc.wantIdle("no more frames to send") } -func TestConnCloseClosedByListener(t *testing.T) { +func TestConnCloseClosedByEndpoint(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, clientSide) tc.handshake() - tc.listener.l.Close(ctx) - tc.wantFrame("listener closes connection before exiting", + tc.endpoint.e.Close(ctx) + tc.wantFrame("endpoint closes connection before exiting", packetType1RTT, debugFrameConnectionCloseTransport{ code: errNo, }) diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index 439c22123..2efe8d6b5 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -76,7 +76,7 @@ func (s *connIDState) initClient(c *Conn) error { cid: locid, }) s.nextLocalSeq = 1 - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, locid) }) @@ -117,7 +117,7 @@ func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { cid: locid, }) s.nextLocalSeq = 1 - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, dstConnID) conns.addConnID(c, locid) }) @@ -194,7 +194,7 @@ func (s *connIDState) issueLocalIDs(c *Conn) error { s.needSend = true toIssue-- } - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { for _, cid := range newIDs { conns.addConnID(c, cid) } @@ -247,7 +247,7 @@ func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p trans } token := statelessResetToken(p.statelessResetToken) s.remote[0].resetToken = token - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, token) }) } @@ -276,7 +276,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. cid := s.local[0].cid - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:0], s.local[1:]...) @@ -314,7 +314,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re rcid := &s.remote[i] if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { s.retireRemote(rcid) - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireResetToken(c, rcid.resetToken) }) } @@ -350,7 +350,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re s.retireRemote(&s.remote[len(s.remote)-1]) } else { active++ - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, resetToken) }) } @@ -399,7 +399,7 @@ func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { for i := range s.local { if s.local[i].seq == seq { cid := s.local[i].cid - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:i], s.local[i+1:]...) @@ -463,7 +463,7 @@ func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { s.local[i].seq, retireBefore, s.local[i].cid, - c.listener.resetGen.tokenForConnID(s.local[i].cid), + c.endpoint.resetGen.tokenForConnID(s.local[i].cid), ) { return false } diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index 314a6b384..d44472e81 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -651,16 +651,16 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) { // Wait for the conn to drain. // Then wait for the conn loop to exit, // and force an immediate sync of the connsMap updates - // (normally only done by the listener read loop). + // (normally only done by the endpoint read loop). tc.advanceToTimer() <-tc.conn.donec - tc.listener.l.connsMap.applyUpdates() + tc.endpoint.e.connsMap.applyUpdates() - if got := len(tc.listener.l.connsMap.byConnID); got != 0 { - t.Errorf("%v conn ids in listener map after closing, want 0", got) + if got := len(tc.endpoint.e.connsMap.byConnID); got != 0 { + t.Errorf("%v conn ids in endpoint map after closing, want 0", got) } - if got := len(tc.listener.l.connsMap.byResetToken); got != 0 { - t.Errorf("%v reset tokens in listener map after closing, want 0", got) + if got := len(tc.endpoint.e.connsMap.byResetToken); got != 0 { + t.Errorf("%v reset tokens in endpoint map after closing, want 0", got) } }) } diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index e45dc8af3..4065474d2 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -170,7 +170,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } } - c.listener.sendDatagram(buf, c.peerAddr) + c.endpoint.sendDatagram(buf, c.peerAddr) } } diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 70ba7b392..c57ba1487 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -34,12 +34,12 @@ func TestConnTestConn(t *testing.T) { tc.conn.runOnLoop(func(now time.Time, c *Conn) { ranAt = now }) - if !ranAt.Equal(tc.listener.now) { - t.Errorf("func ran on loop at %v, want %v", ranAt, tc.listener.now) + if !ranAt.Equal(tc.endpoint.now) { + t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) } tc.wait() - nextTime := tc.listener.now.Add(defaultMaxIdleTimeout / 2) + nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) tc.advanceTo(nextTime) tc.conn.runOnLoop(func(now time.Time, c *Conn) { ranAt = now @@ -117,7 +117,7 @@ const maxTestKeyPhases = 3 type testConn struct { t *testing.T conn *Conn - listener *testListener + endpoint *testEndpoint timer time.Time timerLastFired time.Time idlec chan struct{} // only accessed on the conn's loop @@ -220,27 +220,27 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { } } - listener := newTestListener(t, config) - listener.configTransportParams = configTransportParams - listener.configTestConn = configTestConn - conn, err := listener.l.newConn( - listener.now, + endpoint := newTestEndpoint(t, config) + endpoint.configTransportParams = configTransportParams + endpoint.configTestConn = configTestConn + conn, err := endpoint.e.newConn( + endpoint.now, side, cids, netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) } - tc := listener.conns[conn] + tc := endpoint.conns[conn] tc.wait() return tc } -func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testConn { +func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn { t.Helper() tc := &testConn{ t: t, - listener: listener, + endpoint: endpoint, conn: conn, peerConnID: testPeerConnID(0), ignoreFrames: map[byte]bool{ @@ -251,14 +251,14 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) - for _, f := range listener.configTestConn { + for _, f := range endpoint.configTestConn { f(tc) } conn.testHooks = (*testConnHooks)(tc) - if listener.peerTLSConn != nil { - tc.peerTLSConn = listener.peerTLSConn - listener.peerTLSConn = nil + if endpoint.peerTLSConn != nil { + tc.peerTLSConn = endpoint.peerTLSConn + endpoint.peerTLSConn = nil return tc } @@ -267,7 +267,7 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC if conn.side == clientSide { peerProvidedParams.originalDstConnID = testLocalConnID(-1) } - for _, f := range listener.configTransportParams { + for _, f := range endpoint.configTransportParams { f(&peerProvidedParams) } @@ -286,13 +286,13 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC // advance causes time to pass. func (tc *testConn) advance(d time.Duration) { tc.t.Helper() - tc.listener.advance(d) + tc.endpoint.advance(d) } // advanceTo sets the current time. func (tc *testConn) advanceTo(now time.Time) { tc.t.Helper() - tc.listener.advanceTo(now) + tc.endpoint.advanceTo(now) } // advanceToTimer sets the current time to the time of the Conn's next timer event. @@ -307,10 +307,10 @@ func (tc *testConn) timerDelay() time.Duration { if tc.timer.IsZero() { return math.MaxInt64 // infinite } - if tc.timer.Before(tc.listener.now) { + if tc.timer.Before(tc.endpoint.now) { return 0 } - return tc.timer.Sub(tc.listener.now) + return tc.timer.Sub(tc.endpoint.now) } const infiniteDuration = time.Duration(math.MaxInt64) @@ -320,10 +320,10 @@ func (tc *testConn) timeUntilEvent() time.Duration { if tc.timer.IsZero() { return infiniteDuration } - if tc.timer.Before(tc.listener.now) { + if tc.timer.Before(tc.endpoint.now) { return 0 } - return tc.timer.Sub(tc.listener.now) + return tc.timer.Sub(tc.endpoint.now) } // wait blocks until the conn becomes idle. @@ -400,7 +400,7 @@ func logDatagram(t *testing.T, text string, d *testDatagram) { // write sends the Conn a datagram. func (tc *testConn) write(d *testDatagram) { tc.t.Helper() - tc.listener.writeDatagram(d) + tc.endpoint.writeDatagram(d) } // writeFrame sends the Conn a datagram containing the given frames. @@ -466,11 +466,11 @@ func (tc *testConn) readDatagram() *testDatagram { tc.wait() tc.sentPackets = nil tc.sentFrames = nil - buf := tc.listener.read() + buf := tc.endpoint.read() if buf == nil { return nil } - d := parseTestDatagram(tc.t, tc.listener, tc, buf) + d := parseTestDatagram(tc.t, tc.endpoint, tc, buf) // Log the datagram before removing ignored frames. // When things go wrong, it's useful to see all the frames. logDatagram(tc.t, "-> conn under test sends", d) @@ -771,7 +771,7 @@ func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte return w.datagram() } -func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte) *testDatagram { +func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram { t.Helper() bufSize := len(buf) d := &testDatagram{} @@ -784,7 +784,7 @@ func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte) ptype := getPacketType(buf) switch ptype { case packetTypeRetry: - retry, ok := parseRetryPacket(buf, tl.lastInitialDstConnID) + retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID) if !ok { t.Fatalf("could not parse %v packet", ptype) } @@ -938,7 +938,7 @@ func (tc *testConnHooks) init() { tc.keysInitial.r = tc.conn.keysInitial.w tc.keysInitial.w = tc.conn.keysInitial.r if tc.conn.side == serverSide { - tc.listener.acceptQueue = append(tc.listener.acceptQueue, (*testConn)(tc)) + tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) } } @@ -1039,20 +1039,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { tc.timer = timer for { - if !timer.IsZero() && !timer.After(tc.listener.now) { + if !timer.IsZero() && !timer.After(tc.endpoint.now) { if timer.Equal(tc.timerLastFired) { // If the connection timer fires at time T, the Conn should take some // action to advance the timer into the future. If the Conn reschedules // the timer for the same time, it isn't making progress and we have a bug. - tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.listener.now, timer) + tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) } else { tc.timerLastFired = timer - return tc.listener.now, timerEvent{} + return tc.endpoint.now, timerEvent{} } } select { case m := <-msgc: - return tc.listener.now, m + return tc.endpoint.now, m default: } if !tc.wakeAsync() { @@ -1066,7 +1066,7 @@ func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.T close(idlec) } m = <-msgc - return tc.listener.now, m + return tc.endpoint.now, m } func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { @@ -1074,7 +1074,7 @@ func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { } func (tc *testConnHooks) timeNow() time.Time { - return tc.listener.now + return tc.endpoint.now } // testLocalConnID returns the connection ID with a given sequence number diff --git a/internal/quic/listener.go b/internal/quic/listener.go index ca8f9b25a..82a08a18c 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -17,14 +17,14 @@ import ( "time" ) -// A Listener listens for QUIC traffic on a network address. +// An Endpoint handles QUIC traffic on a network address. // It can accept inbound connections or create outbound ones. // -// Multiple goroutines may invoke methods on a Listener simultaneously. -type Listener struct { +// Multiple goroutines may invoke methods on an Endpoint simultaneously. +type Endpoint struct { config *Config udpConn udpConn - testHooks listenerTestHooks + testHooks endpointTestHooks resetGen statelessResetTokenGenerator retry retryState @@ -37,7 +37,7 @@ type Listener struct { closec chan struct{} // closed when the listen loop exits } -type listenerTestHooks interface { +type endpointTestHooks interface { timeNow() time.Time newConn(c *Conn) } @@ -53,7 +53,7 @@ type udpConn interface { // Listen listens on a local network address. // The configuration config must be non-nil. -func Listen(network, address string, config *Config) (*Listener, error) { +func Listen(network, address string, config *Config) (*Endpoint, error) { if config.TLSConfig == nil { return nil, errors.New("TLSConfig is not set") } @@ -65,11 +65,11 @@ func Listen(network, address string, config *Config) (*Listener, error) { if err != nil { return nil, err } - return newListener(udpConn, config, nil) + return newEndpoint(udpConn, config, nil) } -func newListener(udpConn udpConn, config *Config, hooks listenerTestHooks) (*Listener, error) { - l := &Listener{ +func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { + e := &Endpoint{ config: config, udpConn: udpConn, testHooks: hooks, @@ -77,70 +77,70 @@ func newListener(udpConn udpConn, config *Config, hooks listenerTestHooks) (*Lis acceptQueue: newQueue[*Conn](), closec: make(chan struct{}), } - l.resetGen.init(config.StatelessResetKey) - l.connsMap.init() + e.resetGen.init(config.StatelessResetKey) + e.connsMap.init() if config.RequireAddressValidation { - if err := l.retry.init(); err != nil { + if err := e.retry.init(); err != nil { return nil, err } } - go l.listen() - return l, nil + go e.listen() + return e, nil } // LocalAddr returns the local network address. -func (l *Listener) LocalAddr() netip.AddrPort { - a, _ := l.udpConn.LocalAddr().(*net.UDPAddr) +func (e *Endpoint) LocalAddr() netip.AddrPort { + a, _ := e.udpConn.LocalAddr().(*net.UDPAddr) return a.AddrPort() } -// Close closes the listener. -// Any blocked operations on the Listener or associated Conns and Stream will be unblocked +// Close closes the Endpoint. +// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked // and return errors. // // Close aborts every open connection. // Data in stream read and write buffers is discarded. // It waits for the peers of any open connection to acknowledge the connection has been closed. -func (l *Listener) Close(ctx context.Context) error { - l.acceptQueue.close(errors.New("listener closed")) - l.connsMu.Lock() - if !l.closing { - l.closing = true - for c := range l.conns { +func (e *Endpoint) Close(ctx context.Context) error { + e.acceptQueue.close(errors.New("endpoint closed")) + e.connsMu.Lock() + if !e.closing { + e.closing = true + for c := range e.conns { c.Abort(localTransportError{code: errNo}) } - if len(l.conns) == 0 { - l.udpConn.Close() + if len(e.conns) == 0 { + e.udpConn.Close() } } - l.connsMu.Unlock() + e.connsMu.Unlock() select { - case <-l.closec: + case <-e.closec: case <-ctx.Done(): - l.connsMu.Lock() - for c := range l.conns { + e.connsMu.Lock() + for c := range e.conns { c.exit() } - l.connsMu.Unlock() + e.connsMu.Unlock() return ctx.Err() } return nil } -// Accept waits for and returns the next connection to the listener. -func (l *Listener) Accept(ctx context.Context) (*Conn, error) { - return l.acceptQueue.get(ctx, nil) +// Accept waits for and returns the next connection. +func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { + return e.acceptQueue.get(ctx, nil) } // Dial creates and returns a connection to a network address. -func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) { +func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) { u, err := net.ResolveUDPAddr(network, address) if err != nil { return nil, err } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) + c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) if err != nil { return nil, err } @@ -151,29 +151,29 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { - l.connsMu.Lock() - defer l.connsMu.Unlock() - if l.closing { - return nil, errors.New("listener closed") +func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { + e.connsMu.Lock() + defer e.connsMu.Unlock() + if e.closing { + return nil, errors.New("endpoint closed") } - c, err := newConn(now, side, cids, peerAddr, l.config, l) + c, err := newConn(now, side, cids, peerAddr, e.config, e) if err != nil { return nil, err } - l.conns[c] = struct{}{} + e.conns[c] = struct{}{} return c, nil } // serverConnEstablished is called by a conn when the handshake completes // for an inbound (serverSide) connection. -func (l *Listener) serverConnEstablished(c *Conn) { - l.acceptQueue.put(c) +func (e *Endpoint) serverConnEstablished(c *Conn) { + e.acceptQueue.put(c) } // connDrained is called by a conn when it leaves the draining state, // either when the peer acknowledges connection closure or the drain timeout expires. -func (l *Listener) connDrained(c *Conn) { +func (e *Endpoint) connDrained(c *Conn) { var cids [][]byte for i := range c.connIDState.local { cids = append(cids, c.connIDState.local[i].cid) @@ -182,7 +182,7 @@ func (l *Listener) connDrained(c *Conn) { for i := range c.connIDState.remote { tokens = append(tokens, c.connIDState.remote[i].resetToken) } - l.connsMap.updateConnIDs(func(conns *connsMap) { + e.connsMap.updateConnIDs(func(conns *connsMap) { for _, cid := range cids { conns.retireConnID(c, cid) } @@ -190,60 +190,60 @@ func (l *Listener) connDrained(c *Conn) { conns.retireResetToken(c, token) } }) - l.connsMu.Lock() - defer l.connsMu.Unlock() - delete(l.conns, c) - if l.closing && len(l.conns) == 0 { - l.udpConn.Close() + e.connsMu.Lock() + defer e.connsMu.Unlock() + delete(e.conns, c) + if e.closing && len(e.conns) == 0 { + e.udpConn.Close() } } -func (l *Listener) listen() { - defer close(l.closec) +func (e *Endpoint) listen() { + defer close(e.closec) for { m := newDatagram() // TODO: Read and process the ECN (explicit congestion notification) field. // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4 - n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil) + n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil) if err != nil { - // The user has probably closed the listener. + // The user has probably closed the endpoint. // We currently don't surface errors from other causes; - // we could check to see if the listener has been closed and + // we could check to see if the endpoint has been closed and // record the unexpected error if it has not. return } if n == 0 { continue } - if l.connsMap.updateNeeded.Load() { - l.connsMap.applyUpdates() + if e.connsMap.updateNeeded.Load() { + e.connsMap.applyUpdates() } m.addr = addr m.b = m.b[:n] - l.handleDatagram(m) + e.handleDatagram(m) } } -func (l *Listener) handleDatagram(m *datagram) { +func (e *Endpoint) handleDatagram(m *datagram) { dstConnID, ok := dstConnIDForDatagram(m.b) if !ok { m.recycle() return } - c := l.connsMap.byConnID[string(dstConnID)] + c := e.connsMap.byConnID[string(dstConnID)] if c == nil { // TODO: Move this branch into a separate goroutine to avoid blocking - // the listener while processing packets. - l.handleUnknownDestinationDatagram(m) + // the endpoint while processing packets. + e.handleUnknownDestinationDatagram(m) return } - // TODO: This can block the listener while waiting for the conn to accept the dgram. + // TODO: This can block the endpoint while waiting for the conn to accept the dgram. // Think about buffering between the receive loop and the conn. c.sendMsg(m) } -func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { +func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { defer func() { if m != nil { m.recycle() @@ -254,15 +254,15 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { return } var now time.Time - if l.testHooks != nil { - now = l.testHooks.timeNow() + if e.testHooks != nil { + now = e.testHooks.timeNow() } else { now = time.Now() } // Check to see if this is a stateless reset. var token statelessResetToken copy(token[:], m.b[len(m.b)-len(token):]) - if c := l.connsMap.byResetToken[token]; c != nil { + if c := e.connsMap.byResetToken[token]; c != nil { c.sendMsg(func(now time.Time, c *Conn) { c.handleStatelessReset(now, token) }) @@ -271,7 +271,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { // If this is a 1-RTT packet, there's nothing productive we can do with it. // Send a stateless reset if possible. if !isLongHeader(m.b[0]) { - l.maybeSendStatelessReset(m.b, m.addr) + e.maybeSendStatelessReset(m.b, m.addr) return } p, ok := parseGenericLongHeaderPacket(m.b) @@ -285,7 +285,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { return default: // Unknown version. - l.sendVersionNegotiation(p, m.addr) + e.sendVersionNegotiation(p, m.addr) return } if getPacketType(m.b) != packetTypeInitial { @@ -300,10 +300,10 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { srcConnID: p.srcConnID, dstConnID: p.dstConnID, } - if l.config.RequireAddressValidation { + if e.config.RequireAddressValidation { var ok bool cids.retrySrcConnID = p.dstConnID - cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr) if !ok { return } @@ -311,7 +311,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { cids.originalDstConnID = p.dstConnID } var err error - c, err := l.newConn(now, serverSide, cids, m.addr) + c, err := e.newConn(now, serverSide, cids, m.addr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. @@ -323,8 +323,8 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { m = nil // don't recycle, sendMsg takes ownership } -func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { - if !l.resetGen.canReset { +func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { + if !e.resetGen.canReset { // Config.StatelessResetKey isn't set, so we don't send stateless resets. return } @@ -339,7 +339,7 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { } // TODO: Rate limit stateless resets. cid := b[1:][:connIDLen] - token := l.resetGen.tokenForConnID(cid) + token := e.resetGen.tokenForConnID(cid) // We want to generate a stateless reset that is as short as possible, // but long enough to be difficult to distinguish from a 1-RTT packet. // @@ -364,17 +364,17 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { b[0] &^= headerFormLong // clear long header bit b[0] |= fixedBit // set fixed bit copy(b[len(b)-statelessResetTokenLen:], token[:]) - l.sendDatagram(b, addr) + e.sendDatagram(b, addr) } -func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { m := newDatagram() m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) - l.sendDatagram(m.b, addr) + e.sendDatagram(m.b, addr) m.recycle() } -func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { +func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { keys := initialKeys(in.dstConnID, serverSide) var w packetWriter p := longPacket{ @@ -393,15 +393,15 @@ func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort if len(buf) == 0 { return } - l.sendDatagram(buf, addr) + e.sendDatagram(buf, addr) } -func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error { - _, err := l.udpConn.WriteToUDPAddrPort(p, addr) +func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error { + _, err := e.udpConn.WriteToUDPAddrPort(p, addr) return err } -// A connsMap is a listener's mapping of conn ids and reset tokens to conns. +// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. type connsMap struct { byConnID map[string]*Conn byResetToken map[statelessResetToken]*Conn diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go index 037fb21b4..f9fc80152 100644 --- a/internal/quic/listener_test.go +++ b/internal/quic/listener_test.go @@ -64,39 +64,39 @@ func TestStreamTransfer(t *testing.T) { func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() - l1 := newLocalListener(t, serverSide, conf1) - l2 := newLocalListener(t, clientSide, conf2) - c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String()) + e1 := newLocalEndpoint(t, serverSide, conf1) + e2 := newLocalEndpoint(t, clientSide, conf2) + c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String()) if err != nil { t.Fatal(err) } - c1, err := l1.Accept(ctx) + c1, err := e1.Accept(ctx) if err != nil { t.Fatal(err) } return c2, c1 } -func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener { +func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { t.Helper() if conf.TLSConfig == nil { newConf := *conf conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } - l, err := Listen("udp", "127.0.0.1:0", conf) + e, err := Listen("udp", "127.0.0.1:0", conf) if err != nil { t.Fatal(err) } t.Cleanup(func() { - l.Close(context.Background()) + e.Close(context.Background()) }) - return l + return e } -type testListener struct { +type testEndpoint struct { t *testing.T - l *Listener + e *Endpoint now time.Time recvc chan *datagram idlec chan struct{} @@ -109,8 +109,8 @@ type testListener struct { lastInitialDstConnID []byte // for parsing Retry packets } -func newTestListener(t *testing.T, config *Config) *testListener { - tl := &testListener{ +func newTestEndpoint(t *testing.T, config *Config) *testEndpoint { + te := &testEndpoint{ t: t, now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), recvc: make(chan *datagram), @@ -118,52 +118,52 @@ func newTestListener(t *testing.T, config *Config) *testListener { conns: make(map[*Conn]*testConn), } var err error - tl.l, err = newListener((*testListenerUDPConn)(tl), config, (*testListenerHooks)(tl)) + te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te)) if err != nil { t.Fatal(err) } - t.Cleanup(tl.cleanup) - return tl + t.Cleanup(te.cleanup) + return te } -func (tl *testListener) cleanup() { - tl.l.Close(canceledContext()) +func (te *testEndpoint) cleanup() { + te.e.Close(canceledContext()) } -func (tl *testListener) wait() { +func (te *testEndpoint) wait() { select { - case tl.idlec <- struct{}{}: - case <-tl.l.closec: + case te.idlec <- struct{}{}: + case <-te.e.closec: } - for _, tc := range tl.conns { + for _, tc := range te.conns { tc.wait() } } -// accept returns a server connection from the listener. -// Unlike Listener.Accept, connections are available as soon as they are created. -func (tl *testListener) accept() *testConn { - if len(tl.acceptQueue) == 0 { - tl.t.Fatalf("accept: expected available conn, but found none") +// accept returns a server connection from the endpoint. +// Unlike Endpoint.Accept, connections are available as soon as they are created. +func (te *testEndpoint) accept() *testConn { + if len(te.acceptQueue) == 0 { + te.t.Fatalf("accept: expected available conn, but found none") } - tc := tl.acceptQueue[0] - tl.acceptQueue = tl.acceptQueue[1:] + tc := te.acceptQueue[0] + te.acceptQueue = te.acceptQueue[1:] return tc } -func (tl *testListener) write(d *datagram) { - tl.recvc <- d - tl.wait() +func (te *testEndpoint) write(d *datagram) { + te.recvc <- d + te.wait() } var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000") -func (tl *testListener) writeDatagram(d *testDatagram) { - tl.t.Helper() - logDatagram(tl.t, "<- listener under test receives", d) +func (te *testEndpoint) writeDatagram(d *testDatagram) { + te.t.Helper() + logDatagram(te.t, "<- endpoint under test receives", d) var buf []byte for _, p := range d.packets { - tc := tl.connForDestination(p.dstConnID) + tc := te.connForDestination(p.dstConnID) if p.ptype != packetTypeRetry && tc != nil { space := spaceForPacketType(p.ptype) if p.num >= tc.peerNextPacketNum[space] { @@ -171,13 +171,13 @@ func (tl *testListener) writeDatagram(d *testDatagram) { } } if p.ptype == packetTypeInitial { - tl.lastInitialDstConnID = p.dstConnID + te.lastInitialDstConnID = p.dstConnID } pad := 0 if p.ptype == packetType1RTT { pad = d.paddedSize - len(buf) } - buf = append(buf, encodeTestPacket(tl.t, tc, p, pad)...) + buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...) } for len(buf) < d.paddedSize { buf = append(buf, 0) @@ -186,14 +186,14 @@ func (tl *testListener) writeDatagram(d *testDatagram) { if !addr.IsValid() { addr = testClientAddr } - tl.write(&datagram{ + te.write(&datagram{ b: buf, addr: addr, }) } -func (tl *testListener) connForDestination(dstConnID []byte) *testConn { - for _, tc := range tl.conns { +func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn { + for _, tc := range te.conns { for _, loc := range tc.conn.connIDState.local { if bytes.Equal(loc.cid, dstConnID) { return tc @@ -203,8 +203,8 @@ func (tl *testListener) connForDestination(dstConnID []byte) *testConn { return nil } -func (tl *testListener) connForSource(srcConnID []byte) *testConn { - for _, tc := range tl.conns { +func (te *testEndpoint) connForSource(srcConnID []byte) *testConn { + for _, tc := range te.conns { for _, loc := range tc.conn.connIDState.remote { if bytes.Equal(loc.cid, srcConnID) { return tc @@ -214,106 +214,106 @@ func (tl *testListener) connForSource(srcConnID []byte) *testConn { return nil } -func (tl *testListener) read() []byte { - tl.t.Helper() - tl.wait() - if len(tl.sentDatagrams) == 0 { +func (te *testEndpoint) read() []byte { + te.t.Helper() + te.wait() + if len(te.sentDatagrams) == 0 { return nil } - d := tl.sentDatagrams[0] - tl.sentDatagrams = tl.sentDatagrams[1:] + d := te.sentDatagrams[0] + te.sentDatagrams = te.sentDatagrams[1:] return d } -func (tl *testListener) readDatagram() *testDatagram { - tl.t.Helper() - buf := tl.read() +func (te *testEndpoint) readDatagram() *testDatagram { + te.t.Helper() + buf := te.read() if buf == nil { return nil } p, _ := parseGenericLongHeaderPacket(buf) - tc := tl.connForSource(p.dstConnID) - d := parseTestDatagram(tl.t, tl, tc, buf) - logDatagram(tl.t, "-> listener under test sends", d) + tc := te.connForSource(p.dstConnID) + d := parseTestDatagram(te.t, te, tc, buf) + logDatagram(te.t, "-> endpoint under test sends", d) return d } -// wantDatagram indicates that we expect the Listener to send a datagram. -func (tl *testListener) wantDatagram(expectation string, want *testDatagram) { - tl.t.Helper() - got := tl.readDatagram() +// wantDatagram indicates that we expect the Endpoint to send a datagram. +func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) { + te.t.Helper() + got := te.readDatagram() if !reflect.DeepEqual(got, want) { - tl.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) + te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } -// wantIdle indicates that we expect the Listener to not send any more datagrams. -func (tl *testListener) wantIdle(expectation string) { - if got := tl.readDatagram(); got != nil { - tl.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got) +// wantIdle indicates that we expect the Endpoint to not send any more datagrams. +func (te *testEndpoint) wantIdle(expectation string) { + if got := te.readDatagram(); got != nil { + te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got) } } // advance causes time to pass. -func (tl *testListener) advance(d time.Duration) { - tl.t.Helper() - tl.advanceTo(tl.now.Add(d)) +func (te *testEndpoint) advance(d time.Duration) { + te.t.Helper() + te.advanceTo(te.now.Add(d)) } // advanceTo sets the current time. -func (tl *testListener) advanceTo(now time.Time) { - tl.t.Helper() - if tl.now.After(now) { - tl.t.Fatalf("time moved backwards: %v -> %v", tl.now, now) +func (te *testEndpoint) advanceTo(now time.Time) { + te.t.Helper() + if te.now.After(now) { + te.t.Fatalf("time moved backwards: %v -> %v", te.now, now) } - tl.now = now - for _, tc := range tl.conns { - if !tc.timer.After(tl.now) { + te.now = now + for _, tc := range te.conns { + if !tc.timer.After(te.now) { tc.conn.sendMsg(timerEvent{}) tc.wait() } } } -// testListenerHooks implements listenerTestHooks. -type testListenerHooks testListener +// testEndpointHooks implements endpointTestHooks. +type testEndpointHooks testEndpoint -func (tl *testListenerHooks) timeNow() time.Time { - return tl.now +func (te *testEndpointHooks) timeNow() time.Time { + return te.now } -func (tl *testListenerHooks) newConn(c *Conn) { - tc := newTestConnForConn(tl.t, (*testListener)(tl), c) - tl.conns[c] = tc +func (te *testEndpointHooks) newConn(c *Conn) { + tc := newTestConnForConn(te.t, (*testEndpoint)(te), c) + te.conns[c] = tc } -// testListenerUDPConn implements UDPConn. -type testListenerUDPConn testListener +// testEndpointUDPConn implements UDPConn. +type testEndpointUDPConn testEndpoint -func (tl *testListenerUDPConn) Close() error { - close(tl.recvc) +func (te *testEndpointUDPConn) Close() error { + close(te.recvc) return nil } -func (tl *testListenerUDPConn) LocalAddr() net.Addr { +func (te *testEndpointUDPConn) LocalAddr() net.Addr { return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443")) } -func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { +func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { for { select { - case d, ok := <-tl.recvc: + case d, ok := <-te.recvc: if !ok { return 0, 0, 0, netip.AddrPort{}, io.EOF } n = copy(b, d.b) return n, 0, 0, d.addr, nil - case <-tl.idlec: + case <-te.idlec: } } } -func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...)) +func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...)) return len(b), nil } diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index c8ee429fe..ea53cab1e 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -95,7 +95,7 @@ func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.Add slog.String("type", vantage), ), ) - localAddr := c.listener.LocalAddr() + localAddr := c.endpoint.LocalAddr() // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2 c.log.LogAttrs(context.Background(), QLogLevelEndpoint, "connectivity:connection_started", diff --git a/internal/quic/retry.go b/internal/quic/retry.go index e3d9f4d7d..31cb57b88 100644 --- a/internal/quic/retry.go +++ b/internal/quic/retry.go @@ -39,7 +39,7 @@ var ( // retryTokenValidityPeriod is how long we accept a Retry packet token after sending it. const retryTokenValidityPeriod = 5 * time.Second -// retryState generates and validates a listener's retry tokens. +// retryState generates and validates an endpoint's retry tokens. type retryState struct { aead cipher.AEAD } @@ -139,7 +139,7 @@ func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []by return additional } -func (l *Listener) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { // The retry token is at the start of an Initial packet's data. token, n := consumeUint8Bytes(p.data) if n < 0 { @@ -151,22 +151,22 @@ func (l *Listener) validateInitialAddress(now time.Time, p genericLongPacket, ad if len(token) == 0 { // The sender has not provided a token. // Send a Retry packet to them with one. - l.sendRetry(now, p, addr) + e.sendRetry(now, p, addr) return nil, false } - origDstConnID, ok = l.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) if !ok { // This does not seem to be a valid token. // Close the connection with an INVALID_TOKEN error. // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 - l.sendConnectionClose(p, addr, errInvalidToken) + e.sendConnectionClose(p, addr, errInvalidToken) return nil, false } return origDstConnID, true } -func (l *Listener) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { - token, srcConnID, err := l.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) if err != nil { return } @@ -175,7 +175,7 @@ func (l *Listener) sendRetry(now time.Time, p genericLongPacket, addr netip.Addr srcConnID: srcConnID, token: token, }) - l.sendDatagram(b, addr) + e.sendDatagram(b, addr) } type retryPacket struct { diff --git a/internal/quic/retry_test.go b/internal/quic/retry_test.go index f754270a5..4a21a4ca1 100644 --- a/internal/quic/retry_test.go +++ b/internal/quic/retry_test.go @@ -16,7 +16,7 @@ import ( ) type retryServerTest struct { - tl *testListener + te *testEndpoint originalSrcConnID []byte originalDstConnID []byte retry retryPacket @@ -32,16 +32,16 @@ func newRetryServerTest(t *testing.T) *retryServerTest { TLSConfig: newTestTLSConfig(serverSide), RequireAddressValidation: true, } - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) srcID := testPeerConnID(0) dstID := testLocalConnID(-1) params := defaultTransportParameters() params.initialSrcConnID = srcID - initialCrypto := initialClientCrypto(t, tl, params) + initialCrypto := initialClientCrypto(t, te, params) // Initial packet with no Token. // Server responds with a Retry containing a token. - tl.writeDatagram(&testDatagram{ + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, @@ -56,7 +56,7 @@ func newRetryServerTest(t *testing.T) *retryServerTest { }}, paddedSize: 1200, }) - got := tl.readDatagram() + got := te.readDatagram() if len(got.packets) != 1 || got.packets[0].ptype != packetTypeRetry { t.Fatalf("got datagram: %v\nwant Retry", got) } @@ -66,7 +66,7 @@ func newRetryServerTest(t *testing.T) *retryServerTest { } return &retryServerTest{ - tl: tl, + te: te, originalSrcConnID: srcID, originalDstConnID: dstID, retry: retryPacket{ @@ -80,9 +80,9 @@ func newRetryServerTest(t *testing.T) *retryServerTest { func TestRetryServerSucceeds(t *testing.T) { rt := newRetryServerTest(t) - tl := rt.tl - tl.advance(retryTokenValidityPeriod) - tl.writeDatagram(&testDatagram{ + te := rt.te + te.advance(retryTokenValidityPeriod) + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -98,7 +98,7 @@ func TestRetryServerSucceeds(t *testing.T) { }}, paddedSize: 1200, }) - tc := tl.accept() + tc := te.accept() initial := tc.readPacket() if initial == nil || initial.ptype != packetTypeInitial { t.Fatalf("got packet:\n%v\nwant: Initial", initial) @@ -124,8 +124,8 @@ func TestRetryServerTokenInvalid(t *testing.T) { // INVALID_TOKEN error." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 rt := newRetryServerTest(t) - tl := rt.tl - tl.writeDatagram(&testDatagram{ + te := rt.te + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -141,7 +141,7 @@ func TestRetryServerTokenInvalid(t *testing.T) { }}, paddedSize: 1200, }) - tl.wantDatagram("server closes connection after Initial with invalid Retry token", + te.wantDatagram("server closes connection after Initial with invalid Retry token", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -152,9 +152,9 @@ func TestRetryServerTokenTooOld(t *testing.T) { // "[...] a token SHOULD have an expiration time [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.3-3 rt := newRetryServerTest(t) - tl := rt.tl - tl.advance(retryTokenValidityPeriod + time.Second) - tl.writeDatagram(&testDatagram{ + te := rt.te + te.advance(retryTokenValidityPeriod + time.Second) + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -170,7 +170,7 @@ func TestRetryServerTokenTooOld(t *testing.T) { }}, paddedSize: 1200, }) - tl.wantDatagram("server closes connection after Initial with expired token", + te.wantDatagram("server closes connection after Initial with expired token", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -182,8 +182,8 @@ func TestRetryServerTokenWrongIP(t *testing.T) { // to verify that the source IP address and port in client packets remain constant." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.4-3 rt := newRetryServerTest(t) - tl := rt.tl - tl.writeDatagram(&testDatagram{ + te := rt.te + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -200,7 +200,7 @@ func TestRetryServerTokenWrongIP(t *testing.T) { paddedSize: 1200, addr: netip.MustParseAddrPort("10.0.0.2:8000"), }) - tl.wantDatagram("server closes connection after Initial from wrong address", + te.wantDatagram("server closes connection after Initial from wrong address", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -435,7 +435,7 @@ func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { token: []byte{1, 2, 3, 4}, }) pkt[len(pkt)-1] ^= 1 // invalidate the integrity tag - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: pkt, addr: testClientAddr, }) @@ -527,14 +527,14 @@ func TestParseInvalidRetryPackets(t *testing.T) { } } -func initialClientCrypto(t *testing.T, l *testListener, p transportParameters) []byte { +func initialClientCrypto(t *testing.T, e *testEndpoint, p transportParameters) []byte { t.Helper() config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)} tlsClient := tls.QUICClient(config) tlsClient.SetTransportParameters(marshalTransportParameters(p)) tlsClient.Start(context.Background()) //defer tlsClient.Close() - l.peerTLSConn = tlsClient + e.peerTLSConn = tlsClient var data []byte for { e := tlsClient.NextEvent() diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go index c01375fbd..45a49e81e 100644 --- a/internal/quic/stateless_reset_test.go +++ b/internal/quic/stateless_reset_test.go @@ -68,7 +68,7 @@ func TestStatelessResetSentSizes(t *testing.T) { StatelessResetKey: testStatelessResetKey, } addr := netip.MustParseAddr("127.0.0.1") - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) for i, test := range []struct { reqSize int wantSize int @@ -105,9 +105,9 @@ func TestStatelessResetSentSizes(t *testing.T) { cid := testLocalConnID(int64(i)) token := testStatelessResetToken(cid) addrport := netip.AddrPortFrom(addr, uint16(8000+i)) - tl.write(newDatagramForReset(cid, test.reqSize, addrport)) + te.write(newDatagramForReset(cid, test.reqSize, addrport)) - got := tl.read() + got := te.read() if len(got) != test.wantSize { t.Errorf("got %v-byte response to %v-byte req, want %v", len(got), test.reqSize, test.wantSize) @@ -149,7 +149,7 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { resetToken := testPeerStatelessResetToken(1) // provided during handshake dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) @@ -179,7 +179,7 @@ func TestStatelessResetSuccessfulTransportParameter(t *testing.T) { tc.handshake() dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) @@ -243,7 +243,7 @@ func TestStatelessResetSuccessfulPrefix(t *testing.T) { dgram = append(dgram, byte(len(dgram))) // semi-random junk } dgram = append(dgram, resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) { @@ -278,7 +278,7 @@ func TestStatelessResetRetiredConnID(t *testing.T) { // Receive a stateless reset for connection ID 0. dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index fa339b9fa..14f74a00a 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -36,7 +36,7 @@ func (tc *testConn) handshake() { for { if i == len(dgrams)-1 { if tc.conn.side == clientSide { - want := tc.listener.now.Add(maxAckDelay - timerGranularity) + want := tc.endpoint.now.Add(maxAckDelay - timerGranularity) if !tc.timer.Equal(want) { t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer) } @@ -85,7 +85,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { testPeerConnID(0), testPeerConnID(1), } - localResetToken := tc.listener.l.resetGen.tokenForConnID(localConnIDs[1]) + localResetToken := tc.endpoint.e.resetGen.tokenForConnID(localConnIDs[1]) peerResetToken := testPeerStatelessResetToken(1) if tc.conn.side == clientSide { clientConnIDs = localConnIDs diff --git a/internal/quic/version_test.go b/internal/quic/version_test.go index 830e0e1c8..92fabd7b3 100644 --- a/internal/quic/version_test.go +++ b/internal/quic/version_test.go @@ -17,7 +17,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { config := &Config{ TLSConfig: newTestTLSConfig(serverSide), } - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) // Packet of unknown contents for some unrecognized QUIC version. dstConnID := []byte{1, 2, 3, 4} @@ -34,10 +34,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { pkt = append(pkt, 0) } - tl.write(&datagram{ + te.write(&datagram{ b: pkt, }) - gotPkt := tl.read() + gotPkt := te.read() if gotPkt == nil { t.Fatalf("got no response; want Version Negotiaion") } @@ -59,7 +59,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { func TestVersionNegotiationClientAborts(t *testing.T) { tc := newTestConn(t, clientSide) p := tc.readPacket() // client Initial packet - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), }) tc.wantIdle("connection does not send a CONNECTION_CLOSE") @@ -76,7 +76,7 @@ func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { debugFrameCrypto{ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], }) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), }) if err := tc.conn.waitReady(canceledContext()); err != context.Canceled { @@ -94,7 +94,7 @@ func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) p := tc.readPacket() // client Initial packet - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10), }) tc.writeFrames(packetTypeInitial, From 13e88dd2f74327f590622c561597594022c45de5 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 17 Nov 2023 13:23:46 -0800 Subject: [PATCH 05/70] quic: rename listener{_test}.go to endpoint{_test}.go Separate from CL 543298 to help git recognize that this is a rename. Change-Id: I1cbdffeb66d0960c951a564b8fc1a3dcf2cf40f6 Reviewed-on: https://go-review.googlesource.com/c/net/+/543299 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/{listener.go => endpoint.go} | 0 internal/quic/{listener_test.go => endpoint_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename internal/quic/{listener.go => endpoint.go} (100%) rename internal/quic/{listener_test.go => endpoint_test.go} (100%) diff --git a/internal/quic/listener.go b/internal/quic/endpoint.go similarity index 100% rename from internal/quic/listener.go rename to internal/quic/endpoint.go diff --git a/internal/quic/listener_test.go b/internal/quic/endpoint_test.go similarity index 100% rename from internal/quic/listener_test.go rename to internal/quic/endpoint_test.go From a8e0109124268a0a063b5900bce0c2b33398ec01 Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Mon, 27 Nov 2023 17:02:04 +0000 Subject: [PATCH 06/70] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: Ia3b446633ffc0b3264692cfaae765bfb79063dab Reviewed-on: https://go-review.googlesource.com/c/net/+/545175 Auto-Submit: Gopher Robot LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Knyszek Reviewed-by: Dmitri Shuralyov --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 21deffd4b..8ab3f40e1 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.15.0 - golang.org/x/sys v0.14.0 - golang.org/x/term v0.14.0 + golang.org/x/crypto v0.16.0 + golang.org/x/sys v0.15.0 + golang.org/x/term v0.15.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 54759e489..bb6ed68a0 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= From f812076c5dd92f30fe0b9ed860869246746c9954 Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Wed, 29 Nov 2023 13:48:34 -0800 Subject: [PATCH 07/70] http2: explicitly set minimum TLS version in tests Fixes tests when using 1.22 in certain cases where the go.mod 'go' directive is not being respected. Change-Id: Ia986a7c900287abd67f0a05f662906a665cdeb87 Reviewed-on: https://go-review.googlesource.com/c/net/+/546115 LUCI-TryBot-Result: Go LUCI Auto-Submit: Roland Shoemaker Reviewed-by: Damien Neil --- http2/server_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/http2/server_test.go b/http2/server_test.go index 22657cbfe..1fdd191ef 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -145,6 +145,12 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} ConfigureServer(ts.Config, h2server) + // Go 1.22 changes the default minimum TLS version to TLS 1.2, + // in order to properly test cases where we want to reject low + // TLS versions, we need to explicitly configure the minimum + // version here. + ts.Config.TLSConfig.MinVersion = tls.VersionTLS10 + st := &serverTester{ t: t, ts: ts, From 491f3545934c0aa6f51ce63beb323406693597ec Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 10 Nov 2023 08:01:03 -0800 Subject: [PATCH 08/70] quic: log packets and frames For golang/go#58547 Change-Id: I601f1e74417c0de206f71da58cef5938bba6e860 Reviewed-on: https://go-review.googlesource.com/c/net/+/543084 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_recv.go | 6 + internal/quic/conn_send.go | 9 ++ internal/quic/frame_debug.go | 220 ++++++++++++++++++++++++++++- internal/quic/packet.go | 16 +++ internal/quic/packet_codec_test.go | 71 ++++++++++ internal/quic/qlog.go | 102 +++++++++++++ internal/quic/qlog/json_writer.go | 125 ++++++++++++---- internal/quic/qlog_test.go | 108 +++++++++++--- internal/quic/quic.go | 11 ++ 9 files changed, 616 insertions(+), 52 deletions(-) diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 156ef5dd5..045bf861c 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -101,6 +101,9 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa if logPackets { logInboundLongPacket(c, p) } + if c.logEnabled(QLogLevelPacket) { + c.logLongPacketReceived(p, buf[:n]) + } c.connIDState.handlePacket(c, p.ptype, p.srcConnID) ackEliciting := c.handleFrames(now, ptype, space, p.payload) c.acks[space].receive(now, space, p.num, ackEliciting) @@ -149,6 +152,9 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { if logPackets { logInboundShortPacket(c, p) } + if c.logEnabled(QLogLevelPacket) { + c.log1RTTPacketReceived(p, buf) + } ackEliciting := c.handleFrames(now, packetType1RTT, appDataSpace, p.payload) c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting) return len(buf) diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 4065474d2..e2240f2fd 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -75,6 +75,9 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { c.idleHandlePacketSent(now, sentInitial) @@ -104,6 +107,9 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, handshakeSpace, sent) @@ -132,6 +138,9 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.payload()) + } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, appDataSpace, sent) diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go index dc8009037..0902c385f 100644 --- a/internal/quic/frame_debug.go +++ b/internal/quic/frame_debug.go @@ -8,6 +8,9 @@ package quic import ( "fmt" + "log/slog" + "strconv" + "time" ) // A debugFrame is a representation of the contents of a QUIC frame, @@ -15,6 +18,7 @@ import ( type debugFrame interface { String() string write(w *packetWriter) bool + LogValue() slog.Value } func parseDebugFrame(b []byte) (f debugFrame, n int) { @@ -97,6 +101,13 @@ func (f debugFramePadding) write(w *packetWriter) bool { return true } +func (f debugFramePadding) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "padding"), + slog.Int("length", f.size), + ) +} + // debugFramePing is a PING frame. type debugFramePing struct{} @@ -112,6 +123,12 @@ func (f debugFramePing) write(w *packetWriter) bool { return w.appendPingFrame() } +func (f debugFramePing) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "ping"), + ) +} + // debugFrameAck is an ACK frame. type debugFrameAck struct { ackDelay unscaledAckDelay @@ -126,7 +143,7 @@ func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { end: end, }) }) - // Ranges are parsed smallest to highest; reverse ranges slice to order them high to low. + // Ranges are parsed high to low; reverse ranges slice to order them low to high. for i := 0; i < len(f.ranges)/2; i++ { j := len(f.ranges) - 1 f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i] @@ -146,6 +163,61 @@ func (f debugFrameAck) write(w *packetWriter) bool { return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay) } +func (f debugFrameAck) LogValue() slog.Value { + return slog.StringValue("error: debugFrameAck should not appear as a slog Value") +} + +// debugFrameScaledAck is an ACK frame with scaled ACK Delay. +// +// This type is used in qlog events, which need access to the delay as a duration. +type debugFrameScaledAck struct { + ackDelay time.Duration + ranges []i64range[packetNumber] +} + +func (f debugFrameScaledAck) LogValue() slog.Value { + var ackDelay slog.Attr + if f.ackDelay >= 0 { + ackDelay = slog.Duration("ack_delay", f.ackDelay) + } + return slog.GroupValue( + slog.String("frame_type", "ack"), + // Rather than trying to convert the ack ranges into the slog data model, + // pass a value that can JSON-encode itself. + slog.Any("acked_ranges", debugAckRanges(f.ranges)), + ackDelay, + ) +} + +type debugAckRanges []i64range[packetNumber] + +// AppendJSON appends a JSON encoding of the ack ranges to b, and returns it. +// This is different than the standard json.Marshaler, but more efficient. +// Since we only use this in cooperation with the qlog package, +// encoding/json compatibility is irrelevant. +func (r debugAckRanges) AppendJSON(b []byte) []byte { + b = append(b, '[') + for i, ar := range r { + start, end := ar.start, ar.end-1 // qlog ranges are closed-closed + if i != 0 { + b = append(b, ',') + } + b = append(b, '[') + b = strconv.AppendInt(b, int64(start), 10) + if start != end { + b = append(b, ',') + b = strconv.AppendInt(b, int64(end), 10) + } + b = append(b, ']') + } + b = append(b, ']') + return b +} + +func (r debugAckRanges) String() string { + return string(r.AppendJSON(nil)) +} + // debugFrameResetStream is a RESET_STREAM frame. type debugFrameResetStream struct { id streamID @@ -166,6 +238,14 @@ func (f debugFrameResetStream) write(w *packetWriter) bool { return w.appendResetStreamFrame(f.id, f.code, f.finalSize) } +func (f debugFrameResetStream) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "reset_stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("final_size", uint64(f.finalSize)), + ) +} + // debugFrameStopSending is a STOP_SENDING frame. type debugFrameStopSending struct { id streamID @@ -185,6 +265,14 @@ func (f debugFrameStopSending) write(w *packetWriter) bool { return w.appendStopSendingFrame(f.id, f.code) } +func (f debugFrameStopSending) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stop_sending"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("error_code", uint64(f.code)), + ) +} + // debugFrameCrypto is a CRYPTO frame. type debugFrameCrypto struct { off int64 @@ -206,6 +294,14 @@ func (f debugFrameCrypto) write(w *packetWriter) bool { return added } +func (f debugFrameCrypto) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "crypto"), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + ) +} + // debugFrameNewToken is a NEW_TOKEN frame. type debugFrameNewToken struct { token []byte @@ -224,6 +320,13 @@ func (f debugFrameNewToken) write(w *packetWriter) bool { return w.appendNewTokenFrame(f.token) } +func (f debugFrameNewToken) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_token"), + slogHexstring("token", f.token), + ) +} + // debugFrameStream is a STREAM frame. type debugFrameStream struct { id streamID @@ -251,6 +354,20 @@ func (f debugFrameStream) write(w *packetWriter) bool { return added } +func (f debugFrameStream) LogValue() slog.Value { + var fin slog.Attr + if f.fin { + fin = slog.Bool("fin", true) + } + return slog.GroupValue( + slog.String("frame_type", "stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + fin, + ) +} + // debugFrameMaxData is a MAX_DATA frame. type debugFrameMaxData struct { max int64 @@ -269,6 +386,13 @@ func (f debugFrameMaxData) write(w *packetWriter) bool { return w.appendMaxDataFrame(f.max) } +func (f debugFrameMaxData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_data"), + slog.Int64("maximum", f.max), + ) +} + // debugFrameMaxStreamData is a MAX_STREAM_DATA frame. type debugFrameMaxStreamData struct { id streamID @@ -288,6 +412,14 @@ func (f debugFrameMaxStreamData) write(w *packetWriter) bool { return w.appendMaxStreamDataFrame(f.id, f.max) } +func (f debugFrameMaxStreamData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_stream_data"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("maximum", f.max), + ) +} + // debugFrameMaxStreams is a MAX_STREAMS frame. type debugFrameMaxStreams struct { streamType streamType @@ -307,6 +439,14 @@ func (f debugFrameMaxStreams) write(w *packetWriter) bool { return w.appendMaxStreamsFrame(f.streamType, f.max) } +func (f debugFrameMaxStreams) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_streams"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("maximum", f.max), + ) +} + // debugFrameDataBlocked is a DATA_BLOCKED frame. type debugFrameDataBlocked struct { max int64 @@ -325,6 +465,13 @@ func (f debugFrameDataBlocked) write(w *packetWriter) bool { return w.appendDataBlockedFrame(f.max) } +func (f debugFrameDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "data_blocked"), + slog.Int64("limit", f.max), + ) +} + // debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame. type debugFrameStreamDataBlocked struct { id streamID @@ -344,6 +491,14 @@ func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool { return w.appendStreamDataBlockedFrame(f.id, f.max) } +func (f debugFrameStreamDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stream_data_blocked"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("limit", f.max), + ) +} + // debugFrameStreamsBlocked is a STREAMS_BLOCKED frame. type debugFrameStreamsBlocked struct { streamType streamType @@ -363,6 +518,14 @@ func (f debugFrameStreamsBlocked) write(w *packetWriter) bool { return w.appendStreamsBlockedFrame(f.streamType, f.max) } +func (f debugFrameStreamsBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "streams_blocked"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("limit", f.max), + ) +} + // debugFrameNewConnectionID is a NEW_CONNECTION_ID frame. type debugFrameNewConnectionID struct { seq int64 @@ -384,6 +547,16 @@ func (f debugFrameNewConnectionID) write(w *packetWriter) bool { return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token) } +func (f debugFrameNewConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_connection_id"), + slog.Int64("sequence_number", f.seq), + slog.Int64("retire_prior_to", f.retirePriorTo), + slogHexstring("connection_id", f.connID), + slogHexstring("stateless_reset_token", f.token[:]), + ) +} + // debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame. type debugFrameRetireConnectionID struct { seq int64 @@ -402,6 +575,13 @@ func (f debugFrameRetireConnectionID) write(w *packetWriter) bool { return w.appendRetireConnectionIDFrame(f.seq) } +func (f debugFrameRetireConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "retire_connection_id"), + slog.Int64("sequence_number", f.seq), + ) +} + // debugFramePathChallenge is a PATH_CHALLENGE frame. type debugFramePathChallenge struct { data uint64 @@ -420,6 +600,13 @@ func (f debugFramePathChallenge) write(w *packetWriter) bool { return w.appendPathChallengeFrame(f.data) } +func (f debugFramePathChallenge) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_challenge"), + slog.String("data", fmt.Sprintf("%016x", f.data)), + ) +} + // debugFramePathResponse is a PATH_RESPONSE frame. type debugFramePathResponse struct { data uint64 @@ -438,6 +625,13 @@ func (f debugFramePathResponse) write(w *packetWriter) bool { return w.appendPathResponseFrame(f.data) } +func (f debugFramePathResponse) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_response"), + slog.String("data", fmt.Sprintf("%016x", f.data)), + ) +} + // debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error. type debugFrameConnectionCloseTransport struct { code transportError @@ -465,6 +659,15 @@ func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool { return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason) } +func (f debugFrameConnectionCloseTransport) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "transport"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + // debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error. type debugFrameConnectionCloseApplication struct { code uint64 @@ -488,6 +691,15 @@ func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool { return w.appendConnectionCloseApplicationFrame(f.code, f.reason) } +func (f debugFrameConnectionCloseApplication) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "application"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + // debugFrameHandshakeDone is a HANDSHAKE_DONE frame. type debugFrameHandshakeDone struct{} @@ -502,3 +714,9 @@ func (f debugFrameHandshakeDone) String() string { func (f debugFrameHandshakeDone) write(w *packetWriter) bool { return w.appendHandshakeDoneFrame() } + +func (f debugFrameHandshakeDone) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "handshake_done"), + ) +} diff --git a/internal/quic/packet.go b/internal/quic/packet.go index df589ccca..7a874319d 100644 --- a/internal/quic/packet.go +++ b/internal/quic/packet.go @@ -41,6 +41,22 @@ func (p packetType) String() string { return fmt.Sprintf("unknown packet type %v", byte(p)) } +func (p packetType) qlogString() string { + switch p { + case packetTypeInitial: + return "initial" + case packetType0RTT: + return "0RTT" + case packetTypeHandshake: + return "handshake" + case packetTypeRetry: + return "retry" + case packetType1RTT: + return "1RTT" + } + return "unknown" +} + // Bits set in the first byte of a packet. const ( headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go index 7b01bb00d..475e18c1d 100644 --- a/internal/quic/packet_codec_test.go +++ b/internal/quic/packet_codec_test.go @@ -9,8 +9,13 @@ package quic import ( "bytes" "crypto/tls" + "io" + "log/slog" "reflect" "testing" + "time" + + "golang.org/x/net/internal/quic/qlog" ) func TestParseLongHeaderPacket(t *testing.T) { @@ -207,11 +212,13 @@ func TestRoundtripEncodeShortPacket(t *testing.T) { func TestFrameEncodeDecode(t *testing.T) { for _, test := range []struct { s string + j string f debugFrame b []byte truncated []byte }{{ s: "PADDING*1", + j: `{"frame_type":"padding","length":1}`, f: debugFramePadding{ size: 1, }, @@ -221,12 +228,14 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "PING", + j: `{"frame_type":"ping"}`, f: debugFramePing{}, b: []byte{ 0x01, // TYPE(i) = 0x01 }, }, { s: "ACK Delay=10 [0,16) [17,32) [48,64)", + j: `"error: debugFrameAck should not appear as a slog Value"`, f: debugFrameAck{ ackDelay: 10, ranges: []i64range[packetNumber]{ @@ -257,6 +266,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "RESET_STREAM ID=1 Code=2 FinalSize=3", + j: `{"frame_type":"reset_stream","stream_id":1,"final_size":3}`, f: debugFrameResetStream{ id: 1, code: 2, @@ -270,6 +280,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STOP_SENDING ID=1 Code=2", + j: `{"frame_type":"stop_sending","stream_id":1,"error_code":2}`, f: debugFrameStopSending{ id: 1, code: 2, @@ -281,6 +292,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "CRYPTO Offset=1 Length=2", + j: `{"frame_type":"crypto","offset":1,"length":2}`, f: debugFrameCrypto{ off: 1, data: []byte{3, 4}, @@ -299,6 +311,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "NEW_TOKEN Token=0304", + j: `{"frame_type":"new_token","token":"0304"}`, f: debugFrameNewToken{ token: []byte{3, 4}, }, @@ -309,6 +322,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAM ID=1 Offset=0 Length=0", + j: `{"frame_type":"stream","stream_id":1,"offset":0,"length":0}`, f: debugFrameStream{ id: 1, fin: false, @@ -324,6 +338,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAM ID=100 Offset=4 Length=3", + j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3}`, f: debugFrameStream{ id: 100, fin: false, @@ -346,6 +361,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAM ID=100 FIN Offset=4 Length=3", + j: `{"frame_type":"stream","stream_id":100,"offset":4,"length":3,"fin":true}`, f: debugFrameStream{ id: 100, fin: true, @@ -368,6 +384,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAM ID=1 FIN Offset=100 Length=0", + j: `{"frame_type":"stream","stream_id":1,"offset":100,"length":0,"fin":true}`, f: debugFrameStream{ id: 1, fin: true, @@ -383,6 +400,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "MAX_DATA Max=10", + j: `{"frame_type":"max_data","maximum":10}`, f: debugFrameMaxData{ max: 10, }, @@ -392,6 +410,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "MAX_STREAM_DATA ID=1 Max=10", + j: `{"frame_type":"max_stream_data","stream_id":1,"maximum":10}`, f: debugFrameMaxStreamData{ id: 1, max: 10, @@ -403,6 +422,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "MAX_STREAMS Type=bidi Max=1", + j: `{"frame_type":"max_streams","stream_type":"bidirectional","maximum":1}`, f: debugFrameMaxStreams{ streamType: bidiStream, max: 1, @@ -413,6 +433,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "MAX_STREAMS Type=uni Max=1", + j: `{"frame_type":"max_streams","stream_type":"unidirectional","maximum":1}`, f: debugFrameMaxStreams{ streamType: uniStream, max: 1, @@ -423,6 +444,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "DATA_BLOCKED Max=1", + j: `{"frame_type":"data_blocked","limit":1}`, f: debugFrameDataBlocked{ max: 1, }, @@ -432,6 +454,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAM_DATA_BLOCKED ID=1 Max=2", + j: `{"frame_type":"stream_data_blocked","stream_id":1,"limit":2}`, f: debugFrameStreamDataBlocked{ id: 1, max: 2, @@ -443,6 +466,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAMS_BLOCKED Type=bidi Max=1", + j: `{"frame_type":"streams_blocked","stream_type":"bidirectional","limit":1}`, f: debugFrameStreamsBlocked{ streamType: bidiStream, max: 1, @@ -453,6 +477,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "STREAMS_BLOCKED Type=uni Max=1", + j: `{"frame_type":"streams_blocked","stream_type":"unidirectional","limit":1}`, f: debugFrameStreamsBlocked{ streamType: uniStream, max: 1, @@ -463,6 +488,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "NEW_CONNECTION_ID Seq=3 Retire=2 ID=a0a1a2a3 Token=0102030405060708090a0b0c0d0e0f10", + j: `{"frame_type":"new_connection_id","sequence_number":3,"retire_prior_to":2,"connection_id":"a0a1a2a3","stateless_reset_token":"0102030405060708090a0b0c0d0e0f10"}`, f: debugFrameNewConnectionID{ seq: 3, retirePriorTo: 2, @@ -479,6 +505,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "RETIRE_CONNECTION_ID Seq=1", + j: `{"frame_type":"retire_connection_id","sequence_number":1}`, f: debugFrameRetireConnectionID{ seq: 1, }, @@ -488,6 +515,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "PATH_CHALLENGE Data=0123456789abcdef", + j: `{"frame_type":"path_challenge","data":"0123456789abcdef"}`, f: debugFramePathChallenge{ data: 0x0123456789abcdef, }, @@ -497,6 +525,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "PATH_RESPONSE Data=0123456789abcdef", + j: `{"frame_type":"path_response","data":"0123456789abcdef"}`, f: debugFramePathResponse{ data: 0x0123456789abcdef, }, @@ -506,6 +535,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: `CONNECTION_CLOSE Code=INTERNAL_ERROR FrameType=2 Reason="oops"`, + j: `{"frame_type":"connection_close","error_space":"transport","error_code_value":1,"reason":"oops"}`, f: debugFrameConnectionCloseTransport{ code: 1, frameType: 2, @@ -520,6 +550,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: `CONNECTION_CLOSE AppCode=1 Reason="oops"`, + j: `{"frame_type":"connection_close","error_space":"application","error_code_value":1,"reason":"oops"}`, f: debugFrameConnectionCloseApplication{ code: 1, reason: "oops", @@ -532,6 +563,7 @@ func TestFrameEncodeDecode(t *testing.T) { }, }, { s: "HANDSHAKE_DONE", + j: `{"frame_type":"handshake_done"}`, f: debugFrameHandshakeDone{}, b: []byte{ 0x1e, // Type (i) = 0x1e, @@ -554,6 +586,9 @@ func TestFrameEncodeDecode(t *testing.T) { if got, want := test.f.String(), test.s; got != want { t.Errorf("frame.String():\ngot %q\nwant %q", got, want) } + if got, want := frameJSON(test.f), test.j; got != want { + t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want) + } // Try encoding the frame into too little space. // Most frames will result in an error; some (like STREAM frames) will truncate @@ -579,6 +614,42 @@ func TestFrameEncodeDecode(t *testing.T) { } } +func TestFrameScaledAck(t *testing.T) { + for _, test := range []struct { + j string + f debugFrameScaledAck + }{{ + j: `{"frame_type":"ack","acked_ranges":[[0,15],[17],[48,63]],"ack_delay":10.000000}`, + f: debugFrameScaledAck{ + ackDelay: 10 * time.Millisecond, + ranges: []i64range[packetNumber]{ + {0x00, 0x10}, + {0x11, 0x12}, + {0x30, 0x40}, + }, + }, + }} { + if got, want := frameJSON(test.f), test.j; got != want { + t.Errorf("frame.LogValue():\ngot %q\nwant %q", got, want) + } + } +} + +func frameJSON(f slog.LogValuer) string { + var buf bytes.Buffer + h := qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: QLogLevelFrame, + NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) { + return nopCloseWriter{&buf}, nil + }, + }) + // Log the frame, and then trim out everything but the frame from the log. + slog.New(h).Info("message", slog.Any("frame", f)) + _, b, _ := bytes.Cut(buf.Bytes(), []byte(`"frame":`)) + b = bytes.TrimSuffix(b, []byte("}}\n")) + return string(b) +} + func TestFrameDecode(t *testing.T) { for _, test := range []struct { desc string diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index ea53cab1e..fea8b38ee 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -11,6 +11,7 @@ import ( "encoding/hex" "log/slog" "net/netip" + "time" ) // Log levels for qlog events. @@ -145,3 +146,104 @@ func (c *Conn) logConnectionClosed() { slog.String("trigger", trigger), ) } + +func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { + pnumLen := 1 + int(pkt[0]&0x03) + length := pnumLen + len(p.payload) + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", p.ptype.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slogHexstring("scid", p.srcConnID), + slogHexstring("dcid", p.dstConnID), + slog.Int("length", length), + ), + frames, + ) +} + +func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + dstConnID, _ := dstConnIDForDatagram(pkt) + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", packetType1RTT.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slog.String("scid", ""), + slogHexstring("dcid", dstConnID), + ), + frames, + ) +} + +func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst, payload []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(payload) + } + var scid slog.Attr + if len(src) > 0 { + scid = slogHexstring("scid", src) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_sent", + slog.Group("header", + slog.String("packet_type", ptype.qlogString()), + slog.Uint64("packet_number", uint64(pnum)), + scid, + slogHexstring("dcid", dst), + ), + frames, + ) +} + +// packetFramesAttr returns the "frames" attribute containing the frames in a packet. +// We currently pass this as a slog Any containing a []slog.Value, +// where each Value is a debugFrame that implements slog.LogValuer. +// +// This isn't tremendously efficient, but avoids the need to put a JSON encoder +// in the quic package or a frame parser in the qlog package. +func (c *Conn) packetFramesAttr(payload []byte) slog.Attr { + var frames []slog.Value + for len(payload) > 0 { + f, n := parseDebugFrame(payload) + if n < 0 { + break + } + payload = payload[n:] + switch f := f.(type) { + case debugFrameAck: + // The qlog ACK frame contains the ACK Delay field as a duration. + // Interpreting the contents of this field as a duration requires + // knowing the peer's ack_delay_exponent transport parameter, + // and it's possible for us to parse an ACK frame before we've + // received that parameter. + // + // We could plumb connection state down into the frame parser, + // but for now let's minimize the amount of code that needs to + // deal with this and convert the unscaled value into a scaled one here. + ackDelay := time.Duration(-1) + if c.peerAckDelayExponent >= 0 { + ackDelay = f.ackDelay.Duration(uint8(c.peerAckDelayExponent)) + } + frames = append(frames, slog.AnyValue(debugFrameScaledAck{ + ranges: f.ranges, + ackDelay: ackDelay, + })) + default: + frames = append(frames, slog.AnyValue(f)) + } + } + return slog.Any("frames", frames) +} diff --git a/internal/quic/qlog/json_writer.go b/internal/quic/qlog/json_writer.go index 50cf33bc5..3950ab42f 100644 --- a/internal/quic/qlog/json_writer.go +++ b/internal/quic/qlog/json_writer.go @@ -42,38 +42,56 @@ func (w *jsonWriter) writeRecordEnd() { w.mu.Unlock() } -// writeAttrsField writes a []slog.Attr as an object field. -func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) { - w.writeName(name) +func (w *jsonWriter) writeAttrs(attrs []slog.Attr) { w.buf.WriteByte('{') for _, a := range attrs { + if a.Key == "" { + continue + } w.writeAttr(a) } w.buf.WriteByte('}') } -// writeAttr writes a slog.Attr as an object field. func (w *jsonWriter) writeAttr(a slog.Attr) { - v := a.Value.Resolve() + w.writeName(a.Key) + w.writeValue(a.Value) +} + +// writeAttr writes a []slog.Attr as an object field. +func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) { + w.writeName(name) + w.writeAttrs(attrs) +} + +func (w *jsonWriter) writeValue(v slog.Value) { + v = v.Resolve() switch v.Kind() { case slog.KindAny: - w.writeStringField(a.Key, fmt.Sprint(v.Any())) + switch v := v.Any().(type) { + case []slog.Value: + w.writeArray(v) + case interface{ AppendJSON([]byte) []byte }: + w.buf.Write(v.AppendJSON(w.buf.AvailableBuffer())) + default: + w.writeString(fmt.Sprint(v)) + } case slog.KindBool: - w.writeBoolField(a.Key, v.Bool()) + w.writeBool(v.Bool()) case slog.KindDuration: - w.writeDurationField(a.Key, v.Duration()) + w.writeDuration(v.Duration()) case slog.KindFloat64: - w.writeFloat64Field(a.Key, v.Float64()) + w.writeFloat64(v.Float64()) case slog.KindInt64: - w.writeInt64Field(a.Key, v.Int64()) + w.writeInt64(v.Int64()) case slog.KindString: - w.writeStringField(a.Key, v.String()) + w.writeString(v.String()) case slog.KindTime: - w.writeTimeField(a.Key, v.Time()) + w.writeTime(v.Time()) case slog.KindUint64: - w.writeUint64Field(a.Key, v.Uint64()) + w.writeUint64(v.Uint64()) case slog.KindGroup: - w.writeAttrsField(a.Key, v.Group()) + w.writeAttrs(v.Group()) default: w.writeString("unhandled kind") } @@ -89,24 +107,41 @@ func (w *jsonWriter) writeName(name string) { w.buf.WriteByte(':') } -// writeObject writes an object-valued object field. -// The function f is called to write the contents. -func (w *jsonWriter) writeObjectField(name string, f func()) { - w.writeName(name) +func (w *jsonWriter) writeObject(f func()) { w.buf.WriteByte('{') f() w.buf.WriteByte('}') } -// writeRawField writes an field with a raw JSON value. -func (w *jsonWriter) writeRawField(name, v string) { +// writeObject writes an object-valued object field. +// The function f is called to write the contents. +func (w *jsonWriter) writeObjectField(name string, f func()) { w.writeName(name) + w.writeObject(f) +} + +func (w *jsonWriter) writeArray(vals []slog.Value) { + w.buf.WriteByte('[') + for i, v := range vals { + if i != 0 { + w.buf.WriteByte(',') + } + w.writeValue(v) + } + w.buf.WriteByte(']') +} + +func (w *jsonWriter) writeRaw(v string) { w.buf.WriteString(v) } -// writeBoolField writes a bool-valued object field. -func (w *jsonWriter) writeBoolField(name string, v bool) { +// writeRawField writes a field with a raw JSON value. +func (w *jsonWriter) writeRawField(name, v string) { w.writeName(name) + w.writeRaw(v) +} + +func (w *jsonWriter) writeBool(v bool) { if v { w.buf.WriteString("true") } else { @@ -114,40 +149,62 @@ func (w *jsonWriter) writeBoolField(name string, v bool) { } } +// writeBoolField writes a bool-valued object field. +func (w *jsonWriter) writeBoolField(name string, v bool) { + w.writeName(name) + w.writeBool(v) +} + +// writeDuration writes a duration as milliseconds. +func (w *jsonWriter) writeDuration(v time.Duration) { + fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond) +} + // writeDurationField writes a millisecond duration-valued object field. func (w *jsonWriter) writeDurationField(name string, v time.Duration) { w.writeName(name) - fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond) + w.writeDuration(v) +} + +func (w *jsonWriter) writeFloat64(v float64) { + w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64)) } // writeFloat64Field writes an float64-valued object field. func (w *jsonWriter) writeFloat64Field(name string, v float64) { w.writeName(name) - w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64)) + w.writeFloat64(v) +} + +func (w *jsonWriter) writeInt64(v int64) { + w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10)) } // writeInt64Field writes an int64-valued object field. func (w *jsonWriter) writeInt64Field(name string, v int64) { w.writeName(name) - w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10)) + w.writeInt64(v) +} + +func (w *jsonWriter) writeUint64(v uint64) { + w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10)) } // writeUint64Field writes a uint64-valued object field. func (w *jsonWriter) writeUint64Field(name string, v uint64) { w.writeName(name) - w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10)) + w.writeUint64(v) } -// writeStringField writes a string-valued object field. -func (w *jsonWriter) writeStringField(name, v string) { - w.writeName(name) - w.writeString(v) +// writeTime writes a time as seconds since the Unix epoch. +func (w *jsonWriter) writeTime(v time.Time) { + fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond)) } // writeTimeField writes a time-valued object field. func (w *jsonWriter) writeTimeField(name string, v time.Time) { w.writeName(name) - fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond)) + w.writeTime(v) } func jsonSafeSet(c byte) bool { @@ -192,3 +249,9 @@ func (w *jsonWriter) writeString(v string) { } w.buf.WriteByte('"') } + +// writeStringField writes a string-valued object field. +func (w *jsonWriter) writeStringField(name, v string) { + w.writeName(name) + w.writeString(v) +} diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go index 119f5d16a..e98b11838 100644 --- a/internal/quic/qlog_test.go +++ b/internal/quic/qlog_test.go @@ -55,6 +55,41 @@ func TestQLogHandshake(t *testing.T) { }) } +func TestQLogPacketFrames(t *testing.T) { + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) + tc.handshake() + tc.conn.Abort(nil) + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{}) + tc.advanceToTimer() // let the conn finish draining + + qr.wantEvents(t, jsonEvent{ + "name": "transport:packet_sent", + "data": map[string]any{ + "header": map[string]any{ + "packet_type": "initial", + "packet_number": 0, + "dcid": hex.EncodeToString(testLocalConnID(-1)), + "scid": hex.EncodeToString(testLocalConnID(0)), + }, + "frames": []any{ + map[string]any{"frame_type": "crypto"}, + }, + }, + }, jsonEvent{ + "name": "transport:packet_received", + "data": map[string]any{ + "header": map[string]any{ + "packet_type": "initial", + "packet_number": 0, + "dcid": hex.EncodeToString(testLocalConnID(0)), + "scid": hex.EncodeToString(testPeerConnID(0)), + }, + "frames": []any{map[string]any{"frame_type": "crypto"}}, + }, + }) +} + func TestQLogConnectionClosedTrigger(t *testing.T) { for _, test := range []struct { trigger string @@ -137,21 +172,60 @@ func (j jsonEvent) String() string { return string(b) } -// eventPartialEqual verifies that every field set in want matches the corresponding field in got. -// It ignores additional fields in got. -func eventPartialEqual(got, want jsonEvent) bool { - for k := range want { - ge, gok := got[k].(map[string]any) - we, wok := want[k].(map[string]any) - if gok && wok { - if !eventPartialEqual(ge, we) { - return false +// jsonPartialEqual compares two JSON structures. +// It ignores fields not set in want (see below for specifics). +func jsonPartialEqual(got, want any) (equal bool) { + cmpval := func(v any) any { + // Map certain types to a common representation. + switch v := v.(type) { + case int: + // JSON uses float64s rather than ints for numbers. + // Map int->float64 so we can use integers in expectations. + return float64(v) + case jsonEvent: + return (map[string]any)(v) + case []jsonEvent: + s := []any{} + for _, e := range v { + s = append(s, e) } - } else { - if !reflect.DeepEqual(got[k], want[k]) { + return s + } + return v + } + got = cmpval(got) + want = cmpval(want) + if reflect.TypeOf(got) != reflect.TypeOf(want) { + return false + } + switch w := want.(type) { + case nil: + // Match anything. + case map[string]any: + // JSON object: Every field in want must match a field in got. + g := got.(map[string]any) + for k := range w { + if !jsonPartialEqual(g[k], w[k]) { return false } } + case []any: + // JSON slice: Every field in want must match a field in got, in order. + // So want=[2,4] matches got=[1,2,3,4] but not [4,2]. + g := got.([]any) + for _, ge := range g { + if jsonPartialEqual(ge, w[0]) { + w = w[1:] + if len(w) == 0 { + return true + } + } + } + return false + default: + if !reflect.DeepEqual(got, want) { + return false + } } return true } @@ -179,6 +253,7 @@ func (q *qlogRecord) Close() error { return nil } // config may be passed to newTestConn to configure the conn to use this logger. func (q *qlogRecord) config(c *Config) { c.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: QLogLevelFrame, NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) { return q, nil }, @@ -189,14 +264,7 @@ func (q *qlogRecord) config(c *Config) { func (q *qlogRecord) wantEvents(t *testing.T, want ...jsonEvent) { t.Helper() got := q.ev - unseen := want - for _, g := range got { - if eventPartialEqual(g, unseen[0]) { - unseen = unseen[1:] - if len(unseen) == 0 { - return - } - } + if !jsonPartialEqual(got, want) { + t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want) } - t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want) } diff --git a/internal/quic/quic.go b/internal/quic/quic.go index e4d0d77c7..3e62d7cd9 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -144,6 +144,17 @@ const ( streamTypeCount ) +func (s streamType) qlogString() string { + switch s { + case bidiStream: + return "bidirectional" + case uniStream: + return "unidirectional" + default: + return "BUG" + } +} + func (s streamType) String() string { switch s { case bidiStream: From c1b6eee3f608179effef5e5964776391ef81e619 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 17 Nov 2023 14:49:16 -0800 Subject: [PATCH 09/70] quic: send occasional ack-eliciting packets A receiver that is sending only non-ack-eliciting packets (for example, a connection reading data from a stream but not sending anything other than ACKs in response) can accumulate a large amount of state for in-flight, unacknowledged packets. Add an occasional PING frame when in this state, to cause the peer to send an ACK for our outstanding packets. Change-Id: Iaf6b5a9735fa356fdebaff24200420a280b0c9a5 Reviewed-on: https://go-review.googlesource.com/c/net/+/545215 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn_send.go | 30 ++++++++++++++++++++----- internal/quic/conn_send_test.go | 40 +++++++++++++++++++++++++++++++++ internal/quic/loss.go | 8 +++++++ 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 internal/quic/conn_send_test.go diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index e2240f2fd..a8d930898 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -222,11 +222,7 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, // Either we are willing to send an ACK-only packet, // or we've added additional frames. c.acks[space].sentAck() - if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() { - // The peer has initiated a key update. - // We haven't sent them any packets yet in the new phase. - // Make this an ack-eliciting packet. - // Their ack of this packet will complete the key update. + if !c.w.sent.ackEliciting && c.shouldMakePacketAckEliciting() { c.w.appendPingFrame() } }() @@ -331,6 +327,30 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, } } +// shouldMakePacketAckEliciting is called when sending a packet containing nothing but an ACK frame. +// It reports whether we should add a PING frame to the packet to make it ack-eliciting. +func (c *Conn) shouldMakePacketAckEliciting() bool { + if c.keysAppData.needAckEliciting() { + // The peer has initiated a key update. + // We haven't sent them any packets yet in the new phase. + // Make this an ack-eliciting packet. + // Their ack of this packet will complete the key update. + return true + } + if c.loss.consecutiveNonAckElicitingPackets >= 19 { + // We've sent a run of non-ack-eliciting packets. + // Add in an ack-eliciting one every once in a while so the peer + // lets us know which ones have arrived. + // + // Google QUICHE injects a PING after sending 19 packets. We do the same. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2 + return true + } + // TODO: Consider making every packet sent when in PTO ack-eliciting to speed up recovery. + return false +} + func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { seen, delay := c.acks[space].acksToSend(now) if len(seen) == 0 { diff --git a/internal/quic/conn_send_test.go b/internal/quic/conn_send_test.go new file mode 100644 index 000000000..822783c41 --- /dev/null +++ b/internal/quic/conn_send_test.go @@ -0,0 +1,40 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" + "time" +) + +func TestAckElicitingAck(t *testing.T) { + // "A receiver that sends only non-ack-eliciting packets [...] might not receive + // an acknowledgment for a long period of time. + // [...] a receiver could send a [...] ack-eliciting frame occasionally [...] + // to elicit an ACK from the peer." + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2 + // + // Send a bunch of ack-eliciting packets, verify that the conn doesn't just + // send ACKs in response. + tc := newTestConn(t, clientSide, permissiveTransportParameters) + tc.handshake() + const count = 100 + for i := 0; i < count; i++ { + tc.advance(1 * time.Millisecond) + tc.writeFrames(packetType1RTT, + debugFramePing{}, + ) + got, _ := tc.readFrame() + switch got.(type) { + case debugFrameAck: + continue + case debugFramePing: + return + } + } + t.Errorf("after sending %v PINGs, got no ack-eliciting response", count) +} diff --git a/internal/quic/loss.go b/internal/quic/loss.go index 4a0767bd0..a59081fd5 100644 --- a/internal/quic/loss.go +++ b/internal/quic/loss.go @@ -50,6 +50,9 @@ type lossState struct { // https://www.rfc-editor.org/rfc/rfc9000#section-8-2 antiAmplificationLimit int + // Count of non-ack-eliciting packets (ACKs) sent since the last ack-eliciting one. + consecutiveNonAckElicitingPackets int + rtt rttState pacer pacerState cc *ccReno @@ -192,6 +195,11 @@ func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacke } c.scheduleTimer(now) } + if sent.ackEliciting { + c.consecutiveNonAckElicitingPackets = 0 + } else { + c.consecutiveNonAckElicitingPackets++ + } } // datagramReceived records a datagram (not packet!) received from the peer. From 08a78b1eeae5f15e658ca8972aa74b6857e3b37b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 20 Nov 2023 16:41:14 -0800 Subject: [PATCH 10/70] quic: unblock operations when closing conns Blocking operations associated with a connection, such as accepting a stream or writing data to a stream, should be canceled when the connection is closed. Change-Id: I3b25789885a6c1a2b5aa2178a8d6219a8ea77cbb Reviewed-on: https://go-review.googlesource.com/c/net/+/545216 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Auto-Submit: Damien Neil --- internal/quic/conn.go | 12 ++++--- internal/quic/conn_async_test.go | 15 ++++---- internal/quic/conn_close.go | 15 ++++++-- internal/quic/conn_close_test.go | 61 ++++++++++++++++++++++++++++++++ internal/quic/conn_streams.go | 11 ++++++ internal/quic/endpoint_test.go | 2 +- internal/quic/stream.go | 41 ++++++++++++++++++++- internal/quic/stream_limits.go | 17 +++++++-- internal/quic/stream_test.go | 30 ++++++++++++---- 9 files changed, 180 insertions(+), 24 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 31e789b1d..4abc74030 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -263,10 +263,7 @@ var errIdleTimeout = errors.New("idle timeout") // The loop processes messages from c.msgc and timer events. // Other goroutines may examine or modify conn state by sending the loop funcs to execute. func (c *Conn) loop(now time.Time) { - defer close(c.donec) - defer c.tls.Close() - defer c.endpoint.connDrained(c) - defer c.logConnectionClosed() + defer c.cleanup() // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to @@ -346,6 +343,13 @@ func (c *Conn) loop(now time.Time) { } } +func (c *Conn) cleanup() { + c.logConnectionClosed() + c.endpoint.connDrained(c) + c.tls.Close() + close(c.donec) +} + // sendMsg sends a message to the conn's loop. // It does not wait for the message to be processed. // The conn may close before processing the message, in which case it is lost. diff --git a/internal/quic/conn_async_test.go b/internal/quic/conn_async_test.go index dc2a57f9d..fcc101d19 100644 --- a/internal/quic/conn_async_test.go +++ b/internal/quic/conn_async_test.go @@ -41,7 +41,7 @@ type asyncOp[T any] struct { err error caller string - state *asyncTestState + tc *testConn donec chan struct{} cancelFunc context.CancelFunc } @@ -55,7 +55,7 @@ func (a *asyncOp[T]) cancel() { default: } a.cancelFunc() - <-a.state.notify + <-a.tc.asyncTestState.notify select { case <-a.donec: default: @@ -73,6 +73,7 @@ var errNotDone = errors.New("async op is not done") // control over the progress of operations, an asyncOp can only // become done in reaction to the test taking some action. func (a *asyncOp[T]) result() (v T, err error) { + a.tc.wait() select { case <-a.donec: return a.v, a.err @@ -94,8 +95,8 @@ type asyncContextKey struct{} // The function f should call a blocking function such as // Stream.Write or Conn.AcceptStream and return its result. // It must use the provided context. -func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[T] { - as := &ts.asyncTestState +func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[T] { + as := &tc.asyncTestState if as.notify == nil { as.notify = make(chan struct{}) as.mu.Lock() @@ -106,7 +107,7 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[ ctx := context.WithValue(context.Background(), asyncContextKey{}, true) ctx, cancel := context.WithCancel(ctx) a := &asyncOp[T]{ - state: as, + tc: tc, caller: fmt.Sprintf("%v:%v", filepath.Base(file), line), donec: make(chan struct{}), cancelFunc: cancel, @@ -116,9 +117,9 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[ close(a.donec) as.notify <- struct{}{} }() - ts.t.Cleanup(func() { + tc.t.Cleanup(func() { if _, err := a.result(); err == errNotDone { - ts.t.Errorf("%v: async operation is still executing at end of test", a.caller) + tc.t.Errorf("%v: async operation is still executing at end of test", a.caller) a.cancel() } }) diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go index 246a12638..1798d0536 100644 --- a/internal/quic/conn_close.go +++ b/internal/quic/conn_close.go @@ -71,7 +71,10 @@ func (c *Conn) lifetimeInit() { c.lifetime.donec = make(chan struct{}) } -var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") +var ( + errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") + errConnClosed = errors.New("connection closed") +) // advance is called when time passes. func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { @@ -91,13 +94,21 @@ func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { // setState sets the conn state. func (c *Conn) setState(now time.Time, state connState) { + if c.lifetime.state == state { + return + } + c.lifetime.state = state switch state { case connStateClosing, connStateDraining: if c.lifetime.drainEndTime.IsZero() { c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) } + case connStateDone: + c.setFinalError(nil) + } + if state != connStateAlive { + c.streamsCleanup() } - c.lifetime.state = state } // confirmHandshake is called when the TLS handshake completes. diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go index 49881e62f..63d4911e8 100644 --- a/internal/quic/conn_close_test.go +++ b/internal/quic/conn_close_test.go @@ -216,3 +216,64 @@ func TestConnCloseClosedByEndpoint(t *testing.T) { code: errNo, }) } + +func testConnCloseUnblocks(t *testing.T, f func(context.Context, *testConn) error, opts ...any) { + tc := newTestConn(t, clientSide, opts...) + tc.handshake() + op := runAsync(tc, func(ctx context.Context) (struct{}, error) { + return struct{}{}, f(ctx, tc) + }) + if _, err := op.result(); err != errNotDone { + t.Fatalf("before abort, op = %v, want errNotDone", err) + } + tc.conn.Abort(nil) + if _, err := op.result(); err == nil || err == errNotDone { + t.Fatalf("after abort, op = %v, want error", err) + } +} + +func TestConnCloseUnblocksAcceptStream(t *testing.T) { + testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { + _, err := tc.conn.AcceptStream(ctx) + return err + }, permissiveTransportParameters) +} + +func TestConnCloseUnblocksNewStream(t *testing.T) { + testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { + _, err := tc.conn.NewStream(ctx) + return err + }) +} + +func TestConnCloseUnblocksStreamRead(t *testing.T) { + testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { + s := newLocalStream(t, tc, bidiStream) + buf := make([]byte, 16) + _, err := s.ReadContext(ctx, buf) + return err + }, permissiveTransportParameters) +} + +func TestConnCloseUnblocksStreamWrite(t *testing.T) { + testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { + s := newLocalStream(t, tc, bidiStream) + buf := make([]byte, 32) + _, err := s.WriteContext(ctx, buf) + return err + }, permissiveTransportParameters, func(c *Config) { + c.MaxStreamWriteBufferSize = 16 + }) +} + +func TestConnCloseUnblocksStreamClose(t *testing.T) { + testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { + s := newLocalStream(t, tc, bidiStream) + buf := make([]byte, 16) + _, err := s.WriteContext(ctx, buf) + if err != nil { + return err + } + return s.CloseContext(ctx) + }, permissiveTransportParameters) +} diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index 83ab5554c..818ec3e57 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -49,6 +49,17 @@ func (c *Conn) streamsInit() { c.inflowInit() } +func (c *Conn) streamsCleanup() { + c.streams.queue.close(errConnClosed) + c.streams.localLimit[bidiStream].connHasClosed() + c.streams.localLimit[uniStream].connHasClosed() + for _, s := range c.streams.streams { + if s != nil { + s.connHasClosed() + } + } +} + // AcceptStream waits for and returns the next stream created by the peer. func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { return c.streams.queue.get(ctx, c.testHooks) diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index f9fc80152..2a6daa076 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -48,7 +48,7 @@ func TestStreamTransfer(t *testing.T) { } }() - s, err := cli.NewStream(ctx) + s, err := cli.NewSendOnlyStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 36c80f6af..fb9c1cf3c 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math" ) type Stream struct { @@ -105,6 +106,11 @@ const ( dataQueue // streamsState.queueData ) +// streamResetByConnClose is assigned to Stream.inresetcode to indicate that a stream +// was implicitly reset when the connection closed. It's out of the range of +// possible reset codes the peer can send. +const streamResetByConnClose = math.MaxInt64 + // wantQueue returns the send queue the stream should be on. func (s streamState) wantQueue() streamQueue { switch { @@ -347,7 +353,15 @@ func (s *Stream) CloseContext(ctx context.Context) error { } s.CloseWrite() // TODO: Return code from peer's RESET_STREAM frame? - return s.conn.waitOnDone(ctx, s.outdone) + if err := s.conn.waitOnDone(ctx, s.outdone); err != nil { + return err + } + s.outgate.lock() + defer s.outUnlock() + if s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) { + return nil + } + return errors.New("stream reset") } // CloseRead aborts reads on the stream. @@ -437,6 +451,31 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) { s.outblocked.clear() } +// connHasClosed indicates the stream's conn has closed. +func (s *Stream) connHasClosed() { + // If we're in the closing state, the user closed the conn. + // Otherwise, we the peer initiated the close. + // This only matters for the error we're going to return from stream operations. + localClose := s.conn.lifetime.state == connStateClosing + + s.ingate.lock() + if !s.inset.isrange(0, s.insize) && s.inresetcode == -1 { + if localClose { + s.inclosed.set() + } else { + s.inresetcode = streamResetByConnClose + } + } + s.inUnlock() + + s.outgate.lock() + if localClose { + s.outclosed.set() + } + s.outreset.set() + s.outUnlock() +} + // inUnlock unlocks s.ingate. // It sets the gate condition if reads from s will not block. // If s has receive-related frames to write or if both directions diff --git a/internal/quic/stream_limits.go b/internal/quic/stream_limits.go index 2f42cf418..71cc29135 100644 --- a/internal/quic/stream_limits.go +++ b/internal/quic/stream_limits.go @@ -21,7 +21,7 @@ import ( type localStreamLimits struct { gate gate max int64 // peer-provided MAX_STREAMS - opened int64 // number of streams opened by us + opened int64 // number of streams opened by us, -1 when conn is closed } func (lim *localStreamLimits) init() { @@ -34,10 +34,21 @@ func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err if err := lim.gate.waitAndLock(ctx, c.testHooks); err != nil { return 0, err } - n := lim.opened + if lim.opened < 0 { + lim.gate.unlock(true) + return 0, errConnClosed + } + num = lim.opened lim.opened++ lim.gate.unlock(lim.opened < lim.max) - return n, nil + return num, nil +} + +// connHasClosed indicates the connection has been closed, locally or by the peer. +func (lim *localStreamLimits) connHasClosed() { + lim.gate.lock() + lim.opened = -1 + lim.gate.unlock(true) } // setMax sets the MAX_STREAMS provided by the peer. diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 93c8839ff..00e392dba 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -1047,11 +1047,13 @@ func TestStreamCloseUnblocked(t *testing.T) { for _, test := range []struct { name string unblock func(tc *testConn, s *Stream) + success bool }{{ name: "data received", unblock: func(tc *testConn, s *Stream) { tc.writeAckForAll() }, + success: true, }, { name: "stop sending received", unblock: func(tc *testConn, s *Stream) { @@ -1094,7 +1096,13 @@ func TestStreamCloseUnblocked(t *testing.T) { t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err) } test.unblock(tc, s) - if _, err := closing.result(); err != nil { + _, err := closing.result() + switch { + case err == errNotDone: + t.Fatalf("s.CloseContext() still blocking; want it to have returned") + case err == nil && !test.success: + t.Fatalf("s.CloseContext() = nil, want error") + case err != nil && test.success: t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err) } }) @@ -1390,31 +1398,41 @@ func newTestConnAndStream(t *testing.T, side connSide, sside streamSide, styp st func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) { t.Helper() - ctx := canceledContext() tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) + return tc, newLocalStream(t, tc, styp) +} + +func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream { + t.Helper() + ctx := canceledContext() s, err := tc.conn.newLocalStream(ctx, styp) if err != nil { t.Fatalf("conn.newLocalStream(%v) = %v", styp, err) } - return tc, s + return s } func newTestConnAndRemoteStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) { t.Helper() - ctx := canceledContext() tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) + return tc, newRemoteStream(t, tc, styp) +} + +func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream { + t.Helper() + ctx := canceledContext() tc.writeFrames(packetType1RTT, debugFrameStream{ - id: newStreamID(side.peer(), styp, 0), + id: newStreamID(tc.conn.side.peer(), styp, 0), }) s, err := tc.conn.AcceptStream(ctx) if err != nil { t.Fatalf("conn.AcceptStream() = %v", err) } - return tc, s + return s } // permissiveTransportParameters may be passed as an option to newTestConn. From 65efbad9474a514a2f3c08716b8cf38011fa2736 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 5 Dec 2023 15:05:16 -0800 Subject: [PATCH 11/70] quic: avoid leaking tls goroutines in tests Change-Id: Iaf273294ba3245bfeb387a72e068c048d0fcf93a Reviewed-on: https://go-review.googlesource.com/c/net/+/547736 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_test.go | 3 +++ internal/quic/main_test.go | 52 +++++++++++++++++++++++++++++++++++++ internal/quic/retry_test.go | 4 ++- 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 internal/quic/main_test.go diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index c57ba1487..b48bee803 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -279,6 +279,9 @@ func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testC } tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) tc.peerTLSConn.Start(context.Background()) + t.Cleanup(func() { + tc.peerTLSConn.Close() + }) return tc } diff --git a/internal/quic/main_test.go b/internal/quic/main_test.go new file mode 100644 index 000000000..5ad0042fa --- /dev/null +++ b/internal/quic/main_test.go @@ -0,0 +1,52 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "fmt" + "os" + "runtime" + "testing" + "time" +) + +func TestMain(m *testing.M) { + defer os.Exit(m.Run()) + + // Look for leaked goroutines. + // + // Checking after every test makes it easier to tell which test is the culprit, + // but checking once at the end is faster and less likely to miss something. + start := time.Now() + warned := false + for { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + leaked := false + for _, g := range bytes.Split(buf, []byte("\n\n")) { + if bytes.Contains(g, []byte("quic.TestMain")) || + bytes.Contains(g, []byte("created by os/signal.Notify")) || + bytes.Contains(g, []byte("gotraceback_test.go")) { + continue + } + leaked = true + } + if !leaked { + break + } + if !warned && time.Since(start) > 1*time.Second { + // Print a warning quickly, in case this is an interactive session. + // Keep waiting until the test times out, in case this is a slow trybot. + fmt.Printf("Tests seem to have leaked some goroutines, still waiting.\n\n") + fmt.Print(string(buf)) + warned = true + } + // Goroutines might still be shutting down. + time.Sleep(1 * time.Millisecond) + } +} diff --git a/internal/quic/retry_test.go b/internal/quic/retry_test.go index 4a21a4ca1..8f36e1bd3 100644 --- a/internal/quic/retry_test.go +++ b/internal/quic/retry_test.go @@ -533,7 +533,9 @@ func initialClientCrypto(t *testing.T, e *testEndpoint, p transportParameters) [ tlsClient := tls.QUICClient(config) tlsClient.SetTransportParameters(marshalTransportParameters(p)) tlsClient.Start(context.Background()) - //defer tlsClient.Close() + t.Cleanup(func() { + tlsClient.Close() + }) e.peerTLSConn = tlsClient var data []byte for { From 577e44a5cee023bd639dd2dcc4008644bcb71472 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 8 Dec 2023 07:44:43 -0800 Subject: [PATCH 12/70] quic: skip leaked goroutine check on GOOS=js Fixes golang/go#64620 Change-Id: I3b5ff4d1e1132a47b7cc7eb00861e9f7b76f8764 Reviewed-on: https://go-review.googlesource.com/c/net/+/548455 Auto-Submit: Damien Neil Reviewed-by: Bryan Mills LUCI-TryBot-Result: Go LUCI --- internal/quic/main_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/quic/main_test.go b/internal/quic/main_test.go index 5ad0042fa..ecd0b1e9f 100644 --- a/internal/quic/main_test.go +++ b/internal/quic/main_test.go @@ -22,6 +22,11 @@ func TestMain(m *testing.M) { // // Checking after every test makes it easier to tell which test is the culprit, // but checking once at the end is faster and less likely to miss something. + if runtime.GOOS == "js" { + // The js-wasm runtime creates an additional background goroutine. + // Just skip the leak check there. + return + } start := time.Now() warned := false for { From b952594c266f3a75031e9ba2b43483a735526d39 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 20 Nov 2023 15:52:36 -0800 Subject: [PATCH 13/70] quic: fix data race in connection close We were failing to hold streamsState.streamsMu when removing a closed stream from the conn's stream map. Rework this to remove the mutex entirely. The only access to the map that isn't on the conn's loop is during stream creation. Send a message to the loop to register the stream instead of using a mutex. Change-Id: I2e87089e87c61a6ade8219dfb8acec3809bf95de Reviewed-on: https://go-review.googlesource.com/c/net/+/545217 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn.go | 31 +++++++++++++++++++-- internal/quic/conn_async_test.go | 1 + internal/quic/conn_streams.go | 20 ++++++-------- internal/quic/conn_streams_test.go | 44 ++++++++++++++++++++++++++++++ internal/quic/conn_test.go | 19 ++++++++----- 5 files changed, 93 insertions(+), 22 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 4abc74030..6d79013eb 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -369,12 +369,37 @@ func (c *Conn) wake() { } // runOnLoop executes a function within the conn's loop goroutine. -func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error { +func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error { donec := make(chan struct{}) - c.sendMsg(func(now time.Time, c *Conn) { + msg := func(now time.Time, c *Conn) { defer close(donec) f(now, c) - }) + } + if c.testHooks != nil { + // In tests, we can't rely on being able to send a message immediately: + // c.msgc might be full, and testConnHooks.nextMessage might be waiting + // for us to block before it processes the next message. + // To avoid a deadlock, we send the message in waitUntil. + // If msgc is empty, the message is buffered. + // If msgc is full, we block and let nextMessage process the queue. + msgc := c.msgc + c.testHooks.waitUntil(ctx, func() bool { + for { + select { + case msgc <- msg: + msgc = nil // send msg only once + case <-donec: + return true + case <-c.donec: + return true + default: + return false + } + } + }) + } else { + c.sendMsg(msg) + } select { case <-donec: case <-c.donec: diff --git a/internal/quic/conn_async_test.go b/internal/quic/conn_async_test.go index fcc101d19..4671f8340 100644 --- a/internal/quic/conn_async_test.go +++ b/internal/quic/conn_async_test.go @@ -125,6 +125,7 @@ func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[ }) // Wait for the operation to either finish or block. <-as.notify + tc.wait() return a } diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index 818ec3e57..dc82f8b0f 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -14,10 +14,8 @@ import ( ) type streamsState struct { - queue queue[*Stream] // new, peer-created streams - - streamsMu sync.Mutex - streams map[streamID]*Stream + queue queue[*Stream] // new, peer-created streams + streams map[streamID]*Stream // Limits on the number of streams, indexed by streamType. localLimit [streamTypeCount]localStreamLimits @@ -82,9 +80,6 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) { } func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) { - c.streams.streamsMu.Lock() - defer c.streams.streamsMu.Unlock() - num, err := c.streams.localLimit[styp].open(ctx, c) if err != nil { return nil, err @@ -100,7 +95,12 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er s.inUnlock() s.outUnlock() - c.streams.streams[s.id] = s + // Modify c.streams on the conn's loop. + if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) { + c.streams.streams[s.id] = s + }); err != nil { + return nil, err + } return s, nil } @@ -119,8 +119,6 @@ const ( // streamForID returns the stream with the given id. // If the stream does not exist, it returns nil. func (c *Conn) streamForID(id streamID) *Stream { - c.streams.streamsMu.Lock() - defer c.streams.streamsMu.Unlock() return c.streams.streams[id] } @@ -146,8 +144,6 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) } } - c.streams.streamsMu.Lock() - defer c.streams.streamsMu.Unlock() s, isOpen := c.streams.streams[id] if s != nil { return s diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index c90354db8..90f5cb75c 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "math" + "sync" "testing" ) @@ -478,3 +479,46 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) { t.Fatalf("after test, stream send queue is not empty; should be") } } + +func TestStreamsCreateConcurrency(t *testing.T) { + cli, srv := newLocalConnPair(t, &Config{}, &Config{}) + + srvdone := make(chan int) + go func() { + defer close(srvdone) + for streams := 0; ; streams++ { + s, err := srv.AcceptStream(context.Background()) + if err != nil { + srvdone <- streams + return + } + s.Close() + } + }() + + var wg sync.WaitGroup + const concurrency = 10 + const streams = 10 + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < streams; j++ { + s, err := cli.NewStream(context.Background()) + if err != nil { + t.Errorf("NewStream: %v", err) + return + } + s.Flush() + s.Close() + } + }() + } + wg.Wait() + + cli.Abort(nil) + srv.Abort(nil) + if got, want := <-srvdone, concurrency*streams; got != want { + t.Errorf("accepted %v streams, want %v", got, want) + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index b48bee803..058aa7edc 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -30,10 +30,12 @@ func TestConnTestConn(t *testing.T) { t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) } - var ranAt time.Time - tc.conn.runOnLoop(func(now time.Time, c *Conn) { - ranAt = now - }) + ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { + tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { + when = now + }) + return + }).result() if !ranAt.Equal(tc.endpoint.now) { t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) } @@ -41,9 +43,12 @@ func TestConnTestConn(t *testing.T) { nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) tc.advanceTo(nextTime) - tc.conn.runOnLoop(func(now time.Time, c *Conn) { - ranAt = now - }) + ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { + tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { + when = now + }) + return + }).result() if !ranAt.Equal(nextTime) { t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime) } From b0eb4d6c942abf81c513c88af3ea23aaaaa5a4e0 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 27 Nov 2023 16:42:25 -0800 Subject: [PATCH 14/70] quic: compute pnum len from max ack received, not sent QUIC packet numbers are truncated to include only the least significant bits of the packet number. The number of bits which must be retained is computed based on the largest packet number known to have been received by the peer. See RFC 9000, section 17.1. We were incorrectly using the largest packet number we have received *from* the peer. Oops. (Test infrastructure change: Include the header byte in the testPacket structure, so we can see how many bytes the packet number was encoded with. Ignore this byte when comparing packets.) Change-Id: Iec17c69f007f8b39d14d24b0ca216c6a0018ae22 Reviewed-on: https://go-review.googlesource.com/c/net/+/545575 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_send.go | 6 ++--- internal/quic/conn_send_test.go | 43 +++++++++++++++++++++++++++++++++ internal/quic/conn_test.go | 15 ++++++++++-- internal/quic/endpoint_test.go | 3 +-- internal/quic/tls_test.go | 3 +-- 5 files changed, 61 insertions(+), 9 deletions(-) diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index a8d930898..c2d8d146b 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -60,7 +60,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { pad := false var sentInitial *sentPacket if c.keysInitial.canWrite() { - pnumMaxAcked := c.acks[initialSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked pnum := c.loss.nextNumber(initialSpace) p := longPacket{ ptype: packetTypeInitial, @@ -93,7 +93,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Handshake packet. if c.keysHandshake.canWrite() { - pnumMaxAcked := c.acks[handshakeSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked pnum := c.loss.nextNumber(handshakeSpace) p := longPacket{ ptype: packetTypeHandshake, @@ -124,7 +124,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // 1-RTT packet. if c.keysAppData.canWrite() { - pnumMaxAcked := c.acks[appDataSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked pnum := c.loss.nextNumber(appDataSpace) c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) c.appendFrames(now, appDataSpace, pnum, limit) diff --git a/internal/quic/conn_send_test.go b/internal/quic/conn_send_test.go index 822783c41..2205ff2f7 100644 --- a/internal/quic/conn_send_test.go +++ b/internal/quic/conn_send_test.go @@ -38,3 +38,46 @@ func TestAckElicitingAck(t *testing.T) { } t.Errorf("after sending %v PINGs, got no ack-eliciting response", count) } + +func TestSendPacketNumberSize(t *testing.T) { + tc := newTestConn(t, clientSide, permissiveTransportParameters) + tc.handshake() + + recvPing := func() *testPacket { + t.Helper() + tc.conn.ping(appDataSpace) + p := tc.readPacket() + if p == nil { + t.Fatalf("want packet containing PING, got none") + } + return p + } + + // Desynchronize the packet numbers the conn is sending and the ones it is receiving, + // by having the conn send a number of unacked packets. + for i := 0; i < 16; i++ { + recvPing() + } + + // Establish the maximum packet number the conn has received an ACK for. + maxAcked := recvPing().num + tc.writeAckForAll() + + // Make the conn send a sequence of packets. + // Check that the packet number is encoded with two bytes once the difference between the + // current packet and the max acked one is sufficiently large. + for want := maxAcked + 1; want < maxAcked+0x100; want++ { + p := recvPing() + if p.num != want { + t.Fatalf("received packet number %v, want %v", p.num, want) + } + gotPnumLen := int(p.header&0x03) + 1 + wantPnumLen := 1 + if p.num-maxAcked >= 0x80 { + wantPnumLen = 2 + } + if gotPnumLen != wantPnumLen { + t.Fatalf("packet number 0x%x encoded with %v bytes, want %v (max acked = %v)", p.num, gotPnumLen, wantPnumLen, maxAcked) + } + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 058aa7edc..abf7eede7 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -82,6 +82,7 @@ func (d testDatagram) String() string { type testPacket struct { ptype packetType + header byte version uint32 num packetNumber keyPhaseBit bool @@ -599,12 +600,18 @@ func (tc *testConn) readFrame() (debugFrame, packetType) { func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { tc.t.Helper() got := tc.readDatagram() - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } func datagramEqual(a, b *testDatagram) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } if a.paddedSize != b.paddedSize || a.addr != b.addr || len(a.packets) != len(b.packets) { @@ -622,7 +629,7 @@ func datagramEqual(a, b *testDatagram) bool { func (tc *testConn) wantPacket(expectation string, want *testPacket) { tc.t.Helper() got := tc.readPacket() - if !reflect.DeepEqual(got, want) { + if !packetEqual(got, want) { tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want) } } @@ -630,8 +637,10 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) { func packetEqual(a, b *testPacket) bool { ac := *a ac.frames = nil + ac.header = 0 bc := *b bc.frames = nil + bc.header = 0 if !reflect.DeepEqual(ac, bc) { return false } @@ -839,6 +848,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) } d.packets = append(d.packets, &testPacket{ ptype: p.ptype, + header: buf[0], version: p.version, num: p.num, dstConnID: p.dstConnID, @@ -880,6 +890,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) } d.packets = append(d.packets, &testPacket{ ptype: packetType1RTT, + header: hdr[0], num: pnum, dstConnID: hdr[1:][:len(tc.peerConnID)], keyPhaseBit: hdr[0]&keyPhaseBit != 0, diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index 2a6daa076..452d26052 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -13,7 +13,6 @@ import ( "io" "net" "net/netip" - "reflect" "testing" "time" ) @@ -242,7 +241,7 @@ func (te *testEndpoint) readDatagram() *testDatagram { func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) { te.t.Helper() got := te.readDatagram() - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index 14f74a00a..9c1dd364e 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "reflect" "testing" "time" ) @@ -56,7 +55,7 @@ func (tc *testConn) handshake() { fillCryptoFrames(want, tc.cryptoDataOut) i++ } - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want) } if i >= len(dgrams) { From 1e59a7e58ce15106ab0248605c5de0701624072b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 28 Nov 2023 09:17:02 -0800 Subject: [PATCH 15/70] quic/qlog: correctly write negative durations "-10.000001", not "10.-000001". Change-Id: I84f6487bad15ab3a190e73e655236376b1781e85 Reviewed-on: https://go-review.googlesource.com/c/net/+/545576 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/qlog/json_writer.go | 4 ++++ internal/quic/qlog/json_writer_test.go | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/internal/quic/qlog/json_writer.go b/internal/quic/qlog/json_writer.go index 3950ab42f..b2fa3e03e 100644 --- a/internal/quic/qlog/json_writer.go +++ b/internal/quic/qlog/json_writer.go @@ -157,6 +157,10 @@ func (w *jsonWriter) writeBoolField(name string, v bool) { // writeDuration writes a duration as milliseconds. func (w *jsonWriter) writeDuration(v time.Duration) { + if v < 0 { + w.buf.WriteByte('-') + v = -v + } fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond) } diff --git a/internal/quic/qlog/json_writer_test.go b/internal/quic/qlog/json_writer_test.go index 7ba5e1737..6da556641 100644 --- a/internal/quic/qlog/json_writer_test.go +++ b/internal/quic/qlog/json_writer_test.go @@ -124,9 +124,10 @@ func TestJSONWriterBoolField(t *testing.T) { func TestJSONWriterDurationField(t *testing.T) { w := newTestJSONWriter() w.writeRecordStart() - w.writeDurationField("field", (10*time.Millisecond)+(2*time.Nanosecond)) + w.writeDurationField("field1", (10*time.Millisecond)+(2*time.Nanosecond)) + w.writeDurationField("field2", -((10 * time.Millisecond) + (2 * time.Nanosecond))) w.writeRecordEnd() - wantJSONRecord(t, w, `{"field":10.000002}`) + wantJSONRecord(t, w, `{"field1":10.000002,"field2":-10.000002}`) } func TestJSONWriterFloat64Field(t *testing.T) { From 2b416c3c961a9829f7ca97dd44690e71719f68f2 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 28 Nov 2023 09:20:32 -0800 Subject: [PATCH 16/70] quic/qlog: create log files with O_EXCL Avoid confusing log corruption if two loggers try to write to the same file simultaneously. Change-Id: I3bfbcf56aa55c778ada0178d7c662c414878c9d1 Reviewed-on: https://go-review.googlesource.com/c/net/+/545577 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/qlog/qlog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/quic/qlog/qlog.go b/internal/quic/qlog/qlog.go index 0e71d71aa..e54c839f0 100644 --- a/internal/quic/qlog/qlog.go +++ b/internal/quic/qlog/qlog.go @@ -180,7 +180,7 @@ func newTraceWriter(opts HandlerOptions, info TraceInfo) (io.WriteCloser, error) if !filepath.IsLocal(filename) { return nil, errors.New("invalid trace filename") } - w, err = os.Create(filepath.Join(opts.Dir, filename)) + w, err = os.OpenFile(filepath.Join(opts.Dir, filename), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0666) } else { err = errors.New("no log destination") } From c337daf7db6b2f45306e9b972588478201259c0d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 28 Nov 2023 09:19:54 -0800 Subject: [PATCH 17/70] quic: enable qlog output in tests Set QLOG=/some/dir to enable qlog logging in tests. Change-Id: Id4006c66fd555ad0ca47914d0af9f9ab46467c9c Reviewed-on: https://go-review.googlesource.com/c/net/+/550796 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_test.go | 12 +++++++++++- internal/quic/endpoint_test.go | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index abf7eede7..ddf0740e2 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -13,15 +13,21 @@ import ( "errors" "flag" "fmt" + "log/slog" "math" "net/netip" "reflect" "strings" "testing" "time" + + "golang.org/x/net/internal/quic/qlog" ) -var testVV = flag.Bool("vv", false, "even more verbose test output") +var ( + testVV = flag.Bool("vv", false, "even more verbose test output") + qlogdir = flag.String("qlog", "", "write qlog logs to directory") +) func TestConnTestConn(t *testing.T) { tc := newTestConn(t, serverSide) @@ -199,6 +205,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { config := &Config{ TLSConfig: newTestTLSConfig(side), StatelessResetKey: testStatelessResetKey, + QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: QLogLevelFrame, + Dir: *qlogdir, + })), } var cids newServerConnIDs if side == serverSide { diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index 452d26052..ab6cd1cf5 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -11,10 +11,13 @@ import ( "context" "crypto/tls" "io" + "log/slog" "net" "net/netip" "testing" "time" + + "golang.org/x/net/internal/quic/qlog" ) func TestConnect(t *testing.T) { @@ -83,6 +86,12 @@ func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } + if conf.QLogLogger == nil { + conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: QLogLevelFrame, + Dir: *qlogdir, + })) + } e, err := Listen("udp", "127.0.0.1:0", conf) if err != nil { t.Fatal(err) From f9726a9e4a0fba67ce78802b47601ba194d15b3f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 11 Dec 2023 13:54:56 -0800 Subject: [PATCH 18/70] quic: fix packet size logging The qlog schema puts packet sizes as part of a "raw" field of type RawInfo, not in the packet_sent/packet_received event. Move to the correct location. Change-Id: I4308d4bdb961cf83e29af014b60f50ed029cb915 Reviewed-on: https://go-review.googlesource.com/c/net/+/550797 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn_send.go | 6 +++--- internal/quic/packet_writer.go | 5 +++++ internal/quic/qlog.go | 15 ++++++++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index c2d8d146b..ccb467591 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -76,7 +76,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { - c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { @@ -108,7 +108,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { - c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { c.idleHandlePacketSent(now, sent) @@ -139,7 +139,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) } if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { - c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.payload()) + c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { c.idleHandlePacketSent(now, sent) diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 0c2b2ee41..b4e54ce4b 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -47,6 +47,11 @@ func (w *packetWriter) datagram() []byte { return w.b } +// packet returns the size of the current packet. +func (w *packetWriter) packetLen() int { + return len(w.b[w.pktOff:]) + aeadOverhead +} + // payload returns the payload of the current packet. func (w *packetWriter) payload() []byte { return w.b[w.payOff:] diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index fea8b38ee..82ad92ac8 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -148,8 +148,6 @@ func (c *Conn) logConnectionClosed() { } func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { - pnumLen := 1 + int(pkt[0]&0x03) - length := pnumLen + len(p.payload) var frames slog.Attr if c.logEnabled(QLogLevelFrame) { frames = c.packetFramesAttr(p.payload) @@ -162,7 +160,9 @@ func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { slog.Uint64("flags", uint64(pkt[0])), slogHexstring("scid", p.srcConnID), slogHexstring("dcid", p.dstConnID), - slog.Int("length", length), + ), + slog.Group("raw", + slog.Int("length", len(pkt)), ), frames, ) @@ -180,14 +180,16 @@ func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) { slog.String("packet_type", packetType1RTT.qlogString()), slog.Uint64("packet_number", uint64(p.num)), slog.Uint64("flags", uint64(pkt[0])), - slog.String("scid", ""), slogHexstring("dcid", dstConnID), ), + slog.Group("raw", + slog.Int("length", len(pkt)), + ), frames, ) } -func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst, payload []byte) { +func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst []byte, pktLen int, payload []byte) { var frames slog.Attr if c.logEnabled(QLogLevelFrame) { frames = c.packetFramesAttr(payload) @@ -204,6 +206,9 @@ func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst, payl scid, slogHexstring("dcid", dst), ), + slog.Group("raw", + slog.Int("length", pktLen), + ), frames, ) } From c136d0c937afa54dca414a69603bb1570a28879f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Dec 2023 09:01:27 -0800 Subject: [PATCH 19/70] quic: avoid panic when PTO expires and implicitly-created streams exist The streams map contains nil entries for implicitly-created streams. (Receiving a packet for stream N implicitly creates all streams of the same type LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_streams.go | 45 ++++++++++++++++++++---------- internal/quic/conn_streams_test.go | 35 +++++++++++++++++++++++ 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index dc82f8b0f..87cfd297e 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -14,8 +14,16 @@ import ( ) type streamsState struct { - queue queue[*Stream] // new, peer-created streams - streams map[streamID]*Stream + queue queue[*Stream] // new, peer-created streams + + // All peer-created streams. + // + // Implicitly created streams are included as an empty entry in the map. + // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and + // insert an empty entry for it to the map.) + // + // The map value is maybeStream rather than *Stream as a reminder that values can be nil. + streams map[streamID]maybeStream // Limits on the number of streams, indexed by streamType. localLimit [streamTypeCount]localStreamLimits @@ -37,8 +45,13 @@ type streamsState struct { queueData streamRing // streams with only flow-controlled frames } +// maybeStream is a possibly nil *Stream. See streamsState.streams. +type maybeStream struct { + s *Stream +} + func (c *Conn) streamsInit() { - c.streams.streams = make(map[streamID]*Stream) + c.streams.streams = make(map[streamID]maybeStream) c.streams.queue = newQueue[*Stream]() c.streams.localLimit[bidiStream].init() c.streams.localLimit[uniStream].init() @@ -52,8 +65,8 @@ func (c *Conn) streamsCleanup() { c.streams.localLimit[bidiStream].connHasClosed() c.streams.localLimit[uniStream].connHasClosed() for _, s := range c.streams.streams { - if s != nil { - s.connHasClosed() + if s.s != nil { + s.s.connHasClosed() } } } @@ -97,7 +110,7 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er // Modify c.streams on the conn's loop. if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) { - c.streams.streams[s.id] = s + c.streams.streams[s.id] = maybeStream{s} }); err != nil { return nil, err } @@ -119,7 +132,7 @@ const ( // streamForID returns the stream with the given id. // If the stream does not exist, it returns nil. func (c *Conn) streamForID(id streamID) *Stream { - return c.streams.streams[id] + return c.streams.streams[id].s } // streamForFrame returns the stream with the given id. @@ -144,9 +157,9 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) } } - s, isOpen := c.streams.streams[id] - if s != nil { - return s + ms, isOpen := c.streams.streams[id] + if ms.s != nil { + return ms.s } num := id.num() @@ -183,10 +196,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) // with the same initiator and type and a lower number. // Add a nil entry to the streams map for each implicitly created stream. for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 { - c.streams.streams[n] = nil + c.streams.streams[n] = maybeStream{} } - s = newStream(c, id) + s := newStream(c, id) s.inmaxbuf = c.config.maxStreamReadBufferSize() s.inwin = c.config.maxStreamReadBufferSize() if id.streamType() == bidiStream { @@ -196,7 +209,7 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) s.inUnlock() s.outUnlock() - c.streams.streams[id] = s + c.streams.streams[id] = maybeStream{s} c.streams.queue.put(s) return s } @@ -400,7 +413,11 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() const pto = true - for _, s := range c.streams.streams { + for _, ms := range c.streams.streams { + s := ms.s + if s == nil { + continue + } const pto = true s.ingate.lock() inOK := s.appendInFramesLocked(w, pnum, pto) diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 90f5cb75c..fb9af47eb 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -522,3 +522,38 @@ func TestStreamsCreateConcurrency(t *testing.T) { t.Errorf("accepted %v streams, want %v", got, want) } } + +func TestStreamsPTOWithImplicitStream(t *testing.T) { + ctx := canceledContext() + tc := newTestConn(t, serverSide, permissiveTransportParameters) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + // Peer creates stream 1, and implicitly creates stream 0. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + }) + + // We accept stream 1 and write data to it. + data := []byte("data") + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("conn.AcceptStream() = %v, want stream", err) + } + s.Write(data) + s.Flush() + tc.wantFrame("data written to stream", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + data: data, + }) + + // PTO expires, and the data is resent. + const pto = true + tc.triggerLossOrPTO(packetType1RTT, true) + tc.wantFrame("data resent after PTO expires", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + data: data, + }) +} From f12db26b1c9293fa3eb95c936e548d2c1fba4ba9 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Dec 2023 11:33:13 -0800 Subject: [PATCH 20/70] internal/quic/cmd/interop: use wget --no-verbose in Dockerfile Pass --no-verbose to wget to avoid spamming the build logs with progress indicators. Change-Id: I36a0b91f8dac09cc4055c5d5db3fc61c9b269d6e Reviewed-on: https://go-review.googlesource.com/c/net/+/551495 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/cmd/interop/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/quic/cmd/interop/Dockerfile b/internal/quic/cmd/interop/Dockerfile index 4b52e5356..b60999a86 100644 --- a/internal/quic/cmd/interop/Dockerfile +++ b/internal/quic/cmd/interop/Dockerfile @@ -9,7 +9,7 @@ ENV GOVERSION=1.21.1 RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ filename="go${GOVERSION}.${platform}.tar.gz" && \ - wget https://dl.google.com/go/${filename} && \ + wget --no-verbose https://dl.google.com/go/${filename} && \ tar xfz ${filename} && \ rm ${filename} From 689bbc7005f6bbf9fac1a8333bf03436fa4b4b2a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 4 Jan 2024 10:29:48 -0800 Subject: [PATCH 21/70] quic: deflake TestStreamsCreateConcurrency This test assumed that creating a stream and flushing it on the client ensured the server had accepted the stream. This isn't the case; the stream has been delivered to the server, but there's no guarantee that it been accepted by the user layer. Change the test to make a full loop: The client creates a stream, and then waits for the server to close it. Fixes golang/go#64788 Change-Id: I24f08502e9f5d8bd5a17e680b0aa19dcc2623841 Reviewed-on: https://go-review.googlesource.com/c/net/+/554175 Reviewed-by: Bryan Mills LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_streams_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index fb9af47eb..6815e403e 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -510,6 +510,10 @@ func TestStreamsCreateConcurrency(t *testing.T) { return } s.Flush() + _, err = io.ReadAll(s) + if err != nil { + t.Errorf("ReadFull: %v", err) + } s.Close() } }() From cb5b10f0bbc51089bf49030ce3bd43bbfee08c23 Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Mon, 8 Jan 2024 17:35:41 +0000 Subject: [PATCH 22/70] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: I77f3c5560bd989f4e9c6b8c3f36e900fefe9bb0e Reviewed-on: https://go-review.googlesource.com/c/net/+/554675 Reviewed-by: Than McIntosh Reviewed-by: Dmitri Shuralyov Auto-Submit: Gopher Robot LUCI-TryBot-Result: Go LUCI --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 8ab3f40e1..3bd487f5a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.16.0 - golang.org/x/sys v0.15.0 - golang.org/x/term v0.15.0 + golang.org/x/crypto v0.18.0 + golang.org/x/sys v0.16.0 + golang.org/x/term v0.16.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index bb6ed68a0..8eeaf16c6 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= From 26b646ea024741dd5d8e141fc33d8149c465686a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 8 Jan 2024 09:36:16 -0800 Subject: [PATCH 23/70] quic: avoid deadlock in Endpoint.Close Don't hold Endpoint.connsMu while calling Conn methods that can indirectly depend on acquiring it. Also change test cleanup to not wait for connections to drain when closing a test Endpoint, removing an unnecessary 0.1s delay in test runtime. Fixes golang/go#64982. Change-Id: If336e63b0a7f5b8d2ef63986d36f9ee38a92c477 Reviewed-on: https://go-review.googlesource.com/c/net/+/554695 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/endpoint.go | 16 +++++++++++----- internal/quic/endpoint_test.go | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/internal/quic/endpoint.go b/internal/quic/endpoint.go index 82a08a18c..8ed67de54 100644 --- a/internal/quic/endpoint.go +++ b/internal/quic/endpoint.go @@ -103,25 +103,31 @@ func (e *Endpoint) LocalAddr() netip.AddrPort { // It waits for the peers of any open connection to acknowledge the connection has been closed. func (e *Endpoint) Close(ctx context.Context) error { e.acceptQueue.close(errors.New("endpoint closed")) + + // It isn't safe to call Conn.Abort or conn.exit with connsMu held, + // so copy the list of conns. + var conns []*Conn e.connsMu.Lock() if !e.closing { - e.closing = true + e.closing = true // setting e.closing prevents new conns from being created for c := range e.conns { - c.Abort(localTransportError{code: errNo}) + conns = append(conns, c) } if len(e.conns) == 0 { e.udpConn.Close() } } e.connsMu.Unlock() + + for _, c := range conns { + c.Abort(localTransportError{code: errNo}) + } select { case <-e.closec: case <-ctx.Done(): - e.connsMu.Lock() - for c := range e.conns { + for _, c := range conns { c.exit() } - e.connsMu.Unlock() return ctx.Err() } return nil diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index ab6cd1cf5..16c3e0bce 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -97,7 +97,7 @@ func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { t.Fatal(err) } t.Cleanup(func() { - e.Close(context.Background()) + e.Close(canceledContext()) }) return e } From 07e05fd6e95ab445ebe48840c81a027dbace3b8e Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 5 Jan 2024 11:25:21 +0400 Subject: [PATCH 24/70] http2: remove suspicious uint32->v conversion in frame code Function maxHeaderStringLen(...) uses uint32(int(v)) == v check to validate if length will fit in the int type. This check is a no-op on any architecture because int type always has at least 32 bits, so we can potentially encounter negative return values from maxHeaderStringLen(...) function. This can be bad as this outcome clearly breaks code intention and maybe some further code invariants. This patch replaces uint32(int(v)) == v check with more robust and simpler int(v) > 0 validation which is correct for our case when we operating with uint32 Fixes golang/go#64961 Change-Id: I31f95709df9d25593ade3200696ac5cef9f88652 Reviewed-on: https://go-review.googlesource.com/c/net/+/554235 Auto-Submit: Dmitri Shuralyov Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov --- http2/frame.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/http2/frame.go b/http2/frame.go index c1f6b90dc..e2b298d85 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1510,13 +1510,12 @@ func (mh *MetaHeadersFrame) checkPseudos() error { } func (fr *Framer) maxHeaderStringLen() int { - v := fr.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) + v := int(fr.maxHeaderListSize()) + if v < 0 { + // If maxHeaderListSize overflows an int, use no limit (0). + return 0 } - // They had a crazy big number for MaxHeaderBytes anyway, - // so give them unlimited header lengths: - return 0 + return v } // readMetaFrame returns 0 or more CONTINUATION frames from fr and From 0d0b98c1378dba60d10c77c383c40f94c1641cfc Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Mon, 22 Jan 2024 16:26:00 -0500 Subject: [PATCH 25/70] http2: avoid goroutine starvation in TestServer_Push_RejectAfterGoAway CL 557037 added a runtime.Gosched to prevent goroutine starvation in the wasm fake-net stack. Unfortunately, that Gosched causes the scheduler to enter a very similar starvation loop in this test. Add another runtime.Gosched to break this new loop. For golang/go#65178. Change-Id: I24b3f50dd728800462f71f27290b0d8f99d5ae5b Cq-Include-Trybots: luci.golang.try:x_net-gotip-wasip1-wasm_wasmtime,x_net-gotip-wasip1-wasm_wazero,x_net-gotip-js-wasm Reviewed-on: https://go-review.googlesource.com/c/net/+/557615 Auto-Submit: Bryan Mills LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Pratt --- http2/server_push_test.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/http2/server_push_test.go b/http2/server_push_test.go index 9882d9ef7..cda8f4336 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "net/http" "reflect" + "runtime" "strconv" "sync" "testing" @@ -483,11 +484,7 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { ready := make(chan struct{}) errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - select { - case <-ready: - case <-time.After(5 * time.Second): - errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed") - } + <-ready if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { errc <- fmt.Errorf("Push()=%v, want %v", got, want) } @@ -505,6 +502,10 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { case <-ready: return default: + if runtime.GOARCH == "wasm" { + // Work around https://go.dev/issue/65178 to avoid goroutine starvation. + runtime.Gosched() + } } st.sc.serveMsgCh <- func(loopNum int) { if !st.sc.pushEnabled { From b2208d046df5625a4f78624149cba7722c4ccfee Mon Sep 17 00:00:00 2001 From: btwiuse <54848194+btwiuse@users.noreply.github.com> Date: Sun, 14 Jan 2024 13:20:58 +0000 Subject: [PATCH 26/70] internal/quic/qlog: fix typo VantageClient -> VantageServer Change-Id: Ie9738cffb06f03f961815853247e6f9cbe7fe466 GitHub-Last-Rev: 5d440ad29c49ef4cd529a076449114696662afec GitHub-Pull-Request: golang/net#202 Reviewed-on: https://go-review.googlesource.com/c/net/+/555795 LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Knyszek Reviewed-by: Damien Neil Auto-Submit: Damien Neil --- internal/quic/qlog/qlog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/quic/qlog/qlog.go b/internal/quic/qlog/qlog.go index e54c839f0..f33c6b0fd 100644 --- a/internal/quic/qlog/qlog.go +++ b/internal/quic/qlog/qlog.go @@ -29,7 +29,7 @@ const ( // VantageClient traces follow a connection from the client's perspective. VantageClient = Vantage("client") - // VantageClient traces follow a connection from the server's perspective. + // VantageServer traces follow a connection from the server's perspective. VantageServer = Vantage("server") ) From 73e4b50dadcf3bd6015efb8b6e8ddbeb7dfe74c5 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak <19653795+mateusz834@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:48:44 +0000 Subject: [PATCH 27/70] dns/dnsmessage: allow name compression for SRV resource parsing As per RFC 3597: Receiving servers MUST decompress domain names in RRs of well-known type, and SHOULD also decompress RRs of type RP, AFSDB, RT, SIG, PX, NXT, NAPTR, and SRV (although the current specification of the SRV RR in RFC2782 prohibits compression, RFC2052 mandated it, and some servers following that earlier specification are still in use). This change allows SRV resource decompression. Updates golang/go#36718 Updates golang/go#37362 Change-Id: I473c0d3803758e5b12886f378d2ed54bd5392144 GitHub-Last-Rev: 88d2e0642a7c7ba618d642801ebc72ba82ef30b7 GitHub-Pull-Request: golang/net#199 Reviewed-on: https://go-review.googlesource.com/c/net/+/540375 LUCI-TryBot-Result: Go LUCI Reviewed-by: Carlos Amedee Auto-Submit: Damien Neil Reviewed-by: Damien Neil --- dns/dnsmessage/message.go | 10 +--------- dns/dnsmessage/message_test.go | 22 ---------------------- 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index 42987ab7c..a656efc12 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -273,7 +273,6 @@ var ( errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)") errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)") errStringTooLong = errors.New("character string exceeds maximum length (255)") - errCompressedSRV = errors.New("compressed name in SRV resource data") ) // Internal constants. @@ -2028,10 +2027,6 @@ func (n *Name) pack(msg []byte, compression map[string]uint16, compressionOff in // unpack unpacks a domain name. func (n *Name) unpack(msg []byte, off int) (int, error) { - return n.unpackCompressed(msg, off, true /* allowCompression */) -} - -func (n *Name) unpackCompressed(msg []byte, off int, allowCompression bool) (int, error) { // currOff is the current working offset. currOff := off @@ -2076,9 +2071,6 @@ Loop: name = append(name, '.') currOff = endOff case 0xC0: // Pointer - if !allowCompression { - return off, errCompressedSRV - } if currOff >= len(msg) { return off, errInvalidPtr } @@ -2549,7 +2541,7 @@ func unpackSRVResource(msg []byte, off int) (SRVResource, error) { return SRVResource{}, &nestedError{"Port", err} } var target Name - if _, err := target.unpackCompressed(msg, off, false /* allowCompression */); err != nil { + if _, err := target.unpack(msg, off); err != nil { return SRVResource{}, &nestedError{"Target", err} } return SRVResource{priority, weight, port, target}, nil diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index c84d5a3aa..e60ec42d9 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -303,28 +303,6 @@ func TestNameUnpackTooLongName(t *testing.T) { } } -func TestIncompressibleName(t *testing.T) { - name := MustNewName("example.com.") - compression := map[string]uint16{} - buf, err := name.pack(make([]byte, 0, 100), compression, 0) - if err != nil { - t.Fatal("first Name.pack() =", err) - } - buf, err = name.pack(buf, compression, 0) - if err != nil { - t.Fatal("second Name.pack() =", err) - } - var n1 Name - off, err := n1.unpackCompressed(buf, 0, false /* allowCompression */) - if err != nil { - t.Fatal("unpacking incompressible name without pointers failed:", err) - } - var n2 Name - if _, err := n2.unpackCompressed(buf, off, false /* allowCompression */); err != errCompressedSRV { - t.Errorf("unpacking compressed incompressible name with pointers: got %v, want = %v", err, errCompressedSRV) - } -} - func checkErrorPrefix(err error, prefix string) bool { e, ok := err.(*nestedError) return ok && e.s == prefix From 643fd162e36ae58085b92ff4c0fec0bafe5a46a7 Mon Sep 17 00:00:00 2001 From: Maciej Mionskowski Date: Thu, 19 Oct 2023 20:16:20 +0000 Subject: [PATCH 28/70] html: fix SOLIDUS '/' handling in attribute parsing Calling the Tokenizer with HTML elements containing SOLIDUS (/) character in the attribute name results in incorrect tokenization. This is due to violation of the following rule transitions in the WHATWG spec: - https://html.spec.whatwg.org/multipage/parsing.html#attribute-name-state, where we are not reconsuming the character if '/' is encountered - https://html.spec.whatwg.org/multipage/parsing.html#after-attribute-name-state, where we are not switching to self closing state Fixes golang/go#63402 Change-Id: I90d998dd8decde877bd63aa664f3657aa6161024 GitHub-Last-Rev: 3546db808c5fbf46ea25a10cdadb2802f763b6de GitHub-Pull-Request: golang/net#195 Reviewed-on: https://go-review.googlesource.com/c/net/+/533518 LUCI-TryBot-Result: Go LUCI Auto-Submit: Michael Pratt Reviewed-by: Roland Shoemaker Reviewed-by: David Chase --- html/token.go | 12 ++++++++---- html/token_test.go | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/html/token.go b/html/token.go index de67f938a..3c57880d6 100644 --- a/html/token.go +++ b/html/token.go @@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() { return } switch c { - case ' ', '\n', '\r', '\t', '\f', '/': - z.pendingAttr[0].end = z.raw.end - 1 - return case '=': if z.pendingAttr[0].start+1 == z.raw.end { // WHATWG 13.2.5.32, if we see an equals sign before the attribute name @@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() { continue } fallthrough - case '>': + case ' ', '\n', '\r', '\t', '\f', '/', '>': + // WHATWG 13.2.5.33 Attribute name state + // We need to reconsume the char in the after attribute name state to support the / character z.raw.end-- z.pendingAttr[0].end = z.raw.end return @@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() { if z.err != nil { return } + if c == '/' { + // WHATWG 13.2.5.34 After attribute name state + // U+002F SOLIDUS (/) - Switch to the self-closing start tag state. + return + } if c != '=' { z.raw.end-- return diff --git a/html/token_test.go b/html/token_test.go index b2383a951..8b0d5aab6 100644 --- a/html/token_test.go +++ b/html/token_test.go @@ -601,6 +601,21 @@ var tokenTests = []tokenTest{ `

`, `

`, }, + { + "forward slash before attribute name", + `

`, + `

`, + }, + { + "forward slash before attribute name with spaces around", + `

`, + `

`, + }, + { + "forward slash after attribute name followed by a character", + `

`, + `

`, + }, } func TestTokenizer(t *testing.T) { From 73d21fdbb4d7dc7115b50526b93b6c37a4e3377f Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Wed, 7 Feb 2024 19:22:03 +0000 Subject: [PATCH 29/70] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: I314af161ceac84fec04c729a71860ad35335513b Reviewed-on: https://go-review.googlesource.com/c/net/+/562495 Auto-Submit: Gopher Robot Reviewed-by: Dmitri Shuralyov LUCI-TryBot-Result: Go LUCI Reviewed-by: Than McIntosh --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 3bd487f5a..7f512d703 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.18.0 - golang.org/x/sys v0.16.0 - golang.org/x/term v0.16.0 + golang.org/x/crypto v0.19.0 + golang.org/x/sys v0.17.0 + golang.org/x/term v0.17.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 8eeaf16c6..683b469d6 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= From 5a444b4f2fe893ea00f0376da46aa5376c3f3e28 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 29 Nov 2023 10:41:41 -0800 Subject: [PATCH 30/70] quic: add Stream.Set{Read,Write}Context, drop {Read,Write,Close}Context The ReadContext, WriteContext, and CloseContext Stream methods are difficult to use in conjunction with functions that operate on an io.Reader, io.Writer, or io.Closer. For example, it's reasonable to want to use io.ReadFull with a Stream, but doing so with a context is cumbersome. Drop the Stream methods that take a Context in favor of stateful methods that set the Context to use for read and write operations. (Close counts as a write operation, since it blocks waiting for data to be sent.) Intentionally make Set{Read,Write}Context not concurrency safe, to allow the race detector to catch misuse. This shouldn't be a problem for correct programs, since reads and writes are inherently not concurrency-safe. For golang/go#58547 Change-Id: I41378eb552d89a720921fc8644d3637c1a545676 Reviewed-on: https://go-review.googlesource.com/c/net/+/550795 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn_close_test.go | 11 ++- internal/quic/conn_flow_test.go | 19 ++-- internal/quic/conn_loss_test.go | 9 +- internal/quic/conn_streams_test.go | 26 +++--- internal/quic/conn_test.go | 11 +++ internal/quic/stream.go | 73 ++++++++------- internal/quic/stream_limits_test.go | 9 +- internal/quic/stream_test.go | 139 ++++++++++++++-------------- 8 files changed, 154 insertions(+), 143 deletions(-) diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go index 63d4911e8..213975011 100644 --- a/internal/quic/conn_close_test.go +++ b/internal/quic/conn_close_test.go @@ -249,8 +249,9 @@ func TestConnCloseUnblocksNewStream(t *testing.T) { func TestConnCloseUnblocksStreamRead(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetReadContext(ctx) buf := make([]byte, 16) - _, err := s.ReadContext(ctx, buf) + _, err := s.Read(buf) return err }, permissiveTransportParameters) } @@ -258,8 +259,9 @@ func TestConnCloseUnblocksStreamRead(t *testing.T) { func TestConnCloseUnblocksStreamWrite(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetWriteContext(ctx) buf := make([]byte, 32) - _, err := s.WriteContext(ctx, buf) + _, err := s.Write(buf) return err }, permissiveTransportParameters, func(c *Config) { c.MaxStreamWriteBufferSize = 16 @@ -269,11 +271,12 @@ func TestConnCloseUnblocksStreamWrite(t *testing.T) { func TestConnCloseUnblocksStreamClose(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetWriteContext(ctx) buf := make([]byte, 16) - _, err := s.WriteContext(ctx, buf) + _, err := s.Write(buf) if err != nil { return err } - return s.CloseContext(ctx) + return s.Close() }, permissiveTransportParameters) } diff --git a/internal/quic/conn_flow_test.go b/internal/quic/conn_flow_test.go index 39c879346..8e04e20d9 100644 --- a/internal/quic/conn_flow_test.go +++ b/internal/quic/conn_flow_test.go @@ -12,7 +12,6 @@ import ( ) func TestConnInflowReturnOnRead(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { c.MaxConnReadBufferSize = 64 }) @@ -21,14 +20,14 @@ func TestConnInflowReturnOnRead(t *testing.T) { data: make([]byte, 64), }) const readSize = 8 - if n, err := s.ReadContext(ctx, make([]byte, readSize)); n != readSize || err != nil { + if n, err := s.Read(make([]byte, readSize)); n != readSize || err != nil { t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize) } tc.wantFrame("available window increases, send a MAX_DATA", packetType1RTT, debugFrameMaxData{ max: 64 + readSize, }) - if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64-readSize || err != nil { + if n, err := s.Read(make([]byte, 64)); n != 64-readSize || err != nil { t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize) } tc.wantFrame("available window increases, send a MAX_DATA", @@ -42,7 +41,7 @@ func TestConnInflowReturnOnRead(t *testing.T) { data: make([]byte, 64), }) tc.wantIdle("connection is idle") - if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64 || err != nil { + if n, err := s.Read(make([]byte, 64)); n != 64 || err != nil { t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64) } } @@ -79,10 +78,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { t.Fatalf("conn.AcceptStream() = %v", err) } read1 := runAsync(tc, func(ctx context.Context) (int, error) { - return s1.ReadContext(ctx, make([]byte, 16)) + return s1.Read(make([]byte, 16)) }) read2 := runAsync(tc, func(ctx context.Context) (int, error) { - return s2.ReadContext(ctx, make([]byte, 1)) + return s2.Read(make([]byte, 1)) }) // This MAX_DATA might extend the window by 16 or 17, depending on // whether the second write occurs before the update happens. @@ -90,10 +89,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { packetType1RTT, debugFrameMaxData{}) tc.wantIdle("redundant MAX_DATA is not sent") if _, err := read1.result(); err != nil { - t.Errorf("ReadContext #1 = %v", err) + t.Errorf("Read #1 = %v", err) } if _, err := read2.result(); err != nil { - t.Errorf("ReadContext #2 = %v", err) + t.Errorf("Read #2 = %v", err) } } @@ -227,13 +226,13 @@ func TestConnInflowMultipleStreams(t *testing.T) { t.Fatalf("AcceptStream() = %v", err) } streams = append(streams, s) - if n, err := s.ReadContext(ctx, make([]byte, 1)); err != nil || n != 1 { + if n, err := s.Read(make([]byte, 1)); err != nil || n != 1 { t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err) } } tc.wantIdle("streams have read data, but not enough to update MAX_DATA") - if n, err := streams[0].ReadContext(ctx, make([]byte, 32)); err != nil || n != 31 { + if n, err := streams[0].Read(make([]byte, 32)); err != nil || n != 31 { t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err) } tc.wantFrame("read enough data to trigger a MAX_DATA update", diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index 818816335..876ffd093 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -433,7 +433,8 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) { if err != nil { t.Fatalf("AcceptStream() = %v", err) } - s.CloseContext(ctx) + s.SetWriteContext(ctx) + s.Close() if styp == bidiStream { tc.wantFrame("stream is closed", packetType1RTT, debugFrameStream{ @@ -480,7 +481,7 @@ func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) { if err != nil { t.Fatalf("AcceptStream() = %v", err) } - if err := s.CloseContext(ctx); err != nil { + if err := s.Close(); err != nil { t.Fatalf("stream.Close() = %v", err) } tc.wantFrame("closing stream updates peer's MAX_STREAMS", @@ -512,7 +513,7 @@ func TestLostStreamDataBlockedFrame(t *testing.T) { }) w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, []byte{0, 1, 2, 3}) + return s.Write([]byte{0, 1, 2, 3}) }) defer w.cancel() tc.wantFrame("write is blocked by flow control", @@ -564,7 +565,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) { data := []byte{0, 1, 2, 3} w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, data) + return s.Write(data) }) defer w.cancel() tc.wantFrame("write is blocked by flow control", diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 6815e403e..dc81ad991 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -230,8 +230,8 @@ func TestStreamsWriteQueueFairness(t *testing.T) { t.Fatal(err) } streams = append(streams, s) - if n, err := s.WriteContext(ctx, data); n != len(data) || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(data)) + if n, err := s.Write(data); n != len(data) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } // Wait for the stream to finish writing whatever frames it can before // congestion control blocks it. @@ -298,7 +298,7 @@ func TestStreamsShutdown(t *testing.T) { side: localStream, styp: uniStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - s.CloseContext(canceledContext()) + s.Close() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { tc.writeAckForAll() @@ -311,7 +311,7 @@ func TestStreamsShutdown(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameResetStream{ id: s.id, }) - s.CloseContext(canceledContext()) + s.Close() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { tc.writeAckForAll() @@ -321,8 +321,8 @@ func TestStreamsShutdown(t *testing.T) { side: localStream, styp: bidiStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - s.CloseContext(canceledContext()) - tc.wantIdle("all frames after CloseContext are ignored") + s.Close() + tc.wantIdle("all frames after Close are ignored") tc.writeAckForAll() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { @@ -335,13 +335,12 @@ func TestStreamsShutdown(t *testing.T) { side: remoteStream, styp: uniStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - ctx := canceledContext() tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, fin: true, }) - if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF { - t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err) + if n, err := s.Read(make([]byte, 16)); n != 0 || err != io.EOF { + t.Errorf("Read() = %v, %v; want 0, io.EOF", n, err) } }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { @@ -451,17 +450,14 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) { id: op.id, }) case acceptOp: - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream() = %q; want stream %v", err, stringID(op.id)) - } + s := tc.acceptStream() if s.id != op.id { - t.Fatalf("accepted stram %v; want stream %v", err, stringID(op.id)) + t.Fatalf("accepted stream %v; want stream %v", stringID(s.id), stringID(op.id)) } t.Logf("accepted stream %v", stringID(op.id)) // Immediately close the stream, so the stream becomes done when the // peer closes its end. - s.CloseContext(ctx) + s.Close() } p := tc.readPacket() if p != nil { diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index ddf0740e2..2d3c946d6 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -382,6 +382,17 @@ func (tc *testConn) cleanup() { <-tc.conn.donec } +func (tc *testConn) acceptStream() *Stream { + tc.t.Helper() + s, err := tc.conn.AcceptStream(canceledContext()) + if err != nil { + tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err) + } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return s +} + func logDatagram(t *testing.T, text string, d *testDatagram) { t.Helper() if !*testVV { diff --git a/internal/quic/stream.go b/internal/quic/stream.go index fb9c1cf3c..d0122b951 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -18,6 +18,11 @@ type Stream struct { id streamID conn *Conn + // Contexts used for read/write operations. + // Intentionally not mutex-guarded, to allow the race detector to catch concurrent access. + inctx context.Context + outctx context.Context + // ingate's lock guards all receive-related state. // // The gate condition is set if a read from the stream will not block, @@ -152,6 +157,8 @@ func newStream(c *Conn, id streamID) *Stream { inresetcode: -1, // -1 indicates no RESET_STREAM received ingate: newLockedGate(), outgate: newLockedGate(), + inctx: context.Background(), + outctx: context.Background(), } if !s.IsReadOnly() { s.outdone = make(chan struct{}) @@ -159,6 +166,22 @@ func newStream(c *Conn, id streamID) *Stream { return s } +// SetReadContext sets the context used for reads from the stream. +// +// It is not safe to call SetReadContext concurrently. +func (s *Stream) SetReadContext(ctx context.Context) { + s.inctx = ctx +} + +// SetWriteContext sets the context used for writes to the stream. +// The write context is also used by Close when waiting for writes to be +// received by the peer. +// +// It is not safe to call SetWriteContext concurrently. +func (s *Stream) SetWriteContext(ctx context.Context) { + s.outctx = ctx +} + // IsReadOnly reports whether the stream is read-only // (a unidirectional stream created by the peer). func (s *Stream) IsReadOnly() bool { @@ -172,24 +195,18 @@ func (s *Stream) IsWriteOnly() bool { } // Read reads data from the stream. -// See ReadContext for more details. -func (s *Stream) Read(b []byte) (n int, err error) { - return s.ReadContext(context.Background(), b) -} - -// ReadContext reads data from the stream. // -// ReadContext returns as soon as at least one byte of data is available. +// Read returns as soon as at least one byte of data is available. // -// If the peer closes the stream cleanly, ReadContext returns io.EOF after +// If the peer closes the stream cleanly, Read returns io.EOF after // returning all data sent by the peer. -// If the peer aborts reads on the stream, ReadContext returns +// If the peer aborts reads on the stream, Read returns // an error wrapping StreamResetCode. -func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { +func (s *Stream) Read(b []byte) (n int, err error) { if s.IsWriteOnly() { return 0, errors.New("read from write-only stream") } - if err := s.ingate.waitAndLock(ctx, s.conn.testHooks); err != nil { + if err := s.ingate.waitAndLock(s.inctx, s.conn.testHooks); err != nil { return 0, err } defer func() { @@ -237,17 +254,11 @@ func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool { } // Write writes data to the stream. -// See WriteContext for more details. -func (s *Stream) Write(b []byte) (n int, err error) { - return s.WriteContext(context.Background(), b) -} - -// WriteContext writes data to the stream. // -// WriteContext writes data to the stream write buffer. +// Write writes data to the stream write buffer. // Buffered data is only sent when the buffer is sufficiently full. // Call the Flush method to ensure buffered data is sent. -func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) { +func (s *Stream) Write(b []byte) (n int, err error) { if s.IsReadOnly() { return 0, errors.New("write to read-only stream") } @@ -259,7 +270,7 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) if len(b) > 0 && !canWrite { // Our send buffer is full. Wait for the peer to ack some data. s.outUnlock() - if err := s.outgate.waitAndLock(ctx, s.conn.testHooks); err != nil { + if err := s.outgate.waitAndLock(s.outctx, s.conn.testHooks); err != nil { return n, err } // Successfully returning from waitAndLockGate means we are no longer @@ -317,7 +328,7 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) // Flush flushes data written to the stream. // It does not wait for the peer to acknowledge receipt of the data. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) Flush() { s.outgate.lock() defer s.outUnlock() @@ -333,27 +344,21 @@ func (s *Stream) flushLocked() { } // Close closes the stream. -// See CloseContext for more details. -func (s *Stream) Close() error { - return s.CloseContext(context.Background()) -} - -// CloseContext closes the stream. // Any blocked stream operations will be unblocked and return errors. // -// CloseContext flushes any data in the stream write buffer and waits for the peer to +// Close flushes any data in the stream write buffer and waits for the peer to // acknowledge receipt of the data. // If the stream has been reset, it waits for the peer to acknowledge the reset. // If the context expires before the peer receives the stream's data, -// CloseContext discards the buffer and returns the context error. -func (s *Stream) CloseContext(ctx context.Context) error { +// Close discards the buffer and returns the context error. +func (s *Stream) Close() error { s.CloseRead() if s.IsReadOnly() { return nil } s.CloseWrite() // TODO: Return code from peer's RESET_STREAM frame? - if err := s.conn.waitOnDone(ctx, s.outdone); err != nil { + if err := s.conn.waitOnDone(s.outctx, s.outdone); err != nil { return err } s.outgate.lock() @@ -369,7 +374,7 @@ func (s *Stream) CloseContext(ctx context.Context) error { // // CloseRead notifies the peer that the stream has been closed for reading. // It does not wait for the peer to acknowledge the closure. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) CloseRead() { if s.IsWriteOnly() { return @@ -394,7 +399,7 @@ func (s *Stream) CloseRead() { // // CloseWrite sends any data in the stream write buffer to the peer. // It does not wait for the peer to acknowledge receipt of the data. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) CloseWrite() { if s.IsReadOnly() { return @@ -412,7 +417,7 @@ func (s *Stream) CloseWrite() { // Reset sends the application protocol error code, which must be // less than 2^62, to the peer. // It does not wait for the peer to acknowledge receipt of the error. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. // // Reset does not affect reads. // Use CloseRead to abort reads on the stream. diff --git a/internal/quic/stream_limits_test.go b/internal/quic/stream_limits_test.go index 3f291e9f4..9c2f71ec1 100644 --- a/internal/quic/stream_limits_test.go +++ b/internal/quic/stream_limits_test.go @@ -200,7 +200,6 @@ func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) { func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { if styp == uniStream { c.MaxUniRemoteStreams = 4 @@ -218,13 +217,9 @@ func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { id: newStreamID(clientSide, styp, int64(i)), fin: true, }) - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream = %v", err) - } - streams = append(streams, s) + streams = append(streams, tc.acceptStream()) } - streams[3].CloseContext(ctx) + streams[3].Close() if styp == bidiStream { tc.wantFrame("stream is closed", packetType1RTT, debugFrameStream{ diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 00e392dba..08e89b24c 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -19,7 +19,6 @@ import ( func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const writeBufferSize = 4 tc := newTestConn(t, clientSide, permissiveTransportParameters, func(c *Config) { @@ -28,15 +27,12 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { tc.handshake() tc.ignoreFrame(frameTypeAck) - s, err := tc.conn.newLocalStream(ctx, styp) - if err != nil { - t.Fatal(err) - } + s := newLocalStream(t, tc, styp) // Non-blocking write. - n, err := s.WriteContext(ctx, want) + n, err := s.Write(want) if n != writeBufferSize || err != context.Canceled { - t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) + t.Fatalf("s.Write() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) } s.Flush() tc.wantFrame("first write buffer of data sent", @@ -48,7 +44,8 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { // Blocking write, which must wait for buffer space. w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want[writeBufferSize:]) + s.SetWriteContext(ctx) + n, err := s.Write(want[writeBufferSize:]) s.Flush() return n, err }) @@ -75,7 +72,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { }) if n, err := w.result(); n != len(want)-writeBufferSize || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", + t.Fatalf("s.Write() = %v, %v; want %v, nil", len(want)-writeBufferSize, err, writeBufferSize) } }) @@ -99,7 +96,7 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { } // Data is written to the stream output buffer, but we have no flow control. - _, err = s.WriteContext(ctx, want[:1]) + _, err = s.Write(want[:1]) if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } @@ -110,7 +107,7 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { }) // Write more data. - _, err = s.WriteContext(ctx, want[1:]) + _, err = s.Write(want[1:]) if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } @@ -172,7 +169,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { if err != nil { t.Fatal(err) } - s.WriteContext(ctx, want[:1]) + s.Write(want[:1]) s.Flush() tc.wantFrame("sent data (1 byte) fits within flow control limit", packetType1RTT, debugFrameStream{ @@ -188,7 +185,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { }) // Write [1,4). - s.WriteContext(ctx, want[1:]) + s.Write(want[1:]) tc.wantFrame("stream limit is 4 bytes, ignoring decrease in MAX_STREAM_DATA", packetType1RTT, debugFrameStream{ id: s.id, @@ -208,7 +205,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { }) // Write [1,4). - s.WriteContext(ctx, want[4:]) + s.Write(want[4:]) tc.wantFrame("stream limit is 8 bytes, ignoring decrease in MAX_STREAM_DATA", packetType1RTT, debugFrameStream{ id: s.id, @@ -220,7 +217,6 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const maxWriteBuffer = 4 tc := newTestConn(t, clientSide, func(p *transportParameters) { @@ -238,12 +234,10 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { // Write more data than StreamWriteBufferSize. // The peer has given us plenty of flow control, // so we're just blocked by our local limit. - s, err := tc.conn.newLocalStream(ctx, styp) - if err != nil { - t.Fatal(err) - } + s := newLocalStream(t, tc, styp) w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want) + s.SetWriteContext(ctx) + return s.Write(want) }) tc.wantFrame("stream write should send as much data as write buffer allows", packetType1RTT, debugFrameStream{ @@ -266,7 +260,7 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { w.cancel() n, err := w.result() if n != 2*maxWriteBuffer || err == nil { - t.Fatalf("WriteContext() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer) + t.Fatalf("Write() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer) } }) } @@ -397,7 +391,6 @@ func TestStreamReceive(t *testing.T) { }}, }} { testStreamTypes(t, test.name, func(t *testing.T, styp streamType) { - ctx := canceledContext() tc := newTestConn(t, serverSide) tc.handshake() sid := newStreamID(clientSide, styp, 0) @@ -413,21 +406,17 @@ func TestStreamReceive(t *testing.T) { fin: f.fin, }) if s == nil { - var err error - s, err = tc.conn.AcceptStream(ctx) - if err != nil { - tc.t.Fatalf("conn.AcceptStream() = %v", err) - } + s = tc.acceptStream() } for { - n, err := s.ReadContext(ctx, got[total:]) - t.Logf("s.ReadContext() = %v, %v", n, err) + n, err := s.Read(got[total:]) + t.Logf("s.Read() = %v, %v", n, err) total += n if f.wantEOF && err != io.EOF { - t.Fatalf("ReadContext() error = %v; want io.EOF", err) + t.Fatalf("Read() error = %v; want io.EOF", err) } if !f.wantEOF && err == io.EOF { - t.Fatalf("ReadContext() error = io.EOF, want something else") + t.Fatalf("Read() error = io.EOF, want something else") } if err != nil { break @@ -468,8 +457,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) { } tc.wantIdle("stream window is not extended before data is read") buf := make([]byte, maxWindowSize+1) - if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != nil { - t.Fatalf("s.ReadContext() = %v, %v; want %v, nil", n, err, maxWindowSize) + if n, err := s.Read(buf); n != maxWindowSize || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, maxWindowSize) } tc.wantFrame("stream window is extended after reading data", packetType1RTT, debugFrameMaxStreamData{ @@ -482,8 +471,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) { data: make([]byte, maxWindowSize), fin: true, }) - if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != io.EOF { - t.Fatalf("s.ReadContext() = %v, %v; want %v, io.EOF", n, err, maxWindowSize) + if n, err := s.Read(buf); n != maxWindowSize || err != io.EOF { + t.Fatalf("s.Read() = %v, %v; want %v, io.EOF", n, err, maxWindowSize) } tc.wantIdle("stream window is not extended after FIN") }) @@ -673,18 +662,19 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { t.Fatalf("AcceptStream() = %v", err) } - // ReadContext succeeds immediately, since we already have data. + // Read succeeds immediately, since we already have data. got := make([]byte, len(want)) read := runAsync(tc, func(ctx context.Context) (int, error) { - return s.ReadContext(ctx, got) + return s.Read(got) }) if n, err := read.result(); n != write1size || err != nil { - t.Fatalf("ReadContext = %v, %v; want %v, nil", n, err, write1size) + t.Fatalf("Read = %v, %v; want %v, nil", n, err, write1size) } - // ReadContext blocks waiting for more data. + // Read blocks waiting for more data. read = runAsync(tc, func(ctx context.Context) (int, error) { - return s.ReadContext(ctx, got[write1size:]) + s.SetReadContext(ctx) + return s.Read(got[write1size:]) }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: sid, @@ -693,7 +683,7 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { fin: true, }) if n, err := read.result(); n != len(want)-write1size || err != io.EOF { - t.Fatalf("ReadContext = %v, %v; want %v, io.EOF", n, err, len(want)-write1size) + t.Fatalf("Read = %v, %v; want %v, io.EOF", n, err, len(want)-write1size) } if !bytes.Equal(got, want) { t.Fatalf("read bytes %x, want %x", got, want) @@ -935,7 +925,8 @@ func TestStreamResetBlockedStream(t *testing.T) { }) tc.ignoreFrame(frameTypeStreamDataBlocked) writing := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, []byte{0, 1, 2, 3, 4, 5, 6, 7}) + s.SetWriteContext(ctx) + return s.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7}) }) tc.wantFrame("stream writes data until write buffer fills", packetType1RTT, debugFrameStream{ @@ -972,7 +963,7 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { want := make([]byte, 4096) rand.Read(want) // doesn't need to be crypto/rand, but non-deprecated and harmless w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want) + n, err := s.Write(want) s.Flush() return n, err }) @@ -992,7 +983,7 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { got = append(got, sf.data...) } if n, err := w.result(); n != len(want) || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(want)) + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) } if !bytes.Equal(got, want) { t.Fatalf("mismatch in received stream data") @@ -1000,17 +991,16 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { } func TestStreamCloseWaitsForAcks(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) - s.WriteContext(ctx, data) + s.Write(data) s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, data: data, }) - if err := s.CloseContext(ctx); err != context.Canceled { + if err := s.Close(); err != context.Canceled { t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err) } tc.wantFrame("conn sends FIN for closed stream", @@ -1021,21 +1011,22 @@ func TestStreamCloseWaitsForAcks(t *testing.T) { data: []byte{}, }) closing := runAsync(tc, func(ctx context.Context) (struct{}, error) { - return struct{}{}, s.CloseContext(ctx) + s.SetWriteContext(ctx) + return struct{}{}, s.Close() }) if _, err := closing.result(); err != errNotDone { - t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err) + t.Fatalf("s.Close() = %v, want it to block waiting for acks", err) } tc.writeAckForAll() if _, err := closing.result(); err != nil { - t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err) + t.Fatalf("s.Close() = %v, want nil (all data acked)", err) } } func TestStreamCloseReadOnly(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, permissiveTransportParameters) - if err := s.CloseContext(canceledContext()); err != nil { - t.Errorf("s.CloseContext() = %v, want nil", err) + if err := s.Close(); err != nil { + t.Errorf("s.Close() = %v, want nil", err) } tc.wantFrame("closed stream sends STOP_SENDING", packetType1RTT, debugFrameStopSending{ @@ -1069,17 +1060,16 @@ func TestStreamCloseUnblocked(t *testing.T) { }, }} { t.Run(test.name, func(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) - s.WriteContext(ctx, data) + s.Write(data) s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, data: data, }) - if err := s.CloseContext(ctx); err != context.Canceled { + if err := s.Close(); err != context.Canceled { t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err) } tc.wantFrame("conn sends FIN for closed stream", @@ -1090,34 +1080,34 @@ func TestStreamCloseUnblocked(t *testing.T) { data: []byte{}, }) closing := runAsync(tc, func(ctx context.Context) (struct{}, error) { - return struct{}{}, s.CloseContext(ctx) + s.SetWriteContext(ctx) + return struct{}{}, s.Close() }) if _, err := closing.result(); err != errNotDone { - t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err) + t.Fatalf("s.Close() = %v, want it to block waiting for acks", err) } test.unblock(tc, s) _, err := closing.result() switch { case err == errNotDone: - t.Fatalf("s.CloseContext() still blocking; want it to have returned") + t.Fatalf("s.Close() still blocking; want it to have returned") case err == nil && !test.success: - t.Fatalf("s.CloseContext() = nil, want error") + t.Fatalf("s.Close() = nil, want error") case err != nil && test.success: - t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err) + t.Fatalf("s.Close() = %v, want nil (all data acked)", err) } }) } } func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters, func(p *transportParameters) { //p.initialMaxData = 0 p.initialMaxStreamDataUni = 0 }) tc.ignoreFrame(frameTypeStreamDataBlocked) - if _, err := s.WriteContext(ctx, []byte{0, 1}); err != nil { + if _, err := s.Write([]byte{0, 1}); err != nil { t.Fatalf("s.Write = %v", err) } s.CloseWrite() @@ -1149,7 +1139,6 @@ func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() tc, s := newTestConnAndRemoteStream(t, serverSide, styp) data := []byte{0, 1, 2, 3, 4, 5, 6, 7} tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -1157,7 +1146,7 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { data: data, }) got := make([]byte, 4) - if n, err := s.ReadContext(ctx, got); n != len(got) || err != nil { + if n, err := s.Read(got); n != len(got) || err != nil { t.Fatalf("Read start of stream: got %v, %v; want %v, nil", n, err, len(got)) } const sentCode = 42 @@ -1167,7 +1156,7 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { code: sentCode, }) wantErr := StreamErrorCode(sentCode) - if n, err := s.ReadContext(ctx, got); n != 0 || !errors.Is(err, wantErr) { + if n, err := s.Read(got); n != 0 || !errors.Is(err, wantErr) { t.Fatalf("Read reset stream: got %v, %v; want 0, %v", n, err, wantErr) } }) @@ -1177,8 +1166,9 @@ func TestStreamPeerResetWakesBlockedRead(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp) reader := runAsync(tc, func(ctx context.Context) (int, error) { + s.SetReadContext(ctx) got := make([]byte, 4) - return s.ReadContext(ctx, got) + return s.Read(got) }) const sentCode = 42 tc.writeFrames(packetType1RTT, debugFrameResetStream{ @@ -1348,7 +1338,8 @@ func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) { want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want) + s.SetWriteContext(ctx) + n, err := s.Write(want) return n, err }) @@ -1401,7 +1392,10 @@ func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opt tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) - return tc, newLocalStream(t, tc, styp) + s := newLocalStream(t, tc, styp) + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return tc, s } func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream { @@ -1411,6 +1405,8 @@ func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream { if err != nil { t.Fatalf("conn.newLocalStream(%v) = %v", styp, err) } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) return s } @@ -1419,7 +1415,10 @@ func newTestConnAndRemoteStream(t *testing.T, side connSide, styp streamType, op tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) - return tc, newRemoteStream(t, tc, styp) + s := newRemoteStream(t, tc, styp) + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return tc, s } func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream { @@ -1432,6 +1431,8 @@ func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream { if err != nil { t.Fatalf("conn.AcceptStream() = %v", err) } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) return s } From 840656f9213922d0bb729d201162410b0bd74d9b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 13 Feb 2024 15:36:02 -0800 Subject: [PATCH 31/70] quic/qlog: don't output empty slog.Attrs For golang/go#58547 Change-Id: I49a27ab82781c817511c6f7da0268529abc3f27f Reviewed-on: https://go-review.googlesource.com/c/net/+/564015 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/qlog/json_writer.go | 6 +++--- internal/quic/qlog/json_writer_test.go | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/internal/quic/qlog/json_writer.go b/internal/quic/qlog/json_writer.go index b2fa3e03e..6fb8d33b2 100644 --- a/internal/quic/qlog/json_writer.go +++ b/internal/quic/qlog/json_writer.go @@ -45,15 +45,15 @@ func (w *jsonWriter) writeRecordEnd() { func (w *jsonWriter) writeAttrs(attrs []slog.Attr) { w.buf.WriteByte('{') for _, a := range attrs { - if a.Key == "" { - continue - } w.writeAttr(a) } w.buf.WriteByte('}') } func (w *jsonWriter) writeAttr(a slog.Attr) { + if a.Key == "" { + return + } w.writeName(a.Key) w.writeValue(a.Value) } diff --git a/internal/quic/qlog/json_writer_test.go b/internal/quic/qlog/json_writer_test.go index 6da556641..03cf6947c 100644 --- a/internal/quic/qlog/json_writer_test.go +++ b/internal/quic/qlog/json_writer_test.go @@ -85,6 +85,15 @@ func TestJSONWriterAttrs(t *testing.T) { `}}`) } +func TestJSONWriterAttrEmpty(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + var a slog.Attr + w.writeAttr(a) + w.writeRecordEnd() + wantJSONRecord(t, w, `{}`) +} + func TestJSONWriterObjectEmpty(t *testing.T) { w := newTestJSONWriter() w.writeRecordStart() From 6e383c4aaf0635c980378ed3217f2a65391895a5 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 17 Nov 2023 08:06:43 -0800 Subject: [PATCH 32/70] quic: add qlog recovery metrics Log events for various congestion control and loss recovery metrics. For golang/go#58547 Change-Id: Ife3b3897f6ca731049c78b934a7123aa1ed4aee2 Reviewed-on: https://go-review.googlesource.com/c/net/+/564016 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/congestion_reno.go | 59 +++++++++++++++++++-- internal/quic/congestion_reno_test.go | 6 +-- internal/quic/conn.go | 2 +- internal/quic/conn_loss.go | 4 ++ internal/quic/conn_recv.go | 4 +- internal/quic/conn_send.go | 21 +++++--- internal/quic/loss.go | 47 ++++++++++++++--- internal/quic/loss_test.go | 10 ++-- internal/quic/packet_writer.go | 7 +-- internal/quic/qlog.go | 16 +++++- internal/quic/qlog_test.go | 76 ++++++++++++++++++++++++++- internal/quic/sent_packet.go | 7 +-- 12 files changed, 222 insertions(+), 37 deletions(-) diff --git a/internal/quic/congestion_reno.go b/internal/quic/congestion_reno.go index 982cbf4bb..a53983524 100644 --- a/internal/quic/congestion_reno.go +++ b/internal/quic/congestion_reno.go @@ -7,6 +7,8 @@ package quic import ( + "context" + "log/slog" "math" "time" ) @@ -40,6 +42,9 @@ type ccReno struct { // true if we haven't sent that packet yet. sendOnePacketInRecovery bool + // inRecovery is set when we are in the recovery state. + inRecovery bool + // underutilized is set if the congestion window is underutilized // due to insufficient application data, flow control limits, or // anti-amplification limits. @@ -100,12 +105,19 @@ func (c *ccReno) canSend() bool { // congestion controller permits sending data, but no data is sent. // // https://www.rfc-editor.org/rfc/rfc9002#section-7.8 -func (c *ccReno) setUnderutilized(v bool) { +func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) { + if c.underutilized == v { + return + } + oldState := c.state() c.underutilized = v + if logEnabled(log, QLogLevelPacket) { + logCongestionStateUpdated(log, oldState, c.state()) + } } // packetSent indicates that a packet has been sent. -func (c *ccReno) packetSent(now time.Time, space numberSpace, sent *sentPacket) { +func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { if !sent.inFlight { return } @@ -185,7 +197,11 @@ func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket, } // packetBatchEnd is called at the end of processing a batch of acked or lost packets. -func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { +func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { + if logEnabled(log, QLogLevelPacket) { + oldState := c.state() + defer func() { logCongestionStateUpdated(log, oldState, c.state()) }() + } if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) { // Enter the recovery state. // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 @@ -196,8 +212,10 @@ func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, // Clear congestionPendingAcks to avoid increasing the congestion // window based on acks in a frame that sends us into recovery. c.congestionPendingAcks = 0 + c.inRecovery = true } else if c.congestionPendingAcks > 0 { // We are in slow start or congestion avoidance. + c.inRecovery = false if c.congestionWindow < c.slowStartThreshold { // When the congestion window is less than the slow start threshold, // we are in slow start and increase the window by the number of @@ -253,3 +271,38 @@ func (c *ccReno) minimumCongestionWindow() int { // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4 return 2 * c.maxDatagramSize } + +func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) { + if oldState == newState { + return + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:congestion_state_updated", + slog.String("old", oldState.String()), + slog.String("new", newState.String()), + ) +} + +type congestionState string + +func (s congestionState) String() string { return string(s) } + +const ( + congestionSlowStart = congestionState("slow_start") + congestionCongestionAvoidance = congestionState("congestion_avoidance") + congestionApplicationLimited = congestionState("application_limited") + congestionRecovery = congestionState("recovery") +) + +func (c *ccReno) state() congestionState { + switch { + case c.inRecovery: + return congestionRecovery + case c.underutilized: + return congestionApplicationLimited + case c.congestionWindow < c.slowStartThreshold: + return congestionSlowStart + default: + return congestionCongestionAvoidance + } +} diff --git a/internal/quic/congestion_reno_test.go b/internal/quic/congestion_reno_test.go index e9af6452c..cda7a90a8 100644 --- a/internal/quic/congestion_reno_test.go +++ b/internal/quic/congestion_reno_test.go @@ -470,7 +470,7 @@ func (c *ccTest) setRTT(smoothedRTT, rttvar time.Duration) { func (c *ccTest) setUnderutilized(v bool) { c.t.Helper() c.t.Logf("set underutilized = %v", v) - c.cc.setUnderutilized(v) + c.cc.setUnderutilized(nil, v) } func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket)) *sentPacket { @@ -488,7 +488,7 @@ func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket f(sent) } c.t.Logf("packet sent: num=%v.%v, size=%v", space, sent.num, sent.size) - c.cc.packetSent(c.now, space, sent) + c.cc.packetSent(c.now, nil, space, sent) return sent } @@ -519,7 +519,7 @@ func (c *ccTest) packetDiscarded(space numberSpace, sent *sentPacket) { func (c *ccTest) packetBatchEnd(space numberSpace) { c.t.Helper() c.t.Logf("(end of batch)") - c.cc.packetBatchEnd(c.now, space, &c.rtt, c.maxAckDelay) + c.cc.packetBatchEnd(c.now, nil, space, &c.rtt, c.maxAckDelay) } func (c *ccTest) wantCanSend(want bool) { diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 6d79013eb..020bc81a4 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -210,7 +210,7 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) { case handshakeSpace: c.keysHandshake.discard() } - c.loss.discardKeys(now, space) + c.loss.discardKeys(now, c.log, space) } // receiveTransportParameters applies transport parameters sent by the peer. diff --git a/internal/quic/conn_loss.go b/internal/quic/conn_loss.go index 85bda314e..623ebdd7c 100644 --- a/internal/quic/conn_loss.go +++ b/internal/quic/conn_loss.go @@ -20,6 +20,10 @@ import "fmt" // See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss. // https://www.rfc-editor.org/rfc/rfc9000#section-13.3 func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { + if fate == packetLost && c.logEnabled(QLogLevelPacket) { + c.logPacketLost(space, sent) + } + // The list of frames in a sent packet is marshaled into a buffer in the sentPacket // by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with // packetWriter.append*. diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 045bf861c..b666ce8eb 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -192,7 +192,7 @@ func (c *Conn) handleRetry(now time.Time, pkt []byte) { c.connIDState.handleRetryPacket(p.srcConnID) // We need to resend any data we've already sent in Initial packets. // We must not reuse already sent packet numbers. - c.loss.discardPackets(initialSpace, c.handleAckOrLoss) + c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss) // TODO: Discard 0-RTT packets as well, once we support 0-RTT. } @@ -416,7 +416,7 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) if c.peerAckDelayExponent >= 0 { delay = ackDelay.Duration(uint8(c.peerAckDelayExponent)) } - c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss) + c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss) if space == appDataSpace { c.keysAppData.handleAckFor(largest) } diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index ccb467591..575b8f9b4 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -22,7 +22,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Assumption: The congestion window is not underutilized. // If congestion control, pacing, and anti-amplification all permit sending, // but we have no packet to send, then we will declare the window underutilized. - c.loss.cc.setUnderutilized(false) + underutilized := false + defer func() { + c.loss.cc.setUnderutilized(c.log, underutilized) + }() // Send one datagram on each iteration of this loop, // until we hit a limit or run out of data to send. @@ -80,7 +83,6 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { - c.idleHandlePacketSent(now, sentInitial) // Client initial packets and ack-eliciting server initial packaets // need to be sent in a datagram padded to at least 1200 bytes. // We can't add the padding yet, however, since we may want to @@ -111,8 +113,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { - c.idleHandlePacketSent(now, sent) - c.loss.packetSent(now, handshakeSpace, sent) + c.packetSent(now, handshakeSpace, sent) if c.side == clientSide { // "[...] a client MUST discard Initial keys when it first // sends a Handshake packet [...]" @@ -142,8 +143,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { - c.idleHandlePacketSent(now, sent) - c.loss.packetSent(now, appDataSpace, sent) + c.packetSent(now, appDataSpace, sent) } } @@ -152,7 +152,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if limit == ccOK { // We have nothing to send, and congestion control does not // block sending. The congestion window is underutilized. - c.loss.cc.setUnderutilized(true) + underutilized = true } return next } @@ -175,7 +175,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // with a Handshake packet, then we've discarded Initial keys // since constructing the packet and shouldn't record it as in-flight. if c.keysInitial.canWrite() { - c.loss.packetSent(now, initialSpace, sentInitial) + c.packetSent(now, initialSpace, sentInitial) } } @@ -183,6 +183,11 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } } +func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) { + c.idleHandlePacketSent(now, sent) + c.loss.packetSent(now, c.log, space, sent) +} + func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { if c.lifetime.localErr != nil { c.appendConnectionCloseFrame(now, space, c.lifetime.localErr) diff --git a/internal/quic/loss.go b/internal/quic/loss.go index a59081fd5..796b5f7a3 100644 --- a/internal/quic/loss.go +++ b/internal/quic/loss.go @@ -7,6 +7,8 @@ package quic import ( + "context" + "log/slog" "math" "time" ) @@ -179,7 +181,7 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber { } // packetSent records a sent packet. -func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacket) { +func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { sent.time = now c.spaces[space].add(sent) size := sent.size @@ -187,13 +189,16 @@ func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacke c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size) } if sent.inFlight { - c.cc.packetSent(now, space, sent) + c.cc.packetSent(now, log, space, sent) c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT) if sent.ackEliciting { c.spaces[space].lastAckEliciting = sent.num c.ptoExpired = false // reset expired PTO timer after sending probe } c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } if sent.ackEliciting { c.consecutiveNonAckElicitingPackets = 0 @@ -267,7 +272,7 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex // receiveAckEnd finishes processing an ack frame. // The lossf function is called for each packet newly detected as lost. -func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { +func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { c.spaces[space].sentPacketList.clean() // Update the RTT sample when the largest acknowledged packet in the ACK frame // is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting. @@ -286,13 +291,30 @@ func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay tim // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3 c.timer = time.Time{} c.detectLoss(now, lossf) - c.cc.packetBatchEnd(now, space, &c.rtt, c.maxAckDelay) + c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay) + + if logEnabled(log, QLogLevelPacket) { + var ssthresh slog.Attr + if c.cc.slowStartThreshold != math.MaxInt { + ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold) + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Duration("min_rtt", c.rtt.minRTT), + slog.Duration("smoothed_rtt", c.rtt.smoothedRTT), + slog.Duration("latest_rtt", c.rtt.latestRTT), + slog.Duration("rtt_variance", c.rtt.rttvar), + slog.Int("congestion_window", c.cc.congestionWindow), + slog.Int("bytes_in_flight", c.cc.bytesInFlight), + ssthresh, + ) + } } // discardPackets declares that packets within a number space will not be delivered // and that data contained in them should be resent. // For example, after receiving a Retry packet we discard already-sent Initial packets. -func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *sentPacket, packetFate)) { +func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) { for i := 0; i < c.spaces[space].size; i++ { sent := c.spaces[space].nth(i) sent.lost = true @@ -300,10 +322,13 @@ func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *s lossf(numberSpace(space), sent, packetLost) } c.spaces[space].clean() + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } // discardKeys is called when dropping packet protection keys for a number space. -func (c *lossState) discardKeys(now time.Time, space numberSpace) { +func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) { // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4 for i := 0; i < c.spaces[space].size; i++ { sent := c.spaces[space].nth(i) @@ -313,6 +338,9 @@ func (c *lossState) discardKeys(now time.Time, space numberSpace) { c.spaces[space].maxAcked = -1 c.spaces[space].lastAckEliciting = -1 c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } func (c *lossState) lossDuration() time.Duration { @@ -459,3 +487,10 @@ func (c *lossState) ptoBasePeriod() time.Duration { } return pto } + +func logBytesInFlight(log *slog.Logger, bytesInFlight int) { + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Int("bytes_in_flight", bytesInFlight), + ) +} diff --git a/internal/quic/loss_test.go b/internal/quic/loss_test.go index efbf1649e..1fb9662e4 100644 --- a/internal/quic/loss_test.go +++ b/internal/quic/loss_test.go @@ -1060,7 +1060,7 @@ func TestLossPersistentCongestion(t *testing.T) { maxDatagramSize: 1200, }) test.send(initialSpace, 0, testSentPacketSize(1200)) - test.c.cc.setUnderutilized(true) + test.c.cc.setUnderutilized(nil, true) test.advance(10 * time.Millisecond) test.ack(initialSpace, 0*time.Millisecond, i64range[packetNumber]{0, 1}) @@ -1377,7 +1377,7 @@ func (c *lossTest) setRTTVar(d time.Duration) { func (c *lossTest) setUnderutilized(v bool) { c.t.Logf("set congestion window underutilized: %v", v) - c.c.cc.setUnderutilized(v) + c.c.cc.setUnderutilized(nil, v) } func (c *lossTest) advance(d time.Duration) { @@ -1438,7 +1438,7 @@ func (c *lossTest) send(spaceID numberSpace, opts ...any) { sent := &sentPacket{} *sent = prototype sent.num = num - c.c.packetSent(c.now, spaceID, sent) + c.c.packetSent(c.now, nil, spaceID, sent) } } @@ -1462,7 +1462,7 @@ func (c *lossTest) ack(spaceID numberSpace, ackDelay time.Duration, rs ...i64ran c.t.Logf("ack %v delay=%v [%v,%v)", spaceID, ackDelay, r.start, r.end) c.c.receiveAckRange(c.now, spaceID, i, r.start, r.end, c.onAckOrLoss) } - c.c.receiveAckEnd(c.now, spaceID, ackDelay, c.onAckOrLoss) + c.c.receiveAckEnd(c.now, nil, spaceID, ackDelay, c.onAckOrLoss) } func (c *lossTest) onAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { @@ -1491,7 +1491,7 @@ func (c *lossTest) discardKeys(spaceID numberSpace) { c.t.Helper() c.checkUnexpectedEvents() c.t.Logf("discard %s keys", spaceID) - c.c.discardKeys(c.now, spaceID) + c.c.discardKeys(c.now, nil, spaceID) } func (c *lossTest) setMaxAckDelay(d time.Duration) { diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index b4e54ce4b..85149f607 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -141,7 +141,7 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num) - return w.finish(p.num) + return w.finish(p.ptype, p.num) } // start1RTTPacket starts writing a 1-RTT (short header) packet. @@ -183,7 +183,7 @@ func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConn hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) w.padPacketLength(pnumLen) k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum) - return w.finish(pnum) + return w.finish(packetType1RTT, pnum) } // padPacketLength pads out the payload of the current packet to the minimum size, @@ -204,9 +204,10 @@ func (w *packetWriter) padPacketLength(pnumLen int) int { } // finish finishes the current packet after protection is applied. -func (w *packetWriter) finish(pnum packetNumber) *sentPacket { +func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket { w.b = w.b[:len(w.b)+aeadOverhead] w.sent.size = len(w.b) - w.pktOff + w.sent.ptype = ptype w.sent.num = pnum sent := w.sent w.sent = nil diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index 82ad92ac8..e37e2f8ce 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -39,7 +39,11 @@ const ( ) func (c *Conn) logEnabled(level slog.Level) bool { - return c.log != nil && c.log.Enabled(context.Background(), level) + return logEnabled(c.log, level) +} + +func logEnabled(log *slog.Logger, level slog.Level) bool { + return log != nil && log.Enabled(context.Background(), level) } // slogHexstring returns a slog.Attr for a value of the hexstring type. @@ -252,3 +256,13 @@ func (c *Conn) packetFramesAttr(payload []byte) slog.Attr { } return slog.Any("frames", frames) } + +func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:packet_lost", + slog.Group("header", + slog.String("packet_type", sent.ptype.qlogString()), + slog.Uint64("packet_number", uint64(sent.num)), + ), + ) +} diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go index e98b11838..7ad65524c 100644 --- a/internal/quic/qlog_test.go +++ b/internal/quic/qlog_test.go @@ -159,6 +159,77 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { } } +func TestQLogRecovery(t *testing.T) { + qr := &qlogRecord{} + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, qr.config) + + // Ignore events from the handshake. + qr.ev = nil + + data := make([]byte, 16) + s.Write(data) + s.CloseWrite() + tc.wantFrame("created stream 0", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + fin: true, + data: data, + }) + tc.writeAckForAll() + tc.wantIdle("connection should be idle now") + + // Don't check the contents of fields, but verify that recovery metrics are logged. + qr.wantEvents(t, jsonEvent{ + "name": "recovery:metrics_updated", + "data": map[string]any{ + "bytes_in_flight": nil, + }, + }, jsonEvent{ + "name": "recovery:metrics_updated", + "data": map[string]any{ + "bytes_in_flight": 0, + "congestion_window": nil, + "latest_rtt": nil, + "min_rtt": nil, + "rtt_variance": nil, + "smoothed_rtt": nil, + }, + }) +} + +func TestQLogLoss(t *testing.T) { + qr := &qlogRecord{} + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, qr.config) + + // Ignore events from the handshake. + qr.ev = nil + + data := make([]byte, 16) + s.Write(data) + s.CloseWrite() + tc.wantFrame("created stream 0", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + fin: true, + data: data, + }) + + const pto = false + tc.triggerLossOrPTO(packetType1RTT, pto) + + qr.wantEvents(t, jsonEvent{ + "name": "recovery:packet_lost", + "data": map[string]any{ + "header": map[string]any{ + "packet_number": nil, + "packet_type": "1RTT", + }, + }, + }) +} + type nopCloseWriter struct { io.Writer } @@ -193,14 +264,15 @@ func jsonPartialEqual(got, want any) (equal bool) { } return v } + if want == nil { + return true // match anything + } got = cmpval(got) want = cmpval(want) if reflect.TypeOf(got) != reflect.TypeOf(want) { return false } switch w := want.(type) { - case nil: - // Match anything. case map[string]any: // JSON object: Every field in want must match a field in got. g := got.(map[string]any) diff --git a/internal/quic/sent_packet.go b/internal/quic/sent_packet.go index 4f11aa136..194cdc9fa 100644 --- a/internal/quic/sent_packet.go +++ b/internal/quic/sent_packet.go @@ -14,9 +14,10 @@ import ( // A sentPacket tracks state related to an in-flight packet we sent, // to be committed when the peer acks it or resent if the packet is lost. type sentPacket struct { - num packetNumber - size int // size in bytes - time time.Time // time sent + num packetNumber + size int // size in bytes + time time.Time // time sent + ptype packetType ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1 inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 From 93be8fe122ca52e008630144471e9473d94cc43f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 28 Nov 2023 09:19:01 -0800 Subject: [PATCH 33/70] quic: log packet_dropped events Log unparsable or otherwise discarded packets. For golang/go#58547 Change-Id: Ief64174d91c93691bd524515aa6518e487543ced Reviewed-on: https://go-review.googlesource.com/c/net/+/564017 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_recv.go | 15 ++++++++++++--- internal/quic/qlog_test.go | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index b666ce8eb..1b3219723 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -8,6 +8,7 @@ package quic import ( "bytes" + "context" "encoding/binary" "errors" "time" @@ -56,9 +57,16 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen { var token statelessResetToken copy(token[:], buf[len(buf)-len(token):]) - c.handleStatelessReset(now, token) + if c.handleStatelessReset(now, token) { + return + } } // Invalid data at the end of a datagram is ignored. + if c.logEnabled(QLogLevelPacket) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "connectivity:packet_dropped", + ) + } break } c.idleHandlePacketReceived(now) @@ -562,10 +570,11 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa var errStatelessReset = errors.New("received stateless reset") -func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) { +func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) { if !c.connIDState.isValidStatelessResetToken(resetToken) { - return + return false } c.setFinalError(errStatelessReset) c.enterDraining(now) + return true } diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go index 7ad65524c..6c79c6cf4 100644 --- a/internal/quic/qlog_test.go +++ b/internal/quic/qlog_test.go @@ -7,6 +7,7 @@ package quic import ( + "bytes" "encoding/hex" "encoding/json" "fmt" @@ -230,6 +231,27 @@ func TestQLogLoss(t *testing.T) { }) } +func TestQLogPacketDropped(t *testing.T) { + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, permissiveTransportParameters, qr.config) + tc.handshake() + + // A garbage-filled datagram with a DCID matching this connection. + dgram := bytes.Join([][]byte{ + {headerFormShort | fixedBit}, + testLocalConnID(0), + make([]byte, 100), + []byte{1, 2, 3, 4}, // random data, to avoid this looking like a stateless reset + }, nil) + tc.endpoint.write(&datagram{ + b: dgram, + }) + + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:packet_dropped", + }) +} + type nopCloseWriter struct { io.Writer } From 117945d00a55197e260d73c6272a2588d39bdebe Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 20 Nov 2023 16:05:25 -0800 Subject: [PATCH 34/70] quic: add throughput and stream creation benchmarks For golang/go#58547 Change-Id: Ie62fcf596bf020bda5a167f7a0d3d95bac9e591a Reviewed-on: https://go-review.googlesource.com/c/net/+/564475 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/bench_test.go | 99 ++++++++++++++++++++++++++++++++++ internal/quic/endpoint_test.go | 4 +- 2 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 internal/quic/bench_test.go diff --git a/internal/quic/bench_test.go b/internal/quic/bench_test.go new file mode 100644 index 000000000..f883b788c --- /dev/null +++ b/internal/quic/bench_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "fmt" + "io" + "math" + "testing" +) + +// BenchmarkThroughput is based on the crypto/tls benchmark of the same name. +func BenchmarkThroughput(b *testing.B) { + for size := 1; size <= 64; size <<= 1 { + name := fmt.Sprintf("%dMiB", size) + b.Run(name, func(b *testing.B) { + throughput(b, int64(size<<20)) + }) + } +} + +func throughput(b *testing.B, totalBytes int64) { + // Same buffer size as crypto/tls's BenchmarkThroughput, for consistency. + const bufsize = 32 << 10 + + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + go func() { + buf := make([]byte, bufsize) + for i := 0; i < b.N; i++ { + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + if _, err := io.CopyBuffer(sconn, sconn, buf); err != nil { + panic(fmt.Errorf("CopyBuffer: %v", err)) + } + sconn.Close() + } + }() + + b.SetBytes(totalBytes) + buf := make([]byte, bufsize) + chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) + for i := 0; i < b.N; i++ { + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + closec := make(chan struct{}) + go func() { + defer close(closec) + buf := make([]byte, bufsize) + if _, err := io.CopyBuffer(io.Discard, cconn, buf); err != nil { + panic(fmt.Errorf("Discard: %v", err)) + } + }() + for j := 0; j < chunks; j++ { + _, err := cconn.Write(buf) + if err != nil { + b.Fatalf("Write: %v", err) + } + } + cconn.CloseWrite() + <-closec + cconn.Close() + } +} + +func BenchmarkStreamCreation(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + go func() { + for i := 0; i < b.N; i++ { + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + sconn.Close() + } + }() + + buf := make([]byte, 1) + for i := 0; i < b.N; i++ { + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Write(buf) + cconn.Flush() + cconn.Read(buf) + cconn.Close() + } +} diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index 16c3e0bce..6d103f061 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -63,7 +63,7 @@ func TestStreamTransfer(t *testing.T) { } } -func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { +func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() e1 := newLocalEndpoint(t, serverSide, conf1) @@ -79,7 +79,7 @@ func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverCon return c2, c1 } -func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { +func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint { t.Helper() if conf.TLSConfig == nil { newConf := *conf From e94da73eedb3c3244dcc3857c74accb642dd8eac Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 12 Dec 2023 15:48:00 -0800 Subject: [PATCH 35/70] quic: reduce ack frequency after the first 100 packets RFC 9000 recommends sending an ack for every second ack-eliciting packet received. This frequency is high enough to have a noticeable impact on performance. Follow the approach used by Google QUICHE: Ack every other packet for the first 100 packets, and then switch to acking every 10th packet. (Various other implementations also use a reduced ack frequency; see Custura et al., 2022.) For golang/go#58547 Change-Id: Idc7051cec23c279811030eb555bc49bb888d6795 Reviewed-on: https://go-review.googlesource.com/c/net/+/564476 Reviewed-by: Jonathan Amsterdam Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI --- internal/quic/acks.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/internal/quic/acks.go b/internal/quic/acks.go index ba860efb2..039b7b46e 100644 --- a/internal/quic/acks.go +++ b/internal/quic/acks.go @@ -130,12 +130,19 @@ func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bo // there are no gaps. If it does not, there must be a gap. return true } - if acks.unackedAckEliciting >= 2 { - // "[...] after receiving at least two ack-eliciting packets." - // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 - return true + // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 + // + // This ack frequency takes a substantial toll on performance, however. + // Follow the behavior of Google QUICHE: + // Ack every other packet for the first 100 packets, and then ack every 10th packet. + // This keeps ack frequency high during the beginning of slow start when CWND is + // increasing rapidly. + packetsBeforeAck := 2 + if acks.seen.max() > 100 { + packetsBeforeAck = 10 } - return false + return acks.unackedAckEliciting >= packetsBeforeAck } // shouldSendAck reports whether the connection should send an ACK frame at this time, From dda3687b193e5e1fb31df72be5e0bc6ae7841d2e Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 31 Oct 2023 14:06:26 -0700 Subject: [PATCH 36/70] quic: add Stream.ReadByte, Stream.WriteByte Currently unoptimized and slow. Adding along with a benchmark to compare to the fast-path followup. For golang/go#58547 Change-Id: If02b65e6e7cfc770d3f949e5fb9fbb9d8a765a90 Reviewed-on: https://go-review.googlesource.com/c/net/+/564477 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/bench_test.go | 71 +++++++++++++++++++++++++++++++++++++ internal/quic/stream.go | 14 ++++++++ 2 files changed, 85 insertions(+) diff --git a/internal/quic/bench_test.go b/internal/quic/bench_test.go index f883b788c..636b71327 100644 --- a/internal/quic/bench_test.go +++ b/internal/quic/bench_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "math" + "sync" "testing" ) @@ -72,6 +73,76 @@ func throughput(b *testing.B, totalBytes int64) { } } +func BenchmarkReadByte(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 1<<20) + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + for { + if _, err := sconn.Write(buf); err != nil { + break + } + sconn.Flush() + } + }() + + b.SetBytes(1) + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Flush() + for i := 0; i < b.N; i++ { + _, err := cconn.ReadByte() + if err != nil { + b.Fatalf("ReadByte: %v", err) + } + } + cconn.Close() +} + +func BenchmarkWriteByte(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + n, err := io.Copy(io.Discard, sconn) + if n != int64(b.N) || err != nil { + b.Errorf("server io.Copy() = %v, %v; want %v, nil", n, err, b.N) + } + }() + + b.SetBytes(1) + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Flush() + for i := 0; i < b.N; i++ { + if err := cconn.WriteByte(0); err != nil { + b.Fatalf("WriteByte: %v", err) + } + } + cconn.Close() +} + func BenchmarkStreamCreation(b *testing.B) { cli, srv := newLocalConnPair(b, &Config{}, &Config{}) diff --git a/internal/quic/stream.go b/internal/quic/stream.go index d0122b951..670b34263 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -245,6 +245,13 @@ func (s *Stream) Read(b []byte) (n int, err error) { return len(b), nil } +// ReadByte reads and returns a single byte from the stream. +func (s *Stream) ReadByte() (byte, error) { + var b [1]byte + _, err := s.Read(b[:]) + return b[0], err +} + // shouldUpdateFlowControl determines whether to send a flow control window update. // // We want to balance keeping the peer well-supplied with flow control with not sending @@ -326,6 +333,13 @@ func (s *Stream) Write(b []byte) (n int, err error) { return n, nil } +// WriteBytes writes a single byte to the stream. +func (s *Stream) WriteByte(c byte) error { + b := [1]byte{c} + _, err := s.Write(b[:]) + return err +} + // Flush flushes data written to the stream. // It does not wait for the peer to acknowledge receipt of the data. // Use Close to wait for the peer's acknowledgement. From cc568eace4e2768d6befe9748ee0f3cd4edd9a10 Mon Sep 17 00:00:00 2001 From: Tobias Klauser Date: Tue, 20 Feb 2024 12:29:30 +0100 Subject: [PATCH 37/70] internal/quic: use slices.Equal in TestAcksSent The module go.mod uses go 1.18 and acks_test.go has a go:build go1.21 tag. Change-Id: Ic0785bcb4795bedecc6a752f5e67a967851237e6 Reviewed-on: https://go-review.googlesource.com/c/net/+/565137 Reviewed-by: Than McIntosh Auto-Submit: Tobias Klauser Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/acks_test.go | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/internal/quic/acks_test.go b/internal/quic/acks_test.go index 4f1032910..d10f917ad 100644 --- a/internal/quic/acks_test.go +++ b/internal/quic/acks_test.go @@ -7,6 +7,7 @@ package quic import ( + "slices" "testing" "time" ) @@ -198,7 +199,7 @@ func TestAcksSent(t *testing.T) { if len(gotNums) == 0 { wantDelay = 0 } - if !slicesEqual(gotNums, test.wantAcks) || gotDelay != wantDelay { + if !slices.Equal(gotNums, test.wantAcks) || gotDelay != wantDelay { t.Errorf("acks.acksToSend(T+%v) = %v, %v; want %v, %v", delay, gotNums, gotDelay, test.wantAcks, wantDelay) } } @@ -206,20 +207,6 @@ func TestAcksSent(t *testing.T) { } } -// slicesEqual reports whether two slices are equal. -// Replace this with slices.Equal once the module go.mod is go1.17 or newer. -func slicesEqual[E comparable](s1, s2 []E) bool { - if len(s1) != len(s2) { - return false - } - for i := range s1 { - if s1[i] != s2[i] { - return false - } - } - return true -} - func TestAcksDiscardAfterAck(t *testing.T) { acks := ackState{} now := time.Now() From 08d27e39b9ef291f25ae7e4d34440c8d89d6b7f7 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 27 Nov 2023 09:07:49 -0800 Subject: [PATCH 38/70] quic: fast path for stream reads Keep a reference to the next chunk of bytes available for reading in an unsynchronized buffer. Read and ReadByte calls read from this buffer when possible, avoiding the need to lock the stream. This change makes it unnecessary to wrap a stream in a *bytes.Buffer when making small reads, at the expense of making reads concurrency-unsafe. Since the quic package is a low-level one and this lets us avoid an extra buffer in the HTTP/3 implementation, the tradeoff seems worthwhile. For golang/go#58547 Change-Id: Ib3ca446311974571c2367295b302f36a6349b00d Reviewed-on: https://go-review.googlesource.com/c/net/+/564495 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_flow_test.go | 52 ++++++++++++------------- internal/quic/conn_loss_test.go | 20 +++++++--- internal/quic/pipe.go | 23 ++++++++--- internal/quic/stream.go | 67 +++++++++++++++++++++++++++------ internal/quic/stream_test.go | 30 ++++++++++++++- 5 files changed, 143 insertions(+), 49 deletions(-) diff --git a/internal/quic/conn_flow_test.go b/internal/quic/conn_flow_test.go index 8e04e20d9..260684bdb 100644 --- a/internal/quic/conn_flow_test.go +++ b/internal/quic/conn_flow_test.go @@ -17,33 +17,29 @@ func TestConnInflowReturnOnRead(t *testing.T) { }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, - data: make([]byte, 64), + data: make([]byte, 8), }) - const readSize = 8 - if n, err := s.Read(make([]byte, readSize)); n != readSize || err != nil { - t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize) - } - tc.wantFrame("available window increases, send a MAX_DATA", - packetType1RTT, debugFrameMaxData{ - max: 64 + readSize, - }) - if n, err := s.Read(make([]byte, 64)); n != 64-readSize || err != nil { - t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize) + if n, err := s.Read(make([]byte, 8)); n != 8 || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 8) } tc.wantFrame("available window increases, send a MAX_DATA", packetType1RTT, debugFrameMaxData{ - max: 128, + max: 64 + 8, }) // Peer can write up to the new limit. tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, - off: 64, + off: 8, data: make([]byte, 64), }) - tc.wantIdle("connection is idle") - if n, err := s.Read(make([]byte, 64)); n != 64 || err != nil { - t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64) + if n, err := s.Read(make([]byte, 64+1)); n != 64 { + t.Fatalf("s.Read() = %v, %v; want %v, anything", n, err, 64) } + tc.wantFrame("available window increases, send a MAX_DATA", + packetType1RTT, debugFrameMaxData{ + max: 64 + 8 + 64, + }) + tc.wantIdle("connection is idle") } func TestConnInflowReturnOnRacingReads(t *testing.T) { @@ -63,11 +59,11 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { tc.ignoreFrame(frameTypeAck) tc.writeFrames(packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, uniStream, 0), - data: make([]byte, 32), + data: make([]byte, 16), }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, uniStream, 1), - data: make([]byte, 32), + data: make([]byte, 1), }) s1, err := tc.conn.AcceptStream(ctx) if err != nil { @@ -203,7 +199,6 @@ func TestConnInflowResetViolation(t *testing.T) { } func TestConnInflowMultipleStreams(t *testing.T) { - ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { c.MaxConnReadBufferSize = 128 }) @@ -219,12 +214,9 @@ func TestConnInflowMultipleStreams(t *testing.T) { } { tc.writeFrames(packetType1RTT, debugFrameStream{ id: id, - data: make([]byte, 32), + data: make([]byte, 1), }) - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream() = %v", err) - } + s := tc.acceptStream() streams = append(streams, s) if n, err := s.Read(make([]byte, 1)); err != nil || n != 1 { t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err) @@ -232,8 +224,16 @@ func TestConnInflowMultipleStreams(t *testing.T) { } tc.wantIdle("streams have read data, but not enough to update MAX_DATA") - if n, err := streams[0].Read(make([]byte, 32)); err != nil || n != 31 { - t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err) + for _, s := range streams { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 1, + data: make([]byte, 31), + }) + } + + if n, err := streams[0].Read(make([]byte, 32)); n != 31 { + t.Fatalf("s.Read() = %v, %v; want 31, anything", n, err) } tc.wantFrame("read enough data to trigger a MAX_DATA update", packetType1RTT, debugFrameMaxData{ diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index 876ffd093..86ef23db0 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -308,9 +308,9 @@ func TestLostMaxDataFrame(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, off: 0, - data: make([]byte, maxWindowSize), + data: make([]byte, maxWindowSize-1), }) - if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 { + if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1) } tc.wantFrame("conn window is extended after reading data", @@ -319,7 +319,12 @@ func TestLostMaxDataFrame(t *testing.T) { }) // MAX_DATA = 64, which is only one more byte, so we don't send the frame. - if n, err := s.Read(buf); err != nil || n != 1 { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: maxWindowSize - 1, + data: make([]byte, 1), + }) + if n, err := s.Read(buf[:1]); err != nil || n != 1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) } tc.wantIdle("read doesn't extend window enough to send another MAX_DATA") @@ -348,9 +353,9 @@ func TestLostMaxStreamDataFrame(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, off: 0, - data: make([]byte, maxWindowSize), + data: make([]byte, maxWindowSize-1), }) - if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 { + if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1) } tc.wantFrame("stream window is extended after reading data", @@ -360,6 +365,11 @@ func TestLostMaxStreamDataFrame(t *testing.T) { }) // MAX_STREAM_DATA = 64, which is only one more byte, so we don't send the frame. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: maxWindowSize - 1, + data: make([]byte, 1), + }) if n, err := s.Read(buf); err != nil || n != 1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) } diff --git a/internal/quic/pipe.go b/internal/quic/pipe.go index d3a448df3..42a0049da 100644 --- a/internal/quic/pipe.go +++ b/internal/quic/pipe.go @@ -17,14 +17,14 @@ import ( // Writing past the end of the window extends it. // Data may be discarded from the start of the pipe, advancing the window. type pipe struct { - start int64 - end int64 - head *pipebuf - tail *pipebuf + start int64 // stream position of first stored byte + end int64 // stream position just past the last stored byte + head *pipebuf // if non-nil, then head.off + len(head.b) > start + tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end } type pipebuf struct { - off int64 + off int64 // stream position of b[0] b []byte next *pipebuf } @@ -111,6 +111,7 @@ func (p *pipe) copy(off int64, b []byte) { // read calls f with the data in [off, off+n) // The data may be provided sequentially across multiple calls to f. +// Note that read (unlike an io.Reader) does not consume the read data. func (p *pipe) read(off int64, n int, f func([]byte) error) error { if off < p.start { panic("invalid read range") @@ -135,6 +136,18 @@ func (p *pipe) read(off int64, n int, f func([]byte) error) error { return nil } +// peek returns a reference to up to n bytes of internal data buffer, starting at p.start. +// The returned slice is valid until the next call to discardBefore. +// The length of the returned slice will be in the range [0,n]. +func (p *pipe) peek(n int64) []byte { + pb := p.head + if pb == nil { + return nil + } + b := pb.b[p.start-pb.off:] + return b[:min(int64(len(b)), n)] +} + // discardBefore discards all data prior to off. func (p *pipe) discardBefore(off int64) { for p.head != nil && p.head.end() < off { diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 670b34263..17ca8b7d6 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -23,7 +23,7 @@ type Stream struct { inctx context.Context outctx context.Context - // ingate's lock guards all receive-related state. + // ingate's lock guards receive-related state. // // The gate condition is set if a read from the stream will not block, // either because the stream has available data or because the read will fail. @@ -37,7 +37,7 @@ type Stream struct { inclosed sentVal // set by CloseRead inresetcode int64 // RESET_STREAM code received from the peer; -1 if not reset - // outgate's lock guards all send-related state. + // outgate's lock guards send-related state. // // The gate condition is set if a write to the stream will not block, // either because the stream has available flow control or because @@ -57,6 +57,10 @@ type Stream struct { outresetcode uint64 // reset code to send in RESET_STREAM outdone chan struct{} // closed when all data sent + // Unsynchronized buffers, used for lock-free fast path. + inbuf []byte // received data + inbufoff int // bytes of inbuf which have been consumed + // Atomic stream state bits. // // These bits provide a fast way to coordinate between the @@ -202,16 +206,35 @@ func (s *Stream) IsWriteOnly() bool { // returning all data sent by the peer. // If the peer aborts reads on the stream, Read returns // an error wrapping StreamResetCode. +// +// It is not safe to call Read concurrently. func (s *Stream) Read(b []byte) (n int, err error) { if s.IsWriteOnly() { return 0, errors.New("read from write-only stream") } + if len(s.inbuf) > s.inbufoff { + // Fast path: If s.inbuf contains unread bytes, return them immediately + // without taking a lock. + n = copy(b, s.inbuf[s.inbufoff:]) + s.inbufoff += n + return n, nil + } if err := s.ingate.waitAndLock(s.inctx, s.conn.testHooks); err != nil { return 0, err } + if s.inbufoff > 0 { + // Discard bytes consumed by the fast path above. + s.in.discardBefore(s.in.start + int64(s.inbufoff)) + s.inbufoff = 0 + s.inbuf = nil + } + // bytesRead contains the number of bytes of connection-level flow control to return. + // We return flow control for bytes read by this Read call, as well as bytes moved + // to the fast-path read buffer (s.inbuf). + var bytesRead int64 defer func() { s.inUnlock() - s.conn.handleStreamBytesReadOffLoop(int64(n)) // must be done with ingate unlocked + s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked }() if s.inresetcode != -1 { return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode)) @@ -229,27 +252,48 @@ func (s *Stream) Read(b []byte) (n int, err error) { if size := int(s.inset[0].end - s.in.start); size < len(b) { b = b[:size] } + bytesRead = int64(len(b)) start := s.in.start end := start + int64(len(b)) s.in.copy(start, b) s.in.discardBefore(end) + if end == s.insize { + // We have read up to the end of the stream. + // No need to update stream flow control. + return len(b), io.EOF + } + if len(s.inset) > 0 && s.inset[0].start <= s.in.start && s.inset[0].end > s.in.start { + // If we have more readable bytes available, put the next chunk of data + // in s.inbuf for lock-free reads. + s.inbuf = s.in.peek(s.inset[0].end - s.in.start) + bytesRead += int64(len(s.inbuf)) + } if s.insize == -1 || s.insize > s.inwin { - if shouldUpdateFlowControl(s.inmaxbuf, s.in.start+s.inmaxbuf-s.inwin) { + newWindow := s.in.start + int64(len(s.inbuf)) + s.inmaxbuf + addedWindow := newWindow - s.inwin + if shouldUpdateFlowControl(s.inmaxbuf, addedWindow) { // Update stream flow control with a STREAM_MAX_DATA frame. s.insendmax.setUnsent() } } - if end == s.insize { - return len(b), io.EOF - } return len(b), nil } // ReadByte reads and returns a single byte from the stream. +// +// It is not safe to call ReadByte concurrently. func (s *Stream) ReadByte() (byte, error) { + if len(s.inbuf) > s.inbufoff { + b := s.inbuf[s.inbufoff] + s.inbufoff++ + return b, nil + } var b [1]byte - _, err := s.Read(b[:]) - return b[0], err + n, err := s.Read(b[:]) + if n > 0 { + return b[0], err + } + return 0, err } // shouldUpdateFlowControl determines whether to send a flow control window update. @@ -507,8 +551,9 @@ func (s *Stream) inUnlock() { // inUnlockNoQueue is inUnlock, // but reports whether s has frames to write rather than notifying the Conn. func (s *Stream) inUnlockNoQueue() streamState { - canRead := s.inset.contains(s.in.start) || // data available to read - s.insize == s.in.start || // at EOF + nextByte := s.in.start + int64(len(s.inbuf)) + canRead := s.inset.contains(nextByte) || // data available to read + s.insize == s.in.start+int64(len(s.inbuf)) || // at EOF s.inresetcode != -1 || // reset by peer s.inclosed.isSet() // closed locally defer s.ingate.unlock(canRead) diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 08e89b24c..d1cfb34db 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -538,6 +538,32 @@ func TestStreamReceiveDuplicateDataDoesNotViolateLimits(t *testing.T) { }) } +func TestStreamReceiveEmptyEOF(t *testing.T) { + // A stream receives some data, we read a byte of that data + // (causing the rest to be pulled into the s.inbuf buffer), + // and then we receive a FIN with no additional data. + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters) + want := []byte{1, 2, 3} + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + data: want, + }) + if got, err := s.ReadByte(); got != want[0] || err != nil { + t.Fatalf("s.ReadByte() = %v, %v; want %v, nil", got, err, want[0]) + } + + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 3, + fin: true, + }) + if got, err := io.ReadAll(s); !bytes.Equal(got, want[1:]) || err != nil { + t.Fatalf("io.ReadAll(s) = {%x}, %v; want {%x}, nil", got, err, want[1:]) + } + }) +} + func finalSizeTest(t *testing.T, wantErr transportError, f func(tc *testConn, sid streamID) (finalSize int64), opts ...any) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { for _, test := range []struct { @@ -1156,8 +1182,8 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { code: sentCode, }) wantErr := StreamErrorCode(sentCode) - if n, err := s.Read(got); n != 0 || !errors.Is(err, wantErr) { - t.Fatalf("Read reset stream: got %v, %v; want 0, %v", n, err, wantErr) + if _, err := io.ReadAll(s); !errors.Is(err, wantErr) { + t.Fatalf("Read reset stream: ReadAll got error %v; want %v", err, wantErr) } }) } From 5e097125fdec6a2b4d9123a57f9551c2b89c7315 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 14 Feb 2024 12:31:31 -0800 Subject: [PATCH 39/70] quic: fast path for stream writes Similar to the fast-path for reads, writes are buffered in an unsynchronized []byte allowing for lock-free small writes. For golang/go#58547 Change-Id: I305cb5f91eff662a473f44a4bc051acc7c213e4c Reviewed-on: https://go-review.googlesource.com/c/net/+/564496 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/pipe.go | 12 +++++++++ internal/quic/stream.go | 50 ++++++++++++++++++++++++++++++++++-- internal/quic/stream_test.go | 3 ++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/internal/quic/pipe.go b/internal/quic/pipe.go index 42a0049da..75cf76db2 100644 --- a/internal/quic/pipe.go +++ b/internal/quic/pipe.go @@ -148,6 +148,18 @@ func (p *pipe) peek(n int64) []byte { return b[:min(int64(len(b)), n)] } +// availableBuffer returns the available contiguous, allocated buffer space +// following the pipe window. +// +// This is used by the stream write fast path, which makes multiple writes into the pipe buffer +// without a lock, and then adjusts p.end at a later time with a lock held. +func (p *pipe) availableBuffer() []byte { + if p.tail == nil { + return nil + } + return p.tail.b[p.end-p.tail.off:] +} + // discardBefore discards all data prior to off. func (p *pipe) discardBefore(off int64) { for p.head != nil && p.head.end() < off { diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 17ca8b7d6..c5fafdf1d 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -58,8 +58,10 @@ type Stream struct { outdone chan struct{} // closed when all data sent // Unsynchronized buffers, used for lock-free fast path. - inbuf []byte // received data - inbufoff int // bytes of inbuf which have been consumed + inbuf []byte // received data + inbufoff int // bytes of inbuf which have been consumed + outbuf []byte // written data + outbufoff int // bytes of outbuf which contain data to write // Atomic stream state bits. // @@ -313,7 +315,14 @@ func (s *Stream) Write(b []byte) (n int, err error) { if s.IsReadOnly() { return 0, errors.New("write to read-only stream") } + if len(b) > 0 && len(s.outbuf)-s.outbufoff >= len(b) { + // Fast path: The data to write fits in s.outbuf. + copy(s.outbuf[s.outbufoff:], b) + s.outbufoff += len(b) + return len(b), nil + } canWrite := s.outgate.lock() + s.flushFastOutputBuffer() for { // The first time through this loop, we may or may not be write blocked. // We exit the loop after writing all data, so on subsequent passes through @@ -373,17 +382,51 @@ func (s *Stream) Write(b []byte) (n int, err error) { // If we have bytes left to send, we're blocked. canWrite = false } + if lim := s.out.start + s.outmaxbuf - s.out.end - 1; lim > 0 { + // If s.out has space allocated and available to be written into, + // then reference it in s.outbuf for fast-path writes. + // + // It's perhaps a bit pointless to limit s.outbuf to the send buffer limit. + // We've already allocated this buffer so we aren't saving any memory + // by not using it. + // For now, we limit it anyway to make it easier to reason about limits. + // + // We set the limit to one less than the send buffer limit (the -1 above) + // so that a write which completely fills the buffer will overflow + // s.outbuf and trigger a flush. + s.outbuf = s.out.availableBuffer() + if int64(len(s.outbuf)) > lim { + s.outbuf = s.outbuf[:lim] + } + } s.outUnlock() return n, nil } // WriteBytes writes a single byte to the stream. func (s *Stream) WriteByte(c byte) error { + if s.outbufoff < len(s.outbuf) { + s.outbuf[s.outbufoff] = c + s.outbufoff++ + return nil + } b := [1]byte{c} _, err := s.Write(b[:]) return err } +func (s *Stream) flushFastOutputBuffer() { + if s.outbuf == nil { + return + } + // Commit data previously written to s.outbuf. + // s.outbuf is a reference to a buffer in s.out, so we just need to record + // that the output buffer has been extended. + s.out.end += int64(s.outbufoff) + s.outbuf = nil + s.outbufoff = 0 +} + // Flush flushes data written to the stream. // It does not wait for the peer to acknowledge receipt of the data. // Use Close to wait for the peer's acknowledgement. @@ -394,6 +437,7 @@ func (s *Stream) Flush() { } func (s *Stream) flushLocked() { + s.flushFastOutputBuffer() s.outopened.set() if s.outflushed < s.outwin { s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) @@ -509,6 +553,8 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) { // extra RESET_STREAM in this case is harmless. s.outreset.set() s.outresetcode = code + s.outbuf = nil + s.outbufoff = 0 s.out.discardBefore(s.out.end) s.outunsent = rangeset[int64]{} s.outblocked.clear() diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index d1cfb34db..9f857f29d 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -100,6 +100,7 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } + s.Flush() tc.wantFrame("write blocked by flow control triggers a STREAM_DATA_BLOCKED frame", packetType1RTT, debugFrameStreamDataBlocked{ id: s.id, @@ -111,6 +112,7 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } + s.Flush() tc.wantIdle("adding more blocked data does not trigger another STREAM_DATA_BLOCKED") // Provide some flow control window. @@ -1349,7 +1351,6 @@ func TestStreamFlushImplicitExact(t *testing.T) { id: s.id, data: want[0:4], }) - }) } From 2a8baeab1851a3c0f336a9185a02c177a0365232 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Sun, 18 Feb 2024 12:13:12 -0800 Subject: [PATCH 40/70] quic: don't record fin bit as sent when it wasn't When appendStreamFrame is provided with the last chunk of data for a stream, doesn't have enough space in the packet to include all the data, don't incorrectly record the packet as including a FIN bit. We were correctly sending a STREAM frame with no FIN bit--it's just the sent packet accounting that was off. No test, because I can't figure out a scenario where this actually has an observable effect, since we're always going to send the FIN when the remaining stream data is sent. Change-Id: I0ee81273165fcf10a52da76b33d2bf1b9c4f3523 Reviewed-on: https://go-review.googlesource.com/c/net/+/564796 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/packet_writer.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 85149f607..9ed393502 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -388,11 +388,7 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b w.b = appendVarint(w.b, uint64(size)) start := len(w.b) w.b = w.b[:start+size] - if fin { - w.sent.appendAckElicitingFrame(frameTypeStreamBase | streamFinBit) - } else { - w.sent.appendAckElicitingFrame(frameTypeStreamBase) - } + w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit)) w.sent.appendInt(uint64(id)) w.sent.appendOffAndSize(off, size) return w.b[start:][:size], true From a6a24dd292f82221e069bd497ff2a93756f63d20 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 15 Feb 2024 09:52:29 -0800 Subject: [PATCH 41/70] quic: source address and ECN support in the network layer Make the abstraction over UDP connections higher level, and add support for setting the source address and ECN bits in sent packets, and receving the destination address and ECN bits in received packets. There is no good way that I can find to identify the source IP address of packets we send. Look up the destination IP address of the first packet received on each connection, and use this as the source address for all future packets we send. This avoids unexpected path migration, where the address we send from changes without our knowing it. Reject received packets sent from an unexpected peer address. In the future, when we support path migration, we will want to relax these restrictions. ECN bits may be used to detect network congestion. We don't make use of them at this time, but this CL adds the necessary UDP layer support to do so in the future. This CL also lays the groundwork for using more efficient platform APIs to send/receive packets in the future. (sendmmsg/recvmmsg/GSO/GRO) These features require platform-specific APIs. Add support for Darwin and Linux to start with, with a graceful fallback on other OSs. For golang/go#58547 Change-Id: I1c97cc0d3e52fff18e724feaaac4a50d3df671bc Reviewed-on: https://go-review.googlesource.com/c/net/+/565255 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn.go | 9 +- internal/quic/conn_recv.go | 38 ++-- internal/quic/conn_send.go | 5 +- internal/quic/conn_test.go | 7 + internal/quic/dgram.go | 23 ++- internal/quic/endpoint.go | 90 +++++----- internal/quic/endpoint_test.go | 26 ++- internal/quic/qlog.go | 6 + internal/quic/retry.go | 17 +- internal/quic/retry_test.go | 4 +- internal/quic/stateless_reset_test.go | 4 +- internal/quic/udp.go | 30 ++++ internal/quic/udp_darwin.go | 13 ++ internal/quic/udp_linux.go | 13 ++ internal/quic/udp_msg.go | 248 ++++++++++++++++++++++++++ internal/quic/udp_other.go | 62 +++++++ internal/quic/udp_test.go | 176 ++++++++++++++++++ 17 files changed, 676 insertions(+), 95 deletions(-) create mode 100644 internal/quic/udp.go create mode 100644 internal/quic/udp_darwin.go create mode 100644 internal/quic/udp_linux.go create mode 100644 internal/quic/udp_msg.go create mode 100644 internal/quic/udp_other.go create mode 100644 internal/quic/udp_test.go diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 020bc81a4..5738b6dbb 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -25,6 +25,7 @@ type Conn struct { config *Config testHooks connTestHooks peerAddr netip.AddrPort + localAddr netip.AddrPort msgc chan any donec chan struct{} // closed when conn loop exits @@ -97,7 +98,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip side: side, endpoint: e, config: config, - peerAddr: peerAddr, + peerAddr: unmapAddrPort(peerAddr), msgc: make(chan any, 1), donec: make(chan struct{}), peerAckDelayExponent: -1, @@ -317,7 +318,11 @@ func (c *Conn) loop(now time.Time) { } switch m := m.(type) { case *datagram: - c.handleDatagram(now, m) + if !c.handleDatagram(now, m) { + if c.logEnabled(QLogLevelPacket) { + c.logPacketDropped(m) + } + } m.recycle() case timerEvent: // A connection timer has expired. diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 1b3219723..c8d70d85c 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -8,17 +8,33 @@ package quic import ( "bytes" - "context" "encoding/binary" "errors" "time" ) -func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { +func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { + if !c.localAddr.IsValid() { + // We don't have any way to tell in the general case what address we're + // sending packets from. Set our address from the destination address of + // the first packet received from the peer. + c.localAddr = dgram.localAddr + } + if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr { + if c.side == clientSide { + // "If a client receives packets from an unknown server address, + // the client MUST discard these packets." + // https://www.rfc-editor.org/rfc/rfc9000#section-9-6 + return false + } + // We currently don't support connection migration, + // so for now the server also drops packets from an unknown address. + return false + } buf := dgram.b c.loss.datagramReceived(now, len(buf)) if c.isDraining() { - return + return false } for len(buf) > 0 { var n int @@ -28,7 +44,7 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize { // Discard client-sent Initial packets in too-short datagrams. // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 - return + return false } n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf) case packetTypeHandshake: @@ -37,10 +53,10 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { n = c.handle1RTT(now, buf) case packetTypeRetry: c.handleRetry(now, buf) - return + return true case packetTypeVersionNegotiation: c.handleVersionNegotiation(now, buf) - return + return true default: n = -1 } @@ -58,20 +74,16 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { var token statelessResetToken copy(token[:], buf[len(buf)-len(token):]) if c.handleStatelessReset(now, token) { - return + return true } } // Invalid data at the end of a datagram is ignored. - if c.logEnabled(QLogLevelPacket) { - c.log.LogAttrs(context.Background(), QLogLevelPacket, - "connectivity:packet_dropped", - ) - } - break + return false } c.idleHandlePacketReceived(now) buf = buf[n:] } + return true } func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 575b8f9b4..12bcfe308 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -179,7 +179,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } } - c.endpoint.sendDatagram(buf, c.peerAddr) + c.endpoint.sendDatagram(datagram{ + b: buf, + peerAddr: c.peerAddr, + }) } } diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 2d3c946d6..a8f3fc7fd 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -453,6 +453,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { dstConnID: dstConnID, srcConnID: tc.peerConnID, }}, + addr: tc.conn.peerAddr, } if ptype == packetTypeInitial && tc.conn.side == serverSide { d.paddedSize = 1200 @@ -656,6 +657,12 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) { } func packetEqual(a, b *testPacket) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } ac := *a ac.frames = nil ac.header = 0 diff --git a/internal/quic/dgram.go b/internal/quic/dgram.go index 79e6650fa..615589373 100644 --- a/internal/quic/dgram.go +++ b/internal/quic/dgram.go @@ -12,10 +12,25 @@ import ( ) type datagram struct { - b []byte - addr netip.AddrPort + b []byte + localAddr netip.AddrPort + peerAddr netip.AddrPort + ecn ecnBits } +// Explicit Congestion Notification bits. +// +// https://www.rfc-editor.org/rfc/rfc3168.html#section-5 +type ecnBits byte + +const ( + ecnMask = 0b000000_11 + ecnNotECT = 0b000000_00 + ecnECT1 = 0b000000_01 + ecnECT0 = 0b000000_10 + ecnCE = 0b000000_11 +) + var datagramPool = sync.Pool{ New: func() any { return &datagram{ @@ -26,7 +41,9 @@ var datagramPool = sync.Pool{ func newDatagram() *datagram { m := datagramPool.Get().(*datagram) - m.b = m.b[:cap(m.b)] + *m = datagram{ + b: m.b[:cap(m.b)], + } return m } diff --git a/internal/quic/endpoint.go b/internal/quic/endpoint.go index 8ed67de54..6631708b8 100644 --- a/internal/quic/endpoint.go +++ b/internal/quic/endpoint.go @@ -22,11 +22,11 @@ import ( // // Multiple goroutines may invoke methods on an Endpoint simultaneously. type Endpoint struct { - config *Config - udpConn udpConn - testHooks endpointTestHooks - resetGen statelessResetTokenGenerator - retry retryState + config *Config + packetConn packetConn + testHooks endpointTestHooks + resetGen statelessResetTokenGenerator + retry retryState acceptQueue queue[*Conn] // new inbound connections connsMap connsMap // only accessed by the listen loop @@ -42,13 +42,12 @@ type endpointTestHooks interface { newConn(c *Conn) } -// A udpConn is a UDP connection. -// It is implemented by net.UDPConn. -type udpConn interface { +// A packetConn is the interface to sending and receiving UDP packets. +type packetConn interface { Close() error - LocalAddr() net.Addr - ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) - WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) + LocalAddr() netip.AddrPort + Read(f func(*datagram)) + Write(datagram) error } // Listen listens on a local network address. @@ -65,13 +64,17 @@ func Listen(network, address string, config *Config) (*Endpoint, error) { if err != nil { return nil, err } - return newEndpoint(udpConn, config, nil) + pc, err := newNetUDPConn(udpConn) + if err != nil { + return nil, err + } + return newEndpoint(pc, config, nil) } -func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { +func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { e := &Endpoint{ config: config, - udpConn: udpConn, + packetConn: pc, testHooks: hooks, conns: make(map[*Conn]struct{}), acceptQueue: newQueue[*Conn](), @@ -90,8 +93,7 @@ func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*End // LocalAddr returns the local network address. func (e *Endpoint) LocalAddr() netip.AddrPort { - a, _ := e.udpConn.LocalAddr().(*net.UDPAddr) - return a.AddrPort() + return e.packetConn.LocalAddr() } // Close closes the Endpoint. @@ -114,7 +116,7 @@ func (e *Endpoint) Close(ctx context.Context) error { conns = append(conns, c) } if len(e.conns) == 0 { - e.udpConn.Close() + e.packetConn.Close() } } e.connsMu.Unlock() @@ -200,34 +202,18 @@ func (e *Endpoint) connDrained(c *Conn) { defer e.connsMu.Unlock() delete(e.conns, c) if e.closing && len(e.conns) == 0 { - e.udpConn.Close() + e.packetConn.Close() } } func (e *Endpoint) listen() { defer close(e.closec) - for { - m := newDatagram() - // TODO: Read and process the ECN (explicit congestion notification) field. - // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4 - n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil) - if err != nil { - // The user has probably closed the endpoint. - // We currently don't surface errors from other causes; - // we could check to see if the endpoint has been closed and - // record the unexpected error if it has not. - return - } - if n == 0 { - continue - } + e.packetConn.Read(func(m *datagram) { if e.connsMap.updateNeeded.Load() { e.connsMap.applyUpdates() } - m.addr = addr - m.b = m.b[:n] e.handleDatagram(m) - } + }) } func (e *Endpoint) handleDatagram(m *datagram) { @@ -277,7 +263,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { // If this is a 1-RTT packet, there's nothing productive we can do with it. // Send a stateless reset if possible. if !isLongHeader(m.b[0]) { - e.maybeSendStatelessReset(m.b, m.addr) + e.maybeSendStatelessReset(m.b, m.peerAddr) return } p, ok := parseGenericLongHeaderPacket(m.b) @@ -291,7 +277,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { return default: // Unknown version. - e.sendVersionNegotiation(p, m.addr) + e.sendVersionNegotiation(p, m.peerAddr) return } if getPacketType(m.b) != packetTypeInitial { @@ -309,7 +295,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { if e.config.RequireAddressValidation { var ok bool cids.retrySrcConnID = p.dstConnID - cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr) + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr) if !ok { return } @@ -317,7 +303,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { cids.originalDstConnID = p.dstConnID } var err error - c, err := e.newConn(now, serverSide, cids, m.addr) + c, err := e.newConn(now, serverSide, cids, m.peerAddr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. @@ -329,7 +315,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { m = nil // don't recycle, sendMsg takes ownership } -func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { +func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) { if !e.resetGen.canReset { // Config.StatelessResetKey isn't set, so we don't send stateless resets. return @@ -370,17 +356,21 @@ func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { b[0] &^= headerFormLong // clear long header bit b[0] |= fixedBit // set fixed bit copy(b[len(b)-statelessResetTokenLen:], token[:]) - e.sendDatagram(b, addr) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) } -func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) { m := newDatagram() m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) - e.sendDatagram(m.b, addr) + m.peerAddr = peerAddr + e.sendDatagram(*m) m.recycle() } -func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { +func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) { keys := initialKeys(in.dstConnID, serverSide) var w packetWriter p := longPacket{ @@ -399,12 +389,14 @@ func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort if len(buf) == 0 { return } - e.sendDatagram(buf, addr) + e.sendDatagram(datagram{ + b: buf, + peerAddr: peerAddr, + }) } -func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error { - _, err := e.udpConn.WriteToUDPAddrPort(p, addr) - return err +func (e *Endpoint) sendDatagram(dgram datagram) error { + return e.packetConn.Write(dgram) } // A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index 6d103f061..b9fb55fb3 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -12,7 +12,6 @@ import ( "crypto/tls" "io" "log/slog" - "net" "net/netip" "testing" "time" @@ -190,13 +189,9 @@ func (te *testEndpoint) writeDatagram(d *testDatagram) { for len(buf) < d.paddedSize { buf = append(buf, 0) } - addr := d.addr - if !addr.IsValid() { - addr = testClientAddr - } te.write(&datagram{ - b: buf, - addr: addr, + b: buf, + peerAddr: d.addr, }) } @@ -303,25 +298,24 @@ func (te *testEndpointUDPConn) Close() error { return nil } -func (te *testEndpointUDPConn) LocalAddr() net.Addr { - return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443")) +func (te *testEndpointUDPConn) LocalAddr() netip.AddrPort { + return netip.MustParseAddrPort("127.0.0.1:443") } -func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { +func (te *testEndpointUDPConn) Read(f func(*datagram)) { for { select { case d, ok := <-te.recvc: if !ok { - return 0, 0, 0, netip.AddrPort{}, io.EOF + return } - n = copy(b, d.b) - return n, 0, 0, d.addr, nil + f(d) case <-te.idlec: } } } -func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...)) - return len(b), nil +func (te *testEndpointUDPConn) Write(dgram datagram) error { + te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...)) + return nil } diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go index e37e2f8ce..36831252c 100644 --- a/internal/quic/qlog.go +++ b/internal/quic/qlog.go @@ -151,6 +151,12 @@ func (c *Conn) logConnectionClosed() { ) } +func (c *Conn) logPacketDropped(dgram *datagram) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "connectivity:packet_dropped", + ) +} + func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { var frames slog.Attr if c.logEnabled(QLogLevelFrame) { diff --git a/internal/quic/retry.go b/internal/quic/retry.go index 31cb57b88..5dc39d1d9 100644 --- a/internal/quic/retry.go +++ b/internal/quic/retry.go @@ -139,7 +139,7 @@ func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []by return additional } -func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) { // The retry token is at the start of an Initial packet's data. token, n := consumeUint8Bytes(p.data) if n < 0 { @@ -151,22 +151,22 @@ func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, ad if len(token) == 0 { // The sender has not provided a token. // Send a Retry packet to them with one. - e.sendRetry(now, p, addr) + e.sendRetry(now, p, peerAddr) return nil, false } - origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr) if !ok { // This does not seem to be a valid token. // Close the connection with an INVALID_TOKEN error. // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 - e.sendConnectionClose(p, addr, errInvalidToken) + e.sendConnectionClose(p, peerAddr, errInvalidToken) return nil, false } return origDstConnID, true } -func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { - token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr) if err != nil { return } @@ -175,7 +175,10 @@ func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.Addr srcConnID: srcConnID, token: token, }) - e.sendDatagram(b, addr) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) } type retryPacket struct { diff --git a/internal/quic/retry_test.go b/internal/quic/retry_test.go index 8f36e1bd3..42f2bdd4a 100644 --- a/internal/quic/retry_test.go +++ b/internal/quic/retry_test.go @@ -436,8 +436,8 @@ func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { }) pkt[len(pkt)-1] ^= 1 // invalidate the integrity tag tc.endpoint.write(&datagram{ - b: pkt, - addr: testClientAddr, + b: pkt, + peerAddr: testClientAddr, }) tc.wantIdle("client ignores Retry with invalid integrity tag") } diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go index 45a49e81e..9458d2ea9 100644 --- a/internal/quic/stateless_reset_test.go +++ b/internal/quic/stateless_reset_test.go @@ -57,8 +57,8 @@ func newDatagramForReset(cid []byte, size int, addr netip.AddrPort) *datagram { dgram = append(dgram, byte(len(dgram))) // semi-random junk } return &datagram{ - b: dgram, - addr: addr, + b: dgram, + peerAddr: addr, } } diff --git a/internal/quic/udp.go b/internal/quic/udp.go new file mode 100644 index 000000000..0a578286b --- /dev/null +++ b/internal/quic/udp.go @@ -0,0 +1,30 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "net/netip" + +// Per-plaform consts describing support for various features. +// +// const udpECNSupport indicates whether the platform supports setting +// the ECN (Explicit Congestion Notification) IP header bits. +// +// const udpInvalidLocalAddrIsError indicates whether sending a packet +// from an local address not associated with the system is an error. +// For example, assuming 127.0.0.2 is not a local address, does sending +// from it (using IP_PKTINFO or some other such feature) result in an error? + +// unmapAddrPort returns a with any IPv4-mapped IPv6 address prefix removed. +func unmapAddrPort(a netip.AddrPort) netip.AddrPort { + if a.Addr().Is4In6() { + return netip.AddrPortFrom( + a.Addr().Unmap(), + a.Port(), + ) + } + return a +} diff --git a/internal/quic/udp_darwin.go b/internal/quic/udp_darwin.go new file mode 100644 index 000000000..3868a36a8 --- /dev/null +++ b/internal/quic/udp_darwin.go @@ -0,0 +1,13 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && darwin + +package quic + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = true +) diff --git a/internal/quic/udp_linux.go b/internal/quic/udp_linux.go new file mode 100644 index 000000000..2ba3e6f2f --- /dev/null +++ b/internal/quic/udp_linux.go @@ -0,0 +1,13 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && linux + +package quic + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = false +) diff --git a/internal/quic/udp_msg.go b/internal/quic/udp_msg.go new file mode 100644 index 000000000..bdc1b710d --- /dev/null +++ b/internal/quic/udp_msg.go @@ -0,0 +1,248 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && !quicbasicnet && (darwin || linux) + +package quic + +import ( + "net" + "net/netip" + "sync" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Network interface for platforms using sendmsg/recvmsg with cmsgs. + +type netUDPConn struct { + c *net.UDPConn + localAddr netip.AddrPort +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + a, _ := uc.LocalAddr().(*net.UDPAddr) + localAddr := a.AddrPort() + if localAddr.Addr().IsUnspecified() { + // If the conn is not bound to a specified (non-wildcard) address, + // then set localAddr.Addr to an invalid netip.Addr. + // This better conveys that this is not an address we should be using, + // and is a bit more efficient to test against. + localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port()) + } + + sc, err := uc.SyscallConn() + if err != nil { + return nil, err + } + sc.Control(func(fd uintptr) { + // Ask for ECN info and (when we aren't bound to a fixed local address) + // destination info. + // + // If any of these calls fail, we won't get the requested information. + // That's fine, we'll gracefully handle the lack. + unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + if !localAddr.IsValid() { + unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + } + }) + + return &netUDPConn{ + c: uc, + localAddr: localAddr, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + // We shouldn't ever see all of these messages at the same time, + // but the total is small so just allocate enough space for everything we use. + const ( + inPktinfoSize = 12 // int + in_addr + in_addr + in6PktinfoSize = 20 // in6_addr + int + ipTOSSize = 4 + ipv6TclassSize = 4 + ) + control := make([]byte, 0+ + unix.CmsgSpace(inPktinfoSize)+ + unix.CmsgSpace(in6PktinfoSize)+ + unix.CmsgSpace(ipTOSSize)+ + unix.CmsgSpace(ipv6TclassSize)) + + for { + d := newDatagram() + n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control) + if err != nil { + return + } + if n == 0 { + continue + } + d.localAddr = c.localAddr + d.peerAddr = unmapAddrPort(peerAddr) + d.b = d.b[:n] + parseControl(d, control[:controlLen]) + f(d) + } +} + +var cmsgPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +func (c *netUDPConn) Write(dgram datagram) error { + controlp := cmsgPool.Get().(*[]byte) + control := *controlp + defer func() { + *controlp = control[:0] + cmsgPool.Put(controlp) + }() + + localIP := dgram.localAddr.Addr() + if localIP.IsValid() { + if localIP.Is4() { + control = appendCmsgIPSourceAddrV4(control, localIP) + } else { + control = appendCmsgIPSourceAddrV6(control, localIP) + } + } + if dgram.ecn != ecnNotECT { + if dgram.peerAddr.Addr().Is4() { + control = appendCmsgECNv4(control, dgram.ecn) + } else { + control = appendCmsgECNv6(control, dgram.ecn) + } + } + + _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr) + return err +} + +func parseControl(d *datagram, control []byte) { + for len(control) > 0 { + hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control) + if err != nil { + return + } + control = remainder + switch hdr.Level { + case unix.IPPROTO_IP: + switch hdr.Type { + case unix.IP_TOS, unix.IP_RECVTOS: + // Single byte containing the IP TOS field. + // The low two bits are the ECN field. + // + // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, + // jus check for both.) + if len(data) < 1 { + break + } + d.ecn = ecnBits(data[0] & ecnMask) + case unix.IP_PKTINFO: + if a, ok := parseInPktinfo(data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + case unix.IPPROTO_IPV6: + switch hdr.Type { + case unix.IPV6_TCLASS: + // Single byte containing the traffic class field. + // The low two bits are the ECN field. + if len(data) < 1 { + break + } + d.ecn = ecnBits(data[0] & ecnMask) + case unix.IPV6_PKTINFO: + if a, ok := parseIn6Pktinfo(data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + } + } +} + +func parseInPktinfo(b []byte) (netip.Addr, bool) { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* send/recv interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* IP Header dst address */ + // }; + if len(b) != 12 { + return netip.Addr{}, false + } + return netip.AddrFrom4([4]byte(b[8:][:4])), true +} + +func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + if len(b) != 20 { + return netip.Addr{}, false + } + return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true +} + +// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* send/recv interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* IP Header dst address */ + // }; + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12) + ip := src.As4() + copy(data[4:], ip[:]) + return b +} + +// appendCmsgIPSourceAddrV6 appends an IP_PKTINFO or IPV6_PKTINFO +// setting the source address for an outbound datagram. +func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20) + ip := src.As16() + copy(data[0:], ip[:]) + return b +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4) + data[0] = byte(ecn) + return b +} + +func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4) + data[0] = byte(ecn) + return b +} + +// appendCmsg appends a cmsg with the given level, type, and size to b. +// It returns the new buffer, and the data section of the cmsg. +func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { + off := len(b) + b = append(b, make([]byte, unix.CmsgSpace(size))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off])) + h.Level = level + h.Type = typ + h.SetLen(unix.CmsgLen(size)) + return b, b[off+unix.CmsgSpace(0):][:size] +} diff --git a/internal/quic/udp_other.go b/internal/quic/udp_other.go new file mode 100644 index 000000000..28be6d200 --- /dev/null +++ b/internal/quic/udp_other.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && (quicbasicnet || !(darwin || linux)) + +package quic + +import ( + "net" + "net/netip" +) + +// Lowest common denominator network interface: Basic net.UDPConn, no cmsgs. +// We will not be able to send or receive ECN bits, +// and we will not know what our local address is. +// +// The quicbasicnet build tag allows selecting this interface on any platform. + +// See udp.go. +const ( + udpECNSupport = false + udpInvalidLocalAddrIsError = false +) + +type netUDPConn struct { + c *net.UDPConn +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + return &netUDPConn{ + c: uc, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, _, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(dgram.b, nil) + if err != nil { + return + } + if n == 0 { + continue + } + dgram.peerAddr = unmapAddrPort(peerAddr) + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netUDPConn) Write(dgram datagram) error { + _, err := c.c.WriteToUDPAddrPort(dgram.b, dgram.peerAddr) + return err +} diff --git a/internal/quic/udp_test.go b/internal/quic/udp_test.go new file mode 100644 index 000000000..27eddf811 --- /dev/null +++ b/internal/quic/udp_test.go @@ -0,0 +1,176 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "fmt" + "net" + "net/netip" + "runtime" + "testing" +) + +func TestUDPSourceUnspecified(t *testing.T) { + // Send datagram with no source address set. + runUDPTest(t, func(t *testing.T, test udpTest) { + data := []byte("source unspecified") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if !bytes.Equal(got.b, data) { + t.Errorf("got datagram {%x}, want {%x}", got.b, data) + } + }) +} + +func TestUDPSourceSpecified(t *testing.T) { + // Send datagram with source address set. + runUDPTest(t, func(t *testing.T, test udpTest) { + data := []byte("source specified") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + localAddr: test.src.LocalAddr(), + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if !bytes.Equal(got.b, data) { + t.Errorf("got datagram {%x}, want {%x}", got.b, data) + } + }) +} + +func TestUDPSourceInvalid(t *testing.T) { + // Send datagram with source address set to an address not associated with the connection. + if !udpInvalidLocalAddrIsError { + t.Skipf("%v: sending from invalid source succeeds", runtime.GOOS) + } + runUDPTest(t, func(t *testing.T, test udpTest) { + var localAddr netip.AddrPort + if test.src.LocalAddr().Addr().Is4() { + localAddr = netip.MustParseAddrPort("127.0.0.2:1234") + } else { + localAddr = netip.MustParseAddrPort("[::2]:1234") + } + data := []byte("source invalid") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + localAddr: localAddr, + }); err == nil { + t.Errorf("Write with invalid localAddr succeeded; want error") + } + }) +} + +func TestUDPECN(t *testing.T) { + if !udpECNSupport { + t.Skipf("%v: no ECN support", runtime.GOOS) + } + // Send datagrams with ECN bits set, verify the ECN bits are received. + runUDPTest(t, func(t *testing.T, test udpTest) { + for _, ecn := range []ecnBits{ecnNotECT, ecnECT1, ecnECT0, ecnCE} { + if err := test.src.Write(datagram{ + b: []byte{1, 2, 3, 4}, + peerAddr: test.dstAddr, + ecn: ecn, + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if got.ecn != ecn { + t.Errorf("sending ECN bits %x, got %x", ecn, got.ecn) + } + } + }) +} + +type udpTest struct { + src *netUDPConn + dst *netUDPConn + dstAddr netip.AddrPort + dgramc chan *datagram +} + +// runUDPTest calls f with a pair of UDPConns in a matrix of network variations: +// udp, udp4, and udp6, and variations on binding to an unspecified address (0.0.0.0) +// or a specified one. +func runUDPTest(t *testing.T, f func(t *testing.T, u udpTest)) { + for _, test := range []struct { + srcNet, srcAddr, dstNet, dstAddr string + }{ + {"udp4", "127.0.0.1", "udp", ""}, + {"udp4", "127.0.0.1", "udp4", ""}, + {"udp4", "127.0.0.1", "udp4", "127.0.0.1"}, + {"udp6", "::1", "udp", ""}, + {"udp6", "::1", "udp6", ""}, + {"udp6", "::1", "udp6", "::1"}, + } { + spec := "spec" + if test.dstAddr == "" { + spec = "unspec" + } + t.Run(fmt.Sprintf("%v/%v/%v", test.srcNet, test.dstNet, spec), func(t *testing.T) { + srcAddr := netip.AddrPortFrom(netip.MustParseAddr(test.srcAddr), 0) + srcConn, err := net.ListenUDP(test.srcNet, net.UDPAddrFromAddrPort(srcAddr)) + if err != nil { + // If ListenUDP fails here, we presumably don't have + // IPv4/IPv6 configured. + t.Skipf("ListenUDP(%q, %v) = %v", test.srcNet, srcAddr, err) + } + t.Cleanup(func() { srcConn.Close() }) + src, err := newNetUDPConn(srcConn) + if err != nil { + t.Fatalf("newNetUDPConn: %v", err) + } + + var dstAddr netip.AddrPort + if test.dstAddr != "" { + dstAddr = netip.AddrPortFrom(netip.MustParseAddr(test.dstAddr), 0) + } + dstConn, err := net.ListenUDP(test.dstNet, net.UDPAddrFromAddrPort(dstAddr)) + if err != nil { + t.Skipf("ListenUDP(%q, nil) = %v", test.dstNet, err) + } + dst, err := newNetUDPConn(dstConn) + if err != nil { + dstConn.Close() + t.Fatalf("newNetUDPConn: %v", err) + } + + dgramc := make(chan *datagram) + go func() { + defer close(dgramc) + dst.Read(func(dgram *datagram) { + dgramc <- dgram + }) + }() + t.Cleanup(func() { + dstConn.Close() + for range dgramc { + t.Errorf("test read unexpected datagram") + } + }) + + f(t, udpTest{ + src: src, + dst: dst, + dstAddr: netip.AddrPortFrom( + srcAddr.Addr(), + dst.LocalAddr().Port(), + ), + dgramc: dgramc, + }) + }) + } +} From 57e4cc7d885a72ee5111b227ba5790ea5a170656 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 28 Nov 2023 15:31:58 -0800 Subject: [PATCH 42/70] quic: handle PATH_CHALLENGE and PATH_RESPONSE frames We do not support path migration yet, and will ignore packets sent from anything other than the peer's original address. Handle PATH_CHALLENGE frames by sending a PATH_RESPONSE. Handle PATH_RESPONSE frames by closing the connection (since we never send a challenge to respond to). For golang/go#58547 Change-Id: I828b9dcb23e17f5edf3d605b8f04efdafb392807 Reviewed-on: https://go-review.googlesource.com/c/net/+/565795 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn.go | 1 + internal/quic/conn_loss_test.go | 23 ++++++++ internal/quic/conn_recv.go | 44 ++++++++++++--- internal/quic/conn_send.go | 7 +++ internal/quic/conn_test.go | 2 + internal/quic/frame_debug.go | 17 ++++-- internal/quic/packet_codec_test.go | 4 +- internal/quic/packet_parser.go | 11 ++-- internal/quic/packet_writer.go | 17 +++--- internal/quic/path.go | 89 ++++++++++++++++++++++++++++++ internal/quic/path_test.go | 66 ++++++++++++++++++++++ internal/quic/sent_packet.go | 6 ++ 12 files changed, 255 insertions(+), 32 deletions(-) create mode 100644 internal/quic/path.go create mode 100644 internal/quic/path_test.go diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 5738b6dbb..d462e9617 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -37,6 +37,7 @@ type Conn struct { connIDState connIDState loss lossState streams streamsState + path pathState // Packet protection keys, CRYPTO streams, and TLS state. keysInitial fixedKeyPair diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index 86ef23db0..81d537803 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -663,6 +663,29 @@ func TestLostRetireConnectionIDFrame(t *testing.T) { }) } +func TestLostPathResponseFrame(t *testing.T) { + // "Responses to path validation using PATH_RESPONSE frames are sent just once." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.12 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypePing) + + data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + tc.writeFrames(packetType1RTT, debugFramePathChallenge{ + data: data, + }) + tc.wantFrame("response to PATH_CHALLENGE", + packetType1RTT, debugFramePathResponse{ + data: data, + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantIdle("lost PATH_RESPONSE frame is not retransmitted") + }) +} + func TestLostHandshakeDoneFrame(t *testing.T) { // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16 diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index c8d70d85c..b1354cd3a 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -46,11 +46,11 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 return false } - n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf) + n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf) case packetTypeHandshake: - n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf) + n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf) case packetType1RTT: - n = c.handle1RTT(now, buf) + n = c.handle1RTT(now, dgram, buf) case packetTypeRetry: c.handleRetry(now, buf) return true @@ -86,7 +86,7 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { return true } -func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { +func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { if !k.isSet() { return skipLongHeaderPacket(buf) } @@ -125,7 +125,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa c.logLongPacketReceived(p, buf[:n]) } c.connIDState.handlePacket(c, p.ptype, p.srcConnID) - ackEliciting := c.handleFrames(now, ptype, space, p.payload) + ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload) c.acks[space].receive(now, space, p.num, ackEliciting) if p.ptype == packetTypeHandshake && c.side == serverSide { c.loss.validateClientAddress() @@ -138,7 +138,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa return n } -func (c *Conn) handle1RTT(now time.Time, buf []byte) int { +func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { if !c.keysAppData.canRead() { // 1-RTT packets extend to the end of the datagram, // so skip the remainder of the datagram if we can't parse this. @@ -175,7 +175,7 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { if c.logEnabled(QLogLevelPacket) { c.log1RTTPacketReceived(p, buf) } - ackEliciting := c.handleFrames(now, packetType1RTT, appDataSpace, p.payload) + ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload) c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting) return len(buf) } @@ -252,7 +252,7 @@ func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { c.abortImmediately(now, errVersionNegotiation) } -func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { +func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { if len(payload) == 0 { // "An endpoint MUST treat receipt of a packet containing no frames // as a connection error of type PROTOCOL_VIOLATION." @@ -373,6 +373,16 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, return } n = c.handleRetireConnectionIDFrame(now, space, payload) + case frameTypePathChallenge: + if !frameOK(c, ptype, __01) { + return + } + n = c.handlePathChallengeFrame(now, dgram, space, payload) + case frameTypePathResponse: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handlePathResponseFrame(now, space, payload) case frameTypeConnectionCloseTransport: // Transport CONNECTION_CLOSE is OK in all spaces. n = c.handleConnectionCloseTransportFrame(now, payload) @@ -546,6 +556,24 @@ func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, p return n } +func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int { + data, n := consumePathChallengeFrame(payload) + if n < 0 { + return -1 + } + c.handlePathChallenge(now, dgram, data) + return n +} + +func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int { + data, n := consumePathResponseFrame(payload) + if n < 0 { + return -1 + } + c.handlePathResponse(now, data) + return n +} + func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int { code, _, reason, n := consumeConnectionCloseTransportFrame(payload) if n < 0 { diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 12bcfe308..a87cac232 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -271,6 +271,13 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, return } + // PATH_RESPONSE + if pad, ok := c.appendPathFrames(); !ok { + return + } else if pad { + defer c.w.appendPaddingTo(smallestMaxDatagramSize) + } + // All stream-related frames. This should come last in the packet, // so large amounts of STREAM data don't crowd out other frames // we may need to send. diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index a8f3fc7fd..16ee3cf2f 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -168,6 +168,7 @@ type testConn struct { sentDatagrams [][]byte sentPackets []*testPacket sentFrames []debugFrame + lastDatagram *testDatagram lastPacket *testPacket recvDatagram chan *datagram @@ -576,6 +577,7 @@ func (tc *testConn) readDatagram() *testDatagram { } p.frames = frames } + tc.lastDatagram = d return d } diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go index 0902c385f..17234dd7c 100644 --- a/internal/quic/frame_debug.go +++ b/internal/quic/frame_debug.go @@ -77,6 +77,7 @@ func parseDebugFrame(b []byte) (f debugFrame, n int) { // debugFramePadding is a sequence of PADDING frames. type debugFramePadding struct { size int + to int // alternate for writing packets: pad to } func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) { @@ -95,6 +96,10 @@ func (f debugFramePadding) write(w *packetWriter) bool { if w.avail() == 0 { return false } + if f.to > 0 { + w.appendPaddingTo(f.to) + return true + } for i := 0; i < f.size && w.avail() > 0; i++ { w.b = append(w.b, frameTypePadding) } @@ -584,7 +589,7 @@ func (f debugFrameRetireConnectionID) LogValue() slog.Value { // debugFramePathChallenge is a PATH_CHALLENGE frame. type debugFramePathChallenge struct { - data uint64 + data pathChallengeData } func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { @@ -593,7 +598,7 @@ func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { } func (f debugFramePathChallenge) String() string { - return fmt.Sprintf("PATH_CHALLENGE Data=%016x", f.data) + return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data) } func (f debugFramePathChallenge) write(w *packetWriter) bool { @@ -603,13 +608,13 @@ func (f debugFramePathChallenge) write(w *packetWriter) bool { func (f debugFramePathChallenge) LogValue() slog.Value { return slog.GroupValue( slog.String("frame_type", "path_challenge"), - slog.String("data", fmt.Sprintf("%016x", f.data)), + slog.String("data", fmt.Sprintf("%x", f.data)), ) } // debugFramePathResponse is a PATH_RESPONSE frame. type debugFramePathResponse struct { - data uint64 + data pathChallengeData } func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { @@ -618,7 +623,7 @@ func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { } func (f debugFramePathResponse) String() string { - return fmt.Sprintf("PATH_RESPONSE Data=%016x", f.data) + return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data) } func (f debugFramePathResponse) write(w *packetWriter) bool { @@ -628,7 +633,7 @@ func (f debugFramePathResponse) write(w *packetWriter) bool { func (f debugFramePathResponse) LogValue() slog.Value { return slog.GroupValue( slog.String("frame_type", "path_response"), - slog.String("data", fmt.Sprintf("%016x", f.data)), + slog.String("data", fmt.Sprintf("%x", f.data)), ) } diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go index 475e18c1d..98b3bbb05 100644 --- a/internal/quic/packet_codec_test.go +++ b/internal/quic/packet_codec_test.go @@ -517,7 +517,7 @@ func TestFrameEncodeDecode(t *testing.T) { s: "PATH_CHALLENGE Data=0123456789abcdef", j: `{"frame_type":"path_challenge","data":"0123456789abcdef"}`, f: debugFramePathChallenge{ - data: 0x0123456789abcdef, + data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, }, b: []byte{ 0x1a, // Type (i) = 0x1a, @@ -527,7 +527,7 @@ func TestFrameEncodeDecode(t *testing.T) { s: "PATH_RESPONSE Data=0123456789abcdef", j: `{"frame_type":"path_response","data":"0123456789abcdef"}`, f: debugFramePathResponse{ - data: 0x0123456789abcdef, + data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, }, b: []byte{ 0x1b, // Type (i) = 0x1b, diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go index 02ef9fb14..feef9eac7 100644 --- a/internal/quic/packet_parser.go +++ b/internal/quic/packet_parser.go @@ -463,18 +463,17 @@ func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) { return seq, n } -func consumePathChallengeFrame(b []byte) (data uint64, n int) { +func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) { n = 1 - var nn int - data, nn = consumeUint64(b[n:]) - if nn < 0 { - return 0, -1 + nn := copy(data[:], b[n:]) + if nn != len(data) { + return data, -1 } n += nn return data, n } -func consumePathResponseFrame(b []byte) (data uint64, n int) { +func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) { return consumePathChallengeFrame(b) // identical frame format } diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 9ed393502..e4d71e622 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -243,10 +243,7 @@ func (w *packetWriter) appendPingFrame() (added bool) { return false } w.b = append(w.b, frameTypePing) - // Mark this packet as ack-eliciting and in-flight, - // but there's no need to record the presence of a PING frame in it. - w.sent.ackEliciting = true - w.sent.inFlight = true + w.sent.markAckEliciting() // no need to record the frame itself return true } @@ -495,23 +492,23 @@ func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) { return true } -func (w *packetWriter) appendPathChallengeFrame(data uint64) (added bool) { +func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) { if w.avail() < 1+8 { return false } w.b = append(w.b, frameTypePathChallenge) - w.b = binary.BigEndian.AppendUint64(w.b, data) - w.sent.appendAckElicitingFrame(frameTypePathChallenge) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself return true } -func (w *packetWriter) appendPathResponseFrame(data uint64) (added bool) { +func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) { if w.avail() < 1+8 { return false } w.b = append(w.b, frameTypePathResponse) - w.b = binary.BigEndian.AppendUint64(w.b, data) - w.sent.appendAckElicitingFrame(frameTypePathResponse) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself return true } diff --git a/internal/quic/path.go b/internal/quic/path.go new file mode 100644 index 000000000..8c237dd45 --- /dev/null +++ b/internal/quic/path.go @@ -0,0 +1,89 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "time" + +type pathState struct { + // Response to a peer's PATH_CHALLENGE. + // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames. + // We only track the most recent PATH_CHALLENGE. + // If the peer sends a second PATH_CHALLENGE before we respond to the first, + // we'll drop the first response. + sendPathResponse pathResponseType + data pathChallengeData +} + +// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame. +type pathChallengeData [64 / 8]byte + +type pathResponseType uint8 + +const ( + pathResponseNotNeeded = pathResponseType(iota) + pathResponseSmall // send PATH_RESPONSE, do not expand datagram + pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes +) + +func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) { + // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes, + // except when this would exceed the anti-amplification limit. + // + // Rather than maintaining anti-amplification state for each path + // we may be sending a PATH_RESPONSE on, follow the following heuristic: + // + // If we receive a PATH_CHALLENGE in an expanded datagram, + // respond with an expanded datagram. + // + // If we receive a PATH_CHALLENGE in a non-expanded datagram, + // then the peer is presumably blocked by its own anti-amplification limit. + // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE + // will validate the path to the peer, remove its anti-amplification limit, + // and permit it to send a followup PATH_CHALLENGE in an expanded datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1 + if len(dgram.b) >= smallestMaxDatagramSize { + c.path.sendPathResponse = pathResponseExpanded + } else { + c.path.sendPathResponse = pathResponseSmall + } + c.path.data = data +} + +func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + // + // We never send PATH_CHALLENGE frames. + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent", + }) +} + +// appendPathFrames appends path validation related frames to the current packet. +// If the return value pad is true, then the packet should be padded to 1200 bytes. +func (c *Conn) appendPathFrames() (pad, ok bool) { + if c.path.sendPathResponse == pathResponseNotNeeded { + return pad, true + } + // We're required to send the PATH_RESPONSE on the path where the + // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2). + // + // At the moment, we don't support path migration and reject packets if + // the peer changes its source address, so just sending the PATH_RESPONSE + // in a regular datagram is fine. + if !c.w.appendPathResponseFrame(c.path.data) { + return pad, false + } + if c.path.sendPathResponse == pathResponseExpanded { + pad = true + } + c.path.sendPathResponse = pathResponseNotNeeded + return pad, true +} diff --git a/internal/quic/path_test.go b/internal/quic/path_test.go new file mode 100644 index 000000000..a309ed14b --- /dev/null +++ b/internal/quic/path_test.go @@ -0,0 +1,66 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" +) + +func TestPathChallengeReceived(t *testing.T) { + for _, test := range []struct { + name string + padTo int + wantPadding int + }{{ + name: "unexpanded", + padTo: 0, + wantPadding: 0, + }, { + name: "expanded", + padTo: 1200, + wantPadding: 1200, + }} { + // "The recipient of [a PATH_CHALLENGE] frame MUST generate + // a PATH_RESPONSE frame [...] containing the same Data value." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + tc.writeFrames(packetType1RTT, debugFramePathChallenge{ + data: data, + }, debugFramePadding{ + to: test.padTo, + }) + tc.wantFrame("response to PATH_CHALLENGE", + packetType1RTT, debugFramePathResponse{ + data: data, + }) + if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want { + t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want) + } + tc.wantIdle("connection is idle") + } +} + +func TestPathResponseMismatchReceived(t *testing.T) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFramePathResponse{ + data: pathChallengeData{}, + }) + tc.wantFrame("invalid PATH_RESPONSE causes the connection to close", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }, + ) +} diff --git a/internal/quic/sent_packet.go b/internal/quic/sent_packet.go index 194cdc9fa..226152327 100644 --- a/internal/quic/sent_packet.go +++ b/internal/quic/sent_packet.go @@ -59,6 +59,12 @@ func (sent *sentPacket) reset() { } } +// markAckEliciting marks the packet as containing an ack-eliciting frame. +func (sent *sentPacket) markAckEliciting() { + sent.ackEliciting = true + sent.inFlight = true +} + // The append* methods record information about frames in the packet. func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) { From 22cbde9a565f4e40b5060a41d5e5171adcff673e Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 20 Feb 2024 14:58:00 -0800 Subject: [PATCH 43/70] quic: set ServerName in client connection TLSConfig Client connections must set tls.Config.ServerName to authenticate the identity of the server. (RFC 9001, Section 4.4.) Previously, we specified a single tls.Config per Endpoint. Change the Config passed to Listen to only apply to client connections accepted by the endpoint. Add a Config parameter to Listener.Dial to allow specifying a separate config per outbound connection, allowing the user to set the ServerName field. When the user does not set ServerName, set it ourselves. For golang/go#58547 Change-Id: Ie2500ae7c7a85400e6cc1c10cefa2bd4c746e313 Reviewed-on: https://go-review.googlesource.com/c/net/+/565796 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/cmd/interop/main.go | 6 ++-- internal/quic/config.go | 7 ++++ internal/quic/conn.go | 4 +-- internal/quic/conn_test.go | 2 ++ internal/quic/endpoint.go | 59 ++++++++++++++++++------------- internal/quic/endpoint_test.go | 31 ++++++++++------ internal/quic/tls.go | 14 ++++++-- 7 files changed, 81 insertions(+), 42 deletions(-) diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index 20f737b52..0899e0f1e 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -148,7 +148,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) { g.Add(1) go func() { defer g.Done() - fetchFrom(ctx, l, addr, u) + fetchFrom(ctx, config, l, addr, u) }() } @@ -221,8 +221,8 @@ func parseURL(s string) (u *url.URL, authority string, err error) { return u, authority, nil } -func fetchFrom(ctx context.Context, l *quic.Endpoint, addr string, urls []*url.URL) { - conn, err := l.Dial(ctx, "udp", addr) +func fetchFrom(ctx context.Context, config *quic.Config, l *quic.Endpoint, addr string, urls []*url.URL) { + conn, err := l.Dial(ctx, "udp", addr, config) if err != nil { log.Printf("%v: %v", addr, err) return diff --git a/internal/quic/config.go b/internal/quic/config.go index b045b7b92..5d420312b 100644 --- a/internal/quic/config.go +++ b/internal/quic/config.go @@ -107,6 +107,13 @@ type Config struct { QLogLogger *slog.Logger } +// Clone returns a shallow clone of c, or nil if c is nil. +// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint. +func (c *Config) Clone() *Config { + n := *c + return &n +} + func configDefault[T ~int64](v, def, limit T) T { switch { case v == 0: diff --git a/internal/quic/conn.go b/internal/quic/conn.go index d462e9617..38e8fe8f4 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -94,7 +94,7 @@ type newServerConnIDs struct { retrySrcConnID []byte // source from server's Retry } -func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { c := &Conn{ side: side, endpoint: e, @@ -146,7 +146,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip c.lifetimeInit() c.restartIdleTimer(now) - if err := c.startTLS(now, initialConnID, transportParameters{ + if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), originalDstConnID: cids.originalDstConnID, retrySrcConnID: cids.retrySrcConnID, diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 16ee3cf2f..a765ad60c 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -242,8 +242,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { endpoint.configTestConn = configTestConn conn, err := endpoint.e.newConn( endpoint.now, + config, side, cids, + "", netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) diff --git a/internal/quic/endpoint.go b/internal/quic/endpoint.go index 6631708b8..a55336b24 100644 --- a/internal/quic/endpoint.go +++ b/internal/quic/endpoint.go @@ -22,11 +22,11 @@ import ( // // Multiple goroutines may invoke methods on an Endpoint simultaneously. type Endpoint struct { - config *Config - packetConn packetConn - testHooks endpointTestHooks - resetGen statelessResetTokenGenerator - retry retryState + listenConfig *Config + packetConn packetConn + testHooks endpointTestHooks + resetGen statelessResetTokenGenerator + retry retryState acceptQueue queue[*Conn] // new inbound connections connsMap connsMap // only accessed by the listen loop @@ -51,9 +51,11 @@ type packetConn interface { } // Listen listens on a local network address. -// The configuration config must be non-nil. -func Listen(network, address string, config *Config) (*Endpoint, error) { - if config.TLSConfig == nil { +// +// The config is used to for connections accepted by the endpoint. +// If the config is nil, the endpoint will not accept connections. +func Listen(network, address string, listenConfig *Config) (*Endpoint, error) { + if listenConfig != nil && listenConfig.TLSConfig == nil { return nil, errors.New("TLSConfig is not set") } a, err := net.ResolveUDPAddr(network, address) @@ -68,21 +70,25 @@ func Listen(network, address string, config *Config) (*Endpoint, error) { if err != nil { return nil, err } - return newEndpoint(pc, config, nil) + return newEndpoint(pc, listenConfig, nil) } func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { e := &Endpoint{ - config: config, - packetConn: pc, - testHooks: hooks, - conns: make(map[*Conn]struct{}), - acceptQueue: newQueue[*Conn](), - closec: make(chan struct{}), - } - e.resetGen.init(config.StatelessResetKey) + listenConfig: config, + packetConn: pc, + testHooks: hooks, + conns: make(map[*Conn]struct{}), + acceptQueue: newQueue[*Conn](), + closec: make(chan struct{}), + } + var statelessResetKey [32]byte + if config != nil { + statelessResetKey = config.StatelessResetKey + } + e.resetGen.init(statelessResetKey) e.connsMap.init() - if config.RequireAddressValidation { + if config != nil && config.RequireAddressValidation { if err := e.retry.init(); err != nil { return nil, err } @@ -141,14 +147,15 @@ func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { } // Dial creates and returns a connection to a network address. -func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) { +// The config cannot be nil. +func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) { u, err := net.ResolveUDPAddr(network, address) if err != nil { return nil, err } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) + c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr) if err != nil { return nil, err } @@ -159,13 +166,13 @@ func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { +func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) { e.connsMu.Lock() defer e.connsMu.Unlock() if e.closing { return nil, errors.New("endpoint closed") } - c, err := newConn(now, side, cids, peerAddr, e.config, e) + c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e) if err != nil { return nil, err } @@ -288,11 +295,15 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16 return } + if e.listenConfig == nil { + // We are not configured to accept connections. + return + } cids := newServerConnIDs{ srcConnID: p.srcConnID, dstConnID: p.dstConnID, } - if e.config.RequireAddressValidation { + if e.listenConfig.RequireAddressValidation { var ok bool cids.retrySrcConnID = p.dstConnID cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr) @@ -303,7 +314,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { cids.originalDstConnID = p.dstConnID } var err error - c, err := e.newConn(now, serverSide, cids, m.peerAddr) + c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index b9fb55fb3..b6669fc83 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -67,7 +67,8 @@ func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverCon ctx := context.Background() e1 := newLocalEndpoint(t, serverSide, conf1) e2 := newLocalEndpoint(t, clientSide, conf2) - c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String()) + conf2 = makeTestConfig(conf2, clientSide) + c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String(), conf2) if err != nil { t.Fatal(err) } @@ -80,9 +81,24 @@ func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverCon func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint { t.Helper() + conf = makeTestConfig(conf, side) + e, err := Listen("udp", "127.0.0.1:0", conf) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + e.Close(canceledContext()) + }) + return e +} + +func makeTestConfig(conf *Config, side connSide) *Config { + if conf == nil { + return nil + } + newConf := *conf + conf = &newConf if conf.TLSConfig == nil { - newConf := *conf - conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } if conf.QLogLogger == nil { @@ -91,14 +107,7 @@ func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint { Dir: *qlogdir, })) } - e, err := Listen("udp", "127.0.0.1:0", conf) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - e.Close(canceledContext()) - }) - return e + return conf } type testEndpoint struct { diff --git a/internal/quic/tls.go b/internal/quic/tls.go index a37e26fb8..e2f2e5bde 100644 --- a/internal/quic/tls.go +++ b/internal/quic/tls.go @@ -11,14 +11,24 @@ import ( "crypto/tls" "errors" "fmt" + "net" "time" ) // startTLS starts the TLS handshake. -func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error { +func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string, params transportParameters) error { + tlsConfig := c.config.TLSConfig + if a, _, err := net.SplitHostPort(peerHostname); err == nil { + peerHostname = a + } + if tlsConfig.ServerName == "" && peerHostname != "" { + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = peerHostname + } + c.keysInitial = initialKeys(initialConnID, c.side) - qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig} + qconfig := &tls.QUICConfig{TLSConfig: tlsConfig} if c.side == clientSide { c.tls = tls.QUICClient(qconfig) } else { From 4bdc6df28ea746166f486314f8848eb9b25b9073 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 14 Feb 2024 11:22:21 -0800 Subject: [PATCH 44/70] quic: expand package docs, and document Stream For golang/go#58547 Change-Id: Ie5dd0ed383ea7a5b3a45103cb730ff62792f62e1 Reviewed-on: https://go-review.googlesource.com/c/net/+/565797 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/doc.go | 42 ++++++++++++++++++++++++++++++++++++++--- internal/quic/stream.go | 15 +++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/internal/quic/doc.go b/internal/quic/doc.go index 2fe17fe22..2fd10f087 100644 --- a/internal/quic/doc.go +++ b/internal/quic/doc.go @@ -2,8 +2,44 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package quic is an experimental, incomplete implementation of the QUIC protocol. -// This package is a work in progress, and is not ready for use at this time. +// Package quic implements the QUIC protocol. // -// This package implements (or will implement) RFC 9000, RFC 9001, and RFC 9002. +// This package is a work in progress. +// It is not ready for production usage. +// Its API is subject to change without notice. +// +// This package is low-level. +// Most users will use it indirectly through an HTTP/3 implementation. +// +// # Usage +// +// An [Endpoint] sends and receives traffic on a network address. +// Create an Endpoint to either accept inbound QUIC connections +// or create outbound ones. +// +// A [Conn] is a QUIC connection. +// +// A [Stream] is a QUIC stream, an ordered, reliable byte stream. +// +// # Cancelation +// +// All blocking operations may be canceled using a context.Context. +// When performing an operation with a canceled context, the operation +// will succeed if doing so does not require blocking. For example, +// reading from a stream will return data when buffered data is available, +// even if the stream context is canceled. +// +// # Limitations +// +// This package is a work in progress. +// Known limitations include: +// +// - Performance is untuned. +// - 0-RTT is not supported. +// - Address migration is not supported. +// - Server preferred addresses are not supported. +// - The latency spin bit is not supported. +// - Stream send/receive windows are configurable, +// but are fixed and do not adapt to available throughput. +// - Path MTU discovery is not implemented. package quic diff --git a/internal/quic/stream.go b/internal/quic/stream.go index c5fafdf1d..cb45534f8 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -14,6 +14,21 @@ import ( "math" ) +// A Stream is an ordered byte stream. +// +// Streams may be bidirectional, read-only, or write-only. +// Methods inappropriate for a stream's direction +// (for example, [Write] to a read-only stream) +// return errors. +// +// It is not safe to perform concurrent reads from or writes to a stream. +// It is safe, however, to read and write at the same time. +// +// Reads and writes are buffered. +// It is generally not necessary to wrap a stream in a [bufio.ReadWriter] +// or otherwise apply additional buffering. +// +// To cancel reads or writes, use the [SetReadContext] and [SetWriteContext] methods. type Stream struct { id streamID conn *Conn From 34cc4464c5cb7947126d80f9d75b4c16d229337d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 23 Feb 2024 08:53:01 -0800 Subject: [PATCH 45/70] quic: temporarily disable networking tests failing on various platforms For golang/go#65906 For golang/go#65907 Change-Id: I5fe83a27f47b6f2337d280465bf134dbd883809d Reviewed-on: https://go-review.googlesource.com/c/net/+/566098 Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Bryan Mills --- internal/quic/udp_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/quic/udp_test.go b/internal/quic/udp_test.go index 27eddf811..450351b6b 100644 --- a/internal/quic/udp_test.go +++ b/internal/quic/udp_test.go @@ -16,6 +16,7 @@ import ( ) func TestUDPSourceUnspecified(t *testing.T) { + t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with no source address set. runUDPTest(t, func(t *testing.T, test udpTest) { data := []byte("source unspecified") @@ -33,6 +34,7 @@ func TestUDPSourceUnspecified(t *testing.T) { } func TestUDPSourceSpecified(t *testing.T) { + t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with source address set. runUDPTest(t, func(t *testing.T, test udpTest) { data := []byte("source specified") @@ -51,6 +53,7 @@ func TestUDPSourceSpecified(t *testing.T) { } func TestUDPSourceInvalid(t *testing.T) { + t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with source address set to an address not associated with the connection. if !udpInvalidLocalAddrIsError { t.Skipf("%v: sending from invalid source succeeds", runtime.GOOS) @@ -74,6 +77,7 @@ func TestUDPSourceInvalid(t *testing.T) { } func TestUDPECN(t *testing.T) { + t.Skip("https://go.dev/issue/65907 - temporarily skipped pending fix") if !udpECNSupport { t.Skipf("%v: no ECN support", runtime.GOOS) } From 591be7f10be18b4b24250868fb61a93c2e5af3f4 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 23 Feb 2024 09:54:56 -0800 Subject: [PATCH 46/70] quic: fix UDP on big-endian Linux, tests on various architectures The following cmsgs contain a native-endian 32-bit integer: - IP_TOS, passed to sendmsg - IPV6_TCLASS, always IP_TOS received from recvmsg contains a single byte, because why not. We were inadvertently assuming little-endian integers in all cases. Add endianness conversion as appropriate. Disable tests that rely on IPv4-in-IPv6 mapped sockets on dragonfly and openbsd, which don't support this feature. (A "udp" socket cannot receive IPv6 packets on these platforms.) Disable IPv6 tests on wasm, where the simulated networking appears to generally not support IPv6. Fixes golang/go#65906 Fixes golang/go#65907 Change-Id: Ie50af12e182a1a5d685ce4fbdf008748f6aee339 Reviewed-on: https://go-review.googlesource.com/c/net/+/566296 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Reviewed-by: Bryan Mills --- internal/quic/udp_darwin.go | 25 +++++++++++ internal/quic/udp_linux.go | 20 +++++++++ internal/quic/udp_msg.go | 89 ++++++++++++++++++------------------- internal/quic/udp_test.go | 17 +++++-- 4 files changed, 102 insertions(+), 49 deletions(-) diff --git a/internal/quic/udp_darwin.go b/internal/quic/udp_darwin.go index 3868a36a8..2eb2e9f9f 100644 --- a/internal/quic/udp_darwin.go +++ b/internal/quic/udp_darwin.go @@ -6,8 +6,33 @@ package quic +import ( + "encoding/binary" + + "golang.org/x/sys/unix" +) + // See udp.go. const ( udpECNSupport = true udpInvalidLocalAddrIsError = true ) + +// Confusingly, on Darwin the contents of the IP_TOS option differ depending on whether +// it is used as an inbound or outbound cmsg. + +func parseIPTOS(b []byte) (ecnBits, bool) { + // Single byte. The low two bits are the ECN field. + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + // 32-bit integer. + // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073 + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} diff --git a/internal/quic/udp_linux.go b/internal/quic/udp_linux.go index 2ba3e6f2f..6f191ed39 100644 --- a/internal/quic/udp_linux.go +++ b/internal/quic/udp_linux.go @@ -6,8 +6,28 @@ package quic +import ( + "golang.org/x/sys/unix" +) + // See udp.go. const ( udpECNSupport = true udpInvalidLocalAddrIsError = false ) + +// The IP_TOS socket option is a single byte containing the IP TOS field. +// The low two bits are the ECN field. + +func parseIPTOS(b []byte) (ecnBits, bool) { + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 1) + data[0] = byte(ecn) + return b +} diff --git a/internal/quic/udp_msg.go b/internal/quic/udp_msg.go index bdc1b710d..0b600a2b4 100644 --- a/internal/quic/udp_msg.go +++ b/internal/quic/udp_msg.go @@ -7,6 +7,7 @@ package quic import ( + "encoding/binary" "net" "net/netip" "sync" @@ -141,15 +142,11 @@ func parseControl(d *datagram, control []byte) { case unix.IPPROTO_IP: switch hdr.Type { case unix.IP_TOS, unix.IP_RECVTOS: - // Single byte containing the IP TOS field. - // The low two bits are the ECN field. - // // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, - // jus check for both.) - if len(data) < 1 { - break + // just check for both.) + if ecn, ok := parseIPTOS(data); ok { + d.ecn = ecn } - d.ecn = ecnBits(data[0] & ecnMask) case unix.IP_PKTINFO: if a, ok := parseInPktinfo(data); ok { d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) @@ -158,12 +155,11 @@ func parseControl(d *datagram, control []byte) { case unix.IPPROTO_IPV6: switch hdr.Type { case unix.IPV6_TCLASS: - // Single byte containing the traffic class field. + // 32-bit integer containing the traffic class field. // The low two bits are the ECN field. - if len(data) < 1 { - break + if ecn, ok := parseIPv6TCLASS(data); ok { + d.ecn = ecn } - d.ecn = ecnBits(data[0] & ecnMask) case unix.IPV6_PKTINFO: if a, ok := parseIn6Pktinfo(data); ok { d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) @@ -173,27 +169,33 @@ func parseControl(d *datagram, control []byte) { } } -func parseInPktinfo(b []byte) (netip.Addr, bool) { - // struct in_pktinfo { - // unsigned int ipi_ifindex; /* send/recv interface index */ - // struct in_addr ipi_spec_dst; /* Local address */ - // struct in_addr ipi_addr; /* IP Header dst address */ - // }; - if len(b) != 12 { - return netip.Addr{}, false +// IPV6_TCLASS is specified by RFC 3542 as an int. + +func parseIPv6TCLASS(b []byte) (ecnBits, bool) { + if len(b) != 4 { + return 0, false } - return netip.AddrFrom4([4]byte(b[8:][:4])), true + return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true } -func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { - // struct in6_pktinfo { - // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ - // unsigned int ipi6_ifindex; /* send/recv interface index */ - // }; - if len(b) != 20 { +func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} + +// struct in_pktinfo { +// unsigned int ipi_ifindex; /* send/recv interface index */ +// struct in_addr ipi_spec_dst; /* Local address */ +// struct in_addr ipi_addr; /* IP Header dst address */ +// }; + +// parseInPktinfo returns the destination address from an IP_PKTINFO. +func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) { + if len(b) != 12 { return netip.Addr{}, false } - return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true + return netip.AddrFrom4([4]byte(b[8:][:4])), true } // appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address @@ -210,31 +212,28 @@ func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte { return b } -// appendCmsgIPSourceAddrV6 appends an IP_PKTINFO or IPV6_PKTINFO -// setting the source address for an outbound datagram. +// struct in6_pktinfo { +// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ +// unsigned int ipi6_ifindex; /* send/recv interface index */ +// }; + +// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO. +func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { + if len(b) != 20 { + return netip.Addr{}, false + } + return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true +} + +// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address +// for an outbound datagram. func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { - // struct in6_pktinfo { - // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ - // unsigned int ipi6_ifindex; /* send/recv interface index */ - // }; b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20) ip := src.As16() copy(data[0:], ip[:]) return b } -func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { - b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4) - data[0] = byte(ecn) - return b -} - -func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { - b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4) - data[0] = byte(ecn) - return b -} - // appendCmsg appends a cmsg with the given level, type, and size to b. // It returns the new buffer, and the data section of the cmsg. func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { diff --git a/internal/quic/udp_test.go b/internal/quic/udp_test.go index 450351b6b..d3732c140 100644 --- a/internal/quic/udp_test.go +++ b/internal/quic/udp_test.go @@ -16,9 +16,9 @@ import ( ) func TestUDPSourceUnspecified(t *testing.T) { - t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with no source address set. runUDPTest(t, func(t *testing.T, test udpTest) { + t.Logf("%v", test.dstAddr) data := []byte("source unspecified") if err := test.src.Write(datagram{ b: data, @@ -34,7 +34,6 @@ func TestUDPSourceUnspecified(t *testing.T) { } func TestUDPSourceSpecified(t *testing.T) { - t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with source address set. runUDPTest(t, func(t *testing.T, test udpTest) { data := []byte("source specified") @@ -53,7 +52,6 @@ func TestUDPSourceSpecified(t *testing.T) { } func TestUDPSourceInvalid(t *testing.T) { - t.Skip("https://go.dev/issue/65906 - temporarily skipped pending fix") // Send datagram with source address set to an address not associated with the connection. if !udpInvalidLocalAddrIsError { t.Skipf("%v: sending from invalid source succeeds", runtime.GOOS) @@ -77,7 +75,6 @@ func TestUDPSourceInvalid(t *testing.T) { } func TestUDPECN(t *testing.T) { - t.Skip("https://go.dev/issue/65907 - temporarily skipped pending fix") if !udpECNSupport { t.Skipf("%v: no ECN support", runtime.GOOS) } @@ -125,6 +122,18 @@ func runUDPTest(t *testing.T, f func(t *testing.T, u udpTest)) { spec = "unspec" } t.Run(fmt.Sprintf("%v/%v/%v", test.srcNet, test.dstNet, spec), func(t *testing.T) { + // See: https://go.googlesource.com/go/+/refs/tags/go1.22.0/src/net/ipsock.go#47 + // On these platforms, conns with network="udp" cannot accept IPv6. + switch runtime.GOOS { + case "dragonfly", "openbsd": + if test.srcNet == "udp6" && test.dstNet == "udp" { + t.Skipf("%v: no support for mapping IPv4 address to IPv6", runtime.GOOS) + } + } + if runtime.GOARCH == "wasm" && test.srcNet == "udp6" { + t.Skipf("%v: IPv6 tests fail when using wasm fake net", runtime.GOARCH) + } + srcAddr := netip.AddrPortFrom(netip.MustParseAddr(test.srcAddr), 0) srcConn, err := net.ListenUDP(test.srcNet, net.UDPAddrFromAddrPort(srcAddr)) if err != nil { From fa1142799318d3fa3632ecfd9f318ffa040e7c4c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 22 Feb 2024 17:31:38 -0800 Subject: [PATCH 47/70] quic: move package out of internal For golang/go#58547 Change-Id: I119d820824f82bfdd236c6826f960d0c934745ca Reviewed-on: https://go-review.googlesource.com/c/net/+/566295 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/cmd/interop/main.go | 4 ++-- {internal/quic => quic}/ack_delay.go | 0 {internal/quic => quic}/ack_delay_test.go | 0 {internal/quic => quic}/acks.go | 0 {internal/quic => quic}/acks_test.go | 0 {internal/quic => quic}/atomic_bits.go | 0 {internal/quic => quic}/bench_test.go | 0 {internal/quic => quic}/config.go | 0 {internal/quic => quic}/config_test.go | 0 {internal/quic => quic}/congestion_reno.go | 0 {internal/quic => quic}/congestion_reno_test.go | 0 {internal/quic => quic}/conn.go | 0 {internal/quic => quic}/conn_async_test.go | 0 {internal/quic => quic}/conn_close.go | 0 {internal/quic => quic}/conn_close_test.go | 0 {internal/quic => quic}/conn_flow.go | 0 {internal/quic => quic}/conn_flow_test.go | 0 {internal/quic => quic}/conn_id.go | 0 {internal/quic => quic}/conn_id_test.go | 0 {internal/quic => quic}/conn_loss.go | 0 {internal/quic => quic}/conn_loss_test.go | 0 {internal/quic => quic}/conn_recv.go | 0 {internal/quic => quic}/conn_send.go | 0 {internal/quic => quic}/conn_send_test.go | 0 {internal/quic => quic}/conn_streams.go | 0 {internal/quic => quic}/conn_streams_test.go | 0 {internal/quic => quic}/conn_test.go | 2 +- {internal/quic => quic}/crypto_stream.go | 0 {internal/quic => quic}/crypto_stream_test.go | 0 {internal/quic => quic}/dgram.go | 0 {internal/quic => quic}/doc.go | 0 {internal/quic => quic}/endpoint.go | 0 {internal/quic => quic}/endpoint_test.go | 2 +- {internal/quic => quic}/errors.go | 0 {internal/quic => quic}/files_test.go | 0 {internal/quic => quic}/frame_debug.go | 0 {internal/quic => quic}/gate.go | 0 {internal/quic => quic}/gate_test.go | 0 {internal/quic => quic}/gotraceback_test.go | 0 {internal/quic => quic}/idle.go | 0 {internal/quic => quic}/idle_test.go | 0 {internal/quic => quic}/key_update_test.go | 0 {internal/quic => quic}/log.go | 0 {internal/quic => quic}/loss.go | 0 {internal/quic => quic}/loss_test.go | 0 {internal/quic => quic}/main_test.go | 0 {internal/quic => quic}/math.go | 0 {internal/quic => quic}/pacer.go | 0 {internal/quic => quic}/pacer_test.go | 0 {internal/quic => quic}/packet.go | 0 {internal/quic => quic}/packet_codec_test.go | 2 +- {internal/quic => quic}/packet_number.go | 0 {internal/quic => quic}/packet_number_test.go | 0 {internal/quic => quic}/packet_parser.go | 0 {internal/quic => quic}/packet_protection.go | 0 {internal/quic => quic}/packet_protection_test.go | 0 {internal/quic => quic}/packet_test.go | 0 {internal/quic => quic}/packet_writer.go | 0 {internal/quic => quic}/path.go | 0 {internal/quic => quic}/path_test.go | 0 {internal/quic => quic}/ping.go | 0 {internal/quic => quic}/ping_test.go | 0 {internal/quic => quic}/pipe.go | 0 {internal/quic => quic}/pipe_test.go | 0 {internal/quic => quic}/qlog.go | 0 {internal/quic => quic}/qlog/handler.go | 0 {internal/quic => quic}/qlog/json_writer.go | 0 {internal/quic => quic}/qlog/json_writer_test.go | 0 {internal/quic => quic}/qlog/qlog.go | 0 {internal/quic => quic}/qlog/qlog_test.go | 0 {internal/quic => quic}/qlog_test.go | 2 +- {internal/quic => quic}/queue.go | 0 {internal/quic => quic}/queue_test.go | 0 {internal/quic => quic}/quic.go | 0 {internal/quic => quic}/quic_test.go | 0 {internal/quic => quic}/rangeset.go | 0 {internal/quic => quic}/rangeset_test.go | 0 {internal/quic => quic}/retry.go | 0 {internal/quic => quic}/retry_test.go | 0 {internal/quic => quic}/rtt.go | 0 {internal/quic => quic}/rtt_test.go | 0 {internal/quic => quic}/sent_packet.go | 0 {internal/quic => quic}/sent_packet_list.go | 0 {internal/quic => quic}/sent_packet_list_test.go | 0 {internal/quic => quic}/sent_packet_test.go | 0 {internal/quic => quic}/sent_val.go | 0 {internal/quic => quic}/sent_val_test.go | 0 {internal/quic => quic}/stateless_reset.go | 0 {internal/quic => quic}/stateless_reset_test.go | 0 {internal/quic => quic}/stream.go | 0 {internal/quic => quic}/stream_limits.go | 0 {internal/quic => quic}/stream_limits_test.go | 0 {internal/quic => quic}/stream_test.go | 0 {internal/quic => quic}/tls.go | 0 {internal/quic => quic}/tls_test.go | 0 {internal/quic => quic}/tlsconfig_test.go | 0 {internal/quic => quic}/transport_params.go | 0 {internal/quic => quic}/transport_params_test.go | 0 {internal/quic => quic}/udp.go | 0 {internal/quic => quic}/udp_darwin.go | 0 {internal/quic => quic}/udp_linux.go | 0 {internal/quic => quic}/udp_msg.go | 0 {internal/quic => quic}/udp_other.go | 0 {internal/quic => quic}/udp_test.go | 0 {internal/quic => quic}/version_test.go | 0 {internal/quic => quic}/wire.go | 0 {internal/quic => quic}/wire_test.go | 0 107 files changed, 6 insertions(+), 6 deletions(-) rename {internal/quic => quic}/ack_delay.go (100%) rename {internal/quic => quic}/ack_delay_test.go (100%) rename {internal/quic => quic}/acks.go (100%) rename {internal/quic => quic}/acks_test.go (100%) rename {internal/quic => quic}/atomic_bits.go (100%) rename {internal/quic => quic}/bench_test.go (100%) rename {internal/quic => quic}/config.go (100%) rename {internal/quic => quic}/config_test.go (100%) rename {internal/quic => quic}/congestion_reno.go (100%) rename {internal/quic => quic}/congestion_reno_test.go (100%) rename {internal/quic => quic}/conn.go (100%) rename {internal/quic => quic}/conn_async_test.go (100%) rename {internal/quic => quic}/conn_close.go (100%) rename {internal/quic => quic}/conn_close_test.go (100%) rename {internal/quic => quic}/conn_flow.go (100%) rename {internal/quic => quic}/conn_flow_test.go (100%) rename {internal/quic => quic}/conn_id.go (100%) rename {internal/quic => quic}/conn_id_test.go (100%) rename {internal/quic => quic}/conn_loss.go (100%) rename {internal/quic => quic}/conn_loss_test.go (100%) rename {internal/quic => quic}/conn_recv.go (100%) rename {internal/quic => quic}/conn_send.go (100%) rename {internal/quic => quic}/conn_send_test.go (100%) rename {internal/quic => quic}/conn_streams.go (100%) rename {internal/quic => quic}/conn_streams_test.go (100%) rename {internal/quic => quic}/conn_test.go (99%) rename {internal/quic => quic}/crypto_stream.go (100%) rename {internal/quic => quic}/crypto_stream_test.go (100%) rename {internal/quic => quic}/dgram.go (100%) rename {internal/quic => quic}/doc.go (100%) rename {internal/quic => quic}/endpoint.go (100%) rename {internal/quic => quic}/endpoint_test.go (99%) rename {internal/quic => quic}/errors.go (100%) rename {internal/quic => quic}/files_test.go (100%) rename {internal/quic => quic}/frame_debug.go (100%) rename {internal/quic => quic}/gate.go (100%) rename {internal/quic => quic}/gate_test.go (100%) rename {internal/quic => quic}/gotraceback_test.go (100%) rename {internal/quic => quic}/idle.go (100%) rename {internal/quic => quic}/idle_test.go (100%) rename {internal/quic => quic}/key_update_test.go (100%) rename {internal/quic => quic}/log.go (100%) rename {internal/quic => quic}/loss.go (100%) rename {internal/quic => quic}/loss_test.go (100%) rename {internal/quic => quic}/main_test.go (100%) rename {internal/quic => quic}/math.go (100%) rename {internal/quic => quic}/pacer.go (100%) rename {internal/quic => quic}/pacer_test.go (100%) rename {internal/quic => quic}/packet.go (100%) rename {internal/quic => quic}/packet_codec_test.go (99%) rename {internal/quic => quic}/packet_number.go (100%) rename {internal/quic => quic}/packet_number_test.go (100%) rename {internal/quic => quic}/packet_parser.go (100%) rename {internal/quic => quic}/packet_protection.go (100%) rename {internal/quic => quic}/packet_protection_test.go (100%) rename {internal/quic => quic}/packet_test.go (100%) rename {internal/quic => quic}/packet_writer.go (100%) rename {internal/quic => quic}/path.go (100%) rename {internal/quic => quic}/path_test.go (100%) rename {internal/quic => quic}/ping.go (100%) rename {internal/quic => quic}/ping_test.go (100%) rename {internal/quic => quic}/pipe.go (100%) rename {internal/quic => quic}/pipe_test.go (100%) rename {internal/quic => quic}/qlog.go (100%) rename {internal/quic => quic}/qlog/handler.go (100%) rename {internal/quic => quic}/qlog/json_writer.go (100%) rename {internal/quic => quic}/qlog/json_writer_test.go (100%) rename {internal/quic => quic}/qlog/qlog.go (100%) rename {internal/quic => quic}/qlog/qlog_test.go (100%) rename {internal/quic => quic}/qlog_test.go (99%) rename {internal/quic => quic}/queue.go (100%) rename {internal/quic => quic}/queue_test.go (100%) rename {internal/quic => quic}/quic.go (100%) rename {internal/quic => quic}/quic_test.go (100%) rename {internal/quic => quic}/rangeset.go (100%) rename {internal/quic => quic}/rangeset_test.go (100%) rename {internal/quic => quic}/retry.go (100%) rename {internal/quic => quic}/retry_test.go (100%) rename {internal/quic => quic}/rtt.go (100%) rename {internal/quic => quic}/rtt_test.go (100%) rename {internal/quic => quic}/sent_packet.go (100%) rename {internal/quic => quic}/sent_packet_list.go (100%) rename {internal/quic => quic}/sent_packet_list_test.go (100%) rename {internal/quic => quic}/sent_packet_test.go (100%) rename {internal/quic => quic}/sent_val.go (100%) rename {internal/quic => quic}/sent_val_test.go (100%) rename {internal/quic => quic}/stateless_reset.go (100%) rename {internal/quic => quic}/stateless_reset_test.go (100%) rename {internal/quic => quic}/stream.go (100%) rename {internal/quic => quic}/stream_limits.go (100%) rename {internal/quic => quic}/stream_limits_test.go (100%) rename {internal/quic => quic}/stream_test.go (100%) rename {internal/quic => quic}/tls.go (100%) rename {internal/quic => quic}/tls_test.go (100%) rename {internal/quic => quic}/tlsconfig_test.go (100%) rename {internal/quic => quic}/transport_params.go (100%) rename {internal/quic => quic}/transport_params_test.go (100%) rename {internal/quic => quic}/udp.go (100%) rename {internal/quic => quic}/udp_darwin.go (100%) rename {internal/quic => quic}/udp_linux.go (100%) rename {internal/quic => quic}/udp_msg.go (100%) rename {internal/quic => quic}/udp_other.go (100%) rename {internal/quic => quic}/udp_test.go (100%) rename {internal/quic => quic}/version_test.go (100%) rename {internal/quic => quic}/wire.go (100%) rename {internal/quic => quic}/wire_test.go (100%) diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index 0899e0f1e..5b652a2b1 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -25,8 +25,8 @@ import ( "path/filepath" "sync" - "golang.org/x/net/internal/quic" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic" + "golang.org/x/net/quic/qlog" ) var ( diff --git a/internal/quic/ack_delay.go b/quic/ack_delay.go similarity index 100% rename from internal/quic/ack_delay.go rename to quic/ack_delay.go diff --git a/internal/quic/ack_delay_test.go b/quic/ack_delay_test.go similarity index 100% rename from internal/quic/ack_delay_test.go rename to quic/ack_delay_test.go diff --git a/internal/quic/acks.go b/quic/acks.go similarity index 100% rename from internal/quic/acks.go rename to quic/acks.go diff --git a/internal/quic/acks_test.go b/quic/acks_test.go similarity index 100% rename from internal/quic/acks_test.go rename to quic/acks_test.go diff --git a/internal/quic/atomic_bits.go b/quic/atomic_bits.go similarity index 100% rename from internal/quic/atomic_bits.go rename to quic/atomic_bits.go diff --git a/internal/quic/bench_test.go b/quic/bench_test.go similarity index 100% rename from internal/quic/bench_test.go rename to quic/bench_test.go diff --git a/internal/quic/config.go b/quic/config.go similarity index 100% rename from internal/quic/config.go rename to quic/config.go diff --git a/internal/quic/config_test.go b/quic/config_test.go similarity index 100% rename from internal/quic/config_test.go rename to quic/config_test.go diff --git a/internal/quic/congestion_reno.go b/quic/congestion_reno.go similarity index 100% rename from internal/quic/congestion_reno.go rename to quic/congestion_reno.go diff --git a/internal/quic/congestion_reno_test.go b/quic/congestion_reno_test.go similarity index 100% rename from internal/quic/congestion_reno_test.go rename to quic/congestion_reno_test.go diff --git a/internal/quic/conn.go b/quic/conn.go similarity index 100% rename from internal/quic/conn.go rename to quic/conn.go diff --git a/internal/quic/conn_async_test.go b/quic/conn_async_test.go similarity index 100% rename from internal/quic/conn_async_test.go rename to quic/conn_async_test.go diff --git a/internal/quic/conn_close.go b/quic/conn_close.go similarity index 100% rename from internal/quic/conn_close.go rename to quic/conn_close.go diff --git a/internal/quic/conn_close_test.go b/quic/conn_close_test.go similarity index 100% rename from internal/quic/conn_close_test.go rename to quic/conn_close_test.go diff --git a/internal/quic/conn_flow.go b/quic/conn_flow.go similarity index 100% rename from internal/quic/conn_flow.go rename to quic/conn_flow.go diff --git a/internal/quic/conn_flow_test.go b/quic/conn_flow_test.go similarity index 100% rename from internal/quic/conn_flow_test.go rename to quic/conn_flow_test.go diff --git a/internal/quic/conn_id.go b/quic/conn_id.go similarity index 100% rename from internal/quic/conn_id.go rename to quic/conn_id.go diff --git a/internal/quic/conn_id_test.go b/quic/conn_id_test.go similarity index 100% rename from internal/quic/conn_id_test.go rename to quic/conn_id_test.go diff --git a/internal/quic/conn_loss.go b/quic/conn_loss.go similarity index 100% rename from internal/quic/conn_loss.go rename to quic/conn_loss.go diff --git a/internal/quic/conn_loss_test.go b/quic/conn_loss_test.go similarity index 100% rename from internal/quic/conn_loss_test.go rename to quic/conn_loss_test.go diff --git a/internal/quic/conn_recv.go b/quic/conn_recv.go similarity index 100% rename from internal/quic/conn_recv.go rename to quic/conn_recv.go diff --git a/internal/quic/conn_send.go b/quic/conn_send.go similarity index 100% rename from internal/quic/conn_send.go rename to quic/conn_send.go diff --git a/internal/quic/conn_send_test.go b/quic/conn_send_test.go similarity index 100% rename from internal/quic/conn_send_test.go rename to quic/conn_send_test.go diff --git a/internal/quic/conn_streams.go b/quic/conn_streams.go similarity index 100% rename from internal/quic/conn_streams.go rename to quic/conn_streams.go diff --git a/internal/quic/conn_streams_test.go b/quic/conn_streams_test.go similarity index 100% rename from internal/quic/conn_streams_test.go rename to quic/conn_streams_test.go diff --git a/internal/quic/conn_test.go b/quic/conn_test.go similarity index 99% rename from internal/quic/conn_test.go rename to quic/conn_test.go index a765ad60c..f4f1818a6 100644 --- a/internal/quic/conn_test.go +++ b/quic/conn_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) var ( diff --git a/internal/quic/crypto_stream.go b/quic/crypto_stream.go similarity index 100% rename from internal/quic/crypto_stream.go rename to quic/crypto_stream.go diff --git a/internal/quic/crypto_stream_test.go b/quic/crypto_stream_test.go similarity index 100% rename from internal/quic/crypto_stream_test.go rename to quic/crypto_stream_test.go diff --git a/internal/quic/dgram.go b/quic/dgram.go similarity index 100% rename from internal/quic/dgram.go rename to quic/dgram.go diff --git a/internal/quic/doc.go b/quic/doc.go similarity index 100% rename from internal/quic/doc.go rename to quic/doc.go diff --git a/internal/quic/endpoint.go b/quic/endpoint.go similarity index 100% rename from internal/quic/endpoint.go rename to quic/endpoint.go diff --git a/internal/quic/endpoint_test.go b/quic/endpoint_test.go similarity index 99% rename from internal/quic/endpoint_test.go rename to quic/endpoint_test.go index b6669fc83..d5f436e6d 100644 --- a/internal/quic/endpoint_test.go +++ b/quic/endpoint_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestConnect(t *testing.T) { diff --git a/internal/quic/errors.go b/quic/errors.go similarity index 100% rename from internal/quic/errors.go rename to quic/errors.go diff --git a/internal/quic/files_test.go b/quic/files_test.go similarity index 100% rename from internal/quic/files_test.go rename to quic/files_test.go diff --git a/internal/quic/frame_debug.go b/quic/frame_debug.go similarity index 100% rename from internal/quic/frame_debug.go rename to quic/frame_debug.go diff --git a/internal/quic/gate.go b/quic/gate.go similarity index 100% rename from internal/quic/gate.go rename to quic/gate.go diff --git a/internal/quic/gate_test.go b/quic/gate_test.go similarity index 100% rename from internal/quic/gate_test.go rename to quic/gate_test.go diff --git a/internal/quic/gotraceback_test.go b/quic/gotraceback_test.go similarity index 100% rename from internal/quic/gotraceback_test.go rename to quic/gotraceback_test.go diff --git a/internal/quic/idle.go b/quic/idle.go similarity index 100% rename from internal/quic/idle.go rename to quic/idle.go diff --git a/internal/quic/idle_test.go b/quic/idle_test.go similarity index 100% rename from internal/quic/idle_test.go rename to quic/idle_test.go diff --git a/internal/quic/key_update_test.go b/quic/key_update_test.go similarity index 100% rename from internal/quic/key_update_test.go rename to quic/key_update_test.go diff --git a/internal/quic/log.go b/quic/log.go similarity index 100% rename from internal/quic/log.go rename to quic/log.go diff --git a/internal/quic/loss.go b/quic/loss.go similarity index 100% rename from internal/quic/loss.go rename to quic/loss.go diff --git a/internal/quic/loss_test.go b/quic/loss_test.go similarity index 100% rename from internal/quic/loss_test.go rename to quic/loss_test.go diff --git a/internal/quic/main_test.go b/quic/main_test.go similarity index 100% rename from internal/quic/main_test.go rename to quic/main_test.go diff --git a/internal/quic/math.go b/quic/math.go similarity index 100% rename from internal/quic/math.go rename to quic/math.go diff --git a/internal/quic/pacer.go b/quic/pacer.go similarity index 100% rename from internal/quic/pacer.go rename to quic/pacer.go diff --git a/internal/quic/pacer_test.go b/quic/pacer_test.go similarity index 100% rename from internal/quic/pacer_test.go rename to quic/pacer_test.go diff --git a/internal/quic/packet.go b/quic/packet.go similarity index 100% rename from internal/quic/packet.go rename to quic/packet.go diff --git a/internal/quic/packet_codec_test.go b/quic/packet_codec_test.go similarity index 99% rename from internal/quic/packet_codec_test.go rename to quic/packet_codec_test.go index 98b3bbb05..3b39795ef 100644 --- a/internal/quic/packet_codec_test.go +++ b/quic/packet_codec_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestParseLongHeaderPacket(t *testing.T) { diff --git a/internal/quic/packet_number.go b/quic/packet_number.go similarity index 100% rename from internal/quic/packet_number.go rename to quic/packet_number.go diff --git a/internal/quic/packet_number_test.go b/quic/packet_number_test.go similarity index 100% rename from internal/quic/packet_number_test.go rename to quic/packet_number_test.go diff --git a/internal/quic/packet_parser.go b/quic/packet_parser.go similarity index 100% rename from internal/quic/packet_parser.go rename to quic/packet_parser.go diff --git a/internal/quic/packet_protection.go b/quic/packet_protection.go similarity index 100% rename from internal/quic/packet_protection.go rename to quic/packet_protection.go diff --git a/internal/quic/packet_protection_test.go b/quic/packet_protection_test.go similarity index 100% rename from internal/quic/packet_protection_test.go rename to quic/packet_protection_test.go diff --git a/internal/quic/packet_test.go b/quic/packet_test.go similarity index 100% rename from internal/quic/packet_test.go rename to quic/packet_test.go diff --git a/internal/quic/packet_writer.go b/quic/packet_writer.go similarity index 100% rename from internal/quic/packet_writer.go rename to quic/packet_writer.go diff --git a/internal/quic/path.go b/quic/path.go similarity index 100% rename from internal/quic/path.go rename to quic/path.go diff --git a/internal/quic/path_test.go b/quic/path_test.go similarity index 100% rename from internal/quic/path_test.go rename to quic/path_test.go diff --git a/internal/quic/ping.go b/quic/ping.go similarity index 100% rename from internal/quic/ping.go rename to quic/ping.go diff --git a/internal/quic/ping_test.go b/quic/ping_test.go similarity index 100% rename from internal/quic/ping_test.go rename to quic/ping_test.go diff --git a/internal/quic/pipe.go b/quic/pipe.go similarity index 100% rename from internal/quic/pipe.go rename to quic/pipe.go diff --git a/internal/quic/pipe_test.go b/quic/pipe_test.go similarity index 100% rename from internal/quic/pipe_test.go rename to quic/pipe_test.go diff --git a/internal/quic/qlog.go b/quic/qlog.go similarity index 100% rename from internal/quic/qlog.go rename to quic/qlog.go diff --git a/internal/quic/qlog/handler.go b/quic/qlog/handler.go similarity index 100% rename from internal/quic/qlog/handler.go rename to quic/qlog/handler.go diff --git a/internal/quic/qlog/json_writer.go b/quic/qlog/json_writer.go similarity index 100% rename from internal/quic/qlog/json_writer.go rename to quic/qlog/json_writer.go diff --git a/internal/quic/qlog/json_writer_test.go b/quic/qlog/json_writer_test.go similarity index 100% rename from internal/quic/qlog/json_writer_test.go rename to quic/qlog/json_writer_test.go diff --git a/internal/quic/qlog/qlog.go b/quic/qlog/qlog.go similarity index 100% rename from internal/quic/qlog/qlog.go rename to quic/qlog/qlog.go diff --git a/internal/quic/qlog/qlog_test.go b/quic/qlog/qlog_test.go similarity index 100% rename from internal/quic/qlog/qlog_test.go rename to quic/qlog/qlog_test.go diff --git a/internal/quic/qlog_test.go b/quic/qlog_test.go similarity index 99% rename from internal/quic/qlog_test.go rename to quic/qlog_test.go index 6c79c6cf4..c0b5cd170 100644 --- a/internal/quic/qlog_test.go +++ b/quic/qlog_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestQLogHandshake(t *testing.T) { diff --git a/internal/quic/queue.go b/quic/queue.go similarity index 100% rename from internal/quic/queue.go rename to quic/queue.go diff --git a/internal/quic/queue_test.go b/quic/queue_test.go similarity index 100% rename from internal/quic/queue_test.go rename to quic/queue_test.go diff --git a/internal/quic/quic.go b/quic/quic.go similarity index 100% rename from internal/quic/quic.go rename to quic/quic.go diff --git a/internal/quic/quic_test.go b/quic/quic_test.go similarity index 100% rename from internal/quic/quic_test.go rename to quic/quic_test.go diff --git a/internal/quic/rangeset.go b/quic/rangeset.go similarity index 100% rename from internal/quic/rangeset.go rename to quic/rangeset.go diff --git a/internal/quic/rangeset_test.go b/quic/rangeset_test.go similarity index 100% rename from internal/quic/rangeset_test.go rename to quic/rangeset_test.go diff --git a/internal/quic/retry.go b/quic/retry.go similarity index 100% rename from internal/quic/retry.go rename to quic/retry.go diff --git a/internal/quic/retry_test.go b/quic/retry_test.go similarity index 100% rename from internal/quic/retry_test.go rename to quic/retry_test.go diff --git a/internal/quic/rtt.go b/quic/rtt.go similarity index 100% rename from internal/quic/rtt.go rename to quic/rtt.go diff --git a/internal/quic/rtt_test.go b/quic/rtt_test.go similarity index 100% rename from internal/quic/rtt_test.go rename to quic/rtt_test.go diff --git a/internal/quic/sent_packet.go b/quic/sent_packet.go similarity index 100% rename from internal/quic/sent_packet.go rename to quic/sent_packet.go diff --git a/internal/quic/sent_packet_list.go b/quic/sent_packet_list.go similarity index 100% rename from internal/quic/sent_packet_list.go rename to quic/sent_packet_list.go diff --git a/internal/quic/sent_packet_list_test.go b/quic/sent_packet_list_test.go similarity index 100% rename from internal/quic/sent_packet_list_test.go rename to quic/sent_packet_list_test.go diff --git a/internal/quic/sent_packet_test.go b/quic/sent_packet_test.go similarity index 100% rename from internal/quic/sent_packet_test.go rename to quic/sent_packet_test.go diff --git a/internal/quic/sent_val.go b/quic/sent_val.go similarity index 100% rename from internal/quic/sent_val.go rename to quic/sent_val.go diff --git a/internal/quic/sent_val_test.go b/quic/sent_val_test.go similarity index 100% rename from internal/quic/sent_val_test.go rename to quic/sent_val_test.go diff --git a/internal/quic/stateless_reset.go b/quic/stateless_reset.go similarity index 100% rename from internal/quic/stateless_reset.go rename to quic/stateless_reset.go diff --git a/internal/quic/stateless_reset_test.go b/quic/stateless_reset_test.go similarity index 100% rename from internal/quic/stateless_reset_test.go rename to quic/stateless_reset_test.go diff --git a/internal/quic/stream.go b/quic/stream.go similarity index 100% rename from internal/quic/stream.go rename to quic/stream.go diff --git a/internal/quic/stream_limits.go b/quic/stream_limits.go similarity index 100% rename from internal/quic/stream_limits.go rename to quic/stream_limits.go diff --git a/internal/quic/stream_limits_test.go b/quic/stream_limits_test.go similarity index 100% rename from internal/quic/stream_limits_test.go rename to quic/stream_limits_test.go diff --git a/internal/quic/stream_test.go b/quic/stream_test.go similarity index 100% rename from internal/quic/stream_test.go rename to quic/stream_test.go diff --git a/internal/quic/tls.go b/quic/tls.go similarity index 100% rename from internal/quic/tls.go rename to quic/tls.go diff --git a/internal/quic/tls_test.go b/quic/tls_test.go similarity index 100% rename from internal/quic/tls_test.go rename to quic/tls_test.go diff --git a/internal/quic/tlsconfig_test.go b/quic/tlsconfig_test.go similarity index 100% rename from internal/quic/tlsconfig_test.go rename to quic/tlsconfig_test.go diff --git a/internal/quic/transport_params.go b/quic/transport_params.go similarity index 100% rename from internal/quic/transport_params.go rename to quic/transport_params.go diff --git a/internal/quic/transport_params_test.go b/quic/transport_params_test.go similarity index 100% rename from internal/quic/transport_params_test.go rename to quic/transport_params_test.go diff --git a/internal/quic/udp.go b/quic/udp.go similarity index 100% rename from internal/quic/udp.go rename to quic/udp.go diff --git a/internal/quic/udp_darwin.go b/quic/udp_darwin.go similarity index 100% rename from internal/quic/udp_darwin.go rename to quic/udp_darwin.go diff --git a/internal/quic/udp_linux.go b/quic/udp_linux.go similarity index 100% rename from internal/quic/udp_linux.go rename to quic/udp_linux.go diff --git a/internal/quic/udp_msg.go b/quic/udp_msg.go similarity index 100% rename from internal/quic/udp_msg.go rename to quic/udp_msg.go diff --git a/internal/quic/udp_other.go b/quic/udp_other.go similarity index 100% rename from internal/quic/udp_other.go rename to quic/udp_other.go diff --git a/internal/quic/udp_test.go b/quic/udp_test.go similarity index 100% rename from internal/quic/udp_test.go rename to quic/udp_test.go diff --git a/internal/quic/version_test.go b/quic/version_test.go similarity index 100% rename from internal/quic/version_test.go rename to quic/version_test.go diff --git a/internal/quic/wire.go b/quic/wire.go similarity index 100% rename from internal/quic/wire.go rename to quic/wire.go diff --git a/internal/quic/wire_test.go b/quic/wire_test.go similarity index 100% rename from internal/quic/wire_test.go rename to quic/wire_test.go From 3dfd003ad338913e62ad1e56020aee316f1ffe59 Mon Sep 17 00:00:00 2001 From: Aleksei Besogonov Date: Fri, 12 Jan 2024 07:38:27 +0000 Subject: [PATCH 48/70] websocket: add support for dialing with context Right now there is no way to pass context.Context to websocket.Dial. In addition, this method can block indefinitely in the NewClient call. Fixes golang/go#57953. Change-Id: Ic52d4b8306cd0850e78d683abb1bf11f0d4247ca GitHub-Last-Rev: 5e8c3a7cbaa324d6165ff40d2ee9ea6c4433b036 GitHub-Pull-Request: golang/net#160 Reviewed-on: https://go-review.googlesource.com/c/net/+/463097 Auto-Submit: Damien Neil Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov --- websocket/client.go | 56 +++++++++++++++++++++++++++++++++--------- websocket/dial.go | 11 ++++++--- websocket/dial_test.go | 37 ++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 15 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index 69a4ac7ee..2c737f77a 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -6,10 +6,12 @@ package websocket import ( "bufio" + "context" "io" "net" "net/http" "net/url" + "time" ) // DialError is an error that occurs while dialling a websocket server. @@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string { return location.Host } -// DialConfig opens a new client connection to a WebSocket with a config. func DialConfig(config *Config) (ws *Conn, err error) { - var client net.Conn + return config.DialContext(context.Background()) +} + +// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation. +func (config *Config) DialContext(ctx context.Context) (*Conn, error) { if config.Location == nil { return nil, &DialError{config, ErrBadWebSocketLocation} } if config.Origin == nil { return nil, &DialError{config, ErrBadWebSocketOrigin} } + dialer := config.Dialer if dialer == nil { dialer = &net.Dialer{} } - client, err = dialWithDialer(dialer, config) - if err != nil { - goto Error - } - ws, err = NewClient(config, client) + + client, err := dialWithDialer(ctx, dialer, config) if err != nil { - client.Close() - goto Error + return nil, &DialError{config, err} } - return -Error: - return nil, &DialError{config, err} + // Cleanup the connection if we fail to create the websocket successfully + success := false + defer func() { + if !success { + _ = client.Close() + } + }() + + var ws *Conn + var wsErr error + doneConnecting := make(chan struct{}) + go func() { + defer close(doneConnecting) + ws, err = NewClient(config, client) + if err != nil { + wsErr = &DialError{config, err} + } + }() + + // The websocket.NewClient() function can block indefinitely, make sure that we + // respect the deadlines specified by the context. + select { + case <-ctx.Done(): + // Force the pending operations to fail, terminating the pending connection attempt + _ = client.SetDeadline(time.Now()) + <-doneConnecting // Wait for the goroutine that tries to establish the connection to finish + return nil, &DialError{config, ctx.Err()} + case <-doneConnecting: + if wsErr == nil { + success = true // Disarm the deferred connection cleanup + } + return ws, wsErr + } } diff --git a/websocket/dial.go b/websocket/dial.go index 2dab943a4..8a2d83c47 100644 --- a/websocket/dial.go +++ b/websocket/dial.go @@ -5,18 +5,23 @@ package websocket import ( + "context" "crypto/tls" "net" ) -func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) { +func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) { switch config.Location.Scheme { case "ws": - conn, err = dialer.Dial("tcp", parseAuthority(config.Location)) + conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) case "wss": - conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig) + tlsDialer := &tls.Dialer{ + NetDialer: dialer, + Config: config.TlsConfig, + } + conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) default: err = ErrBadScheme } diff --git a/websocket/dial_test.go b/websocket/dial_test.go index aa03e30dd..dd844872c 100644 --- a/websocket/dial_test.go +++ b/websocket/dial_test.go @@ -5,10 +5,13 @@ package websocket import ( + "context" "crypto/tls" + "errors" "fmt" "log" "net" + "net/http" "net/http/httptest" "testing" "time" @@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) { t.Fatalf("expected timeout error, got %#v", neterr) } } + +func TestDialConfigTLSWithTimeouts(t *testing.T) { + t.Parallel() + + finishedRequest := make(chan bool) + + // Context for cancellation + ctx, cancel := context.WithCancel(context.Background()) + + // This is a TLS server that blocks each request indefinitely (and cancels the context) + tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancel() + <-finishedRequest + })) + + tlsServerAddr := tlsServer.Listener.Addr().String() + log.Print("Test TLS WebSocket server listening on ", tlsServerAddr) + defer tlsServer.Close() + defer close(finishedRequest) + + config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost") + config.TlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + _, err := config.DialContext(ctx) + dialerr, ok := err.(*DialError) + if !ok { + t.Fatalf("DialError expected, got %#v", err) + } + if !errors.Is(dialerr.Err, context.Canceled) { + t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err) + } +} From 9fb4a8c9216d09f29d58e45a79cc2065d1b5bbf5 Mon Sep 17 00:00:00 2001 From: bestgopher <84328409@qq.com> Date: Tue, 6 Feb 2024 03:08:04 +0000 Subject: [PATCH 49/70] http2: send an error of FLOW_CONTROL_ERROR when exceed the maximum octets According to rfc9113 "https://www.rfc-editor.org/rfc/rfc9113.html#section-6.9.1-7", if a sender receives a WINDOW_UPDATE that causes a flow-control window to exceed this maximum, it MUST terminate either the stream or the connection, as appropriate. For streams, the sender sends a RST_STREAM with an error code of FLOW_CONTROL_ERROR. Change-Id: I5e14db247012ebc860a23053f73e70b83c7cd85d GitHub-Last-Rev: d1a85d3381f634904fc292c9c0a920dd1341adfd GitHub-Pull-Request: golang/net#204 Reviewed-on: https://go-review.googlesource.com/c/net/+/561035 Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Carlos Amedee Reviewed-by: Damien Neil --- http2/transport.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/http2/transport.go b/http2/transport.go index df578b86c..c2a5b44b3 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2911,6 +2911,15 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { fl = &cs.flow } if !fl.add(int32(f.Increment)) { + // For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR + if cs != nil { + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeFlowControl, + }) + return nil + } + return ConnectionError(ErrCodeFlowControl) } cc.cond.Broadcast() From c289c7ab4f437bb502e685daaae72426126d5595 Mon Sep 17 00:00:00 2001 From: Dmitri Shuralyov Date: Sat, 2 Mar 2024 18:37:38 -0500 Subject: [PATCH 50/70] websocket: re-add documentation for DialConfig The comment of the DialConfig function was dropped during CL 463097. There doesn't seem to be a good reason to do that, so bring it back. For golang/go#57953. Change-Id: I3e458b7d18cdab95763f003da5a644c8287b54ad Reviewed-on: https://go-review.googlesource.com/c/net/+/568198 Reviewed-by: Dmitri Shuralyov LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Auto-Submit: Dmitri Shuralyov --- websocket/client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/websocket/client.go b/websocket/client.go index 2c737f77a..1e64157f3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -79,6 +79,7 @@ func parseAuthority(location *url.URL) string { return location.Host } +// DialConfig opens a new client connection to a WebSocket with a config. func DialConfig(config *Config) (ws *Conn, err error) { return config.DialContext(context.Background()) } From 7ee34a078aecd23a99f205bded144e5246a27d7c Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Mon, 4 Mar 2024 18:45:30 +0000 Subject: [PATCH 51/70] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: I6d2aa8edee71b255fb6970eb5d817a20df7cc357 Reviewed-on: https://go-review.googlesource.com/c/net/+/568895 Auto-Submit: Gopher Robot Reviewed-by: Than McIntosh LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Knyszek --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 7f512d703..36207106d 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.19.0 - golang.org/x/sys v0.17.0 - golang.org/x/term v0.17.0 + golang.org/x/crypto v0.21.0 + golang.org/x/sys v0.18.0 + golang.org/x/term v0.18.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 683b469d6..69fb10498 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= From ab271c317248ea0f18481852f96d12d5eca05cf8 Mon Sep 17 00:00:00 2001 From: David Bell Date: Tue, 29 Aug 2023 20:50:33 +0000 Subject: [PATCH 52/70] http2: add IdleConnTimeout to http2.Transport Exposes an IdleConnTimeout on http2.Transport directly, rather than rely on configuring it through the underlying http1 transport. For golang/go#57893 Change-Id: Ibe506da39e314aebec1cd6df64937982182a37ca GitHub-Last-Rev: cc8f1710ed543da8e937aa2446b0a3982dec6ce3 GitHub-Pull-Request: golang/net#173 Reviewed-on: https://go-review.googlesource.com/c/net/+/497195 Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Matthew Dempsky --- http2/transport.go | 14 ++++++++++ http2/transport_test.go | 62 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/http2/transport.go b/http2/transport.go index c2a5b44b3..b599197e7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -147,6 +147,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -3150,9 +3156,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err } func (t *Transport) idleConnTimeout() time.Duration { + // to keep things backwards compatible, we use non-zero values of + // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying + // http1 transport, followed by 0 + if t.IdleConnTimeout != 0 { + return t.IdleConnTimeout + } + if t.t1 != nil { return t.t1.IdleConnTimeout } + return 0 } diff --git a/http2/transport_test.go b/http2/transport_test.go index a81131f29..6ac8e978b 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -95,6 +95,68 @@ func startH2cServer(t *testing.T) net.Listener { return l } +func TestIdleConnTimeout(t *testing.T) { + for _, test := range []struct { + idleConnTimeout time.Duration + wait time.Duration + baseTransport *http.Transport + wantConns int32 + }{{ + idleConnTimeout: 2 * time.Second, + wait: 1 * time.Second, + baseTransport: nil, + wantConns: 1, + }, { + idleConnTimeout: 1 * time.Second, + wait: 2 * time.Second, + baseTransport: nil, + wantConns: 5, + }, { + idleConnTimeout: 0 * time.Second, + wait: 1 * time.Second, + baseTransport: &http.Transport{ + IdleConnTimeout: 2 * time.Second, + }, + wantConns: 1, + }} { + var gotConns int32 + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + IdleConnTimeout: test.idleConnTimeout, + TLSClientConfig: tlsConfigInsecure, + } + defer tr.CloseIdleConnections() + + for i := 0; i < 5; i++ { + req, _ := http.NewRequest("GET", st.ts.URL, http.NoBody) + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConns, 1) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + _, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("%v", err) + } + + <-time.After(test.wait) + } + + if gotConns != test.wantConns { + t.Errorf("incorrect gotConns: %d != %d", gotConns, test.wantConns) + } + } +} + func TestTransportH2c(t *testing.T) { l := startH2cServer(t) defer l.Close() From 8c07e20f924fb9dec8d39d2793f72a42c3261a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E7=8E=AE=E6=96=87?= Date: Sun, 3 Sep 2023 14:21:31 +0800 Subject: [PATCH 53/70] httpproxy: allow any scheme currently only http/https/socks5 scheme are allowed. However, any scheme could be possible if user provides their own implementation. Specifically, the widely used "socks5h://localhost" is parsed as Scheme="http" Host="socks5h:", which does not make sense because host name cannot contain ":". This patch allows any scheme to appear in the proxy config. And only fallback to http scheme if parsed scheme or host is empty. url.Parse() result of fallback cases: localhost => Scheme="localhost" localhost:1234 => Scheme="localhost" Opaque="1234" example.com => Path="example.com" Updates golang/go#24135 Change-Id: Ia2c041e37e2ac61be16220fd41d6cb6fabeeca3d Reviewed-on: https://go-review.googlesource.com/c/net/+/525257 LUCI-TryBot-Result: Go LUCI Run-TryBot: Damien Neil Reviewed-by: Michael Knyszek Reviewed-by: Damien Neil TryBot-Result: Gopher Robot Auto-Submit: Damien Neil --- http/httpproxy/proxy.go | 5 +---- http/httpproxy/proxy_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index c3bd9a1ee..6404aaf15 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -149,10 +149,7 @@ func parseProxy(proxy string) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) - if err != nil || - (proxyURL.Scheme != "http" && - proxyURL.Scheme != "https" && - proxyURL.Scheme != "socks5") { + if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" { // proxy was bogus. Try prepending "http://" to it and // see if that parses correctly. If not, we fall // through and complain about the original one. diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index d76373295..790afdab7 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "cache.corp.example.com", }, want: "http://cache.corp.example.com", +}, { + // single label domain is recognized as scheme by url.Parse + cfg: httpproxy.Config{ + HTTPProxy: "localhost", + }, + want: "http://localhost", }, { cfg: httpproxy.Config{ HTTPProxy: "https://cache.corp.example.com", @@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "socks5://127.0.0.1", }, want: "socks5://127.0.0.1", +}, { + // Preserve unknown schemes. + cfg: httpproxy.Config{ + HTTPProxy: "foo://host", + }, + want: "foo://host", }, { // Don't use secure for http cfg: httpproxy.Config{ From ea095bc79b94b4bdc6939ce2dbd0300520089a1f Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Mon, 19 Feb 2024 19:53:31 +0800 Subject: [PATCH 54/70] http2: only set up positive deadlines Fixes golang/go#65785 Change-Id: Icd95d7cae5ed26b8a2fe656daf8365e27a7785d8 Reviewed-on: https://go-review.googlesource.com/c/net/+/565195 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Reviewed-by: Carlos Amedee --- http2/server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/http2/server.go b/http2/server.go index ae94c6408..905206f3e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -434,7 +434,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -2017,7 +2017,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } @@ -2038,7 +2038,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { // Disable any read deadline set by the net/http package // prior to the upgrade. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -2116,7 +2116,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } From 57a6a7a86bc0e47508781ed988adcecbe8ff2580 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sun, 25 Feb 2024 15:09:48 +0800 Subject: [PATCH 55/70] http2: prevent uninitialized pipe from being written For golang/go#65927 Change-Id: I6f48706156384e026968cf9a6d9e0ec76b46fabf Reviewed-on: https://go-review.googlesource.com/c/net/+/566675 Reviewed-by: Damien Neil Reviewed-by: Carlos Amedee LUCI-TryBot-Result: Go LUCI --- http2/pipe.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/http2/pipe.go b/http2/pipe.go index 684d984fd..3b9f06b96 100644 --- a/http2/pipe.go +++ b/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } From d600ae05799943851536e26ab37ee23294912c3d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 9 Feb 2024 18:11:52 -0800 Subject: [PATCH 56/70] http2: add testClientConn for testing client RoundTrips Many RoundTrip tests involve testing against a test-defined server with specific behaviors. For example: Testing RoundTrip's behavior when the server violates flow control limits. Existing tests mostly use the clientTester type, which starts separate goroutines for the Transport and a fake server. This results in tests where the control flow bounces around the test function, and requires each test to manage its own synchronization. Introduce a new framework for writing RoundTrip tests. testClientConn allows client tests to be written linearly, with synchronization provided by the test framework. For example, a testClientConn test can, as a linear sequence of actions: - start RoundTrip; - check the request headers sent; - provide data to the request body; - check that a DATA frame is sent; - send response headers from the server to the client; - check that RoundTrip returns. See TestTestClientConn at the top of clientconn_test.go for a full example. To enable synchronization with tests, this CL instruments the RoundTrip path to record when goroutines start, exit, and block waiting for events. This adds a certain amount of noise and bookkeeping to the client implementation, but (in my opinion) this is more than repaid in improved testability. The testClientConn also permits use of synthetic time in tests. At the moment, this is limited to the response header timeout, but extending it to other timeouts (read, 100-continue) should be straightforward. This CL converts a number of existing clientTester tests to use the new framework, but not all. Change-Id: Ief963889969363ec8469cd3c3de0becb2fc548f9 Reviewed-on: https://go-review.googlesource.com/c/net/+/563540 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 634 +++++++++++++++++++++ http2/testsync.go | 246 ++++++++ http2/transport.go | 145 ++++- http2/transport_test.go | 1154 ++++++++++++++------------------------ 4 files changed, 1414 insertions(+), 765 deletions(-) create mode 100644 http2/clientconn_test.go create mode 100644 http2/testsync.go diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go new file mode 100644 index 000000000..6d94762e5 --- /dev/null +++ b/http2/clientconn_test.go @@ -0,0 +1,634 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Infrastructure for testing ClientConn.RoundTrip. +// Put actual tests in transport_test.go. + +package http2 + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "reflect" + "testing" + "time" + + "golang.org/x/net/http2/hpack" +) + +// TestTestClientConn demonstrates usage of testClientConn. +func TestTestClientConn(t *testing.T) { + // newTestClientConn creates a *ClientConn and surrounding test infrastructure. + tc := newTestClientConn(t) + + // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames, + // and sends a SETTINGS frame to the client. + // + // Additional settings may be provided as optional parameters to greet. + tc.greet() + + // Request bodies must either be constant (bytes.Buffer, strings.Reader) + // or created with newRequestBody. + body := tc.newRequestBody() + body.writeBytes(10) // 10 arbitrary bytes... + body.closeWithError(io.EOF) // ...followed by EOF. + + // tc.roundTrip calls RoundTrip, but does not wait for it to return. + // It returns a testRoundTrip. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + // tc has a number of methods to check for expected frames sent. + // Here, we look for headers and the request body. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + // Expect 10 bytes of request body in DATA frames. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: 10, + }) + + // tc.writeHeaders sends a HEADERS frame back to the client. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // Now that we've received headers, RoundTrip has finished. + // testRoundTrip has various methods to examine the response, + // or to fetch the response and/or error returned by RoundTrip + rt.wantStatus(200) + rt.wantBody(nil) +} + +// A testClientConn allows testing ClientConn.RoundTrip against a fake server. +// +// A test using testClientConn consists of: +// - actions on the client (calling RoundTrip, making data available to Request.Body); +// - validation of frames sent by the client to the server; and +// - providing frames from the server to the client. +// +// testClientConn manages synchronization, so tests can generally be written as +// a linear sequence of actions and validations without additional synchronization. +type testClientConn struct { + t *testing.T + + tr *Transport + fr *Framer + cc *ClientConn + hooks *testSyncHooks + + encbuf bytes.Buffer + enc *hpack.Encoder + + roundtrips []*testRoundTrip + + rerr error // returned by Read + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn +} + +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tr := &Transport{} + for _, o := range opts { + o(tr) + } + + tc := &testClientConn{ + t: t, + tr: tr, + hooks: newTestSyncHooks(), + } + tc.enc = hpack.NewEncoder(&tc.encbuf) + tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) + tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + tc.fr.SetMaxReadFrameSize(10 << 20) + + t.Cleanup(func() { + if tc.rerr == nil { + tc.rerr = io.EOF + } + tc.sync() + if tc.hooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tc.hooks.total) + } + + }) + + tc.hooks.newclientconn = func(cc *ClientConn) { + tc.cc = cc + } + const singleUse = false + _, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks) + if err != nil { + t.Fatal(err) + } + tc.sync() + tc.hooks.newclientconn = nil + + // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. + buf := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { + t.Fatalf("reading preface: %v", err) + } + if !bytes.Equal(buf, clientPreface) { + t.Fatalf("client preface: %q, want %q", buf, clientPreface) + } + + return tc +} + +// sync waits for the ClientConn under test to reach a stable state, +// with all goroutines blocked on some input. +func (tc *testClientConn) sync() { + tc.hooks.waitInactive() +} + +// advance advances synthetic time by a duration. +func (tc *testClientConn) advance(d time.Duration) { + tc.hooks.advance(d) + tc.sync() +} + +// hasFrame reports whether a frame is available to be read. +func (tc *testClientConn) hasFrame() bool { + return tc.wbuf.Len() > 0 +} + +// readFrame reads the next frame from the conn. +func (tc *testClientConn) readFrame() Frame { + if tc.wbuf.Len() == 0 { + return nil + } + fr, err := tc.fr.ReadFrame() + if err != nil { + return nil + } + return fr +} + +// testClientConnReadFrame reads a frame of a specific type from the conn. +func testClientConnReadFrame[T any](tc *testClientConn) T { + tc.t.Helper() + var v T + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", v) + } + v, ok := fr.(T) + if !ok { + tc.t.Fatalf("got frame %T, want %T", fr, v) + } + return v +} + +// wantFrameType reads the next frame from the conn. +// It produces an error if the frame type is not the expected value. +func (tc *testClientConn) wantFrameType(want FrameType) { + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", want) + } + if got := fr.Header().Type; got != want { + tc.t.Fatalf("got frame %v, want %v", got, want) + } +} + +type wantHeader struct { + streamID uint32 + endStream bool + header http.Header +} + +// wantHeaders reads a HEADERS frame and potential CONTINUATION frames, +// and asserts that they contain the expected headers. +func (tc *testClientConn) wantHeaders(want wantHeader) { + fr := tc.readFrame() + got, ok := fr.(*MetaHeadersFrame) + if !ok { + tc.t.Fatalf("got %v, want HEADERS frame", want) + } + if got, want := got.StreamID, want.streamID; got != want { + tc.t.Fatalf("got stream ID %v, want %v", got, want) + } + if got, want := got.StreamEnded(), want.endStream; got != want { + tc.t.Fatalf("got stream ended %v, want %v", got, want) + } + gotHeader := make(http.Header) + for _, f := range got.Fields { + gotHeader[f.Name] = append(gotHeader[f.Name], f.Value) + } + for k, v := range want.header { + if !reflect.DeepEqual(v, gotHeader[k]) { + tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) + } + } +} + +type wantData struct { + streamID uint32 + endStream bool + size int +} + +// wantData reads zero or more DATA frames, and asserts that they match the expectation. +func (tc *testClientConn) wantData(want wantData) { + tc.t.Helper() + gotSize := 0 + gotEndStream := false + for tc.hasFrame() && !gotEndStream { + data := testClientConnReadFrame[*DataFrame](tc) + gotSize += len(data.Data()) + if data.StreamEnded() { + gotEndStream = true + } + } + if gotSize != want.size { + tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size) + } + if gotEndStream != want.endStream { + tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream) + } +} + +// testRequestBody is a Request.Body for use in tests. +type testRequestBody struct { + tc *testClientConn + + // At most one of buf or bytes can be set at any given time: + buf bytes.Buffer // specific bytes to read from the body + bytes int // body contains this many arbitrary bytes + + err error // read error (comes after any available bytes) +} + +func (tc *testClientConn) newRequestBody() *testRequestBody { + b := &testRequestBody{ + tc: tc, + } + return b +} + +// Read is called by the ClientConn to read from a request body. +func (b *testRequestBody) Read(p []byte) (n int, _ error) { + b.tc.cc.syncHooks.blockUntil(func() bool { + return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil + }) + switch { + case b.buf.Len() > 0: + return b.buf.Read(p) + case b.bytes > 0: + if len(p) > b.bytes { + p = p[:b.bytes] + } + b.bytes -= len(p) + for i := range p { + p[i] = 'A' + } + return len(p), nil + default: + return 0, b.err + } +} + +// Close is called by the ClientConn when it is done reading from a request body. +func (b *testRequestBody) Close() error { + return nil +} + +// writeBytes adds n arbitrary bytes to the body. +func (b *testRequestBody) writeBytes(n int) { + b.bytes += n + b.checkWrite() + b.tc.sync() +} + +// Write adds bytes to the body. +func (b *testRequestBody) Write(p []byte) (int, error) { + n, err := b.buf.Write(p) + b.checkWrite() + b.tc.sync() + return n, err +} + +func (b *testRequestBody) checkWrite() { + if b.bytes > 0 && b.buf.Len() > 0 { + b.tc.t.Fatalf("can't interleave Write and writeBytes on request body") + } + if b.err != nil { + b.tc.t.Fatalf("can't write to request body after closeWithError") + } +} + +// closeWithError sets an error which will be returned by Read. +func (b *testRequestBody) closeWithError(err error) { + b.err = err + b.tc.sync() +} + +// roundTrip starts a RoundTrip call. +// +// (Note that the RoundTrip won't complete until response headers are received, +// the request times out, or some other terminal condition is reached.) +func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + tc: tc, + donec: make(chan struct{}), + } + tc.roundtrips = append(tc.roundtrips, rt) + tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs } + tc.cc.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tc.cc.RoundTrip(req) + }) + tc.sync() + tc.hooks.newstream = nil + + tc.t.Cleanup(func() { + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} + +func (tc *testClientConn) greet(settings ...Setting) { + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings(settings...) + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement +} + +func (tc *testClientConn) writeSettings(settings ...Setting) { + tc.t.Helper() + if err := tc.fr.WriteSettings(settings...); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeSettingsAck() { + tc.t.Helper() + if err := tc.fr.WriteSettingsAck(); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) { + tc.t.Helper() + if err := tc.fr.WriteData(streamID, endStream, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// makeHeaderBlockFragment encodes headers in a form suitable for inclusion +// in a HEADERS or CONTINUATION frame. +// +// It takes a list of alernating names and values. +func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { + if len(s)%2 != 0 { + tc.t.Fatalf("uneven list of header name/value pairs") + } + tc.encbuf.Reset() + for i := 0; i < len(s); i += 2 { + tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]}) + } + return tc.encbuf.Bytes() +} + +func (tc *testClientConn) writeHeaders(p HeadersFrameParam) { + tc.t.Helper() + if err := tc.fr.WriteHeaders(p); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// writeHeadersMode writes header frames, as modified by mode: +// +// - noHeader: Don't write the header. +// - oneHeader: Write a single HEADERS frame. +// - splitHeader: Write a HEADERS frame and CONTINUATION frame. +func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) { + tc.t.Helper() + switch mode { + case noHeader: + case oneHeader: + tc.writeHeaders(p) + case splitHeader: + if len(p.BlockFragment) < 2 { + panic("too small") + } + contData := p.BlockFragment[1:] + contEnd := p.EndHeaders + p.BlockFragment = p.BlockFragment[:1] + p.EndHeaders = false + tc.writeHeaders(p) + tc.writeContinuation(p.StreamID, contEnd, contData) + default: + panic("bogus mode") + } +} + +func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) { + tc.t.Helper() + if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { + tc.t.Helper() + if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) { + tc.t.Helper() + if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Read calls. +func (tc *testClientConn) closeWrite(err error) { + tc.rerr = err + tc.sync() +} + +// testRoundTrip manages a RoundTrip in progress. +type testRoundTrip struct { + tc *testClientConn + resp *http.Response + respErr error + donec chan struct{} + cs *clientStream +} + +// streamID returns the HTTP/2 stream ID of the request. +func (rt *testRoundTrip) streamID() uint32 { + return rt.cs.ID +} + +// done reports whether RoundTrip has returned. +func (rt *testRoundTrip) done() bool { + select { + case <-rt.donec: + return true + default: + return false + } +} + +// result returns the result of the RoundTrip. +func (rt *testRoundTrip) result() (*http.Response, error) { + t := rt.tc.t + t.Helper() + select { + case <-rt.donec: + default: + t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID()) + } + return rt.resp, rt.respErr +} + +// response returns the response of a successful RoundTrip. +// If the RoundTrip unexpectedly failed, it calls t.Fatal. +func (rt *testRoundTrip) response() *http.Response { + t := rt.tc.t + t.Helper() + resp, err := rt.result() + if err != nil { + t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) + } + if resp == nil { + t.Fatalf("RoundTrip returned nil *Response and nil error") + } + return resp +} + +// err returns the (possibly nil) error result of RoundTrip. +func (rt *testRoundTrip) err() error { + t := rt.tc.t + t.Helper() + _, err := rt.result() + return err +} + +// wantStatus indicates the expected response StatusCode. +func (rt *testRoundTrip) wantStatus(want int) { + t := rt.tc.t + t.Helper() + if got := rt.response().StatusCode; got != want { + t.Fatalf("got response status %v, want %v", got, want) + } +} + +// body reads the contents of the response body. +func (rt *testRoundTrip) readBody() ([]byte, error) { + t := rt.tc.t + t.Helper() + return io.ReadAll(rt.response().Body) +} + +// wantBody indicates the expected response body. +// (Note that this consumes the body.) +func (rt *testRoundTrip) wantBody(want []byte) { + t := rt.tc.t + t.Helper() + got, err := rt.readBody() + if err != nil { + t.Fatalf("unexpected error reading response body: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) + } +} + +// wantHeaders indicates the expected response headers. +func (rt *testRoundTrip) wantHeaders(want http.Header) { + t := rt.tc.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Header, want); diff != "" { + t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +// wantTrailers indicates the expected response trailers. +func (rt *testRoundTrip) wantTrailers(want http.Header) { + t := rt.tc.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Trailer, want); diff != "" { + t.Fatalf("unexpected response trailers:\n%v", diff) + } +} + +func diffHeaders(got, want http.Header) string { + // nil and 0-length non-nil are equal. + if len(got) == 0 && len(want) == 0 { + return "" + } + // We could do a more sophisticated diff here. + // DeepEqual is good enough for now. + if reflect.DeepEqual(got, want) { + return "" + } + return fmt.Sprintf("got: %v\nwant: %v", got, want) +} + +// testClientConnNetConn implements net.Conn. +type testClientConnNetConn testClientConn + +func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) { + nc.cc.syncHooks.blockUntil(func() bool { + return nc.rerr != nil || nc.rbuf.Len() > 0 + }) + if nc.rbuf.Len() > 0 { + return nc.rbuf.Read(b) + } + return 0, nc.rerr +} + +func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { + return nc.wbuf.Write(b) +} + +func (*testClientConnNetConn) Close() error { + return nil +} + +func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/http2/testsync.go b/http2/testsync.go new file mode 100644 index 000000000..b8335c0fb --- /dev/null +++ b/http2/testsync.go @@ -0,0 +1,246 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import ( + "sync" + "time" +) + +// testSyncHooks coordinates goroutines in tests. +// +// For example, a call to ClientConn.RoundTrip involves several goroutines, including: +// - the goroutine running RoundTrip; +// - the clientStream.doRequest goroutine, which writes the request; and +// - the clientStream.readLoop goroutine, which reads the response. +// +// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines +// are blocked waiting for some condition such as reading the Request.Body or waiting for +// flow control to become available. +// +// The testSyncHooks also manage timers and synthetic time in tests. +// This permits us to, for example, start a request and cause it to time out waiting for +// response headers without resorting to time.Sleep calls. +type testSyncHooks struct { + // active/inactive act as a mutex and condition variable. + // + // - neither chan contains a value: testSyncHooks is locked. + // - active contains a value: unlocked, and at least one goroutine is not blocked + // - inactive contains a value: unlocked, and all goroutines are blocked + active chan struct{} + inactive chan struct{} + + // goroutine counts + total int // total goroutines + condwait map[*sync.Cond]int // blocked in sync.Cond.Wait + blocked []*testBlockedGoroutine // otherwise blocked + + // fake time + now time.Time + timers []*fakeTimer + + // Transport testing: Report various events. + newclientconn func(*ClientConn) + newstream func(*clientStream) +} + +// testBlockedGoroutine is a blocked goroutine. +type testBlockedGoroutine struct { + f func() bool // blocked until f returns true + ch chan struct{} // closed when unblocked +} + +func newTestSyncHooks() *testSyncHooks { + h := &testSyncHooks{ + active: make(chan struct{}, 1), + inactive: make(chan struct{}, 1), + condwait: map[*sync.Cond]int{}, + } + h.inactive <- struct{}{} + h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return h +} + +// lock acquires the testSyncHooks mutex. +func (h *testSyncHooks) lock() { + select { + case <-h.active: + case <-h.inactive: + } +} + +// waitInactive waits for all goroutines to become inactive. +func (h *testSyncHooks) waitInactive() { + for { + <-h.inactive + if !h.unlock() { + break + } + } +} + +// unlock releases the testSyncHooks mutex. +// It reports whether any goroutines are active. +func (h *testSyncHooks) unlock() (active bool) { + // Look for a blocked goroutine which can be unblocked. + blocked := h.blocked[:0] + unblocked := false + for _, b := range h.blocked { + if !unblocked && b.f() { + unblocked = true + close(b.ch) + } else { + blocked = append(blocked, b) + } + } + h.blocked = blocked + + // Count goroutines blocked on condition variables. + condwait := 0 + for _, count := range h.condwait { + condwait += count + } + + if h.total > condwait+len(blocked) { + h.active <- struct{}{} + return true + } else { + h.inactive <- struct{}{} + return false + } +} + +// goRun starts a new goroutine. +func (h *testSyncHooks) goRun(f func()) { + h.lock() + h.total++ + h.unlock() + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + f() + }() +} + +// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. +// It waits until f returns true before proceeding. +// +// Example usage: +// +// h.blockUntil(func() bool { +// // Is the context done yet? +// select { +// case <-ctx.Done(): +// default: +// return false +// } +// return true +// }) +// // Wait for the context to become done. +// <-ctx.Done() +// +// The function f passed to blockUntil must be non-blocking and idempotent. +func (h *testSyncHooks) blockUntil(f func() bool) { + if f() { + return + } + ch := make(chan struct{}) + h.lock() + h.blocked = append(h.blocked, &testBlockedGoroutine{ + f: f, + ch: ch, + }) + h.unlock() + <-ch +} + +// broadcast is sync.Cond.Broadcast. +func (h *testSyncHooks) condBroadcast(cond *sync.Cond) { + h.lock() + delete(h.condwait, cond) + h.unlock() + cond.Broadcast() +} + +// broadcast is sync.Cond.Wait. +func (h *testSyncHooks) condWait(cond *sync.Cond) { + h.lock() + h.condwait[cond]++ + h.unlock() +} + +// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests. +func (h *testSyncHooks) newTimer(d time.Duration) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + when: h.now.Add(d), + c: make(chan time.Time), + } + h.timers = append(h.timers, t) + return t +} + +// advance advances time and causes synthetic timers to fire. +func (h *testSyncHooks) advance(d time.Duration) { + h.lock() + defer h.unlock() + h.now = h.now.Add(d) + timers := h.timers[:0] + for _, t := range h.timers { + t.mu.Lock() + switch { + case t.when.After(h.now): + timers = append(timers, t) + case t.when.IsZero(): + // stopped timer + default: + t.when = time.Time{} + close(t.c) + } + t.mu.Unlock() + } + h.timers = timers +} + +// A timer wraps a time.Timer, or a synthetic equivalent in tests. +// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. +type timer interface { + C() <-chan time.Time + Stop() bool +} + +type timeTimer struct { + t *time.Timer + c chan time.Time +} + +func newTimeTimer(d time.Duration) timer { + ch := make(chan time.Time) + t := time.AfterFunc(d, func() { + close(ch) + }) + return &timeTimer{t, ch} +} + +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } + +type fakeTimer struct { + mu sync.Mutex + when time.Time + c chan time.Time +} + +func (t *fakeTimer) C() <-chan time.Time { return t.c } +func (t *fakeTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + stopped := t.when.IsZero() + t.when = time.Time{} + return stopped +} diff --git a/http2/transport.go b/http2/transport.go index b599197e7..04db29275 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -350,6 +350,45 @@ type ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + + syncHooks *testSyncHooks // can be nil +} + +// Hook points used for testing. +// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +// goRun starts a new goroutine. +func (cc *ClientConn) goRun(f func()) { + if cc.syncHooks != nil { + cc.syncHooks.goRun(f) + return + } + go f() +} + +// condBroadcast is cc.cond.Broadcast. +func (cc *ClientConn) condBroadcast() { + if cc.syncHooks != nil { + cc.syncHooks.condBroadcast(cc.cond) + } + cc.cond.Broadcast() +} + +// condWait is cc.cond.Wait. +func (cc *ClientConn) condWait() { + if cc.syncHooks != nil { + cc.syncHooks.condWait(cc.cond) + } + cc.cond.Wait() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (cc *ClientConn) newTimer(d time.Duration) timer { + if cc.syncHooks != nil { + return cc.syncHooks.newTimer(d) + } + return newTimeTimer(d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -431,7 +470,7 @@ func (cs *clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() + cs.cc.condBroadcast() } } @@ -441,7 +480,7 @@ func (cs *clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.cond.Broadcast() + cc.condBroadcast() } } @@ -451,10 +490,10 @@ func (cs *clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - go func() { + cs.cc.goRun(func() { cs.reqBody.Close() close(reqBodyClosed) - }() + }) } type stickyErrWriter struct { @@ -672,7 +711,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -738,10 +777,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) { cc := &ClientConn{ t: t, tconn: c, @@ -756,6 +795,10 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), + syncHooks: hooks, + } + if hooks != nil { + hooks.newclientconn(cc) } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -824,7 +867,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - go cc.readLoop() + cc.goRun(cc.readLoop) return cc, nil } @@ -1062,7 +1105,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - go func() { + cc.goRun(func() { cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1074,9 +1117,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.cond.Wait() + cc.condWait() } - }() + }) shutdownEnterWaitStateHook() select { case <-done: @@ -1086,7 +1129,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() return ctx.Err() } @@ -1124,7 +1167,7 @@ func (cc *ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() cc.closeConn() } @@ -1221,6 +1264,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1235,9 +1282,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + cc.goRun(func() { + cs.doRequest(req) + }) waitDone := func() error { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.donec: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.donec: return nil @@ -1298,7 +1359,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return err } + if streamf != nil { + streamf(cs) + } + for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -1378,6 +1456,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() + if cc.syncHooks != nil { + cc.syncHooks.newstream(cs) + } + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -1458,15 +1540,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.peerClosed: + case <-respHeaderTimer: + case <-respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.peerClosed: return nil @@ -1615,7 +1712,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { return nil } cc.pendingRequests++ - cc.cond.Wait() + cc.condWait() cc.pendingRequests-- select { case <-cs.abort: @@ -1877,7 +1974,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) cs.flow.take(take) return take, nil } - cc.cond.Wait() + cc.condWait() } } @@ -2149,7 +2246,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + cc.condBroadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -2237,7 +2334,7 @@ func (rl *clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() } @@ -2873,7 +2970,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { for _, cs := range cc.streams { cs.flow.add(delta) } - cc.cond.Broadcast() + cc.condBroadcast() cc.initialWindowSize = s.Val case SettingHeaderTableSize: @@ -2928,7 +3025,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { return ConnectionError(ErrCodeFlowControl) } - cc.cond.Broadcast() + cc.condBroadcast() return nil } @@ -2971,7 +3068,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error { cc.mu.Unlock() } errc := make(chan error, 1) - go func() { + cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() if err := cc.fr.WritePing(false, p); err != nil { @@ -2982,7 +3079,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error { errc <- err return } - }() + }) select { case <-c: return nil diff --git a/http2/transport_test.go b/http2/transport_test.go index 6ac8e978b..f889cd12b 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -1014,131 +1014,65 @@ func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyA func testTransportReqBodyAfterResponse(t *testing.T, status int) { const bodySize = 10 << 20 - clientDone := make(chan struct{}) - ct := newClientTester(t) - recvLen := make(chan int64, 1) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - body := &pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - if res.StatusCode != status { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) - } - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - body.CloseWithError(io.EOF) - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("unexpected body: %q", slurp) - } - res.Body.Close() - if status == 200 { - if got := <-recvLen; got != bodySize { - return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) - } - } else { - if got := <-recvLen; got == 0 || got >= bodySize { - return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + body := tc.newRequestBody() + body.writeBytes(bodySize / 2) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + + // Provide enough congestion window for the full request body. + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: bodySize / 2, + }) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", strconv.Itoa(status), + ), + }) + + res := rt.response() + if res.StatusCode != status { + t.Fatalf("status code = %v; want %v", res.StatusCode, status) } - ct.server = func() error { - ct.greet() - defer close(recvLen) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var dataRecv int64 - var closed bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - //println(fmt.Sprintf("server got frame: %v", f)) - ended := false - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - if f.StreamEnded() { - return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) - } - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if dataRecv == 0 { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - dataRecv += int64(dataLen) - if !closed && ((status != 200 && dataRecv > 0) || - (status == 200 && f.StreamEnded())) { - closed = true - if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { - return err - } - } + body.writeBytes(bodySize / 2) + body.closeWithError(io.EOF) - if f.StreamEnded() { - ended = true - } - case *RSTStreamFrame: - if status == 200 { - return fmt.Errorf("Unexpected client frame %v", f) - } - ended = true - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - if ended { - select { - case recvLen <- dataRecv: - default: - } - } - } + if status == 200 { + // After a 200 response, client sends the remaining request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: bodySize / 2, + }) + } else { + // After a 403 response, client gives up and resets the stream. + tc.wantFrameType(FrameRSTStream) } - ct.run() + + rt.wantBody(nil) } // See golang.org/issue/13444 @@ -1319,121 +1253,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy panic("invalid combination") } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) - if expect100Continue != noHeader { - req.Header.Set("Expect", "100-continue") - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - wantBody := resBody - if !withData { - wantBody = "" - } - if string(slurp) != wantBody { - return fmt.Errorf("body = %q; want %q", slurp, wantBody) - } - if trailers == noHeader { - if len(res.Trailer) > 0 { - t.Errorf("Trailer = %v; want none", res.Trailer) - } - } else { - want := http.Header{"Some-Trailer": {"some-value"}} - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Trailer = %v; want %v", res.Trailer, want) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - endStream := false - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *DataFrame: - if !f.StreamEnded() { - // No need to send flow control tokens. The test request body is tiny. - continue - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) - enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) - if trailers != noHeader { - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) - } - endStream = withData == false && trailers == noHeader - send(resHeader) - } - if withData { - endStream = trailers == noHeader - ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) - } - if trailers != noHeader { - endStream = true - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) - send(trailers) - } - if endStream { - return nil - } - case *HeadersFrame: - if expect100Continue != noHeader { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) - send(expect100Continue) - } - } - } + tc.wantFrameType(FrameHeaders) + + // Possibly 100-continue, or skip when noHeader. + tc.writeHeadersMode(expect100Continue, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + + // Client sends request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: len(reqBody), + }) + + hdr := []string{ + ":status", "200", + "x-foo", "blah", + "x-bar", "more", + } + if trailers != noHeader { + hdr = append(hdr, "trailer", "some-trailer") + } + tc.writeHeadersMode(resHeader, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: withData == false && trailers == noHeader, + BlockFragment: tc.makeHeaderBlockFragment(hdr...), + }) + if withData { + endStream := trailers == noHeader + tc.writeData(rt.streamID(), endStream, []byte(resBody)) + } + tc.writeHeadersMode(trailers, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "some-value", + ), + }) + + rt.wantStatus(200) + if !withData { + rt.wantBody(nil) + } else { + rt.wantBody([]byte(resBody)) + } + if trailers == noHeader { + rt.wantTrailers(nil) + } else { + rt.wantTrailers(http.Header{ + "Some-Trailer": {"some-value"}, + }) } - ct.run() } // Issue 26189, Issue 17739: ignore unknown 1xx responses @@ -1445,130 +1332,76 @@ func TestTransportUnknown1xx(t *testing.T) { return nil } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 204 { - return fmt.Errorf("status code = %v; want 204", res.StatusCode) - } - want := `code=110 header=map[Foo-Bar:[110]] + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + for i := 110; i <= 114; i++ { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", fmt.Sprint(i), + "foo-bar", fmt.Sprint(i), + ), + }) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) + + res := rt.response() + if res.StatusCode != 204 { + t.Fatalf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] code=111 header=map[Foo-Bar:[111]] code=112 header=map[Foo-Bar:[112]] code=113 header=map[Foo-Bar:[113]] code=114 header=map[Foo-Bar:[114]] ` - if got := buf.String(); got != want { - t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - for i := 110; i <= 114; i++ { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) - enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) } - ct.run() - } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - if _, ok := res.Trailer["Some-Trailer"]; !ok { - return fmt.Errorf("expected Some-Trailer") - } - return nil - } - ct.server = func() error { - ct.greet() - - var n int - var hf *HeadersFrame - for hf == nil && n < 10 { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - hf, _ = f.(*HeadersFrame) - n++ - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - // send headers without Trailer header - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "I'm an undeclared Trailer!", + ), + }) - // send trailers - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - ct.run() + rt.wantStatus(200) + rt.wantBody(nil) + rt.wantTrailers(http.Header{ + "Some-Trailer": []string{"I'm an undeclared Trailer!"}, + }) } func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { @@ -1578,10 +1411,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - }) + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), + ":colon", "foo", + "foo", "bar", + ) } func TestTransportInvalidTrailer_Capital1(t *testing.T) { @@ -1591,102 +1424,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) - }) + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), + "foo", "bar", + "Capital", "bad", + ) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), + "", "bad", + ) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), + "x", "has\nnewline", + ) } -func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - se, ok := err.(StreamError) - if !ok || se.Cause != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) +func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "trailer", "declared", + ), + }) + tc.writeHeadersMode(mode, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment(trailers...), + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var endStream bool - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) - endStream = false - send(oneHeader) - } - // Trailers: - { - endStream = true - buf.Reset() - writeTrailer(enc) - send(trailers) - } - return nil - } - } + rt.wantStatus(200) + body, err := rt.readBody() + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr) + } + if len(body) > 0 { + t.Fatalf("body = %q; want nothing", body) } - ct.run() } // headerListSize returns the HTTP2 header list size of h. @@ -1962,115 +1747,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if e, ok := err.(StreamError); ok { - err = e.Cause - } - if err != errResponseHeaderListSize { - size := int64(0) - if res != nil { - res.Body.Close() - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 - } - } - } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + hdr := []string{":status", "200"} + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + hdr = append(hdr, large, large) + } + hbf := tc.makeHeaderBlockFragment(hdr...) + // Note: this number might change if our hpack implementation changes. + // That's fine. This is just a sanity check that our response can fit in a single + // header block fragment frame. + if size, want := len(hbf), 6329; size != want { + t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: hbf, + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - large := strings.Repeat("a", 1<<10) - for i := 0; i < 5042; i++ { - enc.WriteField(hpack.HeaderField{Name: large, Value: large}) - } - if size, want := buf.Len(), 6329; size != want { - // Note: this number might change if - // our hpack implementation - // changes. That's fine. This is - // just a sanity check that our - // response can fit in a single - // header block fragment frame. - return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + res, err := rt.result() + if e, ok := err.(StreamError); ok { + err = e.Cause + } + if err != errResponseHeaderListSize { + size := int64(0) + if res != nil { + res.Body.Close() + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil } } + t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } - ct.run() } func TestTransportCookieHeaderSplit(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - req.Header.Add("Cookie", "a=b;c=d; e=f;") - req.Header.Add("Cookie", "e=f;g=h; ") - req.Header.Add("Cookie", "i=j") - _, err := ct.tr.RoundTrip(req) - return err - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - dec := hpack.NewDecoder(initialHeaderTableSize, nil) - hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) - if err != nil { - return err - } - got := []string{} - want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} - for _, hf := range hfs { - if hf.Name == "cookie" { - got = append(got, hf.Value) - } - } - if !reflect.DeepEqual(got, want) { - t.Errorf("Cookies = %#v, want %#v", got, want) - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + req.Header.Add("Cookie", "a=b;c=d; e=f;") + req.Header.Add("Cookie", "e=f;g=h; ") + req.Header.Add("Cookie", "i=j") + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}, + }, + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if err := rt.err(); err != nil { + t.Fatalf("RoundTrip = %v, want success", err) } - ct.run() } // Test that the Transport returns a typed error from Response.Body.Read calls @@ -2286,55 +2036,49 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { } func testTransportResponseHeaderTimeout(t *testing.T, body bool) { - ct := newClientTester(t) - ct.tr.t1 = &http.Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, - } - ct.client = func() error { - c := &http.Client{Transport: ct.tr} - var err error - var n int64 - const bodySize = 4 << 20 - if body { - _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) - } else { - _, err = c.Get("https://dummy.tld/") - } - if !isTimeout(err) { - t.Errorf("client expected timeout error; got %#v", err) - } - if body && n != bodySize { - t.Errorf("only read %d bytes of body; want %d", n, bodySize) + const bodySize = 4 << 20 + tc := newTestClientConn(t, func(tr *Transport) { + tr.t1 = &http.Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, } - return nil + }) + tc.greet() + + var req *http.Request + var reqBody *testRequestBody + if body { + reqBody = tc.newRequestBody() + reqBody.writeBytes(bodySize) + reqBody.closeWithError(io.EOF) + req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody) + req.Header.Set("Content-Type", "text/foo") + } else { + req, _ = http.NewRequest("GET", "https://dummy.tld/", nil) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - switch f := f.(type) { - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - case *RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { - return nil - } - } - } + + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + if body { + tc.wantData(wantData{ + endStream: true, + size: bodySize, + }) + } + + tc.advance(4 * time.Millisecond) + if rt.done() { + t.Fatalf("RoundTrip is done after 4ms; want still waiting") + } + tc.advance(1 * time.Millisecond) + + if err := rt.err(); !isTimeout(err) { + t.Fatalf("RoundTrip error: %v; want timeout error", err) } - ct.run() } func TestTransportDisableCompression(t *testing.T) { @@ -2720,115 +2464,61 @@ func TestTransportNewTLSConfig(t *testing.T) { // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) func TestTransportReadHeadResponse(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != 123 { - return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, // as the GFE does - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, nil) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) + tc.writeData(rt.streamID(), true, nil) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) } - ct.run() + rt.wantBody(nil) } func TestTransportReadHeadResponseWithBody(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + // This test uses an invalid response format. + // Discard logger output to not spam tests output. + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != int64(len(response)) { - return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, []byte(response)) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", strconv.Itoa(len(response)), + ), + }) + tc.writeData(rt.streamID(), true, []byte(response)) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != int64(len(response)) { + t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response)) } - ct.run() + rt.wantBody(nil) } type neverEnding byte @@ -2953,71 +2643,53 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { } func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { - ct := newClientTester(t) - clientDone := make(chan struct{}) + tc := newTestClientConn(t) + tc.greet() const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if failMidBody { - if err != nil { - return fmt.Errorf("unexpected client RoundTrip error: %v", err) - } - _, err = io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - } - want := GoAwayError{ - LastStreamID: 5, - ErrCode: goAwayErrCode, - DebugData: goAwayDebugData, - } - if !reflect.DeepEqual(err, want) { - t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) - } - return nil + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + if failMidBody { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - if failMidBody { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - // Write two GOAWAY frames, to test that the Transport takes - // the interesting parts of both. - ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) - ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.sc.(*net.TCPConn).Close() - } - <-clientDone - return nil + + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) + tc.writeGoAway(5, goAwayErrCode, nil) + tc.closeWrite(io.EOF) + + res, err := rt.result() + whence := "RoundTrip" + if failMidBody { + whence = "Body.Read" + if err != nil { + t.Fatalf("RoundTrip error = %v, want success", err) } + _, err = res.Body.Read(make([]byte, 1)) + } + + want := GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want) } - ct.run() } func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { @@ -6049,7 +5721,7 @@ func TestClientConnReservations(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.newClientConn(st.cc, false) + cc, err := tr.newClientConn(st.cc, false, nil) if err != nil { t.Fatal(err) } From 12ddef72728707026a3d9adbbc28affa76faf688 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 26 Feb 2024 12:25:37 -0800 Subject: [PATCH 57/70] http2: reject DATA frames after 1xx and before final headers When checking to see if a DATA frame can be accepted, check to see if we have received a non-1xx header, not whether we have received any header. Fixes golang/go#65927 Change-Id: Id4fae1862de6179f8fc95e02dec7d4c47a7640e1 Reviewed-on: https://go-review.googlesource.com/c/net/+/567175 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/transport.go | 2 +- http2/transport_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/http2/transport.go b/http2/transport.go index 04db29275..44845bafd 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2787,7 +2787,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, diff --git a/http2/transport_test.go b/http2/transport_test.go index f889cd12b..836d45593 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -6254,3 +6254,32 @@ func TestDialRaceResumesDial(t *testing.T) { case <-successCh: } } + +func TestTransportDataAfter1xxHeader(t *testing.T) { + // Discard logger output to avoid spamming stderr. + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame. + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + tc.writeData(rt.streamID(), true, []byte{0}) + err := rt.err() + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err) + } + tc.wantFrameType(FrameRSTStream) +} From 31d9683ed011ab20a0aa6ab62de563611851a2b8 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 12:58:12 -0700 Subject: [PATCH 58/70] http2: mark several testing functions as helpers Change-Id: Ib5519fd882b3692efadd6191fbebbf042c9aa77d Reviewed-on: https://go-review.googlesource.com/c/net/+/572376 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 6d94762e5..9a5b2b013 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -191,7 +191,7 @@ func testClientConnReadFrame[T any](tc *testClientConn) T { var v T fr := tc.readFrame() if fr == nil { - tc.t.Fatalf("got no frame, want frame %v", v) + tc.t.Fatalf("got no frame, want frame %T", v) } v, ok := fr.(T) if !ok { @@ -203,6 +203,7 @@ func testClientConnReadFrame[T any](tc *testClientConn) T { // wantFrameType reads the next frame from the conn. // It produces an error if the frame type is not the expected value. func (tc *testClientConn) wantFrameType(want FrameType) { + tc.t.Helper() fr := tc.readFrame() if fr == nil { tc.t.Fatalf("got no frame, want frame %v", want) @@ -221,11 +222,8 @@ type wantHeader struct { // wantHeaders reads a HEADERS frame and potential CONTINUATION frames, // and asserts that they contain the expected headers. func (tc *testClientConn) wantHeaders(want wantHeader) { - fr := tc.readFrame() - got, ok := fr.(*MetaHeadersFrame) - if !ok { - tc.t.Fatalf("got %v, want HEADERS frame", want) - } + tc.t.Helper() + got := testClientConnReadFrame[*MetaHeadersFrame](tc) if got, want := got.StreamID, want.streamID; got != want { tc.t.Fatalf("got stream ID %v, want %v", got, want) } From 9e0498de4d22259990fc8eb8440eafd7c353c19c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:06:45 -0700 Subject: [PATCH 59/70] http2: use synthetic timers for ping timeouts in tests Change-Id: I642890519b066937ade3c13e8387c31d29e912f4 Reviewed-on: https://go-review.googlesource.com/c/net/+/572377 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/clientconn_test.go | 9 +++ http2/testsync.go | 101 +++++++++++++++++++++++++++--- http2/transport.go | 69 +++++++++++++++++---- http2/transport_test.go | 131 ++++++++++++++++++++++++--------------- 4 files changed, 240 insertions(+), 70 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 9a5b2b013..97f884c66 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { tc.fr.SetMaxReadFrameSize(10 << 20) t.Cleanup(func() { + tc.sync() if tc.rerr == nil { tc.rerr = io.EOF } @@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he tc.sync() } +func (tc *testClientConn) writePing(ack bool, data [8]byte) { + tc.t.Helper() + if err := tc.fr.WritePing(ack, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { tc.t.Helper() if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { diff --git a/http2/testsync.go b/http2/testsync.go index b8335c0fb..61075bd16 100644 --- a/http2/testsync.go +++ b/http2/testsync.go @@ -4,6 +4,7 @@ package http2 import ( + "context" "sync" "time" ) @@ -173,18 +174,56 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) { h.unlock() } -// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests. +// newTimer creates a new fake timer. func (h *testSyncHooks) newTimer(d time.Duration) timer { h.lock() defer h.unlock() t := &fakeTimer{ - when: h.now.Add(d), - c: make(chan time.Time), + hooks: h, + when: h.now.Add(d), + c: make(chan time.Time), } h.timers = append(h.timers, t) return t } +// afterFunc creates a new fake AfterFunc timer. +func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + f: f, + } + h.timers = append(h.timers, t) + return t +} + +func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + t := h.afterFunc(d, cancel) + return ctx, func() { + t.Stop() + cancel() + } +} + +func (h *testSyncHooks) timeUntilEvent() time.Duration { + h.lock() + defer h.unlock() + var next time.Time + for _, t := range h.timers { + if next.IsZero() || t.when.Before(next) { + next = t.when + } + } + if d := next.Sub(h.now); d > 0 { + return d + } + return 0 +} + // advance advances time and causes synthetic timers to fire. func (h *testSyncHooks) advance(d time.Duration) { h.lock() @@ -192,6 +231,7 @@ func (h *testSyncHooks) advance(d time.Duration) { h.now = h.now.Add(d) timers := h.timers[:0] for _, t := range h.timers { + t := t // remove after go.mod depends on go1.22 t.mu.Lock() switch { case t.when.After(h.now): @@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) { // stopped timer default: t.when = time.Time{} - close(t.c) + if t.c != nil { + close(t.c) + } + if t.f != nil { + h.total++ + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + t.f() + }() + } } t.mu.Unlock() } @@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) { type timer interface { C() <-chan time.Time Stop() bool + Reset(d time.Duration) bool } +// timeTimer implements timer using real time. type timeTimer struct { t *time.Timer c chan time.Time } +// newTimeTimer creates a new timer using real time. func newTimeTimer(d time.Duration) timer { ch := make(chan time.Time) t := time.AfterFunc(d, func() { @@ -227,16 +283,29 @@ func newTimeTimer(d time.Duration) timer { return &timeTimer{t, ch} } -func (t timeTimer) C() <-chan time.Time { return t.c } -func (t timeTimer) Stop() bool { return t.t.Stop() } +// newTimeAfterFunc creates an AfterFunc timer using real time. +func newTimeAfterFunc(d time.Duration, f func()) timer { + return &timeTimer{ + t: time.AfterFunc(d, f), + } +} +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } +func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } + +// fakeTimer implements timer using fake time. type fakeTimer struct { + hooks *testSyncHooks + mu sync.Mutex - when time.Time - c chan time.Time + when time.Time // when the timer will fire + c chan time.Time // closed when the timer fires; mutually exclusive with f + f func() // called when the timer fires; mutually exclusive with c } func (t *fakeTimer) C() <-chan time.Time { return t.c } + func (t *fakeTimer) Stop() bool { t.mu.Lock() defer t.mu.Unlock() @@ -244,3 +313,19 @@ func (t *fakeTimer) Stop() bool { t.when = time.Time{} return stopped } + +func (t *fakeTimer) Reset(d time.Duration) bool { + if t.c != nil || t.f == nil { + panic("fakeTimer only supports Reset on AfterFunc timers") + } + t.mu.Lock() + defer t.mu.Unlock() + t.hooks.lock() + defer t.hooks.unlock() + active := !t.when.IsZero() + t.when = t.hooks.now.Add(d) + if !active { + t.hooks.timers = append(t.hooks.timers, t) + } + return active +} diff --git a/http2/transport.go b/http2/transport.go index 44845bafd..1ce5f125c 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer { return newTimeTimer(d) } +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer { + if cc.syncHooks != nil { + return cc.syncHooks.afterFunc(d, f) + } + return newTimeAfterFunc(d, f) +} + +func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if cc.syncHooks != nil { + return cc.syncHooks.contextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) +} + // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type clientStream struct { @@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } + var newStreamHook func(*clientStream) + if cc.syncHooks != nil { + newStreamHook = cc.syncHooks.newstream + cc.syncHooks.blockUntil(func() bool { + select { + case cc.reqHeaderMu <- struct{}{}: + <-cc.reqHeaderMu + case <-cs.reqCancel: + case <-ctx.Done(): + default: + return false + } + return true + }) + } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() - if cc.syncHooks != nil { - cc.syncHooks.newstream(cs) + if newStreamHook != nil { + newStreamHook(cs) } // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? @@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) + var pingError error + errc := make(chan struct{}) cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } }) + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-c: + case <-errc: + case <-ctx.Done(): + case <-cc.readerDone: + default: + return false + } + return true + }) + } select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: diff --git a/http2/transport_test.go b/http2/transport_test.go index 836d45593..bab2472f3 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3310,26 +3310,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { } func TestTransportCloseAfterLostPing(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.tr.PingTimeout = 1 * time.Second - ct.tr.ReadIdleTimeout = 1 * time.Second - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - if err == nil || !strings.Contains(err.Error(), "client connection lost") { - return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - <-clientDone - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.PingTimeout = 1 * time.Second + tr.ReadIdleTimeout = 1 * time.Second + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.advance(1 * time.Second) + tc.wantFrameType(FramePing) + + tc.advance(1 * time.Second) + err := rt.err() + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + t.Fatalf("expected to get error about \"connection lost\", got %v", err) } - ct.run() } func TestTransportPingWriteBlocks(t *testing.T) { @@ -3362,38 +3360,73 @@ func TestTransportPingWriteBlocks(t *testing.T) { } } -func TestTransportPingWhenReading(t *testing.T) { - testCases := []struct { - name string - readIdleTimeout time.Duration - deadline time.Duration - expectedPingCount int - }{ - { - name: "two pings", - readIdleTimeout: 100 * time.Millisecond, - deadline: time.Second, - expectedPingCount: 2, - }, - { - name: "zero ping", - readIdleTimeout: time.Second, - deadline: 200 * time.Millisecond, - expectedPingCount: 0, - }, - { - name: "0 readIdleTimeout means no ping", - readIdleTimeout: 0 * time.Millisecond, - deadline: 500 * time.Millisecond, - expectedPingCount: 0, - }, +func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 1000 * time.Millisecond + }) + tc.greet() + + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + for i := 0; i < 5; i++ { + // No ping yet... + tc.advance(999 * time.Millisecond) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) + } + + // ...ping now. + tc.advance(1 * time.Millisecond) + f := testClientConnReadFrame[*PingFrame](tc) + tc.writePing(true, f.Data) } - for _, tc := range testCases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) - }) + // Cancel the request, Transport resets it and returns an error from body reads. + cancel() + tc.sync() + + tc.wantFrameType(FrameRSTStream) + _, err := rt.readBody() + if err == nil { + t.Fatalf("Response.Body.Read() = %v, want error", err) + } +} + +func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 0 // PINGs disabled + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // No PING is sent, even after a long delay. + tc.advance(1 * time.Minute) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } } From 6e2c99c943496e33025da68db088edff5dc7d07b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:01:26 -0700 Subject: [PATCH 60/70] http2: allow testing Transports with testSyncHooks Change-Id: Icafc4860ef0691e5133221a0b53bb1d2158346cc Reviewed-on: https://go-review.googlesource.com/c/net/+/572378 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 202 ++++++++++++++++++++------ http2/transport.go | 35 +++-- http2/transport_test.go | 301 ++++++++++++++------------------------- 3 files changed, 288 insertions(+), 250 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 97f884c66..73ceefd7b 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -99,62 +99,57 @@ type testClientConn struct { roundtrips []*testRoundTrip - rerr error // returned by Read - rbuf bytes.Buffer // sent to the test conn - wbuf bytes.Buffer // sent by the test conn + rerr error // returned by Read + netConnClosed bool // set when the ClientConn closes the net.Conn + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn } -func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { - t.Helper() - - tr := &Transport{} - for _, o := range opts { - o(tr) - } - +func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { tc := &testClientConn{ t: t, - tr: tr, - hooks: newTestSyncHooks(), + tr: cc.t, + cc: cc, + hooks: cc.t.syncHooks, } + cc.tconn = (*testClientConnNetConn)(tc) tc.enc = hpack.NewEncoder(&tc.encbuf) tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) tc.fr.SetMaxReadFrameSize(10 << 20) - t.Cleanup(func() { tc.sync() if tc.rerr == nil { tc.rerr = io.EOF } tc.sync() - if tc.hooks.total != 0 { - t.Errorf("%v goroutines still running after test completed", tc.hooks.total) - } - }) + return tc +} - tc.hooks.newclientconn = func(cc *ClientConn) { - tc.cc = cc - } - const singleUse = false - _, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks) - if err != nil { - t.Fatal(err) - } - tc.sync() - tc.hooks.newclientconn = nil - +func (tc *testClientConn) readClientPreface() { + tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. buf := make([]byte, len(clientPreface)) if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { - t.Fatalf("reading preface: %v", err) + tc.t.Fatalf("reading preface: %v", err) } if !bytes.Equal(buf, clientPreface) { - t.Fatalf("client preface: %q, want %q", buf, clientPreface) + tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) } +} - return tc +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + const singleUse = false + _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() } // sync waits for the ClientConn under test to reach a stable state, @@ -349,7 +344,7 @@ func (b *testRequestBody) closeWithError(err error) { // the request times out, or some other terminal condition is reached.) func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { rt := &testRoundTrip{ - tc: tc, + t: tc.t, donec: make(chan struct{}), } tc.roundtrips = append(tc.roundtrips, rt) @@ -362,6 +357,9 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { tc.hooks.newstream = nil tc.t.Cleanup(func() { + if !rt.done() { + return + } res, _ := rt.result() if res != nil { res.Body.Close() @@ -460,6 +458,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he tc.sync() } +func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) { + tc.t.Helper() + if err := tc.fr.WriteRSTStream(streamID, code); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + func (tc *testClientConn) writePing(ack bool, data [8]byte) { tc.t.Helper() if err := tc.fr.WritePing(ack, data); err != nil { @@ -491,9 +497,25 @@ func (tc *testClientConn) closeWrite(err error) { tc.sync() } +// inflowWindow returns the amount of inbound flow control available for a stream, +// or for the connection if streamID is 0. +func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + tc.cc.mu.Lock() + defer tc.cc.mu.Unlock() + if streamID == 0 { + return tc.cc.inflow.avail + tc.cc.inflow.unsent + } + cs := tc.cc.streams[streamID] + if cs == nil { + tc.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent +} + // testRoundTrip manages a RoundTrip in progress. type testRoundTrip struct { - tc *testClientConn + t *testing.T resp *http.Response respErr error donec chan struct{} @@ -502,6 +524,9 @@ type testRoundTrip struct { // streamID returns the HTTP/2 stream ID of the request. func (rt *testRoundTrip) streamID() uint32 { + if rt.cs == nil { + panic("stream ID unknown") + } return rt.cs.ID } @@ -517,12 +542,12 @@ func (rt *testRoundTrip) done() bool { // result returns the result of the RoundTrip. func (rt *testRoundTrip) result() (*http.Response, error) { - t := rt.tc.t + t := rt.t t.Helper() select { case <-rt.donec: default: - t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID()) + t.Fatalf("RoundTrip is not done; want it to be") } return rt.resp, rt.respErr } @@ -530,7 +555,7 @@ func (rt *testRoundTrip) result() (*http.Response, error) { // response returns the response of a successful RoundTrip. // If the RoundTrip unexpectedly failed, it calls t.Fatal. func (rt *testRoundTrip) response() *http.Response { - t := rt.tc.t + t := rt.t t.Helper() resp, err := rt.result() if err != nil { @@ -544,7 +569,7 @@ func (rt *testRoundTrip) response() *http.Response { // err returns the (possibly nil) error result of RoundTrip. func (rt *testRoundTrip) err() error { - t := rt.tc.t + t := rt.t t.Helper() _, err := rt.result() return err @@ -552,7 +577,7 @@ func (rt *testRoundTrip) err() error { // wantStatus indicates the expected response StatusCode. func (rt *testRoundTrip) wantStatus(want int) { - t := rt.tc.t + t := rt.t t.Helper() if got := rt.response().StatusCode; got != want { t.Fatalf("got response status %v, want %v", got, want) @@ -561,7 +586,7 @@ func (rt *testRoundTrip) wantStatus(want int) { // body reads the contents of the response body. func (rt *testRoundTrip) readBody() ([]byte, error) { - t := rt.tc.t + t := rt.t t.Helper() return io.ReadAll(rt.response().Body) } @@ -569,7 +594,7 @@ func (rt *testRoundTrip) readBody() ([]byte, error) { // wantBody indicates the expected response body. // (Note that this consumes the body.) func (rt *testRoundTrip) wantBody(want []byte) { - t := rt.tc.t + t := rt.t t.Helper() got, err := rt.readBody() if err != nil { @@ -582,7 +607,7 @@ func (rt *testRoundTrip) wantBody(want []byte) { // wantHeaders indicates the expected response headers. func (rt *testRoundTrip) wantHeaders(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Header, want); diff != "" { @@ -592,7 +617,7 @@ func (rt *testRoundTrip) wantHeaders(want http.Header) { // wantTrailers indicates the expected response trailers. func (rt *testRoundTrip) wantTrailers(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Trailer, want); diff != "" { @@ -630,7 +655,8 @@ func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { return nc.wbuf.Write(b) } -func (*testClientConnNetConn) Close() error { +func (nc *testClientConnNetConn) Close() error { + nc.netConnClosed = true return nil } @@ -639,3 +665,91 @@ func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } + +// A testTransport allows testing Transport.RoundTrip against fake servers. +// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling +// should use testClientConn instead. +type testTransport struct { + t *testing.T + tr *Transport + + ccs []*testClientConn +} + +func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { + tr := &Transport{ + syncHooks: newTestSyncHooks(), + } + for _, o := range opts { + o(tr) + } + + tt := &testTransport{ + t: t, + tr: tr, + } + tr.syncHooks.newclientconn = func(cc *ClientConn) { + tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc)) + } + + t.Cleanup(func() { + tt.sync() + if len(tt.ccs) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + } + if tt.tr.syncHooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total) + } + }) + + return tt +} + +func (tt *testTransport) sync() { + tt.tr.syncHooks.waitInactive() +} + +func (tt *testTransport) advance(d time.Duration) { + tt.tr.syncHooks.advance(d) + tt.sync() +} + +func (tt *testTransport) hasConn() bool { + return len(tt.ccs) > 0 +} + +func (tt *testTransport) getConn() *testClientConn { + tt.t.Helper() + if len(tt.ccs) == 0 { + tt.t.Fatalf("no new ClientConns created; wanted one") + } + tc := tt.ccs[0] + tt.ccs = tt.ccs[1:] + tc.sync() + tc.readClientPreface() + return tc +} + +func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tt.t, + donec: make(chan struct{}), + } + tt.tr.syncHooks.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + }) + tt.sync() + + tt.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} diff --git a/http2/transport.go b/http2/transport.go index 1ce5f125c..bf1dacd35 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -184,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -597,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -633,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -718,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -814,6 +824,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo } if hooks != nil { hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d diff --git a/http2/transport_test.go b/http2/transport_test.go index bab2472f3..5de0ad8c4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3688,61 +3688,49 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { } func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + var totalDelay time.Duration + count := 0 + for streamID := uint32(1); ; streamID += 2 { + count++ + tc.wantHeaders(wantHeader{ + streamID: streamID, + endStream: true, + }) + if streamID == 1 { + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) + tc.writeRSTStream(streamID, ErrCodeRefusedStream) + + d := tt.tr.syncHooks.timeUntilEvent() + if d == 0 { + if streamID == 1 { + continue } + break + } + totalDelay += d + if totalDelay > 5*time.Minute { + t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } + tt.advance(d) + } + if got, want := count, 5; got < count { + t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) + } + if rt.err() == nil { + t.Errorf("RoundTrip succeeded, want error") } - ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { @@ -5593,155 +5581,80 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - lower ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool + // This test verifies that + // - receiving a protocol error on a connection does not interfere with + // other requests in flight on that connection; + // - the connection is not reused for further requests; and + // - the failed request is retried on a new connecection. + tt := newTestTransport(t) + + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. + + // Request #1: The long request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tt.roundTrip(req1) + tc1 := tt.getConn() + tc1.wantFrameType(FrameSettings) + tc1.wantFrameType(FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc1.writeSettings() + tc1.wantFrameType(FrameSettings) // settings ACK + + // Request #2(a): The short request. + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := tt.roundTrip(req2) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } + // Request #2(a) fails with ErrCodeProtocol. + tc1.writeRSTStream(3, ErrCodeProtocol) + if rt1.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress") } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil + if rt2.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress") } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) + // Request #2(b): The short request is retried on a new connection. + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc2.writeSettings() + tc2.wantFrameType(FrameSettings) // settings ACK - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - } - ct.run() + // Request #2(b) succeeds. + tc2.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "201", + ), + }) + rt2.wantStatus(201) + + // Request #1 succeeds. + tc1.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) } func TestClientConnReservations(t *testing.T) { From 89f602b7bbf237abe0467031a18b42fc742ced08 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 19 Mar 2024 00:18:26 -0700 Subject: [PATCH 61/70] http2: validate client/outgoing trailers This change is a counterpart to the HTTP/1.1 trailers validation CL 572615. This change will ensure that we validate trailers that were set on the HTTP client before they are transformed to HTTP/2 equivalents. Updates golang/go#64766 Change-Id: Id1afd7f7e9af820ea969ef226bbb16e4de6d57a5 Reviewed-on: https://go-review.googlesource.com/c/net/+/572655 Auto-Submit: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Damien Neil Run-TryBot: Emmanuel Odeke LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- http2/transport.go | 33 ++++++++++++++++++++++----------- http2/transport_test.go | 13 ++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index bf1dacd35..ba0956e22 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2019,6 +2019,22 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } + } + return "" +} + var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -2056,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { diff --git a/http2/transport_test.go b/http2/transport_test.go index 5de0ad8c4..5226a61f7 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2290,7 +2290,8 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { } // golang.org/issue/14048 -func TestTransportFailsOnInvalidHeaders(t *testing.T) { +// golang.org/issue/64766 +func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { @@ -2303,6 +2304,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { tests := [...]struct { h http.Header + t http.Header wantErr string }{ 0: { @@ -2321,6 +2323,14 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { h: http.Header{"foo": {"foo\x01bar"}}, wantErr: `invalid HTTP header value for header "foo"`, }, + 4: { + t: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer value for header "foo"`, + }, + 5: { + t: http.Header{"x-\r\nda": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer name "x-\r\nda"`, + }, } tr := &Transport{TLSClientConfig: tlsConfigInsecure} @@ -2329,6 +2339,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { for i, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h + req.Trailer = tt.t res, err := tr.RoundTrip(req) var bad bool if tt.wantErr == "" { From d73acffdc9493532acb85777105bb4a351eea702 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sun, 10 Mar 2024 23:40:56 +0800 Subject: [PATCH 62/70] http2: only set up deadline when Server.IdleTimeout is positive Check out https://go-review.googlesource.com/c/go/+/570396 Change-Id: I8bda17acebc27308c2a1723191ea1e4a9e64d585 Reviewed-on: https://go-review.googlesource.com/c/net/+/570455 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Reviewed-by: Damien Neil Auto-Submit: Damien Neil --- http2/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/http2/server.go b/http2/server.go index 905206f3e..ce2e8b40e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -124,6 +124,7 @@ type Server struct { // IdleTimeout specifies how long until idle clients should be // closed with a GOAWAY frame. PING frames are not considered // activity for the purposes of IdleTimeout. + // If zero or negative, there is no timeout. IdleTimeout time.Duration // MaxUploadBufferPerConnection is the size of the initial flow @@ -924,7 +925,7 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } @@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { From d8870b0bf2f2426fc8d19a9332f652da5c25418f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:51:41 -0700 Subject: [PATCH 63/70] http2: use synthetic time in TestIdleConnTimeout Rewrite TestIdleConnTimeout to use the new synthetic time and synchronization test facilities, rather than using real time and sleeps. Reduces the test time from 20 seconds to 0. Reduces all package tests on my laptop from 32 seconds to 12. Change-Id: I33838488168450a7acd6a462777b5a4caf7f5307 Reviewed-on: https://go-review.googlesource.com/c/net/+/572379 Reviewed-by: Jonathan Amsterdam Reviewed-by: Emmanuel Odeke LUCI-TryBot-Result: Go LUCI --- http2/transport.go | 4 +- http2/transport_test.go | 90 +++++++++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index ba0956e22..ce375c8c7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -310,7 +310,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -828,7 +828,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) diff --git a/http2/transport_test.go b/http2/transport_test.go index 5226a61f7..18d4db3ed 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -97,63 +97,83 @@ func startH2cServer(t *testing.T) net.Listener { func TestIdleConnTimeout(t *testing.T) { for _, test := range []struct { + name string idleConnTimeout time.Duration wait time.Duration baseTransport *http.Transport - wantConns int32 + wantNewConn bool }{{ + name: "NoExpiry", idleConnTimeout: 2 * time.Second, wait: 1 * time.Second, baseTransport: nil, - wantConns: 1, + wantNewConn: false, }, { + name: "H2TransportTimeoutExpires", idleConnTimeout: 1 * time.Second, wait: 2 * time.Second, baseTransport: nil, - wantConns: 5, + wantNewConn: true, }, { + name: "H1TransportTimeoutExpires", idleConnTimeout: 0 * time.Second, wait: 1 * time.Second, baseTransport: &http.Transport{ IdleConnTimeout: 2 * time.Second, }, - wantConns: 1, + wantNewConn: false, }} { - var gotConns int32 - - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - }, optOnlyServer) - defer st.Close() + t.Run(test.name, func(t *testing.T) { + tt := newTestTransport(t, func(tr *Transport) { + tr.IdleConnTimeout = test.idleConnTimeout + }) + var tc *testClientConn + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // This request happens on a new conn if it's the first request + // (and there is no cached conn), or if the test timeout is long + // enough that old conns are being closed. + wantConn := i == 0 || test.wantNewConn + if has := tt.hasConn(); has != wantConn { + t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn) + } + if wantConn { + tc = tt.getConn() + // Read client's SETTINGS and first WINDOW_UPDATE, + // send our SETTINGS. + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings() + } + if tt.hasConn() { + t.Fatalf("request %v: Transport has more than one conn", i) + } - tr := &Transport{ - IdleConnTimeout: test.idleConnTimeout, - TLSClientConfig: tlsConfigInsecure, - } - defer tr.CloseIdleConnections() + // Respond to the client's request. + hf := testClientConnReadFrame[*MetaHeadersFrame](tc) + tc.writeHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) - for i := 0; i < 5; i++ { - req, _ := http.NewRequest("GET", st.ts.URL, http.NoBody) - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConns, 1) - } - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + // If this was a newly-accepted conn, read the SETTINGS ACK. + if wantConn { + tc.wantFrameType(FrameSettings) // ACK to our settings + } - _, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("%v", err) + tt.advance(test.wait) + if got, want := tc.netConnClosed, test.wantNewConn; got != want { + t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) + } } - - <-time.After(test.wait) - } - - if gotConns != test.wantConns { - t.Errorf("incorrect gotConns: %d != %d", gotConns, test.wantConns) - } + }) } } From c7877ac4213b2f859831366f5a35b353e0dc9f66 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:24:51 -0700 Subject: [PATCH 64/70] http2: convert the remaining clientTester tests to testClientConn Change-Id: Ia7f213346baff48504fef6dfdc112575a5459f35 Reviewed-on: https://go-review.googlesource.com/c/net/+/572380 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 74 ++ http2/transport_test.go | 1651 +++++++++++--------------------------- 2 files changed, 535 insertions(+), 1190 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 73ceefd7b..4237b1436 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "reflect" + "slices" "testing" "time" @@ -209,6 +210,71 @@ func (tc *testClientConn) wantFrameType(want FrameType) { } } +// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied. +// +// want is a list of func(*SomeFrame) bool. +// wantUnorderedFrames will call each func with frames of the appropriate type +// until the func returns true. +// It calls t.Fatal if an unexpected frame is received (no func has that frame type, +// or all funcs with that type have returned true), or if the conn runs out of frames +// with unsatisfied funcs. +// +// Example: +// +// // Read a SETTINGS frame, and any number of DATA frames for a stream. +// // The SETTINGS frame may appear anywhere in the sequence. +// // The last DATA frame must indicate the end of the stream. +// tc.wantUnorderedFrames( +// func(f *SettingsFrame) bool { +// return true +// }, +// func(f *DataFrame) bool { +// return f.StreamEnded() +// }, +// ) +func (tc *testClientConn) wantUnorderedFrames(want ...any) { + tc.t.Helper() + want = slices.Clone(want) + seen := 0 +frame: + for seen < len(want) && !tc.t.Failed() { + fr := tc.readFrame() + if fr == nil { + break + } + for i, f := range want { + if f == nil { + continue + } + typ := reflect.TypeOf(f) + if typ.Kind() != reflect.Func || + typ.NumIn() != 1 || + typ.NumOut() != 1 || + typ.Out(0) != reflect.TypeOf(true) { + tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) + } + if typ.In(0) == reflect.TypeOf(fr) { + out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)}) + if out[0].Bool() { + want[i] = nil + seen++ + } + continue frame + } + } + tc.t.Errorf("got unexpected frame type %T", fr) + } + if seen < len(want) { + for _, f := range want { + if f == nil { + continue + } + tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0)) + } + tc.t.Fatalf("did not see %v expected frame types", len(want)-seen) + } +} + type wantHeader struct { streamID uint32 endStream bool @@ -401,6 +467,14 @@ func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte tc.sync() } +func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + tc.t.Helper() + if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + // makeHeaderBlockFragment encodes headers in a form suitable for inclusion // in a HEADERS or CONTINUATION frame. // diff --git a/http2/transport_test.go b/http2/transport_test.go index 18d4db3ed..855c107ef 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2724,122 +2724,75 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { } func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { - ct := newClientTester(t) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } + tc := newTestClientConn(t) + tc.greet() - if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { - return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) - } - res.Body.Close() // leaving 4999 bytes unread + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - return nil + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) + initialInflow := tc.inflowWindow(0) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + const streamNotEnded = false + if oneDataFrame { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000)) + } else { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1)) } - ct.server = func() error { - ct.greet() - - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialInflow := ct.inflowWindow(0) - - // Two cases: - // - Send one DATA frame with 5000 bytes. - // - Send two DATA frames with 1 and 4999 bytes each. - // - // In both cases, the client should consume one byte of data, - // refund that byte, then refund the following 4999 bytes. - // - // In the second case, the server waits for the client to reset the - // stream before sending the second DATA frame. This tests the case - // where the client receives a DATA frame after it has reset the stream. - if oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - } else { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - } + res := rt.response() + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + t.Fatalf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + tc.sync() - wantRST := true - wantWUF := true - if !oneDataFrame { - wantWUF = false // flow control update is small, and will not be sent - } - for wantRST || wantWUF { - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + sentAdditionalData := false + tc.wantUnorderedFrames( + func(f *RSTStreamFrame) bool { + if f.ErrCode != ErrCodeCancel { + t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - switch f := f.(type) { - case *RSTStreamFrame: - if !wantRST { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) - } - wantRST = false - case *WindowUpdateFrame: - if !wantWUF { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.Increment != 5000 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) - } - wantWUF = false - default: - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) + if !oneDataFrame { + // Send the remaining data now. + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999)) + sentAdditionalData = true } - } - if !oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + return true + }, + func(f *WindowUpdateFrame) bool { + if !oneDataFrame && !sentAdditionalData { + t.Fatalf("Got WindowUpdateFrame, don't expect one yet") } - wuf, ok := f.(*WindowUpdateFrame) - if !ok || wuf.Increment != 5000 { - return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + if f.Increment != 5000 { + t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - } - if err := ct.writeReadPing(); err != nil { - return err - } - if got, want := ct.inflowWindow(0), initialInflow; got != want { - return fmt.Errorf("connection flow tokens = %v, want %v", got, want) - } - return nil + return true + }, + ) + + if got, want := tc.inflowWindow(0), initialInflow; got != want { + t.Fatalf("connection flow tokens = %v, want %v", got, want) } - ct.run() } // See golang.org/issue/16481 @@ -2855,199 +2808,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. func TestTransportAdjustsFlowControl(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - const bodySize = 1 << 20 - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + // Don't write our SETTINGS yet. - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err + body := tc.newRequestBody() + body.writeBytes(bodySize) + body.closeWithError(io.EOF) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + gotBytes := int64(0) + for { + f := testClientConnReadFrame[*DataFrame](tc) + gotBytes += int64(len(f.Data())) + // After we've got half the client's initial flow control window's worth + // of request body data, give it just enough flow control to finish. + if gotBytes >= initialWindowSize/2 { + break } - res.Body.Close() - return nil } - ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - var gotBytes int64 - var sentSettings bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - return nil - default: - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - } - switch f := f.(type) { - case *DataFrame: - gotBytes += int64(len(f.Data())) - // After we've got half the client's - // initial flow control window's worth - // of request body data, give it just - // enough flow control to finish. - if gotBytes >= initialWindowSize/2 && !sentSettings { - sentSettings = true - - ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) - ct.fr.WriteWindowUpdate(0, bodySize) - ct.fr.WriteSettingsAck() - } + tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) + tc.writeWindowUpdate(0, bodySize) + tc.writeSettingsAck() - if f.StreamEnded() { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - } - } + tc.wantUnorderedFrames( + func(f *SettingsFrame) bool { return true }, + func(f *DataFrame) bool { + gotBytes += int64(len(f.Data())) + return f.StreamEnded() + }, + ) + + if gotBytes != bodySize { + t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize) } - ct.run() + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) } // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() - unblockClient := make(chan bool, 1) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - defer res.Body.Close() - <-unblockClient - return nil - } - ct.server = func() error { - ct.greet() + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } + initialConnWindow := tc.inflowWindow(0) + initialStreamWindow := tc.inflowWindow(rt.streamID()) - initialConnWindow := ct.inflowWindow(0) + pad := make([]byte, 5) + tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialStreamWindow := ct.inflowWindow(hf.StreamID) - pad := make([]byte, 5) - ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - if err := ct.writeReadPing(); err != nil { - return err - } - // Padding flow control should have been returned. - if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { - t.Errorf("conn inflow window = %v, want %v", got, want) - } - if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { - t.Errorf("stream inflow window = %v, want %v", got, want) - } - unblockClient <- true - return nil + // Padding flow control should have been returned. + if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) + } + if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } - ct.run() } // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - return errors.New("unexpected successful GET") - } - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} - if !reflect.DeepEqual(want, err) { - t.Errorf("RoundTrip error = %#v; want %#v", err, want) - } - return nil - } - ct.server = func() error { - ct.greet() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - hf, err := ct.firstHeaders() - if err != nil { - return err - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + " content-type", "bogus", + ), + }) - for { - fr, err := ct.readFrame() - if err != nil { - return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) - } - if _, ok := fr.(*SettingsFrame); ok { - continue - } - if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) - } - break - } + err := rt.err() + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(err, want) { + t.Fatalf("RoundTrip error = %#v; want %#v", err, want) + } - return nil + fr := testClientConnReadFrame[*RSTStreamFrame](tc) + if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } - ct.run() } // byteAndEOFReader returns is in an io.Reader which reads one byte @@ -3461,261 +3339,84 @@ func TestTransportPingWhenReadingPingDisabled(t *testing.T) { } } -func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { - var pingCount int - ct := newClientTester(t) - ct.tr.ReadIdleTimeout = readIdleTimeout - - ctx, cancel := context.WithTimeout(context.Background(), deadline) - defer cancel() - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) - } - _, err = ioutil.ReadAll(res.Body) - if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil - } +func TestTransportRetryAfterGOAWAY(t *testing.T) { + tt := newTestTransport(t) - cancel() - return err - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var streamID uint32 - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-ctx.Done(): - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - streamID = f.StreamID - case *PingFrame: - pingCount++ - if pingCount == expectedPingCount { - if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { - return err - } - } - if err := ct.fr.WritePing(true, f.Data); err != nil { - return err - } - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil) + if rt.done() { + t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying") } - ct.run() -} - -func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { - ln := newLocalListener(t) - defer ln.Close() - var ( - mu sync.Mutex - count int - conns []net.Conn - ) - var wg sync.WaitGroup - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - } - tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - defer mu.Unlock() - count++ - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - return nil, fmt.Errorf("dial error: %v", err) - } - conns = append(conns, cc) - sc, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("accept error: %v", err) - } - conns = append(conns, sc) - ct := &clientTester{ - t: t, - tr: tr, - cc: cc, - sc: sc, - fr: NewFramer(sc, sc), - } - wg.Add(1) - go func(count int) { - defer wg.Done() - server(count, ct) - }(count) - return cc, nil - } + // Second attempt succeeds on a new connection. + tc = tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - client(tr) - tr.CloseIdleConnections() - ln.Close() - for _, c := range conns { - c.Close() - } - wg.Wait() + rt.wantStatus(200) } -func TestTransportRetryAfterGOAWAY(t *testing.T) { - client := func(tr *Transport) { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := tr.RoundTrip(req) - if res != nil { - res.Body.Close() - if got := res.Header.Get("Foo"); got != "bar" { - err = fmt.Errorf("foo header = %q; want bar", got) - } - } - if err != nil { - t.Errorf("RoundTrip: %v", err) - } - } - - server := func(count int, ct *clientTester) { - switch count { - case 1: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - t.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - case 2: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server2 failed reading HEADERS: %v", err) - return - } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - t.Errorf("server2 failed writing response HEADERS: %v", err) - } - default: - t.Errorf("unexpected number of dials") - return - } - } +func TestTransportRetryAfterRefusedStream(t *testing.T) { + tt := newTestTransport(t) - testClientMultipleDials(t, client, server) -} + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) -func TestTransportRetryAfterRefusedStream(t *testing.T) { - clientDone := make(chan struct{}) - client := func(tr *Transport) { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - return - } - resp.Body.Close() - if resp.StatusCode != 204 { - t.Errorf("Status = %v; want 204", resp.StatusCode) - return - } + // First attempt: Server sends a RST_STREAM. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK + tc.writeRSTStream(1, ErrCodeRefusedStream) + if rt.done() { + t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying") } - server := func(_ int, ct *clientTester) { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var count int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - default: - t.Error(err) - } - return - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - t.Errorf("headers should have END_HEADERS be ended: %v", f) - return - } - count++ - if count == 1 { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - } else { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - default: - t.Errorf("Unexpected client frame %v", f) - return - } - } - } + // Second attempt succeeds on the same connection. + tc.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 3, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(204) } func TestTransportRetryHasLimit(t *testing.T) { @@ -3765,67 +3466,34 @@ func TestTransportRetryHasLimit(t *testing.T) { } func TestTransportResponseDataBeforeHeaders(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) + // Discard log output complaining about protocol error. + log.SetOutput(io.Discard) + t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req := httptest.NewRequest("GET", "https://dummy.tld/", nil) - // First request is normal to ensure the check is per stream and not per connection. - _, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip expected no error, got: %v", err) - } - // Second request returns a DATA frame with no HEADERS. - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { - return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: - case *HeadersFrame: - switch f.StreamID { - case 1: - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - case 3: - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - } - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc := newTestClientConn(t) + tc.greet() + + // First request is normal to ensure the check is per stream and not per connection. + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt1.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) + + // Second request returns a DATA frame with no HEADERS. + rt2 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeData(rt2.streamID(), true, []byte("payload")) + if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err) } - ct.run() } func TestTransportMaxFrameReadSize(t *testing.T) { @@ -3839,30 +3507,17 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - ct := newClientTester(t) - ct.tr.MaxReadFrameSize = test.maxReadFrameSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - return nil - } - ct.server = func() error { - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - var got uint32 - ct.settings.ForeachSetting(func(s Setting) error { - switch s.ID { - case SettingMaxFrameSize: - got = s.Val - } - return nil - }) - if got != test.want { - t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxReadFrameSize = test.maxReadFrameSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + got, ok := fr.Value(SettingMaxFrameSize) + if !ok { + t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want) + } else if got != test.want { + t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) } - ct.run() } } @@ -3902,337 +3557,126 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { - res.Body.Close() - } - } - - if connCount != 1 { - t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) - } -} - -// tests Transport.StrictMaxConcurrentStreams -func TestTransportRequestsStallAtServerLimit(t *testing.T) { - const maxConcurrent = 2 - - greet := make(chan struct{}) // server sends initial SETTINGS frame - gotRequest := make(chan struct{}) // server received a request - clientDone := make(chan struct{}) - cancelClientRequest := make(chan struct{}) - - // Collect errors from goroutines. - var wg sync.WaitGroup - errs := make(chan error, 100) - defer func() { - wg.Wait() - close(errs) - for err := range errs { - t.Error(err) - } - }() - - // We will send maxConcurrent+2 requests. This checker goroutine waits for the - // following stages: - // 1. The first maxConcurrent requests are received by the server. - // 2. The client will cancel the next request - // 3. The server is unblocked so it can service the first maxConcurrent requests - // 4. The client will send the final request - wg.Add(1) - unblockClient := make(chan struct{}) - clientRequestCancelled := make(chan struct{}) - unblockServer := make(chan struct{}) - go func() { - defer wg.Done() - // Stage 1. - for k := 0; k < maxConcurrent; k++ { - <-gotRequest - } - // Stage 2. - close(unblockClient) - <-clientRequestCancelled - // Stage 3: give some time for the final RoundTrip call to be scheduled and - // verify that the final request is not sent. - time.Sleep(50 * time.Millisecond) - select { - case <-gotRequest: - errs <- errors.New("last request did not stall") - close(unblockServer) - return - default: - } - close(unblockServer) - // Stage 4. - <-gotRequest - }() - - ct := newClientTester(t) - ct.tr.StrictMaxConcurrentStreams = true - ct.client = func() error { - var wg sync.WaitGroup - defer func() { - wg.Wait() - close(clientDone) - ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.cc.(*net.TCPConn).Close() - } - }() - for k := 0; k < maxConcurrent+2; k++ { - wg.Add(1) - go func(k int) { - defer wg.Done() - // Don't send the second request until after receiving SETTINGS from the server - // to avoid a race where we use the default SettingMaxConcurrentStreams, which - // is much larger than maxConcurrent. We have to send the first request before - // waiting because the first request triggers the dial and greet. - if k > 0 { - <-greet - } - // Block until maxConcurrent requests are sent before sending any more. - if k >= maxConcurrent { - <-unblockClient - } - body := newStaticCloseChecker("") - req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) - if k == maxConcurrent { - // This request will be canceled. - req.Cancel = cancelClientRequest - close(cancelClientRequest) - _, err := ct.tr.RoundTrip(req) - close(clientRequestCancelled) - if err == nil { - errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) - return - } - } else { - resp, err := ct.tr.RoundTrip(req) - if err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - return - } - ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 204 { - errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - if err := body.isClosed(); err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - } - }(k) + res.Body.Close() } - return nil } - ct.server = func() error { - var wg sync.WaitGroup - defer wg.Wait() + if connCount != 1 { + t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) + } +} - ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) +// tests Transport.StrictMaxConcurrentStreams +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 - // Server write loop. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - writeResp := make(chan uint32, maxConcurrent+1) + tc := newTestClientConn(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) + tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) - wg.Add(1) - go func() { - defer wg.Done() - <-unblockServer - for id := range writeResp { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: id, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - }() + cancelClientRequest := make(chan struct{}) - // Server read loop. - var nreq int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it will have reported any errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame: - case *SettingsFrame: - // Wait for the client SETTINGS ack until ending the greet. - close(greet) - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - gotRequest <- struct{}{} - nreq++ - writeResp <- f.StreamID - if nreq == maxConcurrent+1 { - close(writeResp) - } - case *DataFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) + // Start maxConcurrent+2 requests. + // The server does not respond to any of them yet. + var rts []*testRoundTrip + for k := 0; k < maxConcurrent+2; k++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + req.Cancel = cancelClientRequest + } + rt := tc.roundTrip(req) + rts = append(rts, rt) + + if k < maxConcurrent { + // We are under the stream limit, so the client sends the request. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", k)}, + }, + }) + } else { + // We have reached the stream limit, + // so the client cannot send the request. + if fr := tc.readFrame(); fr != nil { + t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr) } } + + if rt.done() { + t.Fatalf("rt %v done", k) + } + } + + // Cancel the maxConcurrent'th request. + // The request should fail. + close(cancelClientRequest) + tc.sync() + if err := rts[maxConcurrent].err(); err == nil { + t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) + } + + // No requests should be complete, except for the canceled one. + for i, rt := range rts { + if i != maxConcurrent && rt.done() { + t.Fatalf("RoundTrip(%d) is done, but should not be", i) + } } - ct.run() + // Server responds to a request, unblocking the last one. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rts[0].streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.wantHeaders(wantHeader{ + streamID: rts[maxConcurrent+1].streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)}, + }, + }) + rts[0].wantStatus(200) } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var reqSize, resSize uint32 = 8192, 16384 - ct.tr.MaxDecoderHeaderTableSize = reqSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.peerMaxHeaderTableSize, resSize; got != want { - return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxDecoderHeaderTableSize = reqSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + if v, ok := fr.Value(SettingHeaderTableSize); !ok { + t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting") + } else if v != reqSize { + t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize) } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - var found bool - err = sf.ForeachSetting(func(s Setting) error { - if s.ID == SettingHeaderTableSize { - found = true - if got, want := s.Val, reqSize; got != want { - return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want) - } - } - return nil - }) - if err != nil { - return err - } - if !found { - return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting") - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) + if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want { + t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } - ct.run() } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var peerAdvertisedMaxHeaderTableSize uint32 = 16384 - ct.tr.MaxEncoderHeaderTableSize = 8192 - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want { - return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want) - } - return nil - } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxEncoderHeaderTableSize = 8192 + }) + tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { + t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } - ct.run() } func TestAuthorityAddr(t *testing.T) { @@ -4316,40 +3760,24 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. func TestTransportNoBodyMeansNoDATA(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool) + tc := newTestClientConn(t) + tc.greet() - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - <-unblockClient - return nil - } - ct.server = func() error { - defer close(unblockClient) - defer ct.cc.(*net.TCPConn).Close() - ct.greet() + req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f := f.(type) { - default: - return fmt.Errorf("Got %T; want HeadersFrame", f) - case *WindowUpdateFrame, *SettingsFrame: - continue - case *HeadersFrame: - if !f.StreamEnded() { - return fmt.Errorf("got headers frame without END_STREAM") - } - return nil - } - } + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, // END_STREAM should be set when body is http.NoBody + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + }, + }) + if fr := tc.readFrame(); fr != nil { + t.Fatalf("unexpected frame after headers: %v", fr) } - ct.run() } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { @@ -4428,41 +3856,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - const substr = "malformed response from server: missing status pseudo header" - if !strings.Contains(fmt.Sprint(err), substr) { - return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, // we'll send some DATA to try to crash the transport - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - return nil - } - } - } - ct.run() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: tc.makeHeaderBlockFragment( + "content-type", "text/html", // no :status header + ), + }) + tc.writeData(rt.streamID(), true, []byte("payload")) } func BenchmarkClientRequestHeaders(b *testing.B) { @@ -4810,95 +4219,42 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - // So far we've only seen this be flaky on Windows and Plan 9, - // perhaps due to TCP behavior on shutdowns while - // unread data is in flight. This test should be - // fixed, but a skip is better than annoying people - // for now. - t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) - } - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*clientConnPool) - if !ok { - return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) - } - cp.mu.Lock() - defer cp.mu.Unlock() - conns, ok := cp.conns["dummy.tld:443"] - if !ok { - return fmt.Errorf("missing connection") - } - if len(conns) != 1 { - return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) - } - if activeStreams(conns[0]) != 0 { - return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) - } - return nil - } - bodyReadError := errors.New("body read error") - body := &errReader{body, bodyReadError} - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != bodyReadError { - return fmt.Errorf("err = %v; want %v", err, bodyReadError) - } - if err = checkNoStreams(); err != nil { - return err + tc := newTestClientConn(t) + tc.greet() + + bodyReadError := errors.New("body read error") + b := tc.newRequestBody() + b.Write(body) + b.closeWithError(bodyReadError) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", b) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + var receivedBody []byte +readFrames: + for { + switch f := tc.readFrame().(type) { + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + break readFrames + default: + t.Fatalf("unexpected frame: %v", f) + case nil: + t.Fatalf("transport is idle, want RST_STREAM") } - return nil } - ct.server = func() error { - ct.greet() - var receivedBody []byte - var resetCount int - for { - f, err := ct.fr.ReadFrame() - t.Logf("server: ReadFrame = %v, %v", f, err) - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - if bytes.Compare(receivedBody, body) != 0 { - return fmt.Errorf("body: %q; expected %q", receivedBody, body) - } - if resetCount != 1 { - return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) - } - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - receivedBody = append(receivedBody, f.Data()...) - case *RSTStreamFrame: - resetCount++ - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + if !bytes.Equal(receivedBody, body) { + t.Fatalf("body: %q; expected %q", receivedBody, body) + } + + if err := rt.err(); err != bodyReadError { + t.Fatalf("err = %v; want %v", err, bodyReadError) + } + + if got := activeStreams(tc.cc); got != 0 { + t.Fatalf("active streams count: %v; want 0", got) } - ct.run() } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -4911,59 +4267,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) { const reqBody = "some request body" const resBody = "some response body" - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - body := strings.NewReader(reqBody) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } + body := strings.NewReader(reqBody) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + tc.roundTrip(req) - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - return fmt.Errorf("data frame without END_STREAM %v", f) - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte(resBody)) - return nil - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc.wantFrameType(FrameHeaders) + f := testClientConnReadFrame[*DataFrame](tc) + if !f.StreamEnded() { + t.Fatalf("data frame without END_STREAM %v", f) } - ct.run() } type chunkReader struct { @@ -5737,39 +5052,27 @@ func TestClientConnReservations(t *testing.T) { } func TestTransportTimeoutServerHangs(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) + tc := newTestClientConn(t) + tc.greet() - req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) - if err != nil { - return err - } + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - req.Header.Add("Big", strings.Repeat("a", 1<<20)) - _, err = ct.tr.RoundTrip(req) - if err == nil { - return errors.New("error should not be nil") - } - if ne, ok := err.(net.Error); !ok || !ne.Timeout() { - return fmt.Errorf("error should be a net error timeout: %v", err) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.advance(5 * time.Second) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - ct.server = func() error { - ct.greet() - select { - case <-time.After(5 * time.Second): - case <-clientDone: - } - return nil + if rt.done() { + t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned") + } + + cancel() + tc.sync() + if rt.err() != context.Canceled { + t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } - ct.run() } func TestTransportContentLengthWithoutBody(t *testing.T) { @@ -5962,20 +5265,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } -type closeOnceConn struct { - net.Conn - closed uint32 -} - -var errClosed = errors.New("Close of closed connection") - -func (c *closeOnceConn) Close() error { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return c.Conn.Close() - } - return errClosed -} - // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // @@ -5983,53 +5272,35 @@ func (c *closeOnceConn) Close() error { // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { - ct := newClientTester(t) - ct.cc = &closeOnceConn{Conn: ct.cc} + tc := newTestClientConn(t) + tc.greet() - var wg sync.WaitGroup - wg.Add(1) - ct.client = func() error { - defer wg.Done() - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - } - if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { - t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) - } - if err = ct.cc.Close(); err != errClosed { - return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) - } - return nil - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ct.server = func() error { - defer wg.Wait() - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - return fmt.Errorf("server failed reading HEADERS: %v", err) - } - if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { - return fmt.Errorf("server failed writing GOAWAY: %v", err) - } - if lastStream > 0 { - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.writeGoAway(lastStream, ErrCodeNo, nil) + + if lastStream > 0 { + // Send a valid response to first request. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) } - ct.run() + tc.closeWrite(io.EOF) + err := rt.err() + if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { + t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) + } + if !tc.netConnClosed { + t.Errorf("ClientConn did not close its net.Conn, expected it to") + } } type slowCloser struct { From 448c44f9287b6745f958d74aa2a17ec7761c2f13 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Mar 2024 10:37:19 -0700 Subject: [PATCH 65/70] http2: remove clientTester All tests which use clientTester have been converted to use testClientConn, so delete clientTester. Change-Id: Id9a88bf7ee6760fada8442d383d5e68455c6dc3e Reviewed-on: https://go-review.googlesource.com/c/net/+/572815 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/transport_test.go | 195 ---------------------------------------- 1 file changed, 195 deletions(-) diff --git a/http2/transport_test.go b/http2/transport_test.go index 855c107ef..11ff67b4c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -822,53 +822,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { return } -type clientTester struct { - t *testing.T - tr *Transport - sc, cc net.Conn // server and client conn - fr *Framer // server's framer - settings *SettingsFrame - client func() error - server func() error -} - -func newClientTester(t *testing.T) *clientTester { - var dialOnce struct { - sync.Mutex - dialed bool - } - ct := &clientTester{ - t: t, - } - ct.tr = &Transport{ - TLSClientConfig: tlsConfigInsecure, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialOnce.Lock() - defer dialOnce.Unlock() - if dialOnce.dialed { - return nil, errors.New("only one dial allowed in test mode") - } - dialOnce.dialed = true - return ct.cc, nil - }, - } - - ln := newLocalListener(t) - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - sc, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - ln.Close() - ct.cc = cc - ct.sc = sc - ct.fr = NewFramer(sc, sc) - return ct -} - func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { @@ -881,154 +834,6 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func (ct *clientTester) greet(settings ...Setting) { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - ct.t.Fatalf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - ct.t.Fatalf("Reading client settings frame: %v", err) - } - var ok bool - if ct.settings, ok = f.(*SettingsFrame); !ok { - ct.t.Fatalf("Wanted client settings frame; got %v", f) - } - if err := ct.fr.WriteSettings(settings...); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } -} - -func (ct *clientTester) readNonSettingsFrame() (Frame, error) { - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return nil, err - } - if _, ok := f.(*SettingsFrame); ok { - continue - } - return f, nil - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection, -// aside from SETTINGS frames. -func (ct *clientTester) writeReadPing() error { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := ct.fr.WritePing(false, data); err != nil { - return fmt.Errorf("Error writing PING: %v", err) - } - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - p, ok := f.(*PingFrame) - if !ok { - return fmt.Errorf("got a %v, want a PING ACK", f) - } - if p.Flags&FlagPingAck == 0 { - return fmt.Errorf("got a PING, want a PING ACK") - } - if p.Data != data { - return fmt.Errorf("got PING data = %x, want %x", p.Data, data) - } - return nil -} - -func (ct *clientTester) inflowWindow(streamID uint32) int32 { - pool := ct.tr.connPoolOrDef.(*clientConnPool) - pool.mu.Lock() - defer pool.mu.Unlock() - if n := len(pool.keys); n != 1 { - ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) - return -1 - } - for cc := range pool.keys { - cc.mu.Lock() - defer cc.mu.Unlock() - if streamID == 0 { - return cc.inflow.avail + cc.inflow.unsent - } - cs := cc.streams[streamID] - if cs == nil { - ct.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent - } - return -1 -} - -func (ct *clientTester) cleanup() { - ct.tr.CloseIdleConnections() - - // close both connections, ignore the error if its already closed - ct.sc.Close() - ct.cc.Close() -} - -func (ct *clientTester) run() { - var errOnce sync.Once - var wg sync.WaitGroup - - run := func(which string, fn func() error) { - defer wg.Done() - if err := fn(); err != nil { - errOnce.Do(func() { - ct.t.Errorf("%s: %v", which, err) - ct.cleanup() - }) - } - } - - wg.Add(2) - go run("client", ct.client) - go run("server", ct.server) - wg.Wait() - - errOnce.Do(ct.cleanup) // clean up if no error -} - -func (ct *clientTester) readFrame() (Frame, error) { - return ct.fr.ReadFrame() -} - -func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { - for { - f, err := ct.readFrame() - if err != nil { - return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - hf, ok := f.(*HeadersFrame) - if !ok { - return nil, fmt.Errorf("Got %T; want HeadersFrame", f) - } - return hf, nil - } -} - -type countingReader struct { - n *int64 -} - -func (r countingReader) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(i) - } - atomic.AddInt64(r.n, int64(len(p))) - return len(p), err -} - func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } From 3678185f8a652e52864c44049a9ea96b7bcc066a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Mar 2024 13:42:48 -0700 Subject: [PATCH 66/70] http2: make TestCanonicalHeaderCacheGrowth faster Lower the number of iterations that this test runs for. Reduces runtime with -race from 27s on my M1 Mac to 0.06s. Change-Id: Ibd4b225277c79d9030c0a21b3077173a787cc4c1 Reviewed-on: https://go-review.googlesource.com/c/net/+/572656 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/server_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/http2/server_test.go b/http2/server_test.go index 1fdd191ef..afccd9ecd 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4578,13 +4578,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { sc := &serverConn{ serveG: newGoroutineLock(), } - const count = 1000 - for i := 0; i < count; i++ { - h := fmt.Sprintf("%v-%v", base, i) + count := 0 + added := 0 + for added < 10*maxCachedCanonicalHeadersKeysSize { + h := fmt.Sprintf("%v-%v", base, count) c := sc.canonicalHeader(h) if len(h) != len(c) { t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) } + count++ + added += len(h) } total := 0 for k, v := range sc.canonHeader { From ebc8168ac8ac742194df729305175940790c55a2 Mon Sep 17 00:00:00 2001 From: vitalmotif Date: Wed, 20 Mar 2024 09:32:28 +0000 Subject: [PATCH 67/70] all: fix some typos Change-Id: I7e2c867efcc960553da77e395b0069ab6776cd9f GitHub-Last-Rev: eaa122d1b6086c22f329227053411b0a73a5215b GitHub-Pull-Request: golang/net#205 Reviewed-on: https://go-review.googlesource.com/c/net/+/572995 Reviewed-by: Emmanuel Odeke Reviewed-by: David Chase Auto-Submit: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- dns/dnsmessage/message_test.go | 2 +- quic/rangeset.go | 2 +- quic/retry_test.go | 2 +- quic/stream_limits_test.go | 2 +- quic/version_test.go | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index e60ec42d9..255530598 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -1635,7 +1635,7 @@ func FuzzUnpackPack(f *testing.F) { msgPacked, err := m.Pack() if err != nil { - t.Fatalf("failed to pack message that was succesfully unpacked: %v", err) + t.Fatalf("failed to pack message that was successfully unpacked: %v", err) } var m2 Message diff --git a/quic/rangeset.go b/quic/rangeset.go index 4966a99d2..b8b2e9367 100644 --- a/quic/rangeset.go +++ b/quic/rangeset.go @@ -50,7 +50,7 @@ func (s *rangeset[T]) add(start, end T) { if end <= r.end { return } - // Possibly coalesce subsquent ranges into range i. + // Possibly coalesce subsequent ranges into range i. r.end = end j := i + 1 for ; j < len(*s) && r.end >= (*s)[j].start; j++ { diff --git a/quic/retry_test.go b/quic/retry_test.go index 42f2bdd4a..c898ad331 100644 --- a/quic/retry_test.go +++ b/quic/retry_test.go @@ -521,7 +521,7 @@ func TestParseInvalidRetryPackets(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { if _, ok := parseRetryPacket(test.pkt, originalDstConnID); ok { - t.Errorf("parseRetryPacket succeded, want failure") + t.Errorf("parseRetryPacket succeeded, want failure") } }) } diff --git a/quic/stream_limits_test.go b/quic/stream_limits_test.go index 9c2f71ec1..8fed825d7 100644 --- a/quic/stream_limits_test.go +++ b/quic/stream_limits_test.go @@ -249,7 +249,7 @@ func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStopSending{ id: s.id, }) - tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM", + tc.wantFrame("received STOP_SENDING, send RESET_STREAM", packetType1RTT, debugFrameResetStream{ id: s.id, }) diff --git a/quic/version_test.go b/quic/version_test.go index 92fabd7b3..0bd8bac14 100644 --- a/quic/version_test.go +++ b/quic/version_test.go @@ -39,10 +39,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { }) gotPkt := te.read() if gotPkt == nil { - t.Fatalf("got no response; want Version Negotiaion") + t.Fatalf("got no response; want Version Negotiation") } if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { - t.Fatalf("got packet type %v; want Version Negotiaion", got) + t.Fatalf("got packet type %v; want Version Negotiation", got) } gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) if got, want := gotDst, srcConnID; !bytes.Equal(got, want) { From ba872109ef2dc8f1da778651bd1fd3792d0e4587 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 10 Jan 2024 13:41:39 -0800 Subject: [PATCH 68/70] http2: close connections when receiving too many headers Maintaining HPACK state requires that we parse and process all HEADERS and CONTINUATION frames on a connection. When a request's headers exceed MaxHeaderBytes, we don't allocate memory to store the excess headers but we do parse them. This permits an attacker to cause an HTTP/2 endpoint to read arbitrary amounts of data, all associated with a request which is going to be rejected. Set a limit on the amount of excess header frames we will process before closing a connection. Thanks to Bartek Nowotarski for reporting this issue. Fixes CVE-2023-45288 Fixes golang/go#65051 Change-Id: I15df097268df13bb5a9e9d3a5c04a8a141d850f6 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/2130527 Reviewed-by: Roland Shoemaker Reviewed-by: Tatiana Bradley Reviewed-on: https://go-review.googlesource.com/c/net/+/576155 Reviewed-by: Dmitri Shuralyov Auto-Submit: Dmitri Shuralyov Reviewed-by: Than McIntosh LUCI-TryBot-Result: Go LUCI --- http2/frame.go | 31 ++++++++++++++++ http2/server_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/http2/frame.go b/http2/frame.go index e2b298d85..a5a94411d 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the struture of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the struture of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { return nil, ConnectionError(ErrCodeCompression) } diff --git a/http2/server_test.go b/http2/server_test.go index afccd9ecd..d400990d2 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4786,3 +4786,87 @@ Frames: close(s) } } + +func TestServerContinuationFlood(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }, func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 4096 + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + for i := 0; i < 1000; i++ { + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + fmt.Sprintf("x-%v", i), "1234567890", + )) + } + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-last-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} + +func TestServerContinuationAfterInvalidHeader(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + "x-invalid-header", "\x00", + )) + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-valid-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} From 762b58d1cf6e0779780decad89c6c1523386638d Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Wed, 3 Apr 2024 09:32:37 -0700 Subject: [PATCH 69/70] http2: fix tipos in comment Change-Id: I20cd0f8db534fe2a849306eb7e0c8ee5b434e88f Reviewed-on: https://go-review.googlesource.com/c/net/+/576175 Auto-Submit: Ian Lance Taylor Reviewed-by: Ian Lance Taylor LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- http2/frame.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/http2/frame.go b/http2/frame.go index a5a94411d..43557ab7e 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1591,7 +1591,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { log.Printf("http2: header list too large") } // It would be nice to send a RST_STREAM before sending the GOAWAY, - // but the struture of the server's frame writer makes this difficult. + // but the structure of the server's frame writer makes this difficult. return nil, ConnectionError(ErrCodeProtocol) } @@ -1603,7 +1603,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { log.Printf("http2: invalid header: %v", invalid) } // It would be nice to send a RST_STREAM before sending the GOAWAY, - // but the struture of the server's frame writer makes this difficult. + // but the structure of the server's frame writer makes this difficult. return nil, ConnectionError(ErrCodeProtocol) } From c48da131589f122489348be5dfbcb6457640046f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 3 Apr 2024 10:17:18 -0700 Subject: [PATCH 70/70] http2: fix TestServerContinuationFlood flakes This test causes the server to send a GOAWAY and close a connection. The server GOAWAY path writes a GOAWAY frame asynchronously, and closes the connection if the write doesn't complete within 1s. This is causing failures on some builders, when the frame write doesn't complete in time. The important aspect of this test is that the connection be closed. Drop the check for the GOAWAY frame. Change-Id: I099413be9c4dfe71d8fe83d2c6242e82e282293e Reviewed-on: https://go-review.googlesource.com/c/net/+/576235 Reviewed-by: Dmitri Shuralyov Reviewed-by: Dmitri Shuralyov Reviewed-by: Than McIntosh LUCI-TryBot-Result: Go LUCI --- http2/server_test.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/http2/server_test.go b/http2/server_test.go index d400990d2..a931a06e5 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4813,22 +4813,24 @@ func TestServerContinuationFlood(t *testing.T) { "x-last-header", "1", )) - var sawGoAway bool for { f, err := st.readFrame() if err != nil { break } switch f.(type) { - case *GoAwayFrame: - sawGoAway = true case *HeadersFrame: - t.Fatalf("received HEADERS frame; want GOAWAY") + t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection") } } - if !sawGoAway { - t.Errorf("connection closed with no GOAWAY frame; want one") - } + // We expect to have seen a GOAWAY before the connection closes, + // but the server will close the connection after one second + // whether or not it has finished sending the GOAWAY. On windows-amd64-race + // builders, this fairly consistently results in the connection closing without + // the GOAWAY being sent. + // + // Since the server's behavior is inherently racy here and the important thing + // is that the connection is closed, don't check for the GOAWAY having been sent. } func TestServerContinuationAfterInvalidHeader(t *testing.T) {