Skip to content

Commit 3cb34d0

Browse files
committed
[inductor] propagate shapes in CSEVariable
ghstack-source-id: 5f115d1 Pull Request resolved: #152198
1 parent 556e2a7 commit 3cb34d0

File tree

9 files changed

+476
-113
lines changed

9 files changed

+476
-113
lines changed

torch/_inductor/codegen/common.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .. import config, metrics
4545
from ..dtype_propagation import DtypePropagationOpsHandler
4646
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
47+
from ..shape_propagation import ShapePropagationOpsHandler
4748
from ..utils import (
4849
boolean_ops,
4950
DeferredLineBase,
@@ -70,6 +71,7 @@
7071
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
7172
from ..loop_body import LoopBody
7273
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
74+
from ..shape_propagation import BlockShapeType
7375
from .wrapper import PythonWrapperCodegen
7476

7577
_T = TypeVar("_T")
@@ -1770,13 +1772,15 @@ def __init__(
17701772
name: str,
17711773
bounds: ValueRanges[Any],
17721774
dtype: Optional[torch.dtype] = None,
1775+
shape: BlockShapeType = None,
17731776
):
17741777
super().__init__()
17751778
assert isinstance(bounds, ValueRanges), type(bounds)
17761779
self.name = name
17771780
self.bounds = bounds
17781781
self.use_count = 1 # track how many times this expression is used
17791782
self.dtype = dtype
1783+
self.shape = shape
17801784

17811785
def __str__(self) -> str:
17821786
return self.name
@@ -1886,6 +1890,7 @@ def generate(
18861890
write: bool = True,
18871891
assignment: bool = True,
18881892
dtype: Optional[torch.dtype] = None,
1893+
shape: BlockShapeType = None,
18891894
) -> CSEVariableType:
18901895
if isinstance(expr, OpsValue):
18911896
expr = expr.value
@@ -1906,8 +1911,12 @@ def generate(
19061911
assert isinstance(expr, str)
19071912
cache_key = expr
19081913
var = self.try_get(cache_key)
1914+
if shape is None and not assignment:
1915+
# since there's no assignment to a variable, use any shape here
1916+
# other than None to avoid the unknown shape failures
1917+
shape = ()
19091918
if not var:
1910-
var = self.newvar(bounds, dtype)
1919+
var = self.newvar(bounds, dtype, shape)
19111920
self.put(cache_key, var)
19121921
if write:
19131922
if V.kernel.current_node:
@@ -1953,9 +1962,10 @@ def newvar(
19531962
self,
19541963
bounds: ValueRanges[Any] = ValueRanges.unknown(),
19551964
dtype: Optional[torch.dtype] = None,
1965+
shape: BlockShapeType = None,
19561966
) -> CSEVariableType:
19571967
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1958-
var = V.kernel.create_cse_var(var_name, bounds, dtype)
1968+
var = V.kernel.create_cse_var(var_name, bounds, dtype, shape)
19591969
self.varname_map[var_name] = var
19601970
return var
19611971

@@ -1964,11 +1974,12 @@ def namedvar(
19641974
name: str,
19651975
bounds: ValueRanges[Any] = ValueRanges.unknown(),
19661976
dtype: Optional[torch.dtype] = None,
1977+
shape: BlockShapeType = None,
19671978
) -> CSEVariableType:
19681979
torch._check_value(
19691980
name not in self.varname_map, lambda: f"duplicate name: {name}"
19701981
)
1971-
var = V.kernel.create_cse_var(name, bounds, dtype)
1982+
var = V.kernel.create_cse_var(name, bounds, dtype, shape)
19721983
self.varname_map[name] = var
19731984
return var
19741985

@@ -2424,45 +2435,64 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
24242435

24252436
value = getattr(self.parent_handler, name)(*args, **kwargs)
24262437
dtype_handler = DtypePropagationOpsHandler()
2438+
shape_handler = ShapePropagationOpsHandler()
24272439

24282440
backend = get_current_backend()
24292441

2442+
shape_op = getattr(shape_handler, name)
24302443
output_dtype = None
2444+
output_shape = None
2445+
24312446
if name == "masked" and backend == "triton":
24322447
output_dtype = value.dtype
2448+
output_shape = value.shape
24332449
elif name == "masked" and backend == "cpp":
24342450
output_dtype = V.interpreter.current_node.meta.get(
24352451
OptimizationContext.key, None
24362452
).dtype
2453+
# TODO: fix me
2454+
output_shape = None
24372455
elif backend in ("triton", "cpp", "mps"):
24382456
dtype_op = getattr(dtype_handler, name)
24392457
output_dtype = dtype_op(*args, **kwargs)
2458+
output_shape = shape_op(*args, **kwargs)
24402459

24412460
if backend in ("triton", "cpp"):
24422461
# maybe there are some exceptions on mps?
24432462
assert output_dtype is not None
24442463

24452464
output_idx = 0
24462465

2447-
def do_cse(v: str) -> CSEVariable:
2466+
def do_cse(v: Union[str, CSEVariable]) -> CSEVariable:
24482467
# we tree_map over the output, so we need to fetch corresponding dtype
24492468
nonlocal output_idx
24502469
var_dtype: Optional[torch.dtype] = (
24512470
output_dtype[output_idx]
24522471
if isinstance(output_dtype, (list, tuple))
24532472
else output_dtype
24542473
)
2474+
var_shape: BlockShapeType = (
2475+
output_shape[output_idx] # type: ignore[assignment]
2476+
if isinstance(output_shape, (list, tuple))
2477+
and len(output_shape) > 0
2478+
and isinstance(output_shape[0], (list, tuple))
2479+
else output_shape
2480+
)
24552481
output_idx += 1
24562482

24572483
# some cpp op implementations don't set the dtype
2458-
if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None:
2459-
v.dtype = var_dtype
2484+
if isinstance(v, CSEVariable):
2485+
if backend == "cpp" and v.dtype is None:
2486+
v.dtype = var_dtype
2487+
if v.shape is None:
2488+
v.shape = var_shape
24602489

24612490
csevar = V.kernel.cse.generate(
24622491
V.kernel.compute,
24632492
v,
24642493
bounds=bounds,
24652494
dtype=output_dtype,
2495+
shape=output_shape,
24662496
)
24672497

24682498
csevar.update_on_args(name, args, kwargs)
@@ -2559,7 +2589,13 @@ def indirect_indexing(
25592589
pos = var.bounds & ValueRanges(0, int_oo)
25602590
new_bounds = new_bounds | pos
25612591

2562-
var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds)
2592+
var = self.kernel.cse.generate(
2593+
self.kernel.compute,
2594+
stm,
2595+
bounds=new_bounds,
2596+
dtype=var.dtype,
2597+
shape=var.shape,
2598+
)
25632599

25642600
sympy_var = self.parent_handler.indirect_indexing(var, size, check)
25652601
if generate_assert(check):

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,8 @@ def frexp(x):
933933
return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys)
934934

935935
code = BracesBuffer()
936-
exponent = V.kernel.cse.newvar(dtype=torch.int32)
937-
mantissa = V.kernel.cse.newvar(dtype=x.dtype)
936+
exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape)
937+
mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape)
938938
code.writeline(f"int32_t {exponent};")
939939
code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
940940
V.kernel.compute.splice(code)

