-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern #92938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nujaa
wants to merge
3
commits into
llvm:main
Choose a base branch
from
nujaa:hugo.DropUnitDimFromBroadcastOp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1703,6 +1703,67 @@ struct DropUnitDimFromElementwiseOps final | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// Drops unit non scalable dimensions inside a broadcastOp which are shared | ||||||||||||||||||||||||
/// among source and result with shape_casts. | ||||||||||||||||||||||||
/// The newly inserted shape_cast Ops fold (before Op) and then | ||||||||||||||||||||||||
/// restore the unit dim after Op. Source type is required to be a vector. | ||||||||||||||||||||||||
/// | ||||||||||||||||||||||||
/// Ex: | ||||||||||||||||||||||||
/// ``` | ||||||||||||||||||||||||
/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32> | ||||||||||||||||||||||||
/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32> | ||||||||||||||||||||||||
/// ``` | ||||||||||||||||||||||||
/// | ||||||||||||||||||||||||
/// Gets converted to: | ||||||||||||||||||||||||
/// | ||||||||||||||||||||||||
/// ``` | ||||||||||||||||||||||||
/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> | ||||||||||||||||||||||||
/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32> | ||||||||||||||||||||||||
/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to | ||||||||||||||||||||||||
/// vector<1x3x1x4xf32> | ||||||||||||||||||||||||
/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to | ||||||||||||||||||||||||
/// vector<1x3x4xf32> | ||||||||||||||||||||||||
/// ``` | ||||||||||||||||||||||||
/// %cast_new and %cast can be folded away. | ||||||||||||||||||||||||
struct DropUnitDimFromBroadcastOp final | ||||||||||||||||||||||||
: public OpRewritePattern<vector::BroadcastOp> { | ||||||||||||||||||||||||
using OpRewritePattern::OpRewritePattern; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, | ||||||||||||||||||||||||
PatternRewriter &rewriter) const override { | ||||||||||||||||||||||||
auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType()); | ||||||||||||||||||||||||
if (!srcVecTy) | ||||||||||||||||||||||||
return failure(); | ||||||||||||||||||||||||
auto resVecTy = broadcastOp.getResultVectorType(); | ||||||||||||||||||||||||
auto srcVecTyBuilder = VectorType::Builder(srcVecTy); | ||||||||||||||||||||||||
auto resVecTyBuilder = VectorType::Builder(resVecTy); | ||||||||||||||||||||||||
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); | ||||||||||||||||||||||||
// Reversing allows us to remove dims from the back without keeping track of | ||||||||||||||||||||||||
// removed dimensions. | ||||||||||||||||||||||||
for (const auto [reversedIndex, dim] : | ||||||||||||||||||||||||
llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) { | ||||||||||||||||||||||||
unsigned srcDimIndex = srcVecTy.getRank() - reversedIndex - 1; | ||||||||||||||||||||||||
unsigned resDimIndex = resVecTy.getRank() - reversedIndex - 1; | ||||||||||||||||||||||||
if (dim == 1 && !srcVecTy.getScalableDims()[srcDimIndex] && | ||||||||||||||||||||||||
!broadcastedUnitDims.contains(srcDimIndex)) { | ||||||||||||||||||||||||
srcVecTyBuilder.dropDim(srcDimIndex); | ||||||||||||||||||||||||
resVecTyBuilder.dropDim(resDimIndex); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (VectorType(srcVecTyBuilder) == srcVecTy) | ||||||||||||||||||||||||
return failure(); | ||||||||||||||||||||||||
auto loc = broadcastOp->getLoc(); | ||||||||||||||||||||||||
auto newSource = rewriter.create<vector::ShapeCastOp>( | ||||||||||||||||||||||||
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource()); | ||||||||||||||||||||||||
Comment on lines
+1754
to
+1758
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: avoid constructing the new vector type twice:
Suggested change
|
||||||||||||||||||||||||
auto newOp = rewriter.create<vector::BroadcastOp>( | ||||||||||||||||||||||||
loc, VectorType(resVecTyBuilder), newSource); | ||||||||||||||||||||||||
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVecTy, | ||||||||||||||||||||||||
newOp.getResult()); | ||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// Pattern to eliminate redundant zero-constants added to reduction operands. | ||||||||||||||||||||||||
/// It's enough for there to be one initial zero value, so we can eliminate the | ||||||||||||||||||||||||
/// extra ones that feed into `vector.reduction <add>`. These get created by the | ||||||||||||||||||||||||
|
@@ -1827,8 +1888,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, | |||||||||||||||||||||||
|
||||||||||||||||||||||||
void mlir::vector::populateDropUnitDimWithShapeCastPatterns( | ||||||||||||||||||||||||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||||||||||||||||||||||||
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>( | ||||||||||||||||||||||||
patterns.getContext(), benefit); | ||||||||||||||||||||||||
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp, | ||||||||||||||||||||||||
ShapeCastOpFolder>(patterns.getContext(), benefit); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
void mlir::vector::populateBubbleVectorBitCastOpPatterns( | ||||||||||||||||||||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -535,6 +535,60 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>, | |
|
||
// ----- | ||
|
||
func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> { | ||
%bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128> | ||
return %bc : vector<4x1x[1]x3x1xf128> | ||
} | ||
|
||
// CHECK-LABEL: func.func @drop_broadcast_unit_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Use better names than the generated |
||
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128> | ||
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128> | ||
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128> | ||
// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128> | ||
|
||
// ----- | ||
|
||
func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> { | ||
%bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32> | ||
return %bc : vector<1x1xf32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> { | ||
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32> | ||
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32> | ||
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32> | ||
// CHECK: return %[[VAL_3]] : vector<1x1xf32> | ||
|
||
// ----- | ||
|
||
// Generated unit dimensions through broadcasts are not dropped as we prefer to have a | ||
// single broadcast rather than a broadcast and a shape_cast. | ||
func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> { | ||
%bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32> | ||
return %bc : vector<3x1x4xf32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> { | ||
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32> | ||
// CHECK: return %[[VAL_1]] : vector<3x1x4xf32> | ||
|
||
// ----- | ||
|
||
// A broadcasted unit dimension cannot be dropped to prevent type mismatch. | ||
func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> { | ||
%bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32> | ||
return %bc : vector<2x3x4xf32> | ||
} | ||
// CHECK-LABEL: func.func @drop_broadcasted_unit_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> { | ||
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32> | ||
// CHECK: return %[[VAL_1]] : vector<2x3x4xf32> | ||
|
||
// ----- | ||
|
||
func.func @negative_out_of_bound_transfer_read( | ||
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { | ||
%c0 = arith.constant 0 : index | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Builders may be less efficient than just appending the dims not dropped to a new vector (but this is probably not much of a concern given the number of dims is normally < 5-ish).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no need to change this -- just a note)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I tried implementing it with shapes to be appended. But, considering one has to rebuild also scalableDims for both the new source and result type, I found it was generating lots code which is hidden thanks to dropDim.
Also, for some reason, I was able to generate the base of the new resultShape creating a subvector of it with :
but for Scalable Dims I get some errors like this and I dont think I should be changing the behaviour of SmallVector. I suspect it comes from the way ScalableDims are defined.
In order to fix it, I needed to create an ugly vector inserting explicit casts like
If you want, I can push my solution on top and we revert it if we prefer it as it currently is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those
static_casts
look very suspect 😅. It looks like that's just going to make a SmallVector of two (likely true) bools. I think keeping it as-is is fine as it's simpler, and likely not really a performance concern (vector types are normally small).Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I've been trying to make writing code like this easier for scalable dims (for a while now 😅). With my current attempt #96236, I think you'd be able to rewrite this as (untested!!!):
Please take a look at the PR if you think it'd be useful :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks indeed way simpler, very fortunate PR you submitted. I will probably not be able to review it at the moment as I need to sort some things out today. But I hope to give it a look. (I'll be off for a while, do not wait for my review).