-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add beginnings of torch::stable::accelerator #159679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mikaylagawarecki/332/base
Are you sure you want to change the base?
Changes from all commits
12655fc
117381f
6f02310
0ba77fb
6d43751
fb96494
0e247a1
aa47d4f
e28e5f6
43994c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
|
||
from setuptools import find_packages, setup | ||
|
||
from torch.utils.cpp_extension import BuildExtension, CppExtension | ||
import torch | ||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension | ||
|
||
|
||
ROOT_DIR = Path(__file__).parent | ||
|
@@ -35,10 +36,16 @@ def get_extension(): | |
"cxx": ["-fdiagnostics-color=always"], | ||
} | ||
|
||
extension = CppExtension | ||
# allow including <cuda_runtime.h> | ||
if torch.cuda.is_available(): | ||
extra_compile_args["cxx"].append("-DLAE_USE_CUDA") | ||
extension = CUDAExtension | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm maybe this would be a good call to move the CUDA stuff into its own C++ file and build them separately.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, I don't see the issue with the current approach, so gonna keep it as is unless there's something specific you're concerned about! :) |
||
sources = list(CSRC_DIR.glob("**/*.cpp")) | ||
|
||
return [ | ||
CppExtension( | ||
extension( | ||
"libtorch_agnostic._C", | ||
sources=sorted(str(s) for s in sources), | ||
py_limited_api=True, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -24,6 +24,10 @@ | |||||||||||||||||||||||||
#include <iostream> | ||||||||||||||||||||||||||
#include <vector> | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#include <c10/core/Device.h> | ||||||||||||||||||||||||||
#include <c10/core/DeviceGuard.h> | ||||||||||||||||||||||||||
#include <c10/core/Stream.h> | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#ifndef AT_PER_OPERATOR_HEADERS | ||||||||||||||||||||||||||
#include <ATen/Functions.h> | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
|
@@ -1612,3 +1616,55 @@ AOTITorchError aoti_torch_call_dispatcher( | |||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
AOTITorchError aoti_torch_create_device_guard( | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing this code is inspired by https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/shim_cuda.cpp#L8? |
||||||||||||||||||||||||||
int32_t device_index, | ||||||||||||||||||||||||||
DeviceGuardHandle* ret_guard // returns new reference | ||||||||||||||||||||||||||
) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ | ||||||||||||||||||||||||||
// checked=true will fail if no accelerator is available | ||||||||||||||||||||||||||
const auto device_type = | ||||||||||||||||||||||||||
at::accelerator::getAccelerator(/*checked=*/true).value(); | ||||||||||||||||||||||||||
c10::Device device(device_type, device_index); | ||||||||||||||||||||||||||
c10::DeviceGuard* guard = new c10::DeviceGuard(device); | ||||||||||||||||||||||||||
*ret_guard = reinterpret_cast<DeviceGuardHandle>(guard); | ||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( | ||||||||||||||||||||||||||
{ delete reinterpret_cast<c10::DeviceGuard*>(guard); }); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
AOTITorchError aoti_torch_device_guard_set_index( | ||||||||||||||||||||||||||
DeviceGuardHandle guard, | ||||||||||||||||||||||||||
int32_t device_index) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( | ||||||||||||||||||||||||||
{ reinterpret_cast<c10::DeviceGuard*>(guard)->set_index(device_index); }); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
AOTITorchError aoti_torch_delete_stream(StreamHandle stream) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( | ||||||||||||||||||||||||||
{ delete reinterpret_cast<c10::Stream*>(stream); }); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
AOTITorchError aoti_torch_stream_id( | ||||||||||||||||||||||||||
StreamHandle stream, | ||||||||||||||||||||||||||
int64_t* ret_stream_id) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ | ||||||||||||||||||||||||||
c10::Stream* stream_ptr = reinterpret_cast<c10::Stream*>(stream); | ||||||||||||||||||||||||||
*ret_stream_id = stream_ptr->id(); | ||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// This function creates a new Stream object and makes StreamHandle point to it. | ||||||||||||||||||||||||||
// The caller is responsible for managing the object's lifecycle. | ||||||||||||||||||||||||||
AOTITorchError aoti_torch_get_current_stream( | ||||||||||||||||||||||||||
int32_t device_index, | ||||||||||||||||||||||||||
StreamHandle* ret_stream) { | ||||||||||||||||||||||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ | ||||||||||||||||||||||||||
c10::Stream stream = at::accelerator::getCurrentStream(device_index); | ||||||||||||||||||||||||||
c10::Stream* stream_ptr = new c10::Stream(stream); | ||||||||||||||||||||||||||
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr); | ||||||||||||||||||||||||||
Comment on lines
+1667
to
+1668
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating a new stream on the heap is gonna leak memory, we should be able to assign the stream in line 1642 into the pointer (tho idk the cast semantics). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahhh after reading the rest of the code, I see what the user flow is intended to be. That said, I'm not sure if we want the semantics of get_current_stream to create a new stream on the heap (this is likely not expected from the user POV), and if we do end up sticking with this, we need to loudly document this as. I think we do not want to mess with the memory of the stream at all...cc @albanD for thoughts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm looks like @albanD had no thoughts :b I'm not sure what to do here, this one looks different from pytorch/torch/csrc/inductor/aoti_torch/shim_cuda.cpp Lines 49 to 55 in 5f5f508
As that is just directly returning the pytorch/aten/src/ATen/DeviceAccelerator.cpp Lines 106 to 110 in 2ee22e4
I think I can add a comment for sure, but would be curious how else we can tweak this in a way that's more user/memory friendly, would it make more sense to just return the .id() directly? |
||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,71 @@ | ||||||||||||||||||||||
#pragma once | ||||||||||||||||||||||
|
||||||||||||||||||||||
#include <torch/csrc/inductor/aoti_runtime/utils.h> | ||||||||||||||||||||||
|
||||||||||||||||||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h> | ||||||||||||||||||||||
|
||||||||||||||||||||||
namespace torch::stable::accelerator { | ||||||||||||||||||||||
|
||||||||||||||||||||||
namespace { | ||||||||||||||||||||||
inline void delete_device_guard(void* ptr) { | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK( | ||||||||||||||||||||||
aoti_torch_delete_device_guard(reinterpret_cast<DeviceGuardHandle>(ptr))); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
} // namespace | ||||||||||||||||||||||
|
||||||||||||||||||||||
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we | ||||||||||||||||||||||
// can converge on in this world as DeviceIndex in libtorch is not stable. | ||||||||||||||||||||||
using DeviceIndex = int32_t; | ||||||||||||||||||||||
using StreamId = int64_t; // this is from c10/core/Stream.h | ||||||||||||||||||||||
|
||||||||||||||||||||||
class DeviceGuard { | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something that I'm not sure about -- do the copy / move semantics need to match the real device guard although this just a wrapper that is holding a pytorch/c10/core/DeviceGuard.h Lines 37 to 46 in 444e238
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway. |
||||||||||||||||||||||
public: | ||||||||||||||||||||||
explicit DeviceGuard() = delete; | ||||||||||||||||||||||
explicit DeviceGuard(DeviceIndex device_index) | ||||||||||||||||||||||
: guard_(nullptr, delete_device_guard) { | ||||||||||||||||||||||
DeviceGuardHandle ptr = nullptr; | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK( | ||||||||||||||||||||||
aoti_torch_create_device_guard(device_index, &ptr)); | ||||||||||||||||||||||
guard_.reset(ptr); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
void set_index(DeviceIndex device_index) { | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK( | ||||||||||||||||||||||
aoti_torch_device_guard_set_index(guard_.get(), device_index)); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
private: | ||||||||||||||||||||||
std::unique_ptr<DeviceGuardOpaque, aot_inductor::DeleterFnPtr> guard_; | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
|
||||||||||||||||||||||
class Stream { | ||||||||||||||||||||||
public: | ||||||||||||||||||||||
explicit Stream() = delete; | ||||||||||||||||||||||
|
||||||||||||||||||||||
// Construct a stable::Stream from a StreamHandle | ||||||||||||||||||||||
// Steals ownership from the StreamHandle | ||||||||||||||||||||||
explicit Stream(StreamHandle stream) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see, so the expected use case is:
|
||||||||||||||||||||||
: stream_(stream, [](StreamHandle stream) { | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream)); | ||||||||||||||||||||||
}) {} | ||||||||||||||||||||||
|
||||||||||||||||||||||
StreamId id() const { | ||||||||||||||||||||||
StreamId stream_id; | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK( | ||||||||||||||||||||||
aoti_torch_stream_id(stream_.get(), &stream_id)); | ||||||||||||||||||||||
return stream_id; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
private: | ||||||||||||||||||||||
std::shared_ptr<StreamOpaque> stream_; | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
|
||||||||||||||||||||||
Stream getCurrentStream(DeviceIndex device_index) { | ||||||||||||||||||||||
StreamHandle stream = nullptr; | ||||||||||||||||||||||
AOTI_TORCH_ERROR_CODE_CHECK( | ||||||||||||||||||||||
aoti_torch_get_current_stream(device_index, &stream)); | ||||||||||||||||||||||
return Stream(stream); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
} // namespace torch::stable::accelerator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that before the guard you have another device than this one.
Because if there is only 1 gpu, I'm not sure this will ever do anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test that exercises this from python is decorated with
@deviceCountAtLeast(2)
so this situation won't happen I think :)