Skip to content

Commit bcc7e43

Browse files
Updating implementation to support gfx 10+
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent a8c410d commit bcc7e43

File tree

1 file changed

+90
-58
lines changed

1 file changed

+90
-58
lines changed

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 90 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/Location.h"
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/IR/TypeUtilities.h"
25+
#include "llvm/Support/ErrorHandling.h"
2526
#include "llvm/Support/FormatVariadic.h"
2627
#include "llvm/Support/MathExtras.h"
2728
#include <cassert>
@@ -371,72 +372,103 @@ std::optional<Value> createSubgroupDPPReduction(OpBuilder &b, Location loc,
371372
gpu::AllReduceOperation mode,
372373
const ClusterInfo &ci,
373374
amdgpu::Chipset chipset) {
374-
Value dppResult;
375375
Value result = input;
376376
constexpr int allRows = 0xf;
377377
constexpr int allBanks = 0xf;
378378
const bool boundCtrl = true;
379-
Value lane31 =
380-
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(31));
381-
Value lane63 =
382-
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(63));
383-
if (ci.clusterSize >= 2) {
384-
auto permArg = b.getI32ArrayAttr({1, 0, 3, 2});
385-
dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
386-
amdgpu::DPPPerm::quad_perm, permArg,
387-
allRows, allBanks, boundCtrl);
388-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
389-
result, dppResult);
390-
}
391-
392-
if (ci.clusterSize >= 4) {
393-
auto permArg = b.getI32ArrayAttr({2, 3, 0, 1});
394-
dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
395-
amdgpu::DPPPerm::quad_perm, permArg,
396-
allRows, allBanks, boundCtrl);
397-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
398-
result, dppResult);
399-
}
400-
401-
if (ci.clusterSize >= 8) {
402-
dppResult = b.create<amdgpu::DPPOp>(
403-
loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
404-
b.getUnitAttr(), allRows, allBanks, boundCtrl);
405-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
406-
result, dppResult);
407-
}
408-
409-
if (ci.clusterSize >= 16) {
410-
dppResult = b.create<amdgpu::DPPOp>(
411-
loc, result.getType(), result, result, amdgpu::DPPPerm::row_mirror,
412-
b.getUnitAttr(), allRows, allBanks, boundCtrl);
413-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
414-
result, dppResult);
415-
}
416-
417-
if (ci.clusterSize >= 32) {
418-
if (chipset.majorVersion <= 9) {
379+
Value lane0 =
380+
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(0));
381+
Value lane32 =
382+
b.create<arith::ConstantOp>(loc, b.getI32Type(), b.getI32IntegerAttr(32));
383+
384+
auto dppReduceAcrossLanes = [&](int numLanes,
385+
Value res) -> std::optional<Value> {
386+
Value dppResult, laneVal;
387+
388+
switch (numLanes) {
389+
case 2:
390+
// Perform reduction between all lanes N <-> N+1.
391+
dppResult = b.create<amdgpu::DPPOp>(
392+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393+
b.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
394+
break;
395+
case 4:
396+
// Perform reduction between all lanes N <-> N+2.
419397
dppResult = b.create<amdgpu::DPPOp>(
420-
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
421-
b.getUnitAttr(), 0xa, allBanks, /*bound_ctrl*/ false);
422-
} else {
398+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
399+
b.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
400+
break;
401+
case 8:
402+
// Perform reduction between all lanes N <-> 7-N,
403+
// e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
404+
dppResult = b.create<amdgpu::DPPOp>(
405+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
406+
b.getUnitAttr(), allRows, allBanks, boundCtrl);
407+
break;
408+
case 16:
409+
// Perform reduction between all lanes N <-> 15-N,
410+
// e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
411+
dppResult = b.create<amdgpu::DPPOp>(
412+
loc, result.getType(), res, res, amdgpu::DPPPerm::row_mirror,
413+
b.getUnitAttr(), allRows, allBanks, boundCtrl);
414+
break;
415+
case 32:
416+
if (chipset.majorVersion <= 9) {
417+
// Broadcast last value from each row to next row.
418+
// Use row mask to avoid polluting rows 1 and 3.
419+
dppResult = b.create<amdgpu::DPPOp>(loc, res.getType(), res, res,
420+
amdgpu::DPPPerm::row_bcast_15,
421+
b.getUnitAttr(), 0xa, allBanks,
422+
/*bound_ctrl*/ false);
423+
} else if (chipset.majorVersion <= 12) {
424+
// Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
425+
dppResult = b.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
426+
-1, -1, /*fi=*/true,
427+
/*bound_ctrl=*/false);
428+
if (ci.subgroupSize == 32) {
429+
dppResult =
430+
b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
431+
}
432+
} else {
433+
return std::nullopt;
434+
}
435+
break;
436+
case 64:
437+
if (chipset.majorVersion <= 9) {
438+
// Broadcast 31st lane value to rows 2 and 3.
439+
// Use row mask to avoid polluting rows 0 and 1.
440+
dppResult = b.create<amdgpu::DPPOp>(loc, res.getType(), res, res,
441+
amdgpu::DPPPerm::row_bcast_31,
442+
b.getUnitAttr(), 0xc, allBanks,
443+
/*bound_ctrl*/ false);
444+
} else if (chipset.majorVersion <= 12) {
445+
// Assume reduction across 32 lanes has been done.
446+
// Perform final reduction manually by summing values in lane 0 and
447+
// lane 32.
448+
dppResult =
449+
b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
450+
laneVal = b.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
451+
return vector::makeArithReduction(
452+
b, loc, gpu::convertReductionKind(mode), dppResult, laneVal);
453+
} else {
454+
return std::nullopt;
455+
}
456+
break;
457+
default:
458+
// Should never reach here given previous validation of ClusterInfo.
459+
llvm_unreachable("ERROR: Unexpected cluster size.");
423460
return std::nullopt;
424461
}
425-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
426-
result, dppResult);
427-
if (ci.subgroupSize == 32) {
428-
result =
429-
b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane31);
462+
return vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
463+
res, dppResult);
464+
};
465+
466+
for (unsigned cs = 2; cs <= ci.clusterSize; cs = cs << 1) {
467+
if (auto dpp = dppReduceAcrossLanes(cs, result)) {
468+
result = *dpp;
469+
continue;
430470
}
431-
}
432-
433-
if (ci.clusterSize == 64) {
434-
dppResult = b.create<amdgpu::DPPOp>(
435-
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
436-
b.getUnitAttr(), 0xc, allBanks, /*bound_ctrl*/ false);
437-
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
438-
result, dppResult);
439-
result = b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane63);
471+
return std::nullopt;
440472
}
441473

442474
assert(result.getType() == input.getType());

0 commit comments

Comments
 (0)