@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
72
72
auto resType =
73
73
getTypeConverter ()->convertType <VectorType>(constOp.getType ());
74
74
75
+ if (!resType)
76
+ return rewriter.notifyMatchFailure (loc, " can't convert return type" );
77
+
75
78
if (resType.isScalable () && !isa<SplatElementsAttr>(constOp.getValue ()))
76
79
return rewriter.notifyMatchFailure (
77
80
loc,
78
81
" Cannot linearize a constant scalable vector that's not a splat" );
79
82
80
- if (!resType)
81
- return rewriter.notifyMatchFailure (loc, " can't convert return type" );
82
83
if (!isLessThanTargetBitWidth (constOp, targetVectorBitWidth))
83
84
return rewriter.notifyMatchFailure (
84
85
loc, " Can't flatten since targetBitWidth <= OpSize" );
@@ -459,6 +460,45 @@ struct LinearizeVectorInsert final
459
460
private:
460
461
unsigned targetVectorBitWidth;
461
462
};
463
+
464
+ // / This pattern converts the BitCastOp that works on nD (n > 1)
465
+ // / vectors to a BitCastOp that works on linearized vectors.
466
+ // / Following,
467
+ // / vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
468
+ // / is converted to :
469
+ // / %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
470
+ // / %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
471
+ // / %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
472
+ struct LinearizeVectorBitCast final
473
+ : public OpConversionPattern<vector::BitCastOp> {
474
+ using OpConversionPattern::OpConversionPattern;
475
+ LinearizeVectorBitCast (
476
+ const TypeConverter &typeConverter, MLIRContext *context,
477
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
478
+ PatternBenefit benefit = 1 )
479
+ : OpConversionPattern(typeConverter, context, benefit),
480
+ targetVectorBitWidth (targetVectBitWidth) {}
481
+ LogicalResult
482
+ matchAndRewrite (vector::BitCastOp castOp, OpAdaptor adaptor,
483
+ ConversionPatternRewriter &rewriter) const override {
484
+ Location loc = castOp.getLoc ();
485
+ auto resType = getTypeConverter ()->convertType (castOp.getType ());
486
+ if (!resType)
487
+ return rewriter.notifyMatchFailure (loc, " can't convert return type." );
488
+
489
+ if (!isLessThanTargetBitWidth (castOp, targetVectorBitWidth))
490
+ return rewriter.notifyMatchFailure (
491
+ loc, " Can't flatten since targetBitWidth <= OpSize" );
492
+
493
+ rewriter.replaceOpWithNewOp <vector::BitCastOp>(castOp, resType,
494
+ adaptor.getSource ());
495
+ return mlir::success ();
496
+ }
497
+
498
+ private:
499
+ unsigned targetVectorBitWidth;
500
+ };
501
+
462
502
} // namespace
463
503
464
504
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
@@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
485
525
typeConverter.addTargetMaterialization (materializeCast);
486
526
target.markUnknownOpDynamicallyLegal (
487
527
[=](Operation *op) -> std::optional<bool > {
488
- if ((isa<arith::ConstantOp>(op) ||
528
+ if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
489
529
op->hasTrait <OpTrait::Vectorizable>())) {
490
530
return (isLessThanTargetBitWidth (op, targetBitWidth)
491
531
? typeConverter.isLegal (op)
@@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
494
534
return std::nullopt;
495
535
});
496
536
497
- patterns.add <LinearizeConstant, LinearizeVectorizable>(
498
- typeConverter, patterns.getContext (), targetBitWidth);
537
+ patterns
538
+ .add <LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539
+ typeConverter, patterns.getContext (), targetBitWidth);
499
540
}
500
541
501
542
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments