Skip to content

Commit 748fe1f

Browse files
committed
feat: Add peerbroker proxy for agent connections
Agents will connect using this proxy. Eventually we'll intercept some of these messages for validation, but that's not necessary right now.
1 parent 1789ba0 commit 748fe1f

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed

peerbroker/proxy.go

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
package peerbroker
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"net"
9+
"sync"
10+
11+
"github.com/google/uuid"
12+
"github.com/hashicorp/yamux"
13+
"golang.org/x/xerrors"
14+
protobuf "google.golang.org/protobuf/proto"
15+
"storj.io/drpc/drpcmux"
16+
"storj.io/drpc/drpcserver"
17+
18+
"cdr.dev/slog"
19+
"github.com/coder/coder/database"
20+
"github.com/coder/coder/peerbroker/proto"
21+
)
22+
23+
var (
24+
// Each NegotiationConnection() function call spawns a new stream.
25+
streamIDLength = len(uuid.NewString())
26+
// We shouldn't PubSub anything larger than this!
27+
maxPayloadSizeBytes = 8192
28+
)
29+
30+
// ProxyOptions provides values to configure a proxy.
31+
type ProxyOptions struct {
32+
ChannelID string
33+
Logger slog.Logger
34+
Pubsub database.Pubsub
35+
}
36+
37+
// ProxyDial writes client negotiation streams over PubSub.
38+
//
39+
// PubSub is used to geodistribute in a simple way. All message payloads
40+
// are small in size <=8KB, and we don't require delivery guarantees.
41+
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
42+
proxyDial := &proxyDial{
43+
channelID: options.ChannelID,
44+
logger: options.Logger,
45+
pubsub: options.Pubsub,
46+
connection: client,
47+
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
48+
}
49+
return proxyDial, proxyDial.listen()
50+
}
51+
52+
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
53+
// as new NegotiateConnection() streams.
54+
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
55+
mux := drpcmux.New()
56+
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
57+
channelID: options.ChannelID,
58+
pubsub: options.Pubsub,
59+
logger: options.Logger,
60+
})
61+
if err != nil {
62+
return xerrors.Errorf("register peer broker: %w", err)
63+
}
64+
server := drpcserver.New(mux)
65+
err = server.Serve(ctx, connListener)
66+
if err != nil {
67+
if errors.Is(err, yamux.ErrSessionShutdown) {
68+
return nil
69+
}
70+
return xerrors.Errorf("serve: %w", err)
71+
}
72+
return nil
73+
}
74+
75+
type proxyListen struct {
76+
channelID string
77+
pubsub database.Pubsub
78+
logger slog.Logger
79+
}
80+
81+
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
82+
streamID := uuid.NewString()
83+
var err error
84+
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
85+
err := p.onServerToClientMessage(streamID, stream, message)
86+
if err != nil {
87+
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
88+
}
89+
})
90+
if err != nil {
91+
return xerrors.Errorf("subscribe: %w", err)
92+
}
93+
defer closeSubscribe()
94+
for {
95+
clientToServerMessage, err := stream.Recv()
96+
if err != nil {
97+
if errors.Is(err, io.EOF) {
98+
break
99+
}
100+
return xerrors.Errorf("recv: %w", err)
101+
}
102+
data, err := protobuf.Marshal(clientToServerMessage)
103+
if err != nil {
104+
return xerrors.Errorf("marshal: %w", err)
105+
}
106+
if len(data) > maxPayloadSizeBytes {
107+
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
108+
}
109+
data = append([]byte(streamID), data...)
110+
err = p.pubsub.Publish(proxyOutID(p.channelID), data)
111+
if err != nil {
112+
return xerrors.Errorf("publish: %w", err)
113+
}
114+
}
115+
return nil
116+
}
117+
118+
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
119+
if len(message) < streamIDLength {
120+
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
121+
}
122+
serverStreamID := string(message[0:streamIDLength])
123+
if serverStreamID != streamID {
124+
// It's not trying to communicate with this stream!
125+
return nil
126+
}
127+
var msg proto.NegotiateConnection_ServerToClient
128+
err := protobuf.Unmarshal(message[streamIDLength:], &msg)
129+
if err != nil {
130+
return xerrors.Errorf("unmarshal message: %w", err)
131+
}
132+
err = stream.Send(&msg)
133+
if err != nil {
134+
return xerrors.Errorf("send message: %w", err)
135+
}
136+
return nil
137+
}
138+
139+
type proxyDial struct {
140+
channelID string
141+
pubsub database.Pubsub
142+
logger slog.Logger
143+
144+
connection proto.DRPCPeerBrokerClient
145+
closeSubscribe func()
146+
streamMutex sync.Mutex
147+
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
148+
}
149+
150+
func (p *proxyDial) listen() error {
151+
var err error
152+
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
153+
err := p.onClientToServerMessage(ctx, message)
154+
if err != nil {
155+
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
156+
}
157+
})
158+
if err != nil {
159+
return err
160+
}
161+
return nil
162+
}
163+
164+
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
165+
if len(message) < streamIDLength {
166+
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
167+
}
168+
var err error
169+
streamID := string(message[0:streamIDLength])
170+
p.streamMutex.Lock()
171+
stream, ok := p.streams[streamID]
172+
if !ok {
173+
stream, err = p.connection.NegotiateConnection(ctx)
174+
if err != nil {
175+
p.streamMutex.Unlock()
176+
return xerrors.Errorf("negotiate connection: %w", err)
177+
}
178+
p.streams[streamID] = stream
179+
go func() {
180+
defer stream.Close()
181+
182+
err = p.onServerToClientMessage(streamID, stream)
183+
if err != nil {
184+
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
185+
}
186+
}()
187+
go func() {
188+
<-stream.Context().Done()
189+
p.streamMutex.Lock()
190+
delete(p.streams, streamID)
191+
p.streamMutex.Unlock()
192+
}()
193+
}
194+
p.streamMutex.Unlock()
195+
196+
var msg proto.NegotiateConnection_ClientToServer
197+
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
198+
if err != nil {
199+
return xerrors.Errorf("unmarshal message: %w", err)
200+
}
201+
err = stream.Send(&msg)
202+
if err != nil {
203+
return xerrors.Errorf("write message: %w", err)
204+
}
205+
return nil
206+
}
207+
208+
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
209+
for {
210+
serverToClientMessage, err := stream.Recv()
211+
if err != nil {
212+
if errors.Is(err, io.EOF) {
213+
break
214+
}
215+
if errors.Is(err, context.Canceled) {
216+
break
217+
}
218+
return xerrors.Errorf("recv: %w", err)
219+
}
220+
data, err := protobuf.Marshal(serverToClientMessage)
221+
if err != nil {
222+
return xerrors.Errorf("marshal: %w", err)
223+
}
224+
if len(data) > maxPayloadSizeBytes {
225+
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
226+
}
227+
data = append([]byte(streamID), data...)
228+
err = p.pubsub.Publish(proxyInID(p.channelID), data)
229+
if err != nil {
230+
return xerrors.Errorf("publish: %w", err)
231+
}
232+
}
233+
return nil
234+
}
235+
236+
func (p *proxyDial) Close() error {
237+
p.streamMutex.Lock()
238+
defer p.streamMutex.Unlock()
239+
p.closeSubscribe()
240+
return nil
241+
}
242+
243+
func proxyOutID(channelID string) string {
244+
return fmt.Sprintf("%s-out", channelID)
245+
}
246+
247+
func proxyInID(channelID string) string {
248+
return fmt.Sprintf("%s-in", channelID)
249+
}

