From a0f63ce85f7c796387379b1459682c5f4dd5867d Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 1 Feb 2022 14:12:48 +0000 Subject: [PATCH] feat: Add provisionerd service This brings an async service that parses and provisions to life! It's separated from coderd intentionally to allow for simpler testing. Integration with coderd will come in another PR! --- go.mod | 2 + go.sum | 4 + peerbroker/dial_test.go | 3 +- peerbroker/listen.go | 6 +- peerbroker/listen_test.go | 3 +- provisioner/terraform/parse_test.go | 5 +- provisioner/terraform/provision_test.go | 5 +- provisionerd/provisionerd.go | 505 ++++++++++++++++++++++++ provisionerd/provisionerd_test.go | 421 ++++++++++++++++++++ provisionersdk/serve.go | 31 +- provisionersdk/serve_test.go | 7 +- provisionersdk/transport.go | 76 ++-- 12 files changed, 1018 insertions(+), 50 deletions(-) create mode 100644 provisionerd/provisionerd.go create mode 100644 provisionerd/provisionerd_test.go diff --git a/go.mod b/go.mod index 45cf8c8cfca3a..68afc59bf71ac 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/te require ( cdr.dev/slog v1.4.1 + github.com/coder/retry v1.3.0 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -18,6 +19,7 @@ require ( github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-exec v0.15.0 + github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 github.com/justinas/nosurf v1.1.1 github.com/lib/pq v1.10.4 github.com/moby/moby v20.10.12+incompatible diff --git a/go.sum b/go.sum index cf80329fa859d..fbba19d150b6b 100644 --- a/go.sum +++ b/go.sum @@ -232,6 +232,8 @@ github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWH github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/cockroach-go/v2 v2.1.1/go.mod h1:7NtUnP6eK+l6k483WSYNrq3Kb23bWV10IRV1TyeSpwM= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= +github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= @@ -688,6 +690,8 @@ github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKEN github.com/hashicorp/serf v0.9.6/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpTwn9UV4= github.com/hashicorp/terraform-json v0.13.0 h1:Li9L+lKD1FO5RVFRM1mMMIBDoUHslOniyEi5CM+FWGY= github.com/hashicorp/terraform-json v0.13.0/go.mod h1:y5OdLBCT+rxbwnpxZs9kGL7R9ExU76+cpdY8zHwoazk= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 h1:xixZ2bWeofWV68J+x6AzmKuVM/JWCQwkWm6GW/MUR6I= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go index 30066b8d82397..537dc157e0a79 100644 --- a/peerbroker/dial_test.go +++ b/peerbroker/dial_test.go @@ -7,7 +7,6 @@ import ( "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -37,7 +36,7 @@ func TestDial(t *testing.T) { }) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ diff --git a/peerbroker/listen.go b/peerbroker/listen.go index fa68023689c7e..8e92fe7a7c82d 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -4,12 +4,12 @@ import ( "context" "errors" "io" + "net" "reflect" "sync" "github.com/pion/webrtc/v3" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -19,7 +19,7 @@ import ( // Listen consumes the transport as the server-side of the PeerBroker dRPC service. // The Accept function must be serviced, or new connections will hang. -func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) { +func Listen(connListener net.Listener, opts *peer.ConnOptions) (*Listener, error) { ctx, cancelFunc := context.WithCancel(context.Background()) listener := &Listener{ connectionChannel: make(chan *peer.Conn), @@ -39,7 +39,7 @@ func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) } srv := drpcserver.New(mux) go func() { - err := srv.ServeOne(ctx, transport) + err := srv.Serve(ctx, connListener) _ = listener.closeWithError(err) }() diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go index c66d8a480a176..81582a91d4b84 100644 --- a/peerbroker/listen_test.go +++ b/peerbroker/listen_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" @@ -27,7 +26,7 @@ func TestListen(t *testing.T) { listener, err := peerbroker.Listen(server, nil) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, nil, nil) diff --git a/provisioner/terraform/parse_test.go b/provisioner/terraform/parse_test.go index e678d1d36c674..9d5bec03338f8 100644 --- a/provisioner/terraform/parse_test.go +++ b/provisioner/terraform/parse_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -30,12 +29,12 @@ func TestParse(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisioner/terraform/provision_test.go b/provisioner/terraform/provision_test.go index 07ac1bde9dace..27117daa8464a 100644 --- a/provisioner/terraform/provision_test.go +++ b/provisioner/terraform/provision_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -29,12 +28,12 @@ func TestProvision(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go new file mode 100644 index 0000000000000..3fe7ce793c65e --- /dev/null +++ b/provisionerd/provisionerd.go @@ -0,0 +1,505 @@ +package provisionerd + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "time" + + "go.uber.org/atomic" + + "cdr.dev/slog" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/retry" +) + +// Dialer represents the function to create a daemon client connection. +type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) + +// Provisioners maps provisioner ID to implementation. +type Provisioners map[string]sdkproto.DRPCProvisionerClient + +// Options provides customizations to the behavior of a provisioner daemon. +type Options struct { + Logger slog.Logger + + PollInterval time.Duration + Provisioners Provisioners + WorkDirectory string +} + +// New creates and starts a provisioner daemon. +func New(clientDialer Dialer, opts *Options) io.Closer { + if opts.PollInterval == 0 { + opts.PollInterval = 5 * time.Second + } + ctx, ctxCancel := context.WithCancel(context.Background()) + daemon := &provisionerDaemon{ + clientDialer: clientDialer, + opts: opts, + + closeContext: ctx, + closeCancel: ctxCancel, + closed: make(chan struct{}), + } + go daemon.connect(ctx) + return daemon +} + +type provisionerDaemon struct { + opts *Options + + clientDialer Dialer + connectMutex sync.Mutex + client proto.DRPCProvisionerDaemonClient + updateStream proto.DRPCProvisionerDaemon_UpdateJobClient + + // Only use for ending a job. + closeContext context.Context + closeCancel context.CancelFunc + closed chan struct{} + closeMutex sync.Mutex + closeError error + + // Lock on acquiring a job so two can't happen at once...? + // If a single cancel can happen, but an acquire could happen? + + // Lock on acquire + // Use atomic for checking if we are running a job + // Use atomic for checking if we are canceling job + // If we're running a job, wait for the done chan in + // close. + + acquiredJob *proto.AcquiredJob + acquiredJobMutex sync.Mutex + acquiredJobCancel context.CancelFunc + acquiredJobCancelled atomic.Bool + acquiredJobRunning atomic.Bool + acquiredJobDone chan struct{} +} + +// Connnect establishes a connection to coderd. +func (p *provisionerDaemon) connect(ctx context.Context) { + p.connectMutex.Lock() + defer p.connectMutex.Unlock() + + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + p.client, err = p.clientDialer(ctx) + if err != nil { + // Warn + p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + p.updateStream, err = p.client.UpdateJob(ctx) + if err != nil { + p.opts.Logger.Warn(context.Background(), "create update job stream", slog.Error(err)) + continue + } + p.opts.Logger.Debug(context.Background(), "connected") + break + } + + go func() { + if p.isClosed() { + return + } + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + // We use the update stream to detect when the connection + // has been interrupted. This works well, because logs need + // to buffer if a job is running in the background. + p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err())) + p.connect(ctx) + } + }() + + go func() { + if p.isClosed() { + return + } + ticker := time.NewTicker(p.opts.PollInterval) + defer ticker.Stop() + for { + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + return + case <-ticker.C: + p.acquireJob(ctx) + } + } + }() +} + +// Locks a job in the database, and runs it! +func (p *provisionerDaemon) acquireJob(ctx context.Context) { + p.acquiredJobMutex.Lock() + defer p.acquiredJobMutex.Unlock() + if p.isRunningJob() { + p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running") + return + } + var err error + p.acquiredJob, err = p.client.AcquireJob(ctx, &proto.Empty{}) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + p.opts.Logger.Warn(context.Background(), "acquire job", slog.Error(err)) + return + } + if p.isClosed() { + return + } + if p.acquiredJob.JobId == "" { + p.opts.Logger.Debug(context.Background(), "no jobs available") + return + } + ctx, p.acquiredJobCancel = context.WithCancel(ctx) + p.acquiredJobCancelled.Store(false) + p.acquiredJobRunning.Store(true) + p.acquiredJobDone = make(chan struct{}) + + p.opts.Logger.Info(context.Background(), "acquired job", + slog.F("organization_name", p.acquiredJob.OrganizationName), + slog.F("project_name", p.acquiredJob.ProjectName), + slog.F("username", p.acquiredJob.UserName), + slog.F("provisioner", p.acquiredJob.Provisioner), + ) + + go p.runJob(ctx) +} + +func (p *provisionerDaemon) isRunningJob() bool { + return p.acquiredJobRunning.Load() +} + +func (p *provisionerDaemon) runJob(ctx context.Context) { + go func() { + select { + case <-p.closed: + case <-ctx.Done(): + } + + // Cleanup the work directory after execution. + err := os.RemoveAll(p.opts.WorkDirectory) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) + return + } + p.opts.Logger.Debug(ctx, "cleaned up work directory") + p.acquiredJobMutex.Lock() + defer p.acquiredJobMutex.Unlock() + p.acquiredJobRunning.Store(false) + close(p.acquiredJobDone) + }() + // It's safe to cast this ProvisionerType. This data is coming directly from coderd. + provisioner, hasProvisioner := p.opts.Provisioners[p.acquiredJob.Provisioner] + if !hasProvisioner { + p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", p.acquiredJob.Provisioner)) + return + } + + err := os.MkdirAll(p.opts.WorkDirectory, 0600) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) + return + } + + p.opts.Logger.Info(ctx, "unpacking project source archive", slog.F("size_bytes", len(p.acquiredJob.ProjectSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(p.acquiredJob.ProjectSourceArchive)) + for { + header, err := reader.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) + return + } + // #nosec + path := filepath.Join(p.opts.WorkDirectory, header.Name) + if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) { + p.cancelActiveJob("tar attempts to target relative upper directory") + return + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0600 + } + switch header.Typeflag { + case tar.TypeDir: + err = os.MkdirAll(path, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path)) + case tar.TypeReg: + file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create file %q: %s", path, err)) + return + } + // Max file size of 10MB. + size, err := io.CopyN(file, reader, (1<<20)*10) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) + return + } + err = file.Close() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", path), + slog.F("mode", mode), + ) + } + } + + switch jobType := p.acquiredJob.Type.(type) { + case *proto.AcquiredJob_ProjectImport_: + p.opts.Logger.Debug(context.Background(), "acquired job is project import", + slog.F("project_history_name", jobType.ProjectImport.ProjectHistoryName), + ) + + p.runProjectImport(ctx, provisioner, jobType) + case *proto.AcquiredJob_WorkspaceProvision_: + p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", + slog.F("workspace_name", jobType.WorkspaceProvision.WorkspaceName), + slog.F("state_length", len(jobType.WorkspaceProvision.State)), + slog.F("parameters", jobType.WorkspaceProvision.ParameterValues), + ) + + p.runWorkspaceProvision(ctx, provisioner, jobType) + default: + p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(p.acquiredJob.Type).String())) + return + } + + p.acquiredJobCancel() + p.opts.Logger.Info(context.Background(), "completed job") +} + +func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_ProjectImport_) { + stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ + Directory: p.opts.WorkDirectory, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) + return + } + defer stream.Close() + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Parse_Response_Log: + p.opts.Logger.Debug(context.Background(), "parse job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("project_history_id", job.ProjectImport.ProjectHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.acquiredJob.JobId, + ProjectImportLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) + return + } + case *sdkproto.Parse_Response_Complete: + _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + JobId: p.acquiredJob.JobId, + Type: &proto.CompletedJob_ProjectImport_{ + ProjectImport: &proto.CompletedJob_ProjectImport{ + ParameterSchemas: msgType.Complete.ParameterSchemas, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_WorkspaceProvision_) { + stream, err := provisioner.Provision(ctx, &sdkproto.Provision_Request{ + Directory: p.opts.WorkDirectory, + ParameterValues: job.WorkspaceProvision.ParameterValues, + State: job.WorkspaceProvision.State, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) + return + } + defer stream.Close() + + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + p.opts.Logger.Debug(context.Background(), "workspace provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("workspace_history_id", job.WorkspaceProvision.WorkspaceHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.acquiredJob.JobId, + WorkspaceProvisionLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) + return + } + case *sdkproto.Provision_Response_Complete: + p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + + // Complete job may need to be async if we disconnected... + // When we reconnect we can flush any of these cached values. + _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + JobId: p.acquiredJob.JobId, + Type: &proto.CompletedJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.CompletedJob_WorkspaceProvision{ + State: msgType.Complete.State, + Resources: msgType.Complete.Resources, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) cancelActiveJob(errMsg string) { + if !p.isRunningJob() { + p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) + return + } + if p.acquiredJobCancelled.Load() { + return + } + p.acquiredJobCancelled.Store(true) + p.acquiredJobCancel() + p.opts.Logger.Info(context.Background(), "canceling running job", + slog.F("error_message", errMsg), + slog.F("job_id", p.acquiredJob.JobId), + ) + _, err := p.client.CancelJob(p.closeContext, &proto.CancelledJob{ + JobId: p.acquiredJob.JobId, + Error: fmt.Sprintf("provisioner daemon: %s", errMsg), + }) + if err != nil { + p.opts.Logger.Warn(context.Background(), "failed to notify of cancel; job is no longer running", slog.Error(err)) + return + } + p.opts.Logger.Debug(context.Background(), "canceled running job") +} + +// isClosed returns whether the API is closed or not. +func (p *provisionerDaemon) isClosed() bool { + select { + case <-p.closed: + return true + default: + return false + } +} + +// Close ends the provisioner. It will mark any running jobs as canceled. +func (p *provisionerDaemon) Close() error { + return p.closeWithError(nil) +} + +// closeWithError closes the provisioner; subsequent reads/writes will return the error err. +func (p *provisionerDaemon) closeWithError(err error) error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.isClosed() { + return p.closeError + } + + if p.isRunningJob() { + errMsg := "provisioner daemon was shutdown gracefully" + if err != nil { + errMsg = err.Error() + } + if !p.acquiredJobCancelled.Load() { + p.cancelActiveJob(errMsg) + } + <-p.acquiredJobDone + } + + p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) + p.closeError = err + close(p.closed) + p.closeCancel() + + if p.updateStream != nil { + _ = p.client.DRPCConn().Close() + _ = p.updateStream.Close() + } + + return err +} diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go new file mode 100644 index 0000000000000..8148c5369d938 --- /dev/null +++ b/provisionerd/provisionerd_test.go @@ -0,0 +1,421 @@ +package provisionerd_test + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/goleak" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/provisionerd" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestProvisionerd(t *testing.T) { + t.Parallel() + + noopUpdateJob := func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + <-stream.Context().Done() + return nil + } + + t.Run("InstantClose", func(t *testing.T) { + t.Parallel() + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil + }, provisionerd.Provisioners{}) + require.NoError(t, closer.Close()) + }) + + t.Run("ConnectErrorClose", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + defer close(completeChan) + return nil, errors.New("an error") + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("AcquireEmptyJob", func(t *testing.T) { + // The provisioner daemon is supposed to skip the job acquire if + // the job provided is empty. This is to show it successfully + // tried to get a job, but none were available. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + acquireJobAttempt := 0 + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if acquireJobAttempt == 1 { + close(completeChan) + } + acquireJobAttempt++ + return &proto.AcquiredJob{}, nil + }, + updateJob: noopUpdateJob, + }), nil + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("CloseCancelsJob", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + var closer io.Closer + var closerMutex sync.Mutex + closerMutex.Lock() + closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + closerMutex.Lock() + defer closerMutex.Unlock() + return closer.Close() + }, + }), + }) + closerMutex.Unlock() + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("MaliciousTar", func(t *testing.T) { + // Ensures tars with "../../../etc/passwd" as the path + // are not allowed to run, and will fail the job. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "../../../etc/passwd": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), + }) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("ProjectImport", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.ProjectImportLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt")) + require.NoError(t, err) + require.Equal(t, "content", string(data)) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_INFO, + Output: "hello", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Complete{ + Complete: &sdkproto.Parse_Complete{ + ParameterSchemas: []*sdkproto.ParameterSchema{}, + }, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) + + t.Run("WorkspaceProvision", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.AcquiredJob_WorkspaceProvision{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.WorkspaceProvisionLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + err := stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_DEBUG, + Output: "wow", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) +} + +// Creates an in-memory tar of the files provided. +func createTar(t *testing.T, files map[string]string) []byte { + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + for path, content := range files { + err := writer.WriteHeader(&tar.Header{ + Name: path, + Size: int64(len(content)), + }) + require.NoError(t, err) + + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + + err := writer.Flush() + require.NoError(t, err) + return buffer.Bytes() +} + +// Creates a provisionerd implementation with the provided dialer and provisioners. +func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) io.Closer { + closer := provisionerd.New(dialer, &provisionerd.Options{ + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + Provisioners: provisioners, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + +// Creates a provisionerd protobuf client that's connected +// to the server implementation provided. +func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { + clientPipe, serverPipe := provisionersdk.TransportPipe() + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := proto.DRPCRegisterProvisionerDaemon(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + _ = srv.Serve(ctx, serverPipe) + }() + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientPipe)) +} + +// Creates a provisioner protobuf client that's connected +// to the server implementation provided. +func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient { + clientPipe, serverPipe := provisionersdk.TransportPipe() + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := sdkproto.DRPCRegisterProvisioner(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + _ = srv.Serve(ctx, serverPipe) + }() + return sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(clientPipe)) +} + +type provisionerTestServer struct { + parse func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error + provision func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error +} + +func (p *provisionerTestServer) Parse(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + return p.parse(request, stream) +} + +func (p *provisionerTestServer) Provision(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + return p.provision(request, stream) +} + +// Fulfills the protobuf interface for a ProvisionerDaemon with +// passable functions for dynamic functionality. +type provisionerDaemonTestServer struct { + acquireJob func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) + updateJob func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error + cancelJob func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) + completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) +} + +func (p *provisionerDaemonTestServer) AcquireJob(ctx context.Context, empty *proto.Empty) (*proto.AcquiredJob, error) { + return p.acquireJob(ctx, empty) +} + +func (p *provisionerDaemonTestServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + return p.updateJob(stream) +} + +func (p *provisionerDaemonTestServer) CancelJob(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + return p.cancelJob(ctx, job) +} + +func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + return p.completeJob(ctx, job) +} diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 9b7952001f96c..1fbe50d506850 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -4,19 +4,22 @@ import ( "context" "errors" "io" + "net" + "os" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" + "github.com/hashicorp/yamux" + "github.com/coder/coder/provisionersdk/proto" ) // ServeOptions are configurations to serve a provisioner. type ServeOptions struct { - // Transport specifies a custom transport to serve the dRPC connection. - Transport drpc.Transport + // Conn specifies a custom transport to serve the dRPC connection. + Listener net.Listener } // Serve starts a dRPC connection for the provisioner and transport provided. @@ -25,8 +28,17 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser options = &ServeOptions{} } // Default to using stdio. - if options.Transport == nil { - options.Transport = TransportStdio() + if options.Listener == nil { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + stdio, err := yamux.Server(readWriteCloser{ + ReadCloser: os.Stdin, + Writer: os.Stdout, + }, config) + if err != nil { + return xerrors.Errorf("create yamux: %w", err) + } + options.Listener = stdio } // dRPC is a drop-in replacement for gRPC with less generated code, and faster transports. @@ -40,16 +52,15 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser // Only serve a single connection on the transport. // Transports are not multiplexed, and provisioners are // short-lived processes that can be executed concurrently. - err = srv.ServeOne(ctx, options.Transport) + err = srv.Serve(ctx, options.Listener) if err != nil { if errors.Is(err, context.Canceled) { return nil } if errors.Is(err, io.ErrClosedPipe) { - // This may occur if the transport on either end is - // closed before the context. It's fine to return - // nil here, since the server has nothing to - // communicate with. + return nil + } + if errors.Is(err, yamux.ErrSessionShutdown) { return nil } return xerrors.Errorf("serve transport: %w", err) diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index 08ac393eb8dfc..cf2dd7517df82 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "storj.io/drpc/drpcerr" "github.com/coder/coder/provisionersdk" @@ -29,12 +28,12 @@ func TestProvisionerSDK(t *testing.T) { defer cancelFunc() go func() { err := provisionersdk.Serve(ctx, &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) stream, err := api.Parse(context.Background(), &proto.Parse_Request{}) require.NoError(t, err) _, err = stream.Recv() @@ -47,7 +46,7 @@ func TestProvisionerSDK(t *testing.T) { _ = server.Close() err := provisionersdk.Serve(context.Background(), &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }) diff --git a/provisionersdk/transport.go b/provisionersdk/transport.go index c01a7ab8269e9..7fd87839d174b 100644 --- a/provisionersdk/transport.go +++ b/provisionersdk/transport.go @@ -1,44 +1,74 @@ package provisionersdk import ( + "context" "io" - "os" + "github.com/hashicorp/yamux" "storj.io/drpc" + "storj.io/drpc/drpcconn" ) -// Transport creates a dRPC transport using stdin and stdout. -func TransportStdio() drpc.Transport { - return &transport{ - in: os.Stdin, - out: os.Stdout, +// TransportPipe creates an in-memory pipe for dRPC transport. +func TransportPipe() (*yamux.Session, *yamux.Session) { + clientReader, clientWriter := io.Pipe() + serverReader, serverWriter := io.Pipe() + yamuxConfig := yamux.DefaultConfig() + yamuxConfig.LogOutput = io.Discard + client, err := yamux.Client(&readWriteCloser{ + ReadCloser: clientReader, + Writer: serverWriter, + }, yamuxConfig) + if err != nil { + panic(err) + } + + server, err := yamux.Server(&readWriteCloser{ + ReadCloser: serverReader, + Writer: clientWriter, + }, yamuxConfig) + if err != nil { + panic(err) } + return client, server } -// TransportPipe creates an in-memory pipe for dRPC transport. -func TransportPipe() (drpc.Transport, drpc.Transport) { - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - clientTransport := &transport{clientReader, clientWriter} - serverTransport := &transport{serverReader, serverWriter} +// Conn returns a multiplexed dRPC connection from a yamux session. +func Conn(session *yamux.Session) drpc.Conn { + return &multiplexedDRPC{session} +} - return clientTransport, serverTransport +type readWriteCloser struct { + io.ReadCloser + io.Writer } -// transport wraps an input and output to pipe data. -type transport struct { - in io.ReadCloser - out io.Writer +// Allows concurrent requests on a single dRPC connection. +// Required for calling functions concurrently. +type multiplexedDRPC struct { + session *yamux.Session } -func (s *transport) Read(data []byte) (int, error) { - return s.in.Read(data) +func (m *multiplexedDRPC) Close() error { + return m.session.Close() } -func (s *transport) Write(data []byte) (int, error) { - return s.out.Write(data) +func (m *multiplexedDRPC) Closed() <-chan struct{} { + return m.session.CloseChan() } -func (s *transport) Close() error { - return s.in.Close() +func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { + conn, err := m.session.Open() + if err != nil { + return err + } + return drpcconn.New(conn).Invoke(ctx, rpc, enc, in, out) +} + +func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + conn, err := m.session.Open() + if err != nil { + return nil, err + } + return drpcconn.New(conn).NewStream(ctx, rpc, enc) }