Skip to content

chore: add immortal streams manager #19225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: mike/immortal-streams-backed-base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/immortalstreams"

Check failure on line 44 in agent/agent.go

View workflow job for this annotation

GitHub Actions / lint

could not import github.com/coder/coder/v2/agent/immortalstreams (-: # github.com/coder/coder/v2/agent/immortalstreams
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
"github.com/coder/coder/v2/agent/reconnectingpty"
Expand Down Expand Up @@ -280,6 +281,9 @@
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API

// Immortal streams
immortalStreamsManager *immortalstreams.Manager
}

func (a *agent) TailnetConn() *tailnet.Conn {
Expand Down Expand Up @@ -347,6 +351,9 @@

a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)

// Initialize immortal streams manager
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), &net.Dialer{})

a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
Expand Down Expand Up @@ -1930,6 +1937,12 @@
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}

if a.immortalStreamsManager != nil {
if err := a.immortalStreamsManager.Close(); err != nil {
a.logger.Error(a.hardCtx, "immortal streams manager close", slog.Error(err))
}
}

// Wait for the graceful shutdown to complete, but don't wait forever so
// that we don't break user expectations.
go func() {
Expand Down
7 changes: 7 additions & 0 deletions agent/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/google/uuid"

"github.com/coder/coder/v2/agent/immortalstreams"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
Expand Down Expand Up @@ -66,6 +67,12 @@ func (a *agent) apiHandler() http.Handler {
r.Get("/debug/manifest", a.HandleHTTPDebugManifest)
r.Get("/debug/prometheus", promHandler.ServeHTTP)

// Mount immortal streams API
if a.immortalStreamsManager != nil {
immortalStreamsHandler := immortalstreams.NewHandler(a.logger, a.immortalStreamsManager)
r.Mount("/api/v0/immortal-stream", immortalStreamsHandler.Routes())
}

return r
}

Expand Down
261 changes: 261 additions & 0 deletions agent/immortalstreams/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package immortalstreams

import (
"context"
"fmt"
"net/http"
"strconv"
"strings"

"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/websocket"
)

// Handler handles immortal stream requests
type Handler struct {
logger slog.Logger
manager *Manager
}

// NewHandler creates a new immortal streams handler
func NewHandler(logger slog.Logger, manager *Manager) *Handler {
return &Handler{
logger: logger,
manager: manager,
}
}

// Routes registers the immortal streams routes
func (h *Handler) Routes() chi.Router {
r := chi.NewRouter()

r.Post("/", h.createStream)
r.Get("/", h.listStreams)
r.Route("/{streamID}", func(r chi.Router) {
r.Use(h.streamMiddleware)
r.Get("/", h.handleStreamRequest)
r.Delete("/", h.deleteStream)
})

return r
}

// streamMiddleware validates and extracts the stream ID
func (*Handler) streamMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
streamIDStr := chi.URLParam(r, "streamID")
streamID, err := uuid.Parse(streamIDStr)
if err != nil {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "Invalid stream ID format",
})
return
}

ctx := context.WithValue(r.Context(), streamIDKey{}, streamID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// createStream creates a new immortal stream
func (h *Handler) createStream(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

var req codersdk.CreateImmortalStreamRequest
if !httpapi.Read(ctx, w, r, &req) {
return
}

stream, err := h.manager.CreateStream(ctx, req.TCPPort)
if err != nil {
if strings.Contains(err.Error(), "too many immortal streams") {
httpapi.Write(ctx, w, http.StatusServiceUnavailable, codersdk.Response{
Message: "Too many Immortal Streams.",
})
return
}
if strings.Contains(err.Error(), "the connection was refused") {
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
Message: "The connection was refused.",
})
return
}
httpapi.InternalServerError(w, err)
return
}

httpapi.Write(ctx, w, http.StatusCreated, stream)
}

// listStreams lists all immortal streams
func (h *Handler) listStreams(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
streams := h.manager.ListStreams()
httpapi.Write(ctx, w, http.StatusOK, streams)
}

