Skip to content

[mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract #111541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 8, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 8, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benoit Jacob (bjacob)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/111541.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+64)
  • (added) mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir (+40)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+23)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..9ad78cc282b674 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,6 +235,12 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
     PatternBenefit benefit = 1);
 
+/// Populate `patterns` with a pattern to rewrite simple cases of N-D
+/// extract_strided_slice, where the slice is contiguous, into extract and
+/// shape_cast.
+void populateVectorContiguousExtractStridedSliceToExtractPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
 /// based on the destination vector shape. Bitcasts from a lower bitwidth
 /// element type to a higher bitwidth one are extracted from the lower bitwidth
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..324c7b84ebfa0d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,12 +329,76 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+                                       SmallVectorImpl<int64_t> &results) {
+  for (auto attr : arrayAttr)
+    results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+}
+
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    SmallVector<int64_t> sizes;
+    populateFromInt64AttrArray(op.getSizes(), sizes);
+    Value source = op.getOperand();
+    ShapedType sourceType = cast<ShapedType>(source.getType());
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // difference between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    int numOffsets;
+    for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank()) {
+      return failure();
+    }
+
+    // Avoid generating slices that have unit outer dimensions. The shape_cast
+    // op that we create below would take bad generic fallback patterns
+    // (ShapeCastOpRewritePattern).
+    while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
+      ++numOffsets;
+    }
+
+    SmallVector<int64_t> offsets;
+    populateFromInt64AttrArray(op.getOffsets(), offsets);
+    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+                                                       extractOffsets);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op->getResultTypes()[0], extract);
+    return success();
+  }
+};
+
 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
