Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix WebSocket not closing
  • Loading branch information
kylecarbs committed Sep 19, 2022
commit 0b6b47072ce4df108e7922b76fc1333634f45fc5
9 changes: 7 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
return
}
ssh, err := agentConn.SSH()
if !assert.NoError(t, err) {
if err != nil {
_ = conn.Close()
return
}
Expand Down Expand Up @@ -581,11 +581,16 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
},
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
})
go coordinator.ServeAgent(serverConn, agentID)
go func() {
_ = coordinator.ServeAgent(serverConn, agentID)
close(closed)
}()
return clientConn, nil
},
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Expand Down
1 change: 1 addition & 0 deletions codersdk/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
go func() {
defer close(logs)
defer conn.Close(websocket.StatusGoingAway, "")
var log ProvisionerJobLog
for {
err = decoder.Decode(&log)
Expand Down
7 changes: 7 additions & 0 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,12 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
CompressionMode: websocket.CompressionDisabled,
})
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusAbnormalClosure, "")
return
}
if err != nil {
logger.Debug(ctx, "failed to dial", slog.Error(err))
_ = ws.Close(websocket.StatusAbnormalClosure, "")
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
Expand All @@ -294,12 +296,15 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
logger.Debug(ctx, "serving coordinator")
err = <-errChan
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusAbnormalClosure, "")
return
}
if err != nil {
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
_ = ws.Close(websocket.StatusAbnormalClosure, "")
continue
}
_ = ws.Close(websocket.StatusAbnormalClosure, "")
}
}()
return &agent.Conn{
Expand Down Expand Up @@ -423,6 +428,7 @@ func (c *Client) AgentReportStats(
var req AgentStatsReportRequest
err := wsjson.Read(ctx, conn, &req)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "")
return err
}

Expand All @@ -436,6 +442,7 @@ func (c *Client) AgentReportStats(

err = wsjson.Write(ctx, conn, resp)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "")
return err
}
}
Expand Down