Skip to content

fix: block writes from gVisor to tailscale instead of dropping #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
# earlier contributions and clarifying whether it's you or your
# company that owns the rights to your contribution.

Coder Technologies, Inc.
Tailscale Inc.
247 changes: 247 additions & 0 deletions wgengine/netstack/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
// based on https://github.com/google/gvisor/blob/74f22885dc45e2866985fe7179103e1000382415/pkg/tcpip/link/channel/channel.go
//
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Modifications from original source are Copyright 2024 Tailscale Inc & AUTHORS

package netstack

import (
"context"

"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

type queue struct {
// c is the outbound packet channel.
c chan *stack.PacketBuffer
mu sync.RWMutex
// +checklocks:mu
closed bool

closedChOnce sync.Once
closedCh chan struct{}
}

func (q *queue) Close() {
// This unblocks any calls to Write() which might be holding the mu.
q.closedChOnce.Do(func() {
close(q.closedCh)
})

q.mu.Lock()
defer q.mu.Unlock()
if q.closed {
return
}
close(q.c)
q.closed = true
}

func (q *queue) Read() *stack.PacketBuffer {
select {
case p := <-q.c:
return p
default:
return nil
}
}

func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer {
select {
case pkt := <-q.c:
return pkt
case <-ctx.Done():
return nil
}
}

func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error {
q.mu.RLock()
defer q.mu.RUnlock()
if q.closed {
return &tcpip.ErrClosedForSend{}
}
select {
case q.c <- pkt.IncRef():
return nil
case <-q.closedCh:
pkt.DecRef()
return &tcpip.ErrClosedForSend{}
}
}

func (q *queue) Num() int {
return len(q.c)
}

var _ stack.LinkEndpoint = (*Endpoint)(nil)
var _ stack.GSOEndpoint = (*Endpoint)(nil)

// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets. It is based on gVisor
// channel.Endpoint, however when the channel is full, it blocks writes until
// there is space in the channel or until the Endpoint is closed. The gVisor
// version dropped packets if the channel is full. This limits TCP throughput
// as dropped packets need to be retransmitted and are interpreted as a
// congestion event, causing the TCP sender to decrease the congestion window.
// Much better to apply back-pressure to the TCP stack at the Endpoint.
type Endpoint struct {
mtu uint32
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
SupportedGSOKind stack.SupportedGSO

mu sync.RWMutex
// +checklocks:mu
dispatcher stack.NetworkDispatcher

// Outbound packet queue.
q *queue
}

// NewEndpoint creates a new channel endpoint.
func NewEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
q: &queue{
c: make(chan *stack.PacketBuffer, size),
closedCh: make(chan struct{}),
},
mtu: mtu,
linkAddr: linkAddr,
}
}

// Close closes e. Further packet injections will return an error, and all pending
// packets are discarded. Close may be called concurrently with WritePackets.
func (e *Endpoint) Close() {
e.q.Close()
e.Drain()
}

// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() *stack.PacketBuffer {
return e.q.Read()
}

// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns nil.
func (e *Endpoint) ReadContext(ctx context.Context) *stack.PacketBuffer {
return e.q.ReadContext(ctx)
}

// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for pkt := e.Read(); pkt != nil; pkt = e.Read() {
pkt.DecRef()
c++
}
return c
}

// NumQueued returns the number of packet queued for outbound.
func (e *Endpoint) NumQueued() int {
return e.q.Num()
}

// InjectInbound injects an inbound packet. If the endpoint is not attached, the
// packet is not delivered.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.mu.RLock()
d := e.dispatcher
e.mu.RUnlock()
if d != nil {
d.DeliverNetworkPacket(protocol, pkt)
}
}

// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
defer e.mu.Unlock()
e.dispatcher = dispatcher
}

// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *Endpoint) IsAttached() bool {
e.mu.RLock()
defer e.mu.RUnlock()
return e.dispatcher != nil
}

// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *Endpoint) MTU() uint32 {
return e.mtu
}

// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.LinkEPCapabilities
}

