Skip to content

Commit c3dfa64

Browse files
committed
remove default gso from normal contiguity checks
ghstack-source-id: 6c9ccfd Pull-Request: #159197
1 parent 316c188 commit c3dfa64

File tree

16 files changed

+123
-30
lines changed

16 files changed

+123
-30
lines changed

aten/src/ATen/native/TensorProperties.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
5757
return self.sym_size(dim);
5858
}
5959

60+
c10::SymBool sym_is_contiguous(
61+
const Tensor& self,
62+
c10::MemoryFormat memory_format) {
63+
return self.sym_is_contiguous(memory_format);
64+
}
65+
6066
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
6167
return self.sym_stride(dim);
6268
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5496,6 +5496,13 @@
54965496
tags: core
54975497
manual_cpp_binding: True
54985498

5499+
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
5500+
variants: function
5501+
device_check: NoCheck
5502+
device_guard: False
5503+
tags: core
5504+
manual_cpp_binding: True
5505+
54995506
- func: sym_numel(Tensor self) -> SymInt
55005507
variants: function
55015508
device_check: NoCheck

c10/core/TensorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ void TensorImpl::throw_data_ptr_access_error() const {
313313
c10::SymBool TensorImpl::sym_is_contiguous_custom(
314314
at::MemoryFormat memory_format) const {
315315
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
316-
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
316+
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
317317
this, memory_format);
318318
}
319319

c10/core/impl/PyInterpreter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
6060
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
6161
PANIC(is_contiguous);
6262
}
63+
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
64+
const override {
65+
PANIC(sym_is_contiguous);
66+
}
6367
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
6468
const override {
6569
PANIC(is_strides_like);

c10/core/impl/PyInterpreter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ struct C10_API PyInterpreterVTable {
168168

169169
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
170170
const = 0;
171+
virtual c10::SymBool sym_is_contiguous(
172+
const TensorImpl* self,
173+
at::MemoryFormat) const = 0;
171174
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
172175
const = 0;
173176
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;

test/functorch/test_vmap_registrations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@
209209
"aten::subtract_.Tensor",
210210
"aten::svd.U",
211211
"aten::sym_size.int",
212+
"aten::sym_is_contiguous",
212213
"aten::sym_stride.int",
213214
"aten::sym_numel",
214215
"aten::sym_storage_offset",

tools/autograd/gen_python_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"sym_size",
101101
"sym_stride",
102102
"sym_storage_offset",
103+
"sym_is_contiguous",
103104
"sym_numel",
104105
".*_backward",
105106
".*_backward_(out|input|weight|bias)",

torch/_prims_common/__init__.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,14 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
265265
from torch.fx.experimental.symbolic_shapes import (
266266
guard_or_false,
267267
guard_or_true,
268-
guard_size_oblivious,
269268
is_nested_int,
270269
)
271270

272-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
271+
def eval_eager(x):
272+
return bool(x)
273+
274+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
275+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
274276

275277
if maybe_guard_or_false(a.numel() < 2):
276278
return True
@@ -305,14 +307,13 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
305307
if a.ndim != 4:
306308
return False
307309

308-
from torch.fx.experimental.symbolic_shapes import (
309-
guard_or_false,
310-
guard_or_true,
311-
guard_size_oblivious,
312-
)
310+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
311+
312+
def eval_eager(x):
313+
return bool(x)
313314

314-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
315-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
315+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
316+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
316317

317318
expected_stride = 1
318319
for idx in (1, 3, 2, 0):
@@ -334,14 +335,13 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
334335
if a.ndim != 5:
335336
return False
336337

337-
from torch.fx.experimental.symbolic_shapes import (
338-
guard_or_false,
339-
guard_or_true,
340-
guard_size_oblivious,
341-
)
338+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
339+
340+
def eval_eager(x):
341+
return bool(x)
342342

343-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
344-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
343+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
344+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
345345

346346
expected_stride = 1
347347
for idx in (1, 4, 3, 2, 0):
@@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
406406

407407

408408
# similar to is_contiguous_for_memory_format but return false on data dependency.
409-
def contiguous_for_memory_format_or_false( # type: ignore[return]
409+
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
410410
a: Tensor, *, memory_format: torch.memory_format
411411
) -> bool:
412412
return is_contiguous_for_memory_format(
@@ -547,11 +547,14 @@ def compute_elementwise_output_logical_to_physical_perm(
547547
is_contiguous = True
548548
is_channels_last = True
549549
for t in tensors:
550-
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
550+
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
551551
t, memory_format=torch.contiguous_format
552552
)
553-
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
554-
t, memory_format=torch.channels_last
553+
is_channels_last = (
554+
is_channels_last
555+
and is_contiguous_for_memory_format_or_false(
556+
t, memory_format=torch.channels_last
557+
)
555558
)
556559

557560
if is_contiguous and not is_channels_last:

torch/_refs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch import sym_float, sym_int
2020
from torch._prims_common import (
2121
BoolLike,
22-
contiguous_for_memory_format_or_false,
2322
DeviceLikeType,
2423
Dim,
2524
DimsSequenceType,
@@ -29,6 +28,7 @@
2928
FloatLike,
3029
FloatWithoutSymFloat,
3130
IntLike,
31+
is_contiguous_for_memory_format_or_false,
3232
is_contiguous_or_false,
3333
is_weakly_lesser_type,
3434
Number,
@@ -2991,7 +2991,7 @@ def contiguous(
29912991
)
29922992

29932993
# TODO: make logic consistent with aten contiguous
2994-
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
2994+
if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format):
29952995
return a
29962996

29972997
return torch.clone(a, memory_format=memory_format)

torch/_subclasses/fake_impls.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from torch._dispatch.python import no_python_dispatcher
1616
from torch._ops import OpOverload
1717
from torch._prims_common import (
18-
contiguous_for_memory_format_or_false,
1918
elementwise_dtypes,
2019
ELEMENTWISE_TYPE_PROMOTION_KIND,
2120
is_boolean_dtype,
2221
is_contiguous,
22+
is_contiguous_for_memory_format_or_false,
2323
is_contiguous_or_false,
2424
is_float_dtype,
2525
is_integer_dtype,
@@ -1242,13 +1242,13 @@ def slow(msg):
12421242
continue
12431243
definitely_contiguous = (
12441244
definitely_contiguous
1245-
and contiguous_for_memory_format_or_false(
1245+
and is_contiguous_for_memory_format_or_false(
12461246
op, memory_format=torch.contiguous_format
12471247
)
12481248
)
12491249
definitely_channels_last = (
12501250
definitely_channels_last
1251-
and contiguous_for_memory_format_or_false(
1251+
and is_contiguous_for_memory_format_or_false(
12521252
op, memory_format=torch.channels_last
12531253
)
12541254
)

0 commit comments

Comments
 (0)