// handleStreamRequest handles GET requests for a specific stream and returns stream info or handles WebSocket upgrades
func (h *Handler) handleStreamRequest(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
streamID := getStreamID(ctx)

// Check if this is a WebSocket upgrade request by looking for WebSocket headers
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
h.handleUpgrade(w, r)
return
}

// Otherwise, return stream info
stream, ok := h.manager.GetStream(streamID)
if !ok {
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
Message: "Stream not found",
})
return
}

httpapi.Write(ctx, w, http.StatusOK, stream.ToAPI())
}

// deleteStream deletes a stream
func (h *Handler) deleteStream(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
streamID := getStreamID(ctx)

err := h.manager.DeleteStream(streamID)
if err != nil {
if strings.Contains(err.Error(), "stream not found") {
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
Message: "Stream not found",
})
return
}
httpapi.InternalServerError(w, err)
return
}

w.WriteHeader(http.StatusNoContent)
}

// handleUpgrade handles WebSocket upgrade for immortal stream connections
func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
streamID := getStreamID(ctx)

// Get sequence numbers from headers
readSeqNum, err := parseSequenceNumber(r.Header.Get(codersdk.HeaderImmortalStreamSequenceNum))
if err != nil {
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid sequence number: %v", err),
})
return
}

// Check if stream exists
_, ok := h.manager.GetStream(streamID)
if !ok {
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
Message: "Stream not found",
})
return
}

// Upgrade to WebSocket
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
h.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
return
}

// Create a context that we can cancel to clean up the connection
connCtx, cancel := context.WithCancel(ctx)
defer cancel()

// Ensure WebSocket is closed when this function returns
defer func() {
conn.Close(websocket.StatusNormalClosure, "connection closed")
}()

// Create a WebSocket adapter
wsConn := &wsConn{
conn: conn,
logger: h.logger,
ctx: connCtx,
cancel: cancel,
}

// Handle the reconnection - this establishes the connection
// BackedPipe only needs the reader sequence number for replay
err = h.manager.HandleConnection(streamID, wsConn, readSeqNum)
if err != nil {
h.logger.Error(ctx, "failed to handle connection", slog.Error(err))
conn.Close(websocket.StatusInternalError, err.Error())
return
}

// Keep the connection open until the context is canceled
// The wsConn will handle connection closure through its Read/Write methods
// When the connection is closed, the backing pipe will detect it and the context should be canceled
<-connCtx.Done()
}

// wsConn adapts a WebSocket connection to io.ReadWriteCloser
type wsConn struct {
conn *websocket.Conn
logger slog.Logger
ctx context.Context
cancel context.CancelFunc
}

func (c *wsConn) Read(p []byte) (n int, err error) {
typ, data, err := c.conn.Read(c.ctx)
if err != nil {
// Cancel the context when read fails (connection closed)
c.cancel()
return 0, err
}
if typ != websocket.MessageBinary {
return 0, xerrors.Errorf("unexpected message type: %v", typ)
}
n = copy(p, data)
return n, nil
}

func (c *wsConn) Write(p []byte) (n int, err error) {
err = c.conn.Write(c.ctx, websocket.MessageBinary, p)
if err != nil {
// Cancel the context when write fails (connection closed)
c.cancel()
return 0, err
}
return len(p), nil
}

func (c *wsConn) Close() error {
c.cancel() // Cancel the context when explicitly closed
return c.conn.Close(websocket.StatusNormalClosure, "")
}

// parseSequenceNumber parses a sequence number from a string
func parseSequenceNumber(s string) (uint64, error) {
if s == "" {
return 0, nil
}
return strconv.ParseUint(s, 10, 64)
}

// getStreamID gets the stream ID from the context
func getStreamID(ctx context.Context) uuid.UUID {
id, _ := ctx.Value(streamIDKey{}).(uuid.UUID)
return id
}

type streamIDKey struct{}
Loading
Loading