Skip to content

Commit b58e168

Browse files
authored
feat: Add peerbroker proxy for agent connections (#349)
* feat: Add peerbroker proxy for agent connections Agents will connect using this proxy. Eventually we'll intercept some of these messages for validation, but that's not necessary right now. * Add ASCII chart
1 parent a053fe8 commit b58e168

File tree

2 files changed

+341
-0
lines changed

2 files changed

+341
-0
lines changed

peerbroker/proxy.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package peerbroker
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"net"
9+
"sync"
10+
11+
"github.com/google/uuid"
12+
"github.com/hashicorp/yamux"
13+
"golang.org/x/xerrors"
14+
protobuf "google.golang.org/protobuf/proto"
15+
"storj.io/drpc/drpcmux"
16+
"storj.io/drpc/drpcserver"
17+
18+
"cdr.dev/slog"
19+
"github.com/coder/coder/database"
20+
"github.com/coder/coder/peerbroker/proto"
21+
)
22+
23+
var (
24+
// Each NegotiateConnection() function call spawns a new stream.
25+
streamIDLength = len(uuid.NewString())
26+
// We shouldn't PubSub anything larger than this!
27+
maxPayloadSizeBytes = 8192
28+
)
29+
30+
// ProxyOptions provides values to configure a proxy.
31+
type ProxyOptions struct {
32+
ChannelID string
33+
Logger slog.Logger
34+
Pubsub database.Pubsub
35+
}
36+
37+
// ProxyDial writes client negotiation streams over PubSub.
38+
//
39+
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
40+
// messages are small in size (<=8KB), and we don't require delivery
41+
// guarantees because connections can always be renegotiated.
42+
// ┌────────────────────┐ ┌─────────────────────────────┐
43+
// │ coderd │ │ coderd │
44+
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
45+
// │ client │ │ │ │ │ ┌─────┐
46+
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
47+
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
48+
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
49+
// │ │ │NegotiateConnection() streams│
50+
// │<stream-id><payload>│ │or write to existing ones. │
51+
// └────────────────────┘ └─────────────────────────────┘
52+
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
53+
proxyDial := &proxyDial{
54+
channelID: options.ChannelID,
55+
logger: options.Logger,
56+
pubsub: options.Pubsub,
57+
connection: client,
58+
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
59+
}
60+
return proxyDial, proxyDial.listen()
61+
}
62+
63+
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
64+
// as new NegotiateConnection() streams.
65+
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
66+
mux := drpcmux.New()
67+
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
68+
channelID: options.ChannelID,
69+
pubsub: options.Pubsub,
70+
logger: options.Logger,
71+
})
72+
if err != nil {
73+
return xerrors.Errorf("register peer broker: %w", err)
74+
}
75+
server := drpcserver.New(mux)
76+
err = server.Serve(ctx, connListener)
77+
if err != nil {
78+
if errors.Is(err, yamux.ErrSessionShutdown) {
79+
return nil
80+
}
81+
return xerrors.Errorf("serve: %w", err)
82+
}
83+
return nil
84+
}
85+
86+
type proxyListen struct {
87+
channelID string
88+
pubsub database.Pubsub
89+
logger slog.Logger
90+
}
91+
92+
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
93+
streamID := uuid.NewString()
94+
var err error
95+
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
96+
err := p.onServerToClientMessage(streamID, stream, message)
97+
if err != nil {
98+
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
99+
}
100+
})
101+
if err != nil {
102+
return xerrors.Errorf("subscribe: %w", err)
103+
}
104+
defer closeSubscribe()
105+
for {
106+
clientToServerMessage, err := stream.Recv()
107+
if err != nil {
108+
if errors.Is(err, io.EOF) {
109+
break
110+
}
111+
return xerrors.Errorf("recv: %w", err)
112+
}
113+
data, err := protobuf.Marshal(clientToServerMessage)
114+
if err != nil {
115+
return xerrors.Errorf("marshal: %w", err)
116+
}
117+
if len(data) > maxPayloadSizeBytes {
118+
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
119+
}
120+
data = append([]byte(streamID), data...)
121+
err = p.pubsub.Publish(proxyOutID(p.channelID), data)
122+
if err != nil {
123+
return xerrors.Errorf("publish: %w", err)
124+
}
125+
}
126+
return nil
127+
}
128+
129+
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
130+
if len(message) < streamIDLength {
131+
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
132+
}
133+
serverStreamID := string(message[0:streamIDLength])
134+
if serverStreamID != streamID {
135+
// It's not trying to communicate with this stream!
136+
return nil
137+
}
138+
var msg proto.NegotiateConnection_ServerToClient
139+
err := protobuf.Unmarshal(message[streamIDLength:], &msg)
140+
if err != nil {
141+
return xerrors.Errorf("unmarshal message: %w", err)
142+
}
143+
err = stream.Send(&msg)
144+
if err != nil {
145+
return xerrors.Errorf("send message: %w", err)
146+
}
147+
return nil
148+
}
149+
150+
type proxyDial struct {
151+
channelID string
152+
pubsub database.Pubsub
153+
logger slog.Logger
154+
155+
connection proto.DRPCPeerBrokerClient
156+
closeSubscribe func()
157+
streamMutex sync.Mutex
158+
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
159+
}
160+
161+
func (p *proxyDial) listen() error {
162+
var err error
163+
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
164+
err := p.onClientToServerMessage(ctx, message)
165+
if err != nil {
166+
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
167+
}
168+
})
169+
if err != nil {
170+
return err
171+
}
172+
return nil
173+
}
174+
175+
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
176+
if len(message) < streamIDLength {
177+
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
178+
}
179+
var err error
180+
streamID := string(message[0:streamIDLength])
181+
p.streamMutex.Lock()
182+
stream, ok := p.streams[streamID]
183+
if !ok {
184+
stream, err = p.connection.NegotiateConnection(ctx)
185+
if err != nil {
186+
p.streamMutex.Unlock()
187+
return xerrors.Errorf("negotiate connection: %w", err)
188+
}
189+
p.streams[streamID] = stream
190+
go func() {
191+
defer stream.Close()
192+
193+
err = p.onServerToClientMessage(streamID, stream)
194+
if err != nil {
195+
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
196+
}
197+
}()
198+
go func() {
199+
<-stream.Context().Done()
200+
p.streamMutex.Lock()
201+
delete(p.streams, streamID)
202+
p.streamMutex.Unlock()
203+
}()
204+
}
205+
p.streamMutex.Unlock()
206+
207+
var msg proto.NegotiateConnection_ClientToServer
208+
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
209+
if err != nil {
210+
return xerrors.Errorf("unmarshal message: %w", err)
211+
}
212+
err = stream.Send(&msg)
213+
if err != nil {
214+
return xerrors.Errorf("write message: %w", err)
215+
}
216+
return nil
217+
}
218+
219+
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
220+
for {
221+
serverToClientMessage, err := stream.Recv()
222+
if err != nil {
223+
if errors.Is(err, io.EOF) {
224+
break
225+
}
226+
if errors.Is(err, context.Canceled) {
227+
break
228+
}
229+
return xerrors.Errorf("recv: %w", err)
230+
}
231+
data, err := protobuf.Marshal(serverToClientMessage)
232+
if err != nil {
233+
return xerrors.Errorf("marshal: %w", err)
234+
}
235+
if len(data) > maxPayloadSizeBytes {
236+
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
237+
}
238+
data = append([]byte(streamID), data...)
239+
err = p.pubsub.Publish(proxyInID(p.channelID), data)
240+
if err != nil {
241+
return xerrors.Errorf("publish: %w", err)
242+
}
243+
}
244+
return nil
245+
}
246+
247+
func (p *proxyDial) Close() error {
248+
p.streamMutex.Lock()
249+
defer p.streamMutex.Unlock()
250+
p.closeSubscribe()
251+
return nil
252+
}
253+
254+
func proxyOutID(channelID string) string {
255+
return fmt.Sprintf("%s-out", channelID)
256+
}
257+
258+
func proxyInID(channelID string) string {
259+
return fmt.Sprintf("%s-in", channelID)
260+
}

