@@ -143,16 +143,49 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
143
143
defer api .websocketWaitGroup .Done ()
144
144
145
145
workspaceAgent := httpmw .WorkspaceAgent (r )
146
- conn , err := websocket .Accept (rw , r , & websocket.AcceptOptions {
147
- CompressionMode : websocket .CompressionDisabled ,
148
- })
146
+ resource , err := api .Database .GetWorkspaceResourceByID (r .Context (), workspaceAgent .ResourceID )
149
147
if err != nil {
150
148
httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
151
- Message : fmt .Sprintf ("accept websocket : %s" , err ),
149
+ Message : fmt .Sprintf ("get workspace resource : %s" , err ),
152
150
})
153
151
return
154
152
}
155
- resource , err := api .Database .GetWorkspaceResourceByID (r .Context (), workspaceAgent .ResourceID )
153
+
154
+ build , err := api .Database .GetWorkspaceBuildByJobID (r .Context (), resource .JobID )
155
+ if err != nil {
156
+ httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
157
+ Message : fmt .Sprintf ("get workspace build job: %s" , err ),
158
+ })
159
+ return
160
+ }
161
+ // Ensure the resource is still valid!
162
+ // We only accept agents for resources on the latest build.
163
+ ensureLatestBuild := func () error {
164
+ latestBuild , err := api .Database .GetLatestWorkspaceBuildByWorkspaceID (r .Context (), build .WorkspaceID )
165
+ if err != nil {
166
+ return err
167
+ }
168
+ if build .ID != latestBuild .ID {
169
+ return xerrors .New ("build is outdated" )
170
+ }
171
+ return nil
172
+ }
173
+
174
+ err = ensureLatestBuild ()
175
+ if err != nil {
176
+ api .Logger .Debug (r .Context (), "agent tried to connect from non-latest built" ,
177
+ slog .F ("resource" , resource ),
178
+ slog .F ("agent" , workspaceAgent ),
179
+ )
180
+ httpapi .Write (rw , http .StatusForbidden , httpapi.Response {
181
+ Message : fmt .Sprintf ("ensure latest build: %s" , err ),
182
+ })
183
+ return
184
+ }
185
+
186
+ conn , err := websocket .Accept (rw , r , & websocket.AcceptOptions {
187
+ CompressionMode : websocket .CompressionDisabled ,
188
+ })
156
189
if err != nil {
157
190
httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
158
191
Message : fmt .Sprintf ("accept websocket: %s" , err ),
@@ -163,13 +196,15 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
163
196
defer func () {
164
197
_ = conn .Close (websocket .StatusNormalClosure , "" )
165
198
}()
199
+
166
200
config := yamux .DefaultConfig ()
167
201
config .LogOutput = io .Discard
168
202
session , err := yamux .Server (websocket .NetConn (r .Context (), conn , websocket .MessageBinary ), config )
169
203
if err != nil {
170
204
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
171
205
return
172
206
}
207
+
173
208
closer , err := peerbroker .ProxyDial (proto .NewDRPCPeerBrokerClient (provisionersdk .Conn (session )), peerbroker.ProxyOptions {
174
209
ChannelID : workspaceAgent .ID .String (),
175
210
Pubsub : api .Pubsub ,
@@ -180,6 +215,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
180
215
return
181
216
}
182
217
defer closer .Close ()
218
+
183
219
firstConnectedAt := workspaceAgent .FirstConnectedAt
184
220
if ! firstConnectedAt .Valid {
185
221
firstConnectedAt = sql.NullTime {
@@ -204,23 +240,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
204
240
}
205
241
return nil
206
242
}
207
- build , err := api .Database .GetWorkspaceBuildByJobID (r .Context (), resource .JobID )
208
- if err != nil {
209
- _ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
210
- return
211
- }
212
- // Ensure the resource is still valid!
213
- // We only accept agents for resources on the latest build.
214
- ensureLatestBuild := func () error {
215
- latestBuild , err := api .Database .GetLatestWorkspaceBuildByWorkspaceID (r .Context (), build .WorkspaceID )
216
- if err != nil {
217
- return err
218
- }
219
- if build .ID != latestBuild .ID {
220
- return xerrors .New ("build is outdated" )
221
- }
222
- return nil
223
- }
224
243
225
244
defer func () {
226
245
disconnectedAt = sql.NullTime {
@@ -230,11 +249,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
230
249
_ = updateConnectionTimes ()
231
250
}()
232
251
233
- err = ensureLatestBuild ()
234
- if err != nil {
235
- _ = conn .Close (websocket .StatusGoingAway , "" )
236
- return
237
- }
238
252
err = updateConnectionTimes ()
239
253
if err != nil {
240
254
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
0 commit comments