diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 90396235993cb..f26ebe92d8283 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1,6 +1,7 @@ package coderd import ( + "context" "database/sql" "encoding/json" "fmt" @@ -16,6 +17,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/agent" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" @@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "") - }() + + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + session, err := yamux.Server(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{ + err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{ ChannelID: workspaceAgent.ID.String(), Logger: api.Logger.Named("peerbroker-proxy-dial"), Pubsub: api.Pubsub, @@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "") - }() + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + session, err := yamux.Server(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return @@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { } disconnectedAt := workspaceAgent.DisconnectedAt updateConnectionTimes := func() error { - err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{ + err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ ID: workspaceAgent.ID, FirstConnectedAt: firstConnectedAt, LastConnectedAt: lastConnectedAt, @@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } - api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) + api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) defer ticker.Stop() @@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = wsConn.Close(websocket.StatusNormalClosure, "") - }() - netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary) - api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) + + ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + + api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) select { - case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed(): - case <-r.Context().Done(): + case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed(): + case <-ctx.Done(): } - api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) + api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) } // workspaceAgentPTY spawns a PTY and pipes it over a WebSocket. @@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "ended") - }() - // Accept text connections, because it's more developer friendly. - wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary) - agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID) + + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + + agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return @@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { _, _ = io.Copy(ptNetConn, wsNetConn) } -// dialWorkspaceAgent connects to a workspace agent by ID. -func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { +// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on +// r.Context() for cancellation if it's use is safe or r.Hijack() has +// not been performed. +func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { client, server := provisionersdk.TransportPipe() go func() { - _ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{ + _ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{ ChannelID: agentID.String(), Logger: api.Logger.Named("peerbroker-proxy-dial"), Pubsub: api.Pubsub, @@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C }() peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := peerClient.NegotiateConnection(r.Context()) + stream, err := peerClient.NegotiateConnection(ctx) if err != nil { return nil, xerrors.Errorf("negotiate: %w", err) } @@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) { clientPipe, serverPipe := net.Pipe() go func() { - <-r.Context().Done() + <-ctx.Done() _ = clientPipe.Close() _ = serverPipe.Close() }() @@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency return workspaceAgent, nil } + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// websocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +}