Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit f8b26db

Browse files
committed
import updates to wsnet pkg
1 parent c37287c commit f8b26db

File tree

11 files changed

+219
-161
lines changed

11 files changed

+219
-161
lines changed

wsnet/cache.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ func (d *DialerCache) init() {
5555
// evict removes lost/broken/expired connections from the cache.
5656
func (d *DialerCache) evict() {
5757
var wg sync.WaitGroup
58+
// This lock lasts for just the iteration of the for loop, the actual code
59+
// is in waitgroup'd goroutines so the read lock doesn't persist the whole
60+
// time, but it means we can't defer the unlock sadly.
5861
d.mut.RLock()
62+
5963
for key, dialer := range d.dialers {
6064
wg.Add(1)
6165
key := key
@@ -74,7 +78,11 @@ func (d *DialerCache) evict() {
7478
evict = true
7579
}
7680

77-
if dialer.activeConnections() == 0 && time.Since(d.atime[key]) >= d.ttl {
81+
d.mut.RLock()
82+
atime := d.atime[key]
83+
d.mut.RUnlock()
84+
85+
if dialer.activeConnections() == 0 && time.Since(atime) >= d.ttl {
7886
evict = true
7987
} else {
8088
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
@@ -92,12 +100,12 @@ func (d *DialerCache) evict() {
92100
_ = dialer.Close()
93101
// Ensure after Ping and potential delays that we're still testing against
94102
// the proper dialer.
103+
d.mut.Lock()
104+
defer d.mut.Unlock()
95105
if dialer != d.dialers[key] {
96106
return
97107
}
98108

99-
d.mut.Lock()
100-
defer d.mut.Unlock()
101109
delete(d.atime, key)
102110
delete(d.dialers, key)
103111
}()
@@ -109,7 +117,7 @@ func (d *DialerCache) evict() {
109117
// Dial returns a Dialer from the cache if one exists with the key provided,
110118
// or dials a new connection using the dialerFunc.
111119
// The bool returns whether the connection was found in the cache or not.
112-
func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) {
120+
func (d *DialerCache) Dial(_ context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) {
113121
select {
114122
case <-d.closed:
115123
return nil, false, errors.New("cache closed")
@@ -136,9 +144,9 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*
136144
return nil, err
137145
}
138146
d.mut.Lock()
147+
defer d.mut.Unlock()
139148
d.dialers[key] = dialer
140149
d.atime[key] = time.Now()
141-
d.mut.Unlock()
142150

143151
return dialer, nil
144152
})
@@ -159,6 +167,10 @@ func (d *DialerCache) Close() error {
159167
d.mut.Lock()
160168
defer d.mut.Unlock()
161169

170+
if d.isClosed() {
171+
return nil
172+
}
173+
162174
for _, dialer := range d.dialers {
163175
err := dialer.Close()
164176
if err != nil {
@@ -168,3 +180,12 @@ func (d *DialerCache) Close() error {
168180
close(d.closed)
169181
return nil
170182
}
183+
184+
func (d *DialerCache) isClosed() bool {
185+
select {
186+
case <-d.closed:
187+
return true
188+
default:
189+
return false
190+
}
191+
}

wsnet/cache_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ import (
55
"testing"
66
"time"
77

8-
"cdr.dev/slog/sloggers/slogtest"
98
"github.com/stretchr/testify/assert"
109
"github.com/stretchr/testify/require"
10+
"go.uber.org/goleak"
11+
12+
"cdr.dev/slog/sloggers/slogtest"
1113
)
1214

1315
func TestCache(t *testing.T) {
16+
defer goleak.VerifyNone(t)
1417
dialFunc := func(connectAddr string) func() (*Dialer, error) {
1518
return func() (*Dialer, error) {
1619
return DialWebsocket(context.Background(), connectAddr, nil, nil)
@@ -24,6 +27,7 @@ func TestCache(t *testing.T) {
2427
defer l.Close()
2528

2629
cache := DialCache(time.Hour)
30+
defer cache.Close()
2731
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
2832
require.NoError(t, err)
2933
require.Equal(t, cached, false)
@@ -40,7 +44,7 @@ func TestCache(t *testing.T) {
4044
defer l.Close()
4145

4246
cache := DialCache(time.Hour)
43-
47+
defer cache.Close()
4448
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
4549
require.NoError(t, err)
4650
require.Equal(t, cached, false)
@@ -58,7 +62,7 @@ func TestCache(t *testing.T) {
5862
defer l.Close()
5963

6064
cache := DialCache(0)
61-
65+
defer cache.Close()
6266
c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr))
6367
require.NoError(t, err)
6468
require.Equal(t, cached, false)

wsnet/conn.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/pion/webrtc/v3"
1515
"nhooyr.io/websocket"
1616

17-
"cdr.dev/coder-cli/coder-sdk"
17+
"coder.com/m/product/coder/pkg/codersdk/legacy"
1818
)
1919

2020
const (
@@ -65,7 +65,7 @@ type turnProxyDialer struct {
6565
token string
6666
}
6767

68-
func (t *turnProxyDialer) Dial(network, addr string) (c net.Conn, err error) {
68+
func (t *turnProxyDialer) Dial(_, _ string) (c net.Conn, err error) {
6969
headers := http.Header{}
7070
headers.Set("Session-Token", t.token)
7171

@@ -89,7 +89,7 @@ func (t *turnProxyDialer) Dial(network, addr string) (c net.Conn, err error) {
8989
if err != nil {
9090
if resp != nil {
9191
defer resp.Body.Close()
92-
return nil, coder.NewHTTPError(resp)
92+
return nil, legacy.NewHTTPError(resp)
9393
}
9494
return nil, fmt.Errorf("dial: %w", err)
9595
}
@@ -187,14 +187,14 @@ func (c *dataChannelConn) RemoteAddr() net.Addr {
187187
return c.addr
188188
}
189189

190-
func (c *dataChannelConn) SetDeadline(t time.Time) error {
190+
func (c *dataChannelConn) SetDeadline(_ time.Time) error {
191191
return nil
192192
}
193193

194-
func (c *dataChannelConn) SetReadDeadline(t time.Time) error {
194+
func (c *dataChannelConn) SetReadDeadline(_ time.Time) error {
195195
return nil
196196
}
197197

198-
func (c *dataChannelConn) SetWriteDeadline(t time.Time) error {
198+
func (c *dataChannelConn) SetWriteDeadline(_ time.Time) error {
199199
return nil
200200
}

wsnet/dial.go

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ import (
1515
"github.com/pion/webrtc/v3"
1616
"golang.org/x/net/proxy"
1717
"golang.org/x/xerrors"
18+
"k8s.io/utils/pointer"
1819
"nhooyr.io/websocket"
1920

2021
"cdr.dev/slog"
21-
22-
"cdr.dev/coder-cli/coder-sdk"
22+
"coder.com/m/product/coder/pkg/codersdk/legacy"
2323
)
2424

2525
// DialOptions are configurable options for a wsnet connection.
@@ -35,10 +35,11 @@ type DialOptions struct {
3535
// TURNProxyAuthToken is used to authenticate a TURN proxy request.
3636
TURNProxyAuthToken string
3737

38-
// TURNProxyURL is the URL to proxy all TURN data through.
39-
// This URL is sent to the listener during handshake so both
40-
// ends connect to the same TURN endpoint.
41-
TURNProxyURL *url.URL
38+
// TURNRemoteProxyURL is the URL to proxy listener TURN data through.
39+
TURNRemoteProxyURL *url.URL
40+
41+
// TURNLocalProxyURL is the URL to proxy client TURN data through.
42+
TURNLocalProxyURL *url.URL
4243
}
4344

4445
// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
@@ -60,7 +61,7 @@ func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsO
6061
defer func() {
6162
_ = resp.Body.Close()
6263
}()
63-
return nil, coder.NewHTTPError(resp)
64+
return nil, legacy.NewHTTPError(resp)
6465
}
6566
return nil, fmt.Errorf("dial websocket: %w", err)
6667
}
@@ -91,9 +92,9 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er
9192
}
9293

9394
var turnProxy proxy.Dialer
94-
if options.TURNProxyURL != nil {
95+
if options.TURNLocalProxyURL != nil {
9596
turnProxy = &turnProxyDialer{
96-
baseURL: options.TURNProxyURL,
97+
baseURL: options.TURNLocalProxyURL,
9798
token: options.TURNProxyAuthToken,
9899
}
99100
}
@@ -107,7 +108,7 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er
107108
defer func() {
108109
if err != nil {
109110
// Wrap our error with some extra details.
110-
err = errWrap{
111+
err = wrapError{
111112
err: err,
112113
iceServers: rtc.GetConfiguration().ICEServers,
113114
rtc: rtc.ConnectionState(),
@@ -128,8 +129,8 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er
128129

129130
log.Debug(ctx, "creating control channel", slog.F("proto", controlChannel))
130131
ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{
131-
Protocol: stringPtr(controlChannel),
132-
Ordered: boolPtr(true),
132+
Protocol: pointer.String(controlChannel),
133+
Ordered: pointer.Bool(true),
133134
})
134135
if err != nil {
135136
return nil, fmt.Errorf("create control channel: %w", err)
@@ -146,8 +147,8 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er
146147
}
147148

148149
var turnProxyURL string
149-
if options.TURNProxyURL != nil {
150-
turnProxyURL = options.TURNProxyURL.String()
150+
if options.TURNRemoteProxyURL != nil {
151+
turnProxyURL = options.TURNRemoteProxyURL.String()
151152
}
152153

153154
bmsg := BrokerMessage{
@@ -177,7 +178,9 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er
177178

178179
err = dialer.negotiate(ctx)
179180
if err != nil {
180-
return nil, xerrors.Errorf("negotiate rtc connection: %w", err)
181+
// Return the dialer since we have tests that verify things are closed
182+
// if negotiation fails.
183+
return dialer, xerrors.Errorf("negotiate rtc connection: %w", err)
181184
}
182185

183186
return dialer, nil
@@ -290,11 +293,22 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) {
290293
return fmt.Errorf("unhandled message: %+v", msg)
291294
}
292295

293-
return <-errCh
296+
err = <-errCh
297+
if err != nil {
298+
return err
299+
}
300+
301+
proto, err := iceProto(d.rtc)
302+
if err != nil {
303+
return xerrors.Errorf("determine ICE connection protocol: %w", err)
304+
}
305+
d.log.Debug(ctx, "connected", slog.F("ice_proto", proto))
306+
307+
return nil
294308
}
295309

296-
// ActiveConnections returns the amount of active connections.
297-
// DialContext opens a connection, and close will end it.
310+
// ActiveConnections returns the amount of active connections. DialContext
311+
// opens a connection, and close will end it.
298312
func (d *Dialer) activeConnections() int {
299313
stats, ok := d.rtc.GetStats().GetConnectionStats(d.rtc)
300314
if !ok {
@@ -304,6 +318,11 @@ func (d *Dialer) activeConnections() int {
304318
return int(stats.DataChannelsRequested-stats.DataChannelsClosed) - 1
305319
}
306320

321+
// Candidates returns the candidate pair that was chosen for the connection.
322+
func (d *Dialer) Candidates() (*webrtc.ICECandidatePair, error) {
323+
return d.rtc.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
324+
}
325+
307326
// Close closes the RTC connection.
308327
// All data channels dialed will be closed.
309328
func (d *Dialer) Close() error {
@@ -367,7 +386,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
367386

368387
d.log.Debug(ctx, "opening data channel")
369388
dc, err := d.rtc.CreateDataChannel("proxy", &webrtc.DataChannelInit{
370-
Ordered: boolPtr(network != "udp"),
389+
Ordered: pointer.Bool(network != "udp"),
371390
Protocol: &proto,
372391
})
373392
if err != nil {

0 commit comments

Comments
 (0)