Skip to content

chore: refactor agent connection updates #11301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 15 additions & 202 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import (
"net/http"
"net/netip"
"net/url"
"runtime/pprof"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/google/uuid"
Expand All @@ -42,7 +40,6 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/prometheusmetrics"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
Expand Down Expand Up @@ -1084,21 +1081,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}

build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
if !ok {
return
}

Expand All @@ -1120,32 +1106,6 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
return
}

// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}

err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(ctx, "agent tried to connect from non-latest build",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return
}

conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Expand All @@ -1158,109 +1118,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()

// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
// because we want to log the agent's last ping time.
var lastPing atomic.Pointer[time.Time]
lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.

go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) {
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer t.Stop()

for {
select {
case <-t.C:
case <-ctx.Done():
return
}

// We don't need a context that times out here because the ping will
// eventually go through. If the context times out, then other
// websocket read operations will receive an error, obfuscating the
// actual problem.
err := conn.Ping(ctx)
if err != nil {
return
}
lastPing.Store(ptr.Ref(time.Now()))
}
})

firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func(ctx context.Context) error {
//nolint:gocritic // We only update ourself.
err = api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
UpdatedAt: dbtime.Now(),
LastConnectedReplicaID: uuid.NullUUID{
UUID: api.ID,
Valid: true,
},
})
if err != nil {
return err
}
return nil
}

defer func() {
// If connection closed then context will be canceled, try to
// ensure our final update is sent. By waiting at most the agent
// inactive disconnect timeout we ensure that we don't block but
// also guarantee that the agent will be considered disconnected
// by normal status check.
//
// Use a system context as the agent has disconnected and that token
// may no longer be valid.
//nolint:gocritic
ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout)
defer cancel()

// Only update timestamp if the disconnect is new.
if !disconnectedAt.Valid {
disconnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
}
err := updateConnectionTimes(ctx)
if err != nil {
// This is a bug with unit tests that cancel the app context and
// cause this error log to be generated. We should fix the unit tests
// as this is a valid log.
//
// The pq error occurs when the server is shutting down.
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
api.Logger.Error(ctx, "failed to update agent disconnect time",
slog.Error(err),
slog.F("workspace_id", build.WorkspaceID),
)
}
}
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
}()

err = updateConnectionTimes(ctx)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, err.Error())
return
}
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
closeCtx, closeCtxCancel := context.WithCancel(ctx)
defer closeCtxCancel()
monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn)
defer monitor.close()

api.Logger.Debug(ctx, "accepting agent",
slog.F("owner", owner.Username),
Expand All @@ -1271,61 +1132,13 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request

defer conn.Close(websocket.StatusNormalClosure, "")

closeChan := make(chan struct{})
go func() {
defer close(closeChan)
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
)
if err != nil {
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}()
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-closeChan:
return
case <-ticker.C:
}

lastPing := *lastPing.Load()

var connectionStatusChanged bool
if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout {
if !disconnectedAt.Valid {
connectionStatusChanged = true
disconnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
}
} else {
connectionStatusChanged = disconnectedAt.Valid
// TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it?
disconnectedAt = sql.NullTime{}
lastConnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
}
err = updateConnectionTimes(ctx)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, err.Error())
return
}
if connectionStatusChanged {
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
}
err := ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
err = (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
)
if err != nil {
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}

Expand Down
Loading