Skip to content

Commit 512adb2

Browse files
committed
migrate more simple gso checks
ghstack-source-id: 4a17433 Pull-Request: #160253
1 parent c79806b commit 512adb2

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

torch/_decomp/decompositions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,9 +1780,9 @@ def _fused_rms_norm_backward(
17801780

17811781
N = prod(inner_dims) # type: ignore[arg-type]
17821782
M = prod(outer_dims) # type: ignore[arg-type]
1783-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1783+
from torch.fx.experimental.symbolic_shapes import guard_or_false
17841784

1785-
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
1785+
if guard_or_false(M == 0) or guard_or_false(N == 0):
17861786
return (
17871787
input.new_zeros(input_shape) if output_mask[0] else None,
17881788
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
@@ -3987,9 +3987,9 @@ def _unsafe_masked_index(x, mask, indices, fill):
39873987
lambda: "tensors used as masks must be bool tensors",
39883988
)
39893989

3990-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
3990+
from torch.fx.experimental.symbolic_shapes import guard_or_false
39913991

3992-
if guard_size_oblivious(x.numel() == 0):
3992+
if guard_or_false(x.numel() == 0):
39933993
meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
39943994
return x.new_full(meta_result.shape, fill)
39953995

torch/_prims_common/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,10 +1879,12 @@ def compute_required_storage_length(
18791879
40
18801880
18811881
"""
1882-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1882+
from torch.fx.experimental.symbolic_shapes import guard_or_false
18831883

18841884
# Short-circuits if the shape has no elements
1885-
if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):
1885+
# Note: we are unsafely assuming tensor is not empty here, without
1886+
# runtime assertions.
1887+
if guard_or_false(reduce(operator.mul, shape, 1) == 0):
18861888
return 0
18871889

18881890
max_offset = sum((x - 1) * y for x, y in zip(shape, strides))

torch/_refs/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def _broadcast_shapes(*_shapes):
412412
for arg_idx, shape in enumerate(shapes):
413413
for idx in range(-1, -1 - len(shape), -1):
414414
if is_nested_int(shape[idx]):
415-
# maintian nested int behaviour added in PR 145957.
415+
# maintain nested int behaviour added in PR 145957.
416416
if is_nested_int(common_shape[idx]) and guard_or_false(
417417
shape[idx] == common_shape[idx]
418418
):
@@ -4194,7 +4194,7 @@ def index_select(x: TensorLike, dim: int, index: TensorLike):
41944194

41954195
@register_decomposition(aten.squeeze.dims)
41964196
def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
4197-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4197+
from torch.fx.experimental.symbolic_shapes import guard_or_false
41984198

41994199
if dim is None:
42004200
dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
@@ -4209,7 +4209,8 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
42094209
return prims.view_of(a)
42104210

42114211
# Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
4212-
dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
4212+
# would it be better if we just not allow 1 for unbacked at runtiume?
4213+
dims = tuple(d for d in dims if guard_or_false(a.shape[d] == 1))
42134214
if len(dims) == 0:
42144215
return prims.view_of(a)
42154216
if len(dims) == 1:

0 commit comments

Comments
 (0)