diff --git a/go.mod b/go.mod index 833454182a93c..ace5013149946 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/te require ( cdr.dev/slog v1.4.1 + github.com/coder/retry v1.3.0 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -18,6 +19,7 @@ require ( github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-exec v0.15.0 + github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 github.com/justinas/nosurf v1.1.1 github.com/lib/pq v1.10.4 github.com/moby/moby v20.10.12+incompatible diff --git a/go.sum b/go.sum index 1c1ada5be0ae6..583c2a9236c3a 100644 --- a/go.sum +++ b/go.sum @@ -232,6 +232,8 @@ github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWH github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/cockroach-go/v2 v2.1.1/go.mod h1:7NtUnP6eK+l6k483WSYNrq3Kb23bWV10IRV1TyeSpwM= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= +github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= @@ -688,6 +690,8 @@ github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKEN github.com/hashicorp/serf v0.9.6/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpTwn9UV4= github.com/hashicorp/terraform-json v0.13.0 h1:Li9L+lKD1FO5RVFRM1mMMIBDoUHslOniyEi5CM+FWGY= github.com/hashicorp/terraform-json v0.13.0/go.mod h1:y5OdLBCT+rxbwnpxZs9kGL7R9ExU76+cpdY8zHwoazk= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 h1:xixZ2bWeofWV68J+x6AzmKuVM/JWCQwkWm6GW/MUR6I= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go index 30066b8d82397..537dc157e0a79 100644 --- a/peerbroker/dial_test.go +++ b/peerbroker/dial_test.go @@ -7,7 +7,6 @@ import ( "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -37,7 +36,7 @@ func TestDial(t *testing.T) { }) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ diff --git a/peerbroker/listen.go b/peerbroker/listen.go index fa68023689c7e..8e92fe7a7c82d 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -4,12 +4,12 @@ import ( "context" "errors" "io" + "net" "reflect" "sync" "github.com/pion/webrtc/v3" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -19,7 +19,7 @@ import ( // Listen consumes the transport as the server-side of the PeerBroker dRPC service. // The Accept function must be serviced, or new connections will hang. -func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) { +func Listen(connListener net.Listener, opts *peer.ConnOptions) (*Listener, error) { ctx, cancelFunc := context.WithCancel(context.Background()) listener := &Listener{ connectionChannel: make(chan *peer.Conn), @@ -39,7 +39,7 @@ func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) } srv := drpcserver.New(mux) go func() { - err := srv.ServeOne(ctx, transport) + err := srv.Serve(ctx, connListener) _ = listener.closeWithError(err) }() diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go index c66d8a480a176..81582a91d4b84 100644 --- a/peerbroker/listen_test.go +++ b/peerbroker/listen_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" @@ -27,7 +26,7 @@ func TestListen(t *testing.T) { listener, err := peerbroker.Listen(server, nil) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, nil, nil) diff --git a/provisioner/terraform/parse_test.go b/provisioner/terraform/parse_test.go index e678d1d36c674..9d5bec03338f8 100644 --- a/provisioner/terraform/parse_test.go +++ b/provisioner/terraform/parse_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -30,12 +29,12 @@ func TestParse(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisioner/terraform/provision_test.go b/provisioner/terraform/provision_test.go index 07ac1bde9dace..27117daa8464a 100644 --- a/provisioner/terraform/provision_test.go +++ b/provisioner/terraform/provision_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -29,12 +28,12 @@ func TestProvision(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go new file mode 100644 index 0000000000000..3fe7ce793c65e --- /dev/null +++ b/provisionerd/provisionerd.go @@ -0,0 +1,505 @@ +package provisionerd + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "time" + + "go.uber.org/atomic" + + "cdr.dev/slog" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/retry" +) + +// Dialer represents the function to create a daemon client connection. +type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) + +// Provisioners maps provisioner ID to implementation. +type Provisioners map[string]sdkproto.DRPCProvisionerClient + +// Options provides customizations to the behavior of a provisioner daemon. +type Options struct { + Logger slog.Logger + + PollInterval time.Duration + Provisioners Provisioners + WorkDirectory string +} + +// New creates and starts a provisioner daemon. +func New(clientDialer Dialer, opts *Options) io.Closer { + if opts.PollInterval == 0 { + opts.PollInterval = 5 * time.Second + } + ctx, ctxCancel := context.WithCancel(context.Background()) + daemon := &provisionerDaemon{ + clientDialer: clientDialer, + opts: opts, + + closeContext: ctx, + closeCancel: ctxCancel, + closed: make(chan struct{}), + } + go daemon.connect(ctx) + return daemon +} + +type provisionerDaemon struct { + opts *Options + + clientDialer Dialer + connectMutex sync.Mutex + client proto.DRPCProvisionerDaemonClient + updateStream proto.DRPCProvisionerDaemon_UpdateJobClient + + // Only use for ending a job. + closeContext context.Context + closeCancel context.CancelFunc + closed chan struct{} + closeMutex sync.Mutex + closeError error + + // Lock on acquiring a job so two can't happen at once...? + // If a single cancel can happen, but an acquire could happen? + + // Lock on acquire + // Use atomic for checking if we are running a job + // Use atomic for checking if we are canceling job + // If we're running a job, wait for the done chan in + // close. + + acquiredJob *proto.AcquiredJob + acquiredJobMutex sync.Mutex + acquiredJobCancel context.CancelFunc + acquiredJobCancelled atomic.Bool + acquiredJobRunning atomic.Bool + acquiredJobDone chan struct{} +} + +// Connnect establishes a connection to coderd. +func (p *provisionerDaemon) connect(ctx context.Context) { + p.connectMutex.Lock() + defer p.connectMutex.Unlock() + + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + p.client, err = p.clientDialer(ctx) + if err != nil { + // Warn + p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + p.updateStream, err = p.client.UpdateJob(ctx) + if err != nil { + p.opts.Logger.Warn(context.Background(), "create update job stream", slog.Error(err)) + continue + } + p.opts.Logger.Debug(context.Background(), "connected") + break + } + + go func() { + if p.isClosed() { + return + } + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + // We use the update stream to detect when the connection + // has been interrupted. This works well, because logs need + // to buffer if a job is running in the background. + p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err())) + p.connect(ctx) + } + }() + + go func() { + if p.isClosed() { + return + } + ticker := time.NewTicker(p.opts.PollInterval) + defer ticker.Stop() + for { + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + return + case <-ticker.C: + p.acquireJob(ctx) + } + } + }() +} + +// Locks a job in the database, and runs it! +func (p *provisionerDaemon) acquireJob(ctx context.Context) { + p.acquiredJobMutex.Lock() + defer p.acquiredJobMutex.Unlock() + if p.isRunningJob() { + p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running") + return + } + var err error + p.acquiredJob, err = p.client.AcquireJob(ctx, &proto.Empty{}) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + p.opts.Logger.Warn(context.Background(), "acquire job", slog.Error(err)) + return + } + if p.isClosed() { + return + } + if p.acquiredJob.JobId == "" { + p.opts.Logger.Debug(context.Background(), "no jobs available") + return + } + ctx, p.acquiredJobCancel = context.WithCancel(ctx) + p.acquiredJobCancelled.Store(false) + p.acquiredJobRunning.Store(true) + p.acquiredJobDone = make(chan struct{}) + + p.opts.Logger.Info(context.Background(), "acquired job", + slog.F("organization_name", p.acquiredJob.OrganizationName), + slog.F("project_name", p.acquiredJob.ProjectName), + slog.F("username", p.acquiredJob.UserName), + slog.F("provisioner", p.acquiredJob.Provisioner), + ) + + go p.runJob(ctx) +} + +func (p *provisionerDaemon) isRunningJob() bool { + return p.acquiredJobRunning.Load() +} + +func (p *provisionerDaemon) runJob(ctx context.Context) { + go func() { + select { + case <-p.closed: + case <-ctx.Done(): + } + + // Cleanup the work directory after execution. + err := os.RemoveAll(p.opts.WorkDirectory) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) + return + } + p.opts.Logger.Debug(ctx, "cleaned up work directory") + p.acquiredJobMutex.Lock() + defer p.acquiredJobMutex.Unlock() + p.acquiredJobRunning.Store(false) + close(p.acquiredJobDone) + }() + // It's safe to cast this ProvisionerType. This data is coming directly from coderd. + provisioner, hasProvisioner := p.opts.Provisioners[p.acquiredJob.Provisioner] + if !hasProvisioner { + p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", p.acquiredJob.Provisioner)) + return + } + + err := os.MkdirAll(p.opts.WorkDirectory, 0600) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) + return + } + + p.opts.Logger.Info(ctx, "unpacking project source archive", slog.F("size_bytes", len(p.acquiredJob.ProjectSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(p.acquiredJob.ProjectSourceArchive)) + for { + header, err := reader.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) + return + } + // #nosec + path := filepath.Join(p.opts.WorkDirectory, header.Name) + if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) { + p.cancelActiveJob("tar attempts to target relative upper directory") + return + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0600 + } + switch header.Typeflag { + case tar.TypeDir: + err = os.MkdirAll(path, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path)) + case tar.TypeReg: + file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create file %q: %s", path, err)) + return + } + // Max file size of 10MB. + size, err := io.CopyN(file, reader, (1<<20)*10) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) + return + } + err = file.Close() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", path), + slog.F("mode", mode), + ) + } + } + + switch jobType := p.acquiredJob.Type.(type) { + case *proto.AcquiredJob_ProjectImport_: + p.opts.Logger.Debug(context.Background(), "acquired job is project import", + slog.F("project_history_name", jobType.ProjectImport.ProjectHistoryName), + ) + + p.runProjectImport(ctx, provisioner, jobType) + case *proto.AcquiredJob_WorkspaceProvision_: + p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", + slog.F("workspace_name", jobType.WorkspaceProvision.WorkspaceName), + slog.F("state_length", len(jobType.WorkspaceProvision.State)), + slog.F("parameters", jobType.WorkspaceProvision.ParameterValues), + ) + + p.runWorkspaceProvision(ctx, provisioner, jobType) + default: + p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(p.acquiredJob.Type).String())) + return + } + + p.acquiredJobCancel() + p.opts.Logger.Info(context.Background(), "completed job") +} + +func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_ProjectImport_) { + stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ + Directory: p.opts.WorkDirectory, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) + return + } + defer stream.Close() + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Parse_Response_Log: + p.opts.Logger.Debug(context.Background(), "parse job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("project_history_id", job.ProjectImport.ProjectHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.acquiredJob.JobId, + ProjectImportLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) + return + } + case *sdkproto.Parse_Response_Complete: + _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + JobId: p.acquiredJob.JobId, + Type: &proto.CompletedJob_ProjectImport_{ + ProjectImport: &proto.CompletedJob_ProjectImport{ + ParameterSchemas: msgType.Complete.ParameterSchemas, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_WorkspaceProvision_) { + stream, err := provisioner.Provision(ctx, &sdkproto.Provision_Request{ + Directory: p.opts.WorkDirectory, + ParameterValues: job.WorkspaceProvision.ParameterValues, + State: job.WorkspaceProvision.State, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) + return + } + defer stream.Close() + + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + p.opts.Logger.Debug(context.Background(), "workspace provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("workspace_history_id", job.WorkspaceProvision.WorkspaceHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.acquiredJob.JobId, + WorkspaceProvisionLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) + return + } + case *sdkproto.Provision_Response_Complete: + p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + + // Complete job may need to be async if we disconnected... + // When we reconnect we can flush any of these cached values. + _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ + JobId: p.acquiredJob.JobId, + Type: &proto.CompletedJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.CompletedJob_WorkspaceProvision{ + State: msgType.Complete.State, + Resources: msgType.Complete.Resources, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) cancelActiveJob(errMsg string) { + if !p.isRunningJob() { + p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) + return + } + if p.acquiredJobCancelled.Load() { + return + } + p.acquiredJobCancelled.Store(true) + p.acquiredJobCancel() + p.opts.Logger.Info(context.Background(), "canceling running job", + slog.F("error_message", errMsg), + slog.F("job_id", p.acquiredJob.JobId), + ) + _, err := p.client.CancelJob(p.closeContext, &proto.CancelledJob{ + JobId: p.acquiredJob.JobId, + Error: fmt.Sprintf("provisioner daemon: %s", errMsg), + }) + if err != nil { + p.opts.Logger.Warn(context.Background(), "failed to notify of cancel; job is no longer running", slog.Error(err)) + return + } + p.opts.Logger.Debug(context.Background(), "canceled running job") +} + +// isClosed returns whether the API is closed or not. +func (p *provisionerDaemon) isClosed() bool { + select { + case <-p.closed: + return true + default: + return false + } +} + +// Close ends the provisioner. It will mark any running jobs as canceled. +func (p *provisionerDaemon) Close() error { + return p.closeWithError(nil) +} + +// closeWithError closes the provisioner; subsequent reads/writes will return the error err. +func (p *provisionerDaemon) closeWithError(err error) error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.isClosed() { + return p.closeError + } + + if p.isRunningJob() { + errMsg := "provisioner daemon was shutdown gracefully" + if err != nil { + errMsg = err.Error() + } + if !p.acquiredJobCancelled.Load() { + p.cancelActiveJob(errMsg) + } + <-p.acquiredJobDone + } + + p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) + p.closeError = err + close(p.closed) + p.closeCancel() + + if p.updateStream != nil { + _ = p.client.DRPCConn().Close() + _ = p.updateStream.Close() + } + + return err +} diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go new file mode 100644 index 0000000000000..8148c5369d938 --- /dev/null +++ b/provisionerd/provisionerd_test.go @@ -0,0 +1,421 @@ +package provisionerd_test + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/goleak" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/provisionerd" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestProvisionerd(t *testing.T) { + t.Parallel() + + noopUpdateJob := func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + <-stream.Context().Done() + return nil + } + + t.Run("InstantClose", func(t *testing.T) { + t.Parallel() + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil + }, provisionerd.Provisioners{}) + require.NoError(t, closer.Close()) + }) + + t.Run("ConnectErrorClose", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + defer close(completeChan) + return nil, errors.New("an error") + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("AcquireEmptyJob", func(t *testing.T) { + // The provisioner daemon is supposed to skip the job acquire if + // the job provided is empty. This is to show it successfully + // tried to get a job, but none were available. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + acquireJobAttempt := 0 + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if acquireJobAttempt == 1 { + close(completeChan) + } + acquireJobAttempt++ + return &proto.AcquiredJob{}, nil + }, + updateJob: noopUpdateJob, + }), nil + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("CloseCancelsJob", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + var closer io.Closer + var closerMutex sync.Mutex + closerMutex.Lock() + closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + closerMutex.Lock() + defer closerMutex.Unlock() + return closer.Close() + }, + }), + }) + closerMutex.Unlock() + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("MaliciousTar", func(t *testing.T) { + // Ensures tars with "../../../etc/passwd" as the path + // are not allowed to run, and will fail the job. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "../../../etc/passwd": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), + }) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("ProjectImport", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.ProjectImportLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt")) + require.NoError(t, err) + require.Equal(t, "content", string(data)) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_INFO, + Output: "hello", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Complete{ + Complete: &sdkproto.Parse_Complete{ + ParameterSchemas: []*sdkproto.ParameterSchema{}, + }, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) + + t.Run("WorkspaceProvision", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.AcquiredJob_WorkspaceProvision{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.WorkspaceProvisionLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + err := stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_DEBUG, + Output: "wow", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) +} + +// Creates an in-memory tar of the files provided. +func createTar(t *testing.T, files map[string]string) []byte { + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + for path, content := range files { + err := writer.WriteHeader(&tar.Header{ + Name: path, + Size: int64(len(content)), + }) + require.NoError(t, err) + + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + + err := writer.Flush() + require.NoError(t, err) + return buffer.Bytes() +} + +// Creates a provisionerd implementation with the provided dialer and provisioners. +func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) io.Closer { + closer := provisionerd.New(dialer, &provisionerd.Options{ + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + Provisioners: provisioners, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + +// Creates a provisionerd protobuf client that's connected +// to the server implementation provided. +func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { + clientPipe, serverPipe := provisionersdk.TransportPipe() + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := proto.DRPCRegisterProvisionerDaemon(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + _ = srv.Serve(ctx, serverPipe) + }() + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientPipe)) +} + +// Creates a provisioner protobuf client that's connected +// to the server implementation provided. +func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient { + clientPipe, serverPipe := provisionersdk.TransportPipe() + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := sdkproto.DRPCRegisterProvisioner(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + _ = srv.Serve(ctx, serverPipe) + }() + return sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(clientPipe)) +} + +type provisionerTestServer struct { + parse func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error + provision func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error +} + +func (p *provisionerTestServer) Parse(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + return p.parse(request, stream) +} + +func (p *provisionerTestServer) Provision(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + return p.provision(request, stream) +} + +// Fulfills the protobuf interface for a ProvisionerDaemon with +// passable functions for dynamic functionality. +type provisionerDaemonTestServer struct { + acquireJob func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) + updateJob func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error + cancelJob func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) + completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) +} + +func (p *provisionerDaemonTestServer) AcquireJob(ctx context.Context, empty *proto.Empty) (*proto.AcquiredJob, error) { + return p.acquireJob(ctx, empty) +} + +func (p *provisionerDaemonTestServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + return p.updateJob(stream) +} + +func (p *provisionerDaemonTestServer) CancelJob(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + return p.cancelJob(ctx, job) +} + +func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + return p.completeJob(ctx, job) +} diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 9b7952001f96c..1fbe50d506850 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -4,19 +4,22 @@ import ( "context" "errors" "io" + "net" + "os" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" + "github.com/hashicorp/yamux" + "github.com/coder/coder/provisionersdk/proto" ) // ServeOptions are configurations to serve a provisioner. type ServeOptions struct { - // Transport specifies a custom transport to serve the dRPC connection. - Transport drpc.Transport + // Conn specifies a custom transport to serve the dRPC connection. + Listener net.Listener } // Serve starts a dRPC connection for the provisioner and transport provided. @@ -25,8 +28,17 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser options = &ServeOptions{} } // Default to using stdio. - if options.Transport == nil { - options.Transport = TransportStdio() + if options.Listener == nil { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + stdio, err := yamux.Server(readWriteCloser{ + ReadCloser: os.Stdin, + Writer: os.Stdout, + }, config) + if err != nil { + return xerrors.Errorf("create yamux: %w", err) + } + options.Listener = stdio } // dRPC is a drop-in replacement for gRPC with less generated code, and faster transports. @@ -40,16 +52,15 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser // Only serve a single connection on the transport. // Transports are not multiplexed, and provisioners are // short-lived processes that can be executed concurrently. - err = srv.ServeOne(ctx, options.Transport) + err = srv.Serve(ctx, options.Listener) if err != nil { if errors.Is(err, context.Canceled) { return nil } if errors.Is(err, io.ErrClosedPipe) { - // This may occur if the transport on either end is - // closed before the context. It's fine to return - // nil here, since the server has nothing to - // communicate with. + return nil + } + if errors.Is(err, yamux.ErrSessionShutdown) { return nil } return xerrors.Errorf("serve transport: %w", err) diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index 08ac393eb8dfc..cf2dd7517df82 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "storj.io/drpc/drpcerr" "github.com/coder/coder/provisionersdk" @@ -29,12 +28,12 @@ func TestProvisionerSDK(t *testing.T) { defer cancelFunc() go func() { err := provisionersdk.Serve(ctx, &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) stream, err := api.Parse(context.Background(), &proto.Parse_Request{}) require.NoError(t, err) _, err = stream.Recv() @@ -47,7 +46,7 @@ func TestProvisionerSDK(t *testing.T) { _ = server.Close() err := provisionersdk.Serve(context.Background(), &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }) diff --git a/provisionersdk/transport.go b/provisionersdk/transport.go index c01a7ab8269e9..7fd87839d174b 100644 --- a/provisionersdk/transport.go +++ b/provisionersdk/transport.go @@ -1,44 +1,74 @@ package provisionersdk import ( + "context" "io" - "os" + "github.com/hashicorp/yamux" "storj.io/drpc" + "storj.io/drpc/drpcconn" ) -// Transport creates a dRPC transport using stdin and stdout. -func TransportStdio() drpc.Transport { - return &transport{ - in: os.Stdin, - out: os.Stdout, +// TransportPipe creates an in-memory pipe for dRPC transport. +func TransportPipe() (*yamux.Session, *yamux.Session) { + clientReader, clientWriter := io.Pipe() + serverReader, serverWriter := io.Pipe() + yamuxConfig := yamux.DefaultConfig() + yamuxConfig.LogOutput = io.Discard + client, err := yamux.Client(&readWriteCloser{ + ReadCloser: clientReader, + Writer: serverWriter, + }, yamuxConfig) + if err != nil { + panic(err) + } + + server, err := yamux.Server(&readWriteCloser{ + ReadCloser: serverReader, + Writer: clientWriter, + }, yamuxConfig) + if err != nil { + panic(err) } + return client, server } -// TransportPipe creates an in-memory pipe for dRPC transport. -func TransportPipe() (drpc.Transport, drpc.Transport) { - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - clientTransport := &transport{clientReader, clientWriter} - serverTransport := &transport{serverReader, serverWriter} +// Conn returns a multiplexed dRPC connection from a yamux session. +func Conn(session *yamux.Session) drpc.Conn { + return &multiplexedDRPC{session} +} - return clientTransport, serverTransport +type readWriteCloser struct { + io.ReadCloser + io.Writer } -// transport wraps an input and output to pipe data. -type transport struct { - in io.ReadCloser - out io.Writer +// Allows concurrent requests on a single dRPC connection. +// Required for calling functions concurrently. +type multiplexedDRPC struct { + session *yamux.Session } -func (s *transport) Read(data []byte) (int, error) { - return s.in.Read(data) +func (m *multiplexedDRPC) Close() error { + return m.session.Close() } -func (s *transport) Write(data []byte) (int, error) { - return s.out.Write(data) +func (m *multiplexedDRPC) Closed() <-chan struct{} { + return m.session.CloseChan() } -func (s *transport) Close() error { - return s.in.Close() +func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { + conn, err := m.session.Open() + if err != nil { + return err + } + return drpcconn.New(conn).Invoke(ctx, rpc, enc, in, out) +} + +func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + conn, err := m.session.Open() + if err != nil { + return nil, err + } + return drpcconn.New(conn).NewStream(ctx, rpc, enc) }