// GSOMaxSize implements stack.GSOEndpoint.
func (*Endpoint) GSOMaxSize() uint32 {
return 1 << 15
}

// SupportedGSO implements stack.GSOEndpoint.
func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
return e.SupportedGSOKind
}

// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
return 0
}

// LinkAddress returns the link address of this endpoint.
func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}

// WritePackets stores outbound packets into the channel.
// Multiple concurrent calls are permitted.
func (e *Endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
n := 0
for _, pkt := range pkts.AsSlice() {
if err := e.q.Write(pkt); err != nil {
return n, err
}
n++
}

return n, nil
}

// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}

// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}

// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(*stack.PacketBuffer) {}

// ParseHeader implements stack.LinkEndpoint.ParseHeader.
func (*Endpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
142 changes: 142 additions & 0 deletions wgengine/netstack/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package netstack

import (
"context"
"testing"
"time"

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

func TestEndpointBlockingWrites(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
linkEP := NewEndpoint(1, 1500, "")
pb1 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
defer pb1.DecRef()
pb2 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
defer pb2.DecRef()
numWrites := make(chan int, 2)
go func() {
bl := stack.PacketBufferList{}
bl.PushBack(pb1)
n, err := linkEP.WritePackets(bl)
if err != nil {
t.Errorf("expected no error, got %s", err)
} else {
pb1.DecRef()
}
numWrites <- n
bl = stack.PacketBufferList{}
bl.PushBack(pb2)
n, err = linkEP.WritePackets(bl)
if err != nil {
t.Errorf("expected no error, got %s", err)
} else {
pb2.DecRef()
}
numWrites <- n
}()

select {
case n := <-numWrites:
if n != 1 {
t.Fatalf("expected 1 write got %d", n)
}
case <-ctx.Done():
t.Fatal("timed out waiting for 1st write")
}

// second write should block
select {
case <-numWrites:
t.Fatalf("expected write to block")
case <-time.After(50 * time.Millisecond):
// OK
}

pbg := linkEP.ReadContext(ctx)
if pbg != pb1 {
t.Fatalf("expected pb1")
}
// Read unblocks the 2nd write
select {
case n := <-numWrites:
if n != 1 {
t.Fatalf("expected 1 write got %d", n)
}
case <-ctx.Done():
t.Fatal("timed out waiting for 2nd write")
}
pbg = linkEP.ReadContext(ctx)
if pbg != pb2 {
t.Fatalf("expected pb2")
}
}

func TestEndpointCloseUnblocksWrites(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
linkEP := NewEndpoint(1, 1500, "")
pb1 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
pb2 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
defer pb2.DecRef()
numWrites := make(chan int, 2)
errors := make(chan tcpip.Error, 1)
go func() {
bl := stack.PacketBufferList{}
bl.PushBack(pb1)
n, err := linkEP.WritePackets(bl)
if err != nil {
t.Errorf("expected no error, got %s", err)
} else {
pb1.DecRef()
}
numWrites <- n
bl = stack.PacketBufferList{}
bl.PushBack(pb2)
n, err = linkEP.WritePackets(bl)
numWrites <- n
errors <- err
}()

select {
case n := <-numWrites:
if n != 1 {
t.Fatalf("expected 1 write got %d", n)
}
case <-ctx.Done():
t.Fatal("timed out waiting for 1st write")
}

// second write should block
select {
case <-numWrites:
t.Fatalf("expected write to block")
case <-time.After(50 * time.Millisecond):
// OK
}

// close must unblock pending writes without deadlocking
linkEP.Close()
select {
case n := <-numWrites:
if n != 0 {
t.Fatalf("expected 0 writes got %d", n)
}
case <-ctx.Done():
t.Fatal("timed out waiting for 2nd write num")
}
select {
case err := <-errors:
if _, ok := err.(*tcpip.ErrClosedForSend); !ok {
t.Fatalf("expected ErrClosedForSend got %s", err)
}
case <-ctx.Done():
t.Fatal("timed out for 2nd write error")
}
}
Loading
Loading