Skip to content

Commit 75e849e

Browse files
[mlir][GPU] Fixes subgroup reduce lowering
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent b035580 commit 75e849e

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -432,44 +432,50 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
432432
/*bound_ctrl=*/false);
433433
res = vector::makeArithReduction(
434434
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435-
if (ci.subgroupSize == 32) {
436-
Value lane0 = rewriter.create<arith::ConstantOp>(
437-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
438-
res =
439-
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
440-
}
441435
} else {
442436
return rewriter.notifyMatchFailure(
443437
op, "Subgroup reduce lowering to DPP not currently supported for "
444438
"this device.");
445439
}
440+
if (ci.subgroupSize == 32) {
441+
Value lane31 = rewriter.create<arith::ConstantOp>(
442+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
443+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
444+
}
446445
}
447446
if (ci.clusterSize >= 64) {
448447
if (chipset.majorVersion <= 9) {
449448
// Broadcast 31st lane value to rows 2 and 3.
450-
// Use row mask to avoid polluting rows 0 and 1.
451449
dpp = rewriter.create<amdgpu::DPPOp>(
452450
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
453-
rewriter.getUnitAttr(), 0xc, allBanks,
454-
/*bound_ctrl*/ false);
451+
rewriter.getUnitAttr(), 0xf, allBanks,
452+
/*bound_ctrl*/ true);
453+
res = vector::makeArithReduction(
454+
rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
455+
// Obtain reduction from last rows, the previous rows are polluted.
456+
Value lane63 = rewriter.create<arith::ConstantOp>(
457+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
458+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
455459

456460
} else if (chipset.majorVersion <= 12) {
457461
// Assume reduction across 32 lanes has been done.
458462
// Perform final reduction manually by summing values in lane 0 and
459463
// lane 32.
460-
Value lane0 = rewriter.create<arith::ConstantOp>(
461-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
462-
Value lane32 = rewriter.create<arith::ConstantOp>(
463-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
464-
dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465-
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
464+
Value lane31 = rewriter.create<arith::ConstantOp>(
465+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
466+
Value lane63 = rewriter.create<arith::ConstantOp>(
467+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
468+
lane31 =
469+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
470+
lane63 =
471+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
472+
res = vector::makeArithReduction(
473+
rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
466474
} else {
467475
return rewriter.notifyMatchFailure(
468476
op, "Subgroup reduce lowering to DPP not currently supported for "
469477
"this device.");
470478
}
471-
res = vector::makeArithReduction(rewriter, loc,
472-
gpu::convertReductionKind(mode), res, dpp);
473479
}
474480
assert(res.getType() == input.getType());
475481
return res;

mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ gpu.module @kernels {
349349
// CHECK-GFX10: %[[A4:.+]] = arith.addi %[[A3]], %[[P0]] : i16
350350
// CHECK-GFX10: %[[R0:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16
351351
// CHECK-GFX10: %[[R1:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16
352-
// CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R1]], %[[R0]] : i16
352+
// CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R0]], %[[R1]] : i16
353353
// CHECK-GFX10: "test.consume"(%[[A5]]) : (i16) -> ()
354354
%sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16
355355
"test.consume"(%sum0) : (i16) -> ()

0 commit comments

Comments
 (0)