Skip to content

Multi-threaded concurrent fetching in Dataloader for high-latency storage. #158218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Kai-46
Copy link

@Kai-46 Kai-46 commented Jul 14, 2025

Currently, each worker spawned by the DataLoader runs as a separate Python process, but the workload within each worker is single-threaded. In some cases — especially when streaming from high-lantecy storage like S3 or NFS — using multiple threads within each worker can significantly improve throughput, as now fetching a single data item takes quite a while compared with low-latency storage like SSD.

This PR enables multi-threading inside each worker. It exposes the num_threads option in the DataLoader class's instantiation function, allowing users to easily tweak the concurrency setup. By default, we use num_threads=1 for backward compatibility.

To demonstrate the effectiveness of this threaded dataloader, I wrote a simple script to loop over a S3-based dataloader with a fake training fwd/bwd time being 1s. On my AWS machine, I got the following results:

Setup 1: python simple_data_loop.py --num_threads=1 leads to per_iter_time: 2.317s, indicating a data streaming latency of 1.3s.

Setup 2: python simple_data_loop.py --num_threads=32 leads to per_iter_time: 1.021s, indicating a data streaming latency of 0.021s.

We can see multi-threading hides the data streaming latency almost completely.

simple_data_loop.py:

import numpy as np
import torch
import torch.utils.data as torch_data


import logging
import os
import random
import time
from io import BytesIO

import boto3
import pandas as pd

from boto3.s3.transfer import TransferConfig
from botocore import UNSIGNED
from botocore.config import Config
from PIL import Image
from tqdm import tqdm


logger = logging.getLogger(__name__)


class S3ImgDataset(torch_data.Dataset):
    def __init__(self, s3_img_path_list: list[str]):
        super().__init__()

        # https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662
        self.s3_img_path_list = pd.array(s3_img_path_list, dtype="string")

        self.s3_transfer_config = TransferConfig(multipart_threshold=5 * 1024**3)
        self.s3_client = boto3.client(
            "s3",
            aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
            aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
            aws_session_token=os.getenv("AWS_SESSION_TOKEN"),
            config=Config(
                region_name=os.getenv("AWS_REGION", "us-west-2"),
                signature_version=UNSIGNED,
            ),
        )

    def __len__(self):
        return len(self.s3_img_path_list)

    def __getitem__(self, idx):
        # each path is like s3://<bucket>/<prefix>
        try:
            s3_img_path = self.s3_img_path_list[idx]
            bucket, prefix = s3_img_path.replace("s3://", "").split("/", 1)

            bytesio = BytesIO()
            self.s3_client.download_fileobj(
                bucket, prefix, bytesio, Config=self.s3_transfer_config
            )
            bytesio.seek(0)

            img = torch.from_numpy(np.array(Image.open(bytesio)))
        except Exception as e:
            logger.error(f"Error loading image {s3_img_path}: {e}")
            return self.__getitem__(random.randint(0, len(self) - 1))

        return img


def benchmark(
    s3_img_path_list: list[str],
    fake_training_time,
    batch_size,
    num_workers,
    num_threads,
    prefetch_factor,
):
    dataset = S3ImgDataset(s3_img_path_list)
    dataloader = torch_data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        num_threads=num_threads,
        prefetch_factor=prefetch_factor,
        persistent_workers=True,
        pin_memory=True,
        drop_last=True,  # Ensure consistent batch sizes
    )

    num_warmup_batches, num_test_batches = 20, 100

    for batch_idx, batch in tqdm(enumerate(dataloader)):
        if batch_idx == num_warmup_batches:
            tic = time.time()

        # fake training
        time.sleep(fake_training_time)


        if batch_idx > num_test_batches + num_warmup_batches:
            break

    per_iter_time = (time.time() - tic) / num_test_batches
    return per_iter_time


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--fake_training_time",
        type=float,
        default=1.0,
        help="fake training time in seconds",
    )
    parser.add_argument(
        "--batch_size", type=int, default=256, help="how many objects to load per batch"
    )
    parser.add_argument(
        "--num_workers", type=int, default=8, help="how many workers to use"
    )
    parser.add_argument(
        "--num_threads", type=int, default=32, help="how many threads to use"
    )
    parser.add_argument(
        "--prefetch_factor", type=int, default=16, help="how many batches to prefetch"
    )

    args = parser.parse_args()

    with open("s3_img_path_list.txt", "r") as f:
        s3_img_path_list = [line.strip() for line in f.readlines() if line.strip()]

    # repeat s3_img_path_list 1000 times
    s3_img_path_list = s3_img_path_list * 1000

    per_iter_time = benchmark(
        s3_img_path_list,
        args.fake_training_time,
        args.batch_size,
        args.num_workers,
        args.num_threads,
        args.prefetch_factor,
    )
    print(f"args: {args}")
    print(f"per_iter_time: {per_iter_time:.3f}s")

Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158218

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 333fc77 with merge base 0f21fa8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jul 14, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label Jul 14, 2025
)
await result_queue.put((index, result))
except Exception as e:
print(f"Exception during fetch: {e}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a Python logger object and exception category etc...

Copy link
Author

@Kai-46 Kai-46 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I just made the changes following your suggestion.

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2025
@albanD
Copy link
Collaborator

albanD commented Jul 14, 2025

Note that we have limited maintenance on dataloading so such a big change might take quite a bit of time to get reviewed.

Copy link
Contributor

@divyanshk divyanshk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Kai-46

Using multiple threads for improving data access seems like a very good implementation at the dataset level. The dataset can have a __getitems__ implementation which works well with a batch_sampler (https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/fetch.py#L49). Then the async calls to the cloud storage can live inside the dataset class.

@divyanshk
Copy link
Contributor

The S3 plugins built by AWS come to mind: https://aws.amazon.com/blogs/machine-learning/announcing-the-amazon-s3-plugin-for-pytorch/

These are dataset implementation for cloud storage access, and I think they have multi-threaded.

@Kai-46
Copy link
Author

Kai-46 commented Jul 16, 2025

Thanks for the feedback. One quick question: do we prefer to have the users take care of the multi-threading in their own dataset's __getitems__ implementation, or have it live in pytorch in order to avoid boilterplates? I feel that having it as part of the pytorch dataloader can come particularly handy for fast prototyping.

@divyanshk
Copy link
Contributor

The multi-threading usage for dataset access should live in the dataset itself. That is because that implementation is often tuned based on the data storage and data access - which the dataloader layer cannot always control. That said, there is a world where can have threading based workers in place of multiprocessing inside the dataloader. Although that is a bigger topic and something we are thinking about.

@Kai-46
Copy link
Author

Kai-46 commented Aug 7, 2025

Sorry for the late reply. Here're my two cents: having threading support in dataloader does not hurt as it is backward-compatible; meanwhile, it can quickly unblock certain use cases where people might find it cumbersome or error-prone to implement their own multi-threading logic in dataset.py. Please let me know what you know.

@divyanshk
Copy link
Contributor

Generally aligned with you @Kai-46 . We are testing out an implementation which uses threads in the dataloader - at the dataloader level directly (#158714 (comment)). Would love to test it out with remote multi-threaded dataset access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: dataloader release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants