Skip to content

Commit 10054ba

Browse files
bjacobkuhar
andauthored
[mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541)
Co-authored-by: Jakub Kuderski <[email protected]>
1 parent d079743 commit 10054ba

File tree

4 files changed

+121
-0
lines changed

4 files changed

+121
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
235235
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
236236
PatternBenefit benefit = 1);
237237

238+
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
239+
/// slice is contiguous, into extract and shape_cast.
240+
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
241+
RewritePatternSet &patterns, PatternBenefit benefit = 1);
242+
238243
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
239244
/// based on the destination vector shape. Bitcasts from a lower bitwidth
240245
/// element type to a higher bitwidth one are extracted from the lower bitwidth

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,70 @@ class DecomposeNDExtractStridedSlice
329329
}
330330
};
331331

332+
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
333+
/// slice is contiguous, into extract and shape_cast.
334+
class ContiguousExtractStridedSliceToExtract final
335+
: public OpRewritePattern<ExtractStridedSliceOp> {
336+
public:
337+
using OpRewritePattern::OpRewritePattern;
338+
339+
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
340+
PatternRewriter &rewriter) const override {
341+
if (op.hasNonUnitStrides()) {
342+
return failure();
343+
}
344+
Value source = op.getOperand();
345+
auto sourceType = cast<VectorType>(source.getType());
346+
if (sourceType.isScalable()) {
347+
return failure();
348+
}
349+
350+
// Compute the number of offsets to pass to ExtractOp::build. That is the
351+
// difference between the source rank and the desired slice rank. We walk
352+
// the dimensions from innermost out, and stop when the next slice dimension
353+
// is not full-size.
354+
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
355+
int numOffsets;
356+
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
357+
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
358+
break;
359+
}
360+
}
361+
362+
// If not even the inner-most dimension is full-size, this op can't be
363+
// rewritten as an ExtractOp.
364+
if (numOffsets == sourceType.getRank()) {
365+
return failure();
366+
}
367+
368+
// Avoid generating slices that have unit outer dimensions. The shape_cast
369+
// op that we create below would take bad generic fallback patterns
370+
// (ShapeCastOpRewritePattern).
371+
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
372+
++numOffsets;
373+
}
374+
375+
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
376+
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
377+
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
378+
extractOffsets);
379+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
380+
return success();
381+
}
382+
};
383+
332384
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
333385
RewritePatternSet &patterns, PatternBenefit benefit) {
334386
patterns.add<DecomposeDifferentRankInsertStridedSlice,
335387
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
336388
}
337389

390+
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
391+
RewritePatternSet &patterns, PatternBenefit benefit) {
392+
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
393+
benefit);
394+
}
395+
338396
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
339397
RewritePatternSet &patterns,
340398
std::function<bool(ExtractStridedSliceOp)> controlFn,
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
2+
3+
// CHECK-LABEL: @extract_strided_slice_to_extract_i8
4+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
5+
// CHECK: return %[[EXTRACT]] : vector<8xi8>
6+
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
7+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
8+
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
9+
return %2 : vector<8xi8>
10+
}
11+
12+
// CHECK-LABEL: @extract_strided_slice_to_extract_i32
13+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
14+
// CHECK: return %[[EXTRACT]] : vector<4xi32>
15+
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
16+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
17+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
18+
return %2 : vector<4xi32>
19+
}
20+
21+
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
22+
// CHECK: vector.extract_strided_slice
23+
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
24+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
25+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
26+
return %2 : vector<2xi32>
27+
}
28+
29+
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
30+
// CHECK: vector.extract_strided_slice
31+
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
32+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
33+
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
34+
return %2 : vector<2xi32>
35+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
709709
}
710710
};
711711

712+
struct TestVectorContiguousExtractStridedSliceToExtract
713+
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
714+
OperationPass<func::FuncOp>> {
715+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
716+
TestVectorExtractStridedSliceLowering)
717+
718+
StringRef getArgument() const final {
719+
return "test-vector-contiguous-extract-strided-slice-to-extract";
720+
}
721+
StringRef getDescription() const final {
722+
return "Test lowering patterns that rewrite simple cases of N-D "
723+
"extract_strided_slice, where the slice is contiguous, into extract "
724+
"and shape_cast";
725+
}
726+
void runOnOperation() override {
727+
RewritePatternSet patterns(&getContext());
728+
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
729+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
730+
}
731+
};
732+
712733
struct TestVectorBreakDownBitCast
713734
: public PassWrapper<TestVectorBreakDownBitCast,
714735
OperationPass<func::FuncOp>> {
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
935956

936957
PassRegistration<TestVectorExtractStridedSliceLowering>();
937958

959+
PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
960+
938961
PassRegistration<TestVectorBreakDownBitCast>();
939962

940963
PassRegistration<TestCreateVectorBroadcast>();

0 commit comments

Comments
 (0)