44
44
from .. import config , metrics
45
45
from ..dtype_propagation import DtypePropagationOpsHandler
46
46
from ..ops_handler import BasicMathOpsMixin , DefaultHandler
47
+ from ..shape_propagation import ShapePropagationOpsHandler
47
48
from ..utils import (
48
49
boolean_ops ,
49
50
DeferredLineBase ,
70
71
from ..ir import Buffer , ChoiceCaller , FixedLayout , IRNode
71
72
from ..loop_body import LoopBody
72
73
from ..scheduler import BaseScheduling , Scheduler , SchedulerNode
74
+ from ..shape_propagation import BlockShapeType
73
75
from .wrapper import PythonWrapperCodegen
74
76
75
77
_T = TypeVar ("_T" )
@@ -1770,13 +1772,15 @@ def __init__(
1770
1772
name : str ,
1771
1773
bounds : ValueRanges [Any ],
1772
1774
dtype : Optional [torch .dtype ] = None ,
1775
+ shape : BlockShapeType = None ,
1773
1776
):
1774
1777
super ().__init__ ()
1775
1778
assert isinstance (bounds , ValueRanges ), type (bounds )
1776
1779
self .name = name
1777
1780
self .bounds = bounds
1778
1781
self .use_count = 1 # track how many times this expression is used
1779
1782
self .dtype = dtype
1783
+ self .shape = shape
1780
1784
1781
1785
def __str__ (self ) -> str :
1782
1786
return self .name
@@ -1886,6 +1890,7 @@ def generate(
1886
1890
write : bool = True ,
1887
1891
assignment : bool = True ,
1888
1892
dtype : Optional [torch .dtype ] = None ,
1893
+ shape : BlockShapeType = None ,
1889
1894
) -> CSEVariableType :
1890
1895
if isinstance (expr , OpsValue ):
1891
1896
expr = expr .value
@@ -1906,8 +1911,12 @@ def generate(
1906
1911
assert isinstance (expr , str )
1907
1912
cache_key = expr
1908
1913
var = self .try_get (cache_key )
1914
+ if shape is None and not assignment :
1915
+ # since there's no assignment to a variable, use any shape here
1916
+ # other than None to avoid the unknown shape failures
1917
+ shape = ()
1909
1918
if not var :
1910
- var = self .newvar (bounds , dtype )
1919
+ var = self .newvar (bounds , dtype , shape )
1911
1920
self .put (cache_key , var )
1912
1921
if write :
1913
1922
if V .kernel .current_node :
@@ -1953,9 +1962,10 @@ def newvar(
1953
1962
self ,
1954
1963
bounds : ValueRanges [Any ] = ValueRanges .unknown (),
1955
1964
dtype : Optional [torch .dtype ] = None ,
1965
+ shape : BlockShapeType = None ,
1956
1966
) -> CSEVariableType :
1957
1967
var_name = f"{ self .name_prefix } { next (self .iter_buffer_ids )} "
1958
- var = V .kernel .create_cse_var (var_name , bounds , dtype )
1968
+ var = V .kernel .create_cse_var (var_name , bounds , dtype , shape )
1959
1969
self .varname_map [var_name ] = var
1960
1970
return var
1961
1971
@@ -1964,11 +1974,12 @@ def namedvar(
1964
1974
name : str ,
1965
1975
bounds : ValueRanges [Any ] = ValueRanges .unknown (),
1966
1976
dtype : Optional [torch .dtype ] = None ,
1977
+ shape : BlockShapeType = None ,
1967
1978
) -> CSEVariableType :
1968
1979
torch ._check_value (
1969
1980
name not in self .varname_map , lambda : f"duplicate name: { name } "
1970
1981
)
1971
- var = V .kernel .create_cse_var (name , bounds , dtype )
1982
+ var = V .kernel .create_cse_var (name , bounds , dtype , shape )
1972
1983
self .varname_map [name ] = var
1973
1984
return var
1974
1985
@@ -2424,45 +2435,64 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
2424
2435
2425
2436
value = getattr (self .parent_handler , name )(* args , ** kwargs )
2426
2437
dtype_handler = DtypePropagationOpsHandler ()
2438
+ shape_handler = ShapePropagationOpsHandler ()
2427
2439
2428
2440
backend = get_current_backend ()
2429
2441
2442
+ shape_op = getattr (shape_handler , name )
2430
2443
output_dtype = None
2444
+ output_shape = None
2445
+
2431
2446
if name == "masked" and backend == "triton" :
2432
2447
output_dtype = value .dtype
2448
+ output_shape = value .shape
2433
2449
elif name == "masked" and backend == "cpp" :
2434
2450
output_dtype = V .interpreter .current_node .meta .get (
2435
2451
OptimizationContext .key , None
2436
2452
).dtype
2453
+ # TODO: fix me
2454
+ output_shape = None
2437
2455
elif backend in ("triton" , "cpp" , "mps" ):
2438
2456
dtype_op = getattr (dtype_handler , name )
2439
2457
output_dtype = dtype_op (* args , ** kwargs )
2458
+ output_shape = shape_op (* args , ** kwargs )
2440
2459
2441
2460
if backend in ("triton" , "cpp" ):
2442
2461
# maybe there are some exceptions on mps?
2443
2462
assert output_dtype is not None
2444
2463
2445
2464
output_idx = 0
2446
2465
2447
- def do_cse (v : str ) -> CSEVariable :
2466
+ def do_cse (v : Union [ str , CSEVariable ] ) -> CSEVariable :
2448
2467
# we tree_map over the output, so we need to fetch corresponding dtype
2449
2468
nonlocal output_idx
2450
2469
var_dtype : Optional [torch .dtype ] = (
2451
2470
output_dtype [output_idx ]
2452
2471
if isinstance (output_dtype , (list , tuple ))
2453
2472
else output_dtype
2454
2473
)
2474
+ var_shape : BlockShapeType = (
2475
+ output_shape [output_idx ] # type: ignore[assignment]
2476
+ if isinstance (output_shape , (list , tuple ))
2477
+ and len (output_shape ) > 0
2478
+ and isinstance (output_shape [0 ], (list , tuple ))
2479
+ else output_shape
2480
+ )
2455
2481
output_idx += 1
2456
2482
2457
2483
# some cpp op implementations don't set the dtype
2458
- if backend == "cpp" and isinstance (v , CSEVariable ) and v .dtype is None :
2459
- v .dtype = var_dtype
2484
+ if isinstance (v , CSEVariable ):
2485
+ if backend == "cpp" and v .dtype is None :
2486
+ v .dtype = var_dtype
2487
+ if v .shape is None :
2488
+ v .shape = var_shape
2460
2489
2461
2490
csevar = V .kernel .cse .generate (
2462
2491
V .kernel .compute ,
2463
2492
v ,
2464
2493
bounds = bounds ,
2465
2494
dtype = output_dtype ,
2495
+ shape = output_shape ,
2466
2496
)
2467
2497
2468
2498
csevar .update_on_args (name , args , kwargs )
@@ -2559,7 +2589,13 @@ def indirect_indexing(
2559
2589
pos = var .bounds & ValueRanges (0 , int_oo )
2560
2590
new_bounds = new_bounds | pos
2561
2591
2562
- var = self .kernel .cse .generate (self .kernel .compute , stm , bounds = new_bounds )
2592
+ var = self .kernel .cse .generate (
2593
+ self .kernel .compute ,
2594
+ stm ,
2595
+ bounds = new_bounds ,
2596
+ dtype = var .dtype ,
2597
+ shape = var .shape ,
2598
+ )
2563
2599
2564
2600
sympy_var = self .parent_handler .indirect_indexing (var , size , check )
2565
2601
if generate_assert (check ):
0 commit comments