@@ -138,7 +138,7 @@ func (s *ServerTailnet) updateNode(id uuid.UUID, node *tailnet.Node) {
138
138
}
139
139
}
140
140
141
- func (s * ServerTailnet ) ReverseProxy (targetURL , dashboardURL * url.URL , agentID uuid.UUID ) (* httputil.ReverseProxy , func (), error ) {
141
+ func (s * ServerTailnet ) ReverseProxy (targetURL , dashboardURL * url.URL , agentID uuid.UUID ) (_ * httputil.ReverseProxy , release func (), _ error ) {
142
142
proxy := httputil .NewSingleHostReverseProxy (targetURL )
143
143
proxy .ErrorHandler = func (w http.ResponseWriter , r * http.Request , err error ) {
144
144
site .RenderStaticErrorPage (w , r , site.ErrorPageData {
@@ -257,7 +257,7 @@ func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool {
257
257
return node .Addresses [0 ].Addr () == codersdk .WorkspaceAgentIP
258
258
}
259
259
260
- func (s * ServerTailnet ) AgentConn (ctx context.Context , agentID uuid.UUID ) (* codersdk.WorkspaceAgentConn , func (), error ) {
260
+ func (s * ServerTailnet ) AgentConn (ctx context.Context , agentID uuid.UUID ) (_ * codersdk.WorkspaceAgentConn , release func (), _ error ) {
261
261
node , err := s .awaitNodeExists (ctx , agentID , 5 * time .Second )
262
262
if err != nil {
263
263
return nil , nil , xerrors .Errorf ("get agent node: %w" , err )
@@ -297,8 +297,6 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
297
297
if err != nil {
298
298
return nil , xerrors .Errorf ("acquire agent conn: %w" , err )
299
299
}
300
- defer release ()
301
- defer conn .Close ()
302
300
303
301
node , err := s .getNode (agentID )
304
302
if err != nil {
@@ -309,13 +307,29 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
309
307
port , _ := strconv .ParseUint (rawPort , 10 , 16 )
310
308
ipp := netip .AddrPortFrom (node .Addresses [0 ].Addr (), uint16 (port ))
311
309
310
+ var nc net.Conn
312
311
if network == "tcp" {
313
- return conn .DialContextTCP (ctx , ipp )
312
+ nc , err = conn .DialContextTCP (ctx , ipp )
314
313
} else if network == "udp" {
315
- return conn .DialContextUDP (ctx , ipp )
314
+ nc , err = conn .DialContextUDP (ctx , ipp )
316
315
} else {
317
316
return nil , xerrors .Errorf ("unknown network %q" , network )
318
317
}
318
+
319
+ return & netConnCloser {Conn : nc , close : func () {
320
+ release ()
321
+ conn .Close ()
322
+ }}, err
323
+ }
324
+
325
+ type netConnCloser struct {
326
+ net.Conn
327
+ close func ()
328
+ }
329
+
330
+ func (c * netConnCloser ) Close () error {
331
+ c .close ()
332
+ return c .Conn .Close ()
319
333
}
320
334
321
335
func (s * ServerTailnet ) Close () error {
0 commit comments