Skip to content

Commit 2af186f

Browse files
authored
[mlir][gpu] Add patterns to break down subgroup reduce (#76271)
The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized. The overall goal is to be able to perform multi-element reductions with a sequence of `gpu.shuffle` ops.
1 parent 8076ee9 commit 2af186f

File tree

7 files changed

+261
-5
lines changed

7 files changed

+261
-5
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "Utils.h"
1717
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18+
#include "mlir/IR/PatternMatch.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include <optional>
2021

@@ -62,6 +63,13 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
6263
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
6364
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
6465

66+
/// Collect a set of patterns to break down subgroup_reduce ops into smaller
67+
/// ones supported by the target of `size <= maxShuffleBitwidth`, where `size`
68+
/// is the subgroup_reduce value bitwidth.
69+
void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
70+
unsigned maxShuffleBitwidth = 32,
71+
PatternBenefit benefit = 1);
72+
6573
/// Collect all patterns to rewrite ops within the GPU dialect.
6674
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
6775
populateGpuAllReducePatterns(patterns);

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
5050
Transforms/AsyncRegionRewriter.cpp
5151
Transforms/BufferDeallocationOpInterfaceImpl.cpp
5252
Transforms/DecomposeMemrefs.cpp
53+
Transforms/EliminateBarriers.cpp
5354
Transforms/GlobalIdRewriter.cpp
5455
Transforms/KernelOutlining.cpp
5556
Transforms/MemoryPromotion.cpp
5657
Transforms/ModuleToBinary.cpp
5758
Transforms/NVVMAttachTarget.cpp
5859
Transforms/ParallelLoopMapper.cpp
60+
Transforms/ROCDLAttachTarget.cpp
5961
Transforms/SerializeToBlob.cpp
6062
Transforms/SerializeToCubin.cpp
6163
Transforms/SerializeToHsaco.cpp
6264
Transforms/ShuffleRewriter.cpp
6365
Transforms/SPIRVAttachTarget.cpp
64-
Transforms/ROCDLAttachTarget.cpp
65-
Transforms/EliminateBarriers.cpp
66+
Transforms/SubgroupReduceLowering.cpp
6667

6768
ADDITIONAL_HEADER_DIRS
6869
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implements gradual lowering of `gpu.subgroup_reduce` ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
17+
#include "mlir/IR/Location.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
#include "llvm/Support/FormatVariadic.h"
21+
#include "llvm/Support/MathExtras.h"
22+
#include <cassert>
23+
24+
using namespace mlir;
25+
26+
namespace {
27+
28+
/// Example, assumes `maxShuffleBitwidth` equal to 32:
29+
/// ```
30+
/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
31+
/// ==>
32+
/// %v0 = arith.constant dense<0.0> : vector<3xf16>
33+
/// %e0 = vector.extract_strided_slice %x
34+
/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
35+
/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
36+
/// %v1 = vector.insert_strided_slice %r0, %v0
37+
/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
38+
/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
39+
/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
40+
/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
41+
/// ```
42+
struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
43+
BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
44+
PatternBenefit benefit)
45+
: OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
46+
}
47+
48+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
49+
PatternRewriter &rewriter) const override {
50+
auto vecTy = dyn_cast<VectorType>(op.getType());
51+
if (!vecTy || vecTy.getNumElements() < 2)
52+
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
53+
54+
assert(vecTy.getRank() == 1 && "Unexpected vector type");
55+
assert(!vecTy.isScalable() && "Unexpected vector type");
56+
57+
Type elemTy = vecTy.getElementType();
58+
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
59+
if (elemBitwidth >= maxShuffleBitwidth)
60+
return rewriter.notifyMatchFailure(
61+
op, llvm::formatv("element type too large {0}, cannot break down "
62+
"into vectors of bitwidth {1} or less",
63+
elemBitwidth, maxShuffleBitwidth));
64+
65+
unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
66+
assert(elementsPerShuffle >= 1);
67+
68+
unsigned numNewReductions =
69+
llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
70+
assert(numNewReductions >= 1);
71+
if (numNewReductions == 1)
72+
return rewriter.notifyMatchFailure(op, "nothing to break down");
73+
74+
Location loc = op.getLoc();
75+
Value res =
76+
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
77+
78+
for (unsigned i = 0; i != numNewReductions; ++i) {
79+
int64_t startIdx = i * elementsPerShuffle;
80+
int64_t endIdx =
81+
std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
82+
int64_t numElems = endIdx - startIdx;
83+
84+
Value extracted;
85+
if (numElems == 1) {
86+
extracted =
87+
rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
88+
} else {
89+
extracted = rewriter.create<vector::ExtractStridedSliceOp>(
90+
loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
91+
/*strides=*/1);
92+
}
93+
94+
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
95+
loc, extracted, op.getOp(), op.getUniform());
96+
if (numElems == 1) {
97+
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
98+
continue;
99+
}
100+
101+
res = rewriter.create<vector::InsertStridedSliceOp>(
102+
loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
103+
}
104+
105+
rewriter.replaceOp(op, res);
106+
return success();
107+
}
108+
109+
private:
110+
unsigned maxShuffleBitwidth = 0;
111+
};
112+
113+
/// Example:
114+
/// ```
115+
/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
116+
/// ==>
117+
/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
118+
/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
119+
/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
120+
/// ```
121+
struct ScalarizeSingleElementReduce final
122+
: OpRewritePattern<gpu::SubgroupReduceOp> {
123+
using OpRewritePattern::OpRewritePattern;
124+
125+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
126+
PatternRewriter &rewriter) const override {
127+
auto vecTy = dyn_cast<VectorType>(op.getType());
128+
if (!vecTy || vecTy.getNumElements() != 1)
129+
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
130+
131+
assert(vecTy.getRank() == 1 && "Unexpected vector type");
132+
assert(!vecTy.isScalable() && "Unexpected vector type");
133+
Location loc = op.getLoc();
134+
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
135+
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
136+
loc, extracted, op.getOp(), op.getUniform());
137+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
138+
return success();
139+
}
140+
};
141+
142+
} // namespace
143+
144+
void mlir::populateGpuBreakDownSubgrupReducePatterns(
145+
RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
146+
PatternBenefit benefit) {
147+
patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
148+
maxShuffleBitwidth, benefit);
149+
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
150+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
2+
3+
// CHECK: gpu.module @kernels {
4+
gpu.module @kernels {
5+
6+
// CHECK-LABEL: gpu.func @kernel0(
7+
// CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
8+
gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
9+
// CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
10+
// CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
11+
// CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
12+
// CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
13+
// CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
14+
// CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
15+
// CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
16+
// CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
17+
// CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
18+
// CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
19+
// CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
20+
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
21+
"test.consume"(%sum0) : (vector<5xf16>) -> ()
22+
23+
24+
// CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
25+
// CHECK: "test.consume"
26+
%sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
27+
"test.consume"(%sum1) : (vector<5xf16>) -> ()
28+
29+
// CHECK: gpu.return
30+
gpu.return
31+
}
32+
33+
// CHECK-LABEL: gpu.func @kernel1(
34+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
35+
gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
36+
// CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
37+
// CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
38+
// CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
39+
// CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
40+
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
41+
"test.consume"(%sum0) : (vector<1xf32>) -> ()
42+
43+
// CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
44+
// CHECK: "test.consume"
45+
%sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
46+
"test.consume"(%sum1) : (vector<1xf32>) -> ()
47+
48+
// CHECK: gpu.return
49+
gpu.return
50+
}
51+
52+
// These vectors fit the native shuffle size and should not be broken down.
53+
//
54+
// CHECK-LABEL: gpu.func @kernel2(
55+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
56+
gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
57+
// CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
58+
// CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
59+
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
60+
"test.consume"(%sum0) : (vector<3xi8>) -> ()
61+
62+
// CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
63+
// CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
64+
%sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
65+
"test.consume"(%sum1) : (vector<4xi8>) -> ()
66+
67+
// CHECK: gpu.return
68+
gpu.return
69+
}
70+
71+
}

mlir/test/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ set(LIBS
2727
MLIRTransforms
2828
MLIRTransformUtils
2929
MLIRTranslateLib
30+
MLIRVectorDialect
3031
MLIRVectorToLLVMPass
3132
)
3233

mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1616
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021

@@ -39,10 +40,34 @@ struct TestGpuRewritePass
3940
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4041
}
4142
};
43+
44+
struct TestGpuSubgroupReduceLoweringPass
45+
: public PassWrapper<TestGpuSubgroupReduceLoweringPass,
46+
OperationPass<ModuleOp>> {
47+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
48+
TestGpuSubgroupReduceLoweringPass)
49+
50+
void getDependentDialects(DialectRegistry &registry) const override {
51+
registry.insert<arith::ArithDialect, vector::VectorDialect>();
52+
}
53+
StringRef getArgument() const final {
54+
return "test-gpu-subgroup-reduce-lowering";
55+
}
56+
StringRef getDescription() const final {
57+
return "Applies gpu.subgroup_reduce lowering patterns.";
58+
}
59+
void runOnOperation() override {
60+
RewritePatternSet patterns(&getContext());
61+
populateGpuBreakDownSubgrupReducePatterns(patterns,
62+
/*maxShuffleBitwidth=*/32);
63+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
64+
}
65+
};
4266
} // namespace
4367

4468
namespace mlir {
45-
void registerTestAllReduceLoweringPass() {
69+
void registerTestGpuLoweringPasses() {
4670
PassRegistration<TestGpuRewritePass>();
71+
PassRegistration<TestGpuSubgroupReduceLoweringPass>();
4772
}
4873
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
4747
void registerTestBytecodeRoundtripPasses();
4848
void registerTestDecomposeAffineOpPass();
4949
void registerTestAffineLoopUnswitchingPass();
50-
void registerTestAllReduceLoweringPass();
50+
void registerTestGpuLoweringPasses();
5151
void registerTestFunc();
5252
void registerTestGpuMemoryPromotionPass();
5353
void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
167167
registerTestAffineReifyValueBoundsPass();
168168
registerTestDecomposeAffineOpPass();
169169
registerTestAffineLoopUnswitchingPass();
170-
registerTestAllReduceLoweringPass();
170+
registerTestGpuLoweringPasses();
171171
registerTestBytecodeRoundtripPasses();
172172
registerTestFunc();
173173
registerTestGpuMemoryPromotionPass();

0 commit comments

Comments
 (0)