-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][gpu] Add subgroup_reduce to shuffle lowering #76530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThis supports both the scalar and the vector multi-reduction cases. Patch is 28.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76530.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6c5bf75d212478..5885facd07541e 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -70,6 +70,13 @@ void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
unsigned maxShuffleBitwidth = 32,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
+/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
+/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+void populateGpuLowerSubgroupReduceToShufflePattenrs(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
index a426bee7686dbc..f25c506fd638d8 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
@@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
#define MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/LLVM.h"
#include <string>
@@ -28,6 +30,9 @@ class LaunchOp;
/// Returns the default annotation name for GPU binary blobs.
std::string getDefaultGpuBinaryAnnotation();
+
+/// Returns the matching vector combining kind.
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
} // namespace gpu
/// Get a gpu.func created from outlining the region of a gpu.launch op with the
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 8383e06e6d2478..8f289ce9452e80 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupReduceLowering.cpp
+ Transforms/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 608d801ee9bbbe..a75598afe8c72d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -27,33 +27,6 @@ using namespace mlir;
namespace {
-static vector::CombiningKind
-convertReductionKind(gpu::AllReduceOperation mode) {
- switch (mode) {
-#define MAP_CASE(X) \
- case gpu::AllReduceOperation::X: \
- return vector::CombiningKind::X
-
- MAP_CASE(ADD);
- MAP_CASE(MUL);
- MAP_CASE(MINUI);
- MAP_CASE(MINSI);
- MAP_CASE(MINNUMF);
- MAP_CASE(MAXSI);
- MAP_CASE(MAXUI);
- MAP_CASE(MAXNUMF);
- MAP_CASE(AND);
- MAP_CASE(OR);
- MAP_CASE(XOR);
- MAP_CASE(MINIMUMF);
- MAP_CASE(MAXIMUMF);
-
-#undef MAP_CASE
- }
-
- llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
-}
-
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 61edce5e2a0862..83eb125f91cfa3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,13 +13,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
+#include <cstdint>
using namespace mlir;
@@ -58,7 +63,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth >= maxShuffleBitwidth)
return rewriter.notifyMatchFailure(
- op, llvm::formatv("element type too large {0}, cannot break down "
+ op, llvm::formatv("element type too large ({0}), cannot break down "
"into vectors of bitwidth {1} or less",
elemBitwidth, maxShuffleBitwidth));
@@ -139,6 +144,164 @@ struct ScalarizeSingleElementReduce final
}
};
+/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+/// and `unpackFn` to convert to/from the native shuffle type. Assumes that the
+/// subgroup is `subgroupSize` lanes wide and reduces across all of them.
+static Value createSubgroupShuffleReduction(
+ OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
+ unsigned subgroupSize, function_ref<Value(Value)> packFn,
+ function_ref<Value(Value)> unpackFn) {
+ assert(llvm::isPowerOf2_32(subgroupSize));
+ // Lane value always stays in the original type. We use it to perform arith
+ // reductions.
+ Value laneVal = input;
+ // Parallel reduction using butterfly shuffles.
+ for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+ Value shuffled = builder
+ .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
+ /*width=*/subgroupSize,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = vector::makeArithReduction(builder, loc,
+ gpu::convertReductionKind(mode),
+ laneVal, unpackFn(shuffled));
+ assert(laneVal.getType() == input.getType());
+ }
+
+ return laneVal;
+}
+
+/// Lowers scalar gpu subgroup reductions to a series of shuffles.
+struct ScalarSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ Type valueTy = op.getType();
+ unsigned elemBitwidth =
+ getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
+ if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "value type is not a compatible scalar");
+
+ Location loc = op.getLoc();
+ // Since this is already a native shuffle scalar, no packing is necessary.
+ if (elemBitwidth == shuffleBitwidth) {
+ auto identityFn = [](Value v) { return v; };
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, identityFn, identityFn));
+ return success();
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto equivIntType = rewriter.getIntegerType(elemBitwidth);
+ auto packFn = [loc, &rewriter, equivIntType,
+ shuffleIntType](Value unpackedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
+ return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
+ };
+ auto unpackFn = [loc, &rewriter, equivIntType,
+ valueTy](Value packedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
+ return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
+ };
+
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, packFn, unpackFn));
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
+
+/// Lowers vector gpu subgroup reductions to a series of shuffles.
+struct VectorSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy)
+ return rewriter.notifyMatchFailure(op, "value type is not a vector");
+
+ unsigned vecBitwidth =
+ vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
+ if (vecBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op,
+ llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
+ "to shuffles of size {1}",
+ vecBitwidth, shuffleBitwidth));
+
+ unsigned elementsPerShuffle =
+ shuffleBitwidth / vecTy.getElementTypeBitWidth();
+ if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "shuffle bitwidth is not a multiple of the element bitwidth");
+
+ Location loc = op.getLoc();
+
+ // If the reduced type is smaller than the native shuffle size, extend it,
+ // perform the shuffles, and extract at the end.
+ auto extendedVecTy = VectorType::get(
+ static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
+ Value extendedInput = op.getValue();
+ if (vecBitwidth < shuffleBitwidth) {
+ auto zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(extendedVecTy));
+ extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto shuffleVecType = VectorType::get(1, shuffleIntType);
+
+ auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
+ return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
+ };
+ auto unpackFn = [loc, &rewriter, shuffleVecType,
+ extendedVecTy](Value packedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
+ return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
+ };
+
+ Value res =
+ createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
+ subgroupSize, packFn, unpackFn);
+
+ if (vecBitwidth < shuffleBitwidth) {
+ res = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
+ /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
} // namespace
void mlir::populateGpuBreakDownSubgrupReducePatterns(
@@ -148,3 +311,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
maxShuffleBitwidth, benefit);
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
}
+
+void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
+ patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+ patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..e91aa18128c7b9
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
@@ -0,0 +1,44 @@
+//===- Utils.cpp - GPU transforms utils -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements GPU dialect transforms utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir::gpu {
+
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
+ switch (mode) {
+#define MAP_CASE(X) \
+ case gpu::AllReduceOperation::X: \
+ return vector::CombiningKind::X
+
+ MAP_CASE(ADD);
+ MAP_CASE(MUL);
+ MAP_CASE(MINUI);
+ MAP_CASE(MINSI);
+ MAP_CASE(MINNUMF);
+ MAP_CASE(MAXSI);
+ MAP_CASE(MAXUI);
+ MAP_CASE(MAXNUMF);
+ MAP_CASE(AND);
+ MAP_CASE(OR);
+ MAP_CASE(XOR);
+ MAP_CASE(MINIMUMF);
+ MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+ }
+
+ llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
+} // namespace mlir::gpu
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
index b7146071bf2fd8..f04a01ffe75d3c 100644
--- a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -1,71 +1,191 @@
-// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SUB
-// CHECK: gpu.module @kernels {
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering="expand-to-shuffles" %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SHFL
+
+// CHECK-SUB: gpu.module @kernels {
+// CHECK-SHFL: gpu.module @kernels {
gpu.module @kernels {
- // CHECK-LABEL: gpu.func @kernel0(
- // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ // CHECK-SUB-LABEL: gpu.func @kernel0(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel0(
gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
- // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
- // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
- // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
- // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
- // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+ // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+ // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+ // CHECK-SUB: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+ // CHECK-SUB: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+ // CHECK-SUB: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum0) : (vector<5xf16>) -> ()
-
- // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
- // CHECK: "test.consume"
+ // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum1) : (vector<5xf16>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
- // CHECK-LABEL: gpu.func @kernel1(
- // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ // CHECK-SUB-LABEL: gpu.func @kernel1(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel1(
gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
- // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
- // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
- // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+ // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+ // CHECK-SUB: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+ // CHECK-SUB: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum0) : (vector<1xf32>) -> ()
- // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
- // CHECK: "test.consume"
+ // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum1) : (vector<1xf32>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
// These vectors fit the native shuffle size and should not be broken down.
//
- // CHECK-LABEL: gpu.func @kernel2(
- // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ // CHECK-SUB-LABEL: gpu.func @kernel2(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel2(
gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
- // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+ // CHECK-SUB: %[[R0:...
[truncated]
|
This looks very similar to the IREE VectorReduction to shuffle lowering? Is the plan to lower vector.multi_reduce to gpu.subgroup_reduce and use this lowering instead? |
Yes, exactly. I described this lowering flow here: #76015 |
Woaahhhh, exciting! :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. I just have some documentation related comments.
This supports both the scalar and the vector mulit-reduction cases.
eaaf353
to
c456b09
Compare
This supports both the scalar and the vector multi-reduction cases.