-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Multi-threaded concurrent fetching in Dataloader for high-latency storage. #158218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158218
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 333fc77 with merge base 0f21fa8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/utils/data/_utils/fetch.py
Outdated
) | ||
await result_queue.put((index, result)) | ||
except Exception as e: | ||
print(f"Exception during fetch: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a Python logger object and exception category etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I just made the changes following your suggestion.
Note that we have limited maintenance on dataloading so such a big change might take quite a bit of time to get reviewed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Kai-46
Using multiple threads for improving data access seems like a very good implementation at the dataset level. The dataset can have a __getitems__
implementation which works well with a batch_sampler
(https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/fetch.py#L49). Then the async calls to the cloud storage can live inside the dataset class.
The S3 plugins built by AWS come to mind: https://aws.amazon.com/blogs/machine-learning/announcing-the-amazon-s3-plugin-for-pytorch/ These are dataset implementation for cloud storage access, and I think they have multi-threaded. |
Thanks for the feedback. One quick question: do we prefer to have the users take care of the multi-threading in their own dataset's |
The multi-threading usage for dataset access should live in the dataset itself. That is because that implementation is often tuned based on the data storage and data access - which the dataloader layer cannot always control. That said, there is a world where can have threading based workers in place of multiprocessing inside the dataloader. Although that is a bigger topic and something we are thinking about. |
Sorry for the late reply. Here're my two cents: having threading support in dataloader does not hurt as it is backward-compatible; meanwhile, it can quickly unblock certain use cases where people might find it cumbersome or error-prone to implement their own multi-threading logic in dataset.py. Please let me know what you know. |
Generally aligned with you @Kai-46 . We are testing out an implementation which uses threads in the dataloader - at the dataloader level directly (#158714 (comment)). Would love to test it out with remote multi-threaded dataset access. |
Currently, each worker spawned by the DataLoader runs as a separate Python process, but the workload within each worker is single-threaded. In some cases — especially when streaming from high-lantecy storage like S3 or NFS — using multiple threads within each worker can significantly improve throughput, as now fetching a single data item takes quite a while compared with low-latency storage like SSD.
This PR enables multi-threading inside each worker. It exposes the
num_threads
option in theDataLoader
class's instantiation function, allowing users to easily tweak the concurrency setup. By default, we usenum_threads=1
for backward compatibility.To demonstrate the effectiveness of this threaded dataloader, I wrote a simple script to loop over a S3-based dataloader with a
fake training fwd/bwd time being 1s
. On my AWS machine, I got the following results:Setup 1:
python simple_data_loop.py --num_threads=1
leads toper_iter_time: 2.317s
, indicating a data streaming latency of 1.3s.Setup 2:
python simple_data_loop.py --num_threads=32
leads toper_iter_time: 1.021s
, indicating a data streaming latency of 0.021s.We can see multi-threading hides the data streaming latency almost completely.
simple_data_loop.py
: