@@ -400,6 +400,9 @@ namespace {
400
400
// ConvertVectorStore
401
401
// ===----------------------------------------------------------------------===//
402
402
403
+ // /
404
+ // /
405
+
403
406
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
404
407
using OpConversionPattern::OpConversionPattern;
405
408
@@ -443,7 +446,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
443
446
// vector<4xi8>
444
447
445
448
auto origElements = valueToStore.getType ().getNumElements ();
446
- bool isUnalignedEmulation = origElements % numSrcElemsPerDest ! = 0 ;
449
+ bool isAlignedEmulation = origElements % numSrcElemsPerDest = = 0 ;
447
450
448
451
auto stridedMetadata =
449
452
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -459,9 +462,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
459
462
getAsOpFoldResult (adaptor.getIndices ()));
460
463
461
464
std::optional<int64_t > foldedNumFrontPadElems =
462
- isUnalignedEmulation
463
- ? getConstantIntValue (linearizedInfo. intraDataOffset )
464
- : 0 ;
465
+ isAlignedEmulation
466
+ ? 0
467
+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
465
468
466
469
if (!foldedNumFrontPadElems) {
467
470
return failure (" subbyte store emulation: dynamic front padding size is "
@@ -472,13 +475,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
472
475
473
476
// Shortcut: conditions when subbyte emulated store at the front is not
474
477
// needed:
475
- // 1. The source vector size is multiple of byte size
476
- // 2. The address of the store is aligned to the emulated width boundary
478
+ // 1. The source vector size (in bits) is a multiple of byte size.
479
+ // 2. The address of the store is aligned to the emulated width boundary.
477
480
//
478
481
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
479
482
// need unaligned emulation because the store address is aligned and the
480
483
// source is a whole byte.
481
- if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
484
+ if (isAlignedEmulation && *foldedNumFrontPadElems == 0 ) {
482
485
auto numElements = origElements / numSrcElemsPerDest;
483
486
auto bitCast = rewriter.create <vector::BitCastOp>(
484
487
loc, VectorType::get (numElements, newElementType),
@@ -489,17 +492,50 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
489
492
return success ();
490
493
}
491
494
492
- // The index into the target memref we are storing to
495
+ // Next, handle the case when sub-byte read-modify-write
496
+ // sequences are needed to emulate a vector store.
497
+ // Here is an example:
498
+ //
499
+ // Vector to store: vector<7xi2>
500
+ // Value to store: 11 11 11 11 11 11 11 (all ones)
501
+ //
502
+ // Destination: memref<12xi2>
503
+ // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
504
+ //
505
+ // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
506
+ //
507
+ // Destination memref before:
508
+ //
509
+ // Byte 0 Byte 1 Byte 2
510
+ // +----------+----------+----------+
511
+ // | 00000000 | 00000000 | 00000000 |
512
+ // +----------+----------+----------+
513
+ //
514
+ // Destination memref after:
515
+ //
516
+ // Byte 0 Byte 1 Byte 2
517
+ // +----------+----------+----------+
518
+ // | 00001111 | 11111111 | 11000000 |
519
+ // +----------+----------+----------+
520
+ //
521
+ // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
522
+ // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
523
+ // requiring RMW access (atomicity is required).
524
+
525
+ // The index into the target memref we are storing to.
493
526
Value currentDestIndex =
494
527
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
528
+ // The index into the source vector we are currently processing.
529
+ auto currentSourceIndex = 0 ;
530
+
531
+ // Build a mask used for rmw.
495
532
auto subWidthStoreMaskType =
496
533
VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
497
- // The index into the source vector we are currently processing
498
- auto currentSourceIndex = 0 ;
499
534
500
- // 1. Partial width store for the first byte, when the store address is not
501
- // aligned to emulated width boundary, deal with the unaligned part so that
502
- // the rest elements are aligned to width boundary.
535
+ // 1. Partial width store for the leading byte.
536
+ // When the store address is not aligned to emulated width boundary, deal
537
+ // with the unaligned part so that the rest elements are aligned to width
538
+ // boundary.
503
539
auto frontSubWidthStoreElem =
504
540
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
505
541
if (frontSubWidthStoreElem > 0 ) {
@@ -535,8 +571,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535
571
currentDestIndex = rewriter.create <arith::AddIOp>(
536
572
loc, rewriter.getIndexType (), currentDestIndex, constantOne);
537
573
538
- // 2. Full width store. After the previous step, the store address is
539
- // aligned to the emulated width boundary.
574
+ // 2. Full width store for the inner output bytes.
575
+ // After the previous step, the store address is aligned to the emulated
576
+ // width boundary.
540
577
int64_t fullWidthStoreSize =
541
578
(origElements - currentSourceIndex) / numSrcElemsPerDest;
542
579
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
@@ -560,15 +597,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
560
597
rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
561
598
}
562
599
563
- // 3. Deal with trailing elements that are aligned to the emulated width,
564
- // but their length is smaller than the emulated width.
600
+ // 3. Partial width store for the trailing output byte.
601
+ // It is needed when the residual length is smaller than the emulated width,
602
+ // which is not covered in step 2 above.
565
603
auto remainingElements = origElements - currentSourceIndex;
566
604
if (remainingElements != 0 ) {
567
605
auto subWidthStorePart =
568
606
extractSliceIntoByte (rewriter, loc, cast<VectorValue>(valueToStore),
569
607
currentSourceIndex, remainingElements, 0 );
570
608
571
- // Generate back mask
609
+ // Generate back mask.
572
610
auto maskValues = SmallVector<bool >(numSrcElemsPerDest, 0 );
573
611
std::fill_n (maskValues.begin (), remainingElements, 1 );
574
612
auto backMask = rewriter.create <arith::ConstantOp>(
@@ -751,7 +789,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
751
789
// compile time as they must be constants.
752
790
753
791
auto origElements = op.getVectorType ().getNumElements ();
754
- bool isUnalignedEmulation = origElements % scale ! = 0 ;
792
+ bool isAlignedEmulation = origElements % scale = = 0 ;
755
793
756
794
auto stridedMetadata =
757
795
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -767,9 +805,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
767
805
getAsOpFoldResult (adaptor.getIndices ()));
768
806
769
807
std::optional<int64_t > foldedIntraVectorOffset =
770
- isUnalignedEmulation
771
- ? getConstantIntValue (linearizedInfo. intraDataOffset )
772
- : 0 ;
808
+ isAlignedEmulation
809
+ ? 0
810
+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
773
811
774
812
// Always load enough elements which can cover the original elements.
775
813
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
@@ -785,7 +823,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
785
823
result = dynamicallyExtractSubVector (
786
824
rewriter, loc, cast<VectorValue>(result), resultVector,
787
825
linearizedInfo.intraDataOffset , origElements);
788
- } else if (isUnalignedEmulation ) {
826
+ } else if (!isAlignedEmulation ) {
789
827
result = staticallyExtractSubvector (
790
828
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
791
829
}
@@ -867,7 +905,7 @@ struct ConvertVectorMaskedLoad final
867
905
// subvector at the proper offset after bit-casting.
868
906
auto origType = op.getVectorType ();
869
907
auto origElements = origType.getNumElements ();
870
- bool isUnalignedEmulation = origElements % scale ! = 0 ;
908
+ bool isAlignedEmulation = origElements % scale = = 0 ;
871
909
872
910
auto stridedMetadata =
873
911
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -882,9 +920,9 @@ struct ConvertVectorMaskedLoad final
882
920
getAsOpFoldResult (adaptor.getIndices ()));
883
921
884
922
std::optional<int64_t > foldedIntraVectorOffset =
885
- isUnalignedEmulation
886
- ? getConstantIntValue (linearizedInfo. intraDataOffset )
887
- : 0 ;
923
+ isAlignedEmulation
924
+ ? 0
925
+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
888
926
889
927
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
890
928
FailureOr<Operation *> newMask = getCompressedMaskOp (
@@ -905,7 +943,7 @@ struct ConvertVectorMaskedLoad final
905
943
passthru = dynamicallyInsertSubVector (
906
944
rewriter, loc, cast<VectorValue>(passthru), emptyVector,
907
945
linearizedInfo.intraDataOffset , origElements);
908
- } else if (isUnalignedEmulation ) {
946
+ } else if (!isAlignedEmulation ) {
909
947
passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
910
948
*foldedIntraVectorOffset);
911
949
}
@@ -933,7 +971,7 @@ struct ConvertVectorMaskedLoad final
933
971
mask = dynamicallyInsertSubVector (
934
972
rewriter, loc, cast<VectorValue>(mask), emptyMask,
935
973
linearizedInfo.intraDataOffset , origElements);
936
- } else if (isUnalignedEmulation ) {
974
+ } else if (!isAlignedEmulation ) {
937
975
mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
938
976
*foldedIntraVectorOffset);
939
977
}
@@ -944,7 +982,7 @@ struct ConvertVectorMaskedLoad final
944
982
result = dynamicallyExtractSubVector (
945
983
rewriter, loc, cast<VectorValue>(result), op.getPassThru (),
946
984
linearizedInfo.intraDataOffset , origElements);
947
- } else if (isUnalignedEmulation ) {
985
+ } else if (!isAlignedEmulation ) {
948
986
result = staticallyExtractSubvector (
949
987
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
950
988
}
@@ -986,7 +1024,7 @@ struct ConvertVectorTransferRead final
986
1024
987
1025
auto origElements = op.getVectorType ().getNumElements ();
988
1026
989
- bool isUnalignedEmulation = origElements % scale ! = 0 ;
1027
+ bool isAlignedEmulation = origElements % scale = = 0 ;
990
1028
991
1029
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
992
1030
adaptor.getPadding ());
@@ -1005,9 +1043,9 @@ struct ConvertVectorTransferRead final
1005
1043
getAsOpFoldResult (adaptor.getIndices ()));
1006
1044
1007
1045
std::optional<int64_t > foldedIntraVectorOffset =
1008
- isUnalignedEmulation
1009
- ? getConstantIntValue (linearizedInfo. intraDataOffset )
1010
- : 0 ;
1046
+ isAlignedEmulation
1047
+ ? 0
1048
+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
1011
1049
1012
1050
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
1013
1051
auto numElements =
@@ -1028,7 +1066,7 @@ struct ConvertVectorTransferRead final
1028
1066
result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
1029
1067
linearizedInfo.intraDataOffset ,
1030
1068
origElements);
1031
- } else if (isUnalignedEmulation ) {
1069
+ } else if (!isAlignedEmulation ) {
1032
1070
result = staticallyExtractSubvector (
1033
1071
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1034
1072
}
0 commit comments