Skip to content

Commit 6b5ea62

Browse files
committed
actually fixed with extra tests
1 parent 90f619d commit 6b5ea62

File tree

2 files changed

+249
-45
lines changed

2 files changed

+249
-45
lines changed

vpn/tunnel.go

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -619,47 +619,61 @@ func (u *updater) recordLatencies() {
619619
for _, agent := range u.agents {
620620
agentsIDsToPing = append(agentsIDsToPing, agent.ID)
621621
}
622+
conn := u.conn
622623
u.mu.Unlock()
623624

624-
for _, agentID := range agentsIDsToPing {
625-
go func() {
626-
pingCtx, cancelFunc := context.WithTimeout(u.ctx, 5*time.Second)
627-
defer cancelFunc()
628-
pingDur, didP2p, pingResult, err := u.conn.Ping(pingCtx, agentID)
629-
if err != nil {
630-
u.logger.Warn(u.ctx, "failed to ping agent", slog.F("agent_id", agentID), slog.Error(err))
631-
return
632-
}
625+
if conn == nil {
626+
u.logger.Debug(u.ctx, "skipping pings as tunnel is not connected")
627+
return
628+
}
633629

634-
u.mu.Lock()
635-
defer u.mu.Unlock()
636-
if u.conn == nil {
637-
u.logger.Debug(u.ctx, "ignoring ping result as connection is closed", slog.F("agent_id", agentID))
638-
return
639-
}
640-
node := u.conn.Node()
641-
derpMap := u.conn.DERPMap()
642-
derpLatencies := tailnet.ExtractDERPLatency(node, derpMap)
643-
preferredDerp := tailnet.ExtractPreferredDERPName(pingResult, node, derpMap)
644-
var preferredDerpLatency *time.Duration
645-
if derpLatency, ok := derpLatencies[preferredDerp]; ok {
646-
preferredDerpLatency = &derpLatency
647-
} else {
648-
u.logger.Debug(u.ctx, "preferred DERP not found in DERP latency map", slog.F("preferred_derp", preferredDerp))
649-
}
650-
if agent, ok := u.agents[agentID]; ok {
651-
agent.lastPing = &lastPing{
652-
pingDur: pingDur,
653-
didP2p: didP2p,
654-
preferredDerp: preferredDerp,
655-
preferredDerpLatency: preferredDerpLatency,
630+
go func() {
631+
// We need a waitgroup to cancel the context after all pings are done.
632+
var wg sync.WaitGroup
633+
pingCtx, cancelFunc := context.WithTimeout(u.ctx, 5*time.Second)
634+
defer cancelFunc()
635+
for _, agentID := range agentsIDsToPing {
636+
wg.Add(1)
637+
go func() {
638+
defer wg.Done()
639+
640+
pingDur, didP2p, pingResult, err := conn.Ping(pingCtx, agentID)
641+
if err != nil {
642+
u.logger.Warn(u.ctx, "failed to ping agent", slog.F("agent_id", agentID), slog.Error(err))
643+
return
656644
}
657-
u.agents[agentID] = agent
658-
} else {
659-
u.logger.Debug(u.ctx, "ignoring ping result for unknown agent", slog.F("agent_id", agentID))
660-
}
661-
}()
662-
}
645+
646+
// We fetch the Node and DERPMap after each ping, as it may have
647+
// changed.
648+
node := conn.Node()
649+
derpMap := conn.DERPMap()
650+
derpLatencies := tailnet.ExtractDERPLatency(node, derpMap)
651+
preferredDerp := tailnet.ExtractPreferredDERPName(pingResult, node, derpMap)
652+
var preferredDerpLatency *time.Duration
653+
if derpLatency, ok := derpLatencies[preferredDerp]; ok {
654+
preferredDerpLatency = &derpLatency
655+
} else {
656+
u.logger.Debug(u.ctx, "preferred DERP not found in DERP latency map", slog.F("preferred_derp", preferredDerp))
657+
}
658+
659+
// Write back results
660+
u.mu.Lock()
661+
defer u.mu.Unlock()
662+
if agent, ok := u.agents[agentID]; ok {
663+
agent.lastPing = &lastPing{
664+
pingDur: pingDur,
665+
didP2p: didP2p,
666+
preferredDerp: preferredDerp,
667+
preferredDerpLatency: preferredDerpLatency,
668+
}
669+
u.agents[agentID] = agent
670+
} else {
671+
u.logger.Debug(u.ctx, "ignoring ping result for unknown agent", slog.F("agent_id", agentID))
672+
}
673+
}()
674+
}
675+
wg.Wait()
676+
}()
663677
}
664678

665679
// processSnapshotUpdate handles the logic when a full state update is received.

vpn/tunnel_internal_test.go

Lines changed: 198 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,17 @@ func newFakeConn(state tailnet.WorkspaceUpdate, hsTime time.Time) *fakeConn {
6060
}
6161
}
6262

