Skip to content

Commit 5b95cf9

Browse files
committed
gso cat ops
ghstack-source-id: c29de49 Pull-Request: #160250
1 parent aeb5321 commit 5b95cf9

File tree

5 files changed

+10
-19
lines changed

5 files changed

+10
-19
lines changed

aten/src/ATen/WrapDimUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ inline int64_t legacy_cat_wrap_dim_symint(
121121
const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
122122
for (auto& sizes : tensor_sizes) {
123123
if (sizes.size() == 1) {
124-
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) {
124+
if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) {
125125
continue;
126126
}
127127
}
@@ -135,7 +135,7 @@ inline int64_t legacy_cat_wrap_dim(
135135
const MaterializedITensorListRef& tensors) {
136136
for (const Tensor& tensor : tensors) {
137137
if (tensor.dim() == 1) {
138-
if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) {
138+
if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) {
139139
continue;
140140
}
141141
}

torch/_inductor/decomposition.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@
3434
ELEMENTWISE_TYPE_PROMOTION_KIND,
3535
type_to_dtype,
3636
)
37-
from torch.fx.experimental.symbolic_shapes import (
38-
guard_or_false,
39-
guard_size_oblivious,
40-
statically_known_true,
41-
)
37+
from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true
4238

4339
from . import config, inductor_prims
4440
from .utils import (
@@ -404,10 +400,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool:
404400
# runtime assert forcing u0 to be zero. So if this hasn't happened,
405401
# we know that the unbacked SymInt has appropriate size and there are
406402
# no problems.
407-
if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
403+
if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0):
408404
return False
409405

410-
if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
406+
if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0):
411407
return False
412408

413409
return True

torch/_inductor/fx_passes/split_cat.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch._dynamo.utils import counters
12-
from torch.fx.experimental.symbolic_shapes import free_symbols
12+
from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false
1313
from torch.utils._ordered_set import OrderedSet
1414

1515
from ..pattern_matcher import (
@@ -307,8 +307,6 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
307307
pass_dict=construct_pattern_matcher_pass("normalization_pass"),
308308
)
309309
def normalize_cat_default(match: Match, *args, **kwargs):
310-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
311-
312310
cat_node = match.nodes[0]
313311
graph = match.graph
314312
tensors = get_arg_value(cat_node, 0, "tensors")
@@ -333,7 +331,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
333331
def is_empty_tensor(x):
334332
# special case where torch.cat supports cat'ing with an empty tensor
335333
x_shape = x.meta["example_value"].shape
336-
return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0)
334+
return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0)
337335

338336
assert all(
339337
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors

torch/_refs/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,10 +2780,7 @@ def cat_compute_output_memory_format(inputs):
27802780

27812781
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
27822782

2783-
from torch.fx.experimental.symbolic_shapes import (
2784-
guard_or_false,
2785-
guard_size_oblivious,
2786-
)
2783+
from torch.fx.experimental.symbolic_shapes import guard_or_false
27872784

27882785
# This is a bit tricky. Naively, you would expect to just pick one
27892786
# arbitrary tensor and check that all tensors match this tensor. However,
@@ -2837,7 +2834,7 @@ def cat_compute_output_memory_format(inputs):
28372834
# through), and is load bearing for our Inductor lowerings
28382835
# (which assume that size oblivious tests are OK to determine
28392836
# if a shape is permissibly zero.)
2840-
guard_size_oblivious(tensor.shape[0] == 0),
2837+
guard_or_false(tensor.shape[0] == 0),
28412838
lambda: f"Number of dimensions of tensors must match. "
28422839
f"Expected {example.ndim}-D tensors, but got 1-D for "
28432840
f"tensor number {tensor_idx} in the list",

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ std::vector<Tensor> cat_tensors_backward(
10781078
auto& shape = sizes[i];
10791079
// If input was empty tensor, gradInput should be empty tensor.
10801080
if (shape.size() == 1) {
1081-
if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) {
1081+
if (TORCH_GUARD_OR_FALSE(shape[0].sym_eq(0))) {
10821082
grad_inputs[i] = at::zeros({0}, grad_val.options());
10831083
continue;
10841084
}

0 commit comments

Comments
 (0)