peerbroker/proxy_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package peerbroker_test
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
8+
"github.com/pion/webrtc/v3"
9+
"github.com/stretchr/testify/require"
10+
11+
"cdr.dev/slog"
12+
"cdr.dev/slog/sloggers/slogtest"
13+
"github.com/coder/coder/database"
14+
"github.com/coder/coder/peer"
15+
"github.com/coder/coder/peerbroker"
16+
"github.com/coder/coder/peerbroker/proto"
17+
"github.com/coder/coder/provisionersdk"
18+
)
19+
20+
func TestProxy(t *testing.T) {
21+
t.Parallel()
22+
ctx := context.Background()
23+
channelID := "hello"
24+
pubsub := database.NewPubsubInMemory()
25+
dialerClient, dialerServer := provisionersdk.TransportPipe()
26+
defer dialerClient.Close()
27+
defer dialerServer.Close()
28+
listenerClient, listenerServer := provisionersdk.TransportPipe()
29+
defer listenerClient.Close()
30+
defer listenerServer.Close()
31+
32+
listener, err := peerbroker.Listen(listenerServer, &peer.ConnOptions{
33+
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
34+
})
35+
require.NoError(t, err)
36+
37+
proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
38+
ChannelID: channelID,
39+
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
40+
Pubsub: pubsub,
41+
})
42+
require.NoError(t, err)
43+
t.Cleanup(func() {
44+
_ = proxyCloser.Close()
45+
})
46+
47+
var wg sync.WaitGroup
48+
wg.Add(1)
49+
go func() {
50+
defer wg.Done()
51+
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
52+
ChannelID: channelID,
53+
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
54+
Pubsub: pubsub,
55+
})
56+
require.NoError(t, err)
57+
}()
58+
59+
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
60+
stream, err := api.NegotiateConnection(ctx)
61+
require.NoError(t, err)
62+
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
63+
URLs: []string{"stun:stun.l.google.com:19302"},
64+
}}, &peer.ConnOptions{
65+
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
66+
})
67+
require.NoError(t, err)
68+
defer clientConn.Close()
69+
70+
serverConn, err := listener.Accept()
71+
require.NoError(t, err)
72+
defer serverConn.Close()
73+
_, err = serverConn.Ping()
74+
require.NoError(t, err)
75+
76+
_, err = clientConn.Ping()
77+
require.NoError(t, err)
78+
79+
_ = dialerServer.Close()
80+
wg.Wait()
81+
}

0 commit comments

Comments
 (0)