63+
func (f *fakeConn) withManualPings() *fakeConn {
64+
f.returnPing = make(chan struct{})
65+
return f
66+
}
67+
6368
type fakeConn struct {
64-
state tailnet.WorkspaceUpdate
65-
hsTime time.Time
66-
closed chan struct{}
67-
doClose sync.Once
69+
state tailnet.WorkspaceUpdate
70+
returnPing chan struct{}
71+
hsTime time.Time
72+
closed chan struct{}
73+
doClose sync.Once
6874
}
6975

7076
func (*fakeConn) DERPMap() *tailcfg.DERPMap {
@@ -90,10 +96,22 @@ func (*fakeConn) Node() *tailnet.Node {
9096

9197
var _ Conn = (*fakeConn)(nil)
9298

93-
func (*fakeConn) Ping(ctx context.Context, agentID uuid.UUID) (time.Duration, bool, *ipnstate.PingResult, error) {
94-
return time.Millisecond * 100, true, &ipnstate.PingResult{
95-
DERPRegionID: 999,
96-
}, nil
99+
func (f *fakeConn) Ping(ctx context.Context, agentID uuid.UUID) (time.Duration, bool, *ipnstate.PingResult, error) {
100+
if f.returnPing == nil {
101+
return time.Millisecond * 100, true, &ipnstate.PingResult{
102+
DERPRegionID: 999,
103+
}, nil
104+
}
105+
106+
select {
107+
case <-ctx.Done():
108+
return 0, false, nil, ctx.Err()
109+
case <-f.returnPing:
110+
return time.Millisecond * 100, true, &ipnstate.PingResult{
111+
DERPRegionID: 999,
112+
}, nil
113+
}
114+
97115
}
98116

99117
func (f *fakeConn) CurrentWorkspaceState() (tailnet.WorkspaceUpdate, error) {
@@ -759,6 +777,178 @@ func TestTunnel_sendAgentUpdateWorkspaceReconnect(t *testing.T) {
759777
require.Equal(t, wID1[:], peerUpdate.DeletedWorkspaces[0].Id)
760778
}
761779

780+
func TestTunnel_slowPing(t *testing.T) {
781+
t.Parallel()
782+
783+
ctx := testutil.Context(t, testutil.WaitShort)
784+
785+
mClock := quartz.NewMock(t)
786+
787+
wID1 := uuid.UUID{1}
788+
aID1 := uuid.UUID{2}
789+
hsTime := time.Now().Add(-time.Minute).UTC()
790+
791+
client := newFakeClient(ctx, t)
792+
conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime).withManualPings()
793+
794+
tun, mgr := setupTunnel(t, ctx, client, mClock)
795+
errCh := make(chan error, 1)
796+
var resp *TunnelMessage
797+
go func() {
798+
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
799+
Msg: &ManagerMessage_Start{
800+
Start: &StartRequest{
801+
TunnelFileDescriptor: 2,
802+
CoderUrl: "https://coder.example.com",
803+
ApiToken: "fakeToken",
804+
},
805+
},
806+
})
807+
resp = r
808+
errCh <- err
809+
}()
810+
testutil.RequireSend(ctx, t, client.ch, conn)
811+
err := testutil.TryReceive(ctx, t, errCh)
812+
require.NoError(t, err)
813+
_, ok := resp.Msg.(*TunnelMessage_Start)
814+
require.True(t, ok)
815+
816+
// Inform the tunnel of the initial state
817+
err = tun.Update(tailnet.WorkspaceUpdate{
818+
UpsertedWorkspaces: []*tailnet.Workspace{
819+
{
820+
ID: wID1, Name: "w1", Status: proto.Workspace_STARTING,
821+
},
822+
},
823+
UpsertedAgents: []*tailnet.Agent{
824+
{
825+
ID: aID1,
826+
Name: "agent1",
827+
WorkspaceID: wID1,
828+
Hosts: map[dnsname.FQDN][]netip.Addr{
829+
"agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
830+
},
831+
},
832+
},
833+
})
834+
require.NoError(t, err)
835+
req := testutil.TryReceive(ctx, t, mgr.requests)
836+
require.Nil(t, req.msg.Rpc)
837+
require.NotNil(t, req.msg.GetPeerUpdate())
838+
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
839+
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
840+
841+
// We can't check that it *never* pings, so the best we can do is
842+
// check it doesn't ping even with 5 goroutines attempting to,
843+
// and that updates are received as normal
844+
for range 5 {
845+
mClock.AdvanceNext()
846+
require.Nil(t, req.msg.GetPeerUpdate().UpsertedAgents[0].LastPing)
847+
}
848+
849+
// Provided that it hasn't been 5 seconds since the last AdvanceNext call,
850+
// there'll be a ping in-flight that will return with this message
851+
testutil.RequireSend(ctx, t, conn.returnPing, struct{}{})
852+
// Which will mean we'll eventually receive a PeerUpdate with the ping
853+
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
854+
mClock.AdvanceNext()
855+
req = testutil.TryReceive(ctx, t, mgr.requests)
856+
if len(req.msg.GetPeerUpdate().UpsertedAgents) == 0 {
857+
return false
858+
}
859+
if req.msg.GetPeerUpdate().UpsertedAgents[0].LastPing == nil {
860+
return false
861+
}
862+
if req.msg.GetPeerUpdate().UpsertedAgents[0].LastPing.Latency.AsDuration().Milliseconds() != 100 {
863+
return false
864+
}
865+
return req.msg.GetPeerUpdate().UpsertedAgents[0].LastPing.PreferredDerp == "Coder Region"
866+
}, testutil.IntervalFast)
867+
}
868+
869+
func TestTunnel_stopMidPing(t *testing.T) {
870+
t.Parallel()
871+
872+
ctx := testutil.Context(t, testutil.WaitShort)
873+
874+
mClock := quartz.NewMock(t)
875+
876+
wID1 := uuid.UUID{1}
877+
aID1 := uuid.UUID{2}
878+
hsTime := time.Now().Add(-time.Minute).UTC()
879+
880+
client := newFakeClient(ctx, t)
881+
conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime).withManualPings()
882+
883+
tun, mgr := setupTunnel(t, ctx, client, mClock)
884+
errCh := make(chan error, 1)
885+
var resp *TunnelMessage
886+
go func() {
887+
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
888+
Msg: &ManagerMessage_Start{
889+
Start: &StartRequest{
890+
TunnelFileDescriptor: 2,
891+
CoderUrl: "https://coder.example.com",
892+
ApiToken: "fakeToken",
893+
},
894+
},
895+
})
896+
resp = r
897+
errCh <- err
898+
}()
899+
testutil.RequireSend(ctx, t, client.ch, conn)
900+
err := testutil.TryReceive(ctx, t, errCh)
901+
require.NoError(t, err)
902+
_, ok := resp.Msg.(*TunnelMessage_Start)
903+
require.True(t, ok)
904+
905+
// Inform the tunnel of the initial state
906+
err = tun.Update(tailnet.WorkspaceUpdate{
907+
UpsertedWorkspaces: []*tailnet.Workspace{
908+
{
909+
ID: wID1, Name: "w1", Status: proto.Workspace_STARTING,
910+
},
911+
},
912+
UpsertedAgents: []*tailnet.Agent{
913+
{
914+
ID: aID1,
915+
Name: "agent1",
916+
WorkspaceID: wID1,
917+
Hosts: map[dnsname.FQDN][]netip.Addr{
918+
"agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
919+
},
920+
},
921+
},
922+
})
923+
require.NoError(t, err)
924+
req := testutil.TryReceive(ctx, t, mgr.requests)
925+
require.Nil(t, req.msg.Rpc)
926+
require.NotNil(t, req.msg.GetPeerUpdate())
927+
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
928+
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
929+
930+
// We'll have some pings in flight when we stop
931+
for range 5 {
932+
mClock.AdvanceNext()
933+
req = testutil.TryReceive(ctx, t, mgr.requests)
934+
require.Nil(t, req.msg.GetPeerUpdate().UpsertedAgents[0].LastPing)
935+
}
936+
937+
// Stop the tunnel
938+
go func() {
939+
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
940+
Msg: &ManagerMessage_Stop{},
941+
})
942+
resp = r
943+
errCh <- err
944+
}()
945+
testutil.TryReceive(ctx, t, conn.closed)
946+
err = testutil.TryReceive(ctx, t, errCh)
947+
require.NoError(t, err)
948+
_, ok = resp.Msg.(*TunnelMessage_Stop)
949+
require.True(t, ok)
950+
}
951+
762952
//nolint:revive // t takes precedence
763953
func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock *quartz.Mock) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) {
764954
mp, tp := net.Pipe()

0 commit comments

Comments
 (0)