Skip to content

Commit b4026d3

Browse files
committed
gso cat ops
ghstack-source-id: 667e71e Pull-Request: #160250
1 parent aeb5321 commit b4026d3

File tree

5 files changed

+9
-10
lines changed

5 files changed

+9
-10
lines changed

aten/src/ATen/WrapDimUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ inline int64_t legacy_cat_wrap_dim_symint(
121121
const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
122122
for (auto& sizes : tensor_sizes) {
123123
if (sizes.size() == 1) {
124-
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) {
124+
if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) {
125125
continue;
126126
}
127127
}
@@ -135,7 +135,7 @@ inline int64_t legacy_cat_wrap_dim(
135135
const MaterializedITensorListRef& tensors) {
136136
for (const Tensor& tensor : tensors) {
137137
if (tensor.dim() == 1) {
138-
if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) {
138+
if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) {
139139
continue;
140140
}
141141
}

torch/_inductor/decomposition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool:
404404
# runtime assert forcing u0 to be zero. So if this hasn't happened,
405405
# we know that the unbacked SymInt has appropriate size and there are
406406
# no problems.
407-
if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
407+
if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0):
408408
return False
409409

410-
if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
410+
if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0):
411411
return False
412412

413413
return True

torch/_inductor/fx_passes/split_cat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch._dynamo.utils import counters
12-
from torch.fx.experimental.symbolic_shapes import free_symbols
12+
from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false
1313
from torch.utils._ordered_set import OrderedSet
1414

1515
from ..pattern_matcher import (
@@ -307,7 +307,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
307307
pass_dict=construct_pattern_matcher_pass("normalization_pass"),
308308
)
309309
def normalize_cat_default(match: Match, *args, **kwargs):
310-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
310+
from torch.fx.experimental.symbolic_shapes import guard_or_false
311311

312312
cat_node = match.nodes[0]
313313
graph = match.graph
@@ -333,7 +333,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
333333
def is_empty_tensor(x):
334334
# special case where torch.cat supports cat'ing with an empty tensor
335335
x_shape = x.meta["example_value"].shape
336-
return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0)
336+
return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0)
337337

338338
assert all(
339339
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors

torch/_refs/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2782,7 +2782,6 @@ def cat_compute_output_memory_format(inputs):
27822782

27832783
from torch.fx.experimental.symbolic_shapes import (
27842784
guard_or_false,
2785-
guard_size_oblivious,
27862785
)
27872786

27882787
# This is a bit tricky. Naively, you would expect to just pick one
@@ -2837,7 +2836,7 @@ def cat_compute_output_memory_format(inputs):
28372836
# through), and is load bearing for our Inductor lowerings
28382837
# (which assume that size oblivious tests are OK to determine
28392838
# if a shape is permissibly zero.)
2840-
guard_size_oblivious(tensor.shape[0] == 0),
2839+
guard_or_false(tensor.shape[0] == 0),
28412840
lambda: f"Number of dimensions of tensors must match. "
28422841
f"Expected {example.ndim}-D tensors, but got 1-D for "
28432842
f"tensor number {tensor_idx} in the list",

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ std::vector<Tensor> cat_tensors_backward(
10781078
auto& shape = sizes[i];
10791079
// If input was empty tensor, gradInput should be empty tensor.
10801080
if (shape.size() == 1) {
1081-
if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) {
1081+
if (TORCH_GUARD_OR_FALSE(shape[0].sym_eq(0))) {
10821082
grad_inputs[i] = at::zeros({0}, grad_val.options());
10831083
continue;
10841084
}

0 commit comments

Comments
 (0)