diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index e3dfc581179a..f012b10656c5 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -4,6 +4,8 @@ #include #include +#include + #include void inline sgd_math( @@ -322,18 +324,31 @@ void boxed_my_narrow( stack[0] = from(res); } +Tensor my_new_empty(Tensor t) { + std::vector sizes = {2, 5}; + auto dtype = std::make_optional(at::ScalarType::BFloat16); + return new_empty(t, sizes, dtype); +} + +void boxed_my_new_empty(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_new_empty(to(stack[0])); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); m.def("my_empty_like(Tensor t) -> Tensor"); m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)"); m.def("my_pad(Tensor t) -> Tensor"); m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor"); + m.def("my_new_empty(Tensor t) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_transpose", &boxed_my_transpose); m.impl("my_empty_like", &boxed_empty_like); m.impl("fill_infinity", &boxed_fill_infinity); + m.impl("my_new_empty", &boxed_my_new_empty); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) { diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 817732371060..b73637ad0da4 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -203,3 +203,15 @@ def my_narrow(t, dim, start, length) -> Tensor: Returns: Narrowed tensor """ return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length) + + +def my_new_empty(t) -> Tensor: + """ + Returns a new empty tensor with shape [2, 5] and dtype bfloat16 + + Args: + t: Input tensor used as a reference for device and other properties + + Returns: New empty tensor with shape [2, 5] and dtype bfloat16 + """ + return torch.ops.libtorch_agnostic.my_new_empty.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index ae3c2767627f..e6738704939b 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -252,6 +252,21 @@ def test_my_narrow(self, device): expected0 = torch.narrow(t, dim0, start0, length0) self.assertEqual(out0, expected0) + def test_my_new_empty(self, device): + import libtorch_agnostic + + deterministic = torch.are_deterministic_algorithms_enabled() + try: + # set use_deterministic_algorithms to fill unintialized memory + torch.use_deterministic_algorithms(True) + t = torch.randn(3, 4, device=device) + out = libtorch_agnostic.ops.my_new_empty(t) + ref_out = t.new_empty((2, 5), dtype=torch.bfloat16) + + self.assertEqual(out, ref_out, exact_device=True) + finally: + torch.use_deterministic_algorithms(deterministic) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b1446318dd34..34c903b5d723 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -220,6 +220,9 @@ aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_layout(AtenTensorHandle tensor, int32_t* ret_layout); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AtenTensorHandle tensor, int64_t* ret_storage_offset); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h index d5bc50750fc7..dd4f3c30f9fa 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h @@ -16,6 +16,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); #ifdef __cplusplus diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 868da9831e76..ac0b663d434e 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -384,6 +384,15 @@ AOTITorchError aoti_torch_get_device_index( }); } +AOTITorchError aoti_torch_get_layout( + AtenTensorHandle tensor, + int32_t* ret_layout) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + *ret_layout = static_cast(t->layout()); + }); +} + AOTITorchError aoti_torch_get_storage_offset( AtenTensorHandle tensor, int64_t* ret_storage_offset) { diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 7ce25af14d3f..815169062b56 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -8,6 +8,7 @@ #include #include +#include using torch::stable::Tensor; @@ -51,6 +52,45 @@ inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) { return Tensor(ret0); } +// We expect this to be a stable version of the new_empty op that takes in +// only dtype information. +inline Tensor new_empty( + const Tensor& self, + std::vector size, + std::optional dtype = std::nullopt) { + int32_t device_type; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); + + int32_t device_index; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(self.get(), &device_index)); + + // Handle dtype - use input tensor's dtype if not specified + int32_t target_dtype; + if (dtype.has_value()) { + target_dtype = static_cast(dtype.value()); + } else { + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); + } + + int32_t layout; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); + + AtenTensorHandle ret0 = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( + self.get(), + size.data(), + static_cast(size.size()), + &target_dtype, + &layout, + &device_type, + device_index, + nullptr, // pin_memory (nullptr for default) + &ret0)); + + return Tensor(ret0); +} + // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as // use std::vector because diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index be00c49d7b1f..a60b3379e034 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -185,4 +185,5 @@ "aten.fill_.Scalar": {}, "aten.pad.default": {}, "aten.narrow.default": {}, + "aten.new_empty.default": {}, } diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 36db26bb5ea6..28dd5c05bb5a 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -24,6 +24,7 @@ OperatorName, OptionalType, Type, + Variant, ) from torchgen.utils import FileManager, mapMaybe @@ -396,7 +397,23 @@ def gen_static_dispatch_backend_call( ) -> str: sig = DispatcherSignature.from_schema(f.func) cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) + if backend_index is None: + # Check if this is a symint function and if the function only has method variants + if sig.symint and f.func.has_symint(): + # Check if the function has standalone function variants + has_function_variant = Variant.function in f.variants + + if not has_function_variant: + # Functions with both function and method variants can use the at::{*}_symint version + # (e.g., narrow -> at::narrow_symint), BUT + # Method-only functions with symint parameters should use at::symint:: namespace + # Remove the _symint suffix since at::symint:: namespace uses the base name + # (e.g., new_empty -> at::symint::new_empty) + base_name = cpp_sig.name() + base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix + return f"at::symint::{base_name}" + return f"at::{cpp_sig.name()}" else: return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"