@@ -437,8 +437,10 @@ struct OneDimMultiReductionToTwoDim
437
437
auto loc = multiReductionOp.getLoc ();
438
438
auto srcVectorType = multiReductionOp.getSourceVectorType ();
439
439
auto srcShape = srcVectorType.getShape ();
440
- auto castedType = VectorType::get (ArrayRef<int64_t >{1 , srcShape.back ()},
441
- srcVectorType.getElementType ());
440
+ auto castedType = VectorType::get (
441
+ ArrayRef<int64_t >{1 , srcShape.back ()}, srcVectorType.getElementType (),
442
+ ArrayRef<bool >{false , srcVectorType.getScalableDims ().back ()});
443
+
442
444
auto accType =
443
445
VectorType::get (ArrayRef<int64_t >{1 }, srcVectorType.getElementType ());
444
446
assert (!llvm::isa<VectorType>(multiReductionOp.getDestType ()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
455
457
loc, accType, multiReductionOp.getAcc ());
456
458
Value castMask;
457
459
if (maskableOp.isMasked ()) {
458
- auto maskType = llvm::cast<ShapedType>(mask.getType ());
459
- auto castMaskType =
460
- VectorType::get (ArrayRef<int64_t >{1 , maskType.getShape ().back ()},
461
- maskType.getElementType ());
460
+ auto maskType = llvm::cast<VectorType>(mask.getType ());
461
+ auto castMaskType = VectorType::get (
462
+ ArrayRef<int64_t >{1 , maskType.getShape ().back ()},
463
+ maskType.getElementType (),
464
+ ArrayRef<bool >{false , maskType.getScalableDims ().back ()});
462
465
castMask = rewriter.create <vector::BroadcastOp>(loc, castMaskType, mask);
463
466
}
464
467
0 commit comments