From c04145963b95957fbf685945028172e905c72378 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 12 May 2025 10:04:17 -0500 Subject: [PATCH 1/3] chore: adjustable drpc protocol message size limit --- agent/agenttest/client.go | 1 + coderd/agentapi/api.go | 2 + coderd/coderd.go | 3 +- codersdk/drpcsdk/transport.go | 30 ++++++-- enterprise/coderd/provisionerdaemons.go | 2 + enterprise/provisionerd/remoteprovisioners.go | 7 +- provisionerd/provisionerd_test.go | 77 ++++++++++++++++++- provisionersdk/serve.go | 5 +- provisionersdk/serve_test.go | 4 +- tailnet/service.go | 2 + 10 files changed, 122 insertions(+), 11 deletions(-) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 73fb40e826519..24658c44d6e18 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -60,6 +60,7 @@ func NewClient(t testing.TB, err = agentproto.DRPCRegisterAgent(mux, fakeAAPI) require.NoError(t, err) server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), Log: func(err error) { if xerrors.Is(err, io.EOF) { return diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index 1b2b8d92a10ef..8a0871bc083d4 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -30,6 +30,7 @@ import ( "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/tailnet" tailnetproto "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/quartz" @@ -209,6 +210,7 @@ func (a *API) Server(ctx context.Context) (*drpcserver.Server, error) { return drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), Log: func(err error) { if xerrors.Is(err, io.EOF) { return diff --git a/coderd/coderd.go b/coderd/coderd.go index 1b4b5746b7f7e..86db50d9559a4 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -38,6 +38,7 @@ import ( "tailscale.com/util/singleflight" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/quartz" "github.com/coder/serpent" @@ -84,7 +85,6 @@ import ( "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/codersdk/healthsdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" @@ -1803,6 +1803,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n } server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), Log: func(err error) { if xerrors.Is(err, io.EOF) { return diff --git a/codersdk/drpcsdk/transport.go b/codersdk/drpcsdk/transport.go index 84cef5e6d8db1..2c037dc512b8f 100644 --- a/codersdk/drpcsdk/transport.go +++ b/codersdk/drpcsdk/transport.go @@ -9,6 +9,7 @@ import ( "github.com/valyala/fasthttp/fasthttputil" "storj.io/drpc" "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmanager" "github.com/coder/coder/v2/coderd/tracing" ) @@ -16,9 +17,20 @@ import ( const ( // MaxMessageSize is the maximum payload size that can be // transported without error. - MaxMessageSize = 4 << 20 + MaxMessageSize = 10 << 20 ) +func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options { + if options == nil { + options = &drpcmanager.Options{} + } + + if options.Reader.MaximumBufferSize == 0 { + options.Reader.MaximumBufferSize = MaxMessageSize + } + return *options +} + // MultiplexedConn returns a multiplexed dRPC connection from a yamux Session. func MultiplexedConn(session *yamux.Session) drpc.Conn { return &multiplexedDRPC{session} @@ -43,7 +55,9 @@ func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encod if err != nil { return err } - dConn := drpcconn.New(conn) + dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{ + Manager: DefaultDRPCOptions(nil), + }) defer func() { _ = dConn.Close() }() @@ -55,7 +69,9 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En if err != nil { return nil, err } - dConn := drpcconn.New(conn) + dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{ + Manager: DefaultDRPCOptions(nil), + }) stream, err := dConn.NewStream(ctx, rpc, enc) if err == nil { go func() { @@ -97,7 +113,9 @@ func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inM return err } - dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)} + dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{ + Manager: DefaultDRPCOptions(nil), + })} defer func() { _ = dConn.Close() _ = conn.Close() @@ -110,7 +128,9 @@ func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) if err != nil { return nil, err } - dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)} + dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{ + Manager: DefaultDRPCOptions(nil), + })} stream, err := dConn.NewStream(ctx, rpc, enc) if err != nil { _ = dConn.Close() diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 6ffa15851214d..b9a93144f4931 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/websocket" @@ -370,6 +371,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) return } server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), Log: func(err error) { if xerrors.Is(err, io.EOF) { return diff --git a/enterprise/provisionerd/remoteprovisioners.go b/enterprise/provisionerd/remoteprovisioners.go index 26c93322e662a..1ae02f00312e9 100644 --- a/enterprise/provisionerd/remoteprovisioners.go +++ b/enterprise/provisionerd/remoteprovisioners.go @@ -27,6 +27,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisioner/echo" agpl "github.com/coder/coder/v2/provisionerd" "github.com/coder/coder/v2/provisionerd/proto" @@ -188,8 +189,10 @@ func (r *remoteConnector) handleConn(conn net.Conn) { logger.Info(r.ctx, "provisioner connected") closeConn = false // we're passing the conn over the channel w.respCh <- agpl.ConnectResponse{ - Job: w.job, - Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)), + Job: w.job, + Client: sdkproto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(tlsConn, drpcconn.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + })), } } diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index a9418d9391c8f..119fb932da581 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -178,6 +178,79 @@ func TestProvisionerd(t *testing.T) { require.NoError(t, closer.Close()) }) + // LargePayloads sends a 6mb tar file to the provisioner. The provisioner also + // returns large payload messages back. The limit should be 10mb, so all + // these messages should work. + t.Run("LargePayloads", func(t *testing.T) { + t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) + var ( + largeSize = 6 * 1024 * 1024 + completeChan = make(chan struct{}) + completeOnce sync.Once + acq = newAcquireOne(t, &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + TemplateSourceArchive: testutil.CreateTar(t, map[string]string{ + "toolarge.txt": string(make([]byte, largeSize)), + }), + Type: &proto.AcquiredJob_TemplateImport_{ + TemplateImport: &proto.AcquiredJob_TemplateImport{ + Metadata: &sdkproto.Metadata{}, + }, + }, + }) + ) + + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ + acquireJobWithCancel: acq.acquireWithCancel, + updateJob: noopUpdateJob, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + completeOnce.Do(func() { close(completeChan) }) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.LocalProvisioners{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ + parse: func( + s *provisionersdk.Session, + _ *sdkproto.ParseRequest, + cancelOrComplete <-chan struct{}, + ) *sdkproto.ParseComplete { + return &sdkproto.ParseComplete{ + // 6mb readme + Readme: make([]byte, largeSize), + } + }, + plan: func( + _ *provisionersdk.Session, + _ *sdkproto.PlanRequest, + _ <-chan struct{}, + ) *sdkproto.PlanComplete { + return &sdkproto.PlanComplete{ + Resources: []*sdkproto.Resource{}, + Plan: make([]byte, largeSize), + } + }, + apply: func( + _ *provisionersdk.Session, + _ *sdkproto.ApplyRequest, + _ <-chan struct{}, + ) *sdkproto.ApplyComplete { + return &sdkproto.ApplyComplete{ + State: make([]byte, largeSize), + } + }, + }), + }) + require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) + require.NoError(t, closer.Close()) + }) + t.Run("RunningPeriodicUpdate", func(t *testing.T) { t.Parallel() done := make(chan struct{}) @@ -1115,7 +1188,9 @@ func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server pr mux := drpcmux.New() err := proto.DRPCRegisterProvisionerDaemon(mux, &server) require.NoError(t, err) - srv := drpcserver.New(mux) + srv := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + }) ctx, cancelFunc := context.WithCancel(context.Background()) closed := make(chan struct{}) go func() { diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index b91329d0665fe..c652cfa94949d 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -15,6 +15,7 @@ import ( "storj.io/drpc/drpcserver" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/provisionersdk/proto" @@ -81,7 +82,9 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error { if err != nil { return xerrors.Errorf("register provisioner: %w", err) } - srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux}) + srv := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + }) if options.Listener != nil { err = srv.Serve(ctx, options.Listener) diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index a78a573e11b02..4fc7342b1eed2 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -94,7 +94,9 @@ func TestProvisionerSDK(t *testing.T) { srvErr <- err }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(client, drpcconn.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + })) s, err := api.Session(ctx) require.NoError(t, err) err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}}) diff --git a/tailnet/service.go b/tailnet/service.go index abb91acef8772..c3a6b9f2b282b 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/apiversion" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/quartz" ) @@ -92,6 +93,7 @@ func NewClientService(options ClientServiceOptions) ( return nil, xerrors.Errorf("register DRPC service: %w", err) } server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), Log: func(err error) { if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || From 915457f295bbf3e851e644f2bb9d20b0faa1290f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 12 May 2025 13:39:52 -0500 Subject: [PATCH 2/3] revert back to 4mb --- codersdk/drpcsdk/transport.go | 2 +- provisionerd/provisionerd_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codersdk/drpcsdk/transport.go b/codersdk/drpcsdk/transport.go index 2c037dc512b8f..82a0921b41057 100644 --- a/codersdk/drpcsdk/transport.go +++ b/codersdk/drpcsdk/transport.go @@ -17,7 +17,7 @@ import ( const ( // MaxMessageSize is the maximum payload size that can be // transported without error. - MaxMessageSize = 10 << 20 + MaxMessageSize = 4 << 20 ) func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options { diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 119fb932da581..d8a3fa394edca 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -178,7 +178,7 @@ func TestProvisionerd(t *testing.T) { require.NoError(t, closer.Close()) }) - // LargePayloads sends a 6mb tar file to the provisioner. The provisioner also + // LargePayloads sends a 3mb tar file to the provisioner. The provisioner also // returns large payload messages back. The limit should be 10mb, so all // these messages should work. t.Run("LargePayloads", func(t *testing.T) { @@ -188,7 +188,7 @@ func TestProvisionerd(t *testing.T) { close(done) }) var ( - largeSize = 6 * 1024 * 1024 + largeSize = 3 * 1024 * 1024 completeChan = make(chan struct{}) completeOnce sync.Once acq = newAcquireOne(t, &proto.AcquiredJob{ From 83de69c988bb91268e2bcb2e99b6c619025506ab Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 13 May 2025 08:44:33 -0500 Subject: [PATCH 3/3] Update provisionerd/provisionerd_test.go Co-authored-by: Spike Curtis --- provisionerd/provisionerd_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index d8a3fa394edca..7a5d714befa05 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -179,7 +179,7 @@ func TestProvisionerd(t *testing.T) { }) // LargePayloads sends a 3mb tar file to the provisioner. The provisioner also - // returns large payload messages back. The limit should be 10mb, so all + // returns large payload messages back. The limit should be 4mb, so all // these messages should work. t.Run("LargePayloads", func(t *testing.T) { t.Parallel()