27
27
28
28
import torch
29
29
import torch ._logging
30
+ from torch .fx .immutable_collections import immutable_dict
30
31
from torch .utils ._ordered_set import OrderedSet
31
32
from torch .utils ._sympy .functions import FloorDiv , Identity , ModularIndexing
32
33
from torch .utils ._sympy .symbol import (
@@ -339,7 +340,7 @@ class SIMDKernel(Kernel):
339
340
340
341
def __init__ (
341
342
self ,
342
- * groups ,
343
+ tiling : Dict [ str , sympy . Expr ] ,
343
344
features : SIMDKernelFeatures ,
344
345
pid_cache = None ,
345
346
override_persistent_reduction = None ,
@@ -352,11 +353,13 @@ def __init__(
352
353
self .mutations = features .get_mutations ()
353
354
self .body = IndentedBuffer ()
354
355
self .indexing_code = IndentedBuffer ()
355
- self .numels = [V .graph .sizevars .simplify (s ) for s in groups ]
356
+ self .numels = {
357
+ prefix : V .graph .sizevars .simplify (val ) for prefix , val in tiling .items ()
358
+ }
356
359
self .range_trees : List [IterationRangesRoot ] = []
357
360
self .range_tree_nodes : Dict [sympy .Symbol , IterationRangesEntry ] = {}
358
361
self .iter_vars_count = itertools .count ()
359
- self .inside_reduction = self .numels [- 1 ] != 1
362
+ self .inside_reduction = self .numels ["r" ] != 1
360
363
self .cooperative_reduction : bool = (
361
364
override_cooperative_reduction
362
365
if override_cooperative_reduction is not None
@@ -393,7 +396,7 @@ def want_no_x_dim(self):
393
396
return False
394
397
395
398
def initialize_range_tree (self , pid_cache ):
396
- no_r_dim = not self .inside_reduction or self .numels [- 1 ] == 1
399
+ no_r_dim = not self .inside_reduction or self .numels ["r" ] == 1
397
400
398
401
prefixes = "zyxr"
399
402
active_prefixes = prefixes [- len (self .numels ) :]
@@ -416,7 +419,7 @@ def initialize_range_tree(self, pid_cache):
416
419
self .range_trees .append (
417
420
IterationRangesRoot (
418
421
f"{ prefix } index" ,
419
- self .numels [i ],
422
+ self .numels [prefix ],
420
423
prefix ,
421
424
index ,
422
425
self ,
@@ -525,7 +528,7 @@ def disable_reduction(self):
525
528
526
529
@contextlib .contextmanager
527
530
def ctx ():
528
- if self .numels [- 1 ] == 1 :
531
+ if self .numels ["r" ] == 1 :
529
532
assert not self .inside_reduction
530
533
yield
531
534
return
@@ -688,7 +691,7 @@ def is_broadcasted(self, index: sympy.Expr):
688
691
simplify = V .graph .sizevars .simplify
689
692
return any (
690
693
simplify (idx_range ) != simplify (iter_range ) # type: ignore[arg-type]
691
- for idx_range , iter_range in zip (index_numels , self .numels )
694
+ for idx_range , iter_range in zip (index_numels , self .numels . values () )
692
695
)
693
696
694
697
def index_to_str (self , index : sympy .Expr ) -> str :
@@ -855,7 +858,7 @@ def estimate_kernel_num_bytes(self):
855
858
# for the "cat". However, I think it might be a bit overwhelming that
856
859
# we add such complexity only for handling some particular cases for
857
860
# benchmarking.
858
- out_numel = V .graph .sizevars .size_hint (sympy_product (self .numels ))
861
+ out_numel = V .graph .sizevars .size_hint (sympy_product (self .numels . values () ))
859
862
for i , arg in enumerate (call_args ):
860
863
# "buf" may be narrowed. In this case, the number of memory accesses
861
864
# should be estimated based on the reinterpreted layout.
@@ -960,7 +963,7 @@ def warn_mix_layout(self, kernel_name):
960
963
def welford_reduce_fallback (self , dtype , value ):
961
964
sum_ = ops .reduction (dtype , dtype , "sum" , value )
962
965
self .inside_reduction = False
963
- rnumel = ops .index_expr (self .numels [- 1 ], dtype )
966
+ rnumel = ops .index_expr (self .numels ["r" ], dtype )
964
967
mean = ops .truediv (sum_ , rnumel )
965
968
966
969
self .inside_reduction = True
@@ -1081,8 +1084,8 @@ def can_fuse(self, node1, node2):
1081
1084
config .triton .tiling_prevents_reduction_fusion
1082
1085
and not node1 .is_template ()
1083
1086
):
1084
- is_reduction_tiling_valid = self . select_tiling (
1085
- node1 .get_nodes (), numel1
1087
+ is_reduction_tiling_valid = tuple (
1088
+ self . select_tiling ( node1 .get_nodes (), numel1 ). values ()
1086
1089
) in (
1087
1090
(numel1 , 1 ),
1088
1091
(numel2 , rnumel2 , 1 ),
@@ -1246,11 +1249,11 @@ def can_use_32bit_indexing(
1246
1249
1247
1250
def codegen_node_schedule (self , kernel_features : SIMDKernelFeatures ):
1248
1251
node_schedule = kernel_features .node_schedule
1249
- tiled_groups = self .select_tiling (
1252
+ tiling = self .select_tiling (
1250
1253
node_schedule , kernel_features .numel , kernel_features .reduction_numel
1251
1254
)
1252
1255
kernels = self .create_kernel_choices (
1253
- kernel_features , tiled_groups , {"features" : kernel_features }
1256
+ kernel_features , [ tiling ] , {"features" : kernel_features }
1254
1257
)
1255
1258
for kernel in kernels :
1256
1259
self .codegen_node_schedule_with_kernel (node_schedule , kernel )
@@ -1426,10 +1429,10 @@ def generate_combo_kernel_code(
1426
1429
for pn , nodes in zip (subkernel_nodes , fused_node_lists ):
1427
1430
_ , (numel , rnumel ) = max (nodes , key = lambda x : int (x .is_reduction ())).group
1428
1431
node_schedule = self .generate_node_schedule (nodes , numel , rnumel )
1429
- tiled_groups = self .select_tiling (node_schedule , numel , rnumel )
1430
- node_schedule_map [pn ] = node_schedule , tiled_groups , numel , rnumel
1432
+ tiling = self .select_tiling (node_schedule , numel , rnumel )
1433
+ node_schedule_map [pn ] = node_schedule , tiling , numel , rnumel
1431
1434
subkernel_map [pn ] = ComboKernel .create_triton_kernel (
1432
- * tiled_groups ,
1435
+ tiling ,
1433
1436
features = SIMDKernelFeatures (node_schedule , numel , rnumel ),
1434
1437
optimize_mask = not mixed_sizes ,
1435
1438
)
@@ -1562,7 +1565,23 @@ def candidate_tilings(node):
1562
1565
return tilings
1563
1566
1564
1567
@classmethod
1565
- def select_tiling (cls , node_schedule , numel , reduction_numel = sympy .S .One ):
1568
+ def create_tiling (
1569
+ cls , pw_tiling : Sequence [sympy .Expr ], reduction_tiling : Sequence [sympy .Expr ]
1570
+ ) -> Dict [str , sympy .Expr ]:
1571
+ """
1572
+ Create a tiling dict from pointwise and reduction splits.
1573
+ """
1574
+ pw_prefixes = ["z" , "y" , "x" ][- len (pw_tiling ) :]
1575
+ reduction_prefixes = ["r" ][: len (reduction_tiling )]
1576
+ return immutable_dict (
1577
+ list (zip (pw_prefixes , pw_tiling ))
1578
+ + list (zip (reduction_prefixes , reduction_tiling ))
1579
+ )
1580
+
1581
+ @classmethod
1582
+ def select_tiling (
1583
+ cls , node_schedule , numel , reduction_numel = sympy .S .One
1584
+ ) -> Dict [str , sympy .Expr ]:
1566
1585
"""
1567
1586
Heuristics to decide how to tile kernels.
1568
1587
Currently, we tile based on stride-1 dimensions.
@@ -1571,6 +1590,7 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
1571
1590
`(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
1572
1591
1573
1592
"""
1593
+ default_tiling = cls .create_tiling ([numel ], [reduction_numel ])
1574
1594
if reduction_numel != 1 or config .triton .max_tiles <= 1 :
1575
1595
# TODO(jansel): should we tile reductions?
1576
1596
# do perf hint here if stride-1 dim is not being reduced
@@ -1579,7 +1599,7 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
1579
1599
if len (cls .candidate_tilings (node )) > 0 :
1580
1600
perf_hint_log .info ("reduction over non-contiguous dims" )
1581
1601
break
1582
- return ( numel , reduction_numel )
1602
+ return default_tiling
1583
1603
1584
1604
seen_names : OrderedSet [str ] = OrderedSet ()
1585
1605
candidate_tiles : Counter [Any ] = collections .Counter ()
@@ -1647,9 +1667,9 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
1647
1667
for node in node_schedule
1648
1668
if isinstance (node , scheduler .SchedulerNode )
1649
1669
):
1650
- return new_groups
1670
+ return cls . create_tiling ( tiled_groups , [ reduction_numel ])
1651
1671
1652
- return ( numel , reduction_numel )
1672
+ return default_tiling
1653
1673
1654
1674
def flush (self ):
1655
1675
pass
@@ -1661,9 +1681,9 @@ def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
1661
1681
if not nodes [0 ].is_template ():
1662
1682
_ , (numel , rnumel ) = max (nodes , key = lambda x : int (x .is_reduction ())).group
1663
1683
node_schedule = self .generate_node_schedule (nodes , numel , rnumel )
1664
- tiled_groups = self .select_tiling (node_schedule , numel , rnumel )
1684
+ tiling = self .select_tiling (node_schedule , numel , rnumel )
1665
1685
kernel = self .kernel_type (
1666
- * tiled_groups ,
1686
+ tiling ,
1667
1687
features = SIMDKernelFeatures (node_schedule , numel , rnumel ),
1668
1688
)
1669
1689
self .codegen_node_schedule_with_kernel (node_schedule , kernel )
0 commit comments