diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 3b4b0ae02bec..aa000b118daa 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -121,7 +121,7 @@ inline int64_t legacy_cat_wrap_dim_symint( const std::vector>& tensor_sizes) { for (auto& sizes : tensor_sizes) { if (sizes.size() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) { continue; } } @@ -135,7 +135,7 @@ inline int64_t legacy_cat_wrap_dim( const MaterializedITensorListRef& tensors) { for (const Tensor& tensor : tensors) { if (tensor.dim() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) { continue; } } diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d903d851ee87..e7f8e3c31964 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -34,11 +34,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) -from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - guard_size_oblivious, - statically_known_true, -) +from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from . import config, inductor_prims from .utils import ( @@ -404,10 +400,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool: # runtime assert forcing u0 to be zero. So if this hasn't happened, # we know that the unbacked SymInt has appropriate size and there are # no problems. - if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0): return False - if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): + if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0): return False return True diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 327f96ae34ac..a5fc96a5ec4e 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -9,7 +9,7 @@ import torch from torch._dynamo.utils import counters -from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false from torch.utils._ordered_set import OrderedSet from ..pattern_matcher import ( @@ -307,8 +307,6 @@ def normalize_unbind_default(match: Match, *args, **kwargs): pass_dict=construct_pattern_matcher_pass("normalization_pass"), ) def normalize_cat_default(match: Match, *args, **kwargs): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - cat_node = match.nodes[0] graph = match.graph tensors = get_arg_value(cat_node, 0, "tensors") @@ -333,7 +331,7 @@ def normalize_cat_default(match: Match, *args, **kwargs): def is_empty_tensor(x): # special case where torch.cat supports cat'ing with an empty tensor x_shape = x.meta["example_value"].shape - return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0) assert all( ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c82d7aaecb85..c5e175e0d7df 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2782,7 +2782,6 @@ def cat_compute_output_memory_format(inputs): from torch.fx.experimental.symbolic_shapes import ( guard_or_false, - guard_size_oblivious, ) # This is a bit tricky. Naively, you would expect to just pick one @@ -2837,7 +2836,7 @@ def cat_compute_output_memory_format(inputs): # through), and is load bearing for our Inductor lowerings # (which assume that size oblivious tests are OK to determine # if a shape is permissibly zero.) - guard_size_oblivious(tensor.shape[0] == 0), + guard_or_false(tensor.shape[0] == 0), lambda: f"Number of dimensions of tensors must match. " f"Expected {example.ndim}-D tensors, but got 1-D for " f"tensor number {tensor_idx} in the list", diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 8e13d4267edb..29dd6702ce31 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1078,7 +1078,7 @@ std::vector cat_tensors_backward( auto& shape = sizes[i]; // If input was empty tensor, gradInput should be empty tensor. if (shape.size() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(shape[0].sym_eq(0))) { grad_inputs[i] = at::zeros({0}, grad_val.options()); continue; }