+void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
+                                                       benefit);
+}
+
 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     RewritePatternSet &patterns,
     std::function<bool(ExtractStridedSliceOp)> controlFn,
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
new file mode 100644
index 00000000000000..da8ff492431629
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -split-input-file -test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i8
+// CHECK:       vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
+
+func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
+  return %2 : vector<8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32
+// CHECK:        vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
+// CHECK:        vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
+  return %2 : vector<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
+// CHECK:        vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
+  return %2 : vector<2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72aaa7dc4f8973..d91e955b70641e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
+struct TestVectorContiguousExtractStridedSliceToExtract
+    : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorExtractStridedSliceLowering)
+
+  StringRef getArgument() const final {
+    return "test-vector-contiguous-extract-strided-slice-to-extract";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering patterns that rewrite simple cases of N-D "
+           "extract_strided_slice, where the slice is contiguous, into extract "
+           "and shape_cast";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestVectorBreakDownBitCast
     : public PassWrapper<TestVectorBreakDownBitCast,
                          OperationPass<func::FuncOp>> {
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
+  PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
+
   PassRegistration<TestVectorBreakDownBitCast>();
 
   PassRegistration<TestCreateVectorBroadcast>();

@bjacob bjacob requested a review from kuhar October 8, 2024 15:04
@kuhar kuhar changed the title [MLIR] Vector: add pattern to rewrite contiguous ExtractStridedSlice into Extract. [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract Oct 8, 2024
@bjacob bjacob requested a review from kuhar October 8, 2024 15:37
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % the RUN line

@bjacob bjacob merged commit 10054ba into llvm:main Oct 8, 2024
5 of 6 checks passed
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this was landed while I was reviewing ...

Please, could you address my comments post-commit?

Comment on lines +338 to +339
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice-to-have - MLIR example with "before" and "after"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can drop @extract_strided_slice_to_extract from test function names. Similar info is already provided in the file name, so it's just noise. Instead, I'd "encode" what's unique about every test. One specific suggestion further down :)

Also, how about a test with non-unit strides?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re non-unit strides: I thought about it, but because the pattern checks for full-size slices (which it really what it means by "contiguous", a slightly misleading term here), that implies unit strides. So the pattern's check for unit strides is redundant but I left it because otherwise I would have had to add a comment explaining that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the pattern's check for unit strides is redundant but I left it because otherwise I would have had to add a comment explaining that.

I would remove it then. The presence of that check suggests that it's something significant, but from your explanation I see that it isn't (unless we were able to find an edge case where it matters).

Comment on lines +16 to +20
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
return %2 : vector<4xi32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you can skip this test. The pattern that you added doesn't really care about the element type, so this is just repeating the test above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped the i8 one. These tests are subtly different in the distribution of unit dims but the i32 one is the more interesting one to keep.


func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why vector.shape_cast if the pattern doesn't care about anything apart from vector.extract_strided_slice? Removing it would reduce noise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern itself generates extract + shape_cast. The generated shape_cast folds with the one here in the test source. So by including the shape_cast in the source, I remove the shape_cast in the output (that I would otherwise have to CHECK for) and I have the test ensure that the folding happens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. Then, could you either add CHECK-NOT: vector.shape_cast or replace CHECK with CHECK-NEXT? Both would make sure that there's no shape_cast in the output and everything worked as expected :)


// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
// CHECK: vector.extract_strided_slice
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't immediately obvious to me what was wrong with this case, so I suggest encoding that info in the test name.

Suggested change
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
func.func @non_contiguous_no_full_dim_slice(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

@bjacob
Copy link
Contributor Author

bjacob commented Oct 8, 2024

@banach-space , sent #111552

/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on why isn't this a good canonicalization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things (that happened in the testcase I looked at, where these ops where extracting parts from a matrix tile to feed into GPU matrix multiplication intrinsics):

  1. extract is more constrained than extract_strided_slice, so it is more likely to have a good lowering.
  2. my use case was, similarly to the test added in this PR, a extract_strided_slice producing a vector with leading unit dims, followed by a shape_cast dropping the unit dims. That shape_cast was hitting the fallback lowering pattern, ShapeCastOpRewritePattern. Now that the extract_strided_slice is rewritten into a pair (extract, shape_cast), the two shape_cast fold.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @joker-eph, I mis-parsed your question --- read "is" instead of "isn't".

No opinion about whether this should be a "canonicalization". I wasn't too sure that I wanted to enter that debate; my pattern is replacing 1 op with 2 ops so I expected a nontrivial debate. Feel free!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like this kind of consideration of be carefully done before adding random patterns to the codebase.

I am really concerned about the lack of design coming with adding single pattern with single "populateXXX" methods. This can't scale and does not help defining a cohesive system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah individual patterns and populate is a problem. The proliferation of such methods is difficult to keep track of.

  1. I dont have enough to go by whether this is a canonicalization or not. In absence of evidence, and (at least some) justification, id bias towards not a canonicalization. Basically question is if this pattern will unlock other canonicalizations, then its a strong signal that this is a canonicalization.
  2. @joker-eph whats your solution for such cases. This kind of thing is a very point-fix for something. There is really not a over-arching problem being solved really. So one option is, we dont add it to MLIR itself, but keep it downstream. That will only make such things be duplicated across all MLIR projects. Where/how do we house such one-off things in core.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically question is if this pattern will unlock other canonicalizations,

I don't think this is a necessary criteria, for example @bjacob above mentioned:

  • extract is more constrained than extract_strided_slice, so it is more likely to have a good lowering.
  • a extract_strided_slice producing a vector with leading unit dims, followed by a shape_cast dropping the unit dims. That shape_cast was hitting the fallback lowering pattern, ShapeCastOpRewritePattern.

These are enough of a motivation to justify a canonicalization to me. Then on top of this is the question of whether the transformation is potentially losing semantics that can't be trivially reconstructed (such aspect would likely make it clearly not suitable for canonicalization).

@joker-eph whats your solution for such cases.

Make it a canonicalization.

So one option is, we dont add it to MLIR itself, but keep it downstream.

If this can't be grouped in a cohesive pass that achieve something meaningful that we can reason about, then yeah please keep all these patterns out-of-tree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can't be grouped in a cohesive pass that achieve something meaningful that we can reason about, then yeah please keep all these patterns out-of-tree.

At the risk of duplication in all downstream projects? I am fine with that if that is the consensus, but IMO many times, the "pattern" across patterns only appears when they are all put in the same place, i.e. it is not always possible to have overarching plans for everything from the get go, but rather you build that intuition over time.

Copy link
Collaborator

@joker-eph joker-eph Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the risk of duplication in all downstream projects?

I don't have much concerns about this: if it is really valuable, then it may get upstreamed properly eventually. We don't need to take on any random pattern without organization or rationale, just because it fitted some particular downstream flow.

but IMO many times, the "pattern" across patterns only appears when they are all put in the same place

This does not really match the way I approach the design. Do you have prior examples of success story that would support this? Right now I see an abuse of populateXXXPattern without much convergence.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not really match the way I approach the design. Do you have prior examples of success story that would support this? Right now I see an abuse of populateXXXPattern without much convergence.

Not for patterns, but has been the case for a couple of Interfaces that were initially just patterns that did similar things and then were consolidated to use interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good discussion! To help keep it concrete, I gave the canonicalizer idea a try: #111614.

bjacob added a commit that referenced this pull request Oct 9, 2024
 into a canonicalization (#111614)

This is a reasonable canonicalization because `extract` is more
constrained than `extract_strided_slices`, so there is no loss of
semantics here, just lifting an op to a special-case higher/constrained
op. And the additional `shape_cast` is merely adding leading unit dims
to match the original result type.

Context: discussion on #111541. I wasn't sure how this would turn out,
but in the process of writing this PR, I discovered at least 2 bugs in
the pattern introduced in #111541, which shows the value of shared
canonicalization patterns which are exercised on a high number of
testcases.

---------

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants