Skip to content

Data corruption when reading data as CUDA tensor from a different process #134273

@richmengsix

Description

@richmengsix

🐛 Describe the bug

For performance reasons, I want to separate IO tasks by reading video frames in a separate process and then moving them to the GPU. This way, the main process can focus solely on running the processing logic. However, I'm encountering data corruption issues when passing CUDA tensors across processes.

Here's a minimal example of how to reproduce it.

import torch
import cv2
import os
import torch.multiprocessing as mp
import numpy as np

def producer(queue, event, video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to tensor and move to CUDA
        tensor_frame = torch.from_numpy(frame).cuda()

        # Put the tensor into the queue
        queue.put(tensor_frame)
        print(f'Frame {frame_count} added to queue')
        frame_count += 1
    # Indicate end of stream
    queue.put(None)

    event.wait()
    cap.release()

def consumer(queue, event, output_folder):
    frame_count = 0
    while True:
        tensor_frame = queue.get()
        if tensor_frame is None:
            break
        
        # Move the tensor back to CPU and convert to numpy array
        frame = tensor_frame.cpu().numpy()

        # Write the frame to disk
        output_path = f"{output_folder}/frame_{frame_count:04d}.png"
        
        # Detect if frame is all 0s
        if np.all(frame == 0):
            print(f"CORRUPTED: Frame {frame_count} is all 0s")

        cv2.imwrite(output_path, frame)
        frame_count += 1
        print(f'Frame {frame_count} read from queue')

    event.set()

if __name__ == "__main__":
    mp.freeze_support()
    mp.set_start_method("spawn", force=True)

    video_path = "input.mp4"
    output_folder = "output_frames"
    # mkdir 
    os.makedirs(output_folder, exist_ok=True)

    queue = mp.Queue()
    event = mp.Event()

    # Launch producer and consumer processes
    producer_process = mp.Process(target=producer, args=(queue, event, video_path))
    consumer_process = mp.Process(target=consumer, args=(queue, event, output_folder))

    producer_process.start()
    consumer_process.start()

    producer_process.join()
    consumer_process.join()

I used this video for testing. In the output folder, some of the frames are saved without any issues, but others are corrupted, with the data containing all zeros. The corrupted frames vary from run to run, but there seems to be a pattern.

Here is a visual representation of the issue:
image

If removing .cuda() in producer, output frames have no corruption:
image

According to documentation, I believe I was passing data correctly with mp.Queue(). I have also attempted to preserve the CUDA tensor in producer with an extra tensor_cache, but corruption issue did not improve.

def producer(queue, event, video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    tensor_cache = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to tensor and move to CUDA
        tensor_frame = torch.from_numpy(frame).cuda()
        tensor_cache.append(tensor_frame)

        # Put the tensor into the queue
        queue.put(tensor_frame)
        print(f'Frame {frame_count} added to queue')
        frame_count += 1
    # Indicate end of stream
    queue.put(None)

    event.wait()
    cap.release()

Thanks a lot for your help!

Versions

PyTorch version: 2.4.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.11 (tags/v3.9.11:2de452f, Mar 16 2022, 14:33:45) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 536.23
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3401
DeviceID=CPU0
Family=107
L2CacheSize=8192
L2CacheSpeed=
Manufacturer=AuthenticAMD
MaxClockSpeed=3401
Name=AMD Ryzen 9 5950X 16-Core Processor
ProcessorType=3
Revision=8448

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.4.0+cu118
[pip3] torchaudio==2.4.0+cu118
[pip3] torchvision==0.19.0+cu118
[conda] Could not collect

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm @iremyux @Blackhex @VitalyFedyunin @albanD @pragupta

Metadata

Metadata

Assignees

Labels

module: multiprocessingRelated to torch.multiprocessingmodule: windowsWindows support for PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions