Skip to content

Commit 86d37e2

Browse files
committed
review fixes
1 parent e204a7b commit 86d37e2

File tree

2 files changed

+38
-61
lines changed

2 files changed

+38
-61
lines changed

cli/portforward.go

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package cli
33
import (
44
"context"
55
"fmt"
6-
"net"
76
"os"
87
"os/signal"
98
"syscall"
@@ -20,24 +19,6 @@ import (
2019
"github.com/coder/serpent"
2120
)
2221

23-
// cliDialer adapts workspacesdk.AgentConn to portforward.Dialer
24-
type cliDialer struct {
25-
conn *workspacesdk.AgentConn
26-
}
27-
28-
func (d *cliDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
29-
return d.conn.DialContext(ctx, network, address)
30-
}
31-
32-
// cliListener adapts serpent.Invocation.Net to portforward.Listener
33-
type cliListener struct {
34-
inv *serpent.Invocation
35-
}
36-
37-
func (l *cliListener) Listen(network, address string) (net.Listener, error) {
38-
return l.inv.Net.Listen(network, address)
39-
}
40-
4122
func (r *RootCmd) portForward() *serpent.Command {
4223
var (
4324
tcpForwards []string // <port>:<port>
@@ -132,24 +113,18 @@ func (r *RootCmd) portForward() *serpent.Command {
132113
}
133114
defer conn.Close()
134115

135-
// Create port forwarding options
116+
// Create port forwarding manager
136117
pfOpts := portforward.Options{
137118
Logger: logger,
138-
Dialer: &cliDialer{conn: conn},
139-
Listener: &cliListener{inv: inv},
119+
Dialer: conn,
120+
Listener: inv.Net,
140121
}
141-
142-
// Start all forwarders.
143-
var (
144-
forwarders = make([]portforward.Forwarder, 0, len(specs))
145-
closeAllForwarders = func() {
146-
logger.Debug(ctx, "closing all forwarders")
147-
for _, f := range forwarders {
148-
_ = f.Stop()
149-
}
122+
manager := portforward.NewManager(pfOpts)
123+
defer func() {
124+
if stopErr := manager.Stop(); stopErr != nil {
125+
logger.Error(ctx, "failed to stop port forwarding manager", slog.Error(stopErr))
150126
}
151-
)
152-
defer closeAllForwarders()
127+
}()
153128

154129
// Create a signal handler for graceful shutdown
155130
shutdownCh := make(chan struct{})
@@ -165,30 +140,32 @@ func (r *RootCmd) portForward() *serpent.Command {
165140
// first, opportunistically try to listen on IPv6
166141
spec6 := spec
167142
spec6.ListenHost = portforward.IPv6Loopback
168-
f6 := portforward.NewForwarder(spec6, pfOpts)
169-
err6 := f6.Start(ctx)
170-
if err6 != nil {
171-
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
143+
_, err := manager.Add(spec6)
144+
if err != nil {
145+
logger.Info(ctx, "failed to opportunistically add IPv6 forwarder", slog.F("spec", spec), slog.Error(err))
172146
} else {
173-
forwarders = append(forwarders, f6)
174147
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://[%s]:%d' locally to '%s://127.0.0.1:%d' in the workspace\n",
175148
spec6.Network, spec6.ListenHost, spec6.ListenPort, spec6.Network, spec6.DialPort)
176149
}
177150
spec.ListenHost = portforward.IPv4Loopback
178151
}
179152

180-
f := portforward.NewForwarder(spec, pfOpts)
181-
err := f.Start(ctx)
153+
_, err := manager.Add(spec)
182154
if err != nil {
183-
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
155+
logger.Error(ctx, "failed to add forwarder", slog.F("spec", spec), slog.Error(err))
184156
return err
185157
}
186158

187-
forwarders = append(forwarders, f)
188159
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s:%d' locally to '%s://127.0.0.1:%d' in the workspace\n",
189160
spec.Network, spec.ListenHost, spec.ListenPort, spec.Network, spec.DialPort)
190161
}
191162

163+
// Start all forwarders at once
164+
err = manager.Start(ctx)
165+
if err != nil {
166+
return xerrors.Errorf("start port forwarding: %w", err)
167+
}
168+
192169
conn.AwaitReachable(ctx)
193170
logger.Debug(ctx, "ready to accept connections to forward")
194171
_, _ = fmt.Fprintln(inv.Stderr, "Ready!")

portforward/forwarder.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ type Forwarder interface {
3737

3838
// Manager manages multiple port forwards.
3939
type Manager interface {
40-
// AddForward adds a new port forward.
41-
AddForward(spec Spec) (Forwarder, error)
42-
// RemoveForward removes an existing port forward.
43-
RemoveForward(spec Spec) error
44-
// ListForwards returns all active port forwards.
45-
ListForwards() []Forwarder
40+
// Add adds a new port forward.
41+
Add(spec Spec) (Forwarder, error)
42+
// Remove removes an existing port forward.
43+
Remove(spec Spec) error
44+
// List returns all active port forwards.
45+
List() []Forwarder
4646
// Start starts all port forwards.
4747
Start(ctx context.Context) error
4848
// Stop stops all port forwards.
@@ -66,8 +66,8 @@ type Options struct {
6666
Listener Listener
6767
}
6868

69-
// localForwarder implements a single port forward from local to remote.
70-
type localForwarder struct {
69+
// LocalForwarder implements a single port forward from local to remote.
70+
type LocalForwarder struct {
7171
spec Spec
7272
opts Options
7373
listener net.Listener
@@ -76,15 +76,15 @@ type localForwarder struct {
7676
wg sync.WaitGroup
7777
}
7878

79-
// NewForwarder creates a new port forwarder.
80-
func NewForwarder(spec Spec, opts Options) Forwarder {
81-
return &localForwarder{
79+
// NewLocal creates a new local port forwarder.
80+
func NewLocal(spec Spec, opts Options) *LocalForwarder {
81+
return &LocalForwarder{
8282
spec: spec,
8383
opts: opts,
8484
}
8585
}
8686

87-
func (f *localForwarder) Start(ctx context.Context) error {
87+
func (f *LocalForwarder) Start(ctx context.Context) error {
8888
if f.active.Load() {
8989
return xerrors.New("forwarder is already active")
9090
}
@@ -157,7 +157,7 @@ func (f *localForwarder) Start(ctx context.Context) error {
157157
return nil
158158
}
159159

160-
func (f *localForwarder) Stop() error {
160+
func (f *LocalForwarder) Stop() error {
161161
if !f.active.Load() {
162162
return nil
163163
}
@@ -172,11 +172,11 @@ func (f *localForwarder) Stop() error {
172172
return nil
173173
}
174174

175-
func (f *localForwarder) IsActive() bool {
175+
func (f *LocalForwarder) IsActive() bool {
176176
return f.active.Load()
177177
}
178178

179-
func (f *localForwarder) Spec() Spec {
179+
func (f *LocalForwarder) Spec() Spec {
180180
return f.spec
181181
}
182182

@@ -195,7 +195,7 @@ func NewManager(opts Options) Manager {
195195
}
196196
}
197197

198-
func (m *manager) AddForward(spec Spec) (Forwarder, error) {
198+
func (m *manager) Add(spec Spec) (Forwarder, error) {
199199
m.mu.Lock()
200200
defer m.mu.Unlock()
201201

@@ -204,12 +204,12 @@ func (m *manager) AddForward(spec Spec) (Forwarder, error) {
204204
return nil, xerrors.Errorf("forwarder already exists for %s", key)
205205
}
206206

207-
forwarder := NewForwarder(spec, m.opts)
207+
forwarder := NewLocal(spec, m.opts)
208208
m.forwarders[key] = forwarder
209209
return forwarder, nil
210210
}
211211

212-
func (m *manager) RemoveForward(spec Spec) error {
212+
func (m *manager) Remove(spec Spec) error {
213213
m.mu.Lock()
214214
defer m.mu.Unlock()
215215

@@ -224,7 +224,7 @@ func (m *manager) RemoveForward(spec Spec) error {
224224
return err
225225
}
226226

227-
func (m *manager) ListForwards() []Forwarder {
227+
func (m *manager) List() []Forwarder {
228228
m.mu.RLock()
229229
defer m.mu.RUnlock()
230230

0 commit comments

Comments
 (0)