diff --git a/go.mod b/go.mod index 21deffd4b..8ab3f40e1 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.15.0 - golang.org/x/sys v0.14.0 - golang.org/x/term v0.14.0 + golang.org/x/crypto v0.16.0 + golang.org/x/sys v0.15.0 + golang.org/x/term v0.15.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 54759e489..bb6ed68a0 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index cc5292e9e..20f737b52 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/url" "os" @@ -25,14 +26,16 @@ import ( "sync" "golang.org/x/net/internal/quic" + "golang.org/x/net/internal/quic/qlog" ) var ( - listen = flag.String("listen", "", "listen address") - cert = flag.String("cert", "", "certificate") - pkey = flag.String("key", "", "private key") - root = flag.String("root", "", "serve files from this root") - output = flag.String("output", "", "directory to write files to") + listen = flag.String("listen", "", "listen address") + cert = flag.String("cert", "", "certificate") + pkey = flag.String("key", "", "private key") + root = flag.String("root", "", "serve files from this root") + output = flag.String("output", "", "directory to write files to") + qlogdir = flag.String("qlog", "", "directory to write qlog output to") ) func main() { @@ -48,6 +51,10 @@ func main() { }, MaxBidiRemoteStreams: -1, MaxUniRemoteStreams: -1, + QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + Level: quic.QLogLevelFrame, + Dir: *qlogdir, + })), } if *cert != "" { c, err := tls.LoadX509KeyPair(*cert, *pkey) @@ -150,7 +157,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) { } } -func serve(ctx context.Context, l *quic.Listener) error { +func serve(ctx context.Context, l *quic.Endpoint) error { for { c, err := l.Accept(ctx) if err != nil { @@ -214,7 +221,7 @@ func parseURL(s string) (u *url.URL, authority string, err error) { return u, authority, nil } -func fetchFrom(ctx context.Context, l *quic.Listener, addr string, urls []*url.URL) { +func fetchFrom(ctx context.Context, l *quic.Endpoint, addr string, urls []*url.URL) { conn, err := l.Dial(ctx, "udp", addr) if err != nil { log.Printf("%v: %v", addr, err) diff --git a/internal/quic/cmd/interop/run_endpoint.sh b/internal/quic/cmd/interop/run_endpoint.sh index d72335d8e..442039bc0 100644 --- a/internal/quic/cmd/interop/run_endpoint.sh +++ b/internal/quic/cmd/interop/run_endpoint.sh @@ -11,7 +11,7 @@ if [ "$ROLE" == "client" ]; then # Wait for the simulator to start up. /wait-for-it.sh sim:57832 -s -t 30 - ./interop -output=/downloads $CLIENT_PARAMS $REQUESTS + ./interop -output=/downloads -qlog=$QLOGDIR $CLIENT_PARAMS $REQUESTS elif [ "$ROLE" == "server" ]; then - ./interop -cert=/certs/cert.pem -key=/certs/priv.key -listen=:443 -root=/www "$@" $SERVER_PARAMS + ./interop -cert=/certs/cert.pem -key=/certs/priv.key -qlog=$QLOGDIR -listen=:443 -root=/www "$@" $SERVER_PARAMS fi diff --git a/internal/quic/config.go b/internal/quic/config.go index 6278bf89c..b045b7b92 100644 --- a/internal/quic/config.go +++ b/internal/quic/config.go @@ -8,6 +8,9 @@ package quic import ( "crypto/tls" + "log/slog" + "math" + "time" ) // A Config structure configures a QUIC endpoint. @@ -72,9 +75,39 @@ type Config struct { // // If this field is left as zero, stateless reset is disabled. StatelessResetKey [32]byte + + // HandshakeTimeout is the maximum time in which a connection handshake must complete. + // If zero, the default of 10 seconds is used. + // If negative, there is no handshake timeout. + HandshakeTimeout time.Duration + + // MaxIdleTimeout is the maximum time after which an idle connection will be closed. + // If zero, the default of 30 seconds is used. + // If negative, idle connections are never closed. + // + // The idle timeout for a connection is the minimum of the maximum idle timeouts + // of the endpoints. + MaxIdleTimeout time.Duration + + // KeepAlivePeriod is the time after which a packet will be sent to keep + // an idle connection alive. + // If zero, keep alive packets are not sent. + // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and + // half the connection idle timeout. + KeepAlivePeriod time.Duration + + // QLogLogger receives qlog events. + // + // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03. + // This is not the latest version of the draft, but is the latest version supported + // by common event log viewers as of the time this paragraph was written. + // + // The qlog package contains a slog.Handler which serializes qlog events + // to a standard JSON representation. + QLogLogger *slog.Logger } -func configDefault(v, def, limit int64) int64 { +func configDefault[T ~int64](v, def, limit T) T { switch { case v == 0: return def @@ -104,3 +137,15 @@ func (c *Config) maxStreamWriteBufferSize() int64 { func (c *Config) maxConnReadBufferSize() int64 { return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint) } + +func (c *Config) handshakeTimeout() time.Duration { + return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64) +} + +func (c *Config) maxIdleTimeout() time.Duration { + return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64) +} + +func (c *Config) keepAlivePeriod() time.Duration { + return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64) +} diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 1292f2b20..31e789b1d 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "errors" "fmt" + "log/slog" "net/netip" "time" ) @@ -20,27 +21,22 @@ import ( // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn struct { side connSide - listener *Listener + endpoint *Endpoint config *Config testHooks connTestHooks peerAddr netip.AddrPort - msgc chan any - donec chan struct{} // closed when conn loop exits - exited bool // set to make the conn loop exit immediately + msgc chan any + donec chan struct{} // closed when conn loop exits w packetWriter acks [numberSpaceCount]ackState // indexed by number space lifetime lifetimeState + idle idleState connIDState connIDState loss lossState streams streamsState - // idleTimeout is the time at which the connection will be closed due to inactivity. - // https://www.rfc-editor.org/rfc/rfc9000#section-10.1 - maxIdleTimeout time.Duration - idleTimeout time.Time - // Packet protection keys, CRYPTO streams, and TLS state. keysInitial fixedKeyPair keysHandshake fixedKeyPair @@ -60,6 +56,8 @@ type Conn struct { // Tests only: Send a PING in a specific number space. testSendPingSpace numberSpace testSendPing sentVal + + log *slog.Logger } // connTestHooks override conn behavior in tests. @@ -94,25 +92,31 @@ type newServerConnIDs struct { retrySrcConnID []byte // source from server's Retry } -func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { c := &Conn{ side: side, - listener: l, + endpoint: e, config: config, peerAddr: peerAddr, msgc: make(chan any, 1), donec: make(chan struct{}), - maxIdleTimeout: defaultMaxIdleTimeout, - idleTimeout: now.Add(defaultMaxIdleTimeout), peerAckDelayExponent: -1, } + defer func() { + // If we hit an error in newConn, close donec so tests don't get stuck waiting for it. + // This is only relevant if we've got a bug, but it makes tracking that bug down + // much easier. + if conn == nil { + close(c.donec) + } + }() // A one-element buffer allows us to wake a Conn's event loop as a // non-blocking operation. c.msgc = make(chan any, 1) - if l.testHooks != nil { - l.testHooks.newConn(c) + if e.testHooks != nil { + e.testHooks.newConn(c) } // initialConnID is the connection ID used to generate Initial packet protection keys. @@ -132,13 +136,13 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip } } - // The smallest allowed maximum QUIC datagram size is 1200 bytes. // TODO: PMTU discovery. - const maxDatagramSize = 1200 + c.logConnectionStarted(cids.originalDstConnID, peerAddr) c.keysAppData.init() - c.loss.init(c.side, maxDatagramSize, now) + c.loss.init(c.side, smallestMaxDatagramSize, now) c.streamsInit() c.lifetimeInit() + c.restartIdleTimer(now) if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), @@ -183,13 +187,14 @@ func (c *Conn) confirmHandshake(now time.Time) { if c.side == serverSide { // When the server confirms the handshake, it sends a HANDSHAKE_DONE. c.handshakeConfirmed.setUnsent() - c.listener.serverConnEstablished(c) + c.endpoint.serverConnEstablished(c) } else { // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed // to the received state, indicating that the handshake is confirmed and we // don't need to send anything. c.handshakeConfirmed.setReceived() } + c.restartIdleTimer(now) c.loss.confirmHandshake() // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed" // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1 @@ -220,6 +225,7 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error { c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni + c.receivePeerMaxIdleTimeout(p.maxIdleTimeout) c.peerAckDelayExponent = p.ackDelayExponent c.loss.setMaxAckDelay(p.maxAckDelay) if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil { @@ -236,7 +242,6 @@ func (c *Conn) receiveTransportParameters(p transportParameters) error { return err } } - // TODO: max_idle_timeout // TODO: stateless_reset_token // TODO: max_udp_payload_size // TODO: disable_active_migration @@ -249,6 +254,8 @@ type ( wakeEvent struct{} ) +var errIdleTimeout = errors.New("idle timeout") + // loop is the connection main loop. // // Except where otherwise noted, all connection state is owned by the loop goroutine. @@ -258,7 +265,8 @@ type ( func (c *Conn) loop(now time.Time) { defer close(c.donec) defer c.tls.Close() - defer c.listener.connDrained(c) + defer c.endpoint.connDrained(c) + defer c.logConnectionClosed() // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to @@ -275,14 +283,14 @@ func (c *Conn) loop(now time.Time) { defer timer.Stop() } - for !c.exited { + for c.lifetime.state != connStateDone { sendTimeout := c.maybeSend(now) // try sending // Note that we only need to consider the ack timer for the App Data space, // since the Initial and Handshake spaces always ack immediately. nextTimeout := sendTimeout - nextTimeout = firstTime(nextTimeout, c.idleTimeout) - if !c.isClosingOrDraining() { + nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout) + if c.isAlive() { nextTimeout = firstTime(nextTimeout, c.loss.timer) nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) } else { @@ -316,11 +324,9 @@ func (c *Conn) loop(now time.Time) { m.recycle() case timerEvent: // A connection timer has expired. - if !now.Before(c.idleTimeout) { - // "[...] the connection is silently closed and - // its state is discarded [...]" - // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 - c.exited = true + if c.idleAdvance(now) { + // The connection idle timer has expired. + c.abortImmediately(now, errIdleTimeout) return } c.loss.advance(now, c.handleAckOrLoss) diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go index a9ef0db5e..246a12638 100644 --- a/internal/quic/conn_close.go +++ b/internal/quic/conn_close.go @@ -12,33 +12,54 @@ import ( "time" ) +// connState is the state of a connection. +type connState int + +const ( + // A connection is alive when it is first created. + connStateAlive = connState(iota) + + // The connection has received a CONNECTION_CLOSE frame from the peer, + // and has not yet sent a CONNECTION_CLOSE in response. + // + // We will send a CONNECTION_CLOSE, and then enter the draining state. + connStatePeerClosed + + // The connection is in the closing state. + // + // We will send CONNECTION_CLOSE frames to the peer + // (once upon entering the closing state, and possibly again in response to peer packets). + // + // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state. + // Otherwise, we will eventually time out and move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1 + connStateClosing + + // The connection is in the draining state. + // + // We will neither send packets nor process received packets. + // When the drain timer expires, we move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2 + connStateDraining + + // The connection is done, and the conn loop will exit. + connStateDone +) + // lifetimeState tracks the state of a connection. // // This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps // reason about operations that cause state transitions. type lifetimeState struct { - readyc chan struct{} // closed when TLS handshake completes - drainingc chan struct{} // closed when entering the draining state + state connState + + readyc chan struct{} // closed when TLS handshake completes + donec chan struct{} // closed when finalErr is set - // Possible states for the connection: - // - // Alive: localErr and finalErr are both nil. - // - // Closing: localErr is non-nil and finalErr is nil. - // We have sent a CONNECTION_CLOSE to the peer or are about to - // (if connCloseSentTime is zero) and are waiting for the peer to respond. - // drainEndTime is set to the time the closing state ends. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.1 - // - // Draining: finalErr is non-nil. - // If localErr is nil, we're waiting for the user to provide us with a final status - // to send to the peer. - // Otherwise, we've either sent a CONNECTION_CLOSE to the peer or are about to - // (if connCloseSentTime is zero). - // drainEndTime is set to the time the draining state ends. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 localErr error // error sent to the peer - finalErr error // error sent by the peer, or transport error; always set before draining + finalErr error // error sent by the peer, or transport error; set before closing donec connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent @@ -47,7 +68,7 @@ type lifetimeState struct { func (c *Conn) lifetimeInit() { c.lifetime.readyc = make(chan struct{}) - c.lifetime.drainingc = make(chan struct{}) + c.lifetime.donec = make(chan struct{}) } var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") @@ -60,13 +81,25 @@ func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { // The connection drain period has ended, and we can shut down. // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7 c.lifetime.drainEndTime = time.Time{} - if c.lifetime.finalErr == nil { - // The peer never responded to our CONNECTION_CLOSE. - c.enterDraining(now, errNoPeerResponse) + if c.lifetime.state != connStateDraining { + // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer. + c.setFinalError(errNoPeerResponse) } + c.setState(now, connStateDone) return true } +// setState sets the conn state. +func (c *Conn) setState(now time.Time, state connState) { + switch state { + case connStateClosing, connStateDraining: + if c.lifetime.drainEndTime.IsZero() { + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + } + } + c.lifetime.state = state +} + // confirmHandshake is called when the TLS handshake completes. func (c *Conn) handshakeDone() { close(c.lifetime.readyc) @@ -81,44 +114,66 @@ func (c *Conn) handshakeDone() { // // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 func (c *Conn) isDraining() bool { - return c.lifetime.finalErr != nil + switch c.lifetime.state { + case connStateDraining, connStateDone: + return true + } + return false } -// isClosingOrDraining reports whether the conn is in the closing or draining states. -func (c *Conn) isClosingOrDraining() bool { - return c.lifetime.localErr != nil || c.lifetime.finalErr != nil +// isAlive reports whether the conn is handling packets. +func (c *Conn) isAlive() bool { + return c.lifetime.state == connStateAlive } // sendOK reports whether the conn can send frames at this time. func (c *Conn) sendOK(now time.Time) bool { - if !c.isClosingOrDraining() { + switch c.lifetime.state { + case connStateAlive: return true - } - // We are closing or draining. - if c.lifetime.localErr == nil { - // We're waiting for the user to close the connection, providing us with - // a final status to send to the peer. + case connStatePeerClosed: + if c.lifetime.localErr == nil { + // We're waiting for the user to close the connection, providing us with + // a final status to send to the peer. + return false + } + // We should send a CONNECTION_CLOSE. + return true + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + return true + } + maxRecvTime := c.acks[initialSpace].maxRecvTime + if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { + // After sending CONNECTION_CLOSE, ignore packets from the peer for + // a delay. On the next packet received after the delay, send another + // CONNECTION_CLOSE. + return false + } + return true + case connStateDraining: + // We are in the draining state, and will send no more packets. return false + case connStateDone: + return false + default: + panic("BUG: unhandled connection state") } - // Past this point, returning true will result in the conn sending a CONNECTION_CLOSE - // due to localErr being set. - if c.lifetime.drainEndTime.IsZero() { - // The closing and draining states should last for at least three times - // the current PTO interval. We currently use exactly that minimum. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-5 - // - // The drain period begins when we send or receive a CONNECTION_CLOSE, - // whichever comes first. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2-3 - c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) +} + +// sendConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer. +func (c *Conn) sentConnectionClose(now time.Time) { + switch c.lifetime.state { + case connStatePeerClosed: + c.enterDraining(now) } if c.lifetime.connCloseSentTime.IsZero() { - // We haven't sent a CONNECTION_CLOSE yet. Do so. - // Either we're initiating an immediate close - // (and will enter the closing state as soon as we send CONNECTION_CLOSE), - // or we've read a CONNECTION_CLOSE from our peer - // (and may send one CONNECTION_CLOSE before entering the draining state). - // // Set the initial delay before we will send another CONNECTION_CLOSE. // // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames, @@ -126,65 +181,56 @@ func (c *Conn) sendOK(now time.Time) bool { // with the same delay as the PTO timer (RFC 9002, Section 6.2.1), // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent. c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity) - c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) - return true - } - if c.isDraining() { - // We are in the draining state, and will send no more packets. - return false - } - maxRecvTime := c.acks[initialSpace].maxRecvTime - if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { - maxRecvTime = t - } - if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { - maxRecvTime = t - } - if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { - // After sending CONNECTION_CLOSE, ignore packets from the peer for - // a delay. On the next packet received after the delay, send another - // CONNECTION_CLOSE. - return false + } else if !c.lifetime.connCloseSentTime.Equal(now) { + // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames + // coalesced into the same datagram. We only want to increase the delay once. + c.lifetime.connCloseDelay *= 2 } c.lifetime.connCloseSentTime = now - c.lifetime.connCloseDelay *= 2 - return true } -// enterDraining enters the draining state. -func (c *Conn) enterDraining(now time.Time, err error) { - if c.isDraining() { - return +// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer. +func (c *Conn) handlePeerConnectionClose(now time.Time, err error) { + c.setFinalError(err) + switch c.lifetime.state { + case connStateAlive: + c.setState(now, connStatePeerClosed) + case connStatePeerClosed: + // Duplicate CONNECTION_CLOSE, ignore. + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + c.setState(now, connStatePeerClosed) + } else { + c.setState(now, connStateDraining) + } + case connStateDraining: + case connStateDone: } - if err == errStatelessReset { - // If we've received a stateless reset, then we must not send a CONNECTION_CLOSE. - // Setting connCloseSentTime here prevents us from doing so. - c.lifetime.finalErr = errStatelessReset - c.lifetime.localErr = errStatelessReset - c.lifetime.connCloseSentTime = now - } else if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo { - // If we've terminated the connection due to a peer protocol violation, - // record the final error on the connection as our reason for termination. - c.lifetime.finalErr = c.lifetime.localErr - } else { - c.lifetime.finalErr = err +} + +// setFinalError records the final connection status we report to the user. +func (c *Conn) setFinalError(err error) { + select { + case <-c.lifetime.donec: + return // already set + default: } - close(c.lifetime.drainingc) - c.streams.queue.close(c.lifetime.finalErr) + c.lifetime.finalErr = err + close(c.lifetime.donec) } func (c *Conn) waitReady(ctx context.Context) error { select { case <-c.lifetime.readyc: return nil - case <-c.lifetime.drainingc: + case <-c.lifetime.donec: return c.lifetime.finalErr default: } select { case <-c.lifetime.readyc: return nil - case <-c.lifetime.drainingc: + case <-c.lifetime.donec: return c.lifetime.finalErr case <-ctx.Done(): return ctx.Err() @@ -199,7 +245,7 @@ func (c *Conn) waitReady(ctx context.Context) error { // err := conn.Wait(context.Background()) func (c *Conn) Close() error { c.Abort(nil) - <-c.lifetime.drainingc + <-c.lifetime.donec return c.lifetime.finalErr } @@ -213,7 +259,7 @@ func (c *Conn) Close() error { // containing the peer's error code and reason. // If the peer closes the connection with any other status, Wait returns a non-nil error. func (c *Conn) Wait(ctx context.Context) error { - if err := c.waitOnDone(ctx, c.lifetime.drainingc); err != nil { + if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil { return err } return c.lifetime.finalErr @@ -229,30 +275,46 @@ func (c *Conn) Abort(err error) { err = localTransportError{code: errNo} } c.sendMsg(func(now time.Time, c *Conn) { - c.abort(now, err) + c.enterClosing(now, err) }) } // abort terminates a connection with an error. func (c *Conn) abort(now time.Time, err error) { - if c.lifetime.localErr != nil { - return // already closing - } - c.lifetime.localErr = err + c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE + c.enterClosing(now, err) } // abortImmediately terminates a connection. // The connection does not send a CONNECTION_CLOSE, and skips the draining period. func (c *Conn) abortImmediately(now time.Time, err error) { - c.abort(now, err) - c.enterDraining(now, err) - c.exited = true + c.setFinalError(err) + c.setState(now, connStateDone) +} + +// enterClosing starts an immediate close. +// We will send a CONNECTION_CLOSE to the peer and wait for their response. +func (c *Conn) enterClosing(now time.Time, err error) { + switch c.lifetime.state { + case connStateAlive: + c.lifetime.localErr = err + c.setState(now, connStateClosing) + case connStatePeerClosed: + c.lifetime.localErr = err + } +} + +// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE. +func (c *Conn) enterDraining(now time.Time) { + switch c.lifetime.state { + case connStateAlive, connStatePeerClosed, connStateClosing: + c.setState(now, connStateDraining) + } } // exit fully terminates a connection immediately. func (c *Conn) exit() { c.sendMsg(func(now time.Time, c *Conn) { - c.enterDraining(now, errors.New("connection closed")) - c.exited = true + c.abortImmediately(now, errors.New("connection closed")) }) } diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go index d583ae92a..49881e62f 100644 --- a/internal/quic/conn_close_test.go +++ b/internal/quic/conn_close_test.go @@ -70,7 +70,8 @@ func TestConnCloseResponseBackoff(t *testing.T) { } func TestConnCloseWithPeerResponse(t *testing.T) { - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() tc.conn.Abort(nil) @@ -99,10 +100,19 @@ func TestConnCloseWithPeerResponse(t *testing.T) { if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr) } + + tc.advance(1 * time.Second) // long enough to exit the draining state + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "application", + }, + }) } func TestConnClosePeerCloses(t *testing.T) { - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() wantErr := &ApplicationError{ @@ -128,6 +138,14 @@ func TestConnClosePeerCloses(t *testing.T) { code: 9, reason: "because", }) + + tc.advance(1 * time.Second) // long enough to exit the draining state + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "application", + }, + }) } func TestConnCloseReceiveInInitial(t *testing.T) { @@ -187,13 +205,13 @@ func TestConnCloseReceiveInHandshake(t *testing.T) { tc.wantIdle("no more frames to send") } -func TestConnCloseClosedByListener(t *testing.T) { +func TestConnCloseClosedByEndpoint(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, clientSide) tc.handshake() - tc.listener.l.Close(ctx) - tc.wantFrame("listener closes connection before exiting", + tc.endpoint.e.Close(ctx) + tc.wantFrame("endpoint closes connection before exiting", packetType1RTT, debugFrameConnectionCloseTransport{ code: errNo, }) diff --git a/internal/quic/conn_flow_test.go b/internal/quic/conn_flow_test.go index 03e0757a6..39c879346 100644 --- a/internal/quic/conn_flow_test.go +++ b/internal/quic/conn_flow_test.go @@ -262,6 +262,7 @@ func TestConnOutflowBlocked(t *testing.T) { if n != len(data) || err != nil { t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } + s.Flush() tc.wantFrame("stream writes data up to MAX_DATA limit", packetType1RTT, debugFrameStream{ @@ -310,6 +311,7 @@ func TestConnOutflowMaxDataDecreases(t *testing.T) { if n != len(data) || err != nil { t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } + s.Flush() tc.wantFrame("stream writes data up to MAX_DATA limit", packetType1RTT, debugFrameStream{ @@ -337,7 +339,9 @@ func TestConnOutflowMaxDataRoundRobin(t *testing.T) { } s1.Write(make([]byte, 10)) + s1.Flush() s2.Write(make([]byte, 10)) + s2.Flush() tc.writeFrames(packetType1RTT, debugFrameMaxData{ max: 1, @@ -378,6 +382,7 @@ func TestConnOutflowMetaAndData(t *testing.T) { data := makeTestData(32) s.Write(data) + s.Flush() s.CloseRead() tc.wantFrame("CloseRead sends a STOP_SENDING, not flow controlled", @@ -405,6 +410,7 @@ func TestConnOutflowResentData(t *testing.T) { data := makeTestData(15) s.Write(data[:8]) + s.Flush() tc.wantFrame("data is under MAX_DATA limit, all sent", packetType1RTT, debugFrameStream{ id: s.id, @@ -421,6 +427,7 @@ func TestConnOutflowResentData(t *testing.T) { }) s.Write(data[8:]) + s.Flush() tc.wantFrame("new data is sent up to the MAX_DATA limit", packetType1RTT, debugFrameStream{ id: s.id, diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index 439c22123..2efe8d6b5 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -76,7 +76,7 @@ func (s *connIDState) initClient(c *Conn) error { cid: locid, }) s.nextLocalSeq = 1 - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, locid) }) @@ -117,7 +117,7 @@ func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { cid: locid, }) s.nextLocalSeq = 1 - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, dstConnID) conns.addConnID(c, locid) }) @@ -194,7 +194,7 @@ func (s *connIDState) issueLocalIDs(c *Conn) error { s.needSend = true toIssue-- } - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { for _, cid := range newIDs { conns.addConnID(c, cid) } @@ -247,7 +247,7 @@ func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p trans } token := statelessResetToken(p.statelessResetToken) s.remote[0].resetToken = token - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, token) }) } @@ -276,7 +276,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. cid := s.local[0].cid - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:0], s.local[1:]...) @@ -314,7 +314,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re rcid := &s.remote[i] if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { s.retireRemote(rcid) - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireResetToken(c, rcid.resetToken) }) } @@ -350,7 +350,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re s.retireRemote(&s.remote[len(s.remote)-1]) } else { active++ - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, resetToken) }) } @@ -399,7 +399,7 @@ func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { for i := range s.local { if s.local[i].seq == seq { cid := s.local[i].cid - c.listener.connsMap.updateConnIDs(func(conns *connsMap) { + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:i], s.local[i+1:]...) @@ -463,7 +463,7 @@ func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { s.local[i].seq, retireBefore, s.local[i].cid, - c.listener.resetGen.tokenForConnID(s.local[i].cid), + c.endpoint.resetGen.tokenForConnID(s.local[i].cid), ) { return false } diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index 314a6b384..d44472e81 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -651,16 +651,16 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) { // Wait for the conn to drain. // Then wait for the conn loop to exit, // and force an immediate sync of the connsMap updates - // (normally only done by the listener read loop). + // (normally only done by the endpoint read loop). tc.advanceToTimer() <-tc.conn.donec - tc.listener.l.connsMap.applyUpdates() + tc.endpoint.e.connsMap.applyUpdates() - if got := len(tc.listener.l.connsMap.byConnID); got != 0 { - t.Errorf("%v conn ids in listener map after closing, want 0", got) + if got := len(tc.endpoint.e.connsMap.byConnID); got != 0 { + t.Errorf("%v conn ids in endpoint map after closing, want 0", got) } - if got := len(tc.listener.l.connsMap.byResetToken); got != 0 { - t.Errorf("%v reset tokens in listener map after closing, want 0", got) + if got := len(tc.endpoint.e.connsMap.byResetToken); got != 0 { + t.Errorf("%v reset tokens in endpoint map after closing, want 0", got) } }) } diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index 5144be6ac..818816335 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -183,7 +183,7 @@ func TestLostStreamFrameEmpty(t *testing.T) { if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + c.Flush() // open the stream tc.wantFrame("created bidirectional stream 0", packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, bidiStream, 0), @@ -213,6 +213,7 @@ func TestLostStreamWithData(t *testing.T) { p.initialMaxStreamDataUni = 1 << 20 }) s.Write(data[:4]) + s.Flush() tc.wantFrame("send [0,4)", packetType1RTT, debugFrameStream{ id: s.id, @@ -220,6 +221,7 @@ func TestLostStreamWithData(t *testing.T) { data: data[:4], }) s.Write(data[4:8]) + s.Flush() tc.wantFrame("send [4,8)", packetType1RTT, debugFrameStream{ id: s.id, @@ -263,6 +265,7 @@ func TestLostStreamPartialLoss(t *testing.T) { }) for i := range data { s.Write(data[i : i+1]) + s.Flush() tc.wantFrame(fmt.Sprintf("send STREAM frame with byte %v", i), packetType1RTT, debugFrameStream{ id: s.id, diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 896c6d74e..156ef5dd5 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -61,7 +61,7 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { // Invalid data at the end of a datagram is ignored. break } - c.idleTimeout = now.Add(c.maxIdleTimeout) + c.idleHandlePacketReceived(now) buf = buf[n:] } } @@ -525,7 +525,7 @@ func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte if n < 0 { return -1 } - c.enterDraining(now, peerTransportError{code: code, reason: reason}) + c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason}) return n } @@ -534,7 +534,7 @@ func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []by if n < 0 { return -1 } - c.enterDraining(now, &ApplicationError{Code: code, Reason: reason}) + c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason}) return n } @@ -548,7 +548,7 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa }) return -1 } - if !c.isClosingOrDraining() { + if c.isAlive() { c.confirmHandshake(now) } return 1 @@ -560,5 +560,6 @@ func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToke if !c.connIDState.isValidStatelessResetToken(resetToken) { return } - c.enterDraining(now, errStatelessReset) + c.setFinalError(errStatelessReset) + c.enterDraining(now) } diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 22e780479..4065474d2 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -77,6 +77,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { + c.idleHandlePacketSent(now, sentInitial) // Client initial packets and ack-eliciting server initial packaets // need to be sent in a datagram padded to at least 1200 bytes. // We can't add the padding yet, however, since we may want to @@ -104,6 +105,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { + c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, handshakeSpace, sent) if c.side == clientSide { // "[...] a client MUST discard Initial keys when it first @@ -131,6 +133,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { + c.idleHandlePacketSent(now, sent) c.loss.packetSent(now, appDataSpace, sent) } } @@ -167,7 +170,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } } - c.listener.sendDatagram(buf, c.peerAddr) + c.endpoint.sendDatagram(buf, c.peerAddr) } } @@ -261,6 +264,10 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, if !c.appendStreamFrames(&c.w, pnum, pto) { return } + + if !c.appendKeepAlive(now) { + return + } } // If this is a PTO probe and we haven't added an ack-eliciting frame yet, @@ -325,7 +332,7 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { } func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { - c.lifetime.connCloseSentTime = now + c.sentConnectionClose(now) switch e := err.(type) { case localTransportError: c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason) @@ -342,11 +349,12 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err // TLS alerts are sent using error codes [0x0100,0x01ff). // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 var alert tls.AlertError - if errors.As(err, &alert) { + switch { + case errors.As(err, &alert): // tls.AlertError is a uint8, so this can't exceed 0x01ff. code := errTLSBase + transportError(alert) c.w.appendConnectionCloseTransportFrame(code, 0, "") - } else { + default: c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") } } diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 69f982c3a..c90354db8 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -19,33 +19,33 @@ func TestStreamsCreate(t *testing.T) { tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() - c, err := tc.conn.NewStream(ctx) + s, err := tc.conn.NewStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created bidirectional stream 0", packetType1RTT, debugFrameStream{ id: 0, // client-initiated, bidi, number 0 data: []byte{}, }) - c, err = tc.conn.NewSendOnlyStream(ctx) + s, err = tc.conn.NewSendOnlyStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created unidirectional stream 0", packetType1RTT, debugFrameStream{ id: 2, // client-initiated, uni, number 0 data: []byte{}, }) - c, err = tc.conn.NewStream(ctx) + s, err = tc.conn.NewStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created bidirectional stream 1", packetType1RTT, debugFrameStream{ id: 4, // client-initiated, uni, number 4 @@ -177,11 +177,11 @@ func TestStreamsStreamSendOnly(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() - c, err := tc.conn.NewSendOnlyStream(ctx) + s, err := tc.conn.NewSendOnlyStream(ctx) if err != nil { t.Fatalf("NewStream: %v", err) } - c.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("created unidirectional stream 0", packetType1RTT, debugFrameStream{ id: 3, // server-initiated, uni, number 0 diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index c70c58ef0..c57ba1487 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -25,6 +25,7 @@ var testVV = flag.Bool("vv", false, "even more verbose test output") func TestConnTestConn(t *testing.T) { tc := newTestConn(t, serverSide) + tc.handshake() if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) } @@ -33,12 +34,12 @@ func TestConnTestConn(t *testing.T) { tc.conn.runOnLoop(func(now time.Time, c *Conn) { ranAt = now }) - if !ranAt.Equal(tc.listener.now) { - t.Errorf("func ran on loop at %v, want %v", ranAt, tc.listener.now) + if !ranAt.Equal(tc.endpoint.now) { + t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) } tc.wait() - nextTime := tc.listener.now.Add(defaultMaxIdleTimeout / 2) + nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) tc.advanceTo(nextTime) tc.conn.runOnLoop(func(now time.Time, c *Conn) { ranAt = now @@ -49,8 +50,8 @@ func TestConnTestConn(t *testing.T) { tc.wait() tc.advanceToTimer() - if !tc.conn.exited { - t.Errorf("after advancing to idle timeout, exited = false, want true") + if got := tc.conn.lifetime.state; got != connStateDone { + t.Errorf("after advancing to idle timeout, conn state = %v, want done", got) } } @@ -116,7 +117,7 @@ const maxTestKeyPhases = 3 type testConn struct { t *testing.T conn *Conn - listener *testListener + endpoint *testEndpoint timer time.Time timerLastFired time.Time idlec chan struct{} // only accessed on the conn's loop @@ -198,6 +199,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { // The initial connection ID for the server is chosen by the client. cids.srcConnID = testPeerConnID(0) cids.dstConnID = testPeerConnID(-1) + cids.originalDstConnID = cids.dstConnID } var configTransportParams []func(*transportParameters) var configTestConn []func(*testConn) @@ -218,27 +220,27 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { } } - listener := newTestListener(t, config) - listener.configTransportParams = configTransportParams - listener.configTestConn = configTestConn - conn, err := listener.l.newConn( - listener.now, + endpoint := newTestEndpoint(t, config) + endpoint.configTransportParams = configTransportParams + endpoint.configTestConn = configTestConn + conn, err := endpoint.e.newConn( + endpoint.now, side, cids, netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) } - tc := listener.conns[conn] + tc := endpoint.conns[conn] tc.wait() return tc } -func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testConn { +func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn { t.Helper() tc := &testConn{ t: t, - listener: listener, + endpoint: endpoint, conn: conn, peerConnID: testPeerConnID(0), ignoreFrames: map[byte]bool{ @@ -249,14 +251,14 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) - for _, f := range listener.configTestConn { + for _, f := range endpoint.configTestConn { f(tc) } conn.testHooks = (*testConnHooks)(tc) - if listener.peerTLSConn != nil { - tc.peerTLSConn = listener.peerTLSConn - listener.peerTLSConn = nil + if endpoint.peerTLSConn != nil { + tc.peerTLSConn = endpoint.peerTLSConn + endpoint.peerTLSConn = nil return tc } @@ -265,7 +267,7 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC if conn.side == clientSide { peerProvidedParams.originalDstConnID = testLocalConnID(-1) } - for _, f := range listener.configTransportParams { + for _, f := range endpoint.configTransportParams { f(&peerProvidedParams) } @@ -284,13 +286,13 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC // advance causes time to pass. func (tc *testConn) advance(d time.Duration) { tc.t.Helper() - tc.listener.advance(d) + tc.endpoint.advance(d) } // advanceTo sets the current time. func (tc *testConn) advanceTo(now time.Time) { tc.t.Helper() - tc.listener.advanceTo(now) + tc.endpoint.advanceTo(now) } // advanceToTimer sets the current time to the time of the Conn's next timer event. @@ -305,10 +307,10 @@ func (tc *testConn) timerDelay() time.Duration { if tc.timer.IsZero() { return math.MaxInt64 // infinite } - if tc.timer.Before(tc.listener.now) { + if tc.timer.Before(tc.endpoint.now) { return 0 } - return tc.timer.Sub(tc.listener.now) + return tc.timer.Sub(tc.endpoint.now) } const infiniteDuration = time.Duration(math.MaxInt64) @@ -318,10 +320,10 @@ func (tc *testConn) timeUntilEvent() time.Duration { if tc.timer.IsZero() { return infiniteDuration } - if tc.timer.Before(tc.listener.now) { + if tc.timer.Before(tc.endpoint.now) { return 0 } - return tc.timer.Sub(tc.listener.now) + return tc.timer.Sub(tc.endpoint.now) } // wait blocks until the conn becomes idle. @@ -398,7 +400,7 @@ func logDatagram(t *testing.T, text string, d *testDatagram) { // write sends the Conn a datagram. func (tc *testConn) write(d *testDatagram) { tc.t.Helper() - tc.listener.writeDatagram(d) + tc.endpoint.writeDatagram(d) } // writeFrame sends the Conn a datagram containing the given frames. @@ -464,11 +466,11 @@ func (tc *testConn) readDatagram() *testDatagram { tc.wait() tc.sentPackets = nil tc.sentFrames = nil - buf := tc.listener.read() + buf := tc.endpoint.read() if buf == nil { return nil } - d := parseTestDatagram(tc.t, tc.listener, tc, buf) + d := parseTestDatagram(tc.t, tc.endpoint, tc, buf) // Log the datagram before removing ignored frames. // When things go wrong, it's useful to see all the frames. logDatagram(tc.t, "-> conn under test sends", d) @@ -769,7 +771,7 @@ func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte return w.datagram() } -func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte) *testDatagram { +func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram { t.Helper() bufSize := len(buf) d := &testDatagram{} @@ -782,7 +784,7 @@ func parseTestDatagram(t *testing.T, tl *testListener, tc *testConn, buf []byte) ptype := getPacketType(buf) switch ptype { case packetTypeRetry: - retry, ok := parseRetryPacket(buf, tl.lastInitialDstConnID) + retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID) if !ok { t.Fatalf("could not parse %v packet", ptype) } @@ -936,7 +938,7 @@ func (tc *testConnHooks) init() { tc.keysInitial.r = tc.conn.keysInitial.w tc.keysInitial.w = tc.conn.keysInitial.r if tc.conn.side == serverSide { - tc.listener.acceptQueue = append(tc.listener.acceptQueue, (*testConn)(tc)) + tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) } } @@ -1037,20 +1039,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { tc.timer = timer for { - if !timer.IsZero() && !timer.After(tc.listener.now) { + if !timer.IsZero() && !timer.After(tc.endpoint.now) { if timer.Equal(tc.timerLastFired) { // If the connection timer fires at time T, the Conn should take some // action to advance the timer into the future. If the Conn reschedules // the timer for the same time, it isn't making progress and we have a bug. - tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.listener.now, timer) + tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) } else { tc.timerLastFired = timer - return tc.listener.now, timerEvent{} + return tc.endpoint.now, timerEvent{} } } select { case m := <-msgc: - return tc.listener.now, m + return tc.endpoint.now, m default: } if !tc.wakeAsync() { @@ -1064,7 +1066,7 @@ func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.T close(idlec) } m = <-msgc - return tc.listener.now, m + return tc.endpoint.now, m } func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { @@ -1072,7 +1074,7 @@ func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { } func (tc *testConnHooks) timeNow() time.Time { - return tc.listener.now + return tc.endpoint.now } // testLocalConnID returns the connection ID with a given sequence number diff --git a/internal/quic/listener.go b/internal/quic/endpoint.go similarity index 72% rename from internal/quic/listener.go rename to internal/quic/endpoint.go index ca8f9b25a..82a08a18c 100644 --- a/internal/quic/listener.go +++ b/internal/quic/endpoint.go @@ -17,14 +17,14 @@ import ( "time" ) -// A Listener listens for QUIC traffic on a network address. +// An Endpoint handles QUIC traffic on a network address. // It can accept inbound connections or create outbound ones. // -// Multiple goroutines may invoke methods on a Listener simultaneously. -type Listener struct { +// Multiple goroutines may invoke methods on an Endpoint simultaneously. +type Endpoint struct { config *Config udpConn udpConn - testHooks listenerTestHooks + testHooks endpointTestHooks resetGen statelessResetTokenGenerator retry retryState @@ -37,7 +37,7 @@ type Listener struct { closec chan struct{} // closed when the listen loop exits } -type listenerTestHooks interface { +type endpointTestHooks interface { timeNow() time.Time newConn(c *Conn) } @@ -53,7 +53,7 @@ type udpConn interface { // Listen listens on a local network address. // The configuration config must be non-nil. -func Listen(network, address string, config *Config) (*Listener, error) { +func Listen(network, address string, config *Config) (*Endpoint, error) { if config.TLSConfig == nil { return nil, errors.New("TLSConfig is not set") } @@ -65,11 +65,11 @@ func Listen(network, address string, config *Config) (*Listener, error) { if err != nil { return nil, err } - return newListener(udpConn, config, nil) + return newEndpoint(udpConn, config, nil) } -func newListener(udpConn udpConn, config *Config, hooks listenerTestHooks) (*Listener, error) { - l := &Listener{ +func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { + e := &Endpoint{ config: config, udpConn: udpConn, testHooks: hooks, @@ -77,70 +77,70 @@ func newListener(udpConn udpConn, config *Config, hooks listenerTestHooks) (*Lis acceptQueue: newQueue[*Conn](), closec: make(chan struct{}), } - l.resetGen.init(config.StatelessResetKey) - l.connsMap.init() + e.resetGen.init(config.StatelessResetKey) + e.connsMap.init() if config.RequireAddressValidation { - if err := l.retry.init(); err != nil { + if err := e.retry.init(); err != nil { return nil, err } } - go l.listen() - return l, nil + go e.listen() + return e, nil } // LocalAddr returns the local network address. -func (l *Listener) LocalAddr() netip.AddrPort { - a, _ := l.udpConn.LocalAddr().(*net.UDPAddr) +func (e *Endpoint) LocalAddr() netip.AddrPort { + a, _ := e.udpConn.LocalAddr().(*net.UDPAddr) return a.AddrPort() } -// Close closes the listener. -// Any blocked operations on the Listener or associated Conns and Stream will be unblocked +// Close closes the Endpoint. +// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked // and return errors. // // Close aborts every open connection. // Data in stream read and write buffers is discarded. // It waits for the peers of any open connection to acknowledge the connection has been closed. -func (l *Listener) Close(ctx context.Context) error { - l.acceptQueue.close(errors.New("listener closed")) - l.connsMu.Lock() - if !l.closing { - l.closing = true - for c := range l.conns { +func (e *Endpoint) Close(ctx context.Context) error { + e.acceptQueue.close(errors.New("endpoint closed")) + e.connsMu.Lock() + if !e.closing { + e.closing = true + for c := range e.conns { c.Abort(localTransportError{code: errNo}) } - if len(l.conns) == 0 { - l.udpConn.Close() + if len(e.conns) == 0 { + e.udpConn.Close() } } - l.connsMu.Unlock() + e.connsMu.Unlock() select { - case <-l.closec: + case <-e.closec: case <-ctx.Done(): - l.connsMu.Lock() - for c := range l.conns { + e.connsMu.Lock() + for c := range e.conns { c.exit() } - l.connsMu.Unlock() + e.connsMu.Unlock() return ctx.Err() } return nil } -// Accept waits for and returns the next connection to the listener. -func (l *Listener) Accept(ctx context.Context) (*Conn, error) { - return l.acceptQueue.get(ctx, nil) +// Accept waits for and returns the next connection. +func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { + return e.acceptQueue.get(ctx, nil) } // Dial creates and returns a connection to a network address. -func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) { +func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) { u, err := net.ResolveUDPAddr(network, address) if err != nil { return nil, err } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) + c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) if err != nil { return nil, err } @@ -151,29 +151,29 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { - l.connsMu.Lock() - defer l.connsMu.Unlock() - if l.closing { - return nil, errors.New("listener closed") +func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { + e.connsMu.Lock() + defer e.connsMu.Unlock() + if e.closing { + return nil, errors.New("endpoint closed") } - c, err := newConn(now, side, cids, peerAddr, l.config, l) + c, err := newConn(now, side, cids, peerAddr, e.config, e) if err != nil { return nil, err } - l.conns[c] = struct{}{} + e.conns[c] = struct{}{} return c, nil } // serverConnEstablished is called by a conn when the handshake completes // for an inbound (serverSide) connection. -func (l *Listener) serverConnEstablished(c *Conn) { - l.acceptQueue.put(c) +func (e *Endpoint) serverConnEstablished(c *Conn) { + e.acceptQueue.put(c) } // connDrained is called by a conn when it leaves the draining state, // either when the peer acknowledges connection closure or the drain timeout expires. -func (l *Listener) connDrained(c *Conn) { +func (e *Endpoint) connDrained(c *Conn) { var cids [][]byte for i := range c.connIDState.local { cids = append(cids, c.connIDState.local[i].cid) @@ -182,7 +182,7 @@ func (l *Listener) connDrained(c *Conn) { for i := range c.connIDState.remote { tokens = append(tokens, c.connIDState.remote[i].resetToken) } - l.connsMap.updateConnIDs(func(conns *connsMap) { + e.connsMap.updateConnIDs(func(conns *connsMap) { for _, cid := range cids { conns.retireConnID(c, cid) } @@ -190,60 +190,60 @@ func (l *Listener) connDrained(c *Conn) { conns.retireResetToken(c, token) } }) - l.connsMu.Lock() - defer l.connsMu.Unlock() - delete(l.conns, c) - if l.closing && len(l.conns) == 0 { - l.udpConn.Close() + e.connsMu.Lock() + defer e.connsMu.Unlock() + delete(e.conns, c) + if e.closing && len(e.conns) == 0 { + e.udpConn.Close() } } -func (l *Listener) listen() { - defer close(l.closec) +func (e *Endpoint) listen() { + defer close(e.closec) for { m := newDatagram() // TODO: Read and process the ECN (explicit congestion notification) field. // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4 - n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil) + n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil) if err != nil { - // The user has probably closed the listener. + // The user has probably closed the endpoint. // We currently don't surface errors from other causes; - // we could check to see if the listener has been closed and + // we could check to see if the endpoint has been closed and // record the unexpected error if it has not. return } if n == 0 { continue } - if l.connsMap.updateNeeded.Load() { - l.connsMap.applyUpdates() + if e.connsMap.updateNeeded.Load() { + e.connsMap.applyUpdates() } m.addr = addr m.b = m.b[:n] - l.handleDatagram(m) + e.handleDatagram(m) } } -func (l *Listener) handleDatagram(m *datagram) { +func (e *Endpoint) handleDatagram(m *datagram) { dstConnID, ok := dstConnIDForDatagram(m.b) if !ok { m.recycle() return } - c := l.connsMap.byConnID[string(dstConnID)] + c := e.connsMap.byConnID[string(dstConnID)] if c == nil { // TODO: Move this branch into a separate goroutine to avoid blocking - // the listener while processing packets. - l.handleUnknownDestinationDatagram(m) + // the endpoint while processing packets. + e.handleUnknownDestinationDatagram(m) return } - // TODO: This can block the listener while waiting for the conn to accept the dgram. + // TODO: This can block the endpoint while waiting for the conn to accept the dgram. // Think about buffering between the receive loop and the conn. c.sendMsg(m) } -func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { +func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { defer func() { if m != nil { m.recycle() @@ -254,15 +254,15 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { return } var now time.Time - if l.testHooks != nil { - now = l.testHooks.timeNow() + if e.testHooks != nil { + now = e.testHooks.timeNow() } else { now = time.Now() } // Check to see if this is a stateless reset. var token statelessResetToken copy(token[:], m.b[len(m.b)-len(token):]) - if c := l.connsMap.byResetToken[token]; c != nil { + if c := e.connsMap.byResetToken[token]; c != nil { c.sendMsg(func(now time.Time, c *Conn) { c.handleStatelessReset(now, token) }) @@ -271,7 +271,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { // If this is a 1-RTT packet, there's nothing productive we can do with it. // Send a stateless reset if possible. if !isLongHeader(m.b[0]) { - l.maybeSendStatelessReset(m.b, m.addr) + e.maybeSendStatelessReset(m.b, m.addr) return } p, ok := parseGenericLongHeaderPacket(m.b) @@ -285,7 +285,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { return default: // Unknown version. - l.sendVersionNegotiation(p, m.addr) + e.sendVersionNegotiation(p, m.addr) return } if getPacketType(m.b) != packetTypeInitial { @@ -300,10 +300,10 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { srcConnID: p.srcConnID, dstConnID: p.dstConnID, } - if l.config.RequireAddressValidation { + if e.config.RequireAddressValidation { var ok bool cids.retrySrcConnID = p.dstConnID - cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr) if !ok { return } @@ -311,7 +311,7 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { cids.originalDstConnID = p.dstConnID } var err error - c, err := l.newConn(now, serverSide, cids, m.addr) + c, err := e.newConn(now, serverSide, cids, m.addr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. @@ -323,8 +323,8 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { m = nil // don't recycle, sendMsg takes ownership } -func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { - if !l.resetGen.canReset { +func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { + if !e.resetGen.canReset { // Config.StatelessResetKey isn't set, so we don't send stateless resets. return } @@ -339,7 +339,7 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { } // TODO: Rate limit stateless resets. cid := b[1:][:connIDLen] - token := l.resetGen.tokenForConnID(cid) + token := e.resetGen.tokenForConnID(cid) // We want to generate a stateless reset that is as short as possible, // but long enough to be difficult to distinguish from a 1-RTT packet. // @@ -364,17 +364,17 @@ func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { b[0] &^= headerFormLong // clear long header bit b[0] |= fixedBit // set fixed bit copy(b[len(b)-statelessResetTokenLen:], token[:]) - l.sendDatagram(b, addr) + e.sendDatagram(b, addr) } -func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { m := newDatagram() m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) - l.sendDatagram(m.b, addr) + e.sendDatagram(m.b, addr) m.recycle() } -func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { +func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { keys := initialKeys(in.dstConnID, serverSide) var w packetWriter p := longPacket{ @@ -393,15 +393,15 @@ func (l *Listener) sendConnectionClose(in genericLongPacket, addr netip.AddrPort if len(buf) == 0 { return } - l.sendDatagram(buf, addr) + e.sendDatagram(buf, addr) } -func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error { - _, err := l.udpConn.WriteToUDPAddrPort(p, addr) +func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error { + _, err := e.udpConn.WriteToUDPAddrPort(p, addr) return err } -// A connsMap is a listener's mapping of conn ids and reset tokens to conns. +// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. type connsMap struct { byConnID map[string]*Conn byResetToken map[statelessResetToken]*Conn diff --git a/internal/quic/listener_test.go b/internal/quic/endpoint_test.go similarity index 53% rename from internal/quic/listener_test.go rename to internal/quic/endpoint_test.go index 037fb21b4..f9fc80152 100644 --- a/internal/quic/listener_test.go +++ b/internal/quic/endpoint_test.go @@ -64,39 +64,39 @@ func TestStreamTransfer(t *testing.T) { func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() - l1 := newLocalListener(t, serverSide, conf1) - l2 := newLocalListener(t, clientSide, conf2) - c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String()) + e1 := newLocalEndpoint(t, serverSide, conf1) + e2 := newLocalEndpoint(t, clientSide, conf2) + c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String()) if err != nil { t.Fatal(err) } - c1, err := l1.Accept(ctx) + c1, err := e1.Accept(ctx) if err != nil { t.Fatal(err) } return c2, c1 } -func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener { +func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { t.Helper() if conf.TLSConfig == nil { newConf := *conf conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } - l, err := Listen("udp", "127.0.0.1:0", conf) + e, err := Listen("udp", "127.0.0.1:0", conf) if err != nil { t.Fatal(err) } t.Cleanup(func() { - l.Close(context.Background()) + e.Close(context.Background()) }) - return l + return e } -type testListener struct { +type testEndpoint struct { t *testing.T - l *Listener + e *Endpoint now time.Time recvc chan *datagram idlec chan struct{} @@ -109,8 +109,8 @@ type testListener struct { lastInitialDstConnID []byte // for parsing Retry packets } -func newTestListener(t *testing.T, config *Config) *testListener { - tl := &testListener{ +func newTestEndpoint(t *testing.T, config *Config) *testEndpoint { + te := &testEndpoint{ t: t, now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), recvc: make(chan *datagram), @@ -118,52 +118,52 @@ func newTestListener(t *testing.T, config *Config) *testListener { conns: make(map[*Conn]*testConn), } var err error - tl.l, err = newListener((*testListenerUDPConn)(tl), config, (*testListenerHooks)(tl)) + te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te)) if err != nil { t.Fatal(err) } - t.Cleanup(tl.cleanup) - return tl + t.Cleanup(te.cleanup) + return te } -func (tl *testListener) cleanup() { - tl.l.Close(canceledContext()) +func (te *testEndpoint) cleanup() { + te.e.Close(canceledContext()) } -func (tl *testListener) wait() { +func (te *testEndpoint) wait() { select { - case tl.idlec <- struct{}{}: - case <-tl.l.closec: + case te.idlec <- struct{}{}: + case <-te.e.closec: } - for _, tc := range tl.conns { + for _, tc := range te.conns { tc.wait() } } -// accept returns a server connection from the listener. -// Unlike Listener.Accept, connections are available as soon as they are created. -func (tl *testListener) accept() *testConn { - if len(tl.acceptQueue) == 0 { - tl.t.Fatalf("accept: expected available conn, but found none") +// accept returns a server connection from the endpoint. +// Unlike Endpoint.Accept, connections are available as soon as they are created. +func (te *testEndpoint) accept() *testConn { + if len(te.acceptQueue) == 0 { + te.t.Fatalf("accept: expected available conn, but found none") } - tc := tl.acceptQueue[0] - tl.acceptQueue = tl.acceptQueue[1:] + tc := te.acceptQueue[0] + te.acceptQueue = te.acceptQueue[1:] return tc } -func (tl *testListener) write(d *datagram) { - tl.recvc <- d - tl.wait() +func (te *testEndpoint) write(d *datagram) { + te.recvc <- d + te.wait() } var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000") -func (tl *testListener) writeDatagram(d *testDatagram) { - tl.t.Helper() - logDatagram(tl.t, "<- listener under test receives", d) +func (te *testEndpoint) writeDatagram(d *testDatagram) { + te.t.Helper() + logDatagram(te.t, "<- endpoint under test receives", d) var buf []byte for _, p := range d.packets { - tc := tl.connForDestination(p.dstConnID) + tc := te.connForDestination(p.dstConnID) if p.ptype != packetTypeRetry && tc != nil { space := spaceForPacketType(p.ptype) if p.num >= tc.peerNextPacketNum[space] { @@ -171,13 +171,13 @@ func (tl *testListener) writeDatagram(d *testDatagram) { } } if p.ptype == packetTypeInitial { - tl.lastInitialDstConnID = p.dstConnID + te.lastInitialDstConnID = p.dstConnID } pad := 0 if p.ptype == packetType1RTT { pad = d.paddedSize - len(buf) } - buf = append(buf, encodeTestPacket(tl.t, tc, p, pad)...) + buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...) } for len(buf) < d.paddedSize { buf = append(buf, 0) @@ -186,14 +186,14 @@ func (tl *testListener) writeDatagram(d *testDatagram) { if !addr.IsValid() { addr = testClientAddr } - tl.write(&datagram{ + te.write(&datagram{ b: buf, addr: addr, }) } -func (tl *testListener) connForDestination(dstConnID []byte) *testConn { - for _, tc := range tl.conns { +func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn { + for _, tc := range te.conns { for _, loc := range tc.conn.connIDState.local { if bytes.Equal(loc.cid, dstConnID) { return tc @@ -203,8 +203,8 @@ func (tl *testListener) connForDestination(dstConnID []byte) *testConn { return nil } -func (tl *testListener) connForSource(srcConnID []byte) *testConn { - for _, tc := range tl.conns { +func (te *testEndpoint) connForSource(srcConnID []byte) *testConn { + for _, tc := range te.conns { for _, loc := range tc.conn.connIDState.remote { if bytes.Equal(loc.cid, srcConnID) { return tc @@ -214,106 +214,106 @@ func (tl *testListener) connForSource(srcConnID []byte) *testConn { return nil } -func (tl *testListener) read() []byte { - tl.t.Helper() - tl.wait() - if len(tl.sentDatagrams) == 0 { +func (te *testEndpoint) read() []byte { + te.t.Helper() + te.wait() + if len(te.sentDatagrams) == 0 { return nil } - d := tl.sentDatagrams[0] - tl.sentDatagrams = tl.sentDatagrams[1:] + d := te.sentDatagrams[0] + te.sentDatagrams = te.sentDatagrams[1:] return d } -func (tl *testListener) readDatagram() *testDatagram { - tl.t.Helper() - buf := tl.read() +func (te *testEndpoint) readDatagram() *testDatagram { + te.t.Helper() + buf := te.read() if buf == nil { return nil } p, _ := parseGenericLongHeaderPacket(buf) - tc := tl.connForSource(p.dstConnID) - d := parseTestDatagram(tl.t, tl, tc, buf) - logDatagram(tl.t, "-> listener under test sends", d) + tc := te.connForSource(p.dstConnID) + d := parseTestDatagram(te.t, te, tc, buf) + logDatagram(te.t, "-> endpoint under test sends", d) return d } -// wantDatagram indicates that we expect the Listener to send a datagram. -func (tl *testListener) wantDatagram(expectation string, want *testDatagram) { - tl.t.Helper() - got := tl.readDatagram() +// wantDatagram indicates that we expect the Endpoint to send a datagram. +func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) { + te.t.Helper() + got := te.readDatagram() if !reflect.DeepEqual(got, want) { - tl.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) + te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } -// wantIdle indicates that we expect the Listener to not send any more datagrams. -func (tl *testListener) wantIdle(expectation string) { - if got := tl.readDatagram(); got != nil { - tl.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got) +// wantIdle indicates that we expect the Endpoint to not send any more datagrams. +func (te *testEndpoint) wantIdle(expectation string) { + if got := te.readDatagram(); got != nil { + te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got) } } // advance causes time to pass. -func (tl *testListener) advance(d time.Duration) { - tl.t.Helper() - tl.advanceTo(tl.now.Add(d)) +func (te *testEndpoint) advance(d time.Duration) { + te.t.Helper() + te.advanceTo(te.now.Add(d)) } // advanceTo sets the current time. -func (tl *testListener) advanceTo(now time.Time) { - tl.t.Helper() - if tl.now.After(now) { - tl.t.Fatalf("time moved backwards: %v -> %v", tl.now, now) +func (te *testEndpoint) advanceTo(now time.Time) { + te.t.Helper() + if te.now.After(now) { + te.t.Fatalf("time moved backwards: %v -> %v", te.now, now) } - tl.now = now - for _, tc := range tl.conns { - if !tc.timer.After(tl.now) { + te.now = now + for _, tc := range te.conns { + if !tc.timer.After(te.now) { tc.conn.sendMsg(timerEvent{}) tc.wait() } } } -// testListenerHooks implements listenerTestHooks. -type testListenerHooks testListener +// testEndpointHooks implements endpointTestHooks. +type testEndpointHooks testEndpoint -func (tl *testListenerHooks) timeNow() time.Time { - return tl.now +func (te *testEndpointHooks) timeNow() time.Time { + return te.now } -func (tl *testListenerHooks) newConn(c *Conn) { - tc := newTestConnForConn(tl.t, (*testListener)(tl), c) - tl.conns[c] = tc +func (te *testEndpointHooks) newConn(c *Conn) { + tc := newTestConnForConn(te.t, (*testEndpoint)(te), c) + te.conns[c] = tc } -// testListenerUDPConn implements UDPConn. -type testListenerUDPConn testListener +// testEndpointUDPConn implements UDPConn. +type testEndpointUDPConn testEndpoint -func (tl *testListenerUDPConn) Close() error { - close(tl.recvc) +func (te *testEndpointUDPConn) Close() error { + close(te.recvc) return nil } -func (tl *testListenerUDPConn) LocalAddr() net.Addr { +func (te *testEndpointUDPConn) LocalAddr() net.Addr { return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443")) } -func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { +func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { for { select { - case d, ok := <-tl.recvc: + case d, ok := <-te.recvc: if !ok { return 0, 0, 0, netip.AddrPort{}, io.EOF } n = copy(b, d.b) return n, 0, 0, d.addr, nil - case <-tl.idlec: + case <-te.idlec: } } } -func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...)) +func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...)) return len(b), nil } diff --git a/internal/quic/idle.go b/internal/quic/idle.go new file mode 100644 index 000000000..f5b2422ad --- /dev/null +++ b/internal/quic/idle.go @@ -0,0 +1,170 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "time" +) + +// idleState tracks connection idle events. +// +// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout. +// +// After the handshake is confirmed, the idle timeout is +// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter. +// +// If KeepAlivePeriod is set, keep-alive pings are sent. +// Keep-alives are only sent after the handshake is confirmed. +// +// https://www.rfc-editor.org/rfc/rfc9000#section-10.1 +type idleState struct { + // idleDuration is the negotiated idle timeout for the connection. + idleDuration time.Duration + + // idleTimeout is the time at which the connection will be closed due to inactivity. + idleTimeout time.Time + + // nextTimeout is the time of the next idle event. + // If nextTimeout == idleTimeout, this is the idle timeout. + // Otherwise, this is the keep-alive timeout. + nextTimeout time.Time + + // sentSinceLastReceive is set if we have sent an ack-eliciting packet + // since the last time we received and processed a packet from the peer. + sentSinceLastReceive bool +} + +// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter. +func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) { + localMaxIdleTimeout := c.config.maxIdleTimeout() + switch { + case localMaxIdleTimeout == 0: + c.idle.idleDuration = peerMaxIdleTimeout + case peerMaxIdleTimeout == 0: + c.idle.idleDuration = localMaxIdleTimeout + default: + c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout) + } +} + +func (c *Conn) idleHandlePacketReceived(now time.Time) { + if !c.handshakeConfirmed.isSet() { + return + } + // "An endpoint restarts its idle timer when a packet from its peer is + // received and processed successfully." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + c.idle.sentSinceLastReceive = false + c.restartIdleTimer(now) +} + +func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) { + // "An endpoint also restarts its idle timer when sending an ack-eliciting packet + // if no other ack-eliciting packets have been sent since + // last receiving and processing a packet." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() { + return + } + c.idle.sentSinceLastReceive = true + c.restartIdleTimer(now) +} + +func (c *Conn) restartIdleTimer(now time.Time) { + if !c.isAlive() { + // Connection is closing, disable timeouts. + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + return + } + var idleDuration time.Duration + if c.handshakeConfirmed.isSet() { + idleDuration = c.idle.idleDuration + } else { + idleDuration = c.config.handshakeTimeout() + } + if idleDuration == 0 { + c.idle.idleTimeout = time.Time{} + } else { + // "[...] endpoints MUST increase the idle timeout period to be + // at least three times the current Probe Timeout (PTO)." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4 + idleDuration = max(idleDuration, 3*c.loss.ptoPeriod()) + c.idle.idleTimeout = now.Add(idleDuration) + } + // Set the time of our next event: + // The idle timer if no keep-alive is set, or the keep-alive timer if one is. + c.idle.nextTimeout = c.idle.idleTimeout + keepAlive := c.config.keepAlivePeriod() + switch { + case !c.handshakeConfirmed.isSet(): + // We do not send keep-alives before the handshake is complete. + case keepAlive <= 0: + // Keep-alives are not enabled. + case c.idle.sentSinceLastReceive: + // We have sent an ack-eliciting packet to the peer. + // If they don't acknowledge it, loss detection will follow up with PTO probes, + // which will function as keep-alives. + // We don't need to send further pings. + case idleDuration == 0: + // The connection does not have a negotiated idle timeout. + // Send keep-alives anyway, since they may be required to keep middleboxes + // from losing state. + c.idle.nextTimeout = now.Add(keepAlive) + default: + // Schedule our next keep-alive. + // If our configured keep-alive period is greater than half the negotiated + // connection idle timeout, we reduce the keep-alive period to half + // the idle timeout to ensure we have time for the ping to arrive. + c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2)) + } +} + +func (c *Conn) appendKeepAlive(now time.Time) bool { + if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) { + return true // timer has not expired + } + if c.idle.nextTimeout.Equal(c.idle.idleTimeout) { + return true // no keepalive timer set, only idle + } + if c.idle.sentSinceLastReceive { + return true // already sent an ack-eliciting packet + } + if c.w.sent.ackEliciting { + return true // this packet is already ack-eliciting + } + // Send an ack-eliciting PING frame to the peer to keep the connection alive. + return c.w.appendPingFrame() +} + +var errHandshakeTimeout error = localTransportError{ + code: errConnectionRefused, + reason: "handshake timeout", +} + +func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) { + if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) { + return false + } + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + if !c.handshakeConfirmed.isSet() { + // Handshake timeout has expired. + // If we're a server, we're refusing the too-slow client. + // If we're a client, we're giving up. + // In either case, we're going to send a CONNECTION_CLOSE frame and + // enter the closing state rather than unceremoniously dropping the connection, + // since the peer might still be trying to complete the handshake. + c.abort(now, errHandshakeTimeout) + return false + } + // Idle timeout has expired. + // + // "[...] the connection is silently closed and its state is discarded [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 + return true +} diff --git a/internal/quic/idle_test.go b/internal/quic/idle_test.go new file mode 100644 index 000000000..18f6a690a --- /dev/null +++ b/internal/quic/idle_test.go @@ -0,0 +1,225 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "crypto/tls" + "fmt" + "testing" + "time" +) + +func TestHandshakeTimeoutExpiresServer(t *testing.T) { + const timeout = 5 * time.Second + tc := newTestConn(t, serverSide, func(c *Config) { + c.HandshakeTimeout = timeout + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeNewConnectionID) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + // Server starts its end of the handshake. + // Client acks these packets to avoid starting the PTO timer. + tc.wantFrameType("server sends Initial CRYPTO flight", + packetTypeInitial, debugFrameCrypto{}) + tc.writeAckForAll() + tc.wantFrameType("server sends Handshake CRYPTO flight", + packetTypeHandshake, debugFrameCrypto{}) + tc.writeAckForAll() + + if got, want := tc.timerDelay(), timeout; got != want { + t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) + } + + // Client sends a packet, but this does not extend the handshake timer. + tc.advance(1 * time.Second) + tc.writeFrames(packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], // partial data + }) + tc.wantIdle("handshake is not complete") + + tc.advance(timeout - 1*time.Second) + tc.wantFrame("server closes connection after handshake timeout", + packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) +} + +func TestHandshakeTimeoutExpiresClient(t *testing.T) { + const timeout = 5 * time.Second + tc := newTestConn(t, clientSide, func(c *Config) { + c.HandshakeTimeout = timeout + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeNewConnectionID) + // Start the handshake. + // The client always sets a PTO timer until it gets an ack for a handshake packet + // or confirms the handshake, so proceed far enough through the handshake to + // let us not worry about PTO. + tc.wantFrameType("client sends Initial CRYPTO flight", + packetTypeInitial, debugFrameCrypto{}) + tc.writeAckForAll() + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("client sends Handshake CRYPTO flight", + packetTypeHandshake, debugFrameCrypto{}) + tc.writeAckForAll() + tc.wantIdle("client is waiting for end of handshake") + + if got, want := tc.timerDelay(), timeout; got != want { + t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) + } + tc.advance(timeout) + tc.wantFrame("client closes connection after handshake timeout", + packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) +} + +func TestIdleTimeoutExpires(t *testing.T) { + for _, test := range []struct { + localMaxIdleTimeout time.Duration + peerMaxIdleTimeout time.Duration + wantTimeout time.Duration + }{{ + localMaxIdleTimeout: 10 * time.Second, + peerMaxIdleTimeout: 20 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 20 * time.Second, + peerMaxIdleTimeout: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 0, + peerMaxIdleTimeout: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + localMaxIdleTimeout: 10 * time.Second, + peerMaxIdleTimeout: 0, + wantTimeout: 10 * time.Second, + }} { + name := fmt.Sprintf("local=%v/peer=%v", test.localMaxIdleTimeout, test.peerMaxIdleTimeout) + t.Run(name, func(t *testing.T) { + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.maxIdleTimeout = test.peerMaxIdleTimeout + }, func(c *Config) { + c.MaxIdleTimeout = test.localMaxIdleTimeout + }) + tc.handshake() + if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { + t.Errorf("new conn timeout=%v, want %v (idle timeout)", got, want) + } + tc.advance(test.wantTimeout - 1) + tc.wantIdle("connection is idle and alive prior to timeout") + ctx := canceledContext() + if err := tc.conn.Wait(ctx); err != context.Canceled { + t.Fatalf("conn.Wait() = %v, want Canceled", err) + } + tc.advance(1) + tc.wantIdle("connection exits after timeout") + if err := tc.conn.Wait(ctx); err != errIdleTimeout { + t.Fatalf("conn.Wait() = %v, want errIdleTimeout", err) + } + }) + } +} + +func TestIdleTimeoutKeepAlive(t *testing.T) { + for _, test := range []struct { + idleTimeout time.Duration + keepAlive time.Duration + wantTimeout time.Duration + }{{ + idleTimeout: 30 * time.Second, + keepAlive: 10 * time.Second, + wantTimeout: 10 * time.Second, + }, { + idleTimeout: 10 * time.Second, + keepAlive: 30 * time.Second, + wantTimeout: 5 * time.Second, + }, { + idleTimeout: -1, // disabled + keepAlive: 30 * time.Second, + wantTimeout: 30 * time.Second, + }} { + name := fmt.Sprintf("idle_timeout=%v/keepalive=%v", test.idleTimeout, test.keepAlive) + t.Run(name, func(t *testing.T) { + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxIdleTimeout = test.idleTimeout + c.KeepAlivePeriod = test.keepAlive + }) + tc.handshake() + if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { + t.Errorf("new conn timeout=%v, want %v (keepalive timeout)", got, want) + } + tc.advance(test.wantTimeout - 1) + tc.wantIdle("connection is idle prior to timeout") + tc.advance(1) + tc.wantFrameType("keep-alive ping is sent", packetType1RTT, + debugFramePing{}) + }) + } +} + +func TestIdleLongTermKeepAliveSent(t *testing.T) { + // This test examines a connection sitting idle and sending periodic keep-alive pings. + const keepAlivePeriod = 30 * time.Second + tc := newTestConn(t, clientSide, func(c *Config) { + c.KeepAlivePeriod = keepAlivePeriod + c.MaxIdleTimeout = -1 + }) + tc.handshake() + // The handshake will have completed a little bit after the point at which the + // keepalive timer was set. Send two PING frames to the conn, triggering an immediate ack + // and resetting the timer. + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantFrameType("conn acks received pings", packetType1RTT, debugFrameAck{}) + for i := 0; i < 10; i++ { + tc.wantIdle("conn has nothing more to send") + if got, want := tc.timeUntilEvent(), keepAlivePeriod; got != want { + t.Errorf("i=%v conn timeout=%v, want %v (keepalive timeout)", i, got, want) + } + tc.advance(keepAlivePeriod) + tc.wantFrameType("keep-alive ping is sent", packetType1RTT, + debugFramePing{}) + tc.writeAckForAll() + } +} + +func TestIdleLongTermKeepAliveReceived(t *testing.T) { + // This test examines a connection sitting idle, but receiving periodic peer + // traffic to keep the connection alive. + const idleTimeout = 30 * time.Second + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxIdleTimeout = idleTimeout + }) + tc.handshake() + for i := 0; i < 10; i++ { + tc.advance(idleTimeout - 1*time.Second) + tc.writeFrames(packetType1RTT, debugFramePing{}) + if got, want := tc.timeUntilEvent(), maxAckDelay-timerGranularity; got != want { + t.Errorf("i=%v conn timeout=%v, want %v (max_ack_delay)", i, got, want) + } + tc.advanceToTimer() + tc.wantFrameType("conn acks received ping", packetType1RTT, debugFrameAck{}) + } + // Connection is still alive. + ctx := canceledContext() + if err := tc.conn.Wait(ctx); err != context.Canceled { + t.Fatalf("conn.Wait() = %v, want Canceled", err) + } +} diff --git a/internal/quic/loss.go b/internal/quic/loss.go index c0f915b42..4a0767bd0 100644 --- a/internal/quic/loss.go +++ b/internal/quic/loss.go @@ -431,12 +431,15 @@ func (c *lossState) scheduleTimer(now time.Time) { c.timer = time.Time{} return } - // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 - pto := c.ptoBasePeriod() << c.ptoBackoffCount - c.timer = last.Add(pto) + c.timer = last.Add(c.ptoPeriod()) c.ptoTimerArmed = true } +func (c *lossState) ptoPeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + return c.ptoBasePeriod() << c.ptoBackoffCount +} + func (c *lossState) ptoBasePeriod() time.Duration { // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity) diff --git a/internal/quic/qlog.go b/internal/quic/qlog.go new file mode 100644 index 000000000..ea53cab1e --- /dev/null +++ b/internal/quic/qlog.go @@ -0,0 +1,147 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "encoding/hex" + "log/slog" + "net/netip" +) + +// Log levels for qlog events. +const ( + // QLogLevelFrame includes per-frame information. + // When this level is enabled, packet_sent and packet_received events will + // contain information on individual frames sent/received. + QLogLevelFrame = slog.Level(-6) + + // QLogLevelPacket events occur at most once per packet sent or received. + // + // For example: packet_sent, packet_received. + QLogLevelPacket = slog.Level(-4) + + // QLogLevelConn events occur multiple times over a connection's lifetime, + // but less often than the frequency of individual packets. + // + // For example: connection_state_updated. + QLogLevelConn = slog.Level(-2) + + // QLogLevelEndpoint events occur at most once per connection. + // + // For example: connection_started, connection_closed. + QLogLevelEndpoint = slog.Level(0) +) + +func (c *Conn) logEnabled(level slog.Level) bool { + return c.log != nil && c.log.Enabled(context.Background(), level) +} + +// slogHexstring returns a slog.Attr for a value of the hexstring type. +// +// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1 +func slogHexstring(key string, value []byte) slog.Attr { + return slog.String(key, hex.EncodeToString(value)) +} + +func slogAddr(key string, value netip.Addr) slog.Attr { + return slog.String(key, value.String()) +} + +func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) { + if c.config.QLogLogger == nil || + !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) { + return + } + var vantage string + if c.side == clientSide { + vantage = "client" + originalDstConnID = c.connIDState.originalDstConnID + } else { + vantage = "server" + } + // A qlog Trace container includes some metadata (title, description, vantage_point) + // and a list of Events. The Trace also includes a common_fields field setting field + // values common to all events in the trace. + // + // Trace = { + // ? title: text + // ? description: text + // ? configuration: Configuration + // ? common_fields: CommonFields + // ? vantage_point: VantagePoint + // events: [* Event] + // } + // + // To map this into slog's data model, we start each per-connection trace with a With + // call that includes both the trace metadata and the common fields. + // + // This means that in slog's model, each trace event will also include + // the Trace metadata fields (vantage_point), which is a divergence from the qlog model. + c.log = c.config.QLogLogger.With( + // The group_id permits associating traces taken from different vantage points + // for the same connection. + // + // We use the original destination connection ID as the group ID. + // + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6 + slogHexstring("group_id", originalDstConnID), + slog.Group("vantage_point", + slog.String("name", "go quic"), + slog.String("type", vantage), + ), + ) + localAddr := c.endpoint.LocalAddr() + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_started", + slogAddr("src_ip", localAddr.Addr()), + slog.Int("src_port", int(localAddr.Port())), + slogHexstring("src_cid", c.connIDState.local[0].cid), + slogAddr("dst_ip", peerAddr.Addr()), + slog.Int("dst_port", int(peerAddr.Port())), + slogHexstring("dst_cid", c.connIDState.remote[0].cid), + ) +} + +func (c *Conn) logConnectionClosed() { + if !c.logEnabled(QLogLevelEndpoint) { + return + } + err := c.lifetime.finalErr + trigger := "error" + switch e := err.(type) { + case *ApplicationError: + // TODO: Distinguish between peer and locally-initiated close. + trigger = "application" + case localTransportError: + switch err { + case errHandshakeTimeout: + trigger = "handshake_timeout" + default: + if e.code == errNo { + trigger = "clean" + } + } + case peerTransportError: + if e.code == errNo { + trigger = "clean" + } + default: + switch err { + case errIdleTimeout: + trigger = "idle_timeout" + case errStatelessReset: + trigger = "stateless_reset" + } + } + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_closed", + slog.String("trigger", trigger), + ) +} diff --git a/internal/quic/qlog/handler.go b/internal/quic/qlog/handler.go new file mode 100644 index 000000000..35a66cf8b --- /dev/null +++ b/internal/quic/qlog/handler.go @@ -0,0 +1,76 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "context" + "log/slog" +) + +type withAttrsHandler struct { + attrs []slog.Attr + h slog.Handler +} + +func withAttrs(h slog.Handler, attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &withAttrsHandler{attrs: attrs, h: h} +} + +func (h *withAttrsHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.h.Enabled(ctx, level) +} + +func (h *withAttrsHandler) Handle(ctx context.Context, r slog.Record) error { + r.AddAttrs(h.attrs...) + return h.h.Handle(ctx, r) +} + +func (h *withAttrsHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *withAttrsHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +type withGroupHandler struct { + name string + h slog.Handler +} + +func withGroup(h slog.Handler, name string) slog.Handler { + if name == "" { + return h + } + return &withGroupHandler{name: name, h: h} +} + +func (h *withGroupHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.h.Enabled(ctx, level) +} + +func (h *withGroupHandler) Handle(ctx context.Context, r slog.Record) error { + var attrs []slog.Attr + r.Attrs(func(a slog.Attr) bool { + attrs = append(attrs, a) + return true + }) + nr := slog.NewRecord(r.Time, r.Level, r.Message, r.PC) + nr.Add(slog.Any(h.name, slog.GroupValue(attrs...))) + return h.h.Handle(ctx, nr) +} + +func (h *withGroupHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *withGroupHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} diff --git a/internal/quic/qlog/json_writer.go b/internal/quic/qlog/json_writer.go new file mode 100644 index 000000000..50cf33bc5 --- /dev/null +++ b/internal/quic/qlog/json_writer.go @@ -0,0 +1,194 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "strconv" + "sync" + "time" +) + +// A jsonWriter writes JSON-SEQ (RFC 7464). +// +// A JSON-SEQ file consists of a series of JSON text records, +// each beginning with an RS (0x1e) character and ending with LF (0x0a). +type jsonWriter struct { + mu sync.Mutex + w io.WriteCloser + buf bytes.Buffer +} + +// writeRecordStart writes the start of a JSON-SEQ record. +func (w *jsonWriter) writeRecordStart() { + w.mu.Lock() + w.buf.WriteByte(0x1e) + w.buf.WriteByte('{') +} + +// writeRecordEnd finishes writing a JSON-SEQ record. +func (w *jsonWriter) writeRecordEnd() { + w.buf.WriteByte('}') + w.buf.WriteByte('\n') + w.w.Write(w.buf.Bytes()) + w.buf.Reset() + w.mu.Unlock() +} + +// writeAttrsField writes a []slog.Attr as an object field. +func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) { + w.writeName(name) + w.buf.WriteByte('{') + for _, a := range attrs { + w.writeAttr(a) + } + w.buf.WriteByte('}') +} + +// writeAttr writes a slog.Attr as an object field. +func (w *jsonWriter) writeAttr(a slog.Attr) { + v := a.Value.Resolve() + switch v.Kind() { + case slog.KindAny: + w.writeStringField(a.Key, fmt.Sprint(v.Any())) + case slog.KindBool: + w.writeBoolField(a.Key, v.Bool()) + case slog.KindDuration: + w.writeDurationField(a.Key, v.Duration()) + case slog.KindFloat64: + w.writeFloat64Field(a.Key, v.Float64()) + case slog.KindInt64: + w.writeInt64Field(a.Key, v.Int64()) + case slog.KindString: + w.writeStringField(a.Key, v.String()) + case slog.KindTime: + w.writeTimeField(a.Key, v.Time()) + case slog.KindUint64: + w.writeUint64Field(a.Key, v.Uint64()) + case slog.KindGroup: + w.writeAttrsField(a.Key, v.Group()) + default: + w.writeString("unhandled kind") + } +} + +// writeName writes an object field name followed by a colon. +func (w *jsonWriter) writeName(name string) { + if b := w.buf.Bytes(); len(b) > 0 && b[len(b)-1] != '{' { + // Add the comma separating this from the previous field. + w.buf.WriteByte(',') + } + w.writeString(name) + w.buf.WriteByte(':') +} + +// writeObject writes an object-valued object field. +// The function f is called to write the contents. +func (w *jsonWriter) writeObjectField(name string, f func()) { + w.writeName(name) + w.buf.WriteByte('{') + f() + w.buf.WriteByte('}') +} + +// writeRawField writes an field with a raw JSON value. +func (w *jsonWriter) writeRawField(name, v string) { + w.writeName(name) + w.buf.WriteString(v) +} + +// writeBoolField writes a bool-valued object field. +func (w *jsonWriter) writeBoolField(name string, v bool) { + w.writeName(name) + if v { + w.buf.WriteString("true") + } else { + w.buf.WriteString("false") + } +} + +// writeDurationField writes a millisecond duration-valued object field. +func (w *jsonWriter) writeDurationField(name string, v time.Duration) { + w.writeName(name) + fmt.Fprintf(&w.buf, "%d.%06d", v.Milliseconds(), v%time.Millisecond) +} + +// writeFloat64Field writes an float64-valued object field. +func (w *jsonWriter) writeFloat64Field(name string, v float64) { + w.writeName(name) + w.buf.Write(strconv.AppendFloat(w.buf.AvailableBuffer(), v, 'f', -1, 64)) +} + +// writeInt64Field writes an int64-valued object field. +func (w *jsonWriter) writeInt64Field(name string, v int64) { + w.writeName(name) + w.buf.Write(strconv.AppendInt(w.buf.AvailableBuffer(), v, 10)) +} + +// writeUint64Field writes a uint64-valued object field. +func (w *jsonWriter) writeUint64Field(name string, v uint64) { + w.writeName(name) + w.buf.Write(strconv.AppendUint(w.buf.AvailableBuffer(), v, 10)) +} + +// writeStringField writes a string-valued object field. +func (w *jsonWriter) writeStringField(name, v string) { + w.writeName(name) + w.writeString(v) +} + +// writeTimeField writes a time-valued object field. +func (w *jsonWriter) writeTimeField(name string, v time.Time) { + w.writeName(name) + fmt.Fprintf(&w.buf, "%d.%06d", v.UnixMilli(), v.Nanosecond()%int(time.Millisecond)) +} + +func jsonSafeSet(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c > 128, then 1<>64)) != 0 +} + +func jsonNeedsEscape(s string) bool { + for i := range s { + if !jsonSafeSet(s[i]) { + return true + } + } + return false +} + +// writeString writes an ASCII string. +// +// qlog fields should never contain anything that isn't ASCII, +// so we do the bare minimum to avoid producing invalid output if we +// do write something unexpected. +func (w *jsonWriter) writeString(v string) { + w.buf.WriteByte('"') + if !jsonNeedsEscape(v) { + w.buf.WriteString(v) + } else { + for i := range v { + if jsonSafeSet(v[i]) { + w.buf.WriteByte(v[i]) + } else { + fmt.Fprintf(&w.buf, `\u%04x`, v[i]) + } + } + } + w.buf.WriteByte('"') +} diff --git a/internal/quic/qlog/json_writer_test.go b/internal/quic/qlog/json_writer_test.go new file mode 100644 index 000000000..7ba5e1737 --- /dev/null +++ b/internal/quic/qlog/json_writer_test.go @@ -0,0 +1,186 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "testing" + "time" +) + +type testJSONOut struct { + bytes.Buffer +} + +func (o *testJSONOut) Close() error { return nil } + +func newTestJSONWriter() *jsonWriter { + return &jsonWriter{w: &testJSONOut{}} +} + +func wantJSONRecord(t *testing.T, w *jsonWriter, want string) { + t.Helper() + want = "\x1e" + want + "\n" + got := w.w.(*testJSONOut).String() + if got != want { + t.Errorf("jsonWriter contains unexpected output\ngot: %q\nwant: %q", got, want) + } +} + +func TestJSONWriterWriteConcurrentRecords(t *testing.T) { + w := newTestJSONWriter() + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + w.writeRecordStart() + w.writeInt64Field("field", 0) + w.writeRecordEnd() + }() + } + wg.Wait() + wantJSONRecord(t, w, strings.Join([]string{ + `{"field":0}`, + `{"field":0}`, + `{"field":0}`, + }, "\n\x1e")) +} + +func TestJSONWriterAttrs(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeAttrsField("field", []slog.Attr{ + slog.Any("any", errors.New("value")), + slog.Bool("bool", true), + slog.Duration("duration", 1*time.Second), + slog.Float64("float64", 1), + slog.Int64("int64", 1), + slog.String("string", "value"), + slog.Time("time", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + slog.Uint64("uint64", 1), + slog.Group("group", "a", 1), + }) + w.writeRecordEnd() + wantJSONRecord(t, w, + `{"field":{`+ + `"any":"value",`+ + `"bool":true,`+ + `"duration":1000.000000,`+ + `"float64":1,`+ + `"int64":1,`+ + `"string":"value",`+ + `"time":946684800000.000000,`+ + `"uint64":1,`+ + `"group":{"a":1}`+ + `}}`) +} + +func TestJSONWriterObjectEmpty(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeObjectField("field", func() {}) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":{}}`) +} + +func TestJSONWriterObjectFields(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeObjectField("field", func() { + w.writeStringField("a", "value") + w.writeInt64Field("b", 10) + }) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":{"a":"value","b":10}}`) +} + +func TestJSONWriterRawField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeRawField("field", `[1]`) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":[1]}`) +} + +func TestJSONWriterBoolField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeBoolField("true", true) + w.writeBoolField("false", false) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"true":true,"false":false}`) +} + +func TestJSONWriterDurationField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeDurationField("field", (10*time.Millisecond)+(2*time.Nanosecond)) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":10.000002}`) +} + +func TestJSONWriterFloat64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeFloat64Field("field", 1.1) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1.1}`) +} + +func TestJSONWriterInt64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeInt64Field("field", 1234) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1234}`) +} + +func TestJSONWriterUint64Field(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeUint64Field("field", 1234) + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":1234}`) +} + +func TestJSONWriterStringField(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", "value") + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":"value"}`) +} + +func TestJSONWriterStringFieldEscaped(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", "va\x00ue") + w.writeRecordEnd() + wantJSONRecord(t, w, `{"field":"va\u0000ue"}`) +} + +func TestJSONWriterStringEscaping(t *testing.T) { + for c := 0; c <= 0xff; c++ { + w := newTestJSONWriter() + w.writeRecordStart() + w.writeStringField("field", string([]byte{byte(c)})) + w.writeRecordEnd() + var want string + if (c >= 0x20 && c <= 0x21) || (c >= 0x23 && c <= 0x5b) || (c >= 0x5d && c <= 0x7e) { + want = fmt.Sprintf(`%c`, c) + } else { + want = fmt.Sprintf(`\u%04x`, c) + } + wantJSONRecord(t, w, `{"field":"`+want+`"}`) + } +} diff --git a/internal/quic/qlog/qlog.go b/internal/quic/qlog/qlog.go new file mode 100644 index 000000000..0e71d71aa --- /dev/null +++ b/internal/quic/qlog/qlog.go @@ -0,0 +1,267 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +// Package qlog serializes qlog events. +package qlog + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "os" + "path/filepath" + "sync" + "time" +) + +// Vantage is the vantage point of a trace. +type Vantage string + +const ( + // VantageEndpoint traces contain events not specific to a single connection. + VantageEndpoint = Vantage("endpoint") + + // VantageClient traces follow a connection from the client's perspective. + VantageClient = Vantage("client") + + // VantageClient traces follow a connection from the server's perspective. + VantageServer = Vantage("server") +) + +// TraceInfo contains information about a trace. +type TraceInfo struct { + // Vantage is the vantage point of the trace. + Vantage Vantage + + // GroupID identifies the logical group the trace belongs to. + // For a connection trace, the group will be the same for + // both the client and server vantage points. + GroupID string +} + +// HandlerOptions are options for a JSONHandler. +type HandlerOptions struct { + // Level reports the minimum record level that will be logged. + // If Level is nil, the handler assumes QLogLevelEndpoint. + Level slog.Leveler + + // Dir is the directory in which to create trace files. + // The handler will create one file per connection. + // If NewTrace is non-nil or Dir is "", the handler will not create files. + Dir string + + // NewTrace is called to create a new trace. + // If NewTrace is nil and Dir is set, + // the handler will create a new file in Dir for each trace. + NewTrace func(TraceInfo) (io.WriteCloser, error) +} + +type endpointHandler struct { + opts HandlerOptions + + traceOnce sync.Once + trace *jsonTraceHandler +} + +// NewJSONHandler returns a handler which serializes qlog events to JSON. +// +// The handler will write an endpoint-wide trace, +// and a separate trace for each connection. +// The HandlerOptions control the location traces are written. +// +// It uses the streamable JSON Text Sequences mapping (JSON-SEQ) +// defined in draft-ietf-quic-qlog-main-schema-04, Section 6.2. +// +// A JSONHandler may be used as the handler for a quic.Config.QLogLogger. +// It is not a general-purpose slog handler, +// and may not properly handle events from other sources. +func NewJSONHandler(opts HandlerOptions) slog.Handler { + if opts.Dir == "" && opts.NewTrace == nil { + return slogDiscard{} + } + return &endpointHandler{ + opts: opts, + } +} + +func (h *endpointHandler) Enabled(ctx context.Context, level slog.Level) bool { + return enabled(h.opts.Level, level) +} + +func (h *endpointHandler) Handle(ctx context.Context, r slog.Record) error { + h.traceOnce.Do(func() { + h.trace, _ = newJSONTraceHandler(h.opts, nil) + }) + if h.trace != nil { + h.trace.Handle(ctx, r) + } + return nil +} + +func (h *endpointHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + // Create a new trace output file for each top-level WithAttrs. + tr, err := newJSONTraceHandler(h.opts, attrs) + if err != nil { + return withAttrs(h, attrs) + } + return tr +} + +func (h *endpointHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +type jsonTraceHandler struct { + level slog.Leveler + w jsonWriter + start time.Time + buf bytes.Buffer +} + +func newJSONTraceHandler(opts HandlerOptions, attrs []slog.Attr) (*jsonTraceHandler, error) { + w, err := newTraceWriter(opts, traceInfoFromAttrs(attrs)) + if err != nil { + return nil, err + } + + // For testing, it might be nice to set the start time used for relative timestamps + // to the time of the first event. + // + // At the expense of some additional complexity here, we could defer writing + // the reference_time header field until the first event is processed. + // + // Just use the current time for now. + start := time.Now() + + h := &jsonTraceHandler{ + w: jsonWriter{w: w}, + level: opts.Level, + start: start, + } + h.writeHeader(attrs) + return h, nil +} + +func traceInfoFromAttrs(attrs []slog.Attr) TraceInfo { + info := TraceInfo{ + Vantage: VantageEndpoint, // default if not specified + } + for _, a := range attrs { + if a.Key == "group_id" && a.Value.Kind() == slog.KindString { + info.GroupID = a.Value.String() + } + if a.Key == "vantage_point" && a.Value.Kind() == slog.KindGroup { + for _, aa := range a.Value.Group() { + if aa.Key == "type" && aa.Value.Kind() == slog.KindString { + info.Vantage = Vantage(aa.Value.String()) + } + } + } + } + return info +} + +func newTraceWriter(opts HandlerOptions, info TraceInfo) (io.WriteCloser, error) { + var w io.WriteCloser + var err error + if opts.NewTrace != nil { + w, err = opts.NewTrace(info) + } else if opts.Dir != "" { + var filename string + if info.GroupID != "" { + filename = info.GroupID + "_" + } + filename += string(info.Vantage) + ".sqlog" + if !filepath.IsLocal(filename) { + return nil, errors.New("invalid trace filename") + } + w, err = os.Create(filepath.Join(opts.Dir, filename)) + } else { + err = errors.New("no log destination") + } + return w, err +} + +func (h *jsonTraceHandler) writeHeader(attrs []slog.Attr) { + h.w.writeRecordStart() + defer h.w.writeRecordEnd() + + // At the time of writing this comment the most recent version is 0.4, + // but qvis only supports up to 0.3. + h.w.writeStringField("qlog_version", "0.3") + h.w.writeStringField("qlog_format", "JSON-SEQ") + + // The attrs flatten both common trace event fields and Trace fields. + // This identifies the fields that belong to the Trace. + isTraceSeqField := func(s string) bool { + switch s { + case "title", "description", "configuration", "vantage_point": + return true + } + return false + } + + h.w.writeObjectField("trace", func() { + h.w.writeObjectField("common_fields", func() { + h.w.writeRawField("protocol_type", `["QUIC"]`) + h.w.writeStringField("time_format", "relative") + h.w.writeTimeField("reference_time", h.start) + for _, a := range attrs { + if !isTraceSeqField(a.Key) { + h.w.writeAttr(a) + } + } + }) + for _, a := range attrs { + if isTraceSeqField(a.Key) { + h.w.writeAttr(a) + } + } + }) +} + +func (h *jsonTraceHandler) Enabled(ctx context.Context, level slog.Level) bool { + return enabled(h.level, level) +} + +func (h *jsonTraceHandler) Handle(ctx context.Context, r slog.Record) error { + h.w.writeRecordStart() + defer h.w.writeRecordEnd() + h.w.writeDurationField("time", r.Time.Sub(h.start)) + h.w.writeStringField("name", r.Message) + h.w.writeObjectField("data", func() { + r.Attrs(func(a slog.Attr) bool { + h.w.writeAttr(a) + return true + }) + }) + return nil +} + +func (h *jsonTraceHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return withAttrs(h, attrs) +} + +func (h *jsonTraceHandler) WithGroup(name string) slog.Handler { + return withGroup(h, name) +} + +func enabled(leveler slog.Leveler, level slog.Level) bool { + var minLevel slog.Level + if leveler != nil { + minLevel = leveler.Level() + } + return level >= minLevel +} + +type slogDiscard struct{} + +func (slogDiscard) Enabled(context.Context, slog.Level) bool { return false } +func (slogDiscard) Handle(ctx context.Context, r slog.Record) error { return nil } +func (slogDiscard) WithAttrs(attrs []slog.Attr) slog.Handler { return slogDiscard{} } +func (slogDiscard) WithGroup(name string) slog.Handler { return slogDiscard{} } diff --git a/internal/quic/qlog/qlog_test.go b/internal/quic/qlog/qlog_test.go new file mode 100644 index 000000000..7575cd890 --- /dev/null +++ b/internal/quic/qlog/qlog_test.go @@ -0,0 +1,151 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package qlog + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "reflect" + "testing" + "time" +) + +// QLog tests are mostly in the quic package, where we can test event generation +// and serialization together. + +func TestQLogHandlerEvents(t *testing.T) { + for _, test := range []struct { + name string + f func(*slog.Logger) + want []map[string]any // events, not counting the trace header + }{{ + name: "various types", + f: func(log *slog.Logger) { + log.Info("message", + "bool", true, + "duration", time.Duration(1*time.Second), + "float", 0.0, + "int", 0, + "string", "value", + "uint", uint64(0), + slog.Group("group", + "a", 0, + ), + ) + }, + want: []map[string]any{{ + "name": "message", + "data": map[string]any{ + "bool": true, + "duration": float64(1000), + "float": float64(0.0), + "int": float64(0), + "string": "value", + "uint": float64(0), + "group": map[string]any{ + "a": float64(0), + }, + }, + }}, + }, { + name: "WithAttrs", + f: func(log *slog.Logger) { + log = log.With( + "with_a", "a", + "with_b", "b", + ) + log.Info("m1", "field", "1") + log.Info("m2", "field", "2") + }, + want: []map[string]any{{ + "name": "m1", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "1", + }, + }, { + "name": "m2", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "2", + }, + }}, + }, { + name: "WithGroup", + f: func(log *slog.Logger) { + log = log.With( + "with_a", "a", + "with_b", "b", + ) + log.Info("m1", "field", "1") + log.Info("m2", "field", "2") + }, + want: []map[string]any{{ + "name": "m1", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "1", + }, + }, { + "name": "m2", + "data": map[string]any{ + "with_a": "a", + "with_b": "b", + "field": "2", + }, + }}, + }} { + var out bytes.Buffer + opts := HandlerOptions{ + Level: slog.LevelDebug, + NewTrace: func(TraceInfo) (io.WriteCloser, error) { + return nopCloseWriter{&out}, nil + }, + } + h, err := newJSONTraceHandler(opts, []slog.Attr{ + slog.String("group_id", "group"), + slog.Group("vantage_point", + slog.String("type", "client"), + ), + }) + if err != nil { + t.Fatal(err) + } + log := slog.New(h) + test.f(log) + got := []map[string]any{} + for i, e := range bytes.Split(out.Bytes(), []byte{0x1e}) { + // i==0: empty string before the initial record separator + // i==1: trace header; not part of this test + if i < 2 { + continue + } + var val map[string]any + if err := json.Unmarshal(e, &val); err != nil { + panic(fmt.Errorf("log unmarshal failure: %v\n%q", err, string(e))) + } + delete(val, "time") + got = append(got, val) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("event mismatch\ngot: %v\nwant: %v", got, test.want) + } + } + +} + +type nopCloseWriter struct { + io.Writer +} + +func (nopCloseWriter) Close() error { return nil } diff --git a/internal/quic/qlog_test.go b/internal/quic/qlog_test.go new file mode 100644 index 000000000..119f5d16a --- /dev/null +++ b/internal/quic/qlog_test.go @@ -0,0 +1,202 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "reflect" + "testing" + "time" + + "golang.org/x/net/internal/quic/qlog" +) + +func TestQLogHandshake(t *testing.T) { + testSides(t, "", func(t *testing.T, side connSide) { + qr := &qlogRecord{} + tc := newTestConn(t, side, qr.config) + tc.handshake() + tc.conn.Abort(nil) + tc.wantFrame("aborting connection generates CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{}) + tc.advanceToTimer() // let the conn finish draining + + var src, dst []byte + if side == clientSide { + src = testLocalConnID(0) + dst = testLocalConnID(-1) + } else { + src = testPeerConnID(-1) + dst = testPeerConnID(0) + } + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_started", + "data": map[string]any{ + "src_cid": hex.EncodeToString(src), + "dst_cid": hex.EncodeToString(dst), + }, + }, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "clean", + }, + }) + }) +} + +func TestQLogConnectionClosedTrigger(t *testing.T) { + for _, test := range []struct { + trigger string + connOpts []any + f func(*testConn) + }{{ + trigger: "clean", + f: func(tc *testConn) { + tc.handshake() + tc.conn.Abort(nil) + }, + }, { + trigger: "handshake_timeout", + connOpts: []any{ + func(c *Config) { + c.HandshakeTimeout = 5 * time.Second + }, + }, + f: func(tc *testConn) { + tc.ignoreFrame(frameTypeCrypto) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypePing) + tc.advance(5 * time.Second) + }, + }, { + trigger: "idle_timeout", + connOpts: []any{ + func(c *Config) { + c.MaxIdleTimeout = 5 * time.Second + }, + }, + f: func(tc *testConn) { + tc.handshake() + tc.advance(5 * time.Second) + }, + }, { + trigger: "error", + f: func(tc *testConn) { + tc.handshake() + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) + tc.conn.Abort(nil) + }, + }} { + t.Run(test.trigger, func(t *testing.T) { + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, append(test.connOpts, qr.config)...) + test.f(tc) + fr, ptype := tc.readFrame() + switch fr := fr.(type) { + case debugFrameConnectionCloseTransport: + tc.writeFrames(ptype, fr) + case nil: + default: + t.Fatalf("unexpected frame: %v", fr) + } + tc.wantIdle("connection should be idle while closing") + tc.advance(5 * time.Second) // long enough for the drain timer to expire + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": test.trigger, + }, + }) + }) + } +} + +type nopCloseWriter struct { + io.Writer +} + +func (nopCloseWriter) Close() error { return nil } + +type jsonEvent map[string]any + +func (j jsonEvent) String() string { + b, _ := json.MarshalIndent(j, "", " ") + return string(b) +} + +// eventPartialEqual verifies that every field set in want matches the corresponding field in got. +// It ignores additional fields in got. +func eventPartialEqual(got, want jsonEvent) bool { + for k := range want { + ge, gok := got[k].(map[string]any) + we, wok := want[k].(map[string]any) + if gok && wok { + if !eventPartialEqual(ge, we) { + return false + } + } else { + if !reflect.DeepEqual(got[k], want[k]) { + return false + } + } + } + return true +} + +// A qlogRecord records events. +type qlogRecord struct { + ev []jsonEvent +} + +func (q *qlogRecord) Write(b []byte) (int, error) { + // This relies on the property that the Handler always makes one Write call per event. + if len(b) < 1 || b[0] != 0x1e { + panic(fmt.Errorf("trace Write should start with record separator, got %q", string(b))) + } + var val map[string]any + if err := json.Unmarshal(b[1:], &val); err != nil { + panic(fmt.Errorf("log unmarshal failure: %v\n%v", err, string(b))) + } + q.ev = append(q.ev, val) + return len(b), nil +} + +func (q *qlogRecord) Close() error { return nil } + +// config may be passed to newTestConn to configure the conn to use this logger. +func (q *qlogRecord) config(c *Config) { + c.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ + NewTrace: func(info qlog.TraceInfo) (io.WriteCloser, error) { + return q, nil + }, + })) +} + +// wantEvents checks that every event in want occurs in the order specified. +func (q *qlogRecord) wantEvents(t *testing.T, want ...jsonEvent) { + t.Helper() + got := q.ev + unseen := want + for _, g := range got { + if eventPartialEqual(g, unseen[0]) { + unseen = unseen[1:] + if len(unseen) == 0 { + return + } + } + } + t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want) +} diff --git a/internal/quic/quic.go b/internal/quic/quic.go index 084887be6..e4d0d77c7 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -54,14 +54,24 @@ const ( maxPeerActiveConnIDLimit = 4 ) +// Time limit for completing the handshake. +const defaultHandshakeTimeout = 10 * time.Second + +// Keep-alive ping frequency. +const defaultKeepAlivePeriod = 0 + // Local timer granularity. // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6 const timerGranularity = 1 * time.Millisecond +// The smallest allowed maximum datagram size. +// https://www.rfc-editor.org/rfc/rfc9000#section-14 +const smallestMaxDatagramSize = 1200 + // Minimum size of a UDP datagram sent by a client carrying an Initial packet, // or a server containing an ack-eliciting Initial packet. // https://www.rfc-editor.org/rfc/rfc9000#section-14.1 -const paddedInitialDatagramSize = 1200 +const paddedInitialDatagramSize = smallestMaxDatagramSize // Maximum number of streams of a given type which may be created. // https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 diff --git a/internal/quic/retry.go b/internal/quic/retry.go index e3d9f4d7d..31cb57b88 100644 --- a/internal/quic/retry.go +++ b/internal/quic/retry.go @@ -39,7 +39,7 @@ var ( // retryTokenValidityPeriod is how long we accept a Retry packet token after sending it. const retryTokenValidityPeriod = 5 * time.Second -// retryState generates and validates a listener's retry tokens. +// retryState generates and validates an endpoint's retry tokens. type retryState struct { aead cipher.AEAD } @@ -139,7 +139,7 @@ func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []by return additional } -func (l *Listener) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { // The retry token is at the start of an Initial packet's data. token, n := consumeUint8Bytes(p.data) if n < 0 { @@ -151,22 +151,22 @@ func (l *Listener) validateInitialAddress(now time.Time, p genericLongPacket, ad if len(token) == 0 { // The sender has not provided a token. // Send a Retry packet to them with one. - l.sendRetry(now, p, addr) + e.sendRetry(now, p, addr) return nil, false } - origDstConnID, ok = l.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) if !ok { // This does not seem to be a valid token. // Close the connection with an INVALID_TOKEN error. // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 - l.sendConnectionClose(p, addr, errInvalidToken) + e.sendConnectionClose(p, addr, errInvalidToken) return nil, false } return origDstConnID, true } -func (l *Listener) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { - token, srcConnID, err := l.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) if err != nil { return } @@ -175,7 +175,7 @@ func (l *Listener) sendRetry(now time.Time, p genericLongPacket, addr netip.Addr srcConnID: srcConnID, token: token, }) - l.sendDatagram(b, addr) + e.sendDatagram(b, addr) } type retryPacket struct { diff --git a/internal/quic/retry_test.go b/internal/quic/retry_test.go index f754270a5..4a21a4ca1 100644 --- a/internal/quic/retry_test.go +++ b/internal/quic/retry_test.go @@ -16,7 +16,7 @@ import ( ) type retryServerTest struct { - tl *testListener + te *testEndpoint originalSrcConnID []byte originalDstConnID []byte retry retryPacket @@ -32,16 +32,16 @@ func newRetryServerTest(t *testing.T) *retryServerTest { TLSConfig: newTestTLSConfig(serverSide), RequireAddressValidation: true, } - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) srcID := testPeerConnID(0) dstID := testLocalConnID(-1) params := defaultTransportParameters() params.initialSrcConnID = srcID - initialCrypto := initialClientCrypto(t, tl, params) + initialCrypto := initialClientCrypto(t, te, params) // Initial packet with no Token. // Server responds with a Retry containing a token. - tl.writeDatagram(&testDatagram{ + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, @@ -56,7 +56,7 @@ func newRetryServerTest(t *testing.T) *retryServerTest { }}, paddedSize: 1200, }) - got := tl.readDatagram() + got := te.readDatagram() if len(got.packets) != 1 || got.packets[0].ptype != packetTypeRetry { t.Fatalf("got datagram: %v\nwant Retry", got) } @@ -66,7 +66,7 @@ func newRetryServerTest(t *testing.T) *retryServerTest { } return &retryServerTest{ - tl: tl, + te: te, originalSrcConnID: srcID, originalDstConnID: dstID, retry: retryPacket{ @@ -80,9 +80,9 @@ func newRetryServerTest(t *testing.T) *retryServerTest { func TestRetryServerSucceeds(t *testing.T) { rt := newRetryServerTest(t) - tl := rt.tl - tl.advance(retryTokenValidityPeriod) - tl.writeDatagram(&testDatagram{ + te := rt.te + te.advance(retryTokenValidityPeriod) + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -98,7 +98,7 @@ func TestRetryServerSucceeds(t *testing.T) { }}, paddedSize: 1200, }) - tc := tl.accept() + tc := te.accept() initial := tc.readPacket() if initial == nil || initial.ptype != packetTypeInitial { t.Fatalf("got packet:\n%v\nwant: Initial", initial) @@ -124,8 +124,8 @@ func TestRetryServerTokenInvalid(t *testing.T) { // INVALID_TOKEN error." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 rt := newRetryServerTest(t) - tl := rt.tl - tl.writeDatagram(&testDatagram{ + te := rt.te + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -141,7 +141,7 @@ func TestRetryServerTokenInvalid(t *testing.T) { }}, paddedSize: 1200, }) - tl.wantDatagram("server closes connection after Initial with invalid Retry token", + te.wantDatagram("server closes connection after Initial with invalid Retry token", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -152,9 +152,9 @@ func TestRetryServerTokenTooOld(t *testing.T) { // "[...] a token SHOULD have an expiration time [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.3-3 rt := newRetryServerTest(t) - tl := rt.tl - tl.advance(retryTokenValidityPeriod + time.Second) - tl.writeDatagram(&testDatagram{ + te := rt.te + te.advance(retryTokenValidityPeriod + time.Second) + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -170,7 +170,7 @@ func TestRetryServerTokenTooOld(t *testing.T) { }}, paddedSize: 1200, }) - tl.wantDatagram("server closes connection after Initial with expired token", + te.wantDatagram("server closes connection after Initial with expired token", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -182,8 +182,8 @@ func TestRetryServerTokenWrongIP(t *testing.T) { // to verify that the source IP address and port in client packets remain constant." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.4-3 rt := newRetryServerTest(t) - tl := rt.tl - tl.writeDatagram(&testDatagram{ + te := rt.te + te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, @@ -200,7 +200,7 @@ func TestRetryServerTokenWrongIP(t *testing.T) { paddedSize: 1200, addr: netip.MustParseAddrPort("10.0.0.2:8000"), }) - tl.wantDatagram("server closes connection after Initial from wrong address", + te.wantDatagram("server closes connection after Initial from wrong address", initialConnectionCloseDatagram( rt.retry.srcConnID, rt.originalSrcConnID, @@ -435,7 +435,7 @@ func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { token: []byte{1, 2, 3, 4}, }) pkt[len(pkt)-1] ^= 1 // invalidate the integrity tag - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: pkt, addr: testClientAddr, }) @@ -527,14 +527,14 @@ func TestParseInvalidRetryPackets(t *testing.T) { } } -func initialClientCrypto(t *testing.T, l *testListener, p transportParameters) []byte { +func initialClientCrypto(t *testing.T, e *testEndpoint, p transportParameters) []byte { t.Helper() config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)} tlsClient := tls.QUICClient(config) tlsClient.SetTransportParameters(marshalTransportParameters(p)) tlsClient.Start(context.Background()) //defer tlsClient.Close() - l.peerTLSConn = tlsClient + e.peerTLSConn = tlsClient var data []byte for { e := tlsClient.NextEvent() diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go index 8a16597c4..45a49e81e 100644 --- a/internal/quic/stateless_reset_test.go +++ b/internal/quic/stateless_reset_test.go @@ -68,7 +68,7 @@ func TestStatelessResetSentSizes(t *testing.T) { StatelessResetKey: testStatelessResetKey, } addr := netip.MustParseAddr("127.0.0.1") - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) for i, test := range []struct { reqSize int wantSize int @@ -105,9 +105,9 @@ func TestStatelessResetSentSizes(t *testing.T) { cid := testLocalConnID(int64(i)) token := testStatelessResetToken(cid) addrport := netip.AddrPortFrom(addr, uint16(8000+i)) - tl.write(newDatagramForReset(cid, test.reqSize, addrport)) + te.write(newDatagramForReset(cid, test.reqSize, addrport)) - got := tl.read() + got := te.read() if len(got) != test.wantSize { t.Errorf("got %v-byte response to %v-byte req, want %v", len(got), test.reqSize, test.wantSize) @@ -130,7 +130,8 @@ func TestStatelessResetSentSizes(t *testing.T) { func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { // "[...] Stateless Reset Token field values from [...] NEW_CONNECTION_ID frames [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1 - tc := newTestConn(t, clientSide) + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, qr.config) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -148,7 +149,7 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { resetToken := testPeerStatelessResetToken(1) // provided during handshake dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) @@ -158,6 +159,13 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { tc.wantIdle("closed connection is idle in draining") tc.advance(1 * time.Second) // long enough to exit the draining state tc.wantIdle("closed connection is idle after draining") + + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:connection_closed", + "data": map[string]any{ + "trigger": "stateless_reset", + }, + }) } func TestStatelessResetSuccessfulTransportParameter(t *testing.T) { @@ -171,7 +179,7 @@ func TestStatelessResetSuccessfulTransportParameter(t *testing.T) { tc.handshake() dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) @@ -235,7 +243,7 @@ func TestStatelessResetSuccessfulPrefix(t *testing.T) { dgram = append(dgram, byte(len(dgram))) // semi-random junk } dgram = append(dgram, resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) { @@ -270,7 +278,7 @@ func TestStatelessResetRetiredConnID(t *testing.T) { // Receive a stateless reset for connection ID 0. dgram := append(make([]byte, 100), resetToken[:]...) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: dgram, }) diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 58d84ed1b..36c80f6af 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -38,10 +38,11 @@ type Stream struct { // the write will fail. outgate gate out pipe // buffered data to send + outflushed int64 // offset of last flush call outwin int64 // maximum MAX_STREAM_DATA received from the peer outmaxsent int64 // maximum data offset we've sent to the peer outmaxbuf int64 // maximum amount of data we will buffer - outunsent rangeset[int64] // ranges buffered but not yet sent + outunsent rangeset[int64] // ranges buffered but not yet sent (only flushed data) outacked rangeset[int64] // ranges sent and acknowledged outopened sentVal // set if we should open the stream outclosed sentVal // set by CloseWrite @@ -240,8 +241,6 @@ func (s *Stream) Write(b []byte) (n int, err error) { // WriteContext writes data to the stream write buffer. // Buffered data is only sent when the buffer is sufficiently full. // Call the Flush method to ensure buffered data is sent. -// -// TODO: Implement Flush. func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) { if s.IsReadOnly() { return 0, errors.New("write to read-only stream") @@ -269,10 +268,6 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) s.outUnlock() return n, errors.New("write to closed stream") } - // We set outopened here rather than below, - // so if this is a zero-length write we still - // open the stream despite not writing any data to it. - s.outopened.set() if len(b) == 0 { break } @@ -282,13 +277,26 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) // Amount to write is min(the full buffer, data up to the write limit). // This is a number of bytes. nn := min(int64(len(b)), lim-s.out.end) - // Copy the data into the output buffer and mark it as unsent. - if s.out.end <= s.outwin { - s.outunsent.add(s.out.end, min(s.out.end+nn, s.outwin)) - } + // Copy the data into the output buffer. s.out.writeAt(b[:nn], s.out.end) b = b[nn:] n += int(nn) + // Possibly flush the output buffer. + // We automatically flush if: + // - We have enough data to consume the send window. + // Sending this data may cause the peer to extend the window. + // - We have buffered as much data as we're willing do. + // We need to send data to clear out buffer space. + // - We have enough data to fill a 1-RTT packet using the smallest + // possible maximum datagram size (1200 bytes, less header byte, + // connection ID, packet number, and AEAD overhead). + const autoFlushSize = smallestMaxDatagramSize - 1 - connIDLen - 1 - aeadOverhead + shouldFlush := s.out.end >= s.outwin || // peer send window is full + s.out.end >= lim || // local send buffer is full + (s.out.end-s.outflushed) >= autoFlushSize // enough data buffered + if shouldFlush { + s.flushLocked() + } if s.out.end > s.outwin { // We're blocked by flow control. // Send a STREAM_DATA_BLOCKED frame to let the peer know. @@ -301,6 +309,23 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) return n, nil } +// Flush flushes data written to the stream. +// It does not wait for the peer to acknowledge receipt of the data. +// Use CloseContext to wait for the peer's acknowledgement. +func (s *Stream) Flush() { + s.outgate.lock() + defer s.outUnlock() + s.flushLocked() +} + +func (s *Stream) flushLocked() { + s.outopened.set() + if s.outflushed < s.outwin { + s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) + } + s.outflushed = s.out.end +} + // Close closes the stream. // See CloseContext for more details. func (s *Stream) Close() error { @@ -363,6 +388,7 @@ func (s *Stream) CloseWrite() { s.outgate.lock() defer s.outUnlock() s.outclosed.set() + s.flushLocked() } // Reset aborts writes on the stream and notifies the peer @@ -612,8 +638,8 @@ func (s *Stream) handleMaxStreamData(maxStreamData int64) error { if maxStreamData <= s.outwin { return nil } - if s.out.end > s.outwin { - s.outunsent.add(s.outwin, min(maxStreamData, s.out.end)) + if s.outflushed > s.outwin { + s.outunsent.add(s.outwin, min(maxStreamData, s.outflushed)) } s.outwin = maxStreamData if s.out.end > s.outwin { @@ -741,10 +767,11 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b } for { // STREAM - off, size := dataToSend(min(s.out.start, s.outwin), min(s.out.end, s.outwin), s.outunsent, s.outacked, pto) + off, size := dataToSend(min(s.out.start, s.outwin), min(s.outflushed, s.outwin), s.outunsent, s.outacked, pto) if end := off + size; end > s.outmaxsent { // This will require connection-level flow control to send. end = min(end, s.outmaxsent+s.conn.streams.outflow.avail()) + end = max(end, off) size = end - off } fin := s.outclosed.isSet() && off+size == s.out.end diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 9bf2b5871..93c8839ff 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -38,6 +38,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { if n != writeBufferSize || err != context.Canceled { t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) } + s.Flush() tc.wantFrame("first write buffer of data sent", packetType1RTT, debugFrameStream{ id: s.id, @@ -47,7 +48,9 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { // Blocking write, which must wait for buffer space. w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want[writeBufferSize:]) + n, err := s.WriteContext(ctx, want[writeBufferSize:]) + s.Flush() + return n, err }) tc.wantIdle("write buffer is full, no more data can be sent") @@ -170,6 +173,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { t.Fatal(err) } s.WriteContext(ctx, want[:1]) + s.Flush() tc.wantFrame("sent data (1 byte) fits within flow control limit", packetType1RTT, debugFrameStream{ id: s.id, @@ -723,7 +727,7 @@ func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFra if err != nil { t.Fatal(err) } - s.Write(nil) // open the stream + s.Flush() // open the stream tc.wantFrame("new stream is opened", packetType1RTT, debugFrameStream{ id: sid, @@ -968,7 +972,9 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { want := make([]byte, 4096) rand.Read(want) // doesn't need to be crypto/rand, but non-deprecated and harmless w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want) + n, err := s.WriteContext(ctx, want) + s.Flush() + return n, err }) got := make([]byte, 0, len(want)) for { @@ -998,6 +1004,7 @@ func TestStreamCloseWaitsForAcks(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.WriteContext(ctx, data) + s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, @@ -1064,6 +1071,7 @@ func TestStreamCloseUnblocked(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.WriteContext(ctx, data) + s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, @@ -1228,6 +1236,7 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, styp, permissiveTransportParameters) for i := 0; i < 4; i++ { s.Write([]byte{byte(i)}) + s.Flush() tc.wantFrame("write sends a STREAM frame to peer", packetType1RTT, debugFrameStream{ id: s.id, @@ -1271,6 +1280,99 @@ func TestStreamReceiveDataBlocked(t *testing.T) { tc.wantIdle("no response to STREAM_DATA_BLOCKED and DATA_BLOCKED") } +func TestStreamFlushExplicit(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters) + want := []byte{0, 1, 2, 3} + n, err := s.Write(want) + if n != len(want) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantIdle("unflushed data is not sent") + s.Flush() + tc.wantFrame("data is sent after flush", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want, + }) + }) +} + +func TestStreamFlushImplicitExact(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + const writeBufferSize = 4 + tc, s := newTestConnAndLocalStream(t, clientSide, styp, + permissiveTransportParameters, + func(c *Config) { + c.MaxStreamWriteBufferSize = writeBufferSize + }) + want := []byte{0, 1, 2, 3, 4, 5, 6} + + // This write doesn't quite fill the output buffer. + n, err := s.Write(want[:3]) + if n != 3 || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantIdle("unflushed data is not sent") + + // This write fills the output buffer exactly. + n, err = s.Write(want[3:4]) + if n != 1 || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) + } + tc.wantFrame("data is sent after write buffer fills", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want[0:4], + }) + + }) +} + +func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + const writeBufferSize = 4 + tc, s := newTestConnAndLocalStream(t, clientSide, styp, + permissiveTransportParameters, + func(c *Config) { + c.MaxStreamWriteBufferSize = writeBufferSize + }) + want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + w := runAsync(tc, func(ctx context.Context) (int, error) { + n, err := s.WriteContext(ctx, want) + return n, err + }) + + tc.wantFrame("data is sent after write buffer fills", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want[0:4], + }) + tc.writeAckForAll() + tc.wantFrame("ack permits sending more data", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 4, + data: want[4:8], + }) + tc.writeAckForAll() + + tc.wantIdle("write buffer is not full") + if n, err := w.result(); n != len(want) || err != nil { + t.Fatalf("Write() = %v, %v; want %v, nil", n, err, len(want)) + } + + s.Flush() + tc.wantFrame("flush sends last buffer of data", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 8, + data: want[8:], + }) + }) +} + type streamSide string const ( diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index fa339b9fa..14f74a00a 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -36,7 +36,7 @@ func (tc *testConn) handshake() { for { if i == len(dgrams)-1 { if tc.conn.side == clientSide { - want := tc.listener.now.Add(maxAckDelay - timerGranularity) + want := tc.endpoint.now.Add(maxAckDelay - timerGranularity) if !tc.timer.Equal(want) { t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer) } @@ -85,7 +85,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { testPeerConnID(0), testPeerConnID(1), } - localResetToken := tc.listener.l.resetGen.tokenForConnID(localConnIDs[1]) + localResetToken := tc.endpoint.e.resetGen.tokenForConnID(localConnIDs[1]) peerResetToken := testPeerStatelessResetToken(1) if tc.conn.side == clientSide { clientConnIDs = localConnIDs diff --git a/internal/quic/version_test.go b/internal/quic/version_test.go index 830e0e1c8..92fabd7b3 100644 --- a/internal/quic/version_test.go +++ b/internal/quic/version_test.go @@ -17,7 +17,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { config := &Config{ TLSConfig: newTestTLSConfig(serverSide), } - tl := newTestListener(t, config) + te := newTestEndpoint(t, config) // Packet of unknown contents for some unrecognized QUIC version. dstConnID := []byte{1, 2, 3, 4} @@ -34,10 +34,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { pkt = append(pkt, 0) } - tl.write(&datagram{ + te.write(&datagram{ b: pkt, }) - gotPkt := tl.read() + gotPkt := te.read() if gotPkt == nil { t.Fatalf("got no response; want Version Negotiaion") } @@ -59,7 +59,7 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { func TestVersionNegotiationClientAborts(t *testing.T) { tc := newTestConn(t, clientSide) p := tc.readPacket() // client Initial packet - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), }) tc.wantIdle("connection does not send a CONNECTION_CLOSE") @@ -76,7 +76,7 @@ func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { debugFrameCrypto{ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], }) - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), }) if err := tc.conn.waitReady(canceledContext()); err != context.Canceled { @@ -94,7 +94,7 @@ func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) p := tc.readPacket() // client Initial packet - tc.listener.write(&datagram{ + tc.endpoint.write(&datagram{ b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10), }) tc.writeFrames(packetTypeInitial,