@@ -1480,17 +1480,17 @@ struct DropUnitDimFromElementwiseOps final
1480
1480
using OpTraitRewritePattern::OpTraitRewritePattern;
1481
1481
LogicalResult matchAndRewrite (Operation *op,
1482
1482
PatternRewriter &rewriter) const override {
1483
- if (op->getNumResults () != 1 )
1483
+ if (op->getNumResults () != 1 || op-> getNumRegions () != 0 )
1484
1484
return failure ();
1485
1485
1486
- // Check the pre-condiitions. For `Elementwise` Ops all operands
1487
- // are guaranteed to have identical shapes and it suffices to only check the
1488
- // first one.
1489
- auto op1 = op->getOperands ()[0 ];
1490
- auto sourceVectorType = dyn_cast<VectorType>(op1.getType ());
1491
- if (!sourceVectorType)
1486
+ auto resultVectorType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
1487
+ if (!resultVectorType)
1492
1488
return failure ();
1493
1489
1490
+ // Check the pre-conditions. For `Elementwise` Ops all operands are
1491
+ // guaranteed to have identical shapes and it suffices to only check the
1492
+ // first one.
1493
+ auto sourceVectorType = cast<VectorType>(op->getOperands ()[0 ].getType ());
1494
1494
if (sourceVectorType.getRank () < 2 )
1495
1495
return failure ();
1496
1496
@@ -1506,23 +1506,26 @@ struct DropUnitDimFromElementwiseOps final
1506
1506
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1507
1507
// operands
1508
1508
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank () - 1 ;
1509
- VectorType newVType = VectorType::Builder (sourceVectorType).dropDim (dim);
1510
1509
1511
1510
SmallVector<Value> newOperands;
1512
1511
auto loc = op->getLoc ();
1513
1512
for (auto operand : op->getOperands ()) {
1513
+ auto opVectorType = cast<VectorType>(operand.getType ());
1514
+ VectorType newVType = VectorType::Builder (opVectorType).dropDim (dim);
1514
1515
auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
1515
1516
newOperands.push_back (opSC);
1516
1517
}
1517
1518
1519
+ VectorType newResultVectorType =
1520
+ VectorType::Builder (resultVectorType).dropDim (dim);
1518
1521
// Create an updated elementwise Op without leading/trailing unit dim
1519
1522
Operation *elementwiseOp =
1520
1523
rewriter.create (loc, op->getName ().getIdentifier (), newOperands,
1521
- newVType , op->getAttrs ());
1524
+ newResultVectorType , op->getAttrs ());
1522
1525
1523
- // Restore the leading/trailing unit dim by applying vector.shape_cast to
1524
- // the result
1525
- rewriter.replaceOpWithNewOp <ShapeCastOp>(op, sourceVectorType ,
1526
+ // Restore the leading/trailing unit dim by applying vector.shape_cast
1527
+ // to the result
1528
+ rewriter.replaceOpWithNewOp <ShapeCastOp>(op, resultVectorType ,
1526
1529
elementwiseOp->getResult (0 ));
1527
1530
1528
1531
return success ();
0 commit comments