Skip to content

Commit bc2aadf

Browse files
committed
chore: add DERPForcedWebsocket to nodeUpdater
1 parent 8701dbc commit bc2aadf

File tree

2 files changed

+97
-8
lines changed

2 files changed

+97
-8
lines changed

tailnet/node.go

+20-6
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ func newNodeUpdater(
8686
id tailcfg.NodeID, np key.NodePublic, dp key.DiscoPublic,
8787
) *nodeUpdater {
8888
u := &nodeUpdater{
89-
phased: phased{Cond: *(sync.NewCond(&sync.Mutex{}))},
90-
logger: logger,
91-
id: id,
92-
key: np,
93-
discoKey: dp,
94-
callback: callback,
89+
phased: phased{Cond: *(sync.NewCond(&sync.Mutex{}))},
90+
logger: logger,
91+
id: id,
92+
key: np,
93+
discoKey: dp,
94+
derpForcedWebsockets: make(map[int]string),
95+
callback: callback,
9596
}
9697
go u.updateLoop()
9798
return u
@@ -132,3 +133,16 @@ func (u *nodeUpdater) setNetInfo(ni *tailcfg.NetInfo) {
132133
u.Broadcast()
133134
}
134135
}
136+
137+
// setDERPForcedWebsocket handles callbacks from the magicConn about DERP regions that are forced to
138+
// use websockets (instead of Upgrade: derp). This information is for debugging only.
139+
func (u *nodeUpdater) setDERPForcedWebsocket(region int, reason string) {
140+
u.L.Lock()
141+
defer u.L.Unlock()
142+
dirty := u.derpForcedWebsockets[region] != reason
143+
u.derpForcedWebsockets[region] = reason
144+
if dirty {
145+
u.dirty = true
146+
u.Broadcast()
147+
}
148+
}

tailnet/node_internal_test.go

+77-2
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,10 @@ func TestNodeUpdater_setNetInfo_same(t *testing.T) {
7474
nodeKey := key.NewNode().Public()
7575
discoKey := key.NewDisco().Public()
7676
nodeCh := make(chan *Node)
77-
goCh := make(chan struct{})
7877
uut := newNodeUpdater(
7978
logger,
8079
func(n *Node) {
8180
nodeCh <- n
82-
<-goCh
8381
},
8482
id, nodeKey, discoKey,
8583
)
@@ -108,3 +106,80 @@ func TestNodeUpdater_setNetInfo_same(t *testing.T) {
108106
}()
109107
_ = testutil.RequireRecvCtx(ctx, t, done)
110108
}
109+
110+
func TestNodeUpdater_setDERPForcedWebsocket_different(t *testing.T) {
111+
t.Parallel()
112+
ctx := testutil.Context(t, testutil.WaitShort)
113+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
114+
id := tailcfg.NodeID(1)
115+
nodeKey := key.NewNode().Public()
116+
discoKey := key.NewDisco().Public()
117+
nodeCh := make(chan *Node)
118+
uut := newNodeUpdater(
119+
logger,
120+
func(n *Node) {
121+
nodeCh <- n
122+
},
123+
id, nodeKey, discoKey,
124+
)
125+
defer uut.close()
126+
127+
// Given: preferred DERP is 1, so we'll send an update
128+
uut.L.Lock()
129+
uut.preferredDERP = 1
130+
uut.L.Unlock()
131+
132+
// When: we set a new forced websocket reason
133+
uut.setDERPForcedWebsocket(1, "test")
134+
135+
// Then: we receive an update with the reason set
136+
node := testutil.RequireRecvCtx(ctx, t, nodeCh)
137+
require.Equal(t, nodeKey, node.Key)
138+
require.Equal(t, discoKey, node.DiscoKey)
139+
require.True(t, maps.Equal(map[int]string{1: "test"}, node.DERPForcedWebsocket))
140+
141+
done := make(chan struct{})
142+
go func() {
143+
defer close(done)
144+
uut.close()
145+
}()
146+
_ = testutil.RequireRecvCtx(ctx, t, done)
147+
}
148+
149+
func TestNodeUpdater_setDERPForcedWebsocket_same(t *testing.T) {
150+
t.Parallel()
151+
ctx := testutil.Context(t, testutil.WaitShort)
152+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
153+
id := tailcfg.NodeID(1)
154+
nodeKey := key.NewNode().Public()
155+
discoKey := key.NewDisco().Public()
156+
nodeCh := make(chan *Node)
157+
uut := newNodeUpdater(
158+
logger,
159+
func(n *Node) {
160+
nodeCh <- n
161+
},
162+
id, nodeKey, discoKey,
163+
)
164+
defer uut.close()
165+
166+
// Then: we don't configure
167+
requireNeverConfigures(ctx, t, &uut.phased)
168+
169+
// Given: preferred DERP is 1, so we would send an update on change &&
170+
// reason for region 1 is set to "test"
171+
uut.L.Lock()
172+
uut.preferredDERP = 1
173+
uut.derpForcedWebsockets[1] = "test"
174+
uut.L.Unlock()
175+
176+
// When: we set region 1 to "test
177+
uut.setDERPForcedWebsocket(1, "test")
178+
179+
done := make(chan struct{})
180+
go func() {
181+
defer close(done)
182+
uut.close()
183+
}()
184+
_ = testutil.RequireRecvCtx(ctx, t, done)
185+
}

0 commit comments

Comments
 (0)