Skip to content

Add new_empty (with dtype argument only) to torch::stable #159508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/mikaylagawarecki/331/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>

#include <torch/headeronly/core/ScalarType.h>

#include <optional>

void inline sgd_math(
Expand Down Expand Up @@ -322,18 +324,31 @@ void boxed_my_narrow(
stack[0] = from(res);
}

Tensor my_new_empty(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::BFloat16);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto dtype = std::make_optional(at::ScalarType::BFloat16);
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);

return new_empty(t, sizes, dtype);
}

void boxed_my_new_empty(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty(to<Tensor>(stack[0]));
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
m.def("my_empty_like(Tensor t) -> Tensor");
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
m.def("my_pad(Tensor t) -> Tensor");
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
m.def("my_new_empty(Tensor t) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_transpose", &boxed_my_transpose);
m.impl("my_empty_like", &boxed_empty_like);
m.impl("fill_infinity", &boxed_fill_infinity);
m.impl("my_new_empty", &boxed_my_new_empty);
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,15 @@ def my_narrow(t, dim, start, length) -> Tensor:
Returns: Narrowed tensor
"""
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length)


def my_new_empty(t) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's name this to something more specific which indicates the test better

Suggested change
def my_new_empty(t) -> Tensor:
def my_new_empty_dtype_variant(t) -> Tensor:

"""
Returns a new empty tensor with shape [2, 5] and dtype bfloat16
Args:
t: Input tensor used as a reference for device and other properties
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
"""
return torch.ops.libtorch_agnostic.my_new_empty.default(t)
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,21 @@ def test_my_narrow(self, device):
expected0 = torch.narrow(t, dim0, start0, length0)
self.assertEqual(out0, expected0)

def test_my_new_empty(self, device):
import libtorch_agnostic

deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill unintialized memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# set use_deterministic_algorithms to fill unintialized memory
# set use_deterministic_algorithms to fill uninitialized memory

nit

torch.use_deterministic_algorithms(True)
t = torch.randn(3, 4, device=device)
out = libtorch_agnostic.ops.my_new_empty(t)
ref_out = t.new_empty((2, 5), dtype=torch.bfloat16)

self.assertEqual(out, ref_out, exact_device=True)
finally:
torch.use_deterministic_algorithms(deterministic)

instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)

if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_layout(AtenTensorHandle tensor, int32_t* ret_layout);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
AtenTensorHandle tensor,
int64_t* ret_storage_offset);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern "C" {

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);

#ifdef __cplusplus
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,15 @@ AOTITorchError aoti_torch_get_device_index(
});
}

AOTITorchError aoti_torch_get_layout(
AtenTensorHandle tensor,
int32_t* ret_layout) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret_layout = static_cast<int32_t>(t->layout());
});
}

AOTITorchError aoti_torch_get_storage_offset(
AtenTensorHandle tensor,
int64_t* ret_storage_offset) {
Expand Down
40 changes: 40 additions & 0 deletions torch/csrc/stable/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/headeronly/core/ScalarType.h>

using torch::stable::Tensor;

Expand Down Expand Up @@ -51,6 +52,45 @@ inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) {
return Tensor(ret0);
}

// We expect this to be a stable version of the new_empty op that takes in
// only dtype information.
inline Tensor new_empty(
const Tensor& self,
std::vector<int64_t> size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));

int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(self.get(), &device_index));

// Handle dtype - use input tensor's dtype if not specified
int32_t target_dtype;
if (dtype.has_value()) {
target_dtype = static_cast<int32_t>(dtype.value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm lol is this actually stable...this looks to expose a detail on how dtype is encapsulated. Should we/I write that translation layer from the headeronly ScalarType to their corresponding shim aoti_torch_get_dtype... maybe there's no other way to hide this

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be nice if you could write that!

Can you describe what you mean in more detail of what the translation layer would do

} else {
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
}

int32_t layout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));

AtenTensorHandle ret0 = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AtenTensorHandle ret0 = nullptr;
AtenTensorHandle ret0;

nit

TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
self.get(),
size.data(),
static_cast<int64_t>(size.size()),
&target_dtype,
&layout,
&device_type,
device_index,
nullptr, // pin_memory (nullptr for default)
&ret0));

return Tensor(ret0);
}

// We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because
Expand Down
1 change: 1 addition & 0 deletions torchgen/aoti/fallback_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@
"aten.fill_.Scalar": {},
"aten.pad.default": {},
"aten.narrow.default": {},
"aten.new_empty.default": {},
}
17 changes: 17 additions & 0 deletions torchgen/gen_aoti_c_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OperatorName,
OptionalType,
Type,
Variant,
)
from torchgen.utils import FileManager, mapMaybe

Expand Down Expand Up @@ -396,7 +397,23 @@ def gen_static_dispatch_backend_call(
) -> str:
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)

if backend_index is None:
# Check if this is a symint function and if the function only has method variants
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other ops we expect to go through this branch? Maybe list new_empty as an example here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm it's listed below (when all the if checks pass) on 412

if sig.symint and f.func.has_symint():
# Check if the function has standalone function variants
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are standalone function variants

has_function_variant = Variant.function in f.variants

if not has_function_variant:
# Functions with both function and method variants can use the at::{*}_symint version
# (e.g., narrow -> at::narrow_symint), BUT
# Method-only functions with symint parameters should use at::symint:: namespace
# Remove the _symint suffix since at::symint:: namespace uses the base name
# (e.g., new_empty -> at::symint::new_empty<c10::SymInt>)
base_name = cpp_sig.name()
base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix
return f"at::symint::{base_name}<c10::SymInt>"

Comment on lines +407 to +416
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the relevant part here

pytorch/torchgen/gen.py

Lines 717 to 739 in 6d91d6d

if Variant.function in f.variants:
result += f"""
// aten::{f.func}
inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}"""
# The template function can be used from template situations
# where you want to switch between the symint or not version
# depending on a template argument
#
# NB: we ALWAYS generate this even for methods. But we put it in
# this header so it can take advantage of per-op headers
if has_symint:
result += f"""
namespace symint {{
template <typename T, typename = std::enable_if_t<std::is_same_v<T, {intlike_t}>>>
{sig.decl(suppress_symint_suffix=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
}}
"""
return result

return f"at::{cpp_sig.name()}"
else:
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
Expand Down
Loading