Skip to content

Commit 905f1d8

Browse files
[mlir][AMDGPU] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (#133204)
When performing cross-lane reductions using subgroup_reduce ops across contiguous lanes on AMD GPUs, lower to Data Parallel Primitives (DPP) ops when possible. This reduces latency on applicable devices. See related [Issue](iree-org/iree#20007) To do: - Improve lowering to subgroup_reduce in compatible matvecs (these get directly lowered to gpu.shuffles in an earlier pass) --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent fc7fee8 commit 905f1d8

File tree

4 files changed

+364
-2
lines changed

4 files changed

+364
-2
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
1414
#define MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
1515

16+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1617
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1718
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
1819
#include "mlir/IR/PatternMatch.h"
@@ -68,6 +69,20 @@ void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
6869
RewritePatternSet &patterns, unsigned subgroupSize,
6970
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
7071

72+
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
73+
/// ops over scalar types. Assumes that the subgroup has
74+
/// `subgroupSize` lanes. Applicable only to AMD GPUs.
75+
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns,
76+
unsigned subgroupSize,
77+
amdgpu::Chipset chipset,
78+
PatternBenefit benefit = 1);
79+
80+
/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToDPPPatterns`
81+
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
82+
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
83+
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
84+
PatternBenefit benefit = 1);
85+
7186
/// Collect all patterns to rewrite ops within the GPU dialect.
7287
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
7388
populateGpuAllReducePatterns(patterns);

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1315
#include "mlir/Dialect/Arith/IR/Arith.h"
1416
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1517
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1618
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
19+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1720
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1821
#include "mlir/IR/BuiltinTypes.h"
1922
#include "mlir/IR/Location.h"
2023
#include "mlir/IR/PatternMatch.h"
2124
#include "mlir/IR/TypeUtilities.h"
25+
#include "llvm/Support/ErrorHandling.h"
2226
#include "llvm/Support/FormatVariadic.h"
2327
#include "llvm/Support/MathExtras.h"
2428
#include <cassert>
@@ -362,6 +366,163 @@ struct VectorSubgroupReduceToShuffles final
362366
unsigned shuffleBitwidth = 0;
363367
bool matchClustered = false;
364368
};
369+
370+
static FailureOr<Value>
371+
createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372+
Value input, gpu::AllReduceOperation mode,
373+
const ClusterInfo &ci, amdgpu::Chipset chipset) {
374+
Location loc = op.getLoc();
375+
Value dpp;
376+
Value res = input;
377+
constexpr int allRows = 0xf;
378+
constexpr int allBanks = 0xf;
379+
const bool boundCtrl = true;
380+
if (ci.clusterSize >= 2) {
381+
// Perform reduction between all lanes N <-> N+1.
382+
dpp = rewriter.create<amdgpu::DPPOp>(
383+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384+
rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385+
res = vector::makeArithReduction(rewriter, loc,
386+
gpu::convertReductionKind(mode), res, dpp);
387+
}
388+
389+
if (ci.clusterSize >= 4) {
390+
// Perform reduction between all lanes N <-> N+2.
391+
dpp = rewriter.create<amdgpu::DPPOp>(
392+
loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393+
rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394+
res = vector::makeArithReduction(rewriter, loc,
395+
gpu::convertReductionKind(mode), res, dpp);
396+
}
397+
if (ci.clusterSize >= 8) {
398+
// Perform reduction between all lanes N <-> 7-N,
399+
// e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400+
dpp = rewriter.create<amdgpu::DPPOp>(
401+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
402+
rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
403+
res = vector::makeArithReduction(rewriter, loc,
404+
gpu::convertReductionKind(mode), res, dpp);
405+
}
406+
if (ci.clusterSize >= 16) {
407+
// Perform reduction between all lanes N <-> 15-N,
408+
// e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409+
dpp = rewriter.create<amdgpu::DPPOp>(
410+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
411+
rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
412+
res = vector::makeArithReduction(rewriter, loc,
413+
gpu::convertReductionKind(mode), res, dpp);
414+
}
415+
if (ci.clusterSize >= 32) {
416+
if (chipset.majorVersion <= 9) {
417+
// Broadcast last value from each row to next row.
418+
// Use row mask to avoid polluting rows 1 and 3.
419+
dpp = rewriter.create<amdgpu::DPPOp>(
420+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
421+
rewriter.getUnitAttr(), 0xa, allBanks,
422+
/*bound_ctrl*/ false);
423+
res = vector::makeArithReduction(
424+
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
425+
} else if (chipset.majorVersion <= 12) {
426+
// Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427+
Value uint32Max = rewriter.create<arith::ConstantOp>(
428+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
429+
dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
430+
uint32Max, uint32Max,
431+
/*fi=*/true,
432+
/*bound_ctrl=*/false);
433+
res = vector::makeArithReduction(
434+
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435+
if (ci.subgroupSize == 32) {
436+
Value lane0 = rewriter.create<arith::ConstantOp>(
437+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
438+
res =
439+
rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
440+
}
441+
} else {
442+
return rewriter.notifyMatchFailure(
443+
op, "Subgroup reduce lowering to DPP not currently supported for "
444+
"this device.");
445+
}
446+
}
447+
if (ci.clusterSize >= 64) {
448+
if (chipset.majorVersion <= 9) {
449+
// Broadcast 31st lane value to rows 2 and 3.
450+
// Use row mask to avoid polluting rows 0 and 1.
451+
dpp = rewriter.create<amdgpu::DPPOp>(
452+
loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
453+
rewriter.getUnitAttr(), 0xc, allBanks,
454+
/*bound_ctrl*/ false);
455+
456+
} else if (chipset.majorVersion <= 12) {
457+
// Assume reduction across 32 lanes has been done.
458+
// Perform final reduction manually by summing values in lane 0 and
459+
// lane 32.
460+
Value lane0 = rewriter.create<arith::ConstantOp>(
461+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
462+
Value lane32 = rewriter.create<arith::ConstantOp>(
463+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
464+
dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465+
res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
466+
} else {
467+
return rewriter.notifyMatchFailure(
468+
op, "Subgroup reduce lowering to DPP not currently supported for "
469+
"this device.");
470+
}
471+
res = vector::makeArithReduction(rewriter, loc,
472+
gpu::convertReductionKind(mode), res, dpp);
473+
}
474+
assert(res.getType() == input.getType());
475+
return res;
476+
}
477+
478+
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
479+
/// ops over scalar types. Assumes that the subgroup has
480+
/// `subgroupSize` lanes. Applicable only to AMD GPUs.
481+
struct ScalarSubgroupReduceToDPP final
482+
: OpRewritePattern<gpu::SubgroupReduceOp> {
483+
ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
484+
bool matchClustered, amdgpu::Chipset chipset,
485+
PatternBenefit benefit)
486+
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
487+
matchClustered(matchClustered), chipset(chipset) {}
488+
489+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
490+
PatternRewriter &rewriter) const override {
491+
if (op.getClusterSize().has_value() != matchClustered) {
492+
return rewriter.notifyMatchFailure(
493+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
494+
"only match {1}clustered ops",
495+
matchClustered ? "non-" : "",
496+
matchClustered ? "" : "non-"));
497+
}
498+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
499+
if (failed(ci))
500+
return failure();
501+
502+
if (ci->clusterStride != 1)
503+
return rewriter.notifyMatchFailure(
504+
op, "Subgroup reductions using DPP are currently only available for "
505+
"clusters of contiguous lanes.");
506+
507+
Type valueTy = op.getType();
508+
if (!valueTy.isIntOrFloat())
509+
return rewriter.notifyMatchFailure(
510+
op, "Value type is not a compatible scalar.");
511+
512+
FailureOr<Value> dpp = createSubgroupDPPReduction(
513+
rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
514+
if (failed(dpp))
515+
return failure();
516+
517+
rewriter.replaceOp(op, dpp.value());
518+
return success();
519+
}
520+
521+
private:
522+
unsigned subgroupSize = 0;
523+
bool matchClustered = false;
524+
amdgpu::Chipset chipset;
525+
};
365526
} // namespace
366527

367528
void mlir::populateGpuBreakDownSubgroupReducePatterns(
@@ -372,6 +533,22 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372533
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
373534
}
374535

536+
void mlir::populateGpuLowerSubgroupReduceToDPPPatterns(
537+
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
538+
PatternBenefit benefit) {
539+
patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
540+
/*matchClustered=*/false, chipset,
541+
benefit);
542+
}
543+
544+
void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
545+
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
546+
PatternBenefit benefit) {
547+
patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
548+
/*matchClustered=*/true, chipset,
549+
benefit);
550+
}
551+
375552
void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
376553
RewritePatternSet &patterns, unsigned subgroupSize,
377554
unsigned shuffleBitwidth, PatternBenefit benefit) {

0 commit comments

Comments
 (0)