Skip to content

Commit 6866aae

Browse files
committed
wgengine/magicsock: factor out receiveIPv4 & receiveIPv6 common code
Updates tailscale#2331 Change-Id: I801df38b217f5d17203e8dc3b8654f44747e0f4b Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
1 parent c889254 commit 6866aae

File tree

2 files changed

+55
-73
lines changed

2 files changed

+55
-73
lines changed

wgengine/magicsock/magicsock.go

Lines changed: 49 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,6 @@ type Conn struct {
322322
// bind is the wireguard-go conn.Bind for Conn.
323323
bind *connBind
324324

325-
// ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and
326-
// receiveIPv6, respectively, to cache an IPPort->endpoint for
327-
// hot flows.
328-
ippEndpoint4, ippEndpoint6 ippEndpointCache
329-
330325
// ============================================================
331326
// Fields that must be accessed via atomic load/stores.
332327

@@ -1851,80 +1846,64 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) {
18511846
c.receiveBatchPool.Put(batch)
18521847
}
18531848

1854-
func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
1855-
health.ReceiveIPv6.Enter()
1856-
defer health.ReceiveIPv6.Exit()
1849+
// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4.
1850+
func (c *Conn) receiveIPv4() conn.ReceiveFunc {
1851+
return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4)
1852+
}
18571853

1858-
batch := c.getReceiveBatchForBuffs(buffs)
1859-
defer c.putReceiveBatch(batch)
1860-
for {
1861-
numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0)
1862-
if err != nil {
1863-
if neterror.PacketWasTruncated(err) {
1864-
// TODO(raggi): discuss whether to log?
1865-
continue
1866-
}
1867-
return 0, err
1868-
}
1854+
// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6.
1855+
func (c *Conn) receiveIPv6() conn.ReceiveFunc {
1856+
return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6)
1857+
}
18691858

1870-
reportToCaller := false
1871-
for i, msg := range batch.msgs[:numMsgs] {
1872-
if msg.N == 0 {
1873-
sizes[i] = 0
1874-
continue
1875-
}
1876-
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
1877-
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
1878-
metricRecvDataIPv6.Add(1)
1879-
eps[i] = ep
1880-
sizes[i] = msg.N
1881-
reportToCaller = true
1882-
} else {
1883-
sizes[i] = 0
1884-
}
1885-
}
1859+
// mkReceiveFunc creates a ReceiveFunc reading from ruc.
1860+
// The provided healthItem and metric are updated if non-nil.
1861+
func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc {
1862+
// epCache caches an IPPort->endpoint for hot flows.
1863+
var epCache ippEndpointCache
18861864

1887-
if reportToCaller {
1888-
return numMsgs, nil
1865+
return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
1866+
if healthItem != nil {
1867+
healthItem.Enter()
1868+
defer healthItem.Exit()
1869+
}
1870+
if ruc == nil {
1871+
panic("nil RebindingUDPConn")
18891872
}
1890-
}
1891-
}
1892-
1893-
func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
1894-
health.ReceiveIPv4.Enter()
1895-
defer health.ReceiveIPv4.Exit()
18961873

1897-
batch := c.getReceiveBatchForBuffs(buffs)
1898-
defer c.putReceiveBatch(batch)
1899-
for {
1900-
numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0)
1901-
if err != nil {
1902-
if neterror.PacketWasTruncated(err) {
1903-
// TODO(raggi): discuss whether to log?
1904-
continue
1874+
batch := c.getReceiveBatchForBuffs(buffs)
1875+
defer c.putReceiveBatch(batch)
1876+
for {
1877+
numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0)
1878+
if err != nil {
1879+
if neterror.PacketWasTruncated(err) {
1880+
continue
1881+
}
1882+
return 0, err
19051883
}
1906-
return 0, err
1907-
}
19081884

1909-
reportToCaller := false
1910-
for i, msg := range batch.msgs[:numMsgs] {
1911-
if msg.N == 0 {
1912-
sizes[i] = 0
1913-
continue
1885+
reportToCaller := false
1886+
for i, msg := range batch.msgs[:numMsgs] {
1887+
if msg.N == 0 {
1888+
sizes[i] = 0
1889+
continue
1890+
}
1891+
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
1892+
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok {
1893+
if metric != nil {
1894+
metric.Add(1)
1895+
}
1896+
eps[i] = ep
1897+
sizes[i] = msg.N
1898+
reportToCaller = true
1899+
} else {
1900+
sizes[i] = 0
1901+
}
19141902
}
1915-
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
1916-
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok {
1917-
metricRecvDataIPv4.Add(1)
1918-
eps[i] = ep
1919-
sizes[i] = msg.N
1920-
reportToCaller = true
1921-
} else {
1922-
sizes[i] = 0
1903+
if reportToCaller {
1904+
return numMsgs, nil
19231905
}
19241906
}
1925-
if reportToCaller {
1926-
return numMsgs, nil
1927-
}
19281907
}
19291908
}
19301909

@@ -3044,7 +3023,7 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
30443023
return nil, 0, errors.New("magicsock: connBind already open")
30453024
}
30463025
c.closed = false
3047-
fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveIPv6, c.receiveDERP}
3026+
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
30483027
if runtime.GOOS == "js" {
30493028
fns = []conn.ReceiveFunc{c.receiveDERP}
30503029
}

wgengine/magicsock/magicsock_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,9 @@ func TestNewConn(t *testing.T) {
374374
sizes := make([]int, 1)
375375
eps := make([]wgconn.Endpoint, 1)
376376
pkts[0] = make([]byte, 64<<10)
377+
receiveIPv4 := conn.receiveIPv4()
377378
for {
378-
_, err := conn.receiveIPv4(pkts, sizes, eps)
379+
_, err := receiveIPv4(pkts, sizes, eps)
379380
if err != nil {
380381
return
381382
}
@@ -1284,11 +1285,12 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
12841285
buffs[0] = make([]byte, 2<<10)
12851286
sizes := make([]int, 1)
12861287
eps := make([]wgconn.Endpoint, 1)
1288+
receiveIPv4 := conn.receiveIPv4()
12871289
return func() {
12881290
if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
12891291
tb.Fatalf("WriteTo: %v", err)
12901292
}
1291-
n, err := conn.receiveIPv4(buffs, sizes, eps)
1293+
n, err := receiveIPv4(buffs, sizes, eps)
12921294
if err != nil {
12931295
tb.Fatal(err)
12941296
}
@@ -1513,8 +1515,9 @@ func TestRebindStress(t *testing.T) {
15131515
sizes := make([]int, 1)
15141516
eps := make([]wgconn.Endpoint, 1)
15151517
buffs[0] = make([]byte, 1500)
1518+
receiveIPv4 := conn.receiveIPv4()
15161519
for {
1517-
_, err := conn.receiveIPv4(buffs, sizes, eps)
1520+
_, err := receiveIPv4(buffs, sizes, eps)
15181521
if ctx.Err() != nil {
15191522
errc <- nil
15201523
return

0 commit comments

Comments
 (0)