@@ -1324,14 +1324,31 @@ def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> C
1324
1324
while r < size_hints [1 ] and conditional_product (x , r ) < target :
1325
1325
r *= 2
1326
1326
1327
- cfg = {"XBLOCK" : x , "RBLOCK" : r }
1328
1327
if num_warps is None :
1329
1328
num_warps = conditional_product (x , r ) // 128
1330
1329
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
1331
- # therefore using half the number of warps here correspondingly.
1330
+ # therefore using half the number of warps here correspondingly.i
1332
1331
default_num_warps = 4 if torch .version .hip else 8
1333
1332
min_num_warps = 1 if torch .version .hip else 2
1334
1333
num_warps = next_power_of_2 (min (max (num_warps , min_num_warps ), default_num_warps ))
1334
+
1335
+ # Check if maxGridSize is exceeded - if so then must scale XBLOCK further
1336
+ max_grid_x = 4294967295 if torch .version .hip else 2147483647
1337
+ warp_size = 64 if torch .version .hip else 32
1338
+ num_blocks = int ((size_hints [0 ] + x - 1 ) // x )
1339
+ while (num_blocks * num_warps * warp_size ) > max_grid_x :
1340
+ if (x >= TRITON_MAX_BLOCK ["X" ]):
1341
+ if num_warps == 1 :
1342
+ break # If no more scaling possible then break
1343
+ num_warps = int (num_warps / 2 ) # If max XBLOCK then scale down warps as last resort
1344
+ x *= 2 # Scale up XBLOCK if grid exceeds limits
1345
+ num_blocks = int (num_blocks / 2 )
1346
+ while conditional_product (x , r ) > target :
1347
+ r = int (r / 2 )
1348
+ if r == 1 :
1349
+ break
1350
+
1351
+ cfg = {"XBLOCK" : x , "RBLOCK" : r }
1335
1352
check_config (cfg , xnumel = size_hints [0 ])
1336
1353
assert r <= TRITON_MAX_BLOCK ["R" ], f"increase TRITON_MAX_BLOCK['r'] to { r } "
1337
1354
return Config (cfg , num_warps = num_warps , num_stages = num_stages )
0 commit comments