Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit c86a67f

Browse files
committed
Centralize webrtc dial logic into xwebrtc
1 parent 43edc2f commit c86a67f

File tree

10 files changed

+403
-268
lines changed

10 files changed

+403
-268
lines changed

agent/server.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ func (s *Server) Run(ctx context.Context) error {
5959
}),
6060
).Run(
6161
func() error {
62-
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
63-
defer cancelFunc()
6462
s.log.Info(ctx, "connecting to coder", slog.F("url", s.listenURL.String()))
6563
conn, resp, err := websocket.Dial(ctx, s.listenURL.String(), nil)
6664
if err != nil && resp == nil {
@@ -71,7 +69,8 @@ func (s *Server) Run(ctx context.Context) error {
7169
Response: resp,
7270
}
7371
}
74-
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
72+
73+
nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
7574
session, err := yamux.Server(nc, nil)
7675
if err != nil {
7776
return fmt.Errorf("open: %w", err)

agent/stream.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import (
77
"io"
88
"net"
99

10+
"cdr.dev/coder-cli/xwebrtc"
11+
1012
"cdr.dev/slog"
1113
"github.com/hashicorp/yamux"
1214
"github.com/pion/webrtc/v3"
1315
"golang.org/x/xerrors"
1416

15-
"cdr.dev/coder-cli/internal/x/xwebrtc"
1617
"cdr.dev/coder-cli/pkg/proto"
1718
)
1819

@@ -128,6 +129,10 @@ func (s *stream) processMessage(msg proto.Message) {
128129
}
129130

130131
func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
132+
if channel.Protocol() == "control" {
133+
return
134+
}
135+
131136
if channel.Protocol() == "ping" {
132137
channel.OnOpen(func() {
133138
rw, err := channel.Detach()
@@ -149,7 +154,7 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
149154
return
150155
}
151156

152-
prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
157+
prto, addr, err := xwebrtc.ParseProxyDataChannel(channel)
153158
if err != nil {
154159
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
155160
return
@@ -159,14 +164,14 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
159164
return
160165
}
161166

162-
conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
167+
conn, err := net.Dial(prto, addr)
163168
if err != nil {
164-
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
169+
s.fatal(fmt.Errorf("failed to dial client addr: %s", addr))
165170
return
166171
}
167172

168173
channel.OnOpen(func() {
169-
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
174+
s.logger.Debug(context.Background(), "proxying data channel", slog.F("addr", addr))
170175
rw, err := channel.Detach()
171176
if err != nil {
172177
_ = channel.Close()

go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ require (
1818
github.com/rjeczalik/notify v0.9.2
1919
github.com/spf13/cobra v1.1.3
2020
go.coder.com/retry v1.2.0
21+
golang.org/x/net v0.0.0-20210420210106-798c2154c571
2122
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9
2223
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe
2324
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1

internal/cmd/tunnel.go

+39-184
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,21 @@ package cmd
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"io"
87
"net"
98
"net/url"
109
"os"
1110
"strconv"
12-
"time"
1311

1412
"cdr.dev/slog"
1513
"cdr.dev/slog/sloggers/sloghuman"
16-
"github.com/pion/webrtc/v3"
1714
"github.com/spf13/cobra"
1815
"golang.org/x/xerrors"
19-
"nhooyr.io/websocket"
2016

2117
"cdr.dev/coder-cli/coder-sdk"
2218
"cdr.dev/coder-cli/internal/x/xcobra"
23-
"cdr.dev/coder-cli/internal/x/xwebrtc"
24-
"cdr.dev/coder-cli/pkg/proto"
19+
"cdr.dev/coder-cli/xwebrtc"
2520
)
2621

2722
func tunnelCmd() *cobra.Command {
@@ -41,26 +36,26 @@ coder tunnel my-dev 3000 3000
4136

4237
remotePort, err := strconv.ParseUint(args[1], 10, 16)
4338
if err != nil {
44-
log.Fatal(ctx, "parse remote port", slog.Error(err))
39+
return xerrors.Errorf("parse remote port: %w", err)
4540
}
4641

4742
var localPort uint64
4843
if args[2] != "stdio" {
4944
localPort, err = strconv.ParseUint(args[2], 10, 16)
5045
if err != nil {
51-
log.Fatal(ctx, "parse local port", slog.Error(err))
46+
return xerrors.Errorf("parse local port: %w", err)
5247
}
5348
}
5449

5550
sdk, err := newClient(ctx)
5651
if err != nil {
57-
return err
52+
return xerrors.Errorf("getting coder client: %w", err)
5853
}
5954
baseURL := sdk.BaseURL()
6055

6156
envs, err := getEnvs(ctx, sdk, coder.Me)
6257
if err != nil {
63-
return err
58+
return xerrors.Errorf("get workspaces: %w", err)
6459
}
6560

6661
var envID string
@@ -74,20 +69,19 @@ coder tunnel my-dev 3000 3000
7469
return xerrors.Errorf("No workspace found by name '%s'", args[0])
7570
}
7671

77-
c := &client{
78-
id: envID,
79-
stdio: args[2] == "stdio",
80-
localPort: uint16(localPort),
81-
remotePort: uint16(remotePort),
82-
ctx: context.Background(),
83-
logger: log.Leveled(slog.LevelDebug),
84-
brokerAddr: baseURL,
85-
token: sdk.Token(),
72+
c := &tunnneler{
73+
log: log.Leveled(slog.LevelDebug),
74+
brokerAddr: &baseURL,
75+
token: sdk.Token(),
76+
workspaceID: envID,
77+
stdio: args[2] == "stdio",
78+
localPort: uint16(localPort),
79+
remotePort: uint16(remotePort),
8680
}
8781

88-
err = c.start()
82+
err = c.start(ctx)
8983
if err != nil {
90-
log.Fatal(ctx, err.Error())
84+
return xerrors.Errorf("running tunnel: %w", err)
9185
}
9286

9387
return nil
@@ -97,197 +91,58 @@ coder tunnel my-dev 3000 3000
9791
return cmd
9892
}
9993

100-
type client struct {
101-
ctx context.Context
102-
brokerAddr url.URL
103-
token string
104-
logger slog.Logger
105-
id string
106-
remotePort uint16
107-
localPort uint16
108-
stdio bool
94+
type tunnneler struct {
95+
log slog.Logger
96+
brokerAddr *url.URL
97+
token string
98+
workspaceID string
99+
remotePort uint16
100+
localPort uint16
101+
stdio bool
109102
}
110103

111-
func (c *client) start() error {
112-
url := fmt.Sprintf("%s%s%s%s%s", c.brokerAddr.String(), "/api/private/envagent/", c.id, "/connect?session_token=", c.token)
113-
turnScheme := "turns"
114-
if c.brokerAddr.Scheme == "http" {
115-
turnScheme = "turn"
116-
}
117-
tcpProxy := fmt.Sprintf("%s:%s:5349?transport=tcp", turnScheme, c.brokerAddr.Host)
118-
c.logger.Info(c.ctx, "connecting to broker", slog.F("url", url), slog.F("tcp-proxy", tcpProxy))
119-
conn, resp, err := websocket.Dial(c.ctx, url, nil)
120-
if err != nil && resp == nil {
121-
return fmt.Errorf("dial: %w", err)
122-
}
123-
if err != nil && resp != nil {
124-
return &coder.HTTPError{
125-
Response: resp,
126-
}
127-
}
128-
nconn := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
129-
130-
// Only enabled under a private feature flag for now,
131-
// so insecure connections are entirely fine to allow.
132-
servers := []webrtc.ICEServer{{
133-
URLs: []string{tcpProxy},
134-
Username: "insecure",
135-
Credential: "pass",
136-
CredentialType: webrtc.ICECredentialTypePassword,
137-
}}
138-
rtc, err := xwebrtc.NewPeerConnection(servers)
139-
if err != nil {
140-
return fmt.Errorf("create connection: %w", err)
141-
}
142-
143-
rtc.OnNegotiationNeeded(func() {
144-
c.logger.Debug(context.Background(), "negotiation needed...")
145-
})
146-
147-
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
148-
c.logger.Info(context.Background(), "connection state changed", slog.F("state", pcs))
149-
})
150-
151-
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
152-
if err != nil {
153-
return fmt.Errorf("create data channel: %w", err)
154-
}
155-
flushCandidates := proto.ProxyICECandidates(rtc, nconn)
156-
157-
localDesc, err := rtc.CreateOffer(&webrtc.OfferOptions{})
158-
if err != nil {
159-
return fmt.Errorf("create offer: %w", err)
160-
}
161-
162-
err = rtc.SetLocalDescription(localDesc)
163-
if err != nil {
164-
return fmt.Errorf("set local desc: %w", err)
165-
}
166-
167-
c.logger.Debug(context.Background(), "writing offer")
168-
b, _ := json.Marshal(&proto.Message{
169-
Offer: &localDesc,
170-
Servers: servers,
171-
})
172-
_, err = nconn.Write(b)
104+
func (c *tunnneler) start(ctx context.Context) error {
105+
wd, err := xwebrtc.NewWorkspaceDialer(ctx, c.log, c.brokerAddr, c.token, c.workspaceID)
173106
if err != nil {
174-
return fmt.Errorf("write offer: %w", err)
175-
}
176-
flushCandidates()
177-
178-
go func() {
179-
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
180-
if err != nil {
181-
c.logger.Fatal(context.Background(), "waiting for data channel open", slog.Error(err))
182-
}
183-
_ = conn.Close(websocket.StatusNormalClosure, "rtc connected")
184-
}()
185-
186-
decoder := json.NewDecoder(nconn)
187-
for {
188-
var msg proto.Message
189-
err = decoder.Decode(&msg)
190-
if err == io.EOF {
191-
break
192-
}
193-
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
194-
break
195-
}
196-
if err != nil {
197-
return fmt.Errorf("read msg: %w", err)
198-
}
199-
if msg.Candidate != "" {
200-
c.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
201-
err = proto.AcceptICECandidate(rtc, &msg)
202-
if err != nil {
203-
return fmt.Errorf("accept ice: %w", err)
204-
}
205-
}
206-
if msg.Answer != nil {
207-
c.logger.Debug(context.Background(), "got answer", slog.F("answer", msg.Answer))
208-
err = rtc.SetRemoteDescription(*msg.Answer)
209-
if err != nil {
210-
return fmt.Errorf("set remote: %w", err)
211-
}
212-
}
107+
return xerrors.Errorf("creating workspace dialer: %w", wd)
213108
}
214-
215-
// Once we're open... let's test out the ping.
216-
pingProto := "ping"
217-
pingChannel, err := rtc.CreateDataChannel("pinger", &webrtc.DataChannelInit{
218-
Protocol: &pingProto,
219-
})
109+
nc, err := wd.DialContext(ctx, xwebrtc.NetworkTCP, fmt.Sprintf("localhost:%d", c.remotePort))
220110
if err != nil {
221-
return fmt.Errorf("create ping channel")
111+
return xerrors.Errorf("dial: %w", err)
222112
}
223-
pingChannel.OnOpen(func() {
224-
defer func() {
225-
_ = pingChannel.Close()
226-
}()
227-
t1 := time.Now()
228-
rw, _ := pingChannel.Detach()
229-
defer func() {
230-
_ = rw.Close()
231-
}()
232-
_, _ = rw.Write([]byte("hello"))
233-
b := make([]byte, 64)
234-
_, _ = rw.Read(b)
235-
c.logger.Info(c.ctx, "your latency directly to the agent", slog.F("ms", time.Since(t1).Milliseconds()))
236-
})
237113

114+
// proxy via stdio
238115
if c.stdio {
239-
// At this point the RTC is connected and data channel is opened...
240-
rw, err := channel.Detach()
241-
if err != nil {
242-
return fmt.Errorf("detach channel: %w", err)
243-
}
244116
go func() {
245-
_, _ = io.Copy(rw, os.Stdin)
117+
_, _ = io.Copy(nc, os.Stdin)
246118
}()
247-
_, err = io.Copy(os.Stdout, rw)
119+
_, err = io.Copy(os.Stdout, nc)
248120
if err != nil {
249-
return fmt.Errorf("copy: %w", err)
121+
return xerrors.Errorf("copy: %w", err)
250122
}
251123
return nil
252124
}
253125

126+
// proxy via tcp listener
254127
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", c.localPort))
255128
if err != nil {
256-
return fmt.Errorf("listen: %w", err)
129+
return xerrors.Errorf("listen: %w", err)
257130
}
258131

259132
for {
260-
conn, err := listener.Accept()
133+
lc, err := listener.Accept()
261134
if err != nil {
262-
return fmt.Errorf("accept: %w", err)
135+
return xerrors.Errorf("accept: %w", err)
263136
}
264137
go func() {
265138
defer func() {
266-
_ = conn.Close()
267-
}()
268-
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
269-
if err != nil {
270-
c.logger.Warn(context.Background(), "create data channel for proxying", slog.Error(err))
271-
return
272-
}
273-
defer func() {
274-
_ = channel.Close()
139+
_ = lc.Close()
275140
}()
276-
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
277-
if err != nil {
278-
c.logger.Warn(context.Background(), "wait for data channel open", slog.Error(err))
279-
return
280-
}
281-
rw, err := channel.Detach()
282-
if err != nil {
283-
c.logger.Warn(context.Background(), "detach channel", slog.Error(err))
284-
return
285-
}
286141

287142
go func() {
288-
_, _ = io.Copy(conn, rw)
143+
_, _ = io.Copy(lc, nc)
289144
}()
290-
_, _ = io.Copy(rw, conn)
145+
_, _ = io.Copy(nc, lc)
291146
}()
292147
}
293148
}

0 commit comments

Comments
 (0)