Skip to content

Commit 2b5b2bf

Browse files
committed
[mlir][gpu] Add DecomposeMemrefsPass
Some GPU backends (SPIR-V) lower memrefs to bare pointers, so for dynamically sized/strided memrefs it will fail. This pass extracts sizes and strides via `memref.extract_strrided_metadata` outside `gpu.launch` body and do index/offset calculation explicitly and then reconstructs memrefs via `memref.reinterpret_cast`. `memref.reinterpret_cast` then lowered via https://reviews.llvm.org/D155011 Differential Revision: https://reviews.llvm.org/D155247
1 parent f9ebcb4 commit 2b5b2bf

File tree

7 files changed

+433
-1
lines changed

7 files changed

+433
-1
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ std::unique_ptr<Pass> createGpuSerializeToHsacoPass(StringRef triple,
150150
StringRef features,
151151
int optLevel);
152152

153+
/// Collect a set of patterns to decompose memrefs ops.
154+
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
155+
156+
/// Pass decomposes memref ops inside `gpu.launch` body.
157+
std::unique_ptr<Pass> createGpuDecomposeMemrefsPass();
158+
153159
/// Generate the code for registering passes.
154160
#define GEN_PASS_REGISTRATION
155161
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,22 @@ def GpuMapParallelLoopsPass
3737
let dependentDialects = ["mlir::gpu::GPUDialect"];
3838
}
3939

40+
def GpuDecomposeMemrefsPass : Pass<"gpu-decompose-memrefs"> {
41+
let summary = "Decomposes memref index computation into explicit ops.";
42+
let description = [{
43+
This pass decomposes memref index computation into explicit computations on
44+
sizes/strides, obtained from `memref.extract_memref_metadata` which it tries
45+
to place outside of `gpu.launch` body. Memrefs are then reconstructed using
46+
`memref.reinterpret_cast`.
47+
This is needed for as some targets (SPIR-V) lower memrefs to bare pointers
48+
and sizes/strides for dynamically-sized memrefs are not available inside
49+
`gpu.launch`.
50+
}];
51+
let constructor = "mlir::createGpuDecomposeMemrefsPass()";
52+
let dependentDialects = [
53+
"mlir::gpu::GPUDialect", "mlir::memref::MemRefDialect",
54+
"mlir::affine::AffineDialect"
55+
];
56+
}
57+
4058
#endif // MLIR_DIALECT_GPU_PASSES

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
229229
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
230230
unsigned dropBack = 0);
231231

232+
/// Compute linear index from provided strides and indices, assuming strided
233+
/// layout.
234+
OpFoldResult computeLinearIndex(OpBuilder &builder, Location loc,
235+
OpFoldResult sourceOffset,
236+
ArrayRef<OpFoldResult> strides,
237+
ArrayRef<OpFoldResult> indices);
232238
} // namespace mlir
233239

234240
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ add_mlir_dialect_library(MLIRGPUDialect
4747
add_mlir_dialect_library(MLIRGPUTransforms
4848
Transforms/AllReduceLowering.cpp
4949
Transforms/AsyncRegionRewriter.cpp
50+
Transforms/DecomposeMemrefs.cpp
5051
Transforms/GlobalIdRewriter.cpp
5152
Transforms/KernelOutlining.cpp
5253
Transforms/MemoryPromotion.cpp
5354
Transforms/ParallelLoopMapper.cpp
54-
Transforms/ShuffleRewriter.cpp
5555
Transforms/SerializeToBlob.cpp
5656
Transforms/SerializeToCubin.cpp
5757
Transforms/SerializeToHsaco.cpp
58+
Transforms/ShuffleRewriter.cpp
5859

5960
ADDITIONAL_HEADER_DIRS
6061
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
//===- DecomposeMemrefs.cpp - Decompose memrefs pass implementation -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements decompose memrefs pass.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/IR/AffineExpr.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Pass/Pass.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
namespace mlir {
26+
#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27+
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28+
} // namespace mlir
29+
30+
using namespace mlir;
31+
32+
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
33+
if (auto parentOp = val.getDefiningOp()) {
34+
builder.setInsertionPointAfter(parentOp);
35+
} else {
36+
builder.setInsertionPointToStart(val.getParentBlock());
37+
}
38+
}
39+
40+
static bool isInsideLaunch(Operation *op) {
41+
return op->getParentOfType<gpu::LaunchOp>();
42+
}
43+
44+
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
45+
getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
46+
ArrayRef<OpFoldResult> subOffsets,
47+
ArrayRef<OpFoldResult> subStrides = std::nullopt) {
48+
auto sourceType = cast<MemRefType>(source.getType());
49+
auto sourceRank = static_cast<unsigned>(sourceType.getRank());
50+
51+
memref::ExtractStridedMetadataOp newExtractStridedMetadata;
52+
{
53+
OpBuilder::InsertionGuard g(rewriter);
54+
setInsertionPointToStart(rewriter, source);
55+
newExtractStridedMetadata =
56+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
57+
}
58+
59+
auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
60+
61+
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
62+
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
63+
: rewriter.getIndexAttr(dim);
64+
};
65+
66+
OpFoldResult origOffset =
67+
getDim(sourceOffset, newExtractStridedMetadata.getOffset());
68+
ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
69+
70+
SmallVector<OpFoldResult> origStrides;
71+
origStrides.reserve(sourceRank);
72+
73+
SmallVector<OpFoldResult> strides;
74+
strides.reserve(sourceRank);
75+
76+
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
77+
AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
78+
for (auto i : llvm::seq(0u, sourceRank)) {
79+
OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
80+
81+
if (!subStrides.empty()) {
82+
strides.push_back(affine::makeComposedFoldedAffineApply(
83+
rewriter, loc, s0 * s1, {subStrides[i], origStride}));
84+
}
85+
86+
origStrides.emplace_back(origStride);
87+
}
88+
89+
OpFoldResult finalOffset =
90+
computeLinearIndex(rewriter, loc, origOffset, origStrides, subOffsets);
91+
return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
92+
}
93+
94+
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
95+
ValueRange offsets) {
96+
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
97+
auto &&[base, offset, ignore] =
98+
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
99+
auto retType = cast<MemRefType>(base.getType());
100+
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
101+
std::nullopt, std::nullopt);
102+
}
103+
104+
static bool needFlatten(Value val) {
105+
auto type = cast<MemRefType>(val.getType());
106+
return type.getRank() != 0;
107+
}
108+
109+
static bool checkLayout(Value val) {
110+
auto type = cast<MemRefType>(val.getType());
111+
return type.getLayout().isIdentity() ||
112+
isa<StridedLayoutAttr>(type.getLayout());
113+
}
114+
115+
namespace {
116+
struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
117+
using OpRewritePattern::OpRewritePattern;
118+
119+
LogicalResult matchAndRewrite(memref::LoadOp op,
120+
PatternRewriter &rewriter) const override {
121+
if (!isInsideLaunch(op))
122+
return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
123+
124+
Value memref = op.getMemref();
125+
if (!needFlatten(memref))
126+
return rewriter.notifyMatchFailure(op, "nothing to do");
127+
128+
if (!checkLayout(memref))
129+
return rewriter.notifyMatchFailure(op, "unsupported layout");
130+
131+
Location loc = op.getLoc();
132+
Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
133+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
134+
return success();
135+
}
136+
};
137+
138+
struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
139+
using OpRewritePattern::OpRewritePattern;
140+
141+
LogicalResult matchAndRewrite(memref::StoreOp op,
142+
PatternRewriter &rewriter) const override {
143+
if (!isInsideLaunch(op))
144+
return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
145+
146+
Value memref = op.getMemref();
147+
if (!needFlatten(memref))
148+
return rewriter.notifyMatchFailure(op, "nothing to do");
149+
150+
if (!checkLayout(memref))
151+
return rewriter.notifyMatchFailure(op, "unsupported layout");
152+
153+
Location loc = op.getLoc();
154+
Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
155+
Value value = op.getValue();
156+
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
157+
return success();
158+
}
159+
};
160+
161+
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
162+
using OpRewritePattern::OpRewritePattern;
163+
164+
LogicalResult matchAndRewrite(memref::SubViewOp op,
165+
PatternRewriter &rewriter) const override {
166+
if (!isInsideLaunch(op))
167+
return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
168+
169+
Value memref = op.getSource();
170+
if (!needFlatten(memref))
171+
return rewriter.notifyMatchFailure(op, "nothing to do");
172+
173+
if (!checkLayout(memref))
174+
return rewriter.notifyMatchFailure(op, "unsupported layout");
175+
176+
Location loc = op.getLoc();
177+
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
178+
SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
179+
SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
180+
auto &&[base, finalOffset, strides] =
181+
getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
182+
183+
auto srcType = cast<MemRefType>(memref.getType());
184+
auto resultType = cast<MemRefType>(op.getType());
185+
unsigned subRank = static_cast<unsigned>(resultType.getRank());
186+
187+
llvm::SmallBitVector droppedDims = op.getDroppedDims();
188+
189+
SmallVector<OpFoldResult> finalSizes;
190+
finalSizes.reserve(subRank);
191+
192+
SmallVector<OpFoldResult> finalStrides;
193+
finalStrides.reserve(subRank);
194+
195+
for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
196+
if (droppedDims.test(i))
197+
continue;
198+
199+
finalSizes.push_back(subSizes[i]);
200+
finalStrides.push_back(strides[i]);
201+
}
202+
203+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
204+
op, resultType, base, finalOffset, finalSizes, finalStrides);
205+
return success();
206+
}
207+
};
208+
209+
struct GpuDecomposeMemrefsPass
210+
: public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
211+
212+
void runOnOperation() override {
213+
RewritePatternSet patterns(&getContext());
214+
215+
populateGpuDecomposeMemrefsPatterns(patterns);
216+
217+
if (failed(
218+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
219+
return signalPassFailure();
220+
}
221+
};
222+
223+
} // namespace
224+
225+
void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
226+
patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
227+
patterns.getContext());
228+
}
229+
230+
std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
231+
return std::make_unique<GpuDecomposeMemrefsPass>();
232+
}

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Utils/IndexingUtils.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/IR/AffineExpr.h"
1213
#include "mlir/IR/Builders.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
@@ -261,3 +262,34 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
261262
res.push_back((*it).getValue().getSExtValue());
262263
return res;
263264
}
265+
266+
OpFoldResult mlir::computeLinearIndex(OpBuilder &builder, Location loc,
267+
OpFoldResult sourceOffset,
268+
ArrayRef<OpFoldResult> strides,
269+
ArrayRef<OpFoldResult> indices) {
270+
assert(strides.size() == indices.size());
271+
auto sourceRank = static_cast<unsigned>(strides.size());
272+
273+
// Hold the affine symbols and values for the computation of the offset.
274+
SmallVector<OpFoldResult> values(2 * sourceRank + 1);
275+
SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
276+
277+
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
278+
AffineExpr expr = symbols.front();
279+
values[0] = sourceOffset;
280+
281+
for (unsigned i = 0; i < sourceRank; ++i) {
282+
// Compute the stride.
283+
OpFoldResult origStride = strides[i];
284+
285+
// Build up the computation of the offset.
286+
unsigned baseIdxForDim = 1 + 2 * i;
287+
unsigned subOffsetForDim = baseIdxForDim;
288+
unsigned origStrideForDim = baseIdxForDim + 1;
289+
expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
290+
values[subOffsetForDim] = indices[i];
291+
values[origStrideForDim] = origStride;
292+
}
293+
294+
return affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
295+
}

0 commit comments

Comments
 (0)