Skip to content

Commit 80f5b9c

Browse files
committed
canonicalize
1 parent c80f484 commit 80f5b9c

File tree

6 files changed

+137
-122
lines changed

6 files changed

+137
-122
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
235235
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
236236
PatternBenefit benefit = 1);
237237

238-
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
239-
/// slice is contiguous, into extract and shape_cast.
240-
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
241-
RewritePatternSet &patterns, PatternBenefit benefit = 1);
242-
243238
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
244239
/// based on the destination vector shape. Bitcasts from a lower bitwidth
245240
/// element type to a higher bitwidth one are extracted from the lower bitwidth

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3772,6 +3772,92 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
37723772
}
37733773
};
37743774

3775+
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3776+
/// slice is contiguous, into extract and shape_cast.
3777+
///
3778+
/// Example:
3779+
/// Before:
3780+
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
3781+
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
3782+
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3783+
/// After:
3784+
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
3785+
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
3786+
/// vector<1x1x1x1x8xi8>
3787+
///
3788+
class ContiguousExtractStridedSliceToExtract final
3789+
: public OpRewritePattern<ExtractStridedSliceOp> {
3790+
public:
3791+
using OpRewritePattern::OpRewritePattern;
3792+
3793+
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3794+
PatternRewriter &rewriter) const override {
3795+
if (op.hasNonUnitStrides()) {
3796+
return failure();
3797+
}
3798+
Value source = op.getOperand();
3799+
auto sourceType = cast<VectorType>(source.getType());
3800+
if (sourceType.isScalable() || sourceType.getRank() == 0) {
3801+
return failure();
3802+
}
3803+
3804+
// Compute the number of offsets to pass to ExtractOp::build. That is the
3805+
// difference between the source rank and the desired slice rank. We walk
3806+
// the dimensions from innermost out, and stop when the next slice dimension
3807+
// is not full-size.
3808+
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3809+
int numOffsets;
3810+
for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3811+
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
3812+
break;
3813+
}
3814+
}
3815+
3816+
// If the created extract op would have no offsets, then this whole
3817+
// extract_strided_slice is the identity and should have been handled by
3818+
// other canonicalizations.
3819+
if (numOffsets == 0) {
3820+
return failure();
3821+
}
3822+
3823+
// If not even the inner-most dimension is full-size, this op can't be
3824+
// rewritten as an ExtractOp.
3825+
if (numOffsets == sourceType.getRank() &&
3826+
static_cast<int>(sizes.size()) == sourceType.getRank()) {
3827+
return failure();
3828+
}
3829+
3830+
// The outer dimensions must have unit size.
3831+
for (int i = 0; i < numOffsets; ++i) {
3832+
if (sizes[i] != 1) {
3833+
return failure();
3834+
}
3835+
}
3836+
3837+
// Avoid generating slices that have leading unit dimensions. The shape_cast
3838+
// op that we create below would take bad generic fallback patterns
3839+
// (ShapeCastOpRewritePattern).
3840+
while (sizes[numOffsets] == 1 &&
3841+
numOffsets < static_cast<int>(sizes.size()) - 1) {
3842+
++numOffsets;
3843+
}
3844+
// After exhausting the list of slice sizes, we keep checking for unit
3845+
// dimensions in the source shape, to remove corner cases where the result
3846+
// would have a leading unit dimension.
3847+
while (sourceType.getDimSize(numOffsets) == 1 &&
3848+
numOffsets < sourceType.getRank() - 1) {
3849+
++numOffsets;
3850+
}
3851+
3852+
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3853+
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3854+
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3855+
extractOffsets);
3856+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3857+
return success();
3858+
}
3859+
};
3860+
37753861
} // namespace
37763862

37773863
void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -3780,7 +3866,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
37803866
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
37813867
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
37823868
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3783-
StridedSliceSplat>(context);
3869+
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3870+
context);
37843871
}
37853872

37863873
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
329329
}
330330
};
331331

