Skip to content

Commit f6e6e71

Browse files
coadlerkylecarbs
authored andcommitted
fix: ensure listen websocket isn't opened for non-latest agents (#2002)
Exponential backoff is only enabled if the websocket fails to open. If the websocket is opened but immediately killed, the agent will try to immediately reconnect. This is desireable in cases where coderd is being replaced or network conditions cause the connection to die, but not for permanent errors.
1 parent 38b3d60 commit f6e6e71

File tree

2 files changed

+160
-68
lines changed

2 files changed

+160
-68
lines changed

coderd/workspaceagents.go

+41-27
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,49 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
143143
defer api.websocketWaitGroup.Done()
144144

145145
workspaceAgent := httpmw.WorkspaceAgent(r)
146-
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
147-
CompressionMode: websocket.CompressionDisabled,
148-
})
146+
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
149147
if err != nil {
150148
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
151-
Message: fmt.Sprintf("accept websocket: %s", err),
149+
Message: fmt.Sprintf("get workspace resource: %s", err),
152150
})
153151
return
154152
}
155-
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
153+
154+
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
155+
if err != nil {
156+
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
157+
Message: fmt.Sprintf("get workspace build job: %s", err),
158+
})
159+
return
160+
}
161+
// Ensure the resource is still valid!
162+
// We only accept agents for resources on the latest build.
163+
ensureLatestBuild := func() error {
164+
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
165+
if err != nil {
166+
return err
167+
}
168+
if build.ID != latestBuild.ID {
169+
return xerrors.New("build is outdated")
170+
}
171+
return nil
172+
}
173+
174+
err = ensureLatestBuild()
175+
if err != nil {
176+
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
177+
slog.F("resource", resource),
178+
slog.F("agent", workspaceAgent),
179+
)
180+
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
181+
Message: fmt.Sprintf("ensure latest build: %s", err),
182+
})
183+
return
184+
}
185+
186+
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
187+
CompressionMode: websocket.CompressionDisabled,
188+
})
156189
if err != nil {
157190
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
158191
Message: fmt.Sprintf("accept websocket: %s", err),
@@ -163,13 +196,15 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
163196
defer func() {
164197
_ = conn.Close(websocket.StatusNormalClosure, "")
165198
}()
199+
166200
config := yamux.DefaultConfig()
167201
config.LogOutput = io.Discard
168202
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
169203
if err != nil {
170204
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
171205
return
172206
}
207+
173208
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
174209
ChannelID: workspaceAgent.ID.String(),
175210
Pubsub: api.Pubsub,
@@ -180,6 +215,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
180215
return
181216
}
182217
defer closer.Close()
218+
183219
firstConnectedAt := workspaceAgent.FirstConnectedAt
184220
if !firstConnectedAt.Valid {
185221
firstConnectedAt = sql.NullTime{
@@ -204,23 +240,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
204240
}
205241
return nil
206242
}
207-
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
208-
if err != nil {
209-
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
210-
return
211-
}
212-
// Ensure the resource is still valid!
213-
// We only accept agents for resources on the latest build.
214-
ensureLatestBuild := func() error {
215-
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
216-
if err != nil {
217-
return err
218-
}
219-
if build.ID != latestBuild.ID {
220-
return xerrors.New("build is outdated")
221-
}
222-
return nil
223-
}
224243

