Skip to content

Commit 304cd11

Browse files
authored
fix(wgengine): Guard endpoint dispatcher via wrapper (#7)
1 parent 1e5e724 commit 304cd11

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

wgengine/netstack/netstack.go

+40-11
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ type Impl struct {
9797

9898
ipstack *stack.Stack
9999
epMu sync.RWMutex
100-
linkEP *channel.Endpoint
100+
linkEP *protectedLinkEndpoint
101101
tundev *tstun.Wrapper
102102
e wgengine.Engine
103103
mc *magicsock.Conn
@@ -159,7 +159,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
159159
if tcpipErr != nil {
160160
return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
161161
}
162-
linkEP := channel.New(512, mtu, "")
162+
linkEP := &protectedLinkEndpoint{Endpoint: channel.New(512, mtu, "")}
163163
if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
164164
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
165165
}
@@ -202,9 +202,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
202202
func (ns *Impl) Close() error {
203203
ns.ctxCancel()
204204
ns.ipstack.Close()
205-
ns.epMu.Lock()
206205
ns.ipstack.Wait()
207-
ns.epMu.Unlock()
208206
return nil
209207
}
210208

@@ -689,13 +687,7 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
689687
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
690688
Payload: bufferv2.MakeWithData(bytes.Clone(p.Buffer())),
691689
})
692-
693-
if ns.epMu.TryRLock() {
694-
if ns.linkEP.IsAttached() {
695-
ns.linkEP.InjectInbound(pn, packetBuf)
696-
}
697-
ns.epMu.RUnlock()
698-
}
690+
ns.linkEP.InjectInbound(pn, packetBuf)
699691
packetBuf.DecRef()
700692

701693
// We've now delivered this to netstack, so we're done.
@@ -1216,3 +1208,40 @@ func ipPortOfNetstackAddr(a tcpip.Address, port uint16) (ipp netip.AddrPort, ok
12161208
return ipp, false
12171209
}
12181210
}
1211+
1212+
// protectedLinkEndpoint guards use of the dispatcher via mutex and forwards
1213+
// everything except Attach/InjectInbound to the underlying *channel.Endpoint.
1214+
type protectedLinkEndpoint struct {
1215+
mu sync.RWMutex
1216+
dispatcher stack.NetworkDispatcher
1217+
*channel.Endpoint
1218+
}
1219+
1220+
// InjectInbound injects an inbound packet.
1221+
func (e *protectedLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
1222+
e.mu.RLock()
1223+
dispatcher := e.dispatcher
1224+
e.mu.RUnlock()
1225+
if dispatcher != nil {
1226+
dispatcher.DeliverNetworkPacket(protocol, pkt)
1227+
}
1228+
}
1229+
1230+
// Attach saves the stack network-layer dispatcher for use later when packets
1231+
// are injected.
1232+
func (e *protectedLinkEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
1233+
e.mu.Lock()
1234+
defer e.mu.Unlock()
1235+
// No need to attach the underlying channel.Endpoint, since we hijack
1236+
// InjectInbound.
1237+
e.dispatcher = dispatcher
1238+
}
1239+
1240+
// IsAttached implements stack.LinkEndpoint.IsAttached.
1241+
func (e *protectedLinkEndpoint) IsAttached() bool {
1242+
e.mu.RLock()
1243+
defer e.mu.RUnlock()
1244+
return e.dispatcher != nil
1245+
}
1246+
1247+
var _ stack.LinkEndpoint = (*protectedLinkEndpoint)(nil)

0 commit comments

Comments
 (0)