Skip to content

Commit c9b7d61

Browse files
authored
chore: refactor agent connection updates (coder#11301)
Refactors the code that handles monitoring an agent websocket with pings and updating the connection times in the DB. Consolidates v1 and v2 agent APIs under the same code for this. One substantive change (not _just_ a refactor) is that I've made it so that we actually disconnect if the agent fails to respond to our pings, rather than the old behavior where we would update the database, but not actually tear down the websocket.
1 parent 520c3a8 commit c9b7d61

File tree

3 files changed

+661
-343
lines changed

3 files changed

+661
-343
lines changed

coderd/workspaceagents.go

Lines changed: 15 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ import (
1212
"net/http"
1313
"net/netip"
1414
"net/url"
15-
"runtime/pprof"
1615
"sort"
1716
"strconv"
1817
"strings"
19-
"sync/atomic"
2018
"time"
2119

2220
"github.com/google/uuid"
@@ -42,7 +40,6 @@ import (
4240
"github.com/coder/coder/v2/coderd/httpmw"
4341
"github.com/coder/coder/v2/coderd/prometheusmetrics"
4442
"github.com/coder/coder/v2/coderd/rbac"
45-
"github.com/coder/coder/v2/coderd/util/ptr"
4643
"github.com/coder/coder/v2/codersdk"
4744
"github.com/coder/coder/v2/codersdk/agentsdk"
4845
"github.com/coder/coder/v2/tailnet"
@@ -1084,21 +1081,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
10841081
api.WebsocketWaitMutex.Unlock()
10851082
defer api.WebsocketWaitGroup.Done()
10861083
workspaceAgent := httpmw.WorkspaceAgent(r)
1087-
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
1088-
if err != nil {
1089-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
1090-
Message: "Failed to accept websocket.",
1091-
Detail: err.Error(),
1092-
})
1093-
return
1094-
}
1095-
1096-
build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
1097-
if err != nil {
1098-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
1099-
Message: "Internal error fetching workspace build job.",
1100-
Detail: err.Error(),
1101-
})
1084+
// Ensure the resource is still valid!
1085+
// We only accept agents for resources on the latest build.
1086+
build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
1087+
if !ok {
11021088
return
11031089
}
11041090

@@ -1120,32 +1106,6 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
11201106
return
11211107
}
11221108

1123-
// Ensure the resource is still valid!
1124-
// We only accept agents for resources on the latest build.
1125-
ensureLatestBuild := func() error {
1126-
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
1127-
if err != nil {
1128-
return err
1129-
}
1130-
if build.ID != latestBuild.ID {
1131-
return xerrors.New("build is outdated")
1132-
}
1133-
return nil
1134-
}
1135-
1136-
err = ensureLatestBuild()
1137-
if err != nil {
1138-
api.Logger.Debug(ctx, "agent tried to connect from non-latest build",
1139-
slog.F("resource", resource),
1140-
slog.F("agent", workspaceAgent),
1141-
)
1142-
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
1143-
Message: "Agent trying to connect from non-latest build.",
1144-
Detail: err.Error(),
1145-
})
1146-
return
1147-
}
1148-
11491109
conn, err := websocket.Accept(rw, r, nil)
11501110
if err != nil {
11511111
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -1158,109 +1118,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
11581118
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
11591119
defer wsNetConn.Close()
11601120

1161-
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
1162-
// because we want to log the agent's last ping time.
1163-
var lastPing atomic.Pointer[time.Time]
1164-
lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
1165-
1166-
go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) {
1167-
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
1168-
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
1169-
defer t.Stop()
1170-
1171-
for {
1172-
select {
1173-
case <-t.C:
1174-
case <-ctx.Done():
1175-
return
1176-
}
1177-
1178-
// We don't need a context that times out here because the ping will
1179-
// eventually go through. If the context times out, then other
1180-
// websocket read operations will receive an error, obfuscating the
1181-
// actual problem.
1182-
err := conn.Ping(ctx)
1183-
if err != nil {
1184-
return
1185-
}
1186-
lastPing.Store(ptr.Ref(time.Now()))
1187-
}
1188-
})
1189-
1190-
firstConnectedAt := workspaceAgent.FirstConnectedAt
1191-
if !firstConnectedAt.Valid {
1192-
firstConnectedAt = sql.NullTime{
1193-
Time: dbtime.Now(),
1194-
Valid: true,
1195-
}
1196-
}
1197-
lastConnectedAt := sql.NullTime{
1198-
Time: dbtime.Now(),
1199-
Valid: true,
1200-
}
1201-
disconnectedAt := workspaceAgent.DisconnectedAt
1202-
updateConnectionTimes := func(ctx context.Context) error {
1203-
//nolint:gocritic // We only update ourself.
1204-
err = api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
1205-
ID: workspaceAgent.ID,
1206-
FirstConnectedAt: firstConnectedAt,
1207-
LastConnectedAt: lastConnectedAt,
1208-
DisconnectedAt: disconnectedAt,
1209-
UpdatedAt: dbtime.Now(),
1210-
LastConnectedReplicaID: uuid.NullUUID{
1211-
UUID: api.ID,
1212-
Valid: true,
1213-
},
1214-
})
1215-
if err != nil {
1216-
return err
1217-
}
1218-
return nil
1219-
}
1220-
1221-
defer func() {
1222-
// If connection closed then context will be canceled, try to
1223-
// ensure our final update is sent. By waiting at most the agent
1224-
// inactive disconnect timeout we ensure that we don't block but
1225-
// also guarantee that the agent will be considered disconnected
1226-
// by normal status check.
1227-
//
1228-
// Use a system context as the agent has disconnected and that token
1229-
// may no longer be valid.
1230-
//nolint:gocritic
1231-
ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout)
1232-
defer cancel()
1233-
1234-
// Only update timestamp if the disconnect is new.
1235-
if !disconnectedAt.Valid {
1236-
disconnectedAt = sql.NullTime{
1237-
Time: dbtime.Now(),
1238-
Valid: true,
1239-
}
1240-
}
1241-
err := updateConnectionTimes(ctx)
1242-
if err != nil {
1243-
// This is a bug with unit tests that cancel the app context and
1244-
// cause this error log to be generated. We should fix the unit tests
1245-
// as this is a valid log.
1246-
//
1247-
// The pq error occurs when the server is shutting down.
1248-
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
1249-
api.Logger.Error(ctx, "failed to update agent disconnect time",
1250-
slog.Error(err),
1251-
slog.F("workspace_id", build.WorkspaceID),
1252-
)
1253-
}
1254-
}
1255-
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
1256-
}()
1257-
1258-
err = updateConnectionTimes(ctx)
1259-
if err != nil {
1260-
_ = conn.Close(websocket.StatusGoingAway, err.Error())
1261-
return
1262-
}
1263-
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
1121+
closeCtx, closeCtxCancel := context.WithCancel(ctx)
1122+
defer closeCtxCancel()
1123+
monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn)
1124+
defer monitor.close()
12641125

12651126
api.Logger.Debug(ctx, "accepting agent",
12661127
slog.F("owner", owner.Username),
@@ -1271,61 +1132,13 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
12711132

12721133
defer conn.Close(websocket.StatusNormalClosure, "")
12731134

1274-
closeChan := make(chan struct{})
1275-
go func() {
1276-
defer close(closeChan)
1277-
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
1278-
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
1279-
)
1280-
if err != nil {
1281-
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
1282-
_ = conn.Close(websocket.StatusInternalError, err.Error())
1283-
return
1284-
}
1285-
}()
1286-
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
1287-
defer ticker.Stop()
1288-
for {
1289-
select {
1290-
case <-closeChan:
1291-
return
1292-
case <-ticker.C:
1293-
}
1294-
1295-
lastPing := *lastPing.Load()
1296-
1297-
var connectionStatusChanged bool
1298-
if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout {
1299-
if !disconnectedAt.Valid {
1300-
connectionStatusChanged = true
1301-
disconnectedAt = sql.NullTime{
1302-
Time: dbtime.Now(),
1303-
Valid: true,
1304-
}
1305-
}
1306-
} else {
1307-
connectionStatusChanged = disconnectedAt.Valid
1308-
// TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it?
1309-
disconnectedAt = sql.NullTime{}
1310-
lastConnectedAt = sql.NullTime{
1311-
Time: dbtime.Now(),
1312-
Valid: true,
1313-
}
1314-
}
1315-
err = updateConnectionTimes(ctx)
1316-
if err != nil {
1317-
_ = conn.Close(websocket.StatusGoingAway, err.Error())
1318-
return
1319-
}
1320-
if connectionStatusChanged {
1321-
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
1322-
}
1323-
err := ensureLatestBuild()
1324-
if err != nil {
1325-
// Disconnect agents that are no longer valid.
1326-
_ = conn.Close(websocket.StatusGoingAway, "")
1327-
return
1328-
}
1135+
err = (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
1136+
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
1137+
)
1138+
if err != nil {
1139+
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
1140+
_ = conn.Close(websocket.StatusInternalError, err.Error())
1141+
return
13291142
}
13301143
}
13311144

0 commit comments

Comments
 (0)