Skip to content

Commit 012c13b

Browse files
jataylodnikolaev-amd
authored andcommitted
Scale XBLOCK in triton reduction configs to avoid hitting max grid (#1434)
1 parent 0c2f97c commit 012c13b

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,14 +1324,31 @@ def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> C
13241324
while r < size_hints[1] and conditional_product(x, r) < target:
13251325
r *= 2
13261326

1327-
cfg = {"XBLOCK": x, "RBLOCK": r}
13281327
if num_warps is None:
13291328
num_warps = conditional_product(x, r) // 128
13301329
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
1331-
# therefore using half the number of warps here correspondingly.
1330+
# therefore using half the number of warps here correspondingly.i
13321331
default_num_warps = 4 if torch.version.hip else 8
13331332
min_num_warps = 1 if torch.version.hip else 2
13341333
num_warps = next_power_of_2(min(max(num_warps, min_num_warps), default_num_warps))
1334+
1335+
# Check if maxGridSize is exceeded - if so then must scale XBLOCK further
1336+
max_grid_x = 4294967295 if torch.version.hip else 2147483647
1337+
warp_size = 64 if torch.version.hip else 32
1338+
num_blocks = int((size_hints[0] + x - 1) // x)
1339+
while(num_blocks * num_warps * warp_size) > max_grid_x:
1340+
if (x >= TRITON_MAX_BLOCK["X"]):
1341+
if num_warps == 1:
1342+
break # If no more scaling possible then break
1343+
num_warps = int(num_warps / 2) # If max XBLOCK then scale down warps as last resort
1344+
x *= 2 # Scale up XBLOCK if grid exceeds limits
1345+
num_blocks = int(num_blocks / 2)
1346+
while conditional_product(x, r) > target:
1347+
r = int(r / 2)
1348+
if r == 1:
1349+
break
1350+
1351+
cfg = {"XBLOCK": x, "RBLOCK": r}
13351352
check_config(cfg, xnumel=size_hints[0])
13361353
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
13371354
return Config(cfg, num_warps=num_warps, num_stages=num_stages)

0 commit comments

Comments
 (0)