Skip to content

Commit b2fe1b9

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Fix 3d tiling (pytorch#141709)
Fixes pytorch#141121 Pull Request resolved: pytorch#141709 Approved by: https://github.com/eellison
1 parent 90f19fe commit b2fe1b9

File tree

8 files changed

+50
-17
lines changed

8 files changed

+50
-17
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,33 @@ def foo(x, y):
14951495
device_stats2["active.all.peak"] <= device_stats["active.all.peak"]
14961496
)
14971497

1498+
@config.patch(
1499+
{
1500+
"triton.prefer_nd_tiling": True,
1501+
"triton.max_tiles": 3,
1502+
}
1503+
)
1504+
def test_3d_tiling(self):
1505+
full_size, view_size, num_block_pointers, num_tiles = (
1506+
(5, 5, 5, 5, 5),
1507+
(3, 3, 5, 3, 5),
1508+
1,
1509+
2,
1510+
)
1511+
GPU_TYPE = "cuda"
1512+
1513+
def get_input() -> torch.Tensor:
1514+
device = torch.device(GPU_TYPE)
1515+
full = torch.randn(full_size).to(device)
1516+
return torch.as_strided(full, view_size, full.stride())
1517+
1518+
a, b = get_input(), get_input()
1519+
1520+
opt_fn = torch.compile(functools.partial(torch.add))
1521+
result, (code,) = run_and_get_code(opt_fn, a, b)
1522+
self.assertEqual(result, a + b)
1523+
self.assertIn("znumel", code)
1524+
14981525
def test_repeated_masked_load(self):
14991526
target_size = (8, 2)
15001527
mem_eff_temporal_upsampling_interp_chunks = 2

test/inductor/test_debug_trace.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def fn(a, b):
8080
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
8181
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
8282
class op0_loop_body:
83-
var_ranges = {z0: 256}
84-
index0 = z0
83+
var_ranges = {p0: 256}
84+
index0 = p0
8585
def body(self, ops):
8686
get_index = self.get_index('index0')
8787
load = ops.load('arg0_1', get_index)
@@ -107,8 +107,8 @@ def body(self, ops):
107107
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
108108
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
109109
class op1_loop_body:
110-
var_ranges = {z0: 256}
111-
index0 = z0
110+
var_ranges = {p0: 256}
111+
index0 = p0
112112
def body(self, ops):
113113
get_index = self.get_index('index0')
114114
load = ops.load('buf0', get_index)
@@ -161,8 +161,8 @@ def body(self, ops):
161161
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
162162
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
163163
class op0_loop_body:
164-
var_ranges = {z0: 256}
165-
index0 = z0
164+
var_ranges = {p0: 256}
165+
index0 = p0
166166
def body(self, ops):
167167
get_index = self.get_index('index0')
168168
load = ops.load('arg0_1', get_index)
@@ -187,8 +187,8 @@ def body(self, ops):
187187
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
188188
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
189189
class op1_loop_body:
190-
var_ranges = {z0: 256}
191-
index0 = z0
190+
var_ranges = {p0: 256}
191+
index0 = p0
192192
def body(self, ops):
193193
get_index = self.get_index('index0')
194194
load = ops.load('buf0', get_index)

test/inductor/test_loop_ordering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ def test_reorder_twice(self):
117117
snode = SchedulerNode(V.graph.scheduler, buf)
118118
snode.apply_new_loop_order([1, 0])
119119
prefix1 = self._get_snode_body_sym_prefix(snode)
120-
self.assertTrue(prefix1 == "z")
120+
self.assertTrue(prefix1 == "p")
121121
snode.apply_new_loop_order([1, 0])
122122
prefix2 = self._get_snode_body_sym_prefix(snode)
123-
self.assertTrue(prefix2 == "z")
123+
self.assertTrue(prefix2 == "p")
124124

125125
def test_reorder_and_merge_loops(self):
126126
sizes = (1024, 2048)
@@ -163,7 +163,7 @@ def inner_fn(index):
163163
_, body = buf.simplify_and_reorder()
164164
new_body = body.reorder_iter_loops([1, 2, 3, 0])
165165

166-
z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4))
166+
z0, z1, z2, z3 = (sympy_index_symbol(f"p{i}") for i in range(4))
167167
self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49})
168168
self.assertEqual(
169169
body.indexing_exprs["index0"],

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ class TritonSymbols:
155155

156156
block_offsets = {
157157
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
158-
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
158+
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
159159
}
160160

161161
block_sizes = {
162162
symt: sympy.Symbol(
163163
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
164164
)
165-
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
165+
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
166166
}
167167

168168
@classmethod
@@ -1564,7 +1564,7 @@ def indexing(
15641564
else:
15651565
# var is one of xN, yN or rN
15661566
assert symbol_is_type(
1567-
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
1567+
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
15681568
), var.name
15691569
mask_vars.add(f"{var.name[0]}mask")
15701570

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,10 @@ class triton:
963963
dense_indexing = False
964964

965965
# limit tiling dimensions
966+
# - max_tiles=1 disables tiling
967+
# - max_tiles=2 is the default
968+
# - max_tiles=3 is experimental and may have bugs
969+
# higher values are unsupported
966970
max_tiles = 2
967971

968972
# Prefer higher dimensional tilings. This simplifies indexing expressions, making

torch/_inductor/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4155,7 +4155,7 @@ def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): # type:
41554155
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
41564156
iter_ranges,
41574157
reduce_ranges,
4158-
prefix="z",
4158+
prefix="p",
41594159
)
41604160
body = LoopBody(
41614161
body,

torch/_inductor/loop_body.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def merge_loops(self) -> LoopBody:
215215
# use the original symbol prefix
216216
# Can try to optimize if this is a bottleneck for compilation time
217217
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
218-
iter_sizes, reduce_sizes, prefix="z"
218+
iter_sizes, reduce_sizes, prefix="p"
219219
)
220220
new_body2 = LoopBody(
221221
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@@ -259,7 +259,7 @@ def new_body(*indices: Sequence[sympy.Expr]) -> Any:
259259

260260
# use the original symbol prefix so we can do multiple round of reordering
261261
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
262-
*new_sizes, prefix="z" # type: ignore[arg-type]
262+
*new_sizes, prefix="p" # type: ignore[arg-type]
263263
)
264264
new_body = LoopBody(
265265
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2

torch/utils/_sympy/symbol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class SymT(Enum):
4747
# Inductor: iteration domain for blockIdx.x/blockIdx.y
4848
XBLOCK = auto()
4949
YBLOCK = auto()
50+
ZBLOCK = auto()
5051
# Inductor: this is used solely for dynamic_reshape_indexer
5152
VIEW = auto()
5253
# Alternate (non-modular) indexing used in halide kernels
@@ -70,6 +71,7 @@ class SymT(Enum):
7071
SymT.TEMPLATE_INDEX: "idx",
7172
SymT.XBLOCK: "x",
7273
SymT.YBLOCK: "y",
74+
SymT.ZBLOCK: "z",
7375
SymT.INDIRECT: "indirect", # false aliasing?
7476
SymT.VIEW: "view",
7577
SymT.HALIDE: "h",

0 commit comments

Comments
 (0)