225244
defer func() {
226245
disconnectedAt = sql.NullTime{
@@ -230,11 +249,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
230249
_ = updateConnectionTimes()
231250
}()
232251

233-
err = ensureLatestBuild()
234-
if err != nil {
235-
_ = conn.Close(websocket.StatusGoingAway, "")
236-
return
237-
}
238252
err = updateConnectionTimes()
239253
if err != nil {
240254
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())

coderd/workspaceagents_test.go

+119-41
Original file line numberDiff line numberDiff line change
@@ -68,52 +68,130 @@ func TestWorkspaceAgent(t *testing.T) {
6868

6969
func TestWorkspaceAgentListen(t *testing.T) {
7070
t.Parallel()
71-
client, coderAPI := coderdtest.NewWithAPI(t, nil)
72-
user := coderdtest.CreateFirstUser(t, client)
73-
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
74-
authToken := uuid.NewString()
75-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
76-
Parse: echo.ParseComplete,
77-
ProvisionDryRun: echo.ProvisionComplete,
78-
Provision: []*proto.Provision_Response{{
79-
Type: &proto.Provision_Response_Complete{
80-
Complete: &proto.Provision_Complete{
81-
Resources: []*proto.Resource{{
82-
Name: "example",
83-
Type: "aws_instance",
84-
Agents: []*proto.Agent{{
85-
Id: uuid.NewString(),
86-
Auth: &proto.Agent_Token{
87-
Token: authToken,
88-
},
71+
72+
t.Run("Connect", func(t *testing.T) {
73+
t.Parallel()
74+
75+
client, coderAPI := coderdtest.NewWithAPI(t, nil)
76+
user := coderdtest.CreateFirstUser(t, client)
77+
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
78+
authToken := uuid.NewString()
79+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
80+
Parse: echo.ParseComplete,
81+
ProvisionDryRun: echo.ProvisionComplete,
82+
Provision: []*proto.Provision_Response{{
83+
Type: &proto.Provision_Response_Complete{
84+
Complete: &proto.Provision_Complete{
85+
Resources: []*proto.Resource{{
86+
Name: "example",
87+
Type: "aws_instance",
88+
Agents: []*proto.Agent{{
89+
Id: uuid.NewString(),
90+
Auth: &proto.Agent_Token{
91+
Token: authToken,
92+
},
93+
}},
8994
}},
90-
}},
95+
},
9196
},
92-
},
93-
}},
94-
})
95-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
96-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
97-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
98-
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
99-
daemonCloser.Close()
97+
}},
98+
})
99+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
100+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
101+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
102+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
103+
daemonCloser.Close()
100104

101-
agentClient := codersdk.New(client.URL)
102-
agentClient.SessionToken = authToken
103-
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
104-
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
105-
})
106-
t.Cleanup(func() {
107-
_ = agentCloser.Close()
105+
agentClient := codersdk.New(client.URL)
106+
agentClient.SessionToken = authToken
107+
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
108+
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
109+
})
110+
t.Cleanup(func() {
111+
_ = agentCloser.Close()
112+
})
113+
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
114+
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
115+
require.NoError(t, err)
116+
t.Cleanup(func() {
117+
_ = conn.Close()
118+
})
119+
_, err = conn.Ping()
120+
require.NoError(t, err)
108121
})
109-
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
110-
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
111-
require.NoError(t, err)
112-
t.Cleanup(func() {
113-
_ = conn.Close()
122+
123+
t.Run("FailNonLatestBuild", func(t *testing.T) {
124+
t.Parallel()
125+
126+
ctx := context.Background()
127+
client, coderAPI := coderdtest.NewWithAPI(t, nil)
128+
user := coderdtest.CreateFirstUser(t, client)
129+
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
130+
defer daemonCloser.Close()
131+
132+
authToken := uuid.NewString()
133+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
134+
Parse: echo.ParseComplete,
135+
ProvisionDryRun: echo.ProvisionComplete,
136+
Provision: []*proto.Provision_Response{{
137+
Type: &proto.Provision_Response_Complete{
138+
Complete: &proto.Provision_Complete{
139+
Resources: []*proto.Resource{{
140+
Name: "example",
141+
Type: "aws_instance",
142+
Agents: []*proto.Agent{{
143+
Id: uuid.NewString(),
144+
Auth: &proto.Agent_Token{
145+
Token: authToken,
146+
},
147+
}},
148+
}},
149+
},
150+
},
151+
}},
152+
})
153+
154+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
155+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
156+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
157+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
158+
159+
version = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
160+
Parse: echo.ParseComplete,
161+
ProvisionDryRun: echo.ProvisionComplete,
162+
Provision: []*proto.Provision_Response{{
163+
Type: &proto.Provision_Response_Complete{
164+
Complete: &proto.Provision_Complete{
165+
Resources: []*proto.Resource{{
166+
Name: "example",
167+
Type: "aws_instance",
168+
Agents: []*proto.Agent{{
169+
Id: uuid.NewString(),
170+
Auth: &proto.Agent_Token{
171+
Token: uuid.NewString(),
172+
},
173+
}},
174+
}},
175+
},
176+
},
177+
}},
178+
}, template.ID)
179+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
180+
181+
stopBuild, err := client.CreateWorkspaceBuild(context.Background(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
182+
TemplateVersionID: version.ID,
183+
Transition: codersdk.WorkspaceTransitionStop,
184+
})
185+
require.NoError(t, err)
186+
coderdtest.AwaitWorkspaceBuildJob(t, client, stopBuild.ID)
187+
188+
agentClient := codersdk.New(client.URL)
189+
agentClient.SessionToken = authToken
190+
191+
_, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
192+
require.Error(t, err)
193+
require.ErrorContains(t, err, "build is outdated")
114194
})
115-
_, err = conn.Ping()
116-
require.NoError(t, err)
117195
}
118196

119197
func TestWorkspaceAgentTURN(t *testing.T) {

0 commit comments

Comments
 (0)