Skip to content

fix: ensure wsproxy MultiAgent is closed when websocket dies #11414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions coderd/httpapi/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"time"

"cdr.dev/slog"
"nhooyr.io/websocket"
)

Expand All @@ -26,10 +27,10 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) {
}
}

// Heartbeat loops to ping a WebSocket to keep it alive. It kills the connection
// on ping failure.
func HeartbeatClose(ctx context.Context, exit func(), conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
// Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping
// failure.
func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: lines 43 to 43]

Drop an INFO log here

See this comment inline on Graphite.

Expand All @@ -41,6 +42,7 @@ func HeartbeatClose(ctx context.Context, exit func(), conn *websocket.Conn) {
err := conn.Ping(ctx)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "Ping failed")
logger.Info(ctx, "failed to heartbeat ping", slog.Error(err))
exit()
return
}
Expand Down
7 changes: 7 additions & 0 deletions coderd/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func (s *ServerTailnet) watchAgentUpdates() {
nodes, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
s.reinitCoordinator()
continue
}
Expand All @@ -247,6 +248,7 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
}

func (s *ServerTailnet) reinitCoordinator() {
start := time.Now()
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
Expand All @@ -264,6 +266,11 @@ func (s *ServerTailnet) reinitCoordinator() {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}

s.logger.Info(s.ctx, "successfully reinitialized multiagent",
slog.F("agents", len(s.agentConnectionTimes)),
slog.F("took", time.Since(start)),
)
s.nodesMu.Unlock()
return
}
Expand Down
24 changes: 18 additions & 6 deletions enterprise/wsproxy/wsproxysdk/wsproxysdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ type CoordinateNodes struct {

func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
ctx, cancel := context.WithCancel(ctx)
logger := c.SDKClient.Logger().Named("multiagent")

coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
if err != nil {
Expand All @@ -454,12 +455,13 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
return nil, xerrors.Errorf("dial coordinate websocket: %w", err)
}

go httpapi.HeartbeatClose(ctx, cancel, conn)
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)

nc := websocket.NetConn(ctx, conn, websocket.MessageText)
rma := remoteMultiAgentHandler{
sdk: c,
nc: nc,
cancel: cancel,
legacyAgentCache: map[uuid.UUID]bool{},
}

Expand All @@ -472,6 +474,11 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") },
}).Init()

go func() {
<-ctx.Done()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding this correctly, we're depending on the fact that the reader goroutine below cancels the context on a failed read.

I think we should also tear down the multi-agent on a failed write of subscription messages. It's unlikely that we'd have a failure that leaves the connection half-open (e.g. for reads but not writes), but such things are possible and you don't want the proxy limping on unable to subscribe to new agents.

ma.Close()
}()

go func() {
defer cancel()
dec := json.NewDecoder(nc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: lines 488 to 488]

I think it's worth dropping an INFO log here.

See this comment inline on Graphite.

Expand All @@ -480,16 +487,17 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
err := dec.Decode(&msg)
if err != nil {
if xerrors.Is(err, io.EOF) {
logger.Info(ctx, "websocket connection severed", slog.Error(err))
return
}

c.SDKClient.Logger().Error(ctx, "failed to decode coordinator nodes", slog.Error(err))
logger.Error(ctx, "decode coordinator nodes", slog.Error(err))
return
}

err = ma.Enqueue(msg.Nodes)
if err != nil {
c.SDKClient.Logger().Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
continue
}
}
Expand All @@ -499,8 +507,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
}

type remoteMultiAgentHandler struct {
sdk *Client
nc net.Conn
sdk *Client
nc net.Conn
cancel func()

legacyMu sync.RWMutex
legacyAgentCache map[uuid.UUID]bool
Expand All @@ -517,10 +526,12 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
a.cancel()
return xerrors.Errorf("set write deadline: %w", err)
}
_, err = a.nc.Write(data)
if err != nil {
a.cancel()
return xerrors.Errorf("write message: %w", err)
}

Expand All @@ -531,6 +542,7 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
// our successful write, it is important that we reset the deadline before it fires.
err = a.nc.SetWriteDeadline(time.Time{})
if err != nil {
a.cancel()
return xerrors.Errorf("clear write deadline: %w", err)
}

Expand Down Expand Up @@ -573,7 +585,7 @@ func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool {
return a.sdk.AgentIsLegacy(ctx, agentID)
})
if err != nil {
a.sdk.SDKClient.Logger().Error(ctx, "failed to check agent legacy status", slog.Error(err))
a.sdk.SDKClient.Logger().Error(ctx, "failed to check agent legacy status", slog.F("agent_id", agentID), slog.Error(err))

// Assume that the agent is legacy since this failed, while less
// efficient it will always work.
Expand Down