Skip to content

Commit a58b8da

Browse files
committed
Addressing review feedbacks
1 parent 2b5d1dc commit a58b8da

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,6 +1971,8 @@ getConvOperationKind(Operation *reduceOp) {
19711971
// is in `buildBinaryFn` helper in the Linalg dialect.
19721972
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
19731973
llvm::IsaPred<BlockArgument>);
1974+
assert(feedValIt != reduceOp->operand_end() &&
1975+
"Expected a non-block argument operand");
19741976
Operation *feedOp = (*feedValIt).getDefiningOp();
19751977
if (isCastOfBlockArgument(feedOp)) {
19761978
return ConvOperationKind::Pool;
@@ -2017,17 +2019,12 @@ static bool isSupportedPoolKind(vector::CombiningKind kind) {
20172019
}
20182020

20192021
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2020-
if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
2021-
return failure();
2022-
2023-
auto lhsShaped = convOp.getDpsInputOperand(0)->get();
2024-
auto rhsShaped = convOp.getDpsInputOperand(1)->get();
2025-
auto resShaped = convOp.getDpsInitOperand(0)->get();
2026-
auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2027-
auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2028-
auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2029-
if (!lhsShapedType || !rhsShapedType || !resShapedType)
2030-
return failure();
2022+
auto getOperandType = [&](auto operand) {
2023+
return dyn_cast<ShapedType>((operand->get()).getType());
2024+
};
2025+
ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2026+
ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2027+
ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
20312028
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
20322029
// (non-channeled convolution -> LHS and RHS both have single dimensions).
20332030
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&

0 commit comments

Comments
 (0)