Skip to content

Commit 2bd85a0

Browse files
changing lowering pattern
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent ee726f5 commit 2bd85a0

File tree

3 files changed

+130
-13
lines changed

3 files changed

+130
-13
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ void populateGpuLowerSubgroupReduceToShufflePatterns(
6262
RewritePatternSet &patterns, unsigned subgroupSize,
6363
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
6464

65+
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
66+
/// ops over scalar types. Assumes that the subgroup has
67+
/// `subgroupSize` lanes. Applicable only to AMD GPUs.
68+
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns,
69+
unsigned subgroupSize,
70+
PatternBenefit benefit = 1);
71+
6572
/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
6673
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
6774
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(

mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
6767
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
6868
Value dppResult =
6969
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
70-
amdgpu::DPPPerm::row_shr, permArg);
70+
amdgpu::DPPPerm::row_shl, permArg);
7171
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
7272
result, dppResult);
7373
}
@@ -76,39 +76,41 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
7676
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
7777
Value dppResult =
7878
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
79-
amdgpu::DPPPerm::row_shr, permArg);
79+
amdgpu::DPPPerm::row_shl, permArg);
8080
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
8181
result, dppResult);
8282
}
8383

8484
if (ci.clusterSize >= 8) {
85-
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 4);
86-
Value dppResult =
87-
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
88-
amdgpu::DPPPerm::row_shr, permArg);
85+
Value dppResult = b.create<amdgpu::DPPOp>(
86+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
87+
b.getUnitAttr());
8988
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
9089
result, dppResult);
9190
}
9291

9392
if (ci.clusterSize >= 16) {
94-
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 8);
9593
Value dppResult =
9694
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
97-
amdgpu::DPPPerm::row_shr, permArg);
95+
amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
9896
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
9997
result, dppResult);
10098
}
10199

102100
const int allRows = 0xf;
103101
const int allBanks = 0xf;
104-
102+
auto int32Type = IntegerType::get(b.getContext(), 32);
105103
if (ci.clusterSize >= 32) {
106104
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
107105
Value dppResult = b.create<amdgpu::DPPOp>(
108106
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
109107
b.getUnitAttr(), 0xa, allBanks, false);
110108
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
111-
result, dppResult);
109+
result, dppResult);
110+
if (ci.subgroupSize == 32) {
111+
Value lane01 = b.create<LLVM::ConstantOp>(loc, int32Type, 1);
112+
result = b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane01);
113+
}
112114
}
113115

114116
if (ci.clusterSize == 64) {
@@ -118,11 +120,10 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
118120
b.getUnitAttr(), allRows, allBanks, false);
119121
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
120122
result, dppResult);
123+
Value lane63 = b.create<LLVM::ConstantOp>(loc, int32Type, 63);
124+
result = b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane63);
121125
}
122126

123-
auto int32Type = IntegerType::get(b.getContext(), 32);
124-
Value lane63 = b.create<LLVM::ConstantOp>(loc, int32Type, 63);
125-
result = b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane63);
126127
assert(result.getType() == input.getType());
127128
return result;
128129
}

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
16+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1517
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1618
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
1719
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -362,6 +364,106 @@ struct VectorSubgroupReduceToShuffles final
362364
unsigned shuffleBitwidth = 0;
363365
bool matchClustered = false;
364366
};
367+
368+
Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
369+
gpu::AllReduceOperation mode,
370+
const ClusterInfo &ci) {
371+
Value result = input;
372+
if (ci.clusterSize >= 2) {
373+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
374+
Value dppResult =
375+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
376+
amdgpu::DPPPerm::row_shl, permArg);
377+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
378+
result, dppResult);
379+
}
380+
381+
if (ci.clusterSize >= 4) {
382+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
383+
Value dppResult =
384+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
385+
amdgpu::DPPPerm::row_shl, permArg);
386+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
387+
result, dppResult);
388+
}
389+
390+
if (ci.clusterSize >= 8) {
391+
Value dppResult = b.create<amdgpu::DPPOp>(
392+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
393+
b.getUnitAttr());
394+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
395+
result, dppResult);
396+
}
397+
398+
if (ci.clusterSize >= 16) {
399+
Value dppResult =
400+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
401+
amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
402+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
403+
result, dppResult);
404+
}
405+
406+
const int allRows = 0xf;
407+
const int allBanks = 0xf;
408+
auto int32Type = IntegerType::get(b.getContext(), 32);
409+
if (ci.clusterSize >= 32) {
410+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
411+
Value dppResult = b.create<amdgpu::DPPOp>(
412+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
413+
b.getUnitAttr(), 0xa, allBanks, false);
414+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
415+
result, dppResult);
416+
if (ci.subgroupSize == 32) {
417+
Value lane01 = b.create<LLVM::ConstantOp>(loc, int32Type, 1);
418+
result =
419+
b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane01);
420+
}
421+
}
422+
423+
if (ci.clusterSize == 64) {
424+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
425+
Value dppResult = b.create<amdgpu::DPPOp>(
426+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
427+
b.getUnitAttr(), allRows, allBanks, false);
428+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
429+
result, dppResult);
430+
Value lane63 = b.create<LLVM::ConstantOp>(loc, int32Type, 63);
431+
result = b.create<ROCDL::ReadlaneOp>(loc, input.getType(), result, lane63);
432+
}
433+
434+
assert(result.getType() == input.getType());
435+
return result;
436+
}
437+
438+
struct ScalarSubgroupReduceToDPP final
439+
: OpRewritePattern<gpu::SubgroupReduceOp> {
440+
ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
441+
bool matchClustered, PatternBenefit benefit)
442+
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
443+
matchClustered(matchClustered) {}
444+
445+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
446+
PatternRewriter &rewriter) const override {
447+
if (op.getClusterSize().has_value() != matchClustered) {
448+
return rewriter.notifyMatchFailure(
449+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
450+
"only match {1}clustered ops",
451+
matchClustered ? "non-" : "",
452+
matchClustered ? "" : "non-"));
453+
}
454+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
455+
if (failed(ci))
456+
return failure();
457+
Location loc = op.getLoc();
458+
rewriter.replaceOp(op, createSubgroupDPPReduction(
459+
rewriter, loc, op.getValue(), op.getOp(), *ci));
460+
return success();
461+
}
462+
463+
private:
464+
unsigned subgroupSize = 0;
465+
bool matchClustered = false;
466+
};
365467
} // namespace
366468

367469
void mlir::populateGpuBreakDownSubgroupReducePatterns(
@@ -372,6 +474,13 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372474
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
373475
}
374476

477+
void mlir::populateGpuLowerSubgroupReduceToDPPPatterns(
478+
RewritePatternSet &patterns, unsigned subgroupSize,
479+
PatternBenefit benefit) {
480+
patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
481+
/*matchClustered=*/true, benefit);
482+
}
483+
375484
void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
376485
RewritePatternSet &patterns, unsigned subgroupSize,
377486
unsigned shuffleBitwidth, PatternBenefit benefit) {

0 commit comments

Comments
 (0)