@@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
415
415
" only 1-D vectors are supported ATM" );
416
416
417
417
auto loc = op.getLoc ();
418
+
418
419
auto valueToStore = cast<VectorValue>(op.getValueToStore ());
419
- auto oldElementType = valueToStore.getType ().getElementType ();
420
- auto newElementType =
420
+ auto containerElemTy =
421
421
cast<MemRefType>(adaptor.getBase ().getType ()).getElementType ();
422
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
423
- int dstBits = newElementType.getIntOrFloatBitWidth ();
422
+ Type emulatedElemTy = op.getValueToStore ().getType ().getElementType ();
423
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
424
+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
424
425
425
- if (dstBits % srcBits != 0 ) {
426
+ // Check per-element alignment.
427
+ if (containerBits % emulatedBits != 0 ) {
426
428
return rewriter.notifyMatchFailure (
427
- op, " only dstBits % srcBits == 0 supported" );
429
+ op, " impossible to pack emulated elements into container elements "
430
+ " (bit-wise misalignment)" );
428
431
}
429
- int numSrcElemsPerDest = dstBits / srcBits ;
432
+ int numSrcElemsPerDest = containerBits / emulatedBits ;
430
433
431
434
// Adjust the number of elements to store when emulating narrow types.
432
435
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
451
454
memref::LinearizedMemRefInfo linearizedInfo;
452
455
std::tie (linearizedInfo, linearizedIndices) =
453
456
memref::getLinearizedMemRefOffsetAndSize (
454
- rewriter, loc, srcBits, dstBits ,
457
+ rewriter, loc, emulatedBits, containerBits ,
455
458
stridedMetadata.getConstifiedMixedOffset (),
456
459
stridedMetadata.getConstifiedMixedSizes (),
457
460
stridedMetadata.getConstifiedMixedStrides (),
@@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
483
486
// Basic case: storing full bytes.
484
487
auto numElements = origElements / numSrcElemsPerDest;
485
488
auto bitCast = rewriter.create <vector::BitCastOp>(
486
- loc, VectorType::get (numElements, newElementType ),
489
+ loc, VectorType::get (numElements, containerElemTy ),
487
490
op.getValueToStore ());
488
491
rewriter.replaceOpWithNewOp <vector::StoreOp>(
489
492
op, bitCast.getResult (), memrefBase,
@@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
638
641
" only 1-D vectors are supported ATM" );
639
642
640
643
auto loc = op.getLoc ();
641
- auto convertedType = cast<MemRefType>(adaptor. getBase (). getType ());
642
- Type oldElementType = op. getValueToStore ().getType ().getElementType ();
643
- Type newElementType = convertedType .getElementType ();
644
- int srcBits = oldElementType .getIntOrFloatBitWidth ();
645
- int dstBits = newElementType .getIntOrFloatBitWidth ();
644
+ auto containerElemTy =
645
+ cast<MemRefType>(adaptor. getBase ().getType () ).getElementType ();
646
+ Type emulatedElemTy = op. getValueToStore (). getType () .getElementType ();
647
+ int emulatedBits = emulatedElemTy .getIntOrFloatBitWidth ();
648
+ int containerBits = containerElemTy .getIntOrFloatBitWidth ();
646
649
647
- if (dstBits % srcBits != 0 ) {
650
+ // Check per-element alignment.
651
+ if (containerBits % emulatedBits != 0 ) {
648
652
return rewriter.notifyMatchFailure (
649
- op, " only dstBits % srcBits == 0 supported" );
653
+ op, " impossible to pack emulated elements into container elements "
654
+ " (bit-wise misalignment)" );
650
655
}
651
656
652
- int scale = dstBits / srcBits ;
657
+ int scale = containerBits / emulatedBits ;
653
658
int origElements = op.getValueToStore ().getType ().getNumElements ();
654
659
if (origElements % scale != 0 )
655
660
return failure ();
@@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
660
665
memref::LinearizedMemRefInfo linearizedInfo;
661
666
std::tie (linearizedInfo, linearizedIndicesOfr) =
662
667
memref::getLinearizedMemRefOffsetAndSize (
663
- rewriter, loc, srcBits, dstBits ,
668
+ rewriter, loc, emulatedBits, containerBits ,
664
669
stridedMetadata.getConstifiedMixedOffset (),
665
670
stridedMetadata.getConstifiedMixedSizes (),
666
671
stridedMetadata.getConstifiedMixedStrides (),
@@ -706,15 +711,15 @@ struct ConvertVectorMaskedStore final
706
711
return failure ();
707
712
708
713
auto numElements = (origElements + scale - 1 ) / scale;
709
- auto newType = VectorType::get (numElements, newElementType );
714
+ auto newType = VectorType::get (numElements, containerElemTy );
710
715
auto passThru = rewriter.create <arith::ConstantOp>(
711
716
loc, newType, rewriter.getZeroAttr (newType));
712
717
713
718
auto newLoad = rewriter.create <vector::MaskedLoadOp>(
714
719
loc, newType, adaptor.getBase (), linearizedIndices,
715
720
newMask.value ()->getResult (0 ), passThru);
716
721
717
- auto newBitCastType = VectorType::get (numElements * scale, oldElementType );
722
+ auto newBitCastType = VectorType::get (numElements * scale, emulatedElemTy );
718
723
Value valueToStore =
719
724
rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
720
725
valueToStore = rewriter.create <arith::SelectOp>(
@@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
746
751
" only 1-D vectors are supported ATM" );
747
752
748
753
auto loc = op.getLoc ();
749
- auto convertedType = cast<MemRefType>(adaptor. getBase (). getType ());
750
- Type oldElementType = op. getType ().getElementType ();
751
- Type newElementType = convertedType .getElementType ();
752
- int srcBits = oldElementType .getIntOrFloatBitWidth ();
753
- int dstBits = newElementType .getIntOrFloatBitWidth ();
754
+ auto containerElemTy =
755
+ cast<MemRefType>(adaptor. getBase (). getType () ).getElementType ();
756
+ Type emulatedElemTy = op. getType () .getElementType ();
757
+ int emulatedBits = emulatedElemTy .getIntOrFloatBitWidth ();
758
+ int containerBits = containerElemTy .getIntOrFloatBitWidth ();
754
759
755
- if (dstBits % srcBits != 0 ) {
760
+ // Check per-element alignment.
761
+ if (containerBits % emulatedBits != 0 ) {
756
762
return rewriter.notifyMatchFailure (
757
- op, " only dstBits % srcBits == 0 supported" );
763
+ op, " impossible to pack emulated elements into container elements "
764
+ " (bit-wise misalignment)" );
758
765
}
759
- int scale = dstBits / srcBits ;
766
+ int scale = containerBits / emulatedBits ;
760
767
761
768
// Adjust the number of elements to load when emulating narrow types,
762
769
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797
804
memref::LinearizedMemRefInfo linearizedInfo;
798
805
std::tie (linearizedInfo, linearizedIndices) =
799
806
memref::getLinearizedMemRefOffsetAndSize (
800
- rewriter, loc, srcBits, dstBits ,
807
+ rewriter, loc, emulatedBits, containerBits ,
801
808
stridedMetadata.getConstifiedMixedOffset (),
802
809
stridedMetadata.getConstifiedMixedSizes (),
803
810
stridedMetadata.getConstifiedMixedStrides (),
@@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
814
821
llvm::divideCeil (maxintraDataOffset + origElements, scale);
815
822
Value result =
816
823
emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
817
- numElements, oldElementType, newElementType );
824
+ numElements, emulatedElemTy, containerElemTy );
818
825
819
826
if (!foldedIntraVectorOffset) {
820
827
auto resultVector = rewriter.create <arith::ConstantOp>(
@@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
848
855
" only 1-D vectors are supported ATM" );
849
856
850
857
auto loc = op.getLoc ();
851
- auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
852
- Type oldElementType = op.getType ().getElementType ();
853
- Type newElementType = convertedType.getElementType ();
854
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
855
- int dstBits = newElementType.getIntOrFloatBitWidth ();
856
858
857
- if (dstBits % srcBits != 0 ) {
859
+ auto containerElemTy =
860
+ cast<MemRefType>(adaptor.getBase ().getType ()).getElementType ();
861
+ Type emulatedElemTy = op.getType ().getElementType ();
862
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
863
+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
864
+
865
+ // Check per-element alignment.
866
+ if (containerBits % emulatedBits != 0 ) {
858
867
return rewriter.notifyMatchFailure (
859
- op, " only dstBits % srcBits == 0 supported" );
868
+ op, " impossible to pack emulated elements into container elements "
869
+ " (bit-wise misalignment)" );
860
870
}
861
- int scale = dstBits / srcBits ;
871
+ int scale = containerBits / emulatedBits ;
862
872
863
873
// Adjust the number of elements to load when emulating narrow types,
864
874
// and then cast back to the original type with vector.bitcast op.
@@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
912
922
memref::LinearizedMemRefInfo linearizedInfo;
913
923
std::tie (linearizedInfo, linearizedIndices) =
914
924
memref::getLinearizedMemRefOffsetAndSize (
915
- rewriter, loc, srcBits, dstBits ,
925
+ rewriter, loc, emulatedBits, containerBits ,
916
926
stridedMetadata.getConstifiedMixedOffset (),
917
927
stridedMetadata.getConstifiedMixedSizes (),
918
928
stridedMetadata.getConstifiedMixedStrides (),
@@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final
933
943
934
944
auto numElements =
935
945
llvm::divideCeil (maxIntraDataOffset + origElements, scale);
936
- auto loadType = VectorType::get (numElements, newElementType );
937
- auto newBitcastType = VectorType::get (numElements * scale, oldElementType );
946
+ auto loadType = VectorType::get (numElements, containerElemTy );
947
+ auto newBitcastType = VectorType::get (numElements * scale, emulatedElemTy );
938
948
939
949
auto emptyVector = rewriter.create <arith::ConstantOp>(
940
950
loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
@@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
1009
1019
" only 1-D vectors are supported ATM" );
1010
1020
1011
1021
auto loc = op.getLoc ();
1012
- auto convertedType = cast<MemRefType>(adaptor.getSource ().getType ());
1013
- Type oldElementType = op.getType ().getElementType ();
1014
- Type newElementType = convertedType.getElementType ();
1015
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
1016
- int dstBits = newElementType.getIntOrFloatBitWidth ();
1017
-
1018
- if (dstBits % srcBits != 0 ) {
1022
+ auto containerElemTy =
1023
+ cast<MemRefType>(adaptor.getSource ().getType ()).getElementType ();
1024
+ Type emulatedElemTy = op.getType ().getElementType ();
1025
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
1026
+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
1027
+
1028
+ // Check per-element alignment.
1029
+ if (containerBits % emulatedBits != 0 ) {
1019
1030
return rewriter.notifyMatchFailure (
1020
- op, " only dstBits % srcBits == 0 supported" );
1031
+ op, " impossible to pack emulated elements into container elements "
1032
+ " (bit-wise misalignment)" );
1021
1033
}
1022
- int scale = dstBits / srcBits ;
1034
+ int scale = containerBits / emulatedBits ;
1023
1035
1024
1036
auto origElements = op.getVectorType ().getNumElements ();
1025
1037
1026
1038
bool isAlignedEmulation = origElements % scale == 0 ;
1027
1039
1028
- auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType ,
1040
+ auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy ,
1029
1041
adaptor.getPadding ());
1030
1042
1031
1043
auto stridedMetadata =
@@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
1035
1047
memref::LinearizedMemRefInfo linearizedInfo;
1036
1048
std::tie (linearizedInfo, linearizedIndices) =
1037
1049
memref::getLinearizedMemRefOffsetAndSize (
1038
- rewriter, loc, srcBits, dstBits ,
1050
+ rewriter, loc, emulatedBits, containerBits ,
1039
1051
stridedMetadata.getConstifiedMixedOffset (),
1040
1052
stridedMetadata.getConstifiedMixedSizes (),
1041
1053
stridedMetadata.getConstifiedMixedStrides (),
@@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
1051
1063
llvm::divideCeil (maxIntraDataOffset + origElements, scale);
1052
1064
1053
1065
auto newRead = rewriter.create <vector::TransferReadOp>(
1054
- loc, VectorType::get (numElements, newElementType ), adaptor.getSource (),
1066
+ loc, VectorType::get (numElements, containerElemTy ), adaptor.getSource (),
1055
1067
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
1056
1068
newPadding);
1057
1069
1058
1070
auto bitCast = rewriter.create <vector::BitCastOp>(
1059
- loc, VectorType::get (numElements * scale, oldElementType ), newRead);
1071
+ loc, VectorType::get (numElements * scale, emulatedElemTy ), newRead);
1060
1072
1061
1073
Value result = bitCast->getResult (0 );
1062
1074
if (!foldedIntraVectorOffset) {
0 commit comments