Skip to content

Commit a9ebdbb

Browse files
authored
[MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (#111614)
This is a reasonable canonicalization because `extract` is more constrained than `extract_strided_slices`, so there is no loss of semantics here, just lifting an op to a special-case higher/constrained op. And the additional `shape_cast` is merely adding leading unit dims to match the original result type. Context: discussion on #111541. I wasn't sure how this would turn out, but in the process of writing this PR, I discovered at least 2 bugs in the pattern introduced in #111541, which shows the value of shared canonicalization patterns which are exercised on a high number of testcases. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent 1e357cd commit a9ebdbb

File tree

6 files changed

+127
-122
lines changed

6 files changed

+127
-122
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
235235
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
236236
PatternBenefit benefit = 1);
237237

238-
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
239-
/// slice is contiguous, into extract and shape_cast.
240-
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
241-
RewritePatternSet &patterns, PatternBenefit benefit = 1);
242-
243238
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
244239
/// based on the destination vector shape. Bitcasts from a lower bitwidth
245240
/// element type to a higher bitwidth one are extracted from the lower bitwidth

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3772,6 +3772,82 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
37723772
}
37733773
};
37743774

3775+
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3776+
/// slice is contiguous, into extract and shape_cast.
3777+
///
3778+
/// Example:
3779+
/// Before:
3780+
/// %1 = vector.extract_strided_slice %arg0 {
3781+
/// offsets = [0, 0, 0, 0, 0],
3782+
/// sizes = [1, 1, 1, 1, 8],
3783+
/// strides = [1, 1, 1, 1, 1]
3784+
/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3785+
/// After:
3786+
/// %0 = vector.extract %arg0[0, 0, 0, 0]
3787+
/// : vector<8xi8> from vector<8x1x1x2x8xi8>
3788+
/// %1 = vector.shape_cast %0
3789+
/// : vector<8xi8> to vector<1x1x1x1x8xi8>
3790+
///
3791+
class ContiguousExtractStridedSliceToExtract final
3792+
: public OpRewritePattern<ExtractStridedSliceOp> {
3793+
public:
3794+
using OpRewritePattern::OpRewritePattern;
3795+
3796+
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3797+
PatternRewriter &rewriter) const override {
3798+
if (op.hasNonUnitStrides())
3799+
return failure();
3800+
Value source = op.getOperand();
3801+
auto sourceType = cast<VectorType>(source.getType());
3802+
if (sourceType.isScalable() || sourceType.getRank() == 0)
3803+
return failure();
3804+
3805+
// Compute the number of offsets to pass to ExtractOp::build. That is the
3806+
// difference between the source rank and the desired slice rank. We walk
3807+
// the dimensions from innermost out, and stop when the next slice dimension
3808+
// is not full-size.
3809+
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3810+
int numOffsets;
3811+
for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3812+
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3813+
break;
3814+
}
3815+
3816+
// If the created extract op would have no offsets, then this whole
3817+
// extract_strided_slice is the identity and should have been handled by
3818+
// other canonicalizations.
3819+
if (numOffsets == 0)
3820+
return failure();
3821+
3822+
// If not even the inner-most dimension is full-size, this op can't be
3823+
// rewritten as an ExtractOp.
3824+
if (numOffsets == sourceType.getRank() &&
3825+
static_cast<int>(sizes.size()) == sourceType.getRank())
3826+
return failure();
3827+
3828+
// The outer dimensions must have unit size.
3829+
for (int i = 0; i < numOffsets; ++i) {
3830+
if (sizes[i] != 1)
3831+
return failure();
3832+
}
3833+
3834+
// Avoid generating slices that have leading unit dimensions. The shape_cast
3835+
// op that we create below would take bad generic fallback patterns
3836+
// (ShapeCastOpRewritePattern).
3837+
while (sizes[numOffsets] == 1 &&
3838+
numOffsets < static_cast<int>(sizes.size()) - 1) {
3839+
++numOffsets;
3840+
}
3841+
3842+
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3843+
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3844+
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3845+
extractOffsets);
3846+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3847+
return success();
3848+
}
3849+
};
3850+
37753851
} // namespace
37763852

37773853
void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -3780,7 +3856,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
37803856
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
37813857
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
37823858
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3783-
StridedSliceSplat>(context);
3859+
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3860+
context);
37843861
}
37853862

37863863
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
329329
}
330330
};
331331

332-
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
333-
/// slice is contiguous, into extract and shape_cast.
334-
///
335-
/// Example:
336-
/// Before:
337-
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
338-
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
339-
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
340-
/// After:
341-
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
342-
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
343-
/// vector<1x1x1x1x8xi8>
344-
///
345-
class ContiguousExtractStridedSliceToExtract final
346-
: public OpRewritePattern<ExtractStridedSliceOp> {
347-
public:
348-
using OpRewritePattern::OpRewritePattern;
349-
350-
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
351-
PatternRewriter &rewriter) const override {
352-
if (op.hasNonUnitStrides()) {
353-
return failure();
354-
}
355-
Value source = op.getOperand();
356-
auto sourceType = cast<VectorType>(source.getType());
357-
if (sourceType.isScalable()) {
358-
return failure();
359-
}
360-
361-
// Compute the number of offsets to pass to ExtractOp::build. That is the
362-
// difference between the source rank and the desired slice rank. We walk
363-
// the dimensions from innermost out, and stop when the next slice dimension
364-
// is not full-size.
365-
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
366-
int numOffsets;
367-
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
368-
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
369-
break;
370-
}
371-
}
372-
373-
// If not even the inner-most dimension is full-size, this op can't be
374-
// rewritten as an ExtractOp.
375-
if (numOffsets == sourceType.getRank()) {
376-
return failure();
377-
}
378-
379-
// Avoid generating slices that have unit outer dimensions. The shape_cast
380-
// op that we create below would take bad generic fallback patterns
381-
// (ShapeCastOpRewritePattern).
382-
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
383-
++numOffsets;
384-
}
385-
386-
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
387-
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
388-
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
389-
extractOffsets);
390-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
391-
return success();
392-
}
393-
};
394-
395332
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
396333
RewritePatternSet &patterns, PatternBenefit benefit) {
397334
patterns.add<DecomposeDifferentRankInsertStridedSlice,
398335
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
399336
}
400337

401-
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
402-
RewritePatternSet &patterns, PatternBenefit benefit) {
403-
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
404-
benefit);
405-
}
406-
407338
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
408339
RewritePatternSet &patterns,
409340
std::function<bool(ExtractStridedSliceOp)> controlFn,

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
27422742
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
27432743
return %1 : vector<4xi8>
27442744
}
2745+
2746+
// -----
2747+
2748+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
2749+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
2750+
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
2751+
func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
2752+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
2753+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
2754+
return %2 : vector<4xi32>
2755+
}
2756+
2757+
// -----
2758+
2759+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
2760+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
2761+
// CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>
2762+
func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x4xi32> {
2763+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
2764+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<1x4xi32>
2765+
return %2 : vector<1x4xi32>
2766+
}
2767+
2768+
// -----
2769+
2770+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
2771+
// CHECK-NEXT: vector.extract_strided_slice
2772+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
2773+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
2774+
return %1 : vector<8x1x1x1x1x4xi32>
2775+
}
2776+
2777+
// -----
2778+
2779+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
2780+
// CHECK-NEXT: vector.extract_strided_slice
2781+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
2782+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
2783+
return %1 : vector<1x1x1x1x1x2xi32>
2784+
}
2785+
2786+
// -----
2787+
2788+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
2789+
// CHECK-NEXT: vector.extract_strided_slice
2790+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
2791+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
2792+
return %1 : vector<1x1x2x1x1x1xi32>
2793+
}

mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir

Lines changed: 0 additions & 24 deletions
This file was deleted.

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
709709
}
710710
};
711711

712-
struct TestVectorContiguousExtractStridedSliceToExtract
713-
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
714-
OperationPass<func::FuncOp>> {
715-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
716-
TestVectorExtractStridedSliceLowering)
717-
718-
StringRef getArgument() const final {
719-
return "test-vector-contiguous-extract-strided-slice-to-extract";
720-
}
721-
StringRef getDescription() const final {
722-
return "Test lowering patterns that rewrite simple cases of N-D "
723-
"extract_strided_slice, where the slice is contiguous, into extract "
724-
"and shape_cast";
725-
}
726-
void runOnOperation() override {
727-
RewritePatternSet patterns(&getContext());
728-
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
729-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
730-
}
731-
};
732-
733712
struct TestVectorBreakDownBitCast
734713
: public PassWrapper<TestVectorBreakDownBitCast,
735714
OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
956935

957936
PassRegistration<TestVectorExtractStridedSliceLowering>();
958937

959-
PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
960-
961938
PassRegistration<TestVectorBreakDownBitCast>();
962939

963940
PassRegistration<TestCreateVectorBroadcast>();

0 commit comments

Comments
 (0)