Skip to content

Commit 5deca07

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor] Represent tiling as a dict (pytorch#141751)
# Summary Preparatory refactor for pytorch#137243. This makes it easier to generalize to multi-dimensional reductions. This diff refactors `self.numels` from a tuple like `(8,16)` to a dict like `{"x": 8, "r": 16}`. Note: this is based off of pytorch#141738, which enables `tree.is_reduction`. That PR should land first. # Test plan The existing CI provides good coverage. Pull Request resolved: pytorch#141751 Approved by: https://github.com/jansel
1 parent 96be048 commit 5deca07

File tree

6 files changed

+83
-44
lines changed

6 files changed

+83
-44
lines changed

torch/_inductor/codegen/halide.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,10 @@ class HalideKernel(SIMDKernel):
673673

674674
def __init__(
675675
self,
676-
*groups,
676+
tiling: Dict[str, sympy.Expr],
677677
**kwargs,
678678
) -> None:
679-
super().__init__(*groups, **kwargs)
679+
super().__init__(tiling, **kwargs)
680680
# For halide, we just write directly to the body
681681
self.compute = self.body
682682
self.loads = self.body

torch/_inductor/codegen/simd.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
import torch._logging
30+
from torch.fx.immutable_collections import immutable_dict
3031
from torch.utils._ordered_set import OrderedSet
3132
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
3233
from torch.utils._sympy.symbol import (
@@ -339,7 +340,7 @@ class SIMDKernel(Kernel):
339340

340341
def __init__(
341342
self,
342-
*groups,
343+
tiling: Dict[str, sympy.Expr],
343344
features: SIMDKernelFeatures,
344345
pid_cache=None,
345346
override_persistent_reduction=None,
@@ -352,11 +353,13 @@ def __init__(
352353
self.mutations = features.get_mutations()
353354
self.body = IndentedBuffer()
354355
self.indexing_code = IndentedBuffer()
355-
self.numels = [V.graph.sizevars.simplify(s) for s in groups]
356+
self.numels = {
357+
prefix: V.graph.sizevars.simplify(val) for prefix, val in tiling.items()
358+
}
356359
self.range_trees: List[IterationRangesRoot] = []
357360
self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
358361
self.iter_vars_count = itertools.count()
359-
self.inside_reduction = self.numels[-1] != 1
362+
self.inside_reduction = self.numels["r"] != 1
360363
self.cooperative_reduction: bool = (
361364
override_cooperative_reduction
362365
if override_cooperative_reduction is not None
@@ -393,7 +396,7 @@ def want_no_x_dim(self):
393396
return False
394397

395398
def initialize_range_tree(self, pid_cache):
396-
no_r_dim = not self.inside_reduction or self.numels[-1] == 1
399+
no_r_dim = not self.inside_reduction or self.numels["r"] == 1
397400

398401
prefixes = "zyxr"
399402
active_prefixes = prefixes[-len(self.numels) :]
@@ -416,7 +419,7 @@ def initialize_range_tree(self, pid_cache):
416419
self.range_trees.append(
417420
IterationRangesRoot(
418421
f"{prefix}index",
419-
self.numels[i],
422+
self.numels[prefix],
420423
prefix,
421424
index,
422425
self,
@@ -525,7 +528,7 @@ def disable_reduction(self):
525528

526529
@contextlib.contextmanager
527530
def ctx():
528-
if self.numels[-1] == 1:
531+
if self.numels["r"] == 1:
529532
assert not self.inside_reduction
530533
yield
531534
return
@@ -688,7 +691,7 @@ def is_broadcasted(self, index: sympy.Expr):
688691
simplify = V.graph.sizevars.simplify
689692
return any(
690693
simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
691-
for idx_range, iter_range in zip(index_numels, self.numels)
694+
for idx_range, iter_range in zip(index_numels, self.numels.values())
692695
)
693696

694697
def index_to_str(self, index: sympy.Expr) -> str:
@@ -855,7 +858,7 @@ def estimate_kernel_num_bytes(self):
855858
# for the "cat". However, I think it might be a bit overwhelming that
856859
# we add such complexity only for handling some particular cases for
857860
# benchmarking.
858-
out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels))
861+
out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values()))
859862
for i, arg in enumerate(call_args):
860863
# "buf" may be narrowed. In this case, the number of memory accesses
861864
# should be estimated based on the reinterpreted layout.
@@ -960,7 +963,7 @@ def warn_mix_layout(self, kernel_name):
960963
def welford_reduce_fallback(self, dtype, value):
961964
sum_ = ops.reduction(dtype, dtype, "sum", value)
962965
self.inside_reduction = False
963-
rnumel = ops.index_expr(self.numels[-1], dtype)
966+
rnumel = ops.index_expr(self.numels["r"], dtype)
964967
mean = ops.truediv(sum_, rnumel)
965968

966969
self.inside_reduction = True
@@ -1081,8 +1084,8 @@ def can_fuse(self, node1, node2):
10811084
config.triton.tiling_prevents_reduction_fusion
10821085
and not node1.is_template()
10831086
):
1084-
is_reduction_tiling_valid = self.select_tiling(
1085-
node1.get_nodes(), numel1
1087+
is_reduction_tiling_valid = tuple(
1088+
self.select_tiling(node1.get_nodes(), numel1).values()
10861089
) in (
10871090
(numel1, 1),
10881091
(numel2, rnumel2, 1),
@@ -1246,11 +1249,11 @@ def can_use_32bit_indexing(
12461249

12471250
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
12481251
node_schedule = kernel_features.node_schedule
1249-
tiled_groups = self.select_tiling(
1252+
tiling = self.select_tiling(
12501253
node_schedule, kernel_features.numel, kernel_features.reduction_numel
12511254
)
12521255
kernels = self.create_kernel_choices(
1253-
kernel_features, tiled_groups, {"features": kernel_features}
1256+
kernel_features, [tiling], {"features": kernel_features}
12541257
)
12551258
for kernel in kernels:
12561259
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
@@ -1426,10 +1429,10 @@ def generate_combo_kernel_code(
14261429
for pn, nodes in zip(subkernel_nodes, fused_node_lists):
14271430
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
14281431
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
1429-
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
1430-
node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel
1432+
tiling = self.select_tiling(node_schedule, numel, rnumel)
1433+
node_schedule_map[pn] = node_schedule, tiling, numel, rnumel
14311434
subkernel_map[pn] = ComboKernel.create_triton_kernel(
1432-
*tiled_groups,
1435+
tiling,
14331436
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
14341437
optimize_mask=not mixed_sizes,
14351438
)
@@ -1562,7 +1565,23 @@ def candidate_tilings(node):
15621565
return tilings
15631566

15641567
@classmethod
1565-
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
1568+
def create_tiling(
1569+
cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr]
1570+
) -> Dict[str, sympy.Expr]:
1571+
"""
1572+
Create a tiling dict from pointwise and reduction splits.
1573+
"""
1574+
pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :]
1575+
reduction_prefixes = ["r"][: len(reduction_tiling)]
1576+
return immutable_dict(
1577+
list(zip(pw_prefixes, pw_tiling))
1578+
+ list(zip(reduction_prefixes, reduction_tiling))
1579+
)
1580+
1581+
@classmethod
1582+
def select_tiling(
1583+
cls, node_schedule, numel, reduction_numel=sympy.S.One
1584+
) -> Dict[str, sympy.Expr]:
15661585
"""
15671586
Heuristics to decide how to tile kernels.
15681587
Currently, we tile based on stride-1 dimensions.
@@ -1571,6 +1590,7 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
15711590
`(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
15721591
15731592
"""
1593+
default_tiling = cls.create_tiling([numel], [reduction_numel])
15741594
if reduction_numel != 1 or config.triton.max_tiles <= 1:
15751595
# TODO(jansel): should we tile reductions?
15761596
# do perf hint here if stride-1 dim is not being reduced
@@ -1579,7 +1599,7 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
15791599
if len(cls.candidate_tilings(node)) > 0:
15801600
perf_hint_log.info("reduction over non-contiguous dims")
15811601
break
1582-
return (numel, reduction_numel)
1602+
return default_tiling
15831603

15841604
seen_names: OrderedSet[str] = OrderedSet()
15851605
candidate_tiles: Counter[Any] = collections.Counter()
@@ -1647,9 +1667,9 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
16471667
for node in node_schedule
16481668
if isinstance(node, scheduler.SchedulerNode)
16491669
):
1650-
return new_groups
1670+
return cls.create_tiling(tiled_groups, [reduction_numel])
16511671

1652-
return (numel, reduction_numel)
1672+
return default_tiling
16531673

16541674
def flush(self):
16551675
pass
@@ -1661,9 +1681,9 @@ def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
16611681
if not nodes[0].is_template():
16621682
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
16631683
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
1664-
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
1684+
tiling = self.select_tiling(node_schedule, numel, rnumel)
16651685
kernel = self.kernel_type(
1666-
*tiled_groups,
1686+
tiling,
16671687
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
16681688
)
16691689
self.codegen_node_schedule_with_kernel(node_schedule, kernel)

torch/_inductor/codegen/triton.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def remove_dims(it):
340340
if (
341341
not V.kernel.inside_reduction
342342
and len(params.strides) == len(V.kernel.numels) - 1
343-
and V.kernel.numels[-1] != 1
343+
and V.kernel.numels["r"] != 1
344344
):
345345
# Need to expand rank by 1 to match rank when self.inside_reduction=True
346346
final_shape.append(sympy.S.One)
@@ -1419,15 +1419,15 @@ class TritonKernel(SIMDKernel):
14191419

14201420
def __init__(
14211421
self,
1422-
*groups,
1422+
tiling: Dict[str, sympy.Expr],
14231423
min_elem_per_thread=0,
14241424
optimize_mask=True,
14251425
fixed_config: Optional[FixedTritonConfig] = None,
14261426
**kwargs,
14271427
) -> None:
14281428
self.optimize_mask: bool = optimize_mask
14291429
self.fixed_config = fixed_config
1430-
super().__init__(*groups, **kwargs)
1430+
super().__init__(tiling, **kwargs)
14311431
self.cse = TritonCSE(self.newvar_prefix, self.suffix)
14321432
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
14331433
self.post_loop_store: IndentedBuffer = IndentedBuffer()
@@ -1463,7 +1463,7 @@ def init_cooperative_reduction(self):
14631463
if tree.grid_dim is not None:
14641464
tree.grid_dim += 1
14651465

1466-
sem_count, _ = self.numels
1466+
sem_count = self.numels["x"]
14671467
if self.fixed_config:
14681468
sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"])
14691469
self.semaphores_name = self.args.semaphores(sem_count)
@@ -2440,7 +2440,7 @@ def codegen_cooperative_reduction_peer_combine(self, result_var, dtype):
24402440
column. After the barrier, every thread block loads the completed value so that it can compute the final
24412441
value independently.
24422442
"""
2443-
xnumel, rnumel = self.numels
2443+
xnumel = self.numels["x"]
24442444
mask = "xindex < xnumel" if xnumel != 1 and not self.no_x_dim else None
24452445
expand = "" if self.no_x_dim else "[None,:]"
24462446

@@ -2946,7 +2946,7 @@ def codegen_kernel(self, name=None):
29462946
code = IndentedBuffer()
29472947

29482948
size_hints = []
2949-
for numel in self.numels:
2949+
for numel in self.numels.values():
29502950
numel_hint = V.graph.sizevars.symbolic_hint(numel)
29512951
if not isinstance(numel_hint, (int, sympy.Integer)):
29522952
# This default heuristic hint was picked carefully: it is
@@ -3384,7 +3384,7 @@ def iteration_ranges_codegen_header(self, entry, code):
33843384

33853385

33863386
class TritonScheduling(SIMDScheduling):
3387-
kernel_type = TritonKernel
3387+
kernel_type: Type[Any] = TritonKernel
33883388
backend_features = dict.fromkeys( # dict for deterministic order
33893389
[
33903390
BackendFeature.FOREACH,
@@ -3642,7 +3642,7 @@ def add_multi_kernel_choices(
36423642
)
36433643
)
36443644
if optional_cooperative:
3645-
_, rnumel = kernel.numels
3645+
rnumel = kernel.numels["r"]
36463646
# for larger sizes non-cooperative gets very slow
36473647
if V.graph.sizevars.statically_known_leq(rnumel, 65536):
36483648
kernels.append(

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Union,
1717
)
1818

19+
import sympy
1920
from sympy import Integer, Symbol
2021

2122
from .. import config, metrics
@@ -102,7 +103,7 @@ def _default_custom_combo_kernel_horizontal_partition(
102103
for n in not_reduction
103104
if not kernel_map[n].inside_reduction
104105
and len(kernel_map[n].numels) == 2
105-
and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS
106+
and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS
106107
]
107108
if large_pointwise:
108109
# TODO benchmark the performance when large pointwise nodes combining with others
@@ -216,7 +217,7 @@ def _base_horizontal_partition(
216217
ndim = len(tiled_groups)
217218
assert ndim >= 2, f"Combokernel not support tile {tiled_groups}"
218219
if not mixed_sizes and ndim == 3:
219-
y_elem = tiled_groups[0]
220+
y_elem = tiled_groups["y"]
220221
partition_state = yelem_to_partition_state[y_elem]
221222
ComboKernel._update_partition(
222223
partition_state, read_write_count, node_info
@@ -463,7 +464,7 @@ def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
463464

464465
@staticmethod
465466
def create_triton_kernel(
466-
*groups: Any,
467+
tiling: Dict[str, sympy.Expr],
467468
features: SIMDKernelFeatures,
468469
optimize_mask: bool,
469470
) -> TritonKernel:
@@ -472,7 +473,7 @@ def create_triton_kernel(
472473
2) numels except x dimension are the same for each sub kernel.
473474
"""
474475
return TritonKernel(
475-
*groups,
476+
tiling,
476477
features=features,
477478
pid_cache={"tl.program_id(0)": "pid_offset"},
478479
optimize_mask=optimize_mask,
@@ -564,7 +565,7 @@ def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None:
564565
def select_heuristics(self, sub_kernel: TritonKernel) -> Tuple[str, List[int]]:
565566
size_hints = [
566567
next_power_of_2(V.graph.sizevars.size_hint(numel))
567-
for numel in sub_kernel.numels
568+
for numel in sub_kernel.numels.values()
568569
]
569570
if sub_kernel.persistent_reduction:
570571
assert sub_kernel.inside_reduction

torch/_inductor/codegen/triton_split_scan.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
# mypy: allow-untyped-defs
22
import functools
3+
from typing import Dict
4+
5+
import sympy
36

47
from torch._inductor import config
58
from torch._inductor.codegen.simd import IterationRangesRoot
69
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
710
from torch._inductor.runtime.triton_heuristics import split_scan_grid
8-
from torch._prims_common import prod
911
from torch.utils._sympy.functions import CeilDiv
1012

13+
from ..utils import sympy_product
14+
from .simd import prefix_is_reduction
15+
1116

1217
class TritonSplitScanKernel(TritonKernel):
1318
"""Generates a triton kernel that supports ops.scan calls while also splitting
@@ -27,15 +32,15 @@ class TritonSplitScanKernel(TritonKernel):
2732

2833
def __init__(
2934
self,
30-
*groups,
35+
tiling: Dict[str, sympy.Expr],
3136
pid_cache=None,
3237
fixed_config=None,
3338
**kwargs,
3439
) -> None:
3540
assert pid_cache is None, "not supported"
3641
assert fixed_config is None, "not supported"
3742
super().__init__(
38-
*groups,
43+
tiling,
3944
**kwargs,
4045
)
4146
self.no_x_dim = True
@@ -54,7 +59,8 @@ def initialize_range_tree(self, pid_cache):
5459
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
5560

5661
grid_dims = "rxy"
57-
for numel, prefix in zip(self.numels, active_prefixes):
62+
for prefix in active_prefixes:
63+
numel = self.numels[prefix]
5864
is_reduction = prefix == "r"
5965
tensor_dim = 0 if is_reduction else None
6066
grid_dim = grid_dims.find(prefix)
@@ -99,7 +105,17 @@ def scan(self, dtypes, combine_fn, values):
99105

100106
assert len(self.numels) == 2, "Unexpected tiling"
101107
min_rblock = config.triton.min_split_scan_rblock
102-
max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock)
108+
reduction_numel = sympy_product(
109+
numel
110+
for prefix, numel in self.numels.items()
111+
if prefix_is_reduction(prefix)
112+
)
113+
pointwise_numel = sympy_product(
114+
numel
115+
for prefix, numel in self.numels.items()
116+
if not prefix_is_reduction(prefix)
117+
)
118+
max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
103119
nbytes = scratch_nbytes_per_block * max_blocks
104120
scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
105121
if offset != 0:

0 commit comments

Comments
 (0)