Skip to content

Commit e12b621

Browse files
fix(coderd): ensure agent WebSocket conn is cleaned up (coder#19711)
When clients disconnected from the /containers/watch endpoint, the WebSocket connection between coderd and the agent stayed open. This caused heartbeat traffic every 15s that was incorrectly counted as workspace activity, extending workspace lifetimes indefinitely. Now properly cancels the agent connection context when the client disconnects.
1 parent 8f72538 commit e12b621

File tree

2 files changed

+144
-6
lines changed

2 files changed

+144
-6
lines changed

coderd/workspaceagents.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -817,12 +817,13 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
817817
var (
818818
ctx = r.Context()
819819
workspaceAgent = httpmw.WorkspaceAgentParam(r)
820+
logger = api.Logger.Named("agent_container_watcher").With(slog.F("agent_id", workspaceAgent.ID))
820821
)
821822

822823
// If the agent is unreachable, the request will hang. Assume that if we
823824
// don't get a response after 30s that the agent is unreachable.
824-
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
825-
defer cancel()
825+
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
826+
defer dialCancel()
826827
apiAgent, err := db2sdk.WorkspaceAgent(
827828
api.DERPMap(),
828829
*api.TailnetCoordinator.Load(),
@@ -857,8 +858,7 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
857858
}
858859
defer release()
859860

860-
watcherLogger := api.Logger.Named("agent_container_watcher").With(slog.F("agent_id", workspaceAgent.ID))
861-
containersCh, closer, err := agentConn.WatchContainers(ctx, watcherLogger)
861+
containersCh, closer, err := agentConn.WatchContainers(ctx, logger)
862862
if err != nil {
863863
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
864864
Message: "Internal error watching agent's containers.",
@@ -877,14 +877,17 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
877877
return
878878
}
879879

880+
ctx, cancel := context.WithCancel(r.Context())
881+
defer cancel()
882+
880883
// Here we close the websocket for reading, so that the websocket library will handle pings and
881884
// close frames.
882885
_ = conn.CloseRead(context.Background())
883886

884887
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
885888
defer wsNetConn.Close()
886889

887-
go httpapi.Heartbeat(ctx, conn)
890+
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
888891

889892
encoder := json.NewEncoder(wsNetConn)
890893

coderd/workspaceagents_internal_test.go

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,145 @@ func (fakeAgentProvider) Close() error {
5959
return nil
6060
}
6161

62+
type channelCloser struct {
63+
closeFn func()
64+
}
65+
66+
func (c *channelCloser) Close() error {
67+
c.closeFn()
68+
return nil
69+
}
70+
6271
func TestWatchAgentContainers(t *testing.T) {
6372
t.Parallel()
6473

65-
t.Run("WebSocketClosesProperly", func(t *testing.T) {
74+
t.Run("CoderdWebSocketCanHandleClientClosing", func(t *testing.T) {
75+
t.Parallel()
76+
77+
// This test ensures that the agent containers `/watch` websocket can gracefully
78+
// handle the client websocket closing. This test was created in
79+
// response to this issue: https://github.com/coder/coder/issues/19449
80+
81+
var (
82+
ctx = testutil.Context(t, testutil.WaitLong)
83+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
84+
85+
mCtrl = gomock.NewController(t)
86+
mDB = dbmock.NewMockStore(mCtrl)
87+
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)
88+
mAgentConn = agentconnmock.NewMockAgentConn(mCtrl)
89+
90+
fAgentProvider = fakeAgentProvider{
91+
agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
92+
return mAgentConn, func() {}, nil
93+
},
94+
}
95+
96+
workspaceID = uuid.New()
97+
agentID = uuid.New()
98+
resourceID = uuid.New()
99+
jobID = uuid.New()
100+
buildID = uuid.New()
101+
102+
containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse)
103+
104+
r = chi.NewMux()
105+
106+
api = API{
107+
ctx: ctx,
108+
Options: &Options{
109+
AgentInactiveDisconnectTimeout: testutil.WaitShort,
110+
Database: mDB,
111+
Logger: logger,
112+
DeploymentValues: &codersdk.DeploymentValues{},
113+
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
114+
},
115+
}
116+
)
117+
118+
var tailnetCoordinator tailnet.Coordinator = mCoordinator
119+
api.TailnetCoordinator.Store(&tailnetCoordinator)
120+
api.agentProvider = fAgentProvider
121+
122+
// Setup: Allow `ExtractWorkspaceAgentParams` to complete.
123+
mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{
124+
ID: agentID,
125+
ResourceID: resourceID,
126+
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
127+
FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
128+
LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
129+
}, nil)
130+
mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{
131+
ID: resourceID,
132+
JobID: jobID,
133+
}, nil)
134+
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{
135+
ID: jobID,
136+
Type: database.ProvisionerJobTypeWorkspaceBuild,
137+
}, nil)
138+
mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{
139+
WorkspaceID: workspaceID,
140+
ID: buildID,
141+
}, nil)
142+
143+
// And: Allow `db2dsk.WorkspaceAgent` to complete.
144+
mCoordinator.EXPECT().Node(gomock.Any()).Return(nil)
145+
146+
// And: Allow `WatchContainers` to be called, returing our `containersCh` channel.
147+
mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()).
148+
DoAndReturn(func(_ context.Context, _ slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) {
149+
return containersCh, &channelCloser{closeFn: func() {
150+
close(containersCh)
151+
}}, nil
152+
})
153+
154+
// And: We mount the HTTP Handler
155+
r.With(httpmw.ExtractWorkspaceAgentParam(mDB)).
156+
Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers)
157+
158+
// Given: We create the HTTP server
159+
srv := httptest.NewServer(r)
160+
defer srv.Close()
161+
162+
// And: Dial the WebSocket
163+
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1)
164+
conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil)
165+
require.NoError(t, err)
166+
if resp.Body != nil {
167+
defer resp.Body.Close()
168+
}
169+
170+
// And: Create a streaming decoder
171+
decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger)
172+
defer decoder.Close()
173+
decodeCh := decoder.Chan()
174+
175+
// And: We can successfully send through the channel.
176+
testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{
177+
Containers: []codersdk.WorkspaceAgentContainer{{
178+
ID: "test-container-id",
179+
}},
180+
})
181+
182+
// And: Receive the data.
183+
containerResp := testutil.RequireReceive(ctx, t, decodeCh)
184+
require.Len(t, containerResp.Containers, 1)
185+
require.Equal(t, "test-container-id", containerResp.Containers[0].ID)
186+
187+
// When: We close the WebSocket
188+
conn.Close(websocket.StatusNormalClosure, "test closing connection")
189+
190+
// Then: We expect `containersCh` to be closed.
191+
select {
192+
case <-ctx.Done():
193+
t.Fail()
194+
195+
case _, ok := <-containersCh:
196+
require.False(t, ok, "channel is expected to be closed")
197+
}
198+
})
199+
200+
t.Run("CoderdWebSocketCanHandleAgentClosing", func(t *testing.T) {
66201
t.Parallel()
67202

68203
// This test ensures that the agent containers `/watch` websocket can gracefully

0 commit comments

Comments
 (0)