From ed9dacca8126dc21def981b021e0b559e90b9e5d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 25 Apr 2025 17:11:04 +0000 Subject: [PATCH 01/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 24 +++- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/halide.py | 25 ++-- torch/_inductor/codegen/triton.py | 124 +++++++++++++++---- torch/_inductor/codegen/triton_split_scan.py | 8 ++ torch/_inductor/select_algorithm.py | 3 +- torch/_inductor/shape_propagation.py | 108 ++++++++++++++++ 7 files changed, 258 insertions(+), 38 deletions(-) create mode 100644 torch/_inductor/shape_propagation.py diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index b85cfa778a8c..b889579823c8 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -74,6 +74,8 @@ # causes typing errors in subclasses (defined in other files). OpVarT = str + ShapeType = Optional[Sequence[Union[int, str]]] + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") log = logging.getLogger(__name__) @@ -1645,6 +1647,7 @@ def __init__( name: str, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ): super().__init__() assert isinstance(bounds, ValueRanges) @@ -1652,6 +1655,7 @@ def __init__( self.bounds = bounds self.use_count = 1 # track how many times this expression is used self.dtype = dtype + self.shape = shape def __str__(self) -> str: return self.name @@ -1761,6 +1765,7 @@ def generate( write: bool = True, assignment: bool = True, dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ) -> CSEVariableType: if isinstance(expr, OpsValue): expr = expr.value @@ -1782,7 +1787,7 @@ def generate( cache_key = expr var = self.try_get(cache_key) if not var: - var = self.newvar(bounds, dtype) + var = self.newvar(bounds, dtype, shape) self.put(cache_key, var) if write: if V.kernel.current_node: @@ -1828,9 +1833,10 @@ def newvar( self, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ) -> CSEVariableType: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" - var = V.kernel.create_cse_var(var_name, bounds, dtype) + var = V.kernel.create_cse_var(var_name, bounds, dtype, shape) self.varname_map[var_name] = var return var @@ -1839,11 +1845,12 @@ def namedvar( name: str, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ) -> CSEVariableType: torch._check_value( name not in self.varname_map, lambda: f"duplicate name: {name}" ) - var = V.kernel.create_cse_var(name, bounds, dtype) + var = V.kernel.create_cse_var(name, bounds, dtype, shape) self.varname_map[name] = var return var @@ -2319,7 +2326,7 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> output_idx = 0 - def do_cse(v: str) -> CSEVariable: + def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: # we tree_map over the output, so we need to fetch corresponding dtype nonlocal output_idx var_dtype: torch.dtype = ( @@ -2338,6 +2345,7 @@ def do_cse(v: str) -> CSEVariable: v, bounds=bounds, dtype=output_dtype, + shape=getattr(v, "shape", None), ) csevar.update_on_args(name, args, kwargs) @@ -2427,7 +2435,13 @@ def indirect_indexing( pos = var.bounds & ValueRanges(0, int_oo) new_bounds = new_bounds | pos - var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds) + var = self.kernel.cse.generate( + self.kernel.compute, + stm, + bounds=new_bounds, + dtype=var.dtype, + shape=var.shape, + ) sympy_var = self.parent_handler.indirect_indexing(var, size, check) if generate_assert(check): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index f966e9cf8dd9..cd1d1db2f04f 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -902,8 +902,8 @@ def frexp(x): return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) code = BracesBuffer() - exponent = V.kernel.cse.newvar(dtype=torch.int32) - mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) + mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) code.writeline(f"int32_t {exponent};") code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") V.kernel.compute.splice(code) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 1339c99aa479..1a4a2196dd01 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -42,6 +42,7 @@ KernelArgType, OpOverrides, PythonPrinter, + ShapeType, SizeArg, TensorArg, ) @@ -556,6 +557,7 @@ def masked(mask, body, other): f"hl.cast({result.name}.type(), {halide_constant(other)})", [], bounds=ValueRanges.wrap(other), + shape=result.shape, ) # TODO(jansel): look into removing the where in the same places triton does return ops.where(new_mask, result, other) @@ -576,8 +578,9 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ) -> None: - super().__init__(name, bounds, dtype) + super().__init__(name, bounds, dtype, shape=shape) self.used_dims: Optional[list[sympy.Symbol]] = None def update_on_args(self, name, args, kwargs): @@ -1196,12 +1199,13 @@ def reduction( assert isinstance(value, HalideCSEVariable) and value.used_dims is not None reduction_vars = OrderedSet(self.reduction_renames) result_var = self.newfunc( - [v for v in value.used_dims if v not in reduction_vars] + [v for v in value.used_dims if v not in reduction_vars], ) if reduction_vars - OrderedSet(value.used_dims): value = self.genfunc( f"{value}", self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))), + shape=value.shape, ) value_str = value.subs_str(self.reduction_renames) default = ir.Reduction.default_accumulator(reduction_type, src_dtype) @@ -1291,7 +1295,9 @@ def scan( else: values.append( self.genfunc( - f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] + f"{value}", + [*value.used_dims, [*self.reduction_renames][:1]], + shape=value.shape, ) ) all_used_dims.update(value.used_dims) @@ -1355,15 +1361,20 @@ def maybe_tuple(x): return tuple(unpack_vars) def genfunc( - self, line, used_dims, *, bounds=ValueRanges.unknown() + self, + line, + used_dims, + *, + bounds=ValueRanges.unknown(), + shape=None, ) -> HalideCSEVariable: - var = self.cse.generate(self.body, line, bounds=bounds) + var = self.cse.generate(self.body, line, bounds=bounds, shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var - def newfunc(self, used_dims) -> HalideCSEVariable: - var = self.cse.newvar() + def newfunc(self, used_dims, *, shape=None) -> HalideCSEVariable: + var = self.cse.newvar(shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d9a3fae9220b..351c1fd6933d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -208,6 +208,7 @@ class IndexingOptions: expand_str: Optional[str] _has_rindex: bool index: sympy.Expr + expand_shape: Optional[Sequence[Union[int, str]]] def has_mask(self) -> bool: return bool(self.mask_vars) @@ -745,8 +746,10 @@ def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: class TritonCSEVariable(CSEVariable): - def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: - super().__init__(name, bounds, dtype) + def __init__( + self, name, bounds: ValueRanges[Any], dtype: torch.dtype, shape=None + ) -> None: + super().__init__(name, bounds, dtype, shape=shape) # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars = OrderedSet[str]() assert dtype is not None, "TritonCSEVariable must have dtype" @@ -1370,6 +1373,7 @@ def index_expr(cls, expr, dtype): indexing.index_str, bounds=get_bounds_index_expr(expr), dtype=dtype, + shape=[], ) finally: config.test_configs.runtime_triton_dtype_assert = orig @@ -1379,6 +1383,7 @@ def index_expr(cls, expr, dtype): V.kernel.compute, cls.to_dtype(var, dtype), dtype=upcast_compute_type(dtype), + shape=var.shape, ) else: # TODO: we are not always consistent in enforcing that the output of the index expr printing @@ -1397,6 +1402,7 @@ def index_expr(cls, expr, dtype): V.kernel.compute, cls.to_dtype(var, index_dtype), dtype=index_dtype, + shape=var.shape, ) var.mask_vars = indexing.mask_vars @@ -1409,6 +1415,7 @@ def masked(mask, body, other): V.kernel.compute, f"{mask}.to(tl.int1)", dtype=torch.bool, + shape=mask.shape, ) nodes = body.graph.find_nodes(op="output") @@ -1439,6 +1446,7 @@ def masked(mask, body, other): f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), dtype=result.dtype, + shape=result.shape, ) ret = ops.where(new_mask, result, other) else: @@ -1460,8 +1468,8 @@ def frexp(x): if cse_val := V.kernel.cse.try_get(cache_key): return cse_val - mantissa = V.kernel.cse.newvar(dtype=x.dtype) - exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) + exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) V.kernel.compute.writeline( f"{mantissa}, {exponent} = triton_helpers.frexp({x})" ) @@ -2018,9 +2026,11 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: return options expand_str = None + expand_shape = None index_str = self.index_to_str(index) if isinstance(index, sympy.Integer): expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + expand_shape = None if copy_shape else self.dense_size_list() index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" if self.fixed_config and not self._has_constant_xmask(): mask_vars = OrderedSet(["xmask"]) @@ -2028,10 +2038,18 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: mask_vars = OrderedSet() if self._load_mask: mask_vars.add(self._load_mask) - return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + return IndexingOptions( + index_str, + mask_vars, + expand_str, + has_rindex, + index, + expand_shape=expand_shape, + ) if need_dense and not have_dense: expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + expand_shape = None if copy_shape else self.dense_size_list() index_str = f"tl.broadcast_to({index_str}, {expand_str})" mask_vars = dense_mask_vars elif not have_loop_vars and copy_shape: @@ -2046,7 +2064,14 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: self.filter_masks(mask_vars) - return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + return IndexingOptions( + index_str, + mask_vars, + expand_str, + has_rindex, + index, + expand_shape=expand_shape, + ) def codegen_block_ptr( self, name: str, var: str, indexing: BlockPtrOptions, other="" @@ -2231,6 +2256,7 @@ def decide_later(): cachemod = ", cache_modifier='.cg'" append_broadcast = None + shape = None dtype = V.graph.get_dtype(name) if should_unwrap_unspec_arg(name): @@ -2247,11 +2273,14 @@ def decide_later(): line = indexing.codegen_broadcast_and_reshape( line, indexing.block_shape, indexing.final_shape, True ) + shape = indexing.final_shape elif isinstance(original_index, sympy.Integer): line = f"tl.load({var} + ({original_index}))" append_broadcast = indexing.expand_str + shape = None else: line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" + shape = indexing.expand_shape if ( dtype in (torch.float16, torch.bfloat16) @@ -2267,7 +2296,9 @@ def decide_later(): dtype = torch.bool load_buffer = self.get_load_buffer(indexing) - result_var = self.cse.generate(load_buffer, make_line(line), dtype=dtype) + result_var = self.cse.generate( + load_buffer, make_line(line), dtype=dtype, shape=shape + ) if result_var.use_count > 1: load_counts[name] -= 1 # don't double count cache hit assert isinstance(result_var, TritonCSEVariable) @@ -2275,7 +2306,9 @@ def decide_later(): if append_broadcast: line = f"tl.broadcast_to({result_var}, {append_broadcast})" - result_var = self.cse.generate(load_buffer, line, dtype=dtype) + result_var = self.cse.generate( + load_buffer, line, dtype=dtype, shape=indexing.expand_shape + ) if indexing.mask_vars: if dtype.is_floating_point: zero = "0.0" @@ -2287,7 +2320,9 @@ def decide_later(): constant_repr(self._load_other) if self._load_other else zero ) line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})" - result_var = self.cse.generate(load_buffer, line, dtype=dtype) + result_var = self.cse.generate( + load_buffer, line, dtype=dtype, shape=result_var.shape + ) if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): self.outside_loop_vars.add(result_var) @@ -2392,6 +2427,7 @@ def bucketize( f"{sorter_indices}, " ")", dtype=indexing_dtype, # type: ignore[attr-defined] + shape=values.shape, ) return result @@ -2418,7 +2454,10 @@ def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str target_shape = initial_shape[:target_ndim] + ["RBLOCK"] return str( self.cse.generate( - buffer, triton_reshape(value, initial_shape, target_shape), dtype=dtype + buffer, + triton_reshape(value, initial_shape, target_shape), + dtype=dtype, + shape=target_shape, ) ) @@ -2470,6 +2509,7 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: self.compute, f"tl.broadcast_to({v}, {dense_size_str})", dtype=v.dtype, + shape=self.dense_size_list(), ), value, ) @@ -2531,7 +2571,9 @@ def final_argreduce(buffer, result_var, value, index): acc_type = triton_acc_type(src_dtype) torch_acc_type = upcast_acc_dtype(src_dtype) - result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_shape = list(self.dense_size_list()) + del result_shape[dim] + result_var: Any = self.cse.newvar(dtype=torch_acc_type, shape=result_shape) result_var.mask_vars = OrderedSet( var for var in masks if not prefix_is_reduction(var[0]) ) @@ -2548,7 +2590,10 @@ def where_cond(tval, fval): def _mask_value(value, default) -> CSEVariable: return self.cse.generate( - self.compute, where_cond(value, default), dtype=value.dtype + self.compute, + where_cond(value, default), + dtype=value.dtype, + shape=value.shape, ) masked_value: Union[CSEVariable, Sequence[CSEVariable]] @@ -2562,12 +2607,14 @@ def _mask_value(value, default) -> CSEVariable: masked_value = _mask_value(value, default) if reduction_type in ("argmax", "argmin"): + assert isinstance(masked_value, CSEVariable) accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() accumulator_index = str( self.cse.generate( self.compute, f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", dtype=accumulator_dtype, + shape=masked_value.shape, ) ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] @@ -2590,7 +2637,9 @@ def _mask_value(value, default) -> CSEVariable: assert isinstance(masked_value, Sequence) (mean, m2, weight) = masked_value result_var = tuple( - self.cse.generate(self.compute, value, dtype=dtype) + self.cse.generate( + self.compute, value, dtype=dtype, shape=value.shape + ) for value in self._welford( self.compute, mean, m2, weight, dim, dtype ) @@ -2605,9 +2654,12 @@ def _mask_value(value, default) -> CSEVariable: self.compute, final_reduction(self.compute, str(masked_value), None), dtype=masked_value.dtype, + shape=result_shape, ) else: - accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + accumulator = self.cse.namedvar( + f"_{result_var}", dtype=torch_acc_type, shape=self.dense_size_list() + ) default = ir.Reduction.default_accumulator(reduction_type, src_dtype) default = self._map_tuple_or_scalar(constant_repr, default) if not isinstance(default, tuple): @@ -2674,7 +2726,7 @@ def _mask_value(value, default) -> CSEVariable: # reduce. Similar to the final reduction for coopereative # reduction result_max = result_var - result_sum = self.cse.newvar(dtype=dtype) + result_sum = self.cse.newvar(dtype=dtype, shape=result_max.shape) result_var = self.online_softmax_reduce_final_reduction( self.post_loop_combine, @@ -3030,7 +3082,9 @@ def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: helper_name = "_triton_helper_fn" from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + from torch._inductor.shape_propagation import ShapePropagationOpsHandler + shape_handler = ShapePropagationOpsHandler() dtype_handler = DtypePropagationOpsHandler() class CSEProxy(DefaultHandler): @@ -3045,10 +3099,16 @@ def _default( name, )(*args, **kwargs) + output_shape = getattr( + shape_handler, + name, + )(*args, **kwargs) + return cse.generate( helper, getattr(overrides, name)(*args, **kwargs), dtype=output_dtype, + shape=output_shape, ) with helper.indent(), V.set_ops_handler(CSEProxy()): @@ -3086,25 +3146,27 @@ def scan( self.compute, f"{value}.to({triton_compute_type(dtype)})", dtype=dtype, + shape=value.shape, ) value = self.cse.generate( self.compute, f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", dtype=dtype, + shape=self.dense_size_list(), ) broadcasted_values.append(value) acc_type = triton_acc_type(dtype) if not self.persistent_reduction: - accumulator = self.cse.newvar(dtype=dtype) reduced_size = self.dense_size_list() reduced_size[-1] = "1" - reduced_size = f"[{', '.join(reduced_size)}]" + accumulator = self.cse.newvar(dtype=dtype, shape=reduced_size) + reduced_size_str = f"[{', '.join(reduced_size)}]" default = "float('nan')" if dtype.is_floating_point else "-1" self.body.writeline( - f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" + f"{accumulator} = tl.full({reduced_size_str}, {default}, {acc_type})" ) accumulators.append(accumulator) @@ -3117,7 +3179,10 @@ def cse_multiple(line, values, masks, dtypes): cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] if all(self.cse.contains(cache_key) for cache_key in cache_keys): return [self.cse.get(cache_key) for cache_key in cache_keys] - result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in dtypes] + result_vars = [ + self.cse.newvar(dtype=_dtype, shape=_value.shape) + for (_dtype, _value) in zip(dtypes, values) + ] self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -3138,10 +3203,16 @@ def cse_multiple(line, values, masks, dtypes): # tl.reduce doesn't work for non-commutative operators, so instead # of repeating the scan op as a reduction, we use sum to select the # last scan value + def _partial_scan_shape(var): + shape = list(var.shape) + shape[-1] = "1" + return shape + partial_reduce_vars = [ cse_compute( f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", dtype=upcast_compute_type(partial_scan_var.dtype), + shape=_partial_scan_shape(partial_scan_var), ) for partial_scan_var in partial_scan_vars ] @@ -3151,6 +3222,7 @@ def cse_multiple(line, values, masks, dtypes): cse_compute( f"tl.where(roffset > 0, {full_scan}, {partial_scan})", dtype=partial_scan.dtype, + shape=full_scan.shape, ) for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) ] @@ -3193,7 +3265,9 @@ def sort( assert len(dtypes) == len(values) broadcasted_values = [ cse_compute( - f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i] + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtypes[i], + shape=self.dense_size_list(), ) for i, value in enumerate(values) ] @@ -3201,11 +3275,15 @@ def sort( def csv(values): return " ".join(f"{value}," for value in values) - def cse_multiple(line, n, masks, dtypes): + def cse_multiple(line, broadcasted_values, masks, dtypes): + n = len(broadcasted_values) cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] if all(self.cse.contains(cache_key) for cache_key in cache_keys): return [self.cse.get(cache_key) for cache_key in cache_keys] - result_vars = [self.cse.newvar(dtype=dtypes[i]) for i in range(n)] # type: ignore[attr-defined] + result_vars = [ + self.cse.newvar(dtype=dtype, shape=value.shape) + for dtype, value in zip(dtypes, broadcasted_values) + ] # type: ignore[attr-defined] self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -3223,7 +3301,7 @@ def cse_multiple(line, n, masks, dtypes): f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," f" {rnumel}, {dim}, stable={stable}, descending={descending})" ) - result_vars = cse_multiple(line, len(values), masks, dtypes) + result_vars = cse_multiple(line, broadcasted_values, masks, dtypes) else: raise AssertionError("Unhandled sort") diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 23ee1e38d18b..435c83994888 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -137,19 +137,24 @@ def scan(self, dtypes, combine_fn, values): value = cse_compute( f"{value}.to({compute_type})", dtype=dtype, + shape=value.shape, ) value = cse_compute( f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtype, + shape=self.dense_size_list(), ) combine_helper_fn = self._lift_helper(combine_fn, 1, (dtype,)) dim = self.triton_tensor_ndim() - 1 assert dim == 0, "" + shape = list(self.dense_size_list()) + del shape[dim] block_sum = cse_compute( f"tl.reduce({value}, {dim}, {combine_helper_fn})", dtype=dtype, + shape=shape, ) exclusive_prefix = self.cse.newvar( dtype=dtype, @@ -188,15 +193,18 @@ def scan(self, dtypes, combine_fn, values): block_scan = cse_compute( f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", dtype=dtype, + shape=shape, ) combined_result = cse_compute( f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", dtype=dtype, + shape=shape, ) return ( cse_compute( f"tl.where(roffset == 0, {block_scan}, {combined_result})", dtype=dtype, + shape=block_scan.shape, ), ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index e344f9599346..1eeb8ece3d6e 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -235,7 +235,8 @@ def load(self, name: str, index: sympy.Expr): if name not in self.fixed_inputs: index_str = self._process_indexing(index) var = self._add_kernel_input(name) - var_dtype = V.graph.get_buffer(name).dtype + buffer = V.graph.get_buffer(name) + var_dtype = buffer.dtype line = f"tl.load({var} + {index_str})" if ( diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py new file mode 100644 index 000000000000..a255b8626a61 --- /dev/null +++ b/torch/_inductor/shape_propagation.py @@ -0,0 +1,108 @@ +import functools +from collections.abc import Sequence +from typing import Callable, Optional, Protocol, Union + +import sympy + +import torch + +from .virtualized import OpsValue + + +ShapeType = Optional[Sequence[Union[int, str]]] + + +class ShapeVar(Protocol): + @property + def shape(self) -> ShapeType: ... + + +ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue] + +# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective +# So first decompose CSEVars -> tuple before calling this + + +@functools.lru_cache(None) +def get_broadcasted_shape(a: ShapeType, b: ShapeType) -> ShapeType: + if a is None: + return b + if b is None: + return a + assert isinstance(a, Sequence) + assert isinstance(b, Sequence) + if len(a) > len(b): + return get_broadcasted_shape(a, tuple(list(b) + list(a[len(b) :]))) + elif len(a) < len(b): + b, a = a, b + return get_broadcasted_shape(a, tuple(list(b) + list(a[len(b) :]))) + else: + + def _get_broadcasted_dim( + d1: Union[int, str], d2: Union[int, str] + ) -> Union[int, str]: + if str(d1) == "1": + return d2 + elif str(d2) == "1": + return d1 + assert str(d1) == str(d2) + return d1 + + return [_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b)] + + +def broadcast_shapes_for_args( + args: Sequence[ShapeArg], +) -> ShapeType: + result_shape = None + + for arg in args: + if shape := getattr(arg, "shape", None): + result_shape = get_broadcasted_shape(result_shape, shape) + + return result_shape + + +class ShapePropagationOpsHandler: + """ + Propagate shape from args to output + """ + + @staticmethod + def constant(value: torch.types.Number, dtype: torch.dtype) -> ShapeType: + return [] + + @staticmethod + def store_reduction(name: str, index: int, value: ShapeArg) -> None: + return None + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: str, + value: Union[ShapeArg, tuple[ShapeArg, ...]], + ) -> Union[ShapeType, tuple[ShapeType, ...]]: + raise NotImplementedError + + @staticmethod + def store( + name: str, index: int, value: ShapeArg, mode: Optional[str] = None + ) -> None: + return None + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> ShapeType: + return [] + + @staticmethod + def indirect_indexing( + var: ShapeArg, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> None: + return None + + def __getattr__(self, name: str) -> Callable[..., ShapeType]: + return lambda *args, **kwargs: broadcast_shapes_for_args(args) From 425726df0bf4ae314f834080ca4e754125953f59 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 25 Apr 2025 17:53:21 +0000 Subject: [PATCH 02/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_utils.py | 4 +++- torch/_inductor/codegen/halide.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 03f780f20b2e..c1748f5e58ae 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -22,6 +22,7 @@ from ..dependencies import Dep from ..loop_body import LoopBody from ..scheduler import BaseSchedulerNode, SchedulerBuffer +from ..shape_propagation import ShapeType from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import ops, OpsValue, V from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext @@ -144,8 +145,9 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, + shape: Optional[ShapeType] = None, ) -> None: - super().__init__(name, bounds, dtype) + super().__init__(name, bounds, dtype, shape=shape) self.is_vec = False self.dependent_itervars = OrderedSet[sympy.Symbol]() diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 1a4a2196dd01..0ce9aa8e194b 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -42,7 +42,6 @@ KernelArgType, OpOverrides, PythonPrinter, - ShapeType, SizeArg, TensorArg, ) @@ -55,6 +54,7 @@ from collections.abc import Sequence from ..ops_handler import ReductionType, StoreMode + from ..shape_propagation import ShapeType log = logging.getLogger(__name__) From e3870bf548e47dd24e010b6b976eb7b2bda5750a Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 2 May 2025 15:01:55 +0000 Subject: [PATCH 03/14] use more CSEVariables [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 147 +++++++++++++++++---------- torch/_inductor/shape_propagation.py | 4 +- 2 files changed, 93 insertions(+), 58 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 351c1fd6933d..a03abf6f0689 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -104,6 +104,7 @@ from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from ..ir import IRNode + from .common import ShapeType from .simd_kernel_features import SIMDKernelFeatures _T = TypeVar("_T") @@ -2441,7 +2442,23 @@ def reduction_resize(self, value) -> str: sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce return f"{value}[{', '.join(sizes)}]" - def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str: + def reduction_resize_and_shape(self, value: CSEVariable) -> tuple[str, ShapeType]: + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})", value.shape + + nreduce = self.num_reduction_dims + sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce + new_shape = ( + (*value.shape[: (ndims - nreduce)], *[1] * nreduce) + if value.shape is not None + else None + ) + return f"{value}[{', '.join(sizes)}]", new_shape + + def reduction_collapse_dims( + self, buffer, value: CSEVariable, dtype: torch.dtype + ) -> CSEVariable: """ Reshape to RBLOCK, collapsing all reduction dims. """ @@ -2452,13 +2469,11 @@ def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims initial_shape = self.dense_size_list() target_shape = initial_shape[:target_ndim] + ["RBLOCK"] - return str( - self.cse.generate( - buffer, - triton_reshape(value, initial_shape, target_shape), - dtype=dtype, - shape=target_shape, - ) + return self.cse.generate( + buffer, + triton_reshape(str(value), initial_shape, target_shape), + dtype=dtype, + shape=target_shape, ) def reduction( @@ -2519,9 +2534,9 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: def final_reduction( buffer, - value: str, + value: CSEVariable, result_type: Optional[str], - ) -> str: + ) -> CSEVariable: """ Helper to generate a reduction call, e.g. tl.sum. """ @@ -2530,23 +2545,23 @@ def final_reduction( value = self.reduction_collapse_dims(buffer, value, dtype) if reduction_type in ("max", "min"): - value = self.reduction_resize( + result = self.reduction_resize( f"{module}.{reduction_type}2({value}, {dim})" ) else: - value = self.reduction_resize( + result = self.reduction_resize( f"{module}.{reduction_type}({value}, {dim})" ) if result_type is not None: - value = f"{value}.to({result_type})" + result = f"{result}.to({result_type})" - return value + return self.cse.generate(buffer, result, dtype=dtype, shape=value.shape) def final_reduction_define( buffer, result_var: str, - value: str, + value: CSEVariable, result_type: Optional[str], ) -> None: """ @@ -2637,10 +2652,8 @@ def _mask_value(value, default) -> CSEVariable: assert isinstance(masked_value, Sequence) (mean, m2, weight) = masked_value result_var = tuple( - self.cse.generate( - self.compute, value, dtype=dtype, shape=value.shape - ) - for value in self._welford( + self.cse.generate(self.compute, value, dtype=dtype, shape=shape) + for value, shape in self._welford( self.compute, mean, m2, weight, dim, dtype ) ) @@ -2650,12 +2663,7 @@ def _mask_value(value, default) -> CSEVariable: result_var = self.prepare_softmax_twopass_fallback(dtype, value) else: assert isinstance(masked_value, CSEVariable) - result_var = self.cse.generate( - self.compute, - final_reduction(self.compute, str(masked_value), None), - dtype=masked_value.dtype, - shape=result_shape, - ) + result_var = final_reduction(self.compute, masked_value, None) else: accumulator = self.cse.namedvar( f"_{result_var}", dtype=torch_acc_type, shape=self.dense_size_list() @@ -2751,18 +2759,17 @@ def _mask_value(value, default) -> CSEVariable: # to # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) # which is needed because tl.reduce doesn't support tl.int1 - accumulator_casted_str = f"{accumulator}.to(tl.int8)" - result_type = triton_compute_type(dtype) - final_reduction_define( - self.post_loop_combine, - str(result_var), - accumulator_casted_str, - result_type, - ) - else: - final_reduction_define( - self.post_loop_combine, str(result_var), str(accumulator), None + accumulator = self.cse.generate( + self.compute, + f"{accumulator}.to(tl.int8)", + dtype=accumulator.dtype, + shape=accumulator.shape, ) + result_type = triton_compute_type(dtype) + + final_reduction_define( + self.post_loop_combine, result_var, accumulator, None + ) if self.cooperative_reduction: default = ir.Reduction.default_accumulator(reduction_type, src_dtype) @@ -2828,9 +2835,7 @@ def _mask_value(value, default) -> CSEVariable: peers = self.codegen_cooperative_reduction_peer_combine( result_var, upcast_acc_dtype(src_dtype), default ) - final_reduction_define( - self.post_loop_store, str(result_var), peers, None - ) + final_reduction_define(self.post_loop_store, result_var, peers, None) exit_stack.close() self.cse.reduction_cache[cache_key] = result_var @@ -2890,11 +2895,21 @@ def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype): for value in (mean, m2, weight) ) welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" - welford_results = [str(self.cse.newvar(dtype=dtype)) for _ in range(3)] - buffer.writeline(f"{', '.join(welford_results)} = {welford}") - result_values = tuple(self.reduction_resize(value) for value in welford_results) - return result_values + def reduced_shape(shape): + result = list(shape) + del result[dim] + return tuple(result) + + welford_results = [ + self.cse.newvar(dtype=dtype, shape=reduced_shape(value.shape)) + for value in (mean, m2, weight) + ] + buffer.writeline(f"{', '.join([str(r) for r in welford_results])} = {welford}") + + return tuple( + self.reduction_resize_and_shape(value) for value in welford_results + ) def welford_reduce( self, result_var, reduction_type, value, where_cond, acc_type, dtype @@ -2902,9 +2917,24 @@ def welford_reduce( """Helper to codegen a welford reduction""" dim = self.triton_tensor_ndim() - self.num_reduction_dims - accumulator = f"{result_var}_mean" - accumulator_m2 = f"{result_var}_m2" - accumulator_weight = f"{result_var}_weight" + accumulator = TritonCSEVariable( + f"{result_var}_mean", + shape=self.dense_size_list(), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) + accumulator_m2 = TritonCSEVariable( + f"{result_var}_m2", + shape=self.dense_size_list(), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) + accumulator_weight = TritonCSEVariable( + f"{result_var}_weight", + shape=self.dense_size_list(), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) self.body.writeline( f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" ) @@ -2941,13 +2971,11 @@ def welford_reduce( """ ) result_mean = result_var - result_m2 = self.cse.newvar(dtype=dtype) - result_weight = self.cse.newvar(dtype=dtype) return self.welford_reduce_final_reduction( self.post_loop_combine, result_mean, - result_m2, - result_weight, + None, + None, accumulator, accumulator_m2, accumulator_weight, @@ -2968,12 +2996,16 @@ def welford_reduce_final_reduction( dtype, ): """Helper to codegen call to triton_helpers.welford""" - values = self._welford(buffer, mean, m2, weight, dim, dtype) + values = list(self._welford(buffer, mean, m2, weight, dim, dtype)) + result_exprs = [result_mean, result_m2, result_weight] - for result_expr, value in zip(result_exprs, values): + for i, (result_expr, (value, shape)) in enumerate(zip(result_exprs, values)): + if result_expr is None: + result_expr = self.cse.newvar(dtype=dtype, shape=shape) + result_exprs[i] = result_expr buffer.splice(f"{result_expr} = {value}") - return result_mean, result_m2, result_weight + return tuple(result_exprs) def online_softmax_reduce_final_reduction( self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype @@ -3204,9 +3236,12 @@ def cse_multiple(line, values, masks, dtypes): # of repeating the scan op as a reduction, we use sum to select the # last scan value def _partial_scan_shape(var): - shape = list(var.shape) - shape[-1] = "1" - return shape + if var.shape is None: + return None + else: + shape = list(var.shape) + shape[-1] = "1" + return shape partial_reduce_vars = [ cse_compute( diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index a255b8626a61..16c311480fb3 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -32,10 +32,10 @@ def get_broadcasted_shape(a: ShapeType, b: ShapeType) -> ShapeType: assert isinstance(a, Sequence) assert isinstance(b, Sequence) if len(a) > len(b): - return get_broadcasted_shape(a, tuple(list(b) + list(a[len(b) :]))) + return get_broadcasted_shape(a, (*b, *a[len(b) :])) elif len(a) < len(b): b, a = a, b - return get_broadcasted_shape(a, tuple(list(b) + list(a[len(b) :]))) + return get_broadcasted_shape(a, (*b, *a[len(b) :])) else: def _get_broadcasted_dim( From a97a6a91c427ad32b4b2accdc1465d9b77572abd Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 2 May 2025 18:09:09 +0000 Subject: [PATCH 04/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 37 ++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a03abf6f0689..f3114ddeca1a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2535,7 +2535,7 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: def final_reduction( buffer, value: CSEVariable, - result_type: Optional[str], + result_type: Optional[torch.dtype], ) -> CSEVariable: """ Helper to generate a reduction call, e.g. tl.sum. @@ -2554,15 +2554,19 @@ def final_reduction( ) if result_type is not None: - result = f"{result}.to({result_type})" + result = f"{result}.to({self.dtype_to_str(result_type)})" + else: + result_type = value.dtype - return self.cse.generate(buffer, result, dtype=dtype, shape=value.shape) + return self.cse.generate( + buffer, result, dtype=result_type, shape=value.shape + ) def final_reduction_define( buffer, - result_var: str, + result_var: CSEVariable, value: CSEVariable, - result_type: Optional[str], + result_type: Optional[torch.dtype], ) -> None: """ Generate a reduction and assign it to an existing variable. @@ -2663,7 +2667,9 @@ def _mask_value(value, default) -> CSEVariable: result_var = self.prepare_softmax_twopass_fallback(dtype, value) else: assert isinstance(masked_value, CSEVariable) - result_var = final_reduction(self.compute, masked_value, None) + result_var = final_reduction( + self.compute, masked_value, masked_value.dtype + ) else: accumulator = self.cse.namedvar( f"_{result_var}", dtype=torch_acc_type, shape=self.dense_size_list() @@ -2760,12 +2766,11 @@ def _mask_value(value, default) -> CSEVariable: # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) # which is needed because tl.reduce doesn't support tl.int1 accumulator = self.cse.generate( - self.compute, + self.post_loop_combine, f"{accumulator}.to(tl.int8)", - dtype=accumulator.dtype, + dtype=torch.int8, shape=accumulator.shape, ) - result_type = triton_compute_type(dtype) final_reduction_define( self.post_loop_combine, result_var, accumulator, None @@ -3024,7 +3029,7 @@ def max_rsplit(self): def codegen_cooperative_reduction_peer_combine( self, result_var, dtype, default_val - ): + ) -> CSEVariable: """ Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different column. After the barrier, every thread block loads the completed value so that it can compute the final @@ -3043,11 +3048,17 @@ def codegen_cooperative_reduction_peer_combine( """, strip=True, ) + peers = self.create_cse_var( + f"{result_var}_peers", + shape=["XBLOCK", "RSPLIT"], + dtype=dtype, + bounds=ValueRanges.unknown(), + ) self.post_loop_store.writeline( - f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " + f"{peers} = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" ) - return f"{result_var}_peers" + return peers def store_reduction( self, @@ -3257,7 +3268,7 @@ def _partial_scan_shape(var): cse_compute( f"tl.where(roffset > 0, {full_scan}, {partial_scan})", dtype=partial_scan.dtype, - shape=full_scan.shape, + shape=partial_scan.shape, ) for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) ] From 033bfb7061e2a9746767871199dd070d9ca2a410 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 6 May 2025 22:30:56 +0000 Subject: [PATCH 05/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/halide.py | 10 ++++--- torch/_inductor/codegen/triton.py | 13 ++++++--- torch/_inductor/ir.py | 44 +++++++------------------------ 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 0ce9aa8e194b..13f21f0a8318 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -705,9 +705,9 @@ def __init__( def dtype_to_str(self, dtype: torch.dtype) -> str: return halide_type(dtype) - def create_cse_var(self, name, bounds=None, dtype=None): + def create_cse_var(self, name, bounds=None, dtype=None, shape=None): self.body.writeline(f"{name} = hl.Func({name!r})") - return HalideCSEVariable(name, bounds, dtype) + return HalideCSEVariable(name, bounds, dtype, shape) def finalize_indexing(self, indices: Sequence[sympy.Expr]): """ @@ -1366,14 +1366,16 @@ def genfunc( used_dims, *, bounds=ValueRanges.unknown(), - shape=None, + shape: Optional[ShapeType] = None, ) -> HalideCSEVariable: var = self.cse.generate(self.body, line, bounds=bounds, shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var - def newfunc(self, used_dims, *, shape=None) -> HalideCSEVariable: + def newfunc( + self, used_dims, *, shape: Optional[ShapeType] = None + ) -> HalideCSEVariable: var = self.cse.newvar(shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 71835ec2f89e..136a8da7c67a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -748,7 +748,11 @@ def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: class TritonCSEVariable(CSEVariable): def __init__( - self, name, bounds: ValueRanges[Any], dtype: torch.dtype, shape=None + self, + name: str, + bounds: ValueRanges[Any], + dtype: torch.dtype, + shape: Optional[ShapeType] = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) # We'll use this to track which masks the variable needs when used for indirect indexing @@ -3169,6 +3173,9 @@ def scan( ], values: tuple[CSEVariable, ...], ) -> tuple[CSEVariable, ...]: + """ + Perform an associative scan on 'values'. + """ assert self.inside_reduction assert not self.cooperative_reduction, "TODO" masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) @@ -3223,8 +3230,8 @@ def cse_multiple(line, values, masks, dtypes): if all(self.cse.contains(cache_key) for cache_key in cache_keys): return [self.cse.get(cache_key) for cache_key in cache_keys] result_vars = [ - self.cse.newvar(dtype=_dtype, shape=_value.shape) - for (_dtype, _value) in zip(dtypes, values) + self.cse.newvar(dtype=dtype, shape=value.shape) + for (dtype, value) in zip(dtypes, values) ] self.compute.writeline( f"{csv(result_vars)} = {line}", diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 54f451ad5843..94b66d77daaa 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -55,7 +55,6 @@ rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, - statically_known_true, SymTypes, ) from torch.utils._ordered_set import OrderedSet @@ -87,7 +86,6 @@ convert_shape_to_inductor, convert_shape_to_symint, developer_warning, - get_dtype_size, get_kernel_metadata, GPU_ALIGN_BYTES, ir_dataclass, @@ -2597,24 +2595,6 @@ def is_stride_order_storage_and_layout( return False -def is_unaligned(node: IRNode) -> bool: - if isinstance(node, (TensorBox, StorageBox)): - return is_unaligned(node.data) - - if isinstance(node, ReinterpretView): - layout = node.layout - has_unaligned_layout = not statically_known_true( - layout.offset * get_dtype_size(layout.dtype) % GPU_ALIGN_BYTES == 0 - ) - return is_unaligned(node.data) or has_unaligned_layout - - if isinstance(node, Buffer): - return node.get_name() in V.graph.unaligned_buffers - - # assume to be aligned otherwise - return False - - @ir_dataclass class BaseView(IRNode): data: IRNode @@ -5720,7 +5700,6 @@ def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] return size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) - wrapper.writeline( f"assert_size_stride({self.get_name()}, {size}, {stride})" ) @@ -6016,6 +5995,7 @@ def __init__( self.subgraph = V.graph.make_subgraph( self.gm, self.example_inputs, subgraph_name ) + import torch._inductor.config as inductor_config with V.set_graph_handler(self.subgraph): @@ -6033,11 +6013,9 @@ def __init__(self, graph: GraphLowering): self.graph = graph self.name = graph.name - outer_inputs = [t.codegen_reference() for t in self.inputs] - - wrapper.codegen_subgraph_with_flattened_outputs( + wrapper.codegen_subgraph( CodegenGraph(self.subgraph), - outer_inputs, + [*[buffer.get_name() for buffer in self.inputs]], [self.name], ) @@ -7012,7 +6990,9 @@ def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] # We need this extra check for input alignment since the example # inputs we created are always aligned. - has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args) + has_unaligned_input = any( + arg.get_name() in V.graph.unaligned_buffers for arg in tensor_args + ) device = cls.find_device(tensor_args, example_output) @@ -7130,22 +7110,19 @@ def get_device(self) -> Optional[torch.device]: class MultiOutput(ExternKernel): def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] wrapper.codegen_multi_output(self) - if not self.skip_size_stride_alignment_checks: - self.codegen_size_asserts(wrapper) - self.codegen_alignment_asserts(wrapper) + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) def __init__( # type: ignore[no-untyped-def] self, layout: OutputSpec, input, indices: list[tuple[Any, ...]], - skip_size_stride_alignment_checks=False, ) -> None: super().__init__(None, layout, [input], ()) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) self.indices = indices - self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks def get_free_symbol_uses( self, unbacked_only: bool = False @@ -7448,9 +7425,9 @@ def __init__( V.graph.register_operation(self) @classmethod - def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def] + def create(cls, subgraph: Subgraph, operands): # type: ignore[no-untyped-def] # TODO(anijain2305) - Support sym expr as operands in future. - fx_operands = V.graph.current_node.args[2:] + fx_operands = V.graph.current_node.args[-1] fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] # Realize the inputs. Also intermediates can have different strides than @@ -7512,7 +7489,6 @@ def create_output(output: IRNode, ind: int): ), invoke_subgraph, [(list, ind)], - skip_size_stride_alignment_checks=True, ) outputs = [create_output(output, i) for i, output in enumerate(outputs)] From 8cf6f34c5731f2e22243e16552406d978cd81fdf Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 6 Jun 2025 17:21:08 +0000 Subject: [PATCH 06/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 45 ++++++++++++++++++++-------- torch/_inductor/codegen/cpp.py | 4 +-- torch/_inductor/codegen/cpp_utils.py | 2 +- torch/_inductor/codegen/halide.py | 12 ++++---- torch/_inductor/codegen/triton.py | 21 +++++++++---- torch/_inductor/shape_propagation.py | 30 ++++++++++++------- 6 files changed, 74 insertions(+), 40 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index b889579823c8..5fc731e10b68 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -40,6 +40,7 @@ from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler from ..ops_handler import BasicMathOpsMixin, DefaultHandler +from ..shape_propagation import ShapePropagationOpsHandler from ..utils import ( boolean_ops, DeferredLineBase, @@ -1354,9 +1355,9 @@ def input(self, name: str) -> str: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.output_buffers: - return cast(str, self.output_buffers[name]) + return cast("str", self.output_buffers[name]) if name in self.inplace_buffers: - return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + return cast("InplacedBuffer", self.inplace_buffers[name]).inner_name if name.startswith("seed"): return self._lookup("seed", self.input_buffers, name) return self._lookup("in_ptr", self.input_buffers, name) @@ -1366,7 +1367,7 @@ def output(self, name: str) -> str: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.inplace_buffers: - return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + return cast("InplacedBuffer", self.inplace_buffers[name]).inner_name return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: @@ -1612,7 +1613,7 @@ def aliases(self) -> Iterator[tuple[str, str]]: if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name if other in self.output_buffers: - yield cast(str, self.output_buffers[other]), inplaced.inner_name + yield cast("str", self.output_buffers[other]), inplaced.inner_name def is_removed(self, name: str) -> bool: return isinstance( @@ -1647,7 +1648,7 @@ def __init__( name: str, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ): super().__init__() assert isinstance(bounds, ValueRanges) @@ -1742,7 +1743,7 @@ def scoped_copy(self) -> typing.Self: def augment_key(self, cache_key: str) -> AugmentedKeyT: "Override this method to augment cache key with backend specifics" - return cast(AugmentedKeyT, cache_key) + return cast("AugmentedKeyT", cache_key) def put(self, cache_key: str, val: CSEVariableType) -> None: self._cache[self.augment_key(cache_key)] = val @@ -1765,7 +1766,7 @@ def generate( write: bool = True, assignment: bool = True, dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> CSEVariableType: if isinstance(expr, OpsValue): expr = expr.value @@ -1777,7 +1778,7 @@ def generate( # with the loose ValueRanges.unknown(), so we need to tighten the bounds expr.bounds = expr.bounds.tighten(bounds) expr.use_count += 1 - return cast(CSEVariableType, expr) + return cast("CSEVariableType", expr) elif isinstance(expr, IndentedBuffer): cache_key = expr.getvalue() elif isinstance(expr, DeferredLineBase): @@ -1833,7 +1834,7 @@ def newvar( self, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> CSEVariableType: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" var = V.kernel.create_cse_var(var_name, bounds, dtype, shape) @@ -1845,7 +1846,7 @@ def namedvar( name: str, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> CSEVariableType: torch._check_value( name not in self.varname_map, lambda: f"duplicate name: {name}" @@ -2306,10 +2307,12 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] dtype_handler = DtypePropagationOpsHandler() + shape_handler = ShapePropagationOpsHandler() backend = get_current_backend() output_dtype = None + if name == "masked" and backend == "triton": output_dtype = value.dtype elif name == "masked" and backend == "cpp": @@ -2320,6 +2323,12 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dtype_op = getattr(dtype_handler, name) output_dtype = dtype_op(*args, **kwargs) + shape_op = getattr(shape_handler, name) + if name == "masked": + output_shape = value.shape + else: + output_shape = shape_op(*args, **kwargs) + if backend in ("triton", "cpp"): # maybe there are some exceptions on mps? assert output_dtype is not None @@ -2334,18 +2343,28 @@ def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: if isinstance(output_dtype, (list, tuple)) else output_dtype ) + var_shape: ShapeType = ( + output_shape[output_idx] # type: ignore[assignment] + if isinstance(output_shape, (list, tuple)) + and len(output_shape) > 0 + and isinstance(output_shape[0], (list, tuple)) + else output_shape + ) output_idx += 1 # some cpp op implementations don't set the dtype - if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None: - v.dtype = var_dtype + if isinstance(v, CSEVariable): + if backend == "cpp" and v.dtype is None: + v.dtype = var_dtype + if v.shape is None: + v.shape = var_shape csevar = V.kernel.cse.generate( V.kernel.compute, v, bounds=bounds, dtype=output_dtype, - shape=getattr(v, "shape", None), + shape=output_shape, ) csevar.update_on_args(name, args, kwargs) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index d1269f07aa72..3ce4f7c144b6 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5090,10 +5090,10 @@ def codegen_template( assert self.is_cpp_template(template_node), ( "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" ) - template_node = cast(SchedulerNode, template_node) + template_node = cast("SchedulerNode", template_node) _, (_, rnumel) = template_node.group assert rnumel == () - ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + ctb: ir.CppTemplateBuffer = cast("ir.CppTemplateBuffer", template_node.node) epilogue_ir_nodes: list[Optional[ir.Operation]] = [ n.node for n in epilogue_nodes ] diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index c1748f5e58ae..f76907c48f36 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -145,7 +145,7 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) self.is_vec = False diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 13f21f0a8318..91cf19febfa9 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -578,7 +578,7 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) self.used_dims: Optional[list[sympy.Symbol]] = None @@ -942,7 +942,7 @@ def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): # group the expression by variables used offset = sympy.S.Zero - split_expr = {s: sympy.S.Zero for s in symbols} + split_expr = dict.fromkeys(symbols, sympy.S.Zero) split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = [] index = sympy.expand(self.rename_indexing(index)) for part in index.args if isinstance(index, sympy.Add) else [index]: @@ -1366,16 +1366,14 @@ def genfunc( used_dims, *, bounds=ValueRanges.unknown(), - shape: Optional[ShapeType] = None, + shape: ShapeType = None, ) -> HalideCSEVariable: var = self.cse.generate(self.body, line, bounds=bounds, shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var - def newfunc( - self, used_dims, *, shape: Optional[ShapeType] = None - ) -> HalideCSEVariable: + def newfunc(self, used_dims, *, shape: ShapeType = None) -> HalideCSEVariable: var = self.cse.newvar(shape=shape) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims @@ -1548,7 +1546,7 @@ def generate(g): code.splice(self.indexing_code) def update_index(m): - var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)]) + var = cast("HalideCSEVariable", self.cse.varname_map[m.group(1)]) assert var.used_dims is not None, var return str(var) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 136a8da7c67a..35d8025a7f4b 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2031,7 +2031,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: return options expand_str = None - expand_shape = None + expand_shape: ShapeType = None index_str = self.index_to_str(index) if isinstance(index, sympy.Integer): expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() @@ -2061,6 +2061,12 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" mask_vars = dense_mask_vars + if expand_shape is None: + if need_dense: + expand_shape = None if copy_shape else self.dense_size_list() + else: + expand_shape = () + if override_mask: mask_vars = OrderedSet([override_mask]) @@ -2261,7 +2267,7 @@ def decide_later(): cachemod = ", cache_modifier='.cg'" append_broadcast = None - shape = None + shape: ShapeType = None dtype = V.graph.get_dtype(name) if should_unwrap_unspec_arg(name): @@ -2270,6 +2276,7 @@ def decide_later(): # see triton_utils.py:signature_of if dtype in (torch.float16, torch.bfloat16): dtype = torch.float32 + shape = () else: if isinstance(indexing, BlockPtrOptions): @@ -2282,7 +2289,7 @@ def decide_later(): elif isinstance(original_index, sympy.Integer): line = f"tl.load({var} + ({original_index}))" append_broadcast = indexing.expand_str - shape = None + shape = () else: line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" shape = indexing.expand_shape @@ -2616,7 +2623,7 @@ def _mask_value(value, default) -> CSEVariable: self.compute, where_cond(value, default), dtype=value.dtype, - shape=value.shape, + shape=value.shape if value.shape is not None else default.shape, ) masked_value: Union[CSEVariable, Sequence[CSEVariable]] @@ -3674,7 +3681,7 @@ def codegen_kernel(self, name=None): if isinstance(arg, SizeArg): # mypy is unhappy about the sympy.Expr # type for the key of the dict below - symbol = cast(sympy.Symbol, arg.expr) + symbol = cast("sympy.Symbol", arg.expr) if symbol in V.graph.sizevars.inv_precomputed_replacements: signature[i] = SizeArg( arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] @@ -3690,7 +3697,9 @@ def codegen_kernel(self, name=None): and mutation not in self.removed_buffers ): mutated_args.add( - cast(InplacedBuffer, self.args.inplace_buffers[mutation]).inner_name + cast( + "InplacedBuffer", self.args.inplace_buffers[mutation] + ).inner_name ) if mutation in self.args.output_buffers: mutation_arg = self.args.output_buffers[mutation] diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index 16c311480fb3..1ca25524f8e9 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -25,10 +25,6 @@ def shape(self) -> ShapeType: ... @functools.lru_cache(None) def get_broadcasted_shape(a: ShapeType, b: ShapeType) -> ShapeType: - if a is None: - return b - if b is None: - return a assert isinstance(a, Sequence) assert isinstance(b, Sequence) if len(a) > len(b): @@ -48,17 +44,28 @@ def _get_broadcasted_dim( assert str(d1) == str(d2) return d1 - return [_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b)] + return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b)) def broadcast_shapes_for_args( - args: Sequence[ShapeArg], + args: Sequence[ShapeArg], assume_equal_shapes: bool = False ) -> ShapeType: - result_shape = None + result_shape: ShapeType = None for arg in args: - if shape := getattr(arg, "shape", None): - result_shape = get_broadcasted_shape(result_shape, shape) + if hasattr(arg, "shape"): + shape = arg.shape + if shape is None: + if assume_equal_shapes: + continue + else: + return None + elif result_shape is None: + result_shape = tuple(shape) + else: + result_shape = get_broadcasted_shape(result_shape, tuple(shape)) + else: + return None return result_shape @@ -70,7 +77,7 @@ class ShapePropagationOpsHandler: @staticmethod def constant(value: torch.types.Number, dtype: torch.dtype) -> ShapeType: - return [] + return () @staticmethod def store_reduction(name: str, index: int, value: ShapeArg) -> None: @@ -93,7 +100,8 @@ def store( @staticmethod def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> ShapeType: - return [] + # TODO: fix me + return () @staticmethod def indirect_indexing( From 3e1a619c12b03cd689f2d87f051997f78ed0ddb5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 9 Jun 2025 19:45:51 +0000 Subject: [PATCH 07/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 10 ++++----- torch/_inductor/codegen/triton.py | 32 ++++++++++++++++------------ torch/_inductor/shape_propagation.py | 11 ++++++++-- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 5fc731e10b68..be670de7f477 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -2311,22 +2311,22 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> backend = get_current_backend() + shape_op = getattr(shape_handler, name) output_dtype = None + output_shape = None if name == "masked" and backend == "triton": output_dtype = value.dtype + output_shape = value.shape elif name == "masked" and backend == "cpp": output_dtype = V.interpreter.current_node.meta.get( OptimizationContext.key, None ).dtype + # TODO: fix me + output_shape = None elif backend in ("triton", "cpp"): dtype_op = getattr(dtype_handler, name) output_dtype = dtype_op(*args, **kwargs) - - shape_op = getattr(shape_handler, name) - if name == "masked": - output_shape = value.shape - else: output_shape = shape_op(*args, **kwargs) if backend in ("triton", "cpp"): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 35d8025a7f4b..ee8f62b2afa2 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1378,7 +1378,7 @@ def index_expr(cls, expr, dtype): indexing.index_str, bounds=get_bounds_index_expr(expr), dtype=dtype, - shape=[], + shape=indexing.expand_shape, ) finally: config.test_configs.runtime_triton_dtype_assert = orig @@ -2035,7 +2035,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: index_str = self.index_to_str(index) if isinstance(index, sympy.Integer): expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - expand_shape = None if copy_shape else self.dense_size_list() + expand_shape = None if copy_shape else tuple(self.dense_size_list()) index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" if self.fixed_config and not self._has_constant_xmask(): mask_vars = OrderedSet(["xmask"]) @@ -2054,7 +2054,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: if need_dense and not have_dense: expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - expand_shape = None if copy_shape else self.dense_size_list() + expand_shape = None if copy_shape else tuple(self.dense_size_list()) index_str = f"tl.broadcast_to({index_str}, {expand_str})" mask_vars = dense_mask_vars elif not have_loop_vars and copy_shape: @@ -2062,8 +2062,8 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: mask_vars = dense_mask_vars if expand_shape is None: - if need_dense: - expand_shape = None if copy_shape else self.dense_size_list() + if need_dense or have_dense: + expand_shape = None if copy_shape else tuple(self.dense_size_list()) else: expand_shape = () @@ -2484,7 +2484,7 @@ def reduction_collapse_dims( buffer, triton_reshape(str(value), initial_shape, target_shape), dtype=dtype, - shape=target_shape, + shape=tuple(target_shape), ) def reduction( @@ -2535,7 +2535,7 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: self.compute, f"tl.broadcast_to({v}, {dense_size_str})", dtype=v.dtype, - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), ), value, ) @@ -2603,7 +2603,9 @@ def final_argreduce(buffer, result_var, value, index): torch_acc_type = upcast_acc_dtype(src_dtype) result_shape = list(self.dense_size_list()) del result_shape[dim] - result_var: Any = self.cse.newvar(dtype=torch_acc_type, shape=result_shape) + result_var: Any = self.cse.newvar( + dtype=torch_acc_type, shape=tuple(result_shape) + ) result_var.mask_vars = OrderedSet( var for var in masks if not prefix_is_reduction(var[0]) ) @@ -2683,7 +2685,9 @@ def _mask_value(value, default) -> CSEVariable: ) else: accumulator = self.cse.namedvar( - f"_{result_var}", dtype=torch_acc_type, shape=self.dense_size_list() + f"_{result_var}", + dtype=torch_acc_type, + shape=tuple(self.dense_size_list()), ) default = ir.Reduction.default_accumulator(reduction_type, src_dtype) default = self._map_tuple_or_scalar(constant_repr, default) @@ -2935,19 +2939,19 @@ def welford_reduce( accumulator = TritonCSEVariable( f"{result_var}_mean", - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), dtype=acc_type, bounds=ValueRanges.unknown(), ) accumulator_m2 = TritonCSEVariable( f"{result_var}_m2", - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), dtype=acc_type, bounds=ValueRanges.unknown(), ) accumulator_weight = TritonCSEVariable( f"{result_var}_weight", - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), dtype=acc_type, bounds=ValueRanges.unknown(), ) @@ -3209,7 +3213,7 @@ def scan( self.compute, f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", dtype=dtype, - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), ) broadcasted_values.append(value) @@ -3327,7 +3331,7 @@ def sort( cse_compute( f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i], - shape=self.dense_size_list(), + shape=tuple(self.dense_size_list()), ) for i, value in enumerate(values) ] diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index 1ca25524f8e9..8378477b0397 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -6,7 +6,7 @@ import torch -from .virtualized import OpsValue +from .virtualized import OpsValue, V ShapeType = Optional[Sequence[Union[int, str]]] @@ -77,7 +77,14 @@ class ShapePropagationOpsHandler: @staticmethod def constant(value: torch.types.Number, dtype: torch.dtype) -> ShapeType: - return () + # See implementation of constant for triton for the reason + from torch._inductor.codegen.triton import TritonKernel + + if isinstance(V.kernel, TritonKernel): + ndim = V.kernel.triton_tensor_ndim() + return tuple([1] * ndim) + else: + return () @staticmethod def store_reduction(name: str, index: int, value: ShapeArg) -> None: From a5392a6a0b3bae11a629a24c7fdc491a07b48895 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 10 Jun 2025 19:00:24 +0000 Subject: [PATCH 08/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/shape_propagation.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2990d63a057b..e2c2787ed7b5 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2599,7 +2599,7 @@ def final_argreduce(buffer, result_var, value, index): acc_type = triton_acc_type(src_dtype) torch_acc_type = upcast_acc_dtype(src_dtype) result_shape = list(self.dense_size_list()) - del result_shape[dim] + result_shape[dim] = "1" result_var: Any = self.cse.newvar( dtype=torch_acc_type, shape=tuple(result_shape) ) diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index a27c0ca45a0f..80014734e0be 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -64,6 +64,9 @@ def broadcast_shapes_for_args( result_shape = tuple(shape) else: result_shape = get_broadcasted_shape(result_shape, tuple(shape)) + elif isinstance(arg, (int, float)): + if result_shape is None: + result_shape = () else: return None From 88526ac06cc32142b64b6f4ea064e3d56b9e1ad8 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 10 Jun 2025 19:01:20 +0000 Subject: [PATCH 09/14] Update [ghstack-poisoned] --- torch/_inductor/shape_propagation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index 80014734e0be..779f21cdcdad 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -47,19 +47,14 @@ def _get_broadcasted_dim( return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b)) -def broadcast_shapes_for_args( - args: Sequence[ShapeArg], assume_equal_shapes: bool = False -) -> ShapeType: +def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> ShapeType: result_shape: ShapeType = None for arg in args: if hasattr(arg, "shape"): shape = arg.shape if shape is None: - if assume_equal_shapes: - continue - else: - return None + return None elif result_shape is None: result_shape = tuple(shape) else: From 4b8b8a9d60b1f39ce2ac67f3a83f9150ec9ca61f Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Thu, 12 Jun 2025 19:12:02 +0000 Subject: [PATCH 10/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 4 +++ torch/_inductor/codegen/triton.py | 27 ++++++++++++++------ torch/_inductor/codegen/triton_split_scan.py | 14 +++++++--- torch/_inductor/shape_propagation.py | 17 +++++++++++- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 5aebf84ca189..3bc5835aba55 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1854,6 +1854,10 @@ def generate( assert isinstance(expr, str) cache_key = expr var = self.try_get(cache_key) + if shape is None and not assignment: + # since there's no assignment to a variable, use any shape here + # other than None to avoid the unknown shape failures + shape = () if not var: var = self.newvar(bounds, dtype, shape) self.put(cache_key, var) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e2c2787ed7b5..5e47a69bc4c1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -759,6 +759,7 @@ def __init__( # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars: OrderedSet[str] = OrderedSet() assert dtype is not None, "TritonCSEVariable must have dtype" + assert shape is not None, "TritonCSEVariable must have shape" def update_on_args(self, name, args, kwargs): for arg in args: @@ -3028,11 +3029,16 @@ def welford_reduce_final_reduction( def online_softmax_reduce_final_reduction( self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype ): - values = self._online_softmax_reduce(buffer, peer_max, peer_sum, dim, dtype) - result_exprs = [result_max, result_sum] - for result_expr, value in zip(result_exprs, values): - buffer.splice(f"{result_expr} = {value}") - + accumulator_max = self.reduction_collapse_dims(buffer, peer_max, dtype) + accumulator_sum = self.reduction_collapse_dims(buffer, peer_sum, dtype) + buffer.splice( + f""" + {result_max}, {result_sum} = triton_helpers.online_softmax_reduce( + {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math}) + {result_max} = {self.reduction_resize(f"{result_max}")} + {result_sum} = {self.reduction_resize(f"{result_sum}")} + """ + ) return result_max, result_sum def max_rsplit(self): @@ -3115,7 +3121,9 @@ def store_reduction( exit_stack.close() - def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: + def _lift_helper( + self, fn, values: tuple[CSEVariable, ...], dtypes: tuple[torch.dtype, ...] + ) -> str: # Lift IR function for scan operations into a triton function # in the global namespace helper = IndentedBuffer() @@ -3123,7 +3131,10 @@ def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: cse = CSE() args = [ - tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args)) + tuple( + cse.namedvar(f"arg{i}_{n}", dtype=dtype, shape=value.shape) + for n, (value, dtype) in enumerate(zip(values, dtypes)) + ) for i in range(2) ] signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) @@ -3197,7 +3208,7 @@ def scan( dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) cse_compute = functools.partial(self.cse.generate, self.compute) - combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes) + combine_helper_fn = self._lift_helper(combine_fn, values, dtypes) dim = self.triton_tensor_ndim() - self.num_reduction_dims for value, dtype in zip(values, dtypes): diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 435c83994888..9c21b08498ef 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -123,11 +123,16 @@ def scan(self, dtypes, combine_fn, values): scratch_base: Union[str, TritonCSEVariable] scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) if offset != 0: - scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") - runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base} + {self.index_to_str(offset)}", shape=() + ) + runtime_rblocks = cse_load( + f"tl.num_programs({self.range_trees[-1].index})", shape=() + ) scratch_base = cse_load( f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " - f"{scratch_elems_per_block} * {runtime_rblocks}" + f"{scratch_elems_per_block} * {runtime_rblocks}", + shape=(), ) masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) @@ -145,7 +150,7 @@ def scan(self, dtypes, combine_fn, values): shape=self.dense_size_list(), ) - combine_helper_fn = self._lift_helper(combine_fn, 1, (dtype,)) + combine_helper_fn = self._lift_helper(combine_fn, (value,), (dtype,)) dim = self.triton_tensor_ndim() - 1 assert dim == 0, "" shape = list(self.dense_size_list()) @@ -158,6 +163,7 @@ def scan(self, dtypes, combine_fn, values): ) exclusive_prefix = self.cse.newvar( dtype=dtype, + shape=shape, ) if element_nbits == 64: self.compute.splice( diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index 779f21cdcdad..1927bd5f31b6 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -62,6 +62,8 @@ def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> ShapeType: elif isinstance(arg, (int, float)): if result_shape is None: result_shape = () + elif isinstance(arg, torch.dtype): + continue else: return None @@ -103,9 +105,22 @@ def store( ) -> None: return None + @staticmethod + def to_dtype( + value: ShapeVar, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ShapeType: + return value.shape + @staticmethod def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> ShapeType: - # TODO: fix me + # shape is implicitly embedded in expr. + return None + + @staticmethod + def load_seed(name: str, offset: int) -> ShapeType: return () @staticmethod From 19a37947712af1ca0f120aa8b2010914dd7b50b6 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 13 Jun 2025 15:31:57 +0000 Subject: [PATCH 11/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 3 ++- torch/_inductor/shape_propagation.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 5e47a69bc4c1..5de4d16792b4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -759,7 +759,8 @@ def __init__( # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars: OrderedSet[str] = OrderedSet() assert dtype is not None, "TritonCSEVariable must have dtype" - assert shape is not None, "TritonCSEVariable must have shape" + # TODO: uncomment this and fix the few failures left + # assert shape is not None, "TritonCSEVariable must have shape" def update_on_args(self, name, args, kwargs): for arg in args: diff --git a/torch/_inductor/shape_propagation.py b/torch/_inductor/shape_propagation.py index 1927bd5f31b6..772ac1d1881f 100644 --- a/torch/_inductor/shape_propagation.py +++ b/torch/_inductor/shape_propagation.py @@ -17,7 +17,7 @@ class ShapeVar(Protocol): def shape(self) -> ShapeType: ... -ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue] +ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype] # Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective # So first decompose CSEVars -> tuple before calling this From 053435f897b1ecebb7bd17199e1592145826d42d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 13 Jun 2025 21:11:39 +0000 Subject: [PATCH 12/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/triton_split_scan.py | 3 +++ torch/_inductor/kernel/flex_attention.py | 6 ++++-- torch/_inductor/select_algorithm.py | 14 +++++++++++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 9c21b08498ef..b36d26ec08bf 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -86,6 +86,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError("NYI TritonSplitDimKernel reductions") def scan(self, dtypes, combine_fn, values): + """ + Perform an associative scan on 'values'. + """ import triton.language as tl (dtype,) = dtypes diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index f590a7a9194b..f743c5937115 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -595,7 +595,8 @@ def load_checked_2d( mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask", val_shape=["BLOCK_M", "V_HEAD_DIM_ROUNDED"])}} if OUTPUT_LOGSUMEXP: off_hz = tl.program_id(1) @@ -2049,7 +2050,8 @@ def flex_attention_backward_grid( # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", val_shape=["BLOCK_N1", "QK_HEAD_DIM_ROUNDED"], indent_width=8)}} @triton.jit def bwd_dq_inner( diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 20e40639576d..a187a67cd3fa 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -251,11 +251,16 @@ def load(self, name: str, index: sympy.Expr): line += ".to(tl.float32)" var_dtype = torch.float32 - out = self.kernel.cse.generate(self.kernel.compute, line, dtype=var_dtype) + out = self.kernel.cse.generate( + self.kernel.compute, line, dtype=var_dtype, shape=() + ) return out return self.kernel.cse.generate( - self.kernel.compute, f"({self.fixed_inputs[name]})", dtype=torch.float32 + self.kernel.compute, + f"({self.fixed_inputs[name]})", + dtype=torch.float32, + shape=(), ) def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): @@ -946,6 +951,7 @@ def store_output( val: str, mask: Optional[str] = None, indent_width: int = 4, + val_shape: Optional[list[str]] = None, ): """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. @@ -996,7 +1002,9 @@ def store_output( if "ACC_TYPE" in self.meta else torch.float32 ) - epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)] + epilogue_args = [ + V.kernel.cse.namedvar(val, dtype=acc_dtype, shape=val_shape) + ] for input_node in itertools.chain( self.input_nodes[: self.prefix_args], self.input_nodes[len(self.input_nodes) - self.suffix_args :], From 27f03bb049a29ad8661f48b90f912796c3186f80 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 11 Aug 2025 22:44:08 +0000 Subject: [PATCH 13/14] Update [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0962f9601aa6..c9aa08669e2d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -833,7 +833,7 @@ def __init__( name: str, bounds: ValueRanges[Any], dtype: torch.dtype, - shape: Optional[BlockShapeType] = None, + shape: BlockShapeType = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) # We'll use this to track which masks the variable needs when used for indirect indexing @@ -2873,7 +2873,7 @@ def final_reduction( buffer, value: CSEVariable, result_type: Optional[torch.dtype], - ) -> CSEVariable: + ) -> Tuple[str, torch.dtype, BlockShapeType]: """ Helper to generate a reduction call, e.g. tl.sum. """ @@ -2895,7 +2895,7 @@ def final_reduction( else: result_type = value.dtype - return self.cse.generate(buffer, result, dtype=result_type, shape=shape) + return result, result_type, shape def final_reduction_define( buffer, @@ -2906,7 +2906,7 @@ def final_reduction_define( """ Generate a reduction and assign it to an existing variable. """ - value = final_reduction(buffer, value, result_type) + value, _, _ = final_reduction(buffer, value, result_type) buffer.splice(f"{result_var} = {value}") def final_argreduce(buffer, result_var, value, index): @@ -3004,9 +3004,10 @@ def _mask_value(value, default) -> CSEVariable: result_var = self.prepare_softmax_twopass_fallback(dtype, value) else: assert isinstance(masked_value, CSEVariable) - result_var = final_reduction( + _result, _dtype, _shape = final_reduction( self.compute, masked_value, masked_value.dtype ) + result_var = self.cse.generate(self.compute, _result, dtype=_dtype, shape=_shape) else: accumulator = self.cse.namedvar( f"_{result_var}", From 576d5a5f784b4ca50fb86d151ccf67f9725796d5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 11 Aug 2025 23:04:15 +0000 Subject: [PATCH 14/14] fix docstring linter [ghstack-poisoned] --- torch/_inductor/codegen/triton.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c9aa08669e2d..ff493ead975d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1862,6 +1862,10 @@ def are_block_parameters_compatible( class TritonKernel(SIMDKernel[TritonCSEVariable]): + """A class to represent a triton kernel and helpers to generate + triton kernel programmatically + """ + overrides = TritonKernelOverrides # type: ignore[assignment] helper_functions: HelperFunctions kexpr: Callable[[sympy.Expr], str] = texpr @@ -2501,6 +2505,9 @@ def get_load_buffer(self, indexing): return self.loads def load(self, name: str, index: sympy.Expr): + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ var = self.args.input(name) load_counts = self._load_counts load_counts[name] += 1 @@ -2873,7 +2880,7 @@ def final_reduction( buffer, value: CSEVariable, result_type: Optional[torch.dtype], - ) -> Tuple[str, torch.dtype, BlockShapeType]: + ) -> tuple[str, Optional[torch.dtype], BlockShapeType]: """ Helper to generate a reduction call, e.g. tl.sum. """ @@ -3007,7 +3014,9 @@ def _mask_value(value, default) -> CSEVariable: _result, _dtype, _shape = final_reduction( self.compute, masked_value, masked_value.dtype ) - result_var = self.cse.generate(self.compute, _result, dtype=_dtype, shape=_shape) + result_var = self.cse.generate( + self.compute, _result, dtype=_dtype, shape=_shape + ) else: accumulator = self.cse.namedvar( f"_{result_var}",