Skip to content

[MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization #111614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 9, 2024

This is a reasonable canonicalization because extract is more constrained than extract_strided_slices, so there is no loss of semantics here, just lifting an op to a special-case higher/constrained op. And the additional shape_cast is merely adding leading unit dims to match the original result type.

Context: discussion on #111541. I wasn't sure how this would turn out, but in the process of writing this PR, I discovered at least 2 bugs in the pattern introduced in #111541, which shows the value of shared canonicalization patterns which are exercised on a high number of testcases.

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benoit Jacob (bjacob)

Changes

This is just experimenting with what it might look like if we followed the suggestion in the discussion on #111541.

This was worthwhile no matter what the outcome, as suddenly that pattern code ran on a larger number of testcases, discovering at least 2 bugs, covered here by additional canonicalize.mlir test cases.

WDYT @joker-eph @MaheshRavishankar ?


Full diff: https://github.com/llvm/llvm-project/pull/111614.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (-5)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+88-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (-69)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+49)
  • (removed) mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir (-24)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (-23)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ec1de7fa66aa07..a59f06f3c1ef1b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
     PatternBenefit benefit = 1);
 
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-void populateVectorContiguousExtractStridedSliceToExtractPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit = 1);
-
 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
 /// based on the destination vector shape. Bitcasts from a lower bitwidth
 /// element type to a higher bitwidth one are extracted from the lower bitwidth
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index dc92bea09dc160..cda31706474e27 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3772,6 +3772,92 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
   }
 };
 
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+///     Before:
+///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+///     After:
+///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+///         vector<1x1x1x1x8xi8>
+///
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    Value source = op.getOperand();
+    auto sourceType = cast<VectorType>(source.getType());
+    if (sourceType.isScalable() || sourceType.getRank() == 0) {
+      return failure();
+    }
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // difference between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+    int numOffsets;
+    for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If the created extract op would have no offsets, then this whole
+    // extract_strided_slice is the identity and should have been handled by
+    // other canonicalizations.
+    if (numOffsets == 0) {
+      return failure();
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank() &&
+        static_cast<int>(sizes.size()) == sourceType.getRank()) {
+      return failure();
+    }
+
+    // The outer dimensions must have unit size.
+    for (int i = 0; i < numOffsets; ++i) {
+      if (sizes[i] != 1) {
+        return failure();
+      }
+    }
+
+    // Avoid generating slices that have leading unit dimensions. The shape_cast
+    // op that we create below would take bad generic fallback patterns
+    // (ShapeCastOpRewritePattern).
+    while (sizes[numOffsets] == 1 &&
+           numOffsets < static_cast<int>(sizes.size()) - 1) {
+      ++numOffsets;
+    }
+    // After exhausting the list of slice sizes, we keep checking for unit
+    // dimensions in the source shape, to remove corner cases where the result
+    // would have a leading unit dimension.
+    while (sourceType.getDimSize(numOffsets) == 1 &&
+           numOffsets < sourceType.getRank() - 1) {
+      ++numOffsets;
+    }
+
+    SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
+    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+                                                       extractOffsets);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -3780,7 +3866,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
   results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
               StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
-              StridedSliceSplat>(context);
+              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ad845608f18d10..ec2ef3fc7501c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-///
-/// Example:
-///     Before:
-///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
-///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
-///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
-///     After:
-///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
-///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
-///         vector<1x1x1x1x8xi8>
-///
-class ContiguousExtractStridedSliceToExtract final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.hasNonUnitStrides()) {
-      return failure();
-    }
-    Value source = op.getOperand();
-    auto sourceType = cast<VectorType>(source.getType());
-    if (sourceType.isScalable()) {
-      return failure();
-    }
-
-    // Compute the number of offsets to pass to ExtractOp::build. That is the
-    // difference between the source rank and the desired slice rank. We walk
-    // the dimensions from innermost out, and stop when the next slice dimension
-    // is not full-size.
-    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
-    int numOffsets;
-    for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
-      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
-        break;
-      }
-    }
-
-    // If not even the inner-most dimension is full-size, this op can't be
-    // rewritten as an ExtractOp.
-    if (numOffsets == sourceType.getRank()) {
-      return failure();
-    }
-
-    // Avoid generating slices that have unit outer dimensions. The shape_cast
-    // op that we create below would take bad generic fallback patterns
-    // (ShapeCastOpRewritePattern).
-    while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
-      ++numOffsets;
-    }
-
-    SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
-    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
-    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
-                                                       extractOffsets);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
-    return success();
-  }
-};
-
 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
-void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
-                                                       benefit);
-}
-
 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     RewritePatternSet &patterns,
     std::function<bool(ExtractStridedSliceOp)> controlFn,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7c78de4b5bd89..15b77c91439cfe 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
   %1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
   return %1 : vector<4xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
+func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
+func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
+// CHECK-NEXT:   vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
+  return %1 : vector<8x1x1x1x1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
+// CHECK-NEXT:   vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+  return %1 : vector<1x1x1x1x1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
+// CHECK-NEXT:    vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+  return %1 : vector<1x1x2x1x1x1xi32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
deleted file mode 100644
index d1401ad7853fc9..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
-
-// CHECK-LABEL: @contiguous
-// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
-func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
-  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
-  return %2 : vector<4xi32>
-}
-
-// CHECK-LABEL: @non_full_size
-// CHECK-NEXT:   vector.extract_strided_slice
-func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
-  return %1 : vector<1x1x1x1x1x2xi32>
-}
-
-// CHECK-LABEL: @non_full_inner_size
-// CHECK-NEXT:    vector.extract_strided_slice
-func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
-  return %1 : vector<1x1x2x1x1x1xi32>
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d91e955b70641e..72aaa7dc4f8973 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
-struct TestVectorContiguousExtractStridedSliceToExtract
-    : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestVectorExtractStridedSliceLowering)
-
-  StringRef getArgument() const final {
-    return "test-vector-contiguous-extract-strided-slice-to-extract";
-  }
-  StringRef getDescription() const final {
-    return "Test lowering patterns that rewrite simple cases of N-D "
-           "extract_strided_slice, where the slice is contiguous, into extract "
-           "and shape_cast";
-  }
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestVectorBreakDownBitCast
     : public PassWrapper<TestVectorBreakDownBitCast,
                          OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
-  PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
-
   PassRegistration<TestVectorBreakDownBitCast>();
 
   PassRegistration<TestCreateVectorBroadcast>();

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, but just please update the description to justify why it is a good canonicalization and state that we don't see any loss of semantics in the transformation here.

Thanks!

@bjacob bjacob force-pushed the canonicalize-extract-strided-slice branch from 80f5b9c to 17d9d8e Compare October 9, 2024 11:04
@bjacob
Copy link
Contributor Author

bjacob commented Oct 9, 2024

The 2nd commit pushed here addresses review comments AND fixes a bug (thankfully just by deleting a few lines). I was over-zealously trying to prevent leading unit dims in the extract result type and in the process I was causing numOffsets to exceed offsets.size() which would cause an out-of-bounds read. The test updates reflects that, there is now a leading unit dim there.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

@bjacob bjacob merged commit a9ebdbb into llvm:main Oct 9, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants