-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][gpu] Add patterns to break down subgroup reduce #76271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-core Author: Jakub Kuderski (kuhar) ChangesThe new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized. The overall goal is to be able to perform multi-element reductions with a sequence of Full diff: https://github.com/llvm/llvm-project/pull/76271.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index c6c02ccaafbcf4..b905ef2e02aee0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
#include "Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
@@ -62,6 +63,12 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to break down subgroup_reduce ops into smaller
+/// ones supported by the target of size <= `maxShuffleBitwidth`.
+void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
+ unsigned maxShuffleBitwidth = 32,
+ PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ab6834cb262fb5..8383e06e6d2478 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AsyncRegionRewriter.cpp
Transforms/BufferDeallocationOpInterfaceImpl.cpp
Transforms/DecomposeMemrefs.cpp
+ Transforms/EliminateBarriers.cpp
Transforms/GlobalIdRewriter.cpp
Transforms/KernelOutlining.cpp
Transforms/MemoryPromotion.cpp
Transforms/ModuleToBinary.cpp
Transforms/NVVMAttachTarget.cpp
Transforms/ParallelLoopMapper.cpp
+ Transforms/ROCDLAttachTarget.cpp
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
- Transforms/ROCDLAttachTarget.cpp
- Transforms/EliminateBarriers.cpp
+ Transforms/SubgroupReduceLowering.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
new file mode 100644
index 00000000000000..07700cfa3c2a22
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+/// ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+ BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+ }
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() < 2)
+ return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+
+ Type elemTy = vecTy.getElementType();
+ unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+ if (elemBitwidth >= maxShuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "large element type, nothing to break down");
+
+ unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+ assert(elementsPerShuffle >= 1);
+
+ unsigned numNewReductions =
+ llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+ assert(numNewReductions >= 1);
+ if (numNewReductions == 1)
+ return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+ Location loc = op.getLoc();
+ Value res =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+ for (unsigned i = 0; i != numNewReductions; ++i) {
+ int64_t startIdx = i * elementsPerShuffle;
+ int64_t endIdx =
+ std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+ int64_t numElems = endIdx - startIdx;
+
+ Value extracted;
+ if (numElems == 1) {
+ extracted =
+ rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+ } else {
+ extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+ /*strides=*/1);
+ }
+
+ Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+ loc, extracted, op.getOp(), op.getUniform());
+ if (numElems == 1) {
+ res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+ continue;
+ }
+
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ private:
+ unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(op, "not a single-element reduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+ Location loc = op.getLoc();
+ Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
+ Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+ loc, extracted, op.getOp(), op.getUniform());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateGpuBreakDownSubgrupReducePatterns(
+ RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit) {
+ patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
+ maxShuffleBitwidth, benefit);
+ patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
new file mode 100644
index 00000000000000..b7146071bf2fd8
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+
+// CHECK: gpu.module @kernels {
+gpu.module @kernels {
+
+ // CHECK-LABEL: gpu.func @kernel0(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
+ // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+ // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+ // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+ // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+ // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
+ "test.consume"(%sum0) : (vector<5xf16>) -> ()
+
+
+ // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+ // CHECK: "test.consume"
+ %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
+ "test.consume"(%sum1) : (vector<5xf16>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @kernel1(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
+ // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+ // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+ // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
+ "test.consume"(%sum0) : (vector<1xf32>) -> ()
+
+ // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+ // CHECK: "test.consume"
+ %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
+ "test.consume"(%sum1) : (vector<1xf32>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+ // These vectors fit the native shuffle size and should not be broken down.
+ //
+ // CHECK-LABEL: gpu.func @kernel2(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+ // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+ "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+ // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+ // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+ %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
+ "test.consume"(%sum1) : (vector<4xi8>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index aa94bce275eafb..48cbc4ad5505b0 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -27,6 +27,7 @@ set(LIBS
MLIRTransforms
MLIRTransformUtils
MLIRTranslateLib
+ MLIRVectorDialect
MLIRVectorToLLVMPass
)
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index db65f3bccec52d..4e8f0cc6667524 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,10 +40,35 @@ struct TestGpuRewritePass
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestGpuSubgroupReduceLoweringPass
+ : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestGpuSubgroupReduceLoweringPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
+ memref::MemRefDialect, vector::VectorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-gpu-subgroup-reduce-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Applies gpu.subgroup_reduce lowering patterns.";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateGpuBreakDownSubgrupReducePatterns(patterns,
+ /*maxShuffleBitwidth=*/32);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
-void registerTestAllReduceLoweringPass() {
+void registerTestGpuLoweringPasses() {
PassRegistration<TestGpuRewritePass>();
+ PassRegistration<TestGpuSubgroupReduceLoweringPass>();
}
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..dc4121dc46bb9b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
void registerTestBytecodeRoundtripPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
-void registerTestAllReduceLoweringPass();
+void registerTestGpuLoweringPasses();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
registerTestAffineReifyValueBoundsPass();
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
- registerTestAllReduceLoweringPass();
+ registerTestGpuLoweringPasses();
registerTestBytecodeRoundtripPasses();
registerTestFunc();
registerTestGpuMemoryPromotionPass();
|
@llvm/pr-subscribers-mlir-gpu Author: Jakub Kuderski (kuhar) ChangesThe new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized. The overall goal is to be able to perform multi-element reductions with a sequence of Full diff: https://github.com/llvm/llvm-project/pull/76271.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index c6c02ccaafbcf4..b905ef2e02aee0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
#include "Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
@@ -62,6 +63,12 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to break down subgroup_reduce ops into smaller
+/// ones supported by the target of size <= `maxShuffleBitwidth`.
+void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
+ unsigned maxShuffleBitwidth = 32,
+ PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ab6834cb262fb5..8383e06e6d2478 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AsyncRegionRewriter.cpp
Transforms/BufferDeallocationOpInterfaceImpl.cpp
Transforms/DecomposeMemrefs.cpp
+ Transforms/EliminateBarriers.cpp
Transforms/GlobalIdRewriter.cpp
Transforms/KernelOutlining.cpp
Transforms/MemoryPromotion.cpp
Transforms/ModuleToBinary.cpp
Transforms/NVVMAttachTarget.cpp
Transforms/ParallelLoopMapper.cpp
+ Transforms/ROCDLAttachTarget.cpp
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
- Transforms/ROCDLAttachTarget.cpp
- Transforms/EliminateBarriers.cpp
+ Transforms/SubgroupReduceLowering.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
new file mode 100644
index 00000000000000..07700cfa3c2a22
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+/// ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+ BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+ }
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() < 2)
+ return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+
+ Type elemTy = vecTy.getElementType();
+ unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+ if (elemBitwidth >= maxShuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "large element type, nothing to break down");
+
+ unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+ assert(elementsPerShuffle >= 1);
+
+ unsigned numNewReductions =
+ llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+ assert(numNewReductions >= 1);
+ if (numNewReductions == 1)
+ return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+ Location loc = op.getLoc();
+ Value res =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+ for (unsigned i = 0; i != numNewReductions; ++i) {
+ int64_t startIdx = i * elementsPerShuffle;
+ int64_t endIdx =
+ std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+ int64_t numElems = endIdx - startIdx;
+
+ Value extracted;
+ if (numElems == 1) {
+ extracted =
+ rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+ } else {
+ extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+ /*strides=*/1);
+ }
+
+ Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+ loc, extracted, op.getOp(), op.getUniform());
+ if (numElems == 1) {
+ res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+ continue;
+ }
+
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ private:
+ unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(op, "not a single-element reduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+ Location loc = op.getLoc();
+ Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
+ Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+ loc, extracted, op.getOp(), op.getUniform());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateGpuBreakDownSubgrupReducePatterns(
+ RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit) {
+ patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
+ maxShuffleBitwidth, benefit);
+ patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
new file mode 100644
index 00000000000000..b7146071bf2fd8
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+
+// CHECK: gpu.module @kernels {
+gpu.module @kernels {
+
+ // CHECK-LABEL: gpu.func @kernel0(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
+ // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+ // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+ // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+ // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+ // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
+ "test.consume"(%sum0) : (vector<5xf16>) -> ()
+
+
+ // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+ // CHECK: "test.consume"
+ %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
+ "test.consume"(%sum1) : (vector<5xf16>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @kernel1(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
+ // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+ // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+ // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
+ "test.consume"(%sum0) : (vector<1xf32>) -> ()
+
+ // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+ // CHECK: "test.consume"
+ %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
+ "test.consume"(%sum1) : (vector<1xf32>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+ // These vectors fit the native shuffle size and should not be broken down.
+ //
+ // CHECK-LABEL: gpu.func @kernel2(
+ // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
+ // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+ // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+ "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+ // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+ // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+ %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
+ "test.consume"(%sum1) : (vector<4xi8>) -> ()
+
+ // CHECK: gpu.return
+ gpu.return
+ }
+
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index aa94bce275eafb..48cbc4ad5505b0 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -27,6 +27,7 @@ set(LIBS
MLIRTransforms
MLIRTransformUtils
MLIRTranslateLib
+ MLIRVectorDialect
MLIRVectorToLLVMPass
)
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index db65f3bccec52d..4e8f0cc6667524 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,10 +40,35 @@ struct TestGpuRewritePass
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestGpuSubgroupReduceLoweringPass
+ : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestGpuSubgroupReduceLoweringPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
+ memref::MemRefDialect, vector::VectorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-gpu-subgroup-reduce-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Applies gpu.subgroup_reduce lowering patterns.";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateGpuBreakDownSubgrupReducePatterns(patterns,
+ /*maxShuffleBitwidth=*/32);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
-void registerTestAllReduceLoweringPass() {
+void registerTestGpuLoweringPasses() {
PassRegistration<TestGpuRewritePass>();
+ PassRegistration<TestGpuSubgroupReduceLoweringPass>();
}
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..dc4121dc46bb9b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
void registerTestBytecodeRoundtripPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
-void registerTestAllReduceLoweringPass();
+void registerTestGpuLoweringPasses();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
registerTestAffineReifyValueBoundsPass();
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
- registerTestAllReduceLoweringPass();
+ registerTestGpuLoweringPasses();
registerTestBytecodeRoundtripPasses();
registerTestFunc();
registerTestGpuMemoryPromotionPass();
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, mostly minor documentation related comments.
The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized. The overall goal is to be able to perform multi-element reductions with a sequence of `gpu.shuffle` ops.
b3e39fb
to
111ead9
Compare
111ead9
to
63fe83c
Compare
The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized.
The overall goal is to be able to perform multi-element reductions with a sequence of
gpu.shuffle
ops.