Skip to content

Commit 9e500e3

Browse files
committed
unify broadcast_shapes functions and avoid duplicates
ghstack-source-id: 797e040 Pull-Request: #160251
1 parent 3dff4ef commit 9e500e3

File tree

2 files changed

+22
-61
lines changed

2 files changed

+22
-61
lines changed

torch/_refs/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,11 @@ def handle_noncontiguous_outputs(input_tlist, output):
385385

386386

387387
def _broadcast_shapes(*_shapes):
388-
from torch.fx.experimental.symbolic_shapes import guard_or_false
388+
from torch.fx.experimental.symbolic_shapes import (
389+
guard_or_false,
390+
guard_size_oblivious,
391+
is_nested_int,
392+
)
389393

390394
shapes = tuple(
391395
(x,) if isinstance(x, IntLike) else x
@@ -407,16 +411,24 @@ def _broadcast_shapes(*_shapes):
407411
] * reduce(max, (len(shape) for shape in shapes))
408412
for arg_idx, shape in enumerate(shapes):
409413
for idx in range(-1, -1 - len(shape), -1):
410-
# if both 1, or statically known the same, we rather pick non-broadcast path.
411-
if guard_or_false(common_shape[idx] == shape[idx]):
412-
continue
413-
elif guard_or_false(common_shape[idx] == 1):
414+
if is_nested_int(shape[idx]):
415+
# maintain nested int behaviour added in PR 145957.
416+
if is_nested_int(common_shape[idx]) and guard_or_false(
417+
shape[idx] == common_shape[idx]
418+
):
419+
continue
420+
else:
421+
if guard_or_false(shape[idx] == common_shape[idx]):
422+
continue
423+
424+
if guard_or_false(common_shape[idx] == 1):
414425
if shape[idx] < 0:
415426
raise ValueError(
416427
"Attempting to broadcast a dimension with negative length!"
417428
)
418429
common_shape[idx] = shape[idx]
419-
elif guard_or_false(shape[idx] == 1):
430+
431+
if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
420432
# broadcast case .
421433
continue
422434
else:
@@ -2780,9 +2792,7 @@ def cat_compute_output_memory_format(inputs):
27802792

27812793
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
27822794

2783-
from torch.fx.experimental.symbolic_shapes import (
2784-
guard_or_false,
2785-
)
2795+
from torch.fx.experimental.symbolic_shapes import guard_or_false
27862796

27872797
# This is a bit tricky. Naively, you would expect to just pick one
27882798
# arbitrary tensor and check that all tensors match this tensor. However,

torch/functional.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -105,58 +105,9 @@ def broadcast_shapes(*shapes):
105105
# This wrapper exists to support variadic args.
106106
# TODO Move this to C++ once the jit has better support for torch.Size.
107107
if not torch.jit.is_tracing():
108-
max_len = 0
109-
for shape in shapes:
110-
if isinstance(shape, (int, torch.SymInt)):
111-
if max_len < 1:
112-
max_len = 1
113-
elif isinstance(shape, (tuple, list)):
114-
s = len(shape)
115-
if max_len < s:
116-
max_len = s
117-
result = [1] * max_len
118-
119-
from torch.fx.experimental.symbolic_shapes import (
120-
guard_size_oblivious,
121-
is_nested_int,
122-
)
123-
124-
for shape in shapes:
125-
if isinstance(shape, (int, torch.SymInt)):
126-
shape = (shape,)
127-
if isinstance(shape, (tuple, list)):
128-
for i in range(-1, -1 - len(shape), -1):
129-
if shape[i] < 0:
130-
raise RuntimeError(
131-
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
132-
)
133-
134-
# NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
135-
if is_nested_int(shape[i]):
136-
# Broadcasting is allowed for (j0, 1) or (j0, j0);
137-
# not (j0, j1), (j0, 5), etc.
138-
if is_nested_int(result[i]) and guard_size_oblivious(
139-
shape[i] == result[i]
140-
):
141-
continue
142-
else:
143-
# NB: result is initialized to 1 so this is effectively an
144-
# equals one test
145-
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
146-
shape[i] == result[i]
147-
):
148-
continue
149-
150-
if result[i] != 1:
151-
raise RuntimeError(
152-
"Shape mismatch: objects cannot be broadcast to a single shape"
153-
)
154-
result[i] = shape[i]
155-
else:
156-
raise RuntimeError(
157-
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
158-
shape,
159-
)
108+
result = torch._refs._broadcast_shapes(*shapes)
109+
if result is None:
110+
return torch.Size([])
160111
return torch.Size(result)
161112
else:
162113
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail

0 commit comments

Comments
 (0)