Skip to content

Commit 32e81c5

Browse files
committed
helper function does not return failure.
1 parent fad986b commit 32e81c5

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,20 +1607,20 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16071607
}
16081608
};
16091609

1610-
FailureOr<VectorType> dropNonScalableUnitDimType(VectorType inVecTy) {
1611-
int numUnitDimsDropped = 0;
1612-
auto inVecShape = inVecTy.getShape();
1610+
VectorType dropNonScalableUnitDimType(VectorType inVecTy) {
16131611
auto newVecBuilder = VectorType::Builder(inVecTy);
1614-
for (unsigned i = 0; i < inVecShape.size(); i++) {
1615-
if (inVecShape[i] == 1 && !inVecTy.getScalableDims()[i]) {
1616-
newVecBuilder.dropDim(i - numUnitDimsDropped);
1617-
numUnitDimsDropped++;
1612+
auto inVecShape = inVecTy.getShape();
1613+
SmallVector<int64_t> newShape;
1614+
SmallVector<bool> newScalableDims;
1615+
for (auto [dim, isScalable] :
1616+
llvm::zip(inVecShape, inVecTy.getScalableDims())) {
1617+
if (dim != 1 || isScalable) {
1618+
newShape.push_back(dim);
1619+
newScalableDims.push_back(isScalable);
16181620
}
16191621
}
16201622

1621-
if (numUnitDimsDropped == 0)
1622-
return failure();
1623-
return VectorType(newVecBuilder);
1623+
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
16241624
}
16251625

16261626
/// For vectors with at least an unit dim, replaces:
@@ -1676,16 +1676,15 @@ struct DropUnitDimFromElementwiseOps final
16761676
for (auto operand : op->getOperands()) {
16771677
auto opVectorType = cast<VectorType>(operand.getType());
16781678
auto newVType = dropNonScalableUnitDimType(opVectorType);
1679-
if (failed(newVType)) {
1680-
return failure();
1679+
if (newVType == opVectorType) {
1680+
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
16811681
}
1682-
auto opSC =
1683-
rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
1682+
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16841683
newOperands.push_back(opSC);
16851684
}
16861685

16871686
VectorType newResultVectorType =
1688-
dropNonScalableUnitDimType(resultVectorType).value();
1687+
dropNonScalableUnitDimType(resultVectorType);
16891688
// Create an updated elementwise Op without unit dim
16901689
Operation *elementwiseOp =
16911690
rewriter.create(loc, op->getName().getIdentifier(), newOperands,

0 commit comments

Comments
 (0)