diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index f8f40d362e999..53a67fc3d217f 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -732,10 +732,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R }) return } + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() + go httpapi.Heartbeat(ctx, conn) defer conn.Close(websocket.StatusNormalClosure, "") - err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) + err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 83d48923be815..d0344eb7f07b8 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) { return nil, codersdk.ReadBodyAsError(res) } + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + // Ping once every 30 seconds to ensure that the websocket is alive. If we // don't get a response within 30s we kill the websocket and reconnect. // See: https://github.com/coder/coder/pull/5824 @@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) { } }() - return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil + return wsNetConn, nil } type PostAppHealthsRequest struct { @@ -529,3 +531,44 @@ type closeFunc func() error func (c closeFunc) Close() error { return c() } + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// websocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +} diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 2a7ed58bb0ba0..814f8bf57ff37 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/cookiejar" "net/url" @@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after return nil, nil, ReadBodyAsError(res) } logs := make(chan ProvisionerJobLog) - decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText)) closed := make(chan struct{}) + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) + decoder := json.NewDecoder(wsNetConn) go func() { defer close(closed) defer close(logs) @@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } }() return logs, closeFunc(func() error { - _ = conn.Close(websocket.StatusNormalClosure, "") + _ = wsNetConn.Close() <-closed return nil }), nil } -// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation. +// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon +// implementation. The context is during dial, not during the lifetime of the +// client. Client should be closed after use. func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) { serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization)) if err != nil { @@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) + // Use background context because caller should close the client. + _, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary) + session, err := yamux.Client(wsNetConn, config) if err != nil { + _ = conn.Close(websocket.StatusGoingAway, "") + _ = wsNetConn.Close() return nil, xerrors.Errorf("multiplex client: %w", err) } return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil } + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +// @typescript-ignore wsNetConn +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// websocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +} diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 96ae353bb63fb..9fbb9eb9200c6 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec } return nil, ReadBodyAsError(res) } - return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil + return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil } // WorkspaceAgentListeningPorts returns a list of ports that are currently being diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 7fbc5b42b17c1..057579fcfee8f 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -1,11 +1,13 @@ package coderd import ( + "context" "database/sql" "encoding/json" "errors" "fmt" "io" + "net" "net/http" "strings" @@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { // @Success 101 // @Router /organizations/{organization}/provisionerdaemons/serve [get] func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + tags := map[string]string{} if r.URL.Query().Has("tag") { for _, tag := range r.URL.Query()["tag"] { parts := strings.SplitN(tag, "=", 2) if len(parts) < 2 { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag), }) return @@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } } if !r.URL.Query().Has("provisioner") { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "The provisioner query parameter must be specified.", }) return @@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) case string(codersdk.ProvisionerTypeTerraform): provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{} default: - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: fmt.Sprintf("Unknown provisioner type %q", provisioner), }) return @@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization { if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) { - httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ Message: "You aren't allowed to create provisioner daemons for the organization.", }) return @@ -155,7 +159,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } name := namesgenerator.GetRandomName(1) - daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ + daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{ ID: uuid.New(), CreatedAt: database.Now(), Name: name, @@ -163,7 +167,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) Tags: tags, }) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error writing provisioner daemon.", Detail: err.Error(), }) @@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) rawTags, err := json.Marshal(daemon.Tags) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error marshaling daemon tags.", Detail: err.Error(), }) @@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) CompressionMode: websocket.CompressionDisabled, }) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Internal error accepting websocket connection.", Detail: err.Error(), }) @@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) // the same connection. config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() + session, err := yamux.Server(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) return @@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) if xerrors.Is(err, io.EOF) { return } - api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err)) + api.Logger.Debug(ctx, "drpc server error", slog.Error(err)) }, }) - err = server.Serve(r.Context(), session) + err = server.Serve(ctx, session) if err != nil && !xerrors.Is(err, io.EOF) { - api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err)) + api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) return } @@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis } return result } + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// websocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +}