|
| 1 | +//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Implements gradual lowering of `gpu.subgroup_reduce` ops. |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | +#include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 15 | +#include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 16 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | +#include "mlir/IR/Location.h" |
| 18 | +#include "mlir/IR/PatternMatch.h" |
| 19 | +#include "mlir/Support/LogicalResult.h" |
| 20 | +#include "llvm/Support/FormatVariadic.h" |
| 21 | +#include "llvm/Support/MathExtras.h" |
| 22 | +#include <cassert> |
| 23 | + |
| 24 | +using namespace mlir; |
| 25 | + |
| 26 | +namespace { |
| 27 | + |
| 28 | +/// Example, assumes `maxShuffleBitwidth` equal to 32: |
| 29 | +/// ``` |
| 30 | +/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> |
| 31 | +/// ==> |
| 32 | +/// %v0 = arith.constant dense<0.0> : vector<3xf16> |
| 33 | +/// %e0 = vector.extract_strided_slice %x |
| 34 | +/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> |
| 35 | +/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> |
| 36 | +/// %v1 = vector.insert_strided_slice %r0, %v0 |
| 37 | +/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> |
| 38 | +/// %e1 = vector.extract %x[2] : f16 from vector<2xf16> |
| 39 | +/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 |
| 40 | +/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> |
| 41 | +/// ``` |
| 42 | +struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { |
| 43 | + BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, |
| 44 | + PatternBenefit benefit) |
| 45 | + : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { |
| 46 | + } |
| 47 | + |
| 48 | + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| 49 | + PatternRewriter &rewriter) const override { |
| 50 | + auto vecTy = dyn_cast<VectorType>(op.getType()); |
| 51 | + if (!vecTy || vecTy.getNumElements() < 2) |
| 52 | + return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); |
| 53 | + |
| 54 | + assert(vecTy.getRank() == 1 && "Unexpected vector type"); |
| 55 | + assert(!vecTy.isScalable() && "Unexpected vector type"); |
| 56 | + |
| 57 | + Type elemTy = vecTy.getElementType(); |
| 58 | + unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); |
| 59 | + if (elemBitwidth >= maxShuffleBitwidth) |
| 60 | + return rewriter.notifyMatchFailure( |
| 61 | + op, llvm::formatv("element type too large {0}, cannot break down " |
| 62 | + "into vectors of bitwidth {1} or less", |
| 63 | + elemBitwidth, maxShuffleBitwidth)); |
| 64 | + |
| 65 | + unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; |
| 66 | + assert(elementsPerShuffle >= 1); |
| 67 | + |
| 68 | + unsigned numNewReductions = |
| 69 | + llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); |
| 70 | + assert(numNewReductions >= 1); |
| 71 | + if (numNewReductions == 1) |
| 72 | + return rewriter.notifyMatchFailure(op, "nothing to break down"); |
| 73 | + |
| 74 | + Location loc = op.getLoc(); |
| 75 | + Value res = |
| 76 | + rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy)); |
| 77 | + |
| 78 | + for (unsigned i = 0; i != numNewReductions; ++i) { |
| 79 | + int64_t startIdx = i * elementsPerShuffle; |
| 80 | + int64_t endIdx = |
| 81 | + std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); |
| 82 | + int64_t numElems = endIdx - startIdx; |
| 83 | + |
| 84 | + Value extracted; |
| 85 | + if (numElems == 1) { |
| 86 | + extracted = |
| 87 | + rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx); |
| 88 | + } else { |
| 89 | + extracted = rewriter.create<vector::ExtractStridedSliceOp>( |
| 90 | + loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, |
| 91 | + /*strides=*/1); |
| 92 | + } |
| 93 | + |
| 94 | + Value reduce = rewriter.create<gpu::SubgroupReduceOp>( |
| 95 | + loc, extracted, op.getOp(), op.getUniform()); |
| 96 | + if (numElems == 1) { |
| 97 | + res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx); |
| 98 | + continue; |
| 99 | + } |
| 100 | + |
| 101 | + res = rewriter.create<vector::InsertStridedSliceOp>( |
| 102 | + loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); |
| 103 | + } |
| 104 | + |
| 105 | + rewriter.replaceOp(op, res); |
| 106 | + return success(); |
| 107 | + } |
| 108 | + |
| 109 | +private: |
| 110 | + unsigned maxShuffleBitwidth = 0; |
| 111 | +}; |
| 112 | + |
| 113 | +/// Example: |
| 114 | +/// ``` |
| 115 | +/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> |
| 116 | +/// ==> |
| 117 | +/// %e0 = vector.extract %x[0] : f32 from vector<1xf32> |
| 118 | +/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 |
| 119 | +/// %a = vector.broadcast %r0 : f32 to vector<1xf32> |
| 120 | +/// ``` |
| 121 | +struct ScalarizeSingleElementReduce final |
| 122 | + : OpRewritePattern<gpu::SubgroupReduceOp> { |
| 123 | + using OpRewritePattern::OpRewritePattern; |
| 124 | + |
| 125 | + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| 126 | + PatternRewriter &rewriter) const override { |
| 127 | + auto vecTy = dyn_cast<VectorType>(op.getType()); |
| 128 | + if (!vecTy || vecTy.getNumElements() != 1) |
| 129 | + return rewriter.notifyMatchFailure(op, "not a single-element reduction"); |
| 130 | + |
| 131 | + assert(vecTy.getRank() == 1 && "Unexpected vector type"); |
| 132 | + assert(!vecTy.isScalable() && "Unexpected vector type"); |
| 133 | + Location loc = op.getLoc(); |
| 134 | + Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0); |
| 135 | + Value reduce = rewriter.create<gpu::SubgroupReduceOp>( |
| 136 | + loc, extracted, op.getOp(), op.getUniform()); |
| 137 | + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); |
| 138 | + return success(); |
| 139 | + } |
| 140 | +}; |
| 141 | + |
| 142 | +} // namespace |
| 143 | + |
| 144 | +void mlir::populateGpuBreakDownSubgrupReducePatterns( |
| 145 | + RewritePatternSet &patterns, unsigned maxShuffleBitwidth, |
| 146 | + PatternBenefit benefit) { |
| 147 | + patterns.add<BreakDownSubgroupReduce>(patterns.getContext(), |
| 148 | + maxShuffleBitwidth, benefit); |
| 149 | + patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); |
| 150 | +} |
0 commit comments