@@ -1410,12 +1410,6 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
1410
1410
tensor::UnPackOp unpackOp,
1411
1411
ArrayRef<int64_t > inputVectorSizes,
1412
1412
SmallVectorImpl<Value> &newResults) {
1413
- // Handling this case requires a bit more change. Right now
1414
- // just the required attributes are handled.
1415
- if (!unpackOp.getOuterDimsPerm ().empty ()) {
1416
- LDBG (" outer dimensions perms NYI for: " << unpackOp);
1417
- return failure ();
1418
- }
1419
1413
1420
1414
OpBuilder::InsertionGuard g (rewriter);
1421
1415
rewriter.setInsertionPoint (unpackOp);
@@ -1442,18 +1436,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
1442
1436
return failure ();
1443
1437
}
1444
1438
int64_t unpackRank = unpackTensorType.getRank ();
1439
+ Location loc = unpackOp->getLoc ();
1445
1440
arith::ConstantIndexOp zeroOp =
1446
- rewriter.create <arith::ConstantIndexOp>(unpackOp-> getLoc () , 0 );
1441
+ rewriter.create <arith::ConstantIndexOp>(loc , 0 );
1447
1442
1448
1443
vector::TransferReadOp readOp = rewriter.create <vector::TransferReadOp>(
1449
- unpackOp. getLoc () , vectorType, unpackOp.getSource (),
1444
+ loc , vectorType, unpackOp.getSource (),
1450
1445
SmallVector<Value>(unpackRank, zeroOp),
1451
1446
rewriter.getMultiDimIdentityMap (unpackRank));
1452
1447
1453
1448
auto readMaskType = VectorType::get (readMaskShape, rewriter.getI1Type ());
1454
1449
Value mask = rewriter.create <vector::CreateMaskOp>(
1455
- unpackOp. getLoc () , readMaskType,
1456
- tensor::getMixedSizes (rewriter, unpackOp. getLoc () , unpackOp.getSource ()));
1450
+ loc , readMaskType,
1451
+ tensor::getMixedSizes (rewriter, loc , unpackOp.getSource ()));
1457
1452
vector::MaskOp maskedOp =
1458
1453
cast<vector::MaskOp>(mlir::vector::maskOperation (rewriter, readOp, mask));
1459
1454
@@ -1474,25 +1469,23 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
1474
1469
RankedTensorType::Builder (stripMineShape, stripMineElemType, {})
1475
1470
.setShape (stripMineShape);
1476
1471
1477
- // Collapse the tensor to the size required by result.
1478
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1479
- stripMineTensorType, packMetadata.reassociations );
1480
- auto vecCollapsedType =
1481
- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1482
-
1483
1472
// Transpose the appropriate rows to match output.
1484
1473
vector::TransposeOp transposeOp = rewriter.create <vector::TransposeOp>(
1485
- unpackOp. getLoc () , maskedOp.getResult (0 ), lastDimToInsertPosPerm);
1474
+ loc , maskedOp.getResult (0 ), lastDimToInsertPosPerm);
1486
1475
1476
+ // Collapse the vector to the size required by result.
1477
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1478
+ stripMineTensorType, packMetadata.reassociations );
1479
+ mlir::VectorType vecCollapsedType =
1480
+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1487
1481
vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1488
- unpackOp.getLoc (), vecCollapsedType, transposeOp->getResult (0 ));
1489
- tensor::EmptyOp emptyOp =
1490
- rewriter.create <tensor::EmptyOp>(unpackOp.getLoc (), reifiedRetShapes[0 ],
1491
- unpackTensorType.getElementType ());
1482
+ loc, vecCollapsedType, transposeOp->getResult (0 ));
1483
+ tensor::EmptyOp emptyOp = rewriter.create <tensor::EmptyOp>(
1484
+ loc, reifiedRetShapes[0 ], unpackTensorType.getElementType ());
1492
1485
1493
1486
int64_t destRank = cast<ShapedType>(emptyOp.getType ()).getRank ();
1494
1487
Operation *writeOp = rewriter.create <vector::TransferWriteOp>(
1495
- unpackOp. getLoc () , shapeCastOp->getResult (0 ), emptyOp,
1488
+ loc , shapeCastOp->getResult (0 ), emptyOp,
1496
1489
SmallVector<Value>(destRank, zeroOp), SmallVector<bool >(destRank, true ));
1497
1490
auto resultShape = unpackOp.getResult ().getType ().getShape ();
1498
1491
@@ -1516,7 +1509,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
1516
1509
// WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
1517
1510
auto writeMaskType = VectorType::get (writeMaskShape, rewriter.getI1Type ());
1518
1511
Value writeMask = rewriter.create <vector::CreateMaskOp>(
1519
- unpackOp. getLoc () , writeMaskType, reifiedRetShapes[0 ]);
1512
+ loc , writeMaskType, reifiedRetShapes[0 ]);
1520
1513
Operation *writeOpWithMask =
1521
1514
mlir::vector::maskOperation (rewriter, writeOp, writeMask);
1522
1515
result = writeOpWithMask->getResult (0 );
@@ -1783,12 +1776,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
1783
1776
static LogicalResult
1784
1777
vectorizeUnPackOpPreCondition (tensor::UnPackOp unpackOp,
1785
1778
ArrayRef<int64_t > inputVectorSizes) {
1779
+
1780
+ // Handling this case requires a bit more change. Right now
1781
+ // just the required attributes are handled.
1782
+ if (!unpackOp.getOuterDimsPerm ().empty ()) {
1783
+ LDBG (" outer dimensions perms NYI for: " << unpackOp);
1784
+ return failure ();
1785
+ }
1786
+
1786
1787
if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
1787
1788
return !getConstantIntValue (res).has_value ();
1788
1789
})) {
1789
1790
LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1790
1791
return failure ();
1791
1792
}
1793
+ llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1794
+ if (inputVectorSizes.empty () == false &&
1795
+ failed (isValidMaskedInputVector (resultShape, inputVectorSizes)))
1796
+ return failure ();
1797
+
1792
1798
return success ();
1793
1799
}
1794
1800
0 commit comments