diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index 0b0548cfd0c72..079b1891346eb 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -239,6 +239,12 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command { return xerrors.Errorf("shutdown: %w", err) } + // Shutdown does not call close. Must call it manually. + err = srv.Close() + if err != nil { + return xerrors.Errorf("close server: %w", err) + } + cancel() if xerrors.Is(exitErr, context.Canceled) { return nil diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 3e49648700f2f..deac80466b48d 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -236,6 +236,9 @@ func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) { select { case <-p.closeContext.Done(): return nil, false + case <-p.shuttingDownCh: + // Shutting down should return a nil client and unblock + return nil, false case client := <-p.clientCh: return client, true } diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 2031fa6c3939e..bca072707f491 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -597,6 +597,38 @@ func TestProvisionerd(t *testing.T) { assert.True(t, didFail.Load(), "should fail the job") }) + // Simulates when there is no coderd to connect to. So the client connection + // will never be established. + t.Run("ShutdownNoCoderd", func(t *testing.T) { + t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) + + connectAttemptedClose := sync.Once{} + connectAttempted := make(chan struct{}) + server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + // This is the dial out to Coderd, which in this unit test will always fail. + connectAttemptedClose.Do(func() { close(connectAttempted) }) + return nil, fmt.Errorf("client connection always fails") + }, provisionerd.LocalProvisioners{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}), + }) + + // Wait for at least 1 attempt to connect to ensure the connect go routine + // is running. + require.Condition(t, closedWithin(connectAttempted, testutil.WaitShort)) + + // The test is ensuring this Shutdown call does not block indefinitely. + // If it does, the context will return with an error, and the test will + // fail. + shutdownCtx := testutil.Context(t, testutil.WaitShort) + err := server.Shutdown(shutdownCtx, true) + require.NoError(t, err, "shutdown did not unblock. Failed to close the server gracefully.") + require.NoError(t, server.Close()) + }) + t.Run("Shutdown", func(t *testing.T) { t.Parallel() done := make(chan struct{})