Skip to content

Update torch::stable::Tensor() default constructor #159507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/mikaylagawarecki/330/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("my_zero_", &boxed_my_zero_);
}

bool test_default_constructor(bool defined) {
Tensor out;
if (defined) {
AtenTensorHandle defined_ath;
int64_t sizes[] = {2, 3};
int64_t strides[] = {3, 1};
aoti_torch_empty_strided(
2,
sizes,
strides,
aoti_torch_dtype_float32(),
aoti_torch_device_type_cpu(),
0,
&defined_ath);
out = Tensor(defined_ath);
}
return out.defined();
}

void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(to<bool>(stack[0]));
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_default_constructor(bool undefined) -> bool");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_default_constructor", &boxed_test_default_constructor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor:
Returns: The modified tensor (same as input)
"""
return torch.ops.libtorch_agnostic.fill_infinity.default(t)


def test_default_constructor(defined) -> bool:
"""
Tests the default constructor for torch::stable::Tensor.

Args:
defined: bool - if True, tests defined tensor; if False, tests undefined tensor

Returns: bool - result of calling .defined() on the tensor
"""
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,20 @@ def test_fill_infinity(self, device):
expected = torch.full_like(t, math.inf)
self.assertEqual(out, expected)

@onlyCPU
def test_default_constructor(self):
import libtorch_agnostic

defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor(
True
)
self.assertTrue(defined_tensor_is_defined)

undefined_tensor_is_defined = (
libtorch_agnostic.ops.test_default_constructor(False)
)
self.assertFalse(undefined_tensor_is_defined)

instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)

if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
AtenTensorHandle orig_handle,
AtenTensorHandle* new_handle);
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous(
});
}

AOTITorchError aoti_torch_is_defined(
AtenTensorHandle tensor,
bool* ret_is_defined) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret_is_defined = t->defined();
});
}

AOTITorchError aoti_torch_new_tensor_handle(
AtenTensorHandle orig_handle,
AtenTensorHandle* new_handle) {
Expand Down Expand Up @@ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) {
if (msg) {
std::cout << " " << msg;
}
std::cout << " "
<< "]:" << '\n';
std::cout << " " << "]:" << '\n';

// Print exact tensor values for small size tensors
const int64_t numel = t->numel();
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/stable/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ class Tensor {
std::shared_ptr<AtenTensorOpaque> ath_;

public:
Tensor() = delete;
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
// Steals ownership from the ATH
Tensor() {
AtenTensorHandle ret;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
});
}

// Construct a stable::Tensor from an AtenTensorHandle (ATH)
// Steals ownership from the ATH
Expand Down Expand Up @@ -115,6 +123,12 @@ class Tensor {
return size;
}

bool defined() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same semantic as the one in TensorBase.h right?

And to make sure I'm understanding correctly: an undefined tensor is an uninitialized tensor which I know has no memory but is it basically a nullptr?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to the first

AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* out_tensor = new at::Tensor();
*ret = tensor_pointer_to_tensor_handle(out_tensor);
});
}

Tensor() = default;

TensorBase() = default;

My understanding is TensorBase() = default constructor would initialize its c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_ to a nullptr indeed

I think there is a special case where UndefinedTensorImpl is not a nullptr but has defined its bool() method return False when defined() is called, but we're not using that here.

bool defined;
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
return defined;
}

// =============================================================================
// END of C-shimified TensorBase APIs
// =============================================================================
Expand Down
Loading