Skip to content

chore: apply the 4mb max limit on drpc protocol message size #17771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func NewClient(t testing.TB,
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
Expand Down
2 changes: 2 additions & 0 deletions coderd/agentapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
Expand Down Expand Up @@ -209,6 +210,7 @@ func (a *API) Server(ctx context.Context) (*drpcserver.Server, error) {

return drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
Expand Down
3 changes: 2 additions & 1 deletion coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"tailscale.com/util/singleflight"

"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/quartz"
"github.com/coder/serpent"

Expand Down Expand Up @@ -84,7 +85,6 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
Expand Down Expand Up @@ -1803,6 +1803,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
}
server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
Expand Down
28 changes: 24 additions & 4 deletions codersdk/drpcsdk/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmanager"

"github.com/coder/coder/v2/coderd/tracing"
)
Expand All @@ -19,6 +20,17 @@ const (
MaxMessageSize = 4 << 20
)

func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options {
if options == nil {
options = &drpcmanager.Options{}
}

if options.Reader.MaximumBufferSize == 0 {
options.Reader.MaximumBufferSize = MaxMessageSize
}
return *options
}

// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
func MultiplexedConn(session *yamux.Session) drpc.Conn {
return &multiplexedDRPC{session}
Expand All @@ -43,7 +55,9 @@ func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encod
if err != nil {
return err
}
dConn := drpcconn.New(conn)
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})
defer func() {
_ = dConn.Close()
}()
Expand All @@ -55,7 +69,9 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En
if err != nil {
return nil, err
}
dConn := drpcconn.New(conn)
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})
stream, err := dConn.NewStream(ctx, rpc, enc)
if err == nil {
go func() {
Expand Down Expand Up @@ -97,7 +113,9 @@ func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inM
return err
}

dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})}
defer func() {
_ = dConn.Close()
_ = conn.Close()
Expand All @@ -110,7 +128,9 @@ func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding)
if err != nil {
return nil, err
}
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})}
stream, err := dConn.NewStream(ctx, rpc, enc)
if err != nil {
_ = dConn.Close()
Expand Down
2 changes: 2 additions & 0 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/websocket"
Expand Down Expand Up @@ -370,6 +371,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
return
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
Expand Down
7 changes: 5 additions & 2 deletions enterprise/provisionerd/remoteprovisioners.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisioner/echo"
agpl "github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
Expand Down Expand Up @@ -188,8 +189,10 @@ func (r *remoteConnector) handleConn(conn net.Conn) {
logger.Info(r.ctx, "provisioner connected")
closeConn = false // we're passing the conn over the channel
w.respCh <- agpl.ConnectResponse{
Job: w.job,
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
Job: w.job,
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(tlsConn, drpcconn.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})),
}
}

Expand Down
77 changes: 76 additions & 1 deletion provisionerd/provisionerd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,79 @@ func TestProvisionerd(t *testing.T) {
require.NoError(t, closer.Close())
})

// LargePayloads sends a 3mb tar file to the provisioner. The provisioner also
// returns large payload messages back. The limit should be 4mb, so all
// these messages should work.
t.Run("LargePayloads", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
largeSize = 3 * 1024 * 1024
completeChan = make(chan struct{})
completeOnce sync.Once
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"toolarge.txt": string(make([]byte, largeSize)),
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
)

closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: noopUpdateJob,
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
s *provisionersdk.Session,
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
return &sdkproto.ParseComplete{
// 6mb readme
Readme: make([]byte, largeSize),
}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{
Resources: []*sdkproto.Resource{},
Plan: make([]byte, largeSize),
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
return &sdkproto.ApplyComplete{
State: make([]byte, largeSize),
}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})

t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
Expand Down Expand Up @@ -1115,7 +1188,9 @@ func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server pr
mux := drpcmux.New()
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
require.NoError(t, err)
srv := drpcserver.New(mux)
srv := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})
ctx, cancelFunc := context.WithCancel(context.Background())
closed := make(chan struct{})
go func() {
Expand Down
5 changes: 4 additions & 1 deletion provisionersdk/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"storj.io/drpc/drpcserver"

"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/drpcsdk"

"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/provisionersdk/proto"
Expand Down Expand Up @@ -81,7 +82,9 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
if err != nil {
return xerrors.Errorf("register provisioner: %w", err)
}
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
srv := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})

if options.Listener != nil {
err = srv.Serve(ctx, options.Listener)
Expand Down
4 changes: 3 additions & 1 deletion provisionersdk/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ func TestProvisionerSDK(t *testing.T) {
srvErr <- err
}()

api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
api := proto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(client, drpcconn.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
}))
s, err := api.Session(ctx)
require.NoError(t, err)
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})
Expand Down
2 changes: 2 additions & 0 deletions tailnet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
Expand Down Expand Up @@ -92,6 +93,7 @@ func NewClientService(options ClientServiceOptions) (
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
Expand Down
Loading