@@ -3,7 +3,6 @@ package cli
3
3
import (
4
4
"context"
5
5
"fmt"
6
- "net"
7
6
"os"
8
7
"os/signal"
9
8
"syscall"
@@ -20,24 +19,6 @@ import (
20
19
"github.com/coder/serpent"
21
20
)
22
21
23
- // cliDialer adapts workspacesdk.AgentConn to portforward.Dialer
24
- type cliDialer struct {
25
- conn * workspacesdk.AgentConn
26
- }
27
-
28
- func (d * cliDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
29
- return d .conn .DialContext (ctx , network , address )
30
- }
31
-
32
- // cliListener adapts serpent.Invocation.Net to portforward.Listener
33
- type cliListener struct {
34
- inv * serpent.Invocation
35
- }
36
-
37
- func (l * cliListener ) Listen (network , address string ) (net.Listener , error ) {
38
- return l .inv .Net .Listen (network , address )
39
- }
40
-
41
22
func (r * RootCmd ) portForward () * serpent.Command {
42
23
var (
43
24
tcpForwards []string // <port>:<port>
@@ -132,24 +113,18 @@ func (r *RootCmd) portForward() *serpent.Command {
132
113
}
133
114
defer conn .Close ()
134
115
135
- // Create port forwarding options
116
+ // Create port forwarding manager
136
117
pfOpts := portforward.Options {
137
118
Logger : logger ,
138
- Dialer : & cliDialer { conn : conn } ,
139
- Listener : & cliListener { inv : inv } ,
119
+ Dialer : conn ,
120
+ Listener : inv . Net ,
140
121
}
141
-
142
- // Start all forwarders.
143
- var (
144
- forwarders = make ([]portforward.Forwarder , 0 , len (specs ))
145
- closeAllForwarders = func () {
146
- logger .Debug (ctx , "closing all forwarders" )
147
- for _ , f := range forwarders {
148
- _ = f .Stop ()
149
- }
122
+ manager := portforward .NewManager (pfOpts )
123
+ defer func () {
124
+ if stopErr := manager .Stop (); stopErr != nil {
125
+ logger .Error (ctx , "failed to stop port forwarding manager" , slog .Error (stopErr ))
150
126
}
151
- )
152
- defer closeAllForwarders ()
127
+ }()
153
128
154
129
// Create a signal handler for graceful shutdown
155
130
shutdownCh := make (chan struct {})
@@ -165,30 +140,32 @@ func (r *RootCmd) portForward() *serpent.Command {
165
140
// first, opportunistically try to listen on IPv6
166
141
spec6 := spec
167
142
spec6 .ListenHost = portforward .IPv6Loopback
168
- f6 := portforward .NewForwarder (spec6 , pfOpts )
169
- err6 := f6 .Start (ctx )
170
- if err6 != nil {
171
- logger .Info (ctx , "failed to opportunistically listen on IPv6" , slog .F ("spec" , spec ), slog .Error (err6 ))
143
+ _ , err := manager .Add (spec6 )
144
+ if err != nil {
145
+ logger .Info (ctx , "failed to opportunistically add IPv6 forwarder" , slog .F ("spec" , spec ), slog .Error (err ))
172
146
} else {
173
- forwarders = append (forwarders , f6 )
174
147
_ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%s://[%s]:%d' locally to '%s://127.0.0.1:%d' in the workspace\n " ,
175
148
spec6 .Network , spec6 .ListenHost , spec6 .ListenPort , spec6 .Network , spec6 .DialPort )
176
149
}
177
150
spec .ListenHost = portforward .IPv4Loopback
178
151
}
179
152
180
- f := portforward .NewForwarder (spec , pfOpts )
181
- err := f .Start (ctx )
153
+ _ , err := manager .Add (spec )
182
154
if err != nil {
183
- logger .Error (ctx , "failed to listen " , slog .F ("spec" , spec ), slog .Error (err ))
155
+ logger .Error (ctx , "failed to add forwarder " , slog .F ("spec" , spec ), slog .Error (err ))
184
156
return err
185
157
}
186
158
187
- forwarders = append (forwarders , f )
188
159
_ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%s://%s:%d' locally to '%s://127.0.0.1:%d' in the workspace\n " ,
189
160
spec .Network , spec .ListenHost , spec .ListenPort , spec .Network , spec .DialPort )
190
161
}
191
162
163
+ // Start all forwarders at once
164
+ err = manager .Start (ctx )
165
+ if err != nil {
166
+ return xerrors .Errorf ("start port forwarding: %w" , err )
167
+ }
168
+
192
169
conn .AwaitReachable (ctx )
193
170
logger .Debug (ctx , "ready to accept connections to forward" )
194
171
_ , _ = fmt .Fprintln (inv .Stderr , "Ready!" )
0 commit comments