-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add new_empty (with dtype argument only) to torch::stable #159508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mikaylagawarecki/331/base
Are you sure you want to change the base?
Changes from all commits
cbcdcfe
6fced91
648adaa
ae40182
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -203,3 +203,15 @@ def my_narrow(t, dim, start, length) -> Tensor: | |||||
Returns: Narrowed tensor | ||||||
""" | ||||||
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length) | ||||||
|
||||||
|
||||||
def my_new_empty(t) -> Tensor: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's name this to something more specific which indicates the test better
Suggested change
|
||||||
""" | ||||||
Returns a new empty tensor with shape [2, 5] and dtype bfloat16 | ||||||
Args: | ||||||
t: Input tensor used as a reference for device and other properties | ||||||
Returns: New empty tensor with shape [2, 5] and dtype bfloat16 | ||||||
""" | ||||||
return torch.ops.libtorch_agnostic.my_new_empty.default(t) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -252,6 +252,21 @@ def test_my_narrow(self, device): | |||||
expected0 = torch.narrow(t, dim0, start0, length0) | ||||||
self.assertEqual(out0, expected0) | ||||||
|
||||||
def test_my_new_empty(self, device): | ||||||
import libtorch_agnostic | ||||||
|
||||||
deterministic = torch.are_deterministic_algorithms_enabled() | ||||||
try: | ||||||
# set use_deterministic_algorithms to fill unintialized memory | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit |
||||||
torch.use_deterministic_algorithms(True) | ||||||
t = torch.randn(3, 4, device=device) | ||||||
out = libtorch_agnostic.ops.my_new_empty(t) | ||||||
ref_out = t.new_empty((2, 5), dtype=torch.bfloat16) | ||||||
|
||||||
self.assertEqual(out, ref_out, exact_device=True) | ||||||
finally: | ||||||
torch.use_deterministic_algorithms(deterministic) | ||||||
|
||||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) | ||||||
|
||||||
if __name__ == "__main__": | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -8,6 +8,7 @@ | |||||
#include <vector> | ||||||
|
||||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h> | ||||||
#include <torch/headeronly/core/ScalarType.h> | ||||||
|
||||||
using torch::stable::Tensor; | ||||||
|
||||||
|
@@ -51,6 +52,45 @@ inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) { | |||||
return Tensor(ret0); | ||||||
} | ||||||
|
||||||
// We expect this to be a stable version of the new_empty op that takes in | ||||||
// only dtype information. | ||||||
inline Tensor new_empty( | ||||||
const Tensor& self, | ||||||
std::vector<int64_t> size, | ||||||
std::optional<c10::ScalarType> dtype = std::nullopt) { | ||||||
int32_t device_type; | ||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); | ||||||
|
||||||
int32_t device_index; | ||||||
TORCH_ERROR_CODE_CHECK( | ||||||
aoti_torch_get_device_index(self.get(), &device_index)); | ||||||
|
||||||
// Handle dtype - use input tensor's dtype if not specified | ||||||
int32_t target_dtype; | ||||||
if (dtype.has_value()) { | ||||||
target_dtype = static_cast<int32_t>(dtype.value()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm lol is this actually stable...this looks to expose a detail on how dtype is encapsulated. Should we/I write that translation layer from the headeronly ScalarType to their corresponding shim aoti_torch_get_dtype... maybe there's no other way to hide this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it would be nice if you could write that! Can you describe what you mean in more detail of what the translation layer would do |
||||||
} else { | ||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); | ||||||
} | ||||||
|
||||||
int32_t layout; | ||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); | ||||||
|
||||||
AtenTensorHandle ret0 = nullptr; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit |
||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( | ||||||
self.get(), | ||||||
size.data(), | ||||||
static_cast<int64_t>(size.size()), | ||||||
&target_dtype, | ||||||
&layout, | ||||||
&device_type, | ||||||
device_index, | ||||||
nullptr, // pin_memory (nullptr for default) | ||||||
&ret0)); | ||||||
|
||||||
return Tensor(ret0); | ||||||
} | ||||||
|
||||||
// We expect this to be the stable version of the pad.default op. | ||||||
// pad.default takes in a SymInt[] as the pad argument however pad is typed as | ||||||
// use std::vector<int64_t> because | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,4 +185,5 @@ | |
"aten.fill_.Scalar": {}, | ||
"aten.pad.default": {}, | ||
"aten.narrow.default": {}, | ||
"aten.new_empty.default": {}, | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -24,6 +24,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
OperatorName, | ||||||||||||||||||||||||||||||||||||||||||||||||
OptionalType, | ||||||||||||||||||||||||||||||||||||||||||||||||
Type, | ||||||||||||||||||||||||||||||||||||||||||||||||
Variant, | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
from torchgen.utils import FileManager, mapMaybe | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -396,7 +397,23 @@ def gen_static_dispatch_backend_call( | |||||||||||||||||||||||||||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||
sig = DispatcherSignature.from_schema(f.func) | ||||||||||||||||||||||||||||||||||||||||||||||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
if backend_index is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
# Check if this is a symint function and if the function only has method variants | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any other ops we expect to go through this branch? Maybe list new_empty as an example here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm it's listed below (when all the if checks pass) on 412 |
||||||||||||||||||||||||||||||||||||||||||||||||
if sig.symint and f.func.has_symint(): | ||||||||||||||||||||||||||||||||||||||||||||||||
# Check if the function has standalone function variants | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are standalone function variants |
||||||||||||||||||||||||||||||||||||||||||||||||
has_function_variant = Variant.function in f.variants | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
if not has_function_variant: | ||||||||||||||||||||||||||||||||||||||||||||||||
# Functions with both function and method variants can use the at::{*}_symint version | ||||||||||||||||||||||||||||||||||||||||||||||||
# (e.g., narrow -> at::narrow_symint), BUT | ||||||||||||||||||||||||||||||||||||||||||||||||
# Method-only functions with symint parameters should use at::symint:: namespace | ||||||||||||||||||||||||||||||||||||||||||||||||
# Remove the _symint suffix since at::symint:: namespace uses the base name | ||||||||||||||||||||||||||||||||||||||||||||||||
# (e.g., new_empty -> at::symint::new_empty<c10::SymInt>) | ||||||||||||||||||||||||||||||||||||||||||||||||
base_name = cpp_sig.name() | ||||||||||||||||||||||||||||||||||||||||||||||||
base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix | ||||||||||||||||||||||||||||||||||||||||||||||||
return f"at::symint::{base_name}<c10::SymInt>" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+407
to
+416
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the relevant part here Lines 717 to 739 in 6d91d6d
|
||||||||||||||||||||||||||||||||||||||||||||||||
return f"at::{cpp_sig.name()}" | ||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.