peerbroker/proxy_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package peerbroker_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/pion/webrtc/v3"
8+
"github.com/stretchr/testify/require"
9+
10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/slogtest"
12+
"github.com/coder/coder/database"
13+
"github.com/coder/coder/peer"
14+
"github.com/coder/coder/peerbroker"
15+
"github.com/coder/coder/peerbroker/proto"
16+
"github.com/coder/coder/provisionersdk"
17+
)
18+
19+
func TestProxy(t *testing.T) {
20+
t.Parallel()
21+
ctx := context.Background()
22+
channelID := "hello"
23+
pubsub := database.NewPubsubInMemory()
24+
dialerClient, dialerServer := provisionersdk.TransportPipe()
25+
defer dialerClient.Close()
26+
defer dialerServer.Close()
27+
listenerClient, listenerServer := provisionersdk.TransportPipe()
28+
defer listenerClient.Close()
29+
defer listenerServer.Close()
30+
31+
listener, err := peerbroker.Listen(listenerServer, &peer.ConnOptions{
32+
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
33+
})
34+
require.NoError(t, err)
35+
36+
proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
37+
ChannelID: channelID,
38+
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
39+
Pubsub: pubsub,
40+
})
41+
require.NoError(t, err)
42+
t.Cleanup(func() {
43+
_ = proxyCloser.Close()
44+
})
45+
46+
go func() {
47+
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
48+
ChannelID: channelID,
49+
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
50+
Pubsub: pubsub,
51+
})
52+
require.NoError(t, err)
53+
}()
54+
55+
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
56+
stream, err := api.NegotiateConnection(ctx)
57+
require.NoError(t, err)
58+
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
59+
URLs: []string{"stun:stun.l.google.com:19302"},
60+
}}, &peer.ConnOptions{
61+
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
62+
})
63+
require.NoError(t, err)
64+
defer clientConn.Close()
65+
66+
serverConn, err := listener.Accept()
67+
require.NoError(t, err)
68+
defer serverConn.Close()
69+
_, err = serverConn.Ping()
70+
require.NoError(t, err)
71+
72+
_, err = clientConn.Ping()
73+
require.NoError(t, err)
74+
}

0 commit comments

Comments
 (0)