From e76b59e8635f18e1a50694a7e42aac4e8d8a58ad Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 14 Feb 2023 15:55:30 +0000 Subject: [PATCH 1/2] fix: Prevent race between provisionerd connect and close --- provisionerd/provisionerd.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 583826c54c194..ce9fe9f8788fb 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -177,7 +177,17 @@ func (p *Server) connect(ctx context.Context) { p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) continue } + // Ensure connection is not left hanging during a race between + // close and dial succeeding. + p.mutex.Lock() + if p.isClosed() { + client.DRPCConn().Close() + p.mutex.Unlock() + break + } p.clientValue.Store(client) + p.mutex.Unlock() + p.opts.Logger.Debug(context.Background(), "connected") break } @@ -390,7 +400,8 @@ func retryable(err error) bool { // is not retryable() or the context expires. func (p *Server) clientDoWithRetries( ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) ( - any, error) { + any, error, +) { for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); { client, ok := p.client() if !ok { From 763c10e1d4d84689e25e48ff2e4a8f06493b224c Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 14 Feb 2023 15:56:40 +0000 Subject: [PATCH 2/2] test: Add detection for provisioner creation after test completion --- provisionerd/provisionerd_test.go | 144 +++++++++++++++++++++++------- 1 file changed, 114 insertions(+), 30 deletions(-) diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index ca2899dd3fc1a..e15a79cf28d4f 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -55,14 +55,22 @@ func TestProvisionerd(t *testing.T) { t.Run("InstantClose", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil }, provisionerd.Provisioners{}) require.NoError(t, closer.Close()) }) t.Run("ConnectErrorClose", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) completeChan := make(chan struct{}) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { defer close(completeChan) @@ -77,10 +85,14 @@ func TestProvisionerd(t *testing.T) { // the job provided is empty. This is to show it successfully // tried to get a job, but none were available. t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) completeChan := make(chan struct{}) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { acquireJobAttempt := 0 - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if acquireJobAttempt == 1 { close(completeChan) @@ -97,13 +109,17 @@ func TestProvisionerd(t *testing.T) { t.Run("CloseCancelsJob", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) completeChan := make(chan struct{}) var completed sync.Once var closer io.Closer var closerMutex sync.Mutex closerMutex.Lock() closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { return &proto.AcquiredJob{ JobId: "test", @@ -127,7 +143,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { closerMutex.Lock() defer closerMutex.Unlock() @@ -144,13 +160,17 @@ func TestProvisionerd(t *testing.T) { // Ensures tars with "../../../etc/passwd" as the path // are not allowed to run, and will fail the job. t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( completeChan = make(chan struct{}) completeOnce sync.Once ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { return &proto.AcquiredJob{ JobId: "test", @@ -172,7 +192,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) @@ -180,13 +200,17 @@ func TestProvisionerd(t *testing.T) { t.Run("RunningPeriodicUpdate", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( completeChan = make(chan struct{}) completeOnce sync.Once ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { return &proto.AcquiredJob{ JobId: "test", @@ -210,7 +234,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { <-stream.Context().Done() return nil @@ -223,6 +247,10 @@ func TestProvisionerd(t *testing.T) { t.Run("TemplateImport", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( didComplete atomic.Bool didLog atomic.Bool @@ -234,7 +262,7 @@ func TestProvisionerd(t *testing.T) { ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if !didAcquireJob.CAS(false, true) { completeOnce.Do(func() { close(completeChan) }) @@ -270,7 +298,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt")) require.NoError(t, err) @@ -332,6 +360,10 @@ func TestProvisionerd(t *testing.T) { t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( didComplete atomic.Bool didLog atomic.Bool @@ -355,7 +387,7 @@ func TestProvisionerd(t *testing.T) { ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if !didAcquireJob.CAS(false, true) { completeOnce.Do(func() { close(completeChan) }) @@ -394,7 +426,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { err := stream.Send(&sdkproto.Provision_Response{ Type: &sdkproto.Provision_Response_Complete{ @@ -417,6 +449,10 @@ func TestProvisionerd(t *testing.T) { t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( didComplete atomic.Bool didLog atomic.Bool @@ -426,7 +462,7 @@ func TestProvisionerd(t *testing.T) { ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if !didAcquireJob.CAS(false, true) { completeOnce.Do(func() { close(completeChan) }) @@ -458,7 +494,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { err := stream.Send(&sdkproto.Provision_Response{ Type: &sdkproto.Provision_Response_Log{ @@ -488,6 +524,10 @@ func TestProvisionerd(t *testing.T) { t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( didComplete atomic.Bool didLog atomic.Bool @@ -498,7 +538,7 @@ func TestProvisionerd(t *testing.T) { ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if !didAcquireJob.CAS(false, true) { completeOnce.Do(func() { close(completeChan) }) @@ -539,7 +579,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { err := stream.Send(&sdkproto.Provision_Response{ Type: &sdkproto.Provision_Response_Log{ @@ -579,6 +619,10 @@ func TestProvisionerd(t *testing.T) { t.Run("WorkspaceBuildFailComplete", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( didFail atomic.Bool didAcquireJob atomic.Bool @@ -587,7 +631,7 @@ func TestProvisionerd(t *testing.T) { ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if !didAcquireJob.CAS(false, true) { completeOnce.Do(func() { close(completeChan) }) @@ -614,7 +658,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { return stream.Send(&sdkproto.Provision_Response{ Type: &sdkproto.Provision_Response_Complete{ @@ -633,12 +677,16 @@ func TestProvisionerd(t *testing.T) { t.Run("Shutdown", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var updated sync.Once var completed sync.Once updateChan := make(chan struct{}) completeChan := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { return &proto.AcquiredJob{ JobId: "test", @@ -676,7 +724,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { // Ignore the first provision message! _, _ = stream.Recv() @@ -714,12 +762,16 @@ func TestProvisionerd(t *testing.T) { t.Run("ShutdownFromJob", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var completed sync.Once var updated sync.Once updateChan := make(chan struct{}) completeChan := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { return &proto.AcquiredJob{ JobId: "test", @@ -765,7 +817,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { // Ignore the first provision message! _, _ = stream.Recv() @@ -801,6 +853,10 @@ func TestProvisionerd(t *testing.T) { t.Run("ReconnectAndFail", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( second atomic.Bool failChan = make(chan struct{}) @@ -811,7 +867,7 @@ func TestProvisionerd(t *testing.T) { completeOnce sync.Once ) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if second.Load() { return &proto.AcquiredJob{}, nil @@ -854,7 +910,7 @@ func TestProvisionerd(t *testing.T) { } return client, nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { // Ignore the first provision message! _, _ = stream.Recv() @@ -874,6 +930,10 @@ func TestProvisionerd(t *testing.T) { t.Run("ReconnectAndComplete", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) var ( second atomic.Bool failChan = make(chan struct{}) @@ -884,7 +944,7 @@ func TestProvisionerd(t *testing.T) { completeOnce sync.Once ) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { if second.Load() { completeOnce.Do(func() { close(completeChan) }) @@ -929,7 +989,7 @@ func TestProvisionerd(t *testing.T) { } return client, nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { // Ignore the first provision message! _, _ = stream.Recv() @@ -947,6 +1007,10 @@ func TestProvisionerd(t *testing.T) { t.Run("UpdatesBeforeComplete", func(t *testing.T) { t.Parallel() + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) logger := slogtest.Make(t, nil) m := sync.Mutex{} var ops []string @@ -954,7 +1018,7 @@ func TestProvisionerd(t *testing.T) { completeOnce := sync.Once{} server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { m.Lock() defer m.Unlock() @@ -1004,7 +1068,7 @@ func TestProvisionerd(t *testing.T) { }, }), nil }, provisionerd.Provisioners{ - "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { err := stream.Send(&sdkproto.Provision_Response{ Type: &sdkproto.Provision_Response_Log{ @@ -1070,7 +1134,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners p // Creates a provisionerd protobuf client that's connected // to the server implementation provided. -func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { +func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { t.Helper() if server.failJob == nil { // Default to asserting the error from the failure, otherwise @@ -1098,13 +1162,23 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer t.Cleanup(func() { cancelFunc() <-closed + select { + case <-done: + t.Error("createProvisionerDaemonClient cleanup after test was done!") + default: + } }) + select { + case <-done: + t.Error("called createProvisionerDaemonClient after test was done!") + default: + } return proto.NewDRPCProvisionerDaemonClient(clientPipe) } // Creates a provisioner protobuf client that's connected // to the server implementation provided. -func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient { +func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient { t.Helper() clientPipe, serverPipe := provisionersdk.MemTransportPipe() t.Cleanup(func() { @@ -1124,7 +1198,17 @@ func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkprot t.Cleanup(func() { cancelFunc() <-closed + select { + case <-done: + t.Error("createProvisionerClient cleanup after test was done!") + default: + } }) + select { + case <-done: + t.Error("called createProvisionerClient after test was done!") + default: + } return sdkproto.NewDRPCProvisionerClient(clientPipe) }