From d1e4f7564e9492fa90e708282b1c4997500911b2 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 9 Aug 2025 08:56:13 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- aten/src/ATen/WrapDimUtils.h | 4 ++-- torch/_inductor/decomposition.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 6 +++--- torch/_refs/__init__.py | 3 +-- torch/csrc/autograd/FunctionsManual.cpp | 2 +- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 3b4b0ae02becf..aa000b118daa2 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -121,7 +121,7 @@ inline int64_t legacy_cat_wrap_dim_symint( const std::vector>& tensor_sizes) { for (auto& sizes : tensor_sizes) { if (sizes.size() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) { continue; } } @@ -135,7 +135,7 @@ inline int64_t legacy_cat_wrap_dim( const MaterializedITensorListRef& tensors) { for (const Tensor& tensor : tensors) { if (tensor.dim() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) { continue; } } diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d903d851ee872..8d05731b469b6 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -404,10 +404,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool: # runtime assert forcing u0 to be zero. So if this hasn't happened, # we know that the unbacked SymInt has appropriate size and there are # no problems. - if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0): return False - if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): + if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0): return False return True diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 327f96ae34ac7..3aa58fbf8d8c1 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -9,7 +9,7 @@ import torch from torch._dynamo.utils import counters -from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false from torch.utils._ordered_set import OrderedSet from ..pattern_matcher import ( @@ -307,7 +307,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs): pass_dict=construct_pattern_matcher_pass("normalization_pass"), ) def normalize_cat_default(match: Match, *args, **kwargs): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false cat_node = match.nodes[0] graph = match.graph @@ -333,7 +333,7 @@ def normalize_cat_default(match: Match, *args, **kwargs): def is_empty_tensor(x): # special case where torch.cat supports cat'ing with an empty tensor x_shape = x.meta["example_value"].shape - return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0) assert all( ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c82d7aaecb853..c5e175e0d7df0 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2782,7 +2782,6 @@ def cat_compute_output_memory_format(inputs): from torch.fx.experimental.symbolic_shapes import ( guard_or_false, - guard_size_oblivious, ) # This is a bit tricky. Naively, you would expect to just pick one @@ -2837,7 +2836,7 @@ def cat_compute_output_memory_format(inputs): # through), and is load bearing for our Inductor lowerings # (which assume that size oblivious tests are OK to determine # if a shape is permissibly zero.) - guard_size_oblivious(tensor.shape[0] == 0), + guard_or_false(tensor.shape[0] == 0), lambda: f"Number of dimensions of tensors must match. " f"Expected {example.ndim}-D tensors, but got 1-D for " f"tensor number {tensor_idx} in the list", diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 8e13d4267edb5..29dd6702ce31d 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1078,7 +1078,7 @@ std::vector cat_tensors_backward( auto& shape = sizes[i]; // If input was empty tensor, gradInput should be empty tensor. if (shape.size() == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) { + if (TORCH_GUARD_OR_FALSE(shape[0].sym_eq(0))) { grad_inputs[i] = at::zeros({0}, grad_val.options()); continue; } From c174a70e3210a857d9b5d20ac387fd5ffccbb719 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 9 Aug 2025 09:51:25 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torch/_inductor/decomposition.py | 6 +----- torch/_inductor/fx_passes/split_cat.py | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 8d05731b469b6..e7f8e3c319646 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -34,11 +34,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) -from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - guard_size_oblivious, - statically_known_true, -) +from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from . import config, inductor_prims from .utils import ( diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 3aa58fbf8d8c1..a5fc96a5ec4ed 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -307,8 +307,6 @@ def normalize_unbind_default(match: Match, *args, **kwargs): pass_dict=construct_pattern_matcher_pass("normalization_pass"), ) def normalize_cat_default(match: Match, *args, **kwargs): - from torch.fx.experimental.symbolic_shapes import guard_or_false - cat_node = match.nodes[0] graph = match.graph tensors = get_arg_value(cat_node, 0, "tensors") From 55997bb0aac5c7962688b5d1be47c39e8c98cb8d Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 9 Aug 2025 10:37:00 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- torch/_refs/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c5e175e0d7df0..989fac45e1c45 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2780,9 +2780,7 @@ def cat_compute_output_memory_format(inputs): utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - ) + from torch.fx.experimental.symbolic_shapes import guard_or_false # This is a bit tricky. Naively, you would expect to just pick one # arbitrary tensor and check that all tensors match this tensor. However, From ed37641139920436a395b02aa954e0b850d98453 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 9 Aug 2025 10:43:08 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- torch/_refs/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 989fac45e1c45..c5e175e0d7df0 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2780,7 +2780,9 @@ def cat_compute_output_memory_format(inputs): utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) - from torch.fx.experimental.symbolic_shapes import guard_or_false + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + ) # This is a bit tricky. Naively, you would expect to just pick one # arbitrary tensor and check that all tensors match this tensor. However,