Skip to content

feat: improve coder connect tunnel handling on reconnect #17598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Next Next commit
initial implementation
  • Loading branch information
ibetitsmike committed Apr 29, 2025
commit 52f1c2b7952cb91cbbb51d0dab3a6a336f40faa4
10 changes: 8 additions & 2 deletions tailnet/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ func (t *tunnelUpdater) recvLoop() {
t.logger.Debug(context.Background(), "tunnel updater recvLoop started")
defer t.logger.Debug(context.Background(), "tunnel updater recvLoop done")
defer close(t.recvLoopDone)
freshState := true
for {
update, err := t.client.Recv()
if err != nil {
Expand All @@ -1061,8 +1062,10 @@ func (t *tunnelUpdater) recvLoop() {
}
t.logger.Debug(context.Background(), "got workspace update",
slog.F("workspace_update", update),
slog.F("fresh_state", freshState),
)
err = t.handleUpdate(update)
err = t.handleUpdate(update, freshState)
freshState = false
if err != nil {
t.logger.Critical(context.Background(), "failed to handle workspace Update", slog.Error(err))
cErr := t.client.Close()
Expand All @@ -1083,6 +1086,7 @@ type WorkspaceUpdate struct {
UpsertedAgents []*Agent
DeletedWorkspaces []*Workspace
DeletedAgents []*Agent
FreshState bool
}

func (w *WorkspaceUpdate) Clone() WorkspaceUpdate {
Expand All @@ -1091,6 +1095,7 @@ func (w *WorkspaceUpdate) Clone() WorkspaceUpdate {
UpsertedAgents: make([]*Agent, len(w.UpsertedAgents)),
DeletedWorkspaces: make([]*Workspace, len(w.DeletedWorkspaces)),
DeletedAgents: make([]*Agent, len(w.DeletedAgents)),
FreshState: w.FreshState,
}
for i, ws := range w.UpsertedWorkspaces {
clone.UpsertedWorkspaces[i] = &Workspace{
Expand All @@ -1115,7 +1120,7 @@ func (w *WorkspaceUpdate) Clone() WorkspaceUpdate {
return clone
}

func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error {
func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate, freshState bool) error {
t.Lock()
defer t.Unlock()

Expand All @@ -1124,6 +1129,7 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error {
UpsertedAgents: []*Agent{},
DeletedWorkspaces: []*Workspace{},
DeletedAgents: []*Agent{},
FreshState: freshState,
}

for _, uw := range update.UpsertedWorkspaces {
Expand Down
10 changes: 8 additions & 2 deletions tailnet/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1611,13 +1611,15 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
},
DeletedWorkspaces: []*tailnet.Workspace{},
DeletedAgents: []*tailnet.Agent{},
FreshState: true,
}

// And the callback
cbUpdate := testutil.TryReceive(ctx, t, fUH.ch)
require.Equal(t, currentState, cbUpdate)

// Current recvState should match
// Current recvState should match but shouldn't be a fresh state
currentState.FreshState = false
recvState, err := updateCtrl.CurrentState()
require.NoError(t, err)
slices.SortFunc(recvState.UpsertedWorkspaces, func(a, b *tailnet.Workspace) int {
Expand Down Expand Up @@ -1692,12 +1694,14 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
},
DeletedWorkspaces: []*tailnet.Workspace{},
DeletedAgents: []*tailnet.Agent{},
FreshState: true,
}

cbUpdate := testutil.TryReceive(ctx, t, fUH.ch)
require.Equal(t, initRecvUp, cbUpdate)

// Current state should match initial
// Current state should match initial but shouldn't be a fresh state
initRecvUp.FreshState = false
state, err := updateCtrl.CurrentState()
require.NoError(t, err)
require.Equal(t, initRecvUp, state)
Expand Down Expand Up @@ -1753,6 +1757,7 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
"w1.coder.": {ws1a1IP},
}},
},
FreshState: false,
}
require.Equal(t, sndRecvUpdate, cbUpdate)

Expand All @@ -1771,6 +1776,7 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
},
DeletedWorkspaces: []*tailnet.Workspace{},
DeletedAgents: []*tailnet.Agent{},
FreshState: false,
}, state)
}

