Skip to content

Commit 2c9ba9c

Browse files
author
Jerry Wu
authored
[mlir] Fix type transformation in DropUnitDimFromElementwiseOps (#75430)
Use operand and result types to build the corresponding new types in `DropUnitDimFromElementwiseOps`.
1 parent 726830f commit 2c9ba9c

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,17 +1480,17 @@ struct DropUnitDimFromElementwiseOps final
14801480
using OpTraitRewritePattern::OpTraitRewritePattern;
14811481
LogicalResult matchAndRewrite(Operation *op,
14821482
PatternRewriter &rewriter) const override {
1483-
if (op->getNumResults() != 1)
1483+
if (op->getNumResults() != 1 || op->getNumRegions() != 0)
14841484
return failure();
14851485

1486-
// Check the pre-condiitions. For `Elementwise` Ops all operands
1487-
// are guaranteed to have identical shapes and it suffices to only check the
1488-
// first one.
1489-
auto op1 = op->getOperands()[0];
1490-
auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
1491-
if (!sourceVectorType)
1486+
auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1487+
if (!resultVectorType)
14921488
return failure();
14931489

1490+
// Check the pre-conditions. For `Elementwise` Ops all operands are
1491+
// guaranteed to have identical shapes and it suffices to only check the
1492+
// first one.
1493+
auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
14941494
if (sourceVectorType.getRank() < 2)
14951495
return failure();
14961496

@@ -1506,23 +1506,26 @@ struct DropUnitDimFromElementwiseOps final
15061506
// Drop leading/trailing unit dim by applying vector.shape_cast to all
15071507
// operands
15081508
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1509-
VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
15101509

15111510
SmallVector<Value> newOperands;
15121511
auto loc = op->getLoc();
15131512
for (auto operand : op->getOperands()) {
1513+
auto opVectorType = cast<VectorType>(operand.getType());
1514+
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
15141515
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
15151516
newOperands.push_back(opSC);
15161517
}
15171518

1519+
VectorType newResultVectorType =
1520+
VectorType::Builder(resultVectorType).dropDim(dim);
15181521
// Create an updated elementwise Op without leading/trailing unit dim
15191522
Operation *elementwiseOp =
15201523
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1521-
newVType, op->getAttrs());
1524+
newResultVectorType, op->getAttrs());
15221525

1523-
// Restore the leading/trailing unit dim by applying vector.shape_cast to
1524-
// the result
1525-
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
1526+
// Restore the leading/trailing unit dim by applying vector.shape_cast
1527+
// to the result
1528+
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
15261529
elementwiseOp->getResult(0));
15271530

15281531
return success();

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
321321

322322
// -----
323323

324+
func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
325+
%sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
326+
%add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
327+
%res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
328+
return %res : vector<8x[2]xf32>
329+
}
330+
331+
// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
332+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
333+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
334+
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
335+
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
336+
337+
// -----
338+
324339
// All shape casts are folded away
325340

326341
func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
341356
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
342357
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
343358
// CHECK: return %[[VAL_4]] : vector<8xi32>
344-

0 commit comments

Comments
 (0)