@@ -1971,6 +1971,8 @@ getConvOperationKind(Operation *reduceOp) {
1971
1971
// is in `buildBinaryFn` helper in the Linalg dialect.
1972
1972
auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
1973
1973
llvm::IsaPred<BlockArgument>);
1974
+ assert (feedValIt != reduceOp->operand_end () &&
1975
+ " Expected a non-block argument operand" );
1974
1976
Operation *feedOp = (*feedValIt).getDefiningOp ();
1975
1977
if (isCastOfBlockArgument (feedOp)) {
1976
1978
return ConvOperationKind::Pool;
@@ -2017,17 +2019,12 @@ static bool isSupportedPoolKind(vector::CombiningKind kind) {
2017
2019
}
2018
2020
2019
2021
static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
2020
- if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
2021
- return failure ();
2022
-
2023
- auto lhsShaped = convOp.getDpsInputOperand (0 )->get ();
2024
- auto rhsShaped = convOp.getDpsInputOperand (1 )->get ();
2025
- auto resShaped = convOp.getDpsInitOperand (0 )->get ();
2026
- auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
2027
- auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
2028
- auto resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
2029
- if (!lhsShapedType || !rhsShapedType || !resShapedType)
2030
- return failure ();
2022
+ auto getOperandType = [&](auto operand) {
2023
+ return dyn_cast<ShapedType>((operand->get ()).getType ());
2024
+ };
2025
+ ShapedType lhsShapedType = getOperandType (convOp.getDpsInputOperand (0 ));
2026
+ ShapedType rhsShapedType = getOperandType (convOp.getDpsInputOperand (1 ));
2027
+ ShapedType resShapedType = getOperandType (convOp.getDpsInitOperand (0 ));
2031
2028
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2032
2029
// (non-channeled convolution -> LHS and RHS both have single dimensions).
2033
2030
if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
0 commit comments