Expand Down
57 changes: 45 additions & 12 deletions vpn/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,57 @@ func (u *updater) sendUpdateResponse(req *request[*TunnelMessage, *ManagerMessag
// createPeerUpdateLocked creates a PeerUpdate message from a workspace update, populating
// the network status of the agents.
func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUpdate {
// this flag is true on the first update after a reconnect
if update.FreshState {
// ignoredWorkspaces is initially populated with the workspaces that are
// in the current update. Later on we populate it with the deleted workspaces too
// so that we don't send duplicate updates. Same applies to ignoredAgents.
ignoredWorkspaces := make(map[uuid.UUID]struct{}, len(update.UpsertedWorkspaces))
ignoredAgents := make(map[uuid.UUID]struct{}, len(update.UpsertedAgents))

for _, workspace := range update.UpsertedWorkspaces {
ignoredWorkspaces[workspace.ID] = struct{}{}
}
for _, agent := range update.UpsertedAgents {
ignoredAgents[agent.ID] = struct{}{}
}
for _, agent := range u.agents {
if _, ok := ignoredAgents[agent.ID]; !ok {
// delete any current agents that are not in the new update
update.DeletedAgents = append(update.DeletedAgents, &tailnet.Agent{
ID: agent.ID,
Name: agent.Name,
WorkspaceID: agent.WorkspaceID,
})
// if the workspace connected to an agent we're deleting,
// is not present in the fresh state, add it to the deleted workspaces
if _, ok := ignoredWorkspaces[agent.WorkspaceID]; !ok {
update.DeletedWorkspaces = append(update.DeletedWorkspaces, &tailnet.Workspace{
// other fields cannot be populated because the tunnel
// only stores agents and corresponding workspaceIDs
ID: agent.WorkspaceID,
})
ignoredWorkspaces[agent.WorkspaceID] = struct{}{}
}
}
}
}

out := &PeerUpdate{
UpsertedWorkspaces: make([]*Workspace, len(update.UpsertedWorkspaces)),
UpsertedAgents: make([]*Agent, len(update.UpsertedAgents)),
DeletedWorkspaces: make([]*Workspace, len(update.DeletedWorkspaces)),
DeletedAgents: make([]*Agent, len(update.DeletedAgents)),
}

u.saveUpdateLocked(update)
// save the workspace update to the tunnel's state, such that it can
// be used to populate automated peer updates.
for _, agent := range update.UpsertedAgents {
u.agents[agent.ID] = agent.Clone()
}
for _, agent := range update.DeletedAgents {
delete(u.agents, agent.ID)
}

for i, ws := range update.UpsertedWorkspaces {
out.UpsertedWorkspaces[i] = &Workspace{
Expand All @@ -413,6 +456,7 @@ func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUp
Status: Workspace_Status(ws.Status),
}
}

upsertedAgents := u.convertAgentsLocked(update.UpsertedAgents)
out.UpsertedAgents = upsertedAgents
for i, ws := range update.DeletedWorkspaces {
Expand Down Expand Up @@ -472,17 +516,6 @@ func (u *updater) convertAgentsLocked(agents []*tailnet.Agent) []*Agent {
return out
}

// saveUpdateLocked saves the workspace update to the tunnel's state, such that it can
// be used to populate automated peer updates.
func (u *updater) saveUpdateLocked(update tailnet.WorkspaceUpdate) {
for _, agent := range update.UpsertedAgents {
u.agents[agent.ID] = agent.Clone()
}
for _, agent := range update.DeletedAgents {
delete(u.agents, agent.ID)
}
}

// setConn sets the `conn` and returns false if there's already a connection set.
func (u *updater) setConn(conn Conn) bool {
u.mu.Lock()
Expand Down
98 changes: 98 additions & 0 deletions vpn/tunnel_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,104 @@ func TestTunnel_sendAgentUpdate(t *testing.T) {
require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime())
}

func TestTunnel_sendAgentUpdateReconnect(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)

mClock := quartz.NewMock(t)

wID1 := uuid.UUID{1}
aID1 := uuid.UUID{2}
aID2 := uuid.UUID{3}
hsTime := time.Now().Add(-time.Minute).UTC()

client := newFakeClient(ctx, t)
conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime)

tun, mgr := setupTunnel(t, ctx, client, mClock)
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
TunnelFileDescriptor: 2,
CoderUrl: "https://coder.example.com",
ApiToken: "fakeToken",
},
},
})
resp = r
errCh <- err
}()
testutil.RequireSend(ctx, t, client.ch, conn)
err := testutil.TryReceive(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)

// Inform the tunnel of the initial state
err = tun.Update(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wID1, Name: "w1", Status: proto.Workspace_STARTING,
},
},
UpsertedAgents: []*tailnet.Agent{
{
ID: aID1,
Name: "agent1",
WorkspaceID: wID1,
Hosts: map[dnsname.FQDN][]netip.Addr{
"agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
},
},
},
})
require.NoError(t, err)
req := testutil.TryReceive(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)

// Upsert a new agent simulating a reconnect
err = tun.Update(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wID1, Name: "w1", Status: proto.Workspace_STARTING,
},
},
UpsertedAgents: []*tailnet.Agent{
{
ID: aID2,
Name: "agent2",
WorkspaceID: wID1,
Hosts: map[dnsname.FQDN][]netip.Addr{
"agent2.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
},
},
},
FreshState: true,
})
require.NoError(t, err)

// The new update only contains the new agent
mClock.AdvanceNext()
req = testutil.TryReceive(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
peerUpdate := req.msg.GetPeerUpdate()
require.NotNil(t, peerUpdate)
require.Len(t, peerUpdate.UpsertedAgents, 1)
require.Len(t, peerUpdate.DeletedAgents, 1)

require.Equal(t, aID2[:], peerUpdate.UpsertedAgents[0].Id)
require.Equal(t, hsTime, peerUpdate.UpsertedAgents[0].LastHandshake.AsTime())

require.Equal(t, aID1[:], peerUpdate.DeletedAgents[0].Id)
}

//nolint:revive // t takes precedence
func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock quartz.Clock) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) {
mp, tp := net.Pipe()
Expand Down
Loading