Skip to content

Commit c041459

Browse files
committed
chore: adjustable drpc protocol message size limit
1 parent ea2cae0 commit c041459

File tree

10 files changed

+122
-11
lines changed

10 files changed

+122
-11
lines changed

agent/agenttest/client.go

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func NewClient(t testing.TB,
6060
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
6161
require.NoError(t, err)
6262
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
63+
Manager: drpcsdk.DefaultDRPCOptions(nil),
6364
Log: func(err error) {
6465
if xerrors.Is(err, io.EOF) {
6566
return

coderd/agentapi/api.go

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/coder/coder/v2/coderd/wspubsub"
3131
"github.com/coder/coder/v2/codersdk"
3232
"github.com/coder/coder/v2/codersdk/agentsdk"
33+
"github.com/coder/coder/v2/codersdk/drpcsdk"
3334
"github.com/coder/coder/v2/tailnet"
3435
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
3536
"github.com/coder/quartz"
@@ -209,6 +210,7 @@ func (a *API) Server(ctx context.Context) (*drpcserver.Server, error) {
209210

210211
return drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
211212
drpcserver.Options{
213+
Manager: drpcsdk.DefaultDRPCOptions(nil),
212214
Log: func(err error) {
213215
if xerrors.Is(err, io.EOF) {
214216
return

coderd/coderd.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
"tailscale.com/util/singleflight"
3939

4040
"cdr.dev/slog"
41+
"github.com/coder/coder/v2/codersdk/drpcsdk"
4142
"github.com/coder/quartz"
4243
"github.com/coder/serpent"
4344

@@ -84,7 +85,6 @@ import (
8485
"github.com/coder/coder/v2/coderd/workspaceapps"
8586
"github.com/coder/coder/v2/coderd/workspacestats"
8687
"github.com/coder/coder/v2/codersdk"
87-
"github.com/coder/coder/v2/codersdk/drpcsdk"
8888
"github.com/coder/coder/v2/codersdk/healthsdk"
8989
"github.com/coder/coder/v2/provisionerd/proto"
9090
"github.com/coder/coder/v2/provisionersdk"
@@ -1803,6 +1803,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
18031803
}
18041804
server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
18051805
drpcserver.Options{
1806+
Manager: drpcsdk.DefaultDRPCOptions(nil),
18061807
Log: func(err error) {
18071808
if xerrors.Is(err, io.EOF) {
18081809
return

codersdk/drpcsdk/transport.go

+25-5
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,28 @@ import (
99
"github.com/valyala/fasthttp/fasthttputil"
1010
"storj.io/drpc"
1111
"storj.io/drpc/drpcconn"
12+
"storj.io/drpc/drpcmanager"
1213

1314
"github.com/coder/coder/v2/coderd/tracing"
1415
)
1516

1617
const (
1718
// MaxMessageSize is the maximum payload size that can be
1819
// transported without error.
19-
MaxMessageSize = 4 << 20
20+
MaxMessageSize = 10 << 20
2021
)
2122

23+
func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options {
24+
if options == nil {
25+
options = &drpcmanager.Options{}
26+
}
27+
28+
if options.Reader.MaximumBufferSize == 0 {
29+
options.Reader.MaximumBufferSize = MaxMessageSize
30+
}
31+
return *options
32+
}
33+
2234
// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
2335
func MultiplexedConn(session *yamux.Session) drpc.Conn {
2436
return &multiplexedDRPC{session}
@@ -43,7 +55,9 @@ func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encod
4355
if err != nil {
4456
return err
4557
}
46-
dConn := drpcconn.New(conn)
58+
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
59+
Manager: DefaultDRPCOptions(nil),
60+
})
4761
defer func() {
4862
_ = dConn.Close()
4963
}()
@@ -55,7 +69,9 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En
5569
if err != nil {
5670
return nil, err
5771
}
58-
dConn := drpcconn.New(conn)
72+
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
73+
Manager: DefaultDRPCOptions(nil),
74+
})
5975
stream, err := dConn.NewStream(ctx, rpc, enc)
6076
if err == nil {
6177
go func() {
@@ -97,7 +113,9 @@ func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inM
97113
return err
98114
}
99115

100-
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
116+
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
117+
Manager: DefaultDRPCOptions(nil),
118+
})}
101119
defer func() {
102120
_ = dConn.Close()
103121
_ = conn.Close()
@@ -110,7 +128,9 @@ func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding)
110128
if err != nil {
111129
return nil, err
112130
}
113-
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
131+
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
132+
Manager: DefaultDRPCOptions(nil),
133+
})}
114134
stream, err := dConn.NewStream(ctx, rpc, enc)
115135
if err != nil {
116136
_ = dConn.Close()

enterprise/coderd/provisionerdaemons.go

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/coder/coder/v2/coderd/telemetry"
3232
"github.com/coder/coder/v2/coderd/util/ptr"
3333
"github.com/coder/coder/v2/codersdk"
34+
"github.com/coder/coder/v2/codersdk/drpcsdk"
3435
"github.com/coder/coder/v2/provisionerd/proto"
3536
"github.com/coder/coder/v2/provisionersdk"
3637
"github.com/coder/websocket"
@@ -370,6 +371,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
370371
return
371372
}
372373
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
374+
Manager: drpcsdk.DefaultDRPCOptions(nil),
373375
Log: func(err error) {
374376
if xerrors.Is(err, io.EOF) {
375377
return

enterprise/provisionerd/remoteprovisioners.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727

2828
"cdr.dev/slog"
2929
"github.com/coder/coder/v2/coderd/database"
30+
"github.com/coder/coder/v2/codersdk/drpcsdk"
3031
"github.com/coder/coder/v2/provisioner/echo"
3132
agpl "github.com/coder/coder/v2/provisionerd"
3233
"github.com/coder/coder/v2/provisionerd/proto"
@@ -188,8 +189,10 @@ func (r *remoteConnector) handleConn(conn net.Conn) {
188189
logger.Info(r.ctx, "provisioner connected")
189190
closeConn = false // we're passing the conn over the channel
190191
w.respCh <- agpl.ConnectResponse{
191-
Job: w.job,
192-
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
192+
Job: w.job,
193+
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(tlsConn, drpcconn.Options{
194+
Manager: drpcsdk.DefaultDRPCOptions(nil),
195+
})),
193196
}
194197
}
195198

provisionerd/provisionerd_test.go

+76-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,79 @@ func TestProvisionerd(t *testing.T) {
178178
require.NoError(t, closer.Close())
179179
})
180180

181+
// LargePayloads sends a 6mb tar file to the provisioner. The provisioner also
182+
// returns large payload messages back. The limit should be 10mb, so all
183+
// these messages should work.
184+
t.Run("LargePayloads", func(t *testing.T) {
185+
t.Parallel()
186+
done := make(chan struct{})
187+
t.Cleanup(func() {
188+
close(done)
189+
})
190+
var (
191+
largeSize = 6 * 1024 * 1024
192+
completeChan = make(chan struct{})
193+
completeOnce sync.Once
194+
acq = newAcquireOne(t, &proto.AcquiredJob{
195+
JobId: "test",
196+
Provisioner: "someprovisioner",
197+
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
198+
"toolarge.txt": string(make([]byte, largeSize)),
199+
}),
200+
Type: &proto.AcquiredJob_TemplateImport_{
201+
TemplateImport: &proto.AcquiredJob_TemplateImport{
202+
Metadata: &sdkproto.Metadata{},
203+
},
204+
},
205+
})
206+
)
207+
208+
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
209+
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
210+
acquireJobWithCancel: acq.acquireWithCancel,
211+
updateJob: noopUpdateJob,
212+
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
213+
completeOnce.Do(func() { close(completeChan) })
214+
return &proto.Empty{}, nil
215+
},
216+
}), nil
217+
}, provisionerd.LocalProvisioners{
218+
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
219+
parse: func(
220+
s *provisionersdk.Session,
221+
_ *sdkproto.ParseRequest,
222+
cancelOrComplete <-chan struct{},
223+
) *sdkproto.ParseComplete {
224+
return &sdkproto.ParseComplete{
225+
// 6mb readme
226+
Readme: make([]byte, largeSize),
227+
}
228+
},
229+
plan: func(
230+
_ *provisionersdk.Session,
231+
_ *sdkproto.PlanRequest,
232+
_ <-chan struct{},
233+
) *sdkproto.PlanComplete {
234+
return &sdkproto.PlanComplete{
235+
Resources: []*sdkproto.Resource{},
236+
Plan: make([]byte, largeSize),
237+
}
238+
},
239+
apply: func(
240+
_ *provisionersdk.Session,
241+
_ *sdkproto.ApplyRequest,
242+
_ <-chan struct{},
243+
) *sdkproto.ApplyComplete {
244+
return &sdkproto.ApplyComplete{
245+
State: make([]byte, largeSize),
246+
}
247+
},
248+
}),
249+
})
250+
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
251+
require.NoError(t, closer.Close())
252+
})
253+
181254
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
182255
t.Parallel()
183256
done := make(chan struct{})
@@ -1115,7 +1188,9 @@ func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server pr
11151188
mux := drpcmux.New()
11161189
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
11171190
require.NoError(t, err)
1118-
srv := drpcserver.New(mux)
1191+
srv := drpcserver.NewWithOptions(mux, drpcserver.Options{
1192+
Manager: drpcsdk.DefaultDRPCOptions(nil),
1193+
})
11191194
ctx, cancelFunc := context.WithCancel(context.Background())
11201195
closed := make(chan struct{})
11211196
go func() {

provisionersdk/serve.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"storj.io/drpc/drpcserver"
1616

1717
"cdr.dev/slog"
18+
"github.com/coder/coder/v2/codersdk/drpcsdk"
1819

1920
"github.com/coder/coder/v2/coderd/tracing"
2021
"github.com/coder/coder/v2/provisionersdk/proto"
@@ -81,7 +82,9 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
8182
if err != nil {
8283
return xerrors.Errorf("register provisioner: %w", err)
8384
}
84-
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
85+
srv := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{
86+
Manager: drpcsdk.DefaultDRPCOptions(nil),
87+
})
8588

8689
if options.Listener != nil {
8790
err = srv.Serve(ctx, options.Listener)

provisionersdk/serve_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ func TestProvisionerSDK(t *testing.T) {
9494
srvErr <- err
9595
}()
9696

97-
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
97+
api := proto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(client, drpcconn.Options{
98+
Manager: drpcsdk.DefaultDRPCOptions(nil),
99+
}))
98100
s, err := api.Session(ctx)
99101
require.NoError(t, err)
100102
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})

tailnet/service.go

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"cdr.dev/slog"
1919
"github.com/coder/coder/v2/apiversion"
20+
"github.com/coder/coder/v2/codersdk/drpcsdk"
2021
"github.com/coder/coder/v2/tailnet/proto"
2122
"github.com/coder/quartz"
2223
)
@@ -92,6 +93,7 @@ func NewClientService(options ClientServiceOptions) (
9293
return nil, xerrors.Errorf("register DRPC service: %w", err)
9394
}
9495
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
96+
Manager: drpcsdk.DefaultDRPCOptions(nil),
9597
Log: func(err error) {
9698
if xerrors.Is(err, io.EOF) ||
9799
xerrors.Is(err, context.Canceled) ||

0 commit comments

Comments
 (0)