-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Fix type transformation in DropUnitDimFromElementwiseOps #75430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Fix type transformation in DropUnitDimFromElementwiseOps #75430
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Jerry Wu (pzread) ChangesUse operand and result types to build the corresponding new types in Elementwise ops only guarantee to have the same shape on their operands and results, but don't guarantee to have the same element type. Full diff: https://github.com/llvm/llvm-project/pull/75430.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..1175da921d7ba1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,24 @@ struct DropUnitDimFromElementwiseOps final
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (op->getNumResults() != 1)
+ if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
- // Check the pre-condiitions. For `Elementwise` Ops all operands
- // are guaranteed to have identical shapes and it suffices to only check the
- // first one.
- auto op1 = op->getOperands()[0];
- auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
- if (!sourceVectorType)
+ auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultVectorType)
return failure();
+ if (llvm::any_of(op->getOperands(), [](auto operand) {
+ return !isa<VectorType>(operand.getType());
+ })) {
+ return failure();
+ }
+
+ // Check the pre-conditions. For `Elementwise` Ops all operands are
+ // guaranteed to have identical shapes and it suffices to only check the
+ // first one.
+ auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+
if (sourceVectorType.getRank() < 2)
return failure();
@@ -1506,23 +1513,26 @@ struct DropUnitDimFromElementwiseOps final
// Drop leading/trailing unit dim by applying vector.shape_cast to all
// operands
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
- VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
+ auto opVectorType = cast<VectorType>(operand.getType());
+ VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
+ VectorType newResultVectorType =
+ VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
- newVType, op->getAttrs());
+ newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast to
- // the result
- rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+ // Restore the leading/trailing unit dim by applying vector.shape_cast
+ // to the result
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
// -----
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+ %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
// All shape casts are folded away
func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
-
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this and thanks for the fix!
LGTM, modulo a couple of questions.
7eabd61
to
9f8506a
Compare
Use operand and result types to build the corresponding new types in
DropUnitDimFromElementwiseOps
.Elementwise ops only guarantee to have the same shape on their operands and results, but don't guarantee to have the same element type.
This change also enhances the preconditions.