Skip to content

Commit 893ef7f

Browse files
[mlir][GPU] Fixes subgroup reduce lowering (#141825)
Fixes the final reduction steps which were taken from an implementation of scan, not reduction, causing lanes earlier in the wave to have incorrect results due to masking. Now aligning more closely with triton implementation : triton-lang/triton#5019 # Hypothetical example To provide an explanation of the issue with the current implementation, let's take the simple example of attempting to perform a sum over 64 lanes where the initial values are as follows (first lane has value 1, and all other lanes have value 0): ``` [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ``` When performing a sum reduction over these 64 lanes, in the current implementation we perform 6 dpp instructions which in sequential order do the following: 1) sum over clusters of 2 contiguous lanes 2) sum over clusters of 4 contiguous lanes 3) sum over clusters of 8 contiguous lanes 4) sum over an entire row 5) broadcast the result of last lane in each row to the next row and each lane sums current value with incoming value. 5) broadcast the result of the 32nd lane to last two rows and each lane sums current value with incoming value. After step 4) the result for the example above looks like this: ``` [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ``` After step 5) the result looks like this: ``` [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ``` After step 6) the result looks like this: ``` [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ``` Note that the correct value here is always 1, yet after the `dpp.broadcast` ops some lanes have incorrect values. The reason is that for these incorrect lanes, like lanes 0-15 in step 5, the `dpp.broadcast` op doesn't provide them incoming values from other lanes. Instead these lanes are provided either their own values, or 0 (depending on whether `bound_ctrl` is true or false) as values to sum over, either way these values are stale and these lanes shouldn't be used in general. So what this means: - For a subgroup reduce over 32 lanes (like Step 5), the correct result is stored in lanes 16 to 31 - For a subgroup reduce over 64 lanes (like Step 6), the correct result is stored in lanes 32 to 63. However in the current implementation we do not specifically read the value from one of the correct lanes when returning a final value. In some workloads it seems without this specification, the stale value from the first lane is returned instead. # Actual failing test For a specific example of how the current implementation causes issues, take a look at the IR below which represents an additive reduction over a dynamic dimension. ``` !matA = tensor<1x?xf16> !matB = tensor<1xf16> #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> func.func @only_producer_fusion_multiple_result(%arg0: !matA) -> !matB { %cst_1 = arith.constant 0.000000e+00 : f16 %c2_i64 = arith.constant 2 : i64 %0 = tensor.empty() : !matB %2 = linalg.fill ins(%cst_1 : f16) outs(%0 : !matB) -> !matB %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : !matA) outs(%2 : !matB) { ^bb0(%in: f16, %out: f16): %7 = arith.addf %in, %out : f16 linalg.yield %7 : f16 } -> !matB return %4 : !matB } ``` When provided an input of type `tensor<1x2xf16>` and values `{0, 1}` to perform the reduction over, the value returned is consistently 4. By the same analysis done above, this shows that the returned value is coming from one of these stale lanes and needs to be read instead from one of the lanes storing the correct result. Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 79ae407 commit 893ef7f

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -432,44 +432,50 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
432432
/*bound_ctrl=*/false);
433433
res = vector::makeArithReduction(
434434
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435-
if (ci.subgroupSize == 32) {
436-
Value lane0 = rewriter.create<arith::ConstantOp>(
437-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
438-
res =
439-
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
440-
}
441435
} else {
442436
return rewriter.notifyMatchFailure(
443437
op, "Subgroup reduce lowering to DPP not currently supported for "
444438
"this device.");
445439
}
440+
if (ci.subgroupSize == 32) {
441+
Value lane31 = rewriter.create<arith::ConstantOp>(
442+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
443+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
444+
}
446445
}
447446
if (ci.clusterSize >= 64) {
448447
if (chipset.majorVersion <= 9) {
449448
// Broadcast 31st lane value to rows 2 and 3.
450-
// Use row mask to avoid polluting rows 0 and 1.
451449
dpp = rewriter.create<amdgpu::DPPOp>(
452450
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
453-
rewriter.getUnitAttr(), 0xc, allBanks,
454-
/*bound_ctrl*/ false);
451+
rewriter.getUnitAttr(), 0xf, allBanks,
452+
/*bound_ctrl*/ true);
453+
res = vector::makeArithReduction(
454+
rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
455+
// Obtain reduction from last rows, the previous rows are polluted.
456+
Value lane63 = rewriter.create<arith::ConstantOp>(
457+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
458+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
455459

456460
} else if (chipset.majorVersion <= 12) {
457461
// Assume reduction across 32 lanes has been done.
458462
// Perform final reduction manually by summing values in lane 0 and
459463
// lane 32.
460-
Value lane0 = rewriter.create<arith::ConstantOp>(
461-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
462-
Value lane32 = rewriter.create<arith::ConstantOp>(
463-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
464-
dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465-
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
464+
Value lane31 = rewriter.create<arith::ConstantOp>(
465+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
466+
Value lane63 = rewriter.create<arith::ConstantOp>(
467+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
468+
lane31 =
469+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
470+
lane63 =
471+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
472+
res = vector::makeArithReduction(
473+
rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
466474
} else {
467475
return rewriter.notifyMatchFailure(
468476
op, "Subgroup reduce lowering to DPP not currently supported for "
469477
"this device.");
470478
}
471-
res = vector::makeArithReduction(rewriter, loc,
472-
gpu::convertReductionKind(mode), res, dpp);
473479
}
474480
assert(res.getType() == input.getType());
475481
return res;

mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ gpu.module @kernels {
349349
// CHECK-GFX10: %[[A4:.+]] = arith.addi %[[A3]], %[[P0]] : i16
350350
// CHECK-GFX10: %[[R0:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16
351351
// CHECK-GFX10: %[[R1:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16
352-
// CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R1]], %[[R0]] : i16
352+
// CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R0]], %[[R1]] : i16
353353
// CHECK-GFX10: "test.consume"(%[[A5]]) : (i16) -> ()
354354
%sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16
355355
"test.consume"(%sum0) : (i16) -> ()

0 commit comments

Comments
 (0)