Skip to content

[MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization #111614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);

/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
/// based on the destination vector shape. Bitcasts from a lower bitwidth
/// element type to a higher bitwidth one are extracted from the lower bitwidth
Expand Down
79 changes: 78 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3772,6 +3772,82 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
}
};

/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
///
/// Example:
/// Before:
/// %1 = vector.extract_strided_slice %arg0 {
/// offsets = [0, 0, 0, 0, 0],
/// sizes = [1, 1, 1, 1, 8],
/// strides = [1, 1, 1, 1, 1]
/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
/// After:
/// %0 = vector.extract %arg0[0, 0, 0, 0]
/// : vector<8xi8> from vector<8x1x1x2x8xi8>
/// %1 = vector.shape_cast %0
/// : vector<8xi8> to vector<1x1x1x1x8xi8>
///
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
if (op.hasNonUnitStrides())
return failure();
Value source = op.getOperand();
auto sourceType = cast<VectorType>(source.getType());
if (sourceType.isScalable() || sourceType.getRank() == 0)
return failure();

// Compute the number of offsets to pass to ExtractOp::build. That is the
// difference between the source rank and the desired slice rank. We walk
// the dimensions from innermost out, and stop when the next slice dimension
// is not full-size.
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
int numOffsets;
for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
break;
}

// If the created extract op would have no offsets, then this whole
// extract_strided_slice is the identity and should have been handled by
// other canonicalizations.
if (numOffsets == 0)
return failure();

// If not even the inner-most dimension is full-size, this op can't be
// rewritten as an ExtractOp.
if (numOffsets == sourceType.getRank() &&
static_cast<int>(sizes.size()) == sourceType.getRank())
return failure();

// The outer dimensions must have unit size.
for (int i = 0; i < numOffsets; ++i) {
if (sizes[i] != 1)
return failure();
}

// Avoid generating slices that have leading unit dimensions. The shape_cast
// op that we create below would take bad generic fallback patterns
// (ShapeCastOpRewritePattern).
while (sizes[numOffsets] == 1 &&
numOffsets < static_cast<int>(sizes.size()) - 1) {
++numOffsets;
}

SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
};

} // namespace

void ExtractStridedSliceOp::getCanonicalizationPatterns(
Expand All @@ -3780,7 +3856,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
StridedSliceSplat>(context);
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
}
};

/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
///
/// Example:
/// Before:
/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
/// After:
/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
/// vector<1x1x1x1x8xi8>
///
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
if (op.hasNonUnitStrides()) {
return failure();
}
Value source = op.getOperand();
auto sourceType = cast<VectorType>(source.getType());
if (sourceType.isScalable()) {
return failure();
}

// Compute the number of offsets to pass to ExtractOp::build. That is the
// difference between the source rank and the desired slice rank. We walk
// the dimensions from innermost out, and stop when the next slice dimension
// is not full-size.
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
int numOffsets;
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
break;
}
}

// If not even the inner-most dimension is full-size, this op can't be
// rewritten as an ExtractOp.
if (numOffsets == sourceType.getRank()) {
return failure();
}

// Avoid generating slices that have unit outer dimensions. The shape_cast
// op that we create below would take bad generic fallback patterns
// (ShapeCastOpRewritePattern).
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
++numOffsets;
}

SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
};

void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}

void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
benefit);
}

void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,
Expand Down
49 changes: 49 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
return %1 : vector<4xi8>
}

// -----

// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
return %2 : vector<4xi32>
}

// -----

// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
// CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>
func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<1x4xi32>
return %2 : vector<1x4xi32>
}

// -----

// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
// CHECK-NEXT: vector.extract_strided_slice
func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
return %1 : vector<8x1x1x1x1x4xi32>
}

// -----

// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
// CHECK-NEXT: vector.extract_strided_slice
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
return %1 : vector<1x1x1x1x1x2xi32>
}

// -----

// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
// CHECK-NEXT: vector.extract_strided_slice
func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
return %1 : vector<1x1x2x1x1x1xi32>
}

This file was deleted.

23 changes: 0 additions & 23 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
}
};

struct TestVectorContiguousExtractStridedSliceToExtract
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorExtractStridedSliceLowering)

StringRef getArgument() const final {
return "test-vector-contiguous-extract-strided-slice-to-extract";
}
StringRef getDescription() const final {
return "Test lowering patterns that rewrite simple cases of N-D "
"extract_strided_slice, where the slice is contiguous, into extract "
"and shape_cast";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -956,8 +935,6 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorExtractStridedSliceLowering>();

PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();

PassRegistration<TestVectorBreakDownBitCast>();

PassRegistration<TestCreateVectorBroadcast>();
Expand Down
Loading