From ebf01b5057f9c5a5810d27c62fba091938b8aa94 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 24 Apr 2022 23:02:47 +0000 Subject: [PATCH] fix: Add resiliency to daemon connections Connections could fail when massive payloads were transmitted. This fixes an upstream bug in dRPC where the connection would end with a context canceled if a message was too large. This adds retransmission of completion and failures too. If Coder somehow loses connection with a provisioner daemon, upon the next connection the state will be properly reported. --- coderd/provisionerdaemons.go | 17 ++- coderd/provisionerdaemons_test.go | 48 ++++++++ codersdk/provisionerdaemons.go | 4 +- go.mod | 3 + go.sum | 4 +- provisionerd/provisionerd.go | 179 +++++++++++++++++++++--------- provisionerd/provisionerd_test.go | 151 ++++++++++++++++++++++++- provisionersdk/transport.go | 6 + 8 files changed, 349 insertions(+), 63 deletions(-) create mode 100644 coderd/provisionerdaemons_test.go diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index e4832aa06b9bb..e727dfcea50a2 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -17,6 +17,7 @@ import ( "github.com/moby/moby/pkg/namesgenerator" "github.com/tabbed/pqtype" "golang.org/x/xerrors" + protobuf "google.golang.org/protobuf/proto" "nhooyr.io/websocket" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -27,6 +28,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/parameter" "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" sdkproto "github.com/coder/coder/provisionersdk/proto" ) @@ -47,6 +49,8 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request }) return } + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ ID: uuid.New(), @@ -82,9 +86,17 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) return } - server := drpcserver.New(mux) + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err)) + }, + }) err = server.Serve(r.Context(), session) if err != nil { + api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) return } @@ -253,6 +265,9 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty default: return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod)) } + if protobuf.Size(protoJob) > provisionersdk.MaxMessageSize { + return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), provisionersdk.MaxMessageSize)) + } return protoJob, err } diff --git a/coderd/provisionerdaemons_test.go b/coderd/provisionerdaemons_test.go new file mode 100644 index 0000000000000..01e9b2dd1abc8 --- /dev/null +++ b/coderd/provisionerdaemons_test.go @@ -0,0 +1,48 @@ +package coderd_test + +import ( + "context" + "crypto/rand" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionersdk" +) + +func TestProvisionerDaemons(t *testing.T) { + t.Parallel() + t.Run("PayloadTooBig", func(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + // Takes too long to allocate memory on Windows! + t.Skip() + } + client := coderdtest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + coderdtest.NewProvisionerDaemon(t, client) + data := make([]byte, provisionersdk.MaxMessageSize) + rand.Read(data) + resp, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data) + require.NoError(t, err) + t.Log(resp.Hash) + + version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{ + StorageMethod: database.ProvisionerStorageMethodFile, + StorageSource: resp.Hash, + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + var err error + version, err = client.TemplateVersion(context.Background(), version.ID) + require.NoError(t, err) + return version.Job.Error != "" + }, 5*time.Second, 25*time.Millisecond) + }) +} diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index de59ee9e99b16..fba4a97cf0353 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -70,8 +70,8 @@ func (c *Client) ListenProvisionerDaemon(ctx context.Context) (proto.DRPCProvisi } return nil, readBodyAsError(res) } - // Allow _somewhat_ large payloads. - conn.SetReadLimit((1 << 20) * 2) + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) config := yamux.DefaultConfig() config.LogOutput = io.Discard diff --git a/go.mod b/go.mod index 56615fd79dc7a..e6d375eb63969 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,9 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220 // Required until https://github.com/briandowns/spinner/pull/136 is merged. replace github.com/briandowns/spinner => github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e +// Required until https://github.com/storj/drpc/pull/31 is merged. +replace storj.io/drpc => github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff + // opencensus-go leaks a goroutine by default. replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b diff --git a/go.sum b/go.sum index 3a237b6c71cc8..1740e48fab3a0 100644 --- a/go.sum +++ b/go.sum @@ -1107,6 +1107,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4= +github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff h1:7qg425aXdULnZWCCQNPOzHO7c+M6BpbTfOUJLrk5+3w= +github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg= github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b h1:1Y1X6aR78kMEQE1iCjQodB3lA7VO4jB88Wf8ZrzXSsA= github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY= @@ -2544,5 +2546,3 @@ sigs.k8s.io/structured-merge-diff/v4 v4.0.3/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK sigs.k8s.io/structured-merge-diff/v4 v4.1.0/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= -storj.io/drpc v0.0.30 h1:jqPe4T9KEu3CDBI05A2hCMgMSHLtd/E0N0yTF9QreIE= -storj.io/drpc v0.0.30/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg= diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index f36893122457d..add6f9b4c2423 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -68,8 +68,8 @@ func New(clientDialer Dialer, opts *Options) *Server { clientDialer: clientDialer, opts: opts, - closeCancel: ctxCancel, - closed: make(chan struct{}), + closeContext: ctx, + closeCancel: ctxCancel, shutdown: make(chan struct{}), @@ -87,13 +87,13 @@ type Server struct { opts *Options clientDialer Dialer - client proto.DRPCProvisionerDaemonClient + clientValue atomic.Value // Locked when closing the daemon. - closeMutex sync.Mutex - closeCancel context.CancelFunc - closed chan struct{} - closeError error + closeMutex sync.Mutex + closeContext context.Context + closeCancel context.CancelFunc + closeError error shutdownMutex sync.Mutex shutdown chan struct{} @@ -108,11 +108,10 @@ type Server struct { // Connect establishes a connection to coderd. func (p *Server) connect(ctx context.Context) { - var err error // An exponential back-off occurs when the connection is failing to dial. // This is to prevent server spam in case of a coderd outage. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - p.client, err = p.clientDialer(ctx) + client, err := p.clientDialer(ctx) if err != nil { if errors.Is(err, context.Canceled) { return @@ -126,6 +125,7 @@ func (p *Server) connect(ctx context.Context) { p.closeMutex.Unlock() continue } + p.clientValue.Store(client) p.opts.Logger.Debug(context.Background(), "connected") break } @@ -139,10 +139,14 @@ func (p *Server) connect(ctx context.Context) { if p.isClosed() { return } + client, ok := p.client() + if !ok { + return + } select { - case <-p.closed: + case <-p.closeContext.Done(): return - case <-p.client.DRPCConn().Closed(): + case <-client.DRPCConn().Closed(): // We use the update stream to detect when the connection // has been interrupted. This works well, because logs need // to buffer if a job is running in the background. @@ -158,10 +162,14 @@ func (p *Server) connect(ctx context.Context) { ticker := time.NewTicker(p.opts.PollInterval) defer ticker.Stop() for { + client, ok := p.client() + if !ok { + return + } select { - case <-p.closed: + case <-p.closeContext.Done(): return - case <-p.client.DRPCConn().Closed(): + case <-client.DRPCConn().Closed(): return case <-ticker.C: p.acquireJob(ctx) @@ -170,6 +178,15 @@ func (p *Server) connect(ctx context.Context) { }() } +func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) { + rawClient := p.clientValue.Load() + if rawClient == nil { + return nil, false + } + client, ok := rawClient.(proto.DRPCProvisionerDaemonClient) + return client, ok +} + func (p *Server) isRunningJob() bool { select { case <-p.jobRunning: @@ -195,7 +212,11 @@ func (p *Server) acquireJob(ctx context.Context) { return } var err error - job, err := p.client.AcquireJob(ctx, &proto.Empty{}) + client, ok := p.client() + if !ok { + return + } + job, err := client.AcquireJob(ctx, &proto.Empty{}) if err != nil { if errors.Is(err, context.Canceled) { return @@ -231,7 +252,7 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { defer ticker.Stop() for { select { - case <-p.closed: + case <-p.closeContext.Done(): return case <-ctx.Done(): return @@ -241,9 +262,16 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { return case <-ticker.C: } - resp, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + client, ok := p.client() + if !ok { + continue + } + resp, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.JobId, }) + if errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, io.EOF) { + continue + } if err != nil { p.failActiveJobf("send periodic update: %s", err) return @@ -297,7 +325,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { return } - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + client, ok := p.client() + if !ok { + p.failActiveJobf("client disconnected") + return + } + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -387,10 +420,14 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { return } + client, ok = p.client() + if !ok { + return + } // Ensure the job is still running to output. // It's possible the job has failed. if p.isRunningJob() { - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -409,7 +446,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { } func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - _, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + client, ok := p.client() + if !ok { + p.failActiveJobf("client disconnected") + return + } + _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -429,7 +471,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd return } - updateResponse, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + updateResponse, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.JobId, ParameterSchemas: parameterSchemas, }) @@ -450,7 +492,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd } } - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -471,7 +513,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd p.failActiveJobf("template import provision for start: %s", err) return } - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -493,7 +535,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd return } - _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + p.completeJob(&proto.CompletedJob{ JobId: job.JobId, Type: &proto.CompletedJob_TemplateImport_{ TemplateImport: &proto.CompletedJob_TemplateImport{ @@ -502,14 +544,14 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd }, }, }) - if err != nil { - p.failActiveJobf("complete job: %s", err) - return - } } // Parses parameter schemas from source. func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) ([]*sdkproto.ParameterSchema, error) { + client, ok := p.client() + if !ok { + return nil, xerrors.New("client disconnected") + } stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ Directory: p.opts.WorkDirectory, }) @@ -529,7 +571,7 @@ func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkprot slog.F("output", msgType.Log.Output), ) - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, @@ -599,8 +641,11 @@ func (p *Server) runTemplateImportProvision(ctx, shutdown context.Context, provi slog.F("level", msgType.Log.Level), slog.F("output", msgType.Log.Output), ) - - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + client, ok := p.client() + if !ok { + continue + } + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, @@ -638,7 +683,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd stage = "Destroying workspace" } - _, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + client, ok := p.client() + if !ok { + p.failActiveJobf("client disconnected") + return + } + _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.GetJobId(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -699,7 +749,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd slog.F("workspace_build_id", job.GetWorkspaceBuild().WorkspaceBuildId), ) - _, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{ + _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, @@ -729,15 +779,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd return } - p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete", - slog.F("resource_count", len(msgType.Complete.Resources)), - slog.F("resources", msgType.Complete.Resources), - slog.F("state_length", len(msgType.Complete.State)), - ) - - // Complete job may need to be async if we disconnected... - // When we reconnect we can flush any of these cached values. - _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + p.completeJob(&proto.CompletedJob{ JobId: job.JobId, Type: &proto.CompletedJob_WorkspaceBuild_{ WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ @@ -746,11 +788,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd }, }, }) - if err != nil { - p.failActiveJobf("complete job: %s", err) - return - } - // Return so we stop looping! + p.opts.Logger.Info(context.Background(), "provision successful; marked job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + // Stop looping! return default: p.failActiveJobf("invalid message type %T received from provisioner", msg.Type) @@ -759,6 +802,26 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd } } +func (p *Server) completeJob(job *proto.CompletedJob) { + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { + client, ok := p.client() + if !ok { + continue + } + // Complete job may need to be async if we disconnected... + // When we reconnect we can flush any of these cached values. + _, err := client.CompleteJob(p.closeContext, job) + if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { + continue + } + if err != nil { + p.opts.Logger.Warn(p.closeContext, "failed to complete job", slog.Error(err)) + return + } + break + } +} + func (p *Server) failActiveJobf(format string, args ...interface{}) { p.failActiveJob(&proto.FailedJob{ Error: fmt.Sprintf(format, args...), @@ -786,18 +849,31 @@ func (p *Server) failActiveJob(failedJob *proto.FailedJob) { slog.F("job_id", p.jobID), ) failedJob.JobId = p.jobID - _, err := p.client.FailJob(context.Background(), failedJob) - if err != nil { - p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err)) + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { + client, ok := p.client() + if !ok { + continue + } + _, err := client.FailJob(p.closeContext, failedJob) + if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { + continue + } + if err != nil { + if p.isClosed() { + return + } + p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err)) + return + } + p.opts.Logger.Debug(context.Background(), "marked running job as failed") return } - p.opts.Logger.Debug(context.Background(), "marked running job as failed") } // isClosed returns whether the API is closed or not. func (p *Server) isClosed() bool { select { - case <-p.closed: + case <-p.closeContext.Done(): return true default: return false @@ -847,7 +923,6 @@ func (p *Server) closeWithError(err error) error { return p.closeError } p.closeError = err - close(p.closed) errMsg := "provisioner daemon was shutdown gracefully" if err != nil { diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index c56c3a6378119..bae6de46933af 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/hashicorp/yamux" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" @@ -126,6 +127,7 @@ func TestProvisionerd(t *testing.T) { // Ensures tars with "../../../etc/passwd" as the path // are not allowed to run, and will fail the job. t.Parallel() + var complete sync.Once completeChan := make(chan struct{}) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ @@ -145,7 +147,9 @@ func TestProvisionerd(t *testing.T) { }, updateJob: noopUpdateJob, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { - close(completeChan) + complete.Do(func() { + close(completeChan) + }) return &proto.Empty{}, nil }, }), nil @@ -158,6 +162,7 @@ func TestProvisionerd(t *testing.T) { t.Run("RunningPeriodicUpdate", func(t *testing.T) { t.Parallel() + var complete sync.Once completeChan := make(chan struct{}) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ @@ -176,11 +181,9 @@ func TestProvisionerd(t *testing.T) { }, nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { - select { - case <-completeChan: - default: + complete.Do(func() { close(completeChan) - } + }) return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { @@ -492,6 +495,7 @@ func TestProvisionerd(t *testing.T) { t.Run("ShutdownFromJob", func(t *testing.T) { t.Parallel() + var updated sync.Once updateChan := make(chan struct{}) completeChan := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { @@ -513,7 +517,9 @@ func TestProvisionerd(t *testing.T) { updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) > 0 && update.Logs[0].Source == proto.LogSource_PROVISIONER { // Close on a log so we know when the job is in progress! - close(updateChan) + updated.Do(func() { + close(updateChan) + }) } return &proto.UpdateJobResponse{ Canceled: true, @@ -558,6 +564,139 @@ func TestProvisionerd(t *testing.T) { <-completeChan require.NoError(t, server.Close()) }) + + t.Run("ReconnectAndFail", func(t *testing.T) { + t.Parallel() + var second atomic.Bool + failChan := make(chan struct{}) + failedChan := make(chan struct{}) + completeChan := make(chan struct{}) + server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if second.Load() { + return &proto.AcquiredJob{}, nil + } + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + TemplateSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + Metadata: &sdkproto.Provision_Metadata{}, + }, + }, + }, nil + }, + updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + return &proto.UpdateJobResponse{}, nil + }, + failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + if second.Load() { + close(completeChan) + return &proto.Empty{}, nil + } + close(failChan) + <-failedChan + return &proto.Empty{}, nil + }, + }) + if !second.Load() { + go func() { + <-failChan + _ = client.DRPCConn().Close() + second.Store(true) + close(failedChan) + }() + } + return client, nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { + // Ignore the first provision message! + _, _ = stream.Recv() + return stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{ + Error: "some error", + }, + }, + }) + }, + }), + }) + <-completeChan + require.NoError(t, server.Close()) + }) + + t.Run("ReconnectAndComplete", func(t *testing.T) { + t.Parallel() + var second atomic.Bool + failChan := make(chan struct{}) + failedChan := make(chan struct{}) + completeChan := make(chan struct{}) + server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if second.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + TemplateSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + Metadata: &sdkproto.Provision_Metadata{}, + }, + }, + }, nil + }, + failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + return nil, yamux.ErrSessionShutdown + }, + updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + return &proto.UpdateJobResponse{}, nil + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + if second.Load() { + return &proto.Empty{}, nil + } + close(failChan) + <-failedChan + return &proto.Empty{}, nil + }, + }) + if !second.Load() { + go func() { + <-failChan + _ = client.DRPCConn().Close() + second.Store(true) + close(failedChan) + }() + } + return client, nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { + // Ignore the first provision message! + _, _ = stream.Recv() + return stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + }, + }), + }) + <-completeChan + require.NoError(t, server.Close()) + }) } // Creates an in-memory tar of the files provided. diff --git a/provisionersdk/transport.go b/provisionersdk/transport.go index 3933aeb5efd7b..8e1a0069cf17a 100644 --- a/provisionersdk/transport.go +++ b/provisionersdk/transport.go @@ -9,6 +9,12 @@ import ( "storj.io/drpc/drpcconn" ) +const ( + // MaxMessageSize is the maximum payload size that can be + // transported without error. + MaxMessageSize = 4 << 20 +) + // TransportPipe creates an in-memory pipe for dRPC transport. func TransportPipe() (*yamux.Session, *yamux.Session) { clientReader, clientWriter := io.Pipe()