Skip to content

Commit 8dffb71

Browse files
committed
[mlir][VectorOps] Add lowering for vector.shape_cast of scalable vectors
This adds a lowering similar to the general shape_cast lowering, but instead moves elements a (scalable) subvector at a time via vector.scalable.extract/insert. It is restricted to the case where both the source and result vector types have a single trailing scalable dimension (due to limitations of the insert/extract ops). The current lowerings are now disabled for scalable vectors, as they produce incorrect results at runtime (due to assuming a fixed number of elements). Examples of casts that now work: // Flattening: %v = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8> // Un-flattening: %v = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> Reviewed By: awarzynski, nicolasvasilache Differential Revision: https://reviews.llvm.org/D159217
1 parent 6255157 commit 8dffb71

File tree

2 files changed

+387
-8
lines changed

2 files changed

+387
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 173 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class ShapeCastOp2DDownCastRewritePattern
5454
PatternRewriter &rewriter) const override {
5555
auto sourceVectorType = op.getSourceVectorType();
5656
auto resultVectorType = op.getResultVectorType();
57+
58+
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
59+
return failure();
60+
5761
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
5862
return failure();
5963

@@ -87,6 +91,10 @@ class ShapeCastOp2DUpCastRewritePattern
8791
PatternRewriter &rewriter) const override {
8892
auto sourceVectorType = op.getSourceVectorType();
8993
auto resultVectorType = op.getResultVectorType();
94+
95+
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
96+
return failure();
97+
9098
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
9199
return failure();
92100

@@ -106,6 +114,20 @@ class ShapeCastOp2DUpCastRewritePattern
106114
}
107115
};
108116

117+
static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
118+
int dimIdx, int initialStep = 1) {
119+
int step = initialStep;
120+
for (int d = dimIdx; d >= 0; d--) {
121+
idx[d] += step;
122+
if (idx[d] >= tp.getDimSize(d)) {
123+
idx[d] = 0;
124+
step = 1;
125+
} else {
126+
break;
127+
}
128+
}
129+
}
130+
109131
// We typically should not lower general shape cast operations into data
110132
// movement instructions, since the assumption is that these casts are
111133
// optimized away during progressive lowering. For completeness, however,
@@ -121,6 +143,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
121143
auto sourceVectorType = op.getSourceVectorType();
122144
auto resultVectorType = op.getResultVectorType();
123145

