Skip to content

[inductor] propagate shapes in CSEVariable #152198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: gh/isuruf/141/base
Choose a base branch
from
50 changes: 43 additions & 7 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
from ..shape_propagation import ShapePropagationOpsHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
Expand All @@ -70,6 +71,7 @@
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
from ..shape_propagation import BlockShapeType
from .wrapper import PythonWrapperCodegen

_T = TypeVar("_T")
Expand Down Expand Up @@ -1770,13 +1772,15 @@ def __init__(
name: str,
bounds: ValueRanges[Any],
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
):
super().__init__()
assert isinstance(bounds, ValueRanges), type(bounds)
self.name = name
self.bounds = bounds
self.use_count = 1 # track how many times this expression is used
self.dtype = dtype
self.shape = shape

def __str__(self) -> str:
return self.name
Expand Down Expand Up @@ -1886,6 +1890,7 @@ def generate(
write: bool = True,
assignment: bool = True,
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
) -> CSEVariableType:
if isinstance(expr, OpsValue):
expr = expr.value
Expand All @@ -1906,8 +1911,12 @@ def generate(
assert isinstance(expr, str)
cache_key = expr
var = self.try_get(cache_key)
if shape is None and not assignment:
# since there's no assignment to a variable, use any shape here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to make the shape the assigned variables shape ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could if expr was a CSEVariable, but not in other cases.

# other than None to avoid the unknown shape failures
shape = ()
if not var:
var = self.newvar(bounds, dtype)
var = self.newvar(bounds, dtype, shape)
self.put(cache_key, var)
if write:
if V.kernel.current_node:
Expand Down Expand Up @@ -1953,9 +1962,10 @@ def newvar(
self,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
) -> CSEVariableType:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds, dtype)
var = V.kernel.create_cse_var(var_name, bounds, dtype, shape)
self.varname_map[var_name] = var
return var

Expand All @@ -1964,11 +1974,12 @@ def namedvar(
name: str,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
) -> CSEVariableType:
torch._check_value(
name not in self.varname_map, lambda: f"duplicate name: {name}"
)
var = V.kernel.create_cse_var(name, bounds, dtype)
var = V.kernel.create_cse_var(name, bounds, dtype, shape)
self.varname_map[name] = var
return var

Expand Down Expand Up @@ -2424,45 +2435,64 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->

value = getattr(self.parent_handler, name)(*args, **kwargs)
dtype_handler = DtypePropagationOpsHandler()
shape_handler = ShapePropagationOpsHandler()

backend = get_current_backend()

shape_op = getattr(shape_handler, name)
output_dtype = None
output_shape = None

if name == "masked" and backend == "triton":
output_dtype = value.dtype
output_shape = value.shape
elif name == "masked" and backend == "cpp":
output_dtype = V.interpreter.current_node.meta.get(
OptimizationContext.key, None
).dtype
# TODO: fix me
output_shape = None
elif backend in ("triton", "cpp", "mps"):
dtype_op = getattr(dtype_handler, name)
output_dtype = dtype_op(*args, **kwargs)
output_shape = shape_op(*args, **kwargs)

if backend in ("triton", "cpp"):
# maybe there are some exceptions on mps?
assert output_dtype is not None

output_idx = 0

def do_cse(v: str) -> CSEVariable:
def do_cse(v: Union[str, CSEVariable]) -> CSEVariable:
# we tree_map over the output, so we need to fetch corresponding dtype
nonlocal output_idx
var_dtype: Optional[torch.dtype] = (
output_dtype[output_idx]
if isinstance(output_dtype, (list, tuple))
else output_dtype
)
var_shape: BlockShapeType = (
output_shape[output_idx] # type: ignore[assignment]
if isinstance(output_shape, (list, tuple))
and len(output_shape) > 0
and isinstance(output_shape[0], (list, tuple))
else output_shape
)
output_idx += 1

# some cpp op implementations don't set the dtype
if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None:
v.dtype = var_dtype
if isinstance(v, CSEVariable):
if backend == "cpp" and v.dtype is None:
v.dtype = var_dtype
if v.shape is None:
v.shape = var_shape

csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
shape=output_shape,
)

csevar.update_on_args(name, args, kwargs)
Expand Down Expand Up @@ -2559,7 +2589,13 @@ def indirect_indexing(
pos = var.bounds & ValueRanges(0, int_oo)
new_bounds = new_bounds | pos

var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds)
var = self.kernel.cse.generate(
self.kernel.compute,
stm,
bounds=new_bounds,
dtype=var.dtype,
shape=var.shape,
)

sympy_var = self.parent_handler.indirect_indexing(var, size, check)
if generate_assert(check):
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,8 @@ def frexp(x):
return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys)

code = BracesBuffer()
exponent = V.kernel.cse.newvar(dtype=torch.int32)
mantissa = V.kernel.cse.newvar(dtype=x.dtype)
exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape)
mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape)
code.writeline(f"int32_t {exponent};")
code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
V.kernel.compute.splice(code)
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/cpp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..dependencies import Dep
from ..loop_body import LoopBody
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
from ..shape_propagation import BlockShapeType
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..virtualized import ops, OpsValue, V
from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
Expand Down Expand Up @@ -145,8 +146,9 @@ def __init__(
name,
bounds: ValueRanges[Any],
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
) -> None:
super().__init__(name, bounds, dtype)
super().__init__(name, bounds, dtype, shape=shape)
self.is_vec = False
self.dependent_itervars = OrderedSet[sympy.Symbol]()

Expand Down
29 changes: 20 additions & 9 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from collections.abc import Sequence

from ..ops_handler import ReductionType, StoreMode
from ..shape_propagation import BlockShapeType

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -556,6 +557,7 @@ def masked(mask, body, other):
f"hl.cast({result.name}.type(), {halide_constant(other)})",
[],
bounds=ValueRanges.wrap(other),
shape=result.shape,
)
# TODO(jansel): look into removing the where in the same places triton does
return ops.where(new_mask, result, other)
Expand All @@ -576,8 +578,9 @@ def __init__(
name,
bounds: ValueRanges[Any],
dtype: Optional[torch.dtype] = None,
shape: BlockShapeType = None,
) -> None:
super().__init__(name, bounds, dtype)
super().__init__(name, bounds, dtype, shape=shape)
self.used_dims: Optional[list[sympy.Symbol]] = None

def update_on_args(self, name, args, kwargs):
Expand Down Expand Up @@ -702,9 +705,9 @@ def __init__(
def dtype_to_str(self, dtype: torch.dtype) -> str:
return halide_type(dtype)

def create_cse_var(self, name, bounds=None, dtype=None):
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
self.body.writeline(f"{name} = hl.Func({name!r})")
return HalideCSEVariable(name, bounds, dtype)
return HalideCSEVariable(name, bounds, dtype, shape)

def finalize_indexing(self, indices: Sequence[sympy.Expr]):
"""
Expand Down Expand Up @@ -1196,12 +1199,13 @@ def reduction(
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
reduction_vars = OrderedSet(self.reduction_renames)
result_var = self.newfunc(
[v for v in value.used_dims if v not in reduction_vars]
[v for v in value.used_dims if v not in reduction_vars],
)
if reduction_vars - OrderedSet(value.used_dims):
value = self.genfunc(
f"{value}",
self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
shape=value.shape,
)
value_str = value.subs_str(self.reduction_renames)
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
Expand Down Expand Up @@ -1291,7 +1295,9 @@ def scan(
else:
values.append(
self.genfunc(
f"{value}", [*value.used_dims, [*self.reduction_renames][:1]]
f"{value}",
[*value.used_dims, [*self.reduction_renames][:1]],
shape=value.shape,
)
)
all_used_dims.update(value.used_dims)
Expand Down Expand Up @@ -1355,15 +1361,20 @@ def maybe_tuple(x):
return tuple(unpack_vars)

def genfunc(
self, line, used_dims, *, bounds=ValueRanges.unknown()
self,
line,
used_dims,
*,
bounds=ValueRanges.unknown(),
shape: BlockShapeType = None,
) -> HalideCSEVariable:
var = self.cse.generate(self.body, line, bounds=bounds)
var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
assert isinstance(var, HalideCSEVariable)
var.used_dims = used_dims
return var

def newfunc(self, used_dims) -> HalideCSEVariable:
var = self.cse.newvar()
def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable:
var = self.cse.newvar(shape=shape)
assert isinstance(var, HalideCSEVariable)
var.used_dims = used_dims
return var
Expand Down
Loading
Loading