1
1
package coderd
2
2
3
3
import (
4
+ "context"
4
5
"database/sql"
5
6
"encoding/json"
6
7
"errors"
7
8
"fmt"
8
9
"io"
10
+ "net"
9
11
"net/http"
10
12
"strings"
11
13
@@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
94
96
// @Success 101
95
97
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
96
98
func (api * API ) provisionerDaemonServe (rw http.ResponseWriter , r * http.Request ) {
99
+ ctx := r .Context ()
100
+
97
101
tags := map [string ]string {}
98
102
if r .URL .Query ().Has ("tag" ) {
99
103
for _ , tag := range r .URL .Query ()["tag" ] {
100
104
parts := strings .SplitN (tag , "=" , 2 )
101
105
if len (parts ) < 2 {
102
- httpapi .Write (r . Context () , rw , http .StatusBadRequest , codersdk.Response {
106
+ httpapi .Write (ctx , rw , http .StatusBadRequest , codersdk.Response {
103
107
Message : fmt .Sprintf ("Invalid format for tag %q. Key and value must be separated with =." , tag ),
104
108
})
105
109
return
@@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
108
112
}
109
113
}
110
114
if ! r .URL .Query ().Has ("provisioner" ) {
111
- httpapi .Write (r . Context () , rw , http .StatusBadRequest , codersdk.Response {
115
+ httpapi .Write (ctx , rw , http .StatusBadRequest , codersdk.Response {
112
116
Message : "The provisioner query parameter must be specified." ,
113
117
})
114
118
return
@@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
122
126
case string (codersdk .ProvisionerTypeTerraform ):
123
127
provisionersMap [codersdk .ProvisionerTypeTerraform ] = struct {}{}
124
128
default :
125
- httpapi .Write (r . Context () , rw , http .StatusBadRequest , codersdk.Response {
129
+ httpapi .Write (ctx , rw , http .StatusBadRequest , codersdk.Response {
126
130
Message : fmt .Sprintf ("Unknown provisioner type %q" , provisioner ),
127
131
})
128
132
return
@@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
137
141
138
142
if tags [provisionerdserver .TagScope ] == provisionerdserver .ScopeOrganization {
139
143
if ! api .AGPL .Authorize (r , rbac .ActionCreate , rbac .ResourceProvisionerDaemon ) {
140
- httpapi .Write (r . Context () , rw , http .StatusForbidden , codersdk.Response {
144
+ httpapi .Write (ctx , rw , http .StatusForbidden , codersdk.Response {
141
145
Message : "You aren't allowed to create provisioner daemons for the organization." ,
142
146
})
143
147
return
@@ -155,15 +159,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
155
159
}
156
160
157
161
name := namesgenerator .GetRandomName (1 )
158
- daemon , err := api .Database .InsertProvisionerDaemon (r . Context () , database.InsertProvisionerDaemonParams {
162
+ daemon , err := api .Database .InsertProvisionerDaemon (ctx , database.InsertProvisionerDaemonParams {
159
163
ID : uuid .New (),
160
164
CreatedAt : database .Now (),
161
165
Name : name ,
162
166
Provisioners : provisioners ,
163
167
Tags : tags ,
164
168
})
165
169
if err != nil {
166
- httpapi .Write (r . Context () , rw , http .StatusInternalServerError , codersdk.Response {
170
+ httpapi .Write (ctx , rw , http .StatusInternalServerError , codersdk.Response {
167
171
Message : "Internal error writing provisioner daemon." ,
168
172
Detail : err .Error (),
169
173
})
@@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
172
176
173
177
rawTags , err := json .Marshal (daemon .Tags )
174
178
if err != nil {
175
- httpapi .Write (r . Context () , rw , http .StatusInternalServerError , codersdk.Response {
179
+ httpapi .Write (ctx , rw , http .StatusInternalServerError , codersdk.Response {
176
180
Message : "Internal error marshaling daemon tags." ,
177
181
Detail : err .Error (),
178
182
})
@@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
189
193
CompressionMode : websocket .CompressionDisabled ,
190
194
})
191
195
if err != nil {
192
- httpapi .Write (r . Context () , rw , http .StatusBadRequest , codersdk.Response {
196
+ httpapi .Write (ctx , rw , http .StatusBadRequest , codersdk.Response {
193
197
Message : "Internal error accepting websocket connection." ,
194
198
Detail : err .Error (),
195
199
})
@@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
203
207
// the same connection.
204
208
config := yamux .DefaultConfig ()
205
209
config .LogOutput = io .Discard
206
- session , err := yamux .Server (websocket .NetConn (r .Context (), conn , websocket .MessageBinary ), config )
210
+ ctx , wsNetConn := websocketNetConn (ctx , conn , websocket .MessageBinary )
211
+ defer wsNetConn .Close ()
212
+ session , err := yamux .Server (wsNetConn , config )
207
213
if err != nil {
208
214
_ = conn .Close (websocket .StatusInternalError , httpapi .WebsocketCloseSprintf ("multiplex server: %s" , err ))
209
215
return
@@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
229
235
if xerrors .Is (err , io .EOF ) {
230
236
return
231
237
}
232
- api .Logger .Debug (r . Context () , "drpc server error" , slog .Error (err ))
238
+ api .Logger .Debug (ctx , "drpc server error" , slog .Error (err ))
233
239
},
234
240
})
235
- err = server .Serve (r . Context () , session )
241
+ err = server .Serve (ctx , session )
236
242
if err != nil && ! xerrors .Is (err , io .EOF ) {
237
- api .Logger .Debug (r . Context () , "provisioner daemon disconnected" , slog .Error (err ))
243
+ api .Logger .Debug (ctx , "provisioner daemon disconnected" , slog .Error (err ))
238
244
_ = conn .Close (websocket .StatusInternalError , httpapi .WebsocketCloseSprintf ("serve: %s" , err ))
239
245
return
240
246
}
@@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
254
260
}
255
261
return result
256
262
}
263
+
264
+ // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
265
+ // is called if a read or write error is encountered.
266
+ type wsNetConn struct {
267
+ cancel context.CancelFunc
268
+ net.Conn
269
+ }
270
+
271
+ func (c * wsNetConn ) Read (b []byte ) (n int , err error ) {
272
+ n , err = c .Conn .Read (b )
273
+ if err != nil {
274
+ c .cancel ()
275
+ }
276
+ return n , err
277
+ }
278
+
279
+ func (c * wsNetConn ) Write (b []byte ) (n int , err error ) {
280
+ n , err = c .Conn .Write (b )
281
+ if err != nil {
282
+ c .cancel ()
283
+ }
284
+ return n , err
285
+ }
286
+
287
+ func (c * wsNetConn ) Close () error {
288
+ defer c .cancel ()
289
+ return c .Conn .Close ()
290
+ }
291
+
292
+ // websocketNetConn wraps websocket.NetConn and returns a context that
293
+ // is tied to the parent context and the lifetime of the conn. Any error
294
+ // during read or write will cancel the context, but not close the
295
+ // conn. Close should be called to release context resources.
296
+ func websocketNetConn (ctx context.Context , conn * websocket.Conn , msgType websocket.MessageType ) (context.Context , net.Conn ) {
297
+ ctx , cancel := context .WithCancel (ctx )
298
+ nc := websocket .NetConn (ctx , conn , msgType )
299
+ return ctx , & wsNetConn {
300
+ cancel : cancel ,
301
+ Conn : nc ,
302
+ }
303
+ }
0 commit comments