Skip to content

Commit f272e0a

Browse files
ColinPepplerpytorchmergebot
authored andcommitted
[inductor] support unbacked symint divisors in vars_and_sizes (pytorch#130595)
Scenario: ``` >>> nodes IterationRangesEntry( x2, divisor=192*u0 + 192576, length=s1, (xindex//(192*u0 + 192576)), {x0: 192, x1: u0 + 1003, x2: s1, x3: 192*s1*u0 + 192576*s1, x4: 192*u0 + 192576}) IterationRangesEntry( x1, divisor=192, length=u0 + 1003, ModularIndexing(xindex, 192, u0 + 1003), {x0: 192, x1: u0 + 1003, x2: s1, x3: 192*s1*u0 + 192576*s1, x4: 192*u0 + 192576}) IterationRangesEntry( x0, divisor=1, length=192, ModularIndexing(xindex, 1, 192), {x0: 192, x1: u0 + 1003, x2: s1, x3: 192*s1*u0 + 192576*s1, x4: 192*u0 + 192576}) ``` Think about whether using fallback is safe here. I think it's safe because the divisor of one IterationRangesEntry should be the product of the lengths of the preceding IterationRangesEntry? Unless, one of the lengths divides by an unbacked symint? Pull Request resolved: pytorch#130595 Approved by: https://github.com/aakhundov, https://github.com/ezyang
1 parent 2b43d33 commit f272e0a

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,23 @@ def fn(x, w, repeats, is_bmm):
252252
expected = fn(*example_inputs)
253253
torch.testing.assert_close(actual, expected)
254254

255+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
256+
def test_unbacked_range_tree_divisor(self, device):
257+
def fn(x, num):
258+
u0 = num.item()
259+
torch._check_is_size(u0)
260+
zeros = torch.zeros(u0, device=device, dtype=torch.int)
261+
return (torch.ops.aten.index(x, [None, zeros]),)
262+
263+
example_inputs = (
264+
torch.randn(16, 16, device=device),
265+
torch.tensor(3, device=device),
266+
)
267+
268+
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
269+
expected = fn(*example_inputs)
270+
torch.testing.assert_close(actual, expected)
271+
255272

256273
instantiate_device_type_tests(
257274
TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu")

torch/_inductor/codegen/simd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ def vars_and_sizes(self, index: sympy.Expr):
200200
"""Figure out vars from this tree used in index"""
201201
nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
202202
nodes = [n for n in nodes if n and n.prefix == self.prefix]
203-
nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
203+
nodes.sort(
204+
key=lambda x: V.graph.sizevars.size_hint(
205+
x.divisor, fallback=config.unbacked_symint_fallback
206+
)
207+
)
204208
divisor = sympy.Integer(1)
205209
index_vars = []
206210
sizes = []

0 commit comments

Comments
 (0)