-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract #111541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,12 +329,70 @@ class DecomposeNDExtractStridedSlice | |
} | ||
}; | ||
|
||
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the | ||
/// slice is contiguous, into extract and shape_cast. | ||
Comment on lines
+332
to
+333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice-to-have - MLIR example with "before" and "after" |
||
class ContiguousExtractStridedSliceToExtract final | ||
: public OpRewritePattern<ExtractStridedSliceOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, | ||
PatternRewriter &rewriter) const override { | ||
if (op.hasNonUnitStrides()) { | ||
return failure(); | ||
} | ||
bjacob marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Value source = op.getOperand(); | ||
auto sourceType = cast<VectorType>(source.getType()); | ||
if (sourceType.isScalable()) { | ||
return failure(); | ||
} | ||
|
||
// Compute the number of offsets to pass to ExtractOp::build. That is the | ||
// difference between the source rank and the desired slice rank. We walk | ||
// the dimensions from innermost out, and stop when the next slice dimension | ||
// is not full-size. | ||
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes()); | ||
int numOffsets; | ||
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) { | ||
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) { | ||
break; | ||
} | ||
} | ||
|
||
// If not even the inner-most dimension is full-size, this op can't be | ||
// rewritten as an ExtractOp. | ||
if (numOffsets == sourceType.getRank()) { | ||
return failure(); | ||
} | ||
|
||
// Avoid generating slices that have unit outer dimensions. The shape_cast | ||
// op that we create below would take bad generic fallback patterns | ||
// (ShapeCastOpRewritePattern). | ||
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) { | ||
++numOffsets; | ||
} | ||
|
||
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets()); | ||
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); | ||
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source, | ||
extractOffsets); | ||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract); | ||
return success(); | ||
} | ||
}; | ||
|
||
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<DecomposeDifferentRankInsertStridedSlice, | ||
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit); | ||
} | ||
|
||
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(), | ||
benefit); | ||
} | ||
|
||
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( | ||
RewritePatternSet &patterns, | ||
std::function<bool(ExtractStridedSliceOp)> controlFn, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can drop Also, how about a test with non-unit strides? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re non-unit strides: I thought about it, but because the pattern checks for full-size slices (which it really what it means by "contiguous", a slightly misleading term here), that implies unit strides. So the pattern's check for unit strides is redundant but I left it because otherwise I would have had to add a comment explaining that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would remove it then. The presence of that check suggests that it's something significant, but from your explanation I see that it isn't (unless we were able to find an edge case where it matters). |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,35 @@ | ||||||
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s | ||||||
|
||||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i8 | ||||||
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8> | ||||||
// CHECK: return %[[EXTRACT]] : vector<8xi8> | ||||||
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> { | ||||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8> | ||||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pattern itself generates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense. Then, could you either add |
||||||
return %2 : vector<8xi8> | ||||||
} | ||||||
|
||||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32 | ||||||
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> | ||||||
// CHECK: return %[[EXTRACT]] : vector<4xi32> | ||||||
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> { | ||||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32> | ||||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32> | ||||||
return %2 : vector<4xi32> | ||||||
} | ||||||
Comment on lines
+15
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that you can skip this test. The pattern that you added doesn't really care about the element type, so this is just repeating the test above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dropped the i8 one. These tests are subtly different in the distribution of unit dims but the i32 one is the more interesting one to keep. |
||||||
|
||||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1 | ||||||
// CHECK: vector.extract_strided_slice | ||||||
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It wasn't immediately obvious to me what was wrong with this case, so I suggest encoding that info in the test name.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed. |
||||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32> | ||||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32> | ||||||
return %2 : vector<2xi32> | ||||||
} | ||||||
|
||||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2 | ||||||
// CHECK: vector.extract_strided_slice | ||||||
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> { | ||||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32> | ||||||
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32> | ||||||
return %2 : vector<2xi32> | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on why isn't this a good canonicalization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things (that happened in the testcase I looked at, where these ops where extracting parts from a matrix tile to feed into GPU matrix multiplication intrinsics):
extract
is more constrained thanextract_strided_slice
, so it is more likely to have a good lowering.extract_strided_slice
producing a vector with leading unit dims, followed by ashape_cast
dropping the unit dims. Thatshape_cast
was hitting the fallback lowering pattern,ShapeCastOpRewritePattern
. Now that theextract_strided_slice
is rewritten into a pair (extract
,shape_cast
), the twoshape_cast
fold.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry @joker-eph, I mis-parsed your question --- read "is" instead of "isn't".
No opinion about whether this should be a "canonicalization". I wasn't too sure that I wanted to enter that debate; my pattern is replacing 1 op with 2 ops so I expected a nontrivial debate. Feel free!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like this kind of consideration of be carefully done before adding random patterns to the codebase.
I am really concerned about the lack of design coming with adding single pattern with single "populateXXX" methods. This can't scale and does not help defining a cohesive system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah individual patterns and populate is a problem. The proliferation of such methods is difficult to keep track of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a necessary criteria, for example @bjacob above mentioned:
These are enough of a motivation to justify a canonicalization to me. Then on top of this is the question of whether the transformation is potentially losing semantics that can't be trivially reconstructed (such aspect would likely make it clearly not suitable for canonicalization).
Make it a canonicalization.
If this can't be grouped in a cohesive pass that achieve something meaningful that we can reason about, then yeah please keep all these patterns out-of-tree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the risk of duplication in all downstream projects? I am fine with that if that is the consensus, but IMO many times, the "pattern" across patterns only appears when they are all put in the same place, i.e. it is not always possible to have overarching plans for everything from the get go, but rather you build that intuition over time.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have much concerns about this: if it is really valuable, then it may get upstreamed properly eventually. We don't need to take on any random pattern without organization or rationale, just because it fitted some particular downstream flow.
This does not really match the way I approach the design. Do you have prior examples of success story that would support this? Right now I see an abuse of populateXXXPattern without much convergence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for patterns, but has been the case for a couple of Interfaces that were initially just patterns that did similar things and then were consolidated to use interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good discussion! To help keep it concrete, I gave the canonicalizer idea a try: #111614.