Skip to content

[mlir][gpu] Add patterns to break down subgroup reduce #76271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 28, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Dec 22, 2023

The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized.

The overall goal is to be able to perform multi-element reductions with a sequence of gpu.shuffle ops.

@llvmbot
Copy link
Member

llvmbot commented Dec 22, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jakub Kuderski (kuhar)

Changes

The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized.

The overall goal is to be able to perform multi-element reductions with a sequence of gpu.shuffle ops.


Full diff: https://github.com/llvm/llvm-project/pull/76271.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+7)
  • (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+3-2)
  • (added) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+139)
  • (added) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+71)
  • (modified) mlir/test/lib/Dialect/GPU/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (+27-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index c6c02ccaafbcf4..b905ef2e02aee0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
 
 #include "Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
@@ -62,6 +63,12 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
 /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
 void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to break down subgroup_reduce ops into smaller
+/// ones supported by the target of size <= `maxShuffleBitwidth`.
+void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
+                                               unsigned maxShuffleBitwidth = 32,
+                                               PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ab6834cb262fb5..8383e06e6d2478 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/AsyncRegionRewriter.cpp
   Transforms/BufferDeallocationOpInterfaceImpl.cpp
   Transforms/DecomposeMemrefs.cpp
+  Transforms/EliminateBarriers.cpp
   Transforms/GlobalIdRewriter.cpp
   Transforms/KernelOutlining.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ModuleToBinary.cpp
   Transforms/NVVMAttachTarget.cpp
   Transforms/ParallelLoopMapper.cpp
+  Transforms/ROCDLAttachTarget.cpp
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
-  Transforms/ROCDLAttachTarget.cpp
-  Transforms/EliminateBarriers.cpp
+  Transforms/SubgroupReduceLowering.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
new file mode 100644
index 00000000000000..07700cfa3c2a22
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
+
+    unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+    assert(elementsPerShuffle >= 1);
+
+    unsigned numNewReductions =
+        llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+    assert(numNewReductions >= 1);
+    if (numNewReductions == 1)
+      return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+    Location loc = op.getLoc();
+    Value res =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+    for (unsigned i = 0; i != numNewReductions; ++i) {
+      int64_t startIdx = i * elementsPerShuffle;
+      int64_t endIdx =
+          std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+      int64_t numElems = endIdx - startIdx;
+
+      Value extracted;
+      if (numElems == 1) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+      } else {
+        extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+            /*strides=*/1);
+      }
+
+      Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+          loc, extracted, op.getOp(), op.getUniform());
+      if (numElems == 1) {
+        res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+        continue;
+      }
+
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+  private:
+  unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(op, "not a single-element reduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+    Location loc = op.getLoc();
+    Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
+    Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+        loc, extracted, op.getOp(), op.getUniform());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuBreakDownSubgrupReducePatterns(
+    RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
+                                        maxShuffleBitwidth, benefit);
+  patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
new file mode 100644
index 00000000000000..b7146071bf2fd8
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+
+// CHECK: gpu.module @kernels {
+gpu.module @kernels {
+
+  // CHECK-LABEL: gpu.func @kernel0(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+  gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
+    // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+    // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+    // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+    // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+    // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum0) : (vector<5xf16>) -> ()
+
+
+    // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum1) : (vector<5xf16>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // CHECK-LABEL: gpu.func @kernel1(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+  gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
+    // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+    // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+    // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum0) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum1) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // These vectors fit the native shuffle size and should not be broken down.
+  //
+  // CHECK-LABEL: gpu.func @kernel2(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+  gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+    // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+    "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+    // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+    %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
+    "test.consume"(%sum1) : (vector<4xi8>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index aa94bce275eafb..48cbc4ad5505b0 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -27,6 +27,7 @@ set(LIBS
   MLIRTransforms
   MLIRTransformUtils
   MLIRTranslateLib
+  MLIRVectorDialect
   MLIRVectorToLLVMPass
   )
 
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index db65f3bccec52d..4e8f0cc6667524 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -39,10 +40,35 @@ struct TestGpuRewritePass
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestGpuSubgroupReduceLoweringPass
+    : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestGpuSubgroupReduceLoweringPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-gpu-subgroup-reduce-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Applies gpu.subgroup_reduce lowering patterns.";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateGpuBreakDownSubgrupReducePatterns(patterns,
+                                              /*maxShuffleBitwidth=*/32);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
-void registerTestAllReduceLoweringPass() {
+void registerTestGpuLoweringPasses() {
   PassRegistration<TestGpuRewritePass>();
+  PassRegistration<TestGpuSubgroupReduceLoweringPass>();
 }
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..dc4121dc46bb9b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
 void registerTestBytecodeRoundtripPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
-void registerTestAllReduceLoweringPass();
+void registerTestGpuLoweringPasses();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
   registerTestAffineReifyValueBoundsPass();
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
-  registerTestAllReduceLoweringPass();
+  registerTestGpuLoweringPasses();
   registerTestBytecodeRoundtripPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();

@llvmbot
Copy link
Member

llvmbot commented Dec 22, 2023

@llvm/pr-subscribers-mlir-gpu

Author: Jakub Kuderski (kuhar)

Changes

The new patterns break down subgroup reduce ops with vector values into a sequence of subgroup reductions that fit the native shuffle size. The maximum/native shuffle size is parametrized.

The overall goal is to be able to perform multi-element reductions with a sequence of gpu.shuffle ops.


Full diff: https://github.com/llvm/llvm-project/pull/76271.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+7)
  • (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+3-2)
  • (added) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+139)
  • (added) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+71)
  • (modified) mlir/test/lib/Dialect/GPU/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (+27-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index c6c02ccaafbcf4..b905ef2e02aee0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
 
 #include "Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
@@ -62,6 +63,12 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
 /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
 void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to break down subgroup_reduce ops into smaller
+/// ones supported by the target of size <= `maxShuffleBitwidth`.
+void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
+                                               unsigned maxShuffleBitwidth = 32,
+                                               PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ab6834cb262fb5..8383e06e6d2478 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/AsyncRegionRewriter.cpp
   Transforms/BufferDeallocationOpInterfaceImpl.cpp
   Transforms/DecomposeMemrefs.cpp
+  Transforms/EliminateBarriers.cpp
   Transforms/GlobalIdRewriter.cpp
   Transforms/KernelOutlining.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ModuleToBinary.cpp
   Transforms/NVVMAttachTarget.cpp
   Transforms/ParallelLoopMapper.cpp
+  Transforms/ROCDLAttachTarget.cpp
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
-  Transforms/ROCDLAttachTarget.cpp
-  Transforms/EliminateBarriers.cpp
+  Transforms/SubgroupReduceLowering.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
new file mode 100644
index 00000000000000..07700cfa3c2a22
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
+
+    unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+    assert(elementsPerShuffle >= 1);
+
+    unsigned numNewReductions =
+        llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+    assert(numNewReductions >= 1);
+    if (numNewReductions == 1)
+      return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+    Location loc = op.getLoc();
+    Value res =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+    for (unsigned i = 0; i != numNewReductions; ++i) {
+      int64_t startIdx = i * elementsPerShuffle;
+      int64_t endIdx =
+          std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+      int64_t numElems = endIdx - startIdx;
+
+      Value extracted;
+      if (numElems == 1) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+      } else {
+        extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+            /*strides=*/1);
+      }
+
+      Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+          loc, extracted, op.getOp(), op.getUniform());
+      if (numElems == 1) {
+        res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+        continue;
+      }
+
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+  private:
+  unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(op, "not a single-element reduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+    Location loc = op.getLoc();
+    Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
+    Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+        loc, extracted, op.getOp(), op.getUniform());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuBreakDownSubgrupReducePatterns(
+    RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
+                                        maxShuffleBitwidth, benefit);
+  patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
new file mode 100644
index 00000000000000..b7146071bf2fd8
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+
+// CHECK: gpu.module @kernels {
+gpu.module @kernels {
+
+  // CHECK-LABEL: gpu.func @kernel0(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+  gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
+    // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+    // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+    // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+    // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+    // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum0) : (vector<5xf16>) -> ()
+
+
+    // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum1) : (vector<5xf16>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // CHECK-LABEL: gpu.func @kernel1(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+  gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
+    // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+    // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+    // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum0) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum1) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // These vectors fit the native shuffle size and should not be broken down.
+  //
+  // CHECK-LABEL: gpu.func @kernel2(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+  gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+    // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+    "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+    // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+    %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
+    "test.consume"(%sum1) : (vector<4xi8>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index aa94bce275eafb..48cbc4ad5505b0 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -27,6 +27,7 @@ set(LIBS
   MLIRTransforms
   MLIRTransformUtils
   MLIRTranslateLib
+  MLIRVectorDialect
   MLIRVectorToLLVMPass
   )
 
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index db65f3bccec52d..4e8f0cc6667524 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -39,10 +40,35 @@ struct TestGpuRewritePass
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestGpuSubgroupReduceLoweringPass
+    : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestGpuSubgroupReduceLoweringPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-gpu-subgroup-reduce-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Applies gpu.subgroup_reduce lowering patterns.";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateGpuBreakDownSubgrupReducePatterns(patterns,
+                                              /*maxShuffleBitwidth=*/32);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
-void registerTestAllReduceLoweringPass() {
+void registerTestGpuLoweringPasses() {
   PassRegistration<TestGpuRewritePass>();
+  PassRegistration<TestGpuSubgroupReduceLoweringPass>();
 }
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..dc4121dc46bb9b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
 void registerTestBytecodeRoundtripPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
-void registerTestAllReduceLoweringPass();
+void registerTestGpuLoweringPasses();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
   registerTestAffineReifyValueBoundsPass();
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
-  registerTestAllReduceLoweringPass();
+  registerTestGpuLoweringPasses();
   registerTestBytecodeRoundtripPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();

@kuhar kuhar requested a review from raikonenfnu December 22, 2023 22:09
Copy link

github-actions bot commented Dec 22, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@kuhar kuhar requested a review from Groverkss December 28, 2023 18:08
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, mostly minor documentation related comments.

The new patterns break down subgroup reduce ops with vector values into
a sequence of subgroup reductions that fit the native shuffle size.
The maximum/native shuffle size is parametrized.

The overall goal is to be able to perform multi-element reductions with
a sequence of `gpu.shuffle` ops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:gpu mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants