Skip to content

Commit c33642b

Browse files
committed
Added some changes proposed by HanHan.
1 parent 70cc122 commit c33642b

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,12 +1410,6 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
14101410
tensor::UnPackOp unpackOp,
14111411
ArrayRef<int64_t> inputVectorSizes,
14121412
SmallVectorImpl<Value> &newResults) {
1413-
// Handling this case requires a bit more change. Right now
1414-
// just the required attributes are handled.
1415-
if (!unpackOp.getOuterDimsPerm().empty()) {
1416-
LDBG("outer dimensions perms NYI for: " << unpackOp);
1417-
return failure();
1418-
}
14191413

14201414
OpBuilder::InsertionGuard g(rewriter);
14211415
rewriter.setInsertionPoint(unpackOp);
@@ -1442,18 +1436,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
14421436
return failure();
14431437
}
14441438
int64_t unpackRank = unpackTensorType.getRank();
1439+
Location loc = unpackOp->getLoc();
14451440
arith::ConstantIndexOp zeroOp =
1446-
rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
1441+
rewriter.create<arith::ConstantIndexOp>(loc, 0);
14471442

14481443
vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
1449-
unpackOp.getLoc(), vectorType, unpackOp.getSource(),
1444+
loc, vectorType, unpackOp.getSource(),
14501445
SmallVector<Value>(unpackRank, zeroOp),
14511446
rewriter.getMultiDimIdentityMap(unpackRank));
14521447

14531448
auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
14541449
Value mask = rewriter.create<vector::CreateMaskOp>(
1455-
unpackOp.getLoc(), readMaskType,
1456-
tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
1450+
loc, readMaskType,
1451+
tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()));
14571452
vector::MaskOp maskedOp =
14581453
cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
14591454

@@ -1474,25 +1469,23 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
14741469
RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
14751470
.setShape(stripMineShape);
14761471

1477-
// Collapse the tensor to the size required by result.
1478-
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1479-
stripMineTensorType, packMetadata.reassociations);
1480-
auto vecCollapsedType =
1481-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1482-
14831472
// Transpose the appropriate rows to match output.
14841473
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1485-
unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
1474+
loc, maskedOp.getResult(0), lastDimToInsertPosPerm);
14861475

1476+
// Collapse the vector to the size required by result.
1477+
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1478+
stripMineTensorType, packMetadata.reassociations);
1479+
mlir::VectorType vecCollapsedType =
1480+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
14871481
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1488-
unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
1489-
tensor::EmptyOp emptyOp =
1490-
rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
1491-
unpackTensorType.getElementType());
1482+
loc, vecCollapsedType, transposeOp->getResult(0));
1483+
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
1484+
loc, reifiedRetShapes[0], unpackTensorType.getElementType());
14921485

14931486
int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
14941487
Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
1495-
unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
1488+
loc, shapeCastOp->getResult(0), emptyOp,
14961489
SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
14971490
auto resultShape = unpackOp.getResult().getType().getShape();
14981491

@@ -1516,7 +1509,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
15161509
// WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
15171510
auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
15181511
Value writeMask = rewriter.create<vector::CreateMaskOp>(
1519-
unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
1512+
loc, writeMaskType, reifiedRetShapes[0]);
15201513
Operation *writeOpWithMask =
15211514
mlir::vector::maskOperation(rewriter, writeOp, writeMask);
15221515
result = writeOpWithMask->getResult(0);
@@ -1783,12 +1776,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
17831776
static LogicalResult
17841777
vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
17851778
ArrayRef<int64_t> inputVectorSizes) {
1779+
1780+
// Handling this case requires a bit more change. Right now
1781+
// just the required attributes are handled.
1782+
if (!unpackOp.getOuterDimsPerm().empty()) {
1783+
LDBG("outer dimensions perms NYI for: " << unpackOp);
1784+
return failure();
1785+
}
1786+
17861787
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
17871788
return !getConstantIntValue(res).has_value();
17881789
})) {
17891790
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
17901791
return failure();
17911792
}
1793+
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1794+
if (inputVectorSizes.empty() == false &&
1795+
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
1796+
return failure();
1797+
17921798
return success();
17931799
}
17941800

0 commit comments

Comments
 (0)