-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. #92934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
8c33c6c
to
9acb3da
Compare
@llvm/pr-subscribers-mlir-vector Author: Hugo Trachino (nujaa) ChangesGeneralizes discussed here. Full diff: https://github.com/llvm/llvm-project/pull/92934.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..e772d4bbea311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1607,7 +1607,23 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-/// For vectors with either leading or trailing unit dim, replaces:
+FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
+ VectorType newVT = VT;
+ int removed = 0;
+ auto shape = VT.getShape();
+ for (unsigned i = 0; i < shape.size(); i++) {
+ if (shape[i] == 1 && !VT.getScalableDims()[i]) {
+ newVT = VectorType::Builder(newVT).dropDim(i - removed);
+ removed++;
+ }
+ }
+
+ if (removed == 0)
+ return failure();
+ return newVT;
+}
+
+/// For vectors with at least an unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
@@ -1652,42 +1668,30 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType)
- return failure();
- if (sourceVectorType.getRank() < 2)
- return failure();
-
- bool hasTrailingDimUnitFixed =
- ((sourceVectorType.getShape().back() == 1) &&
- (!sourceVectorType.getScalableDims().back()));
- bool hasLeadingDimUnitFixed =
- ((sourceVectorType.getShape().front() == 1) &&
- (!sourceVectorType.getScalableDims().front()));
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
- // operands
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
- auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+ auto newVType = dropNonScalableUnitDimType(opVectorType);
+ if (failed(newVType)) {
+ return failure();
+ }
+ auto opSC =
+ rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
- VectorType::Builder(resultVectorType).dropDim(dim);
- // Create an updated elementwise Op without leading/trailing unit dim
+ dropNonScalableUnitDimType(resultVectorType).value();
+ // Create an updated elementwise Op without unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast
- // to the result
+ // Restore the unit dim by applying vector.shape_cast to the result
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..03c19742355bf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -459,6 +459,26 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
// CHECK-128B-NOT: memref.collapse_shape
+// -----
+
+func.func @fold_unit_center_dim_scalable(%arg0 : vector<8x1x[1]xf128>,
+ %arg1 : vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+ %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]xf128> to vector<8x1x[1]xf128>
+ %add = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]xf128>
+ %res = vector.shape_cast %add : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+ return %res : vector<8x[1]xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_center_dim_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]xf128>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]xf128> to vector<8x[1]xf128>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]xf128>
+// CHECK: return %[[VAL_4]] : vector<8x[1]xf128>
+
+
+
// -----
|
@llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesGeneralizes discussed here. Full diff: https://github.com/llvm/llvm-project/pull/92934.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..e772d4bbea311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1607,7 +1607,23 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-/// For vectors with either leading or trailing unit dim, replaces:
+FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
+ VectorType newVT = VT;
+ int removed = 0;
+ auto shape = VT.getShape();
+ for (unsigned i = 0; i < shape.size(); i++) {
+ if (shape[i] == 1 && !VT.getScalableDims()[i]) {
+ newVT = VectorType::Builder(newVT).dropDim(i - removed);
+ removed++;
+ }
+ }
+
+ if (removed == 0)
+ return failure();
+ return newVT;
+}
+
+/// For vectors with at least an unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
@@ -1652,42 +1668,30 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType)
- return failure();
- if (sourceVectorType.getRank() < 2)
- return failure();
-
- bool hasTrailingDimUnitFixed =
- ((sourceVectorType.getShape().back() == 1) &&
- (!sourceVectorType.getScalableDims().back()));
- bool hasLeadingDimUnitFixed =
- ((sourceVectorType.getShape().front() == 1) &&
- (!sourceVectorType.getScalableDims().front()));
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
- // operands
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
- auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+ auto newVType = dropNonScalableUnitDimType(opVectorType);
+ if (failed(newVType)) {
+ return failure();
+ }
+ auto opSC =
+ rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
- VectorType::Builder(resultVectorType).dropDim(dim);
- // Create an updated elementwise Op without leading/trailing unit dim
+ dropNonScalableUnitDimType(resultVectorType).value();
+ // Create an updated elementwise Op without unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast
- // to the result
+ // Restore the unit dim by applying vector.shape_cast to the result
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..03c19742355bf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -459,6 +459,26 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
// CHECK-128B-NOT: memref.collapse_shape
+// -----
+
+func.func @fold_unit_center_dim_scalable(%arg0 : vector<8x1x[1]xf128>,
+ %arg1 : vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+ %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]xf128> to vector<8x1x[1]xf128>
+ %add = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]xf128>
+ %res = vector.shape_cast %add : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+ return %res : vector<8x[1]xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_center_dim_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]xf128>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]xf128> to vector<8x[1]xf128>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]xf128>
+// CHECK: return %[[VAL_4]] : vector<8x[1]xf128>
+
+
+
// -----
|
Hi, @banach-space and @MacDue , let's get started with the drop unit dimension MRs. |
I don’t see unit dims mentioned in that particular thread. Did you mean some other thread? |
Indeed, thanks for checking it out. updated to https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa |
Thanks, now I see what you meant :) Nit - I would add a bit more context in your PR summary to make it more self-contained (I find it super useful when commit msgs contain all the context). Specifically, I’d add that this was suggested in the discussion on improving the lowerings for ArmSME. And then add a link for folks interested in the discussion. |
Sorry for the delay, was OOO last week. I've finally managed to catch-up with the context and I have one high-level comment/question. The pattern that you are updating was designed to help with specific scenarios that are documented here: llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp Lines 1610 to 1638 in b0b3596
However, those cases look very different to what you are trying to "fix":
(copied from Ben's "Canonical form") %lhsCast = vector.shape_cast %inputLHS : vector<[4]xf32> to vector<[4]x1xf32>
%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsCast = vector.shape_cast %inputRHS : vector<[4]xf32> to vector<1x[4]xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
%mul = arith.mulf %lhsT, %rhs : vector<[4]x[4]x1xf32>
%tileMask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
%dropDim = vector.shape_cast %mul : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%addAcc = arith.addf %acc, %dropDim : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32> In the example above there aren't that many internal unit dims. Here are 2 examples: %rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32> Would EDIT Btw, "vector-transfer-flatten.mlir" is failing with this change - not sure why pre-commit CI didn't capture that. In fact, looks like the tests didn't run 🤔 |
I simply found it weird elementwiseOps would only clean off outer dims when other similar patterns would not be limited by it. Simply some normalization. Also originally I thought I could reuse this method for Broadcast care but I realised unit unbroadcasted UnitDim were an issue.
Hah, committed a fix directly inside Github but apparently I got too confident. 👼 |
newOperands.push_back(opSC); | ||
} | ||
|
||
VectorType newResultVectorType = | ||
VectorType::Builder(resultVectorType).dropDim(dim); | ||
// Create an updated elementwise Op without leading/trailing unit dim | ||
dropNonScalableUnitDimType(resultVectorType).value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good example that we expect it will always return valid values.
f5da261
to
5e1c3bd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good to me, just few nits! Thanks for pushing on this!
if (dim != 1 || isScalable) { | ||
newShape.push_back(dim); | ||
newScalableDims.push_back(isScalable); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: prefer using early-exit and continue to simplify code. It also saves the levels of nesting for us.
if (dim != 1 || isScalable) { | |
newShape.push_back(dim); | |
newScalableDims.push_back(isScalable); | |
} | |
if (dim == 1 && !isScalable) | |
continue; | |
newShape.push_back(dim); | |
newScalableDims.push_back(isScalable); |
https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit. Thanks!
Rename test functions. Add scalable specific test.
Added punctutation, removed inline code '`' in Doxygen
36ef826
to
6b2204f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@MacDue Do you have any other comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit:
DCE. Co-authored-by: Benjamin Maxwell <[email protected]>
…n leading / trailing dimensions. (llvm#92934)" This reverts commit 2c06fb8.
…n leading / trailing dimensions. (llvm#92934)" This reverts commit 2c06fb8.
…n leading / trailing dimensions. (llvm#92934)" This reverts commit 2c06fb8.
The commit breaks downstream project (iree-org/iree#17778). Here is the repro: func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
%0 = arith.mulf %arg0, %cst : vector<1x1xf32>
return %0 : vector<1x1xf32>
} Error:
The commit seems problematic for corner cases, can we revert it and re-land it with a fix? https://llvm.org/docs/DeveloperPolicy.html#patch-reversion-policy |
Any upstream repro is enough for a revert. |
got it, thanks! |
…n leading / trailing dimensions." (#97652) Reverts #92934 because it breaks some lowering. To repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir` ```mlir func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> { %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32> %0 = arith.mulf %arg0, %cst : vector<1x1xf32> return %0 : vector<1x1xf32> } ```
…n leading / trailing dimensions." (llvm#97652) Reverts llvm#92934 because it breaks some lowering. To repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir` ```mlir func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> { %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32> %0 = arith.mulf %arg0, %cst : vector<1x1xf32> return %0 : vector<1x1xf32> } ```
…g / trailing dimensions. (llvm#92934) Generalizes `DropUnitDimFromElementwiseOps` to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). --------- Co-authored-by: Benjamin Maxwell <[email protected]>
…g / trailing dimensions. (llvm#92934) Generalizes `DropUnitDimFromElementwiseOps` to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). --------- Co-authored-by: Benjamin Maxwell <[email protected]>
Generalizes
DropUnitDimFromElementwiseOps
to support inner unit dimensions.This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
discussed here.