Skip to content

[mlir][gpu] Add subgroup_reduce to shuffle lowering #76530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
unsigned maxShuffleBitwidth = 32,
PatternBenefit benefit = 1);

/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
void populateGpuLowerSubgroupReduceToShufflePattenrs(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);

/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
#define MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/LLVM.h"

#include <string>
Expand All @@ -28,6 +30,9 @@ class LaunchOp;

/// Returns the default annotation name for GPU binary blobs.
std::string getDefaultGpuBinaryAnnotation();

/// Returns the matching vector combining kind.
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
} // namespace gpu

/// Get a gpu.func created from outlining the region of a gpu.launch op with the
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupReduceLowering.cpp
Transforms/Utils.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
Expand Down
27 changes: 0 additions & 27 deletions mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,6 @@ using namespace mlir;

namespace {

static vector::CombiningKind
convertReductionKind(gpu::AllReduceOperation mode) {
switch (mode) {
#define MAP_CASE(X) \
case gpu::AllReduceOperation::X: \
return vector::CombiningKind::X

MAP_CASE(ADD);
MAP_CASE(MUL);
MAP_CASE(MINUI);
MAP_CASE(MINSI);
MAP_CASE(MINNUMF);
MAP_CASE(MAXSI);
MAP_CASE(MAXUI);
MAP_CASE(MAXNUMF);
MAP_CASE(AND);
MAP_CASE(OR);
MAP_CASE(XOR);
MAP_CASE(MINIMUMF);
MAP_CASE(MAXIMUMF);

#undef MAP_CASE
}

llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
}

struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;

Expand Down
174 changes: 173 additions & 1 deletion mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <cstdint>

using namespace mlir;

Expand Down Expand Up @@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth >= maxShuffleBitwidth)
return rewriter.notifyMatchFailure(
op, llvm::formatv("element type too large {0}, cannot break down "
op, llvm::formatv("element type too large ({0}), cannot break down "
"into vectors of bitwidth {1} or less",
elemBitwidth, maxShuffleBitwidth));

Expand Down Expand Up @@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
}
};

/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
/// and `unpackFn` to convert to the native shuffle type and to the reduction
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
static Value createSubgroupShuffleReduction(
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
unsigned subgroupSize, function_ref<Value(Value)> packFn,
function_ref<Value(Value)> unpackFn) {
assert(llvm::isPowerOf2_32(subgroupSize));
// Lane value always stays in the original type. We use it to perform arith
// reductions.
Value laneVal = input;
// Parallel reduction using butterfly shuffles.
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
/*width=*/subgroupSize,
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = vector::makeArithReduction(builder, loc,
gpu::convertReductionKind(mode),
laneVal, unpackFn(shuffled));
assert(laneVal.getType() == input.getType());
}

return laneVal;
}

/// Lowers scalar gpu subgroup reductions to a series of shuffles.
struct ScalarSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
Type valueTy = op.getType();
unsigned elemBitwidth =
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
return rewriter.notifyMatchFailure(
op, "value type is not a compatible scalar");

Location loc = op.getLoc();
// Since this is already a native shuffle scalar, no packing is necessary.
if (elemBitwidth == shuffleBitwidth) {
auto identityFn = [](Value v) { return v; };
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, identityFn, identityFn));
return success();
}

auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
auto equivIntType = rewriter.getIntegerType(elemBitwidth);
auto packFn = [loc, &rewriter, equivIntType,
shuffleIntType](Value unpackedVal) -> Value {
auto asInt =
rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
};
auto unpackFn = [loc, &rewriter, equivIntType,
valueTy](Value packedVal) -> Value {
auto asInt =
rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
};

rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, packFn, unpackFn));
return success();
}

private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
};

/// Lowers vector gpu subgroup reductions to a series of shuffles.
struct VectorSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return rewriter.notifyMatchFailure(op, "value type is not a vector");

unsigned vecBitwidth =
vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
if (vecBitwidth > shuffleBitwidth)
return rewriter.notifyMatchFailure(
op,
llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
"to shuffles of size {1}",
vecBitwidth, shuffleBitwidth));

unsigned elementsPerShuffle =
shuffleBitwidth / vecTy.getElementTypeBitWidth();
if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
return rewriter.notifyMatchFailure(
op, "shuffle bitwidth is not a multiple of the element bitwidth");

Location loc = op.getLoc();

// If the reduced type is smaller than the native shuffle size, extend it,
// perform the shuffles, and extract at the end.
auto extendedVecTy = VectorType::get(
static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
Value extendedInput = op.getValue();
if (vecBitwidth < shuffleBitwidth) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(extendedVecTy));
extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
}

auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
auto shuffleVecType = VectorType::get(1, shuffleIntType);

auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
auto asIntVec =
rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
};
auto unpackFn = [loc, &rewriter, shuffleVecType,
extendedVecTy](Value packedVal) -> Value {
auto asIntVec =
rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};

Value res =
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
subgroupSize, packFn, unpackFn);

if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(
loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
/*strides=*/1);
}

rewriter.replaceOp(op, res);
return success();
}

private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
};
} // namespace

void mlir::populateGpuBreakDownSubgrupReducePatterns(
Expand All @@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
maxShuffleBitwidth, benefit);
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
}

void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
}
44 changes: 44 additions & 0 deletions mlir/lib/Dialect/GPU/Transforms/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===- Utils.cpp - GPU transforms utils -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements GPU dialect transforms utils.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir::gpu {

vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
switch (mode) {
#define MAP_CASE(X) \
case gpu::AllReduceOperation::X: \
return vector::CombiningKind::X

MAP_CASE(ADD);
MAP_CASE(MUL);
MAP_CASE(MINUI);
MAP_CASE(MINSI);
MAP_CASE(MINNUMF);
MAP_CASE(MAXSI);
MAP_CASE(MAXUI);
MAP_CASE(MAXNUMF);
MAP_CASE(AND);
MAP_CASE(OR);
MAP_CASE(XOR);
MAP_CASE(MINIMUMF);
MAP_CASE(MAXIMUMF);

#undef MAP_CASE
}

llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
}

} // namespace mlir::gpu
Loading