332-
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
333-
/// slice is contiguous, into extract and shape_cast.
334-
///
335-
/// Example:
336-
/// Before:
337-
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
338-
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
339-
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
340-
/// After:
341-
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
342-
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
343-
/// vector<1x1x1x1x8xi8>
344-
///
345-
class ContiguousExtractStridedSliceToExtract final
346-
: public OpRewritePattern<ExtractStridedSliceOp> {
347-
public:
348-
using OpRewritePattern::OpRewritePattern;
349-
350-
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
351-
PatternRewriter &rewriter) const override {
352-
if (op.hasNonUnitStrides()) {
353-
return failure();
354-
}
355-
Value source = op.getOperand();
356-
auto sourceType = cast<VectorType>(source.getType());
357-
if (sourceType.isScalable()) {
358-
return failure();
359-
}
360-
361-
// Compute the number of offsets to pass to ExtractOp::build. That is the
362-
// difference between the source rank and the desired slice rank. We walk
363-
// the dimensions from innermost out, and stop when the next slice dimension
364-
// is not full-size.
365-
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
366-
int numOffsets;
367-
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
368-
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
369-
break;
370-
}
371-
}
372-
373-
// If not even the inner-most dimension is full-size, this op can't be
374-
// rewritten as an ExtractOp.
375-
if (numOffsets == sourceType.getRank()) {
376-
return failure();
377-
}
378-
379-
// Avoid generating slices that have unit outer dimensions. The shape_cast
380-
// op that we create below would take bad generic fallback patterns
381-
// (ShapeCastOpRewritePattern).
382-
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
383-
++numOffsets;
384-
}
385-
386-
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
387-
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
388-
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
389-
extractOffsets);
390-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
391-
return success();
392-
}
393-
};
394-
395332
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
396333
RewritePatternSet &patterns, PatternBenefit benefit) {
397334
patterns.add<DecomposeDifferentRankInsertStridedSlice,
398335
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
399336
}
400337

401-
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
402-
RewritePatternSet &patterns, PatternBenefit benefit) {
403-
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
404-
benefit);
405-
}
406-
407338
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
408339
RewritePatternSet &patterns,
409340
std::function<bool(ExtractStridedSliceOp)> controlFn,

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
27422742
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
27432743
return %1 : vector<4xi8>
27442744
}
2745+
2746+
// -----
2747+
2748+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
2749+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
2750+
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
2751+
func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
2752+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
2753+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
2754+
return %2 : vector<4xi32>
2755+
}
2756+
2757+
// -----
2758+
2759+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
2760+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
2761+
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
2762+
func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
2763+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
2764+
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
2765+
return %2 : vector<4xi32>
2766+
}
2767+
2768+
// -----
2769+
2770+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
2771+
// CHECK-NEXT: vector.extract_strided_slice
2772+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
2773+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
2774+
return %1 : vector<8x1x1x1x1x4xi32>
2775+
}
2776+
2777+
// -----
2778+
2779+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
2780+
// CHECK-NEXT: vector.extract_strided_slice
2781+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
2782+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
2783+
return %1 : vector<1x1x1x1x1x2xi32>
2784+
}
2785+
2786+
// -----
2787+
2788+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
2789+
// CHECK-NEXT: vector.extract_strided_slice
2790+
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
2791+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
2792+
return %1 : vector<1x1x2x1x1x1xi32>
2793+
}

mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir

Lines changed: 0 additions & 24 deletions
This file was deleted.

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
709709
}
710710
};
711711

712-
struct TestVectorContiguousExtractStridedSliceToExtract
713-
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
714-
OperationPass<func::FuncOp>> {
715-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
716-
TestVectorExtractStridedSliceLowering)
717-
718-
StringRef getArgument() const final {
719-
return "test-vector-contiguous-extract-strided-slice-to-extract";
720-
}
721-
StringRef getDescription() const final {
722-
return "Test lowering patterns that rewrite simple cases of N-D "
723-
"extract_strided_slice, where the slice is contiguous, into extract "
724-
"and shape_cast";
725-
}
726-
void runOnOperation() override {
727-
RewritePatternSet patterns(&getContext());
728-
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
729-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
730-
}
731-
};
732-
733712
struct TestVectorBreakDownBitCast
734713
: public PassWrapper<TestVectorBreakDownBitCast,
735714
OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
956935

957936
PassRegistration<TestVectorExtractStridedSliceLowering>();
958937

959-
PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
960-
961938
PassRegistration<TestVectorBreakDownBitCast>();
962939

963940
PassRegistration<TestCreateVectorBroadcast>();

0 commit comments

Comments
 (0)