Skip to content

fix: Allow terraform provisions to be gracefully cancelled #3526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 18, 2022
95 changes: 66 additions & 29 deletions provisioner/terraform/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (e executor) basicEnv() []string {
return env
}

func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOutWriter, stdErrWriter io.WriteCloser) (err error) {
func (e executor) execWriteOutput(ctx, killCtx context.Context, args, env []string, stdOutWriter, stdErrWriter io.WriteCloser) (err error) {
defer func() {
closeErr := stdOutWriter.Close()
if err == nil && closeErr != nil {
Expand All @@ -52,8 +52,12 @@ func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOu
err = closeErr
}
}()
if ctx.Err() != nil {
return ctx.Err()
}

// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, args...)
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Env = env

Expand All @@ -63,19 +67,36 @@ func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOu
cmd.Stdout = syncWriter{mut, stdOutWriter}
cmd.Stderr = syncWriter{mut, stdErrWriter}

return cmd.Run()
err = cmd.Start()
if err != nil {
return err
}
interruptCommandOnCancel(ctx, killCtx, cmd)

return cmd.Wait()
}

func (e executor) execParseJSON(ctx context.Context, args, env []string, v interface{}) error {
func (e executor) execParseJSON(ctx, killCtx context.Context, args, env []string, v interface{}) error {
if ctx.Err() != nil {
return ctx.Err()
}

// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, args...)
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Env = env
out := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
cmd.Stdout = out
cmd.Stderr = stdErr
err := cmd.Run()

err := cmd.Start()
if err != nil {
return err
}
interruptCommandOnCancel(ctx, killCtx, cmd)

err = cmd.Wait()
if err != nil {
errString, _ := io.ReadAll(stdErr)
return xerrors.Errorf("%s: %w", errString, err)
Expand Down Expand Up @@ -109,6 +130,10 @@ func (e executor) version(ctx context.Context) (*version.Version, error) {
}

func versionFromBinaryPath(ctx context.Context, binaryPath string) (*version.Version, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

// #nosec
cmd := exec.CommandContext(ctx, binaryPath, "version", "-json")
out, err := cmd.Output()
Expand All @@ -130,7 +155,7 @@ func versionFromBinaryPath(ctx context.Context, binaryPath string) (*version.Ver
return version.NewVersion(vj.Version)
}

func (e executor) init(ctx context.Context, logr logger) error {
func (e executor) init(ctx, killCtx context.Context, logr logger) error {
outWriter, doneOut := logWriter(logr, proto.LogLevel_DEBUG)
errWriter, doneErr := logWriter(logr, proto.LogLevel_ERROR)
defer func() {
Expand All @@ -156,11 +181,11 @@ func (e executor) init(ctx context.Context, logr logger) error {
defer e.initMu.Unlock()
}

return e.execWriteOutput(ctx, args, e.basicEnv(), outWriter, errWriter)
return e.execWriteOutput(ctx, killCtx, args, e.basicEnv(), outWriter, errWriter)
}

// revive:disable-next-line:flag-parameter
func (e executor) plan(ctx context.Context, env, vars []string, logr logger, destroy bool) (*proto.Provision_Response, error) {
func (e executor) plan(ctx, killCtx context.Context, env, vars []string, logr logger, destroy bool) (*proto.Provision_Response, error) {
planfilePath := filepath.Join(e.workdir, "terraform.tfplan")
args := []string{
"plan",
Expand All @@ -184,11 +209,11 @@ func (e executor) plan(ctx context.Context, env, vars []string, logr logger, des
<-doneErr
}()

err := e.execWriteOutput(ctx, args, env, outWriter, errWriter)
err := e.execWriteOutput(ctx, killCtx, args, env, outWriter, errWriter)
if err != nil {
return nil, xerrors.Errorf("terraform plan: %w", err)
}
resources, err := e.planResources(ctx, planfilePath)
resources, err := e.planResources(ctx, killCtx, planfilePath)
if err != nil {
return nil, err
}
Expand All @@ -201,40 +226,52 @@ func (e executor) plan(ctx context.Context, env, vars []string, logr logger, des
}, nil
}

func (e executor) planResources(ctx context.Context, planfilePath string) ([]*proto.Resource, error) {
plan, err := e.showPlan(ctx, planfilePath)
func (e executor) planResources(ctx, killCtx context.Context, planfilePath string) ([]*proto.Resource, error) {
plan, err := e.showPlan(ctx, killCtx, planfilePath)
if err != nil {
return nil, xerrors.Errorf("show terraform plan file: %w", err)
}

rawGraph, err := e.graph(ctx)
rawGraph, err := e.graph(ctx, killCtx)
if err != nil {
return nil, xerrors.Errorf("graph: %w", err)
}
return ConvertResources(plan.PlannedValues.RootModule, rawGraph)
}

func (e executor) showPlan(ctx context.Context, planfilePath string) (*tfjson.Plan, error) {
func (e executor) showPlan(ctx, killCtx context.Context, planfilePath string) (*tfjson.Plan, error) {
args := []string{"show", "-json", "-no-color", planfilePath}
p := new(tfjson.Plan)
err := e.execParseJSON(ctx, args, e.basicEnv(), p)
err := e.execParseJSON(ctx, killCtx, args, e.basicEnv(), p)
return p, err
}

func (e executor) graph(ctx context.Context) (string, error) {
// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, "graph")
func (e executor) graph(ctx, killCtx context.Context) (string, error) {
if ctx.Err() != nil {
return "", ctx.Err()
}

var out bytes.Buffer
cmd := exec.CommandContext(killCtx, e.binaryPath, "graph") // #nosec
cmd.Stdout = &out
cmd.Dir = e.workdir
cmd.Env = e.basicEnv()
out, err := cmd.Output()

err := cmd.Start()
if err != nil {
return "", err
}
interruptCommandOnCancel(ctx, killCtx, cmd)

err = cmd.Wait()
if err != nil {
return "", xerrors.Errorf("graph: %w", err)
}
return string(out), nil
return out.String(), nil
}

// revive:disable-next-line:flag-parameter
func (e executor) apply(ctx context.Context, env, vars []string, logr logger, destroy bool,
func (e executor) apply(ctx, killCtx context.Context, env, vars []string, logr logger, destroy bool,
) (*proto.Provision_Response, error) {
args := []string{
"apply",
Expand All @@ -258,11 +295,11 @@ func (e executor) apply(ctx context.Context, env, vars []string, logr logger, de
<-doneErr
}()

err := e.execWriteOutput(ctx, args, env, outWriter, errWriter)
err := e.execWriteOutput(ctx, killCtx, args, env, outWriter, errWriter)
if err != nil {
return nil, xerrors.Errorf("terraform apply: %w", err)
}
resources, err := e.stateResources(ctx)
resources, err := e.stateResources(ctx, killCtx)
if err != nil {
return nil, err
}
Expand All @@ -281,12 +318,12 @@ func (e executor) apply(ctx context.Context, env, vars []string, logr logger, de
}, nil
}

func (e executor) stateResources(ctx context.Context) ([]*proto.Resource, error) {
state, err := e.state(ctx)
func (e executor) stateResources(ctx, killCtx context.Context) ([]*proto.Resource, error) {
state, err := e.state(ctx, killCtx)
if err != nil {
return nil, err
}
rawGraph, err := e.graph(ctx)
rawGraph, err := e.graph(ctx, killCtx)
if err != nil {
return nil, xerrors.Errorf("get terraform graph: %w", err)
}
Expand All @@ -300,10 +337,10 @@ func (e executor) stateResources(ctx context.Context) ([]*proto.Resource, error)
return resources, nil
}

func (e executor) state(ctx context.Context) (*tfjson.State, error) {
func (e executor) state(ctx, killCtx context.Context) (*tfjson.State, error) {
args := []string{"show", "-json"}
state := &tfjson.State{}
err := e.execParseJSON(ctx, args, e.basicEnv(), state)
err := e.execParseJSON(ctx, killCtx, args, e.basicEnv(), state)
if err != nil {
return nil, xerrors.Errorf("terraform show state: %w", err)
}
Expand Down
19 changes: 19 additions & 0 deletions provisioner/terraform/executor_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build !windows

package terraform

import (
"context"
"os"
"os/exec"
)

func interruptCommandOnCancel(ctx, killCtx context.Context, cmd *exec.Cmd) {
go func() {
select {
case <-ctx.Done():
_ = cmd.Process.Signal(os.Interrupt)
case <-killCtx.Done():
}
}()
}
19 changes: 19 additions & 0 deletions provisioner/terraform/executor_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build windows

package terraform

import (
"context"
"os/exec"
)

func interruptCommandOnCancel(ctx, killCtx context.Context, cmd *exec.Cmd) {
go func() {
select {
case <-ctx.Done():
// On Windows we can't sent an interrupt, so we just kill the process.
_ = cmd.Process.Kill()
case <-killCtx.Done():
}
}()
}
78 changes: 55 additions & 23 deletions provisioner/terraform/provision.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"strings"
"time"

"golang.org/x/xerrors"

Expand All @@ -14,11 +15,7 @@ import (
)

// Provision executes `terraform apply` or `terraform plan` for dry runs.
func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
logr := streamLogger{stream: stream}
shutdown, shutdownFunc := context.WithCancel(stream.Context())
defer shutdownFunc()

func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
request, err := stream.Recv()
if err != nil {
return err
Expand All @@ -30,36 +27,71 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
if request.GetStart() == nil {
return nil
}

// Create a context for graceful cancellation bound to the stream
// context. This ensures that we will perform graceful cancellation
// even on connection loss.
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()

// Create a separate context for forcefull cancellation not tied to
// the stream so that we can control when to terminate the process.
killCtx, kill := context.WithCancel(context.Background())
defer kill()

// Ensure processes are eventually cleaned up on graceful
// cancellation or disconnect.
go func() {
<-stream.Context().Done()

// TODO(mafredri): We should track this provision request as
// part of graceful server shutdown procedure. Waiting on a
// process here should delay provisioner/coder shutdown.
select {
case <-time.After(time.Minute):
kill()
case <-killCtx.Done():
}
}()

go func() {
for {
request, err := stream.Recv()
if err != nil {
return
}
if request.GetCancel() == nil {
// This is only to process cancels!
continue

rc := request.GetCancel()
switch {
case rc == nil:
case rc.GetForce():
// Likely not needed, but this ensures
// cancel happens before kill.
cancel()
kill()
return
default:
cancel()
// We will continue waiting for forceful cancellation
// (or until the stream is closed).
}
shutdownFunc()
return
}
}()

logr := streamLogger{stream: stream}
start := request.GetStart()

if err != nil {
return xerrors.Errorf("create new terraform executor: %w", err)
}
e := t.executor(start.Directory)
if err := e.checkMinVersion(stream.Context()); err != nil {
e := s.executor(start.Directory)
if err = e.checkMinVersion(stream.Context()); err != nil {
return err
}
if err := logTerraformEnvVars(logr); err != nil {
if err = logTerraformEnvVars(logr); err != nil {
return err
}

statefilePath := filepath.Join(start.Directory, "terraform.tfstate")
if len(start.State) > 0 {
err := os.WriteFile(statefilePath, start.State, 0600)
err = os.WriteFile(statefilePath, start.State, 0o600)
if err != nil {
return xerrors.Errorf("write statefile %q: %w", statefilePath, err)
}
Expand Down Expand Up @@ -87,12 +119,12 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
})
}

t.logger.Debug(shutdown, "running initialization")
err = e.init(stream.Context(), logr)
s.logger.Debug(ctx, "running initialization")
err = e.init(ctx, killCtx, logr)
if err != nil {
return xerrors.Errorf("initialize terraform: %w", err)
}
t.logger.Debug(shutdown, "ran initialization")
s.logger.Debug(ctx, "ran initialization")

env, err := provisionEnv(start)
if err != nil {
Expand All @@ -104,15 +136,15 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
}
var resp *proto.Provision_Response
if start.DryRun {
resp, err = e.plan(shutdown, env, vars, logr,
resp, err = e.plan(ctx, killCtx, env, vars, logr,
start.Metadata.WorkspaceTransition == proto.WorkspaceTransition_DESTROY)
} else {
resp, err = e.apply(shutdown, env, vars, logr,
resp, err = e.apply(ctx, killCtx, env, vars, logr,
start.Metadata.WorkspaceTransition == proto.WorkspaceTransition_DESTROY)
}
if err != nil {
if start.DryRun {
if shutdown.Err() != nil {
if ctx.Err() != nil {
return stream.Send(&proto.Provision_Response{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Expand Down
Loading