Skip to content

[mlir] Fix type transformation in DropUnitDimFromElementwiseOps #75430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 14, 2023

Conversation

pzread
Copy link
Member

@pzread pzread commented Dec 14, 2023

Use operand and result types to build the corresponding new types in DropUnitDimFromElementwiseOps.

Elementwise ops only guarantee to have the same shape on their operands and results, but don't guarantee to have the same element type.

This change also enhances the preconditions.

@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Jerry Wu (pzread)

Changes

Use operand and result types to build the corresponding new types in DropUnitDimFromElementwiseOps.

Elementwise ops only guarantee to have the same shape on their operands and results, but don't guarantee to have the same element type.


Full diff: https://github.com/llvm/llvm-project/pull/75430.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+22-12)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+15-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..1175da921d7ba1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,24 @@ struct DropUnitDimFromElementwiseOps final
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumResults() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
 
-    // Check the pre-condiitions. For `Elementwise` Ops all operands
-    // are guaranteed to have identical shapes and it suffices to only check the
-    // first one.
-    auto op1 = op->getOperands()[0];
-    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
-    if (!sourceVectorType)
+    auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultVectorType)
       return failure();
 
+    if (llvm::any_of(op->getOperands(), [](auto operand) {
+          return !isa<VectorType>(operand.getType());
+        })) {
+      return failure();
+    }
+
+    // Check the pre-conditions. For `Elementwise` Ops all operands are
+    // guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+
     if (sourceVectorType.getRank() < 2)
       return failure();
 
@@ -1506,23 +1513,26 @@ struct DropUnitDimFromElementwiseOps final
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-    VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
 
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
+      auto opVectorType = cast<VectorType>(operand.getType());
+      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
+    VectorType newResultVectorType =
+        VectorType::Builder(resultVectorType).dropDim(dim);
     // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
-                        newVType, op->getAttrs());
+                        newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast to
-    // the result
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+    // Restore the leading/trailing unit dim by applying vector.shape_cast
+    // to the result
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 
 // -----
 
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+   %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_sitofp(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
 // All shape casts are folded away
 
 func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-

@pzread pzread requested a review from dcaballe December 14, 2023 07:18
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this and thanks for the fix!

LGTM, modulo a couple of questions.

@pzread pzread force-pushed the enhance-drop-unit-dim-from-elementwise branch from 7eabd61 to 9f8506a Compare December 14, 2023 08:05
@pzread pzread merged commit 2c9ba9c into llvm:main Dec 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants