From 872fb462e1ad34cfc80e1e528ae395cfa6a375c2 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 15 Dec 2022 11:53:40 +0000 Subject: [PATCH] test: Fix data race in loadtest/reconnectingpty --- loadtest/reconnectingpty/run.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/loadtest/reconnectingpty/run.go b/loadtest/reconnectingpty/run.go index 8d046fac1eacd..c1252f96f1ee6 100644 --- a/loadtest/reconnectingpty/run.go +++ b/loadtest/reconnectingpty/run.go @@ -83,8 +83,8 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { } copyCtx, copyCancel := context.WithTimeout(ctx, time.Duration(copyTimeout)) + defer copyCancel() matched, err := copyContext(copyCtx, copyOutput, conn, r.cfg.ExpectOutput) - copyCancel() if r.cfg.ExpectTimeout { if err == nil { return xerrors.Errorf("expected timeout, but the command exited successfully") @@ -107,11 +107,27 @@ func copyContext(ctx context.Context, dst io.Writer, src io.Reader, expectOutput copyErr = make(chan error) matched = expectOutput == "" ) + + // Guard goroutine for loop body to ensure reading `matched` is safe on + // context cancellation and that `dst` won't be written to after we + // return from this function. + processing := make(chan struct{}, 1) + processing <- struct{}{} + go func() { + defer close(processing) defer close(copyErr) scanner := bufio.NewScanner(src) for scanner.Scan() { + select { + case <-processing: + default: + } + if ctx.Err() != nil { + return + } + if expectOutput != "" && strings.Contains(scanner.Text(), expectOutput) { matched = true } @@ -121,6 +137,7 @@ func copyContext(ctx context.Context, dst io.Writer, src io.Reader, expectOutput copyErr <- xerrors.Errorf("write to logs: %w", err) return } + processing <- struct{}{} } if scanner.Err() != nil { copyErr <- xerrors.Errorf("read from reconnecting PTY: %w", scanner.Err()) @@ -130,6 +147,10 @@ func copyContext(ctx context.Context, dst io.Writer, src io.Reader, expectOutput select { case <-ctx.Done(): + select { + case <-processing: + case <-copyErr: + } return matched, ctx.Err() case err := <-copyErr: return matched, err