torch/_inductor/codegen/cpp_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..dependencies import Dep
2323
from ..loop_body import LoopBody
2424
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
25+
from ..shape_propagation import BlockShapeType
2526
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
2627
from ..virtualized import ops, OpsValue, V
2728
from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
@@ -145,8 +146,9 @@ def __init__(
145146
name,
146147
bounds: ValueRanges[Any],
147148
dtype: Optional[torch.dtype] = None,
149+
shape: BlockShapeType = None,
148150
) -> None:
149-
super().__init__(name, bounds, dtype)
151+
super().__init__(name, bounds, dtype, shape=shape)
150152
self.is_vec = False
151153
self.dependent_itervars = OrderedSet[sympy.Symbol]()
152154

torch/_inductor/codegen/halide.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from collections.abc import Sequence
5555

5656
from ..ops_handler import ReductionType, StoreMode
57+
from ..shape_propagation import BlockShapeType
5758

5859
log = logging.getLogger(__name__)
5960

@@ -556,6 +557,7 @@ def masked(mask, body, other):
556557
f"hl.cast({result.name}.type(), {halide_constant(other)})",
557558
[],
558559
bounds=ValueRanges.wrap(other),
560+
shape=result.shape,
559561
)
560562
# TODO(jansel): look into removing the where in the same places triton does
561563
return ops.where(new_mask, result, other)
@@ -576,8 +578,9 @@ def __init__(
576578
name,
577579
bounds: ValueRanges[Any],
578580
dtype: Optional[torch.dtype] = None,
581+
shape: BlockShapeType = None,
579582
) -> None:
580-
super().__init__(name, bounds, dtype)
583+
super().__init__(name, bounds, dtype, shape=shape)
581584
self.used_dims: Optional[list[sympy.Symbol]] = None
582585