146+
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
147+
return failure();
148+
124149
// Special case 2D / 1D lowerings with better implementations.
125150
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
126151
int64_t srcRank = sourceVectorType.getRank();
@@ -175,21 +200,161 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
175200
rewriter.replaceOp(op, result);
176201
return success();
177202
}
203+
};
204+
205+
/// A shape_cast lowering for scalable vectors with a single trailing scalable
206+
/// dimension. This is similar to the general shape_cast lowering but makes use
207+
/// of vector.scalable.insert and vector.scalable.extract to move elements a
208+
/// subvector at a time.
209+
///
210+
/// E.g.:
211+
/// ```
212+
/// // Flatten scalable vector
213+
/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
214+
/// ```
215+
/// is rewritten to:
216+
/// ```
217+
/// // Flatten scalable vector
218+
/// %c = arith.constant dense<0> : vector<[8]xi32>
219+
/// %0 = vector.extract %arg0[0, 0] : vector<2x1x[4]xi32>
220+
/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
221+
/// %2 = vector.extract %arg0[1, 0] : vector<2x1x[4]xi32>
222+
/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
223+
/// ```
224+
/// or:
225+
/// ```
226+
/// // Un-flatten scalable vector
227+
/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
228+
/// ```
229+
/// is rewritten to:
230+
/// ```
231+
/// // Un-flatten scalable vector
232+
/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
233+
/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
234+
/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
235+
/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
236+
/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
237+
/// ```
238+
class ScalableShapeCastOpRewritePattern
239+
: public OpRewritePattern<vector::ShapeCastOp> {
240+
public:
241+
using OpRewritePattern::OpRewritePattern;
242+
243+
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
244+
PatternRewriter &rewriter) const override {
245+
246+
Location loc = op.getLoc();
247+
auto sourceVectorType = op.getSourceVectorType();
248+
auto resultVectorType = op.getResultVectorType();
249+
auto srcRank = sourceVectorType.getRank();
250+
auto resRank = resultVectorType.getRank();
251+
252+
// This can only lower shape_casts where both the source and result types
253+
// have a single trailing scalable dimension. This is because there are no
254+
// legal representation of other scalable types in LLVM (and likely won't be
255+
// soon). There are also (currently) no operations that can index or extract
256+
// from >= 2D scalable vectors or scalable vectors of fixed vectors.
257+
if (!isTrailingDimScalable(sourceVectorType) ||
258+
!isTrailingDimScalable(resultVectorType)) {
259+
return failure();
260+
}
261+
262+
// The sizes of the trailing dimension of the source and result vectors, the
263+
// size of subvector to move, and the number of elements in the vectors.
264+
// These are "min" sizes as they are the size when vscale == 1.
265+
auto minSourceTrailingSize = sourceVectorType.getShape().back();
266+
auto minResultTrailingSize = resultVectorType.getShape().back();
267+
auto minExtractionSize =
268+
std::min(minSourceTrailingSize, minResultTrailingSize);
269+
int64_t minNumElts = 1;
270+
for (auto size : sourceVectorType.getShape())
271+
minNumElts *= size;
272+
273+
// The subvector type to move from the source to the result. Note that this
274+
// is a scalable vector. This rewrite will generate code in terms of the
275+
// "min" size (vscale == 1 case), that scales to any vscale.
276+
auto extractionVectorType = VectorType::get(
277+
{minExtractionSize}, sourceVectorType.getElementType(), {true});
278+
279+
Value result = rewriter.create<arith::ConstantOp>(
280+
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
281+
282+
SmallVector<int64_t> srcIdx(srcRank);
283+
SmallVector<int64_t> resIdx(resRank);
284+
285+
// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
286+
// once D150000 lands.
287+
Value currentResultScalableVector;
288+
Value currentSourceScalableVector;
289+
for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
290+
// 1. Extract a scalable subvector from the source vector.
291+
if (!currentSourceScalableVector) {
292+
if (srcRank != 1) {
293+
currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
294+
loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
295+
} else {
296+
currentSourceScalableVector = op.getSource();
297+
}
298+
}
299+
Value sourceSubVector = currentSourceScalableVector;
300+
if (minExtractionSize < minSourceTrailingSize) {
301+
sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
302+
loc, extractionVectorType, sourceSubVector, srcIdx.back());
303+
}
178304

179-
private:
180-
static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
181-
assert(0 <= r && r < tp.getRank());
182-
if (++idx[r] == tp.getDimSize(r)) {
183-
idx[r] = 0;
184-
incIdx(idx, tp, r - 1);
305+
// 2. Insert the scalable subvector into the result vector.
306+
if (!currentResultScalableVector) {
307+
if (minExtractionSize == minResultTrailingSize) {
308+
currentResultScalableVector = sourceSubVector;
309+
} else if (resRank != 1) {
310+
currentResultScalableVector = rewriter.create<vector::ExtractOp>(
311+
loc, result, llvm::ArrayRef(resIdx).drop_back());
312+
} else {
313+
currentResultScalableVector = result;
314+
}
315+
}
316+
if (minExtractionSize < minResultTrailingSize) {
317+
currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
318+
loc, sourceSubVector, currentResultScalableVector, resIdx.back());
319+
}
320+
321+
// 3. Update the source and result scalable vectors if needed.
322+
if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323+
currentResultScalableVector != result) {
324+
// Finished row of result. Insert complete scalable vector into result
325+
// (n-D) vector.
326+
result = rewriter.create<vector::InsertOp>(
327+
loc, currentResultScalableVector, result,
328+
llvm::ArrayRef(resIdx).drop_back());
329+
currentResultScalableVector = {};
330+
}
331+
if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
332+
// Finished row of source.
333+
currentSourceScalableVector = {};
334+
}
335+
336+
// 4. Increment the insert/extract indices, stepping by minExtractionSize
337+
// for the trailing dimensions.
338+
incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339+
incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
185340
}
341+
342+
rewriter.replaceOp(op, result);
343+
return success();
344+
}
345+
346+
static bool isTrailingDimScalable(VectorType type) {
347+
return type.getRank() >= 1 && type.getScalableDims().back() &&
348+
!llvm::is_contained(type.getScalableDims().drop_back(), true);
186349
}
187350
};
351+
188352
} // namespace
189353

190354
void mlir::vector::populateVectorShapeCastLoweringPatterns(
191355
RewritePatternSet &patterns, PatternBenefit benefit) {
192356
patterns.add<ShapeCastOp2DDownCastRewritePattern,
193-
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
194-
patterns.getContext(), benefit);
357+
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358+
ScalableShapeCastOpRewritePattern>(patterns.getContext(),
359+
benefit);
195360
}

0 commit comments

Comments
 (0)