1
1
package wsnet
2
2
3
3
import (
4
+ "context"
4
5
"fmt"
5
6
"net"
7
+ "net/http"
6
8
"net/url"
7
9
"sync"
8
10
"time"
9
11
10
12
"github.com/pion/datachannel"
11
13
"github.com/pion/webrtc/v3"
14
+ "golang.org/x/net/proxy"
15
+ "nhooyr.io/websocket"
16
+
17
+ "cdr.dev/coder-cli/coder-sdk"
12
18
)
13
19
14
20
const (
@@ -22,16 +28,6 @@ const (
22
28
maxMessageLength = 32 * 1024 // 32 KB
23
29
)
24
30
25
- // TURNEndpoint returns the TURN address for a Coder baseURL.
26
- func TURNEndpoint (baseURL * url.URL ) string {
27
- turnScheme := "turns"
28
- if baseURL .Scheme == httpScheme {
29
- turnScheme = "turn"
30
- }
31
-
32
- return fmt .Sprintf ("%s:%s:5349?transport=tcp" , turnScheme , baseURL .Hostname ())
33
- }
34
-
35
31
// ListenEndpoint returns the Coder endpoint to listen for workspace connections.
36
32
func ListenEndpoint (baseURL * url.URL , token string ) string {
37
33
wsScheme := "wss"
@@ -50,7 +46,80 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string {
50
46
return fmt .Sprintf ("%s://%s%s%s%s%s" , wsScheme , baseURL .Host , "/api/private/envagent/" , workspace , "/connect?session_token=" , token )
51
47
}
52
48
53
- type conn struct {
49
+ // TURNWebSocketICECandidate returns a valid relay ICEServer that can be used to
50
+ // trigger a TURNWebSocketDialer.
51
+ func TURNProxyICECandidate () webrtc.ICEServer {
52
+ return webrtc.ICEServer {
53
+ URLs : []string {"turn:127.0.0.1:3478?transport=tcp" },
54
+ Username : "~magicalusername~" ,
55
+ Credential : "~magicalpassword~" ,
56
+ CredentialType : webrtc .ICECredentialTypePassword ,
57
+ }
58
+ }
59
+
60
+ // TURNWebSocketDialer proxies all TURN traffic through a WebSocket.
61
+ func TURNProxyWebSocket (baseURL * url.URL , token string ) proxy.Dialer {
62
+ return & turnProxyDialer {
63
+ baseURL : baseURL ,
64
+ token : token ,
65
+ }
66
+ }
67
+
68
+ // Proxies all TURN ICEServer traffic through this dialer.
69
+ // References Coder APIs with a specific token.
70
+ type turnProxyDialer struct {
71
+ baseURL * url.URL
72
+ token string
73
+ }
74
+
75
+ func (t * turnProxyDialer ) Dial (network , addr string ) (c net.Conn , err error ) {
76
+ headers := http.Header {}
77
+ headers .Set ("Session-Token" , t .token )
78
+
79
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 15 )
80
+ defer cancel ()
81
+
82
+ // Copy the baseURL so we can adjust path.
83
+ url := * t .baseURL
84
+ url .Scheme = "wss"
85
+ if url .Scheme == httpScheme {
86
+ url .Scheme = "ws"
87
+ }
88
+ url .Path = "/api/private/turn"
89
+ conn , resp , err := websocket .Dial (ctx , url .String (), & websocket.DialOptions {
90
+ HTTPHeader : headers ,
91
+ })
92
+ if err != nil {
93
+ if resp != nil {
94
+ defer resp .Body .Close ()
95
+ return nil , coder .NewHTTPError (resp )
96
+ }
97
+ return nil , fmt .Errorf ("dial: %w" , err )
98
+ }
99
+
100
+ return & turnProxyConn {
101
+ websocket .NetConn (context .Background (), conn , websocket .MessageBinary ),
102
+ }, nil
103
+ }
104
+
105
+ // turnProxyConn is a net.Conn wrapper that returns a TCPAddr for the
106
+ // LocalAddr function. pion/ice unsafely checks the types. See:
107
+ // https://github.com/pion/ice/blob/e78f26fb435987420546c70369ade5d713beca39/gather.go#L448
108
+ type turnProxyConn struct {
109
+ net.Conn
110
+ }
111
+
112
+ // The LocalAddr specified here doesn't really matter,
113
+ // it just has to be of type "TCPAddr".
114
+ func (* turnProxyConn ) LocalAddr () net.Addr {
115
+ return & net.TCPAddr {
116
+ IP : net .IPv4 (127 , 0 , 0 , 1 ),
117
+ Port : 0 ,
118
+ }
119
+ }
120
+
121
+ // Properly buffers data for data channel connections.
122
+ type dataChannelConn struct {
54
123
addr * net.UnixAddr
55
124
dc * webrtc.DataChannel
56
125
rw datachannel.ReadWriteCloser
@@ -62,7 +131,7 @@ type conn struct {
62
131
writeMutex sync.Mutex
63
132
}
64
133
65
- func (c * conn ) init () {
134
+ func (c * dataChannelConn ) init () {
66
135
c .sendMore = make (chan struct {}, 1 )
67
136
c .dc .SetBufferedAmountLowThreshold (bufferedAmountLowThreshold )
68
137
c .dc .OnBufferedAmountLow (func () {
@@ -78,11 +147,11 @@ func (c *conn) init() {
78
147
})
79
148
}
80
149
81
- func (c * conn ) Read (b []byte ) (n int , err error ) {
150
+ func (c * dataChannelConn ) Read (b []byte ) (n int , err error ) {
82
151
return c .rw .Read (b )
83
152
}
84
153
85
- func (c * conn ) Write (b []byte ) (n int , err error ) {
154
+ func (c * dataChannelConn ) Write (b []byte ) (n int , err error ) {
86
155
c .writeMutex .Lock ()
87
156
defer c .writeMutex .Unlock ()
88
157
if len (b ) > maxMessageLength {
@@ -101,7 +170,7 @@ func (c *conn) Write(b []byte) (n int, err error) {
101
170
return c .rw .Write (b )
102
171
}
103
172
104
- func (c * conn ) Close () error {
173
+ func (c * dataChannelConn ) Close () error {
105
174
c .closedMutex .Lock ()
106
175
defer c .closedMutex .Unlock ()
107
176
if ! c .closed {
@@ -111,22 +180,22 @@ func (c *conn) Close() error {
111
180
return c .dc .Close ()
112
181
}
113
182
114
- func (c * conn ) LocalAddr () net.Addr {
183
+ func (c * dataChannelConn ) LocalAddr () net.Addr {
115
184
return c .addr
116
185
}
117
186
118
- func (c * conn ) RemoteAddr () net.Addr {
187
+ func (c * dataChannelConn ) RemoteAddr () net.Addr {
119
188
return c .addr
120
189
}
121
190
122
- func (c * conn ) SetDeadline (t time.Time ) error {
191
+ func (c * dataChannelConn ) SetDeadline (t time.Time ) error {
123
192
return nil
124
193
}
125
194
126
- func (c * conn ) SetReadDeadline (t time.Time ) error {
195
+ func (c * dataChannelConn ) SetReadDeadline (t time.Time ) error {
127
196
return nil
128
197
}
129
198
130
- func (c * conn ) SetWriteDeadline (t time.Time ) error {
199
+ func (c * dataChannelConn ) SetWriteDeadline (t time.Time ) error {
131
200
return nil
132
201
}
0 commit comments