583586
def update_on_args(self, name, args, kwargs):
@@ -702,9 +705,9 @@ def __init__(
702705
def dtype_to_str(self, dtype: torch.dtype) -> str:
703706
return halide_type(dtype)
704707

705-
def create_cse_var(self, name, bounds=None, dtype=None):
708+
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
706709
self.body.writeline(f"{name} = hl.Func({name!r})")
707-
return HalideCSEVariable(name, bounds, dtype)
710+
return HalideCSEVariable(name, bounds, dtype, shape)
708711

709712
def finalize_indexing(self, indices: Sequence[sympy.Expr]):
710713
"""
@@ -1196,12 +1199,13 @@ def reduction(
11961199
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
11971200
reduction_vars = OrderedSet(self.reduction_renames)
11981201
result_var = self.newfunc(
1199-
[v for v in value.used_dims if v not in reduction_vars]
1202+
[v for v in value.used_dims if v not in reduction_vars],
12001203
)
12011204
if reduction_vars - OrderedSet(value.used_dims):
12021205
value = self.genfunc(
12031206
f"{value}",
12041207
self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
1208+
shape=value.shape,
12051209
)
12061210
value_str = value.subs_str(self.reduction_renames)
12071211
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
@@ -1291,7 +1295,9 @@ def scan(
12911295
else:
12921296
values.append(
12931297
self.genfunc(
1294-
f"{value}", [*value.used_dims, [*self.reduction_renames][:1]]
1298+
f"{value}",
1299+
[*value.used_dims, [*self.reduction_renames][:1]],
1300+
shape=value.shape,
12951301
)
12961302
)
12971303
all_used_dims.update(value.used_dims)
@@ -1355,15 +1361,20 @@ def maybe_tuple(x):
13551361
return tuple(unpack_vars)
13561362

13571363
def genfunc(
1358-
self, line, used_dims, *, bounds=ValueRanges.unknown()
1364+
self,
1365+
line,
1366+
used_dims,
1367+
*,
1368+
bounds=ValueRanges.unknown(),
1369+
shape: BlockShapeType = None,
13591370
) -> HalideCSEVariable:
1360-
var = self.cse.generate(self.body, line, bounds=bounds)
1371+
var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
13611372
assert isinstance(var, HalideCSEVariable)
13621373
var.used_dims = used_dims
13631374
return var
13641375

1365-
def newfunc(self, used_dims) -> HalideCSEVariable:
1366-
var = self.cse.newvar()
1376+
def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable:
1377+
var = self.cse.newvar(shape=shape)
13671378
assert isinstance(var, HalideCSEVariable)
13681379
var.used_dims = used_dims
13691380
return var

0 commit comments

Comments
 (0)