Skip to content

Commit 1ea5d0e

Browse files
committed
Support pack with no padding value
1 parent 55fadd5 commit 1ea5d0e

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,16 +1458,20 @@ static LogicalResult
14581458
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
14591459
ArrayRef<int64_t> inputVectorSizes,
14601460
SmallVectorImpl<Value> &newResults) {
1461-
auto padValue = packOp.getPaddingValue();
1461+
OpBuilder::InsertionGuard g(rewriter);
1462+
rewriter.setInsertionPoint(packOp);
1463+
14621464
Location loc = packOp.getLoc();
1465+
auto padValue = packOp.getPaddingValue();
1466+
if (!padValue) {
1467+
padValue = rewriter.create<arith::ConstantOp>(
1468+
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1469+
}
14631470
int64_t inputRank = inputVectorSizes.size();
14641471
int64_t outputRank = packOp.getDestRank();
14651472
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
14661473
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
14671474

1468-
OpBuilder::InsertionGuard g(rewriter);
1469-
rewriter.setInsertionPoint(packOp);
1470-
14711475
ReifiedRankedShapedTypeDims reifiedReturnShapes;
14721476
LogicalResult status =
14731477
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
@@ -1502,14 +1506,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15021506
/*source=*/emptyOp,
15031507
/*indices=*/SmallVector<Value>(outputRank, zero),
15041508
/*inBounds=*/SmallVector<bool>(outputRank, true));
1505-
// bool needMaskForWrite = llvm::any_of(
1506-
// llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
1507-
// [](auto it) { return std::get<0>(it) != std::get<1>(it); });
1508-
// if (needMaskForWrite) {
1509-
// Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
1510-
// loc, maskType, reifiedReturnShapes[0]);
1511-
// write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
1512-
// }
15131509
newResults.push_back(write->getResult(0));
15141510
return success();
15151511
}
@@ -1710,7 +1706,7 @@ static LogicalResult
17101706
vectorizePackOpPrecondition(tensor::PackOp packOp,
17111707
ArrayRef<int64_t> inputVectorSizes) {
17121708
auto padValue = packOp.getPaddingValue();
1713-
if (!padValue) {
1709+
if (padValue && getConstantIntValue(padValue) != std::nullopt) {
17141710
LDBG("pad value is not constant: " << packOp << "\n");
17151711
return failure();
17161712
}

0 commit comments

Comments
 (0)