@@ -432,7 +432,45 @@ namespace {
432
432
// ConvertVectorStore
433
433
// ===----------------------------------------------------------------------===//
434
434
435
- // TODO: Document-me
435
+ // Emulate `vector.store` using a multi-byte container type.
436
+ //
437
+ // The container type is obtained through Op adaptor and would normally be
438
+ // generated via `NarrowTypeEmulationConverter`.
439
+ //
440
+ // EXAMPLE 1
441
+ // (aligned store of i4, emulated using i8 as the container type)
442
+ //
443
+ // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
444
+ //
445
+ // is rewritten as:
446
+ //
447
+ // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
448
+ // vector.store %src_bitcast, %dest_bitcast[%idx]
449
+ // : memref<16xi8>, vector<4xi8>
450
+ //
451
+ // EXAMPLE 2
452
+ // (unaligned store of i2, emulated using i8 as the container type)
453
+ //
454
+ // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
455
+ //
456
+ // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
457
+ // is modelled using 3 bytes:
458
+ //
459
+ // Byte 0 Byte 1 Byte 2
460
+ // +----------+----------+----------+
461
+ // | oooooooo | ooooNNNN | NNoooooo |
462
+ // +----------+----------+----------+
463
+ //
464
+ // N - (N)ew entries (i.e. to be overwritten by vector.store)
465
+ // o - (o)ld entries (to be preserved)
466
+ //
467
+ // For the generated output in the non-atomic case, see:
468
+ // * @vector_store_i2_const_index_two_partial_stores`
469
+ // in:
470
+ // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
471
+ //
472
+ // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
473
+ // `false` to generate non-atomic RMW sequences.
436
474
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
437
475
using OpConversionPattern::OpConversionPattern;
438
476
@@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
464
502
op, " impossible to pack emulated elements into container elements "
465
503
" (bit-wise misalignment)" );
466
504
}
467
- int numSrcElemsPerDest = containerBits / emulatedBits;
505
+ int emulatedPerContainerElem = containerBits / emulatedBits;
468
506
469
507
// Adjust the number of elements to store when emulating narrow types.
470
508
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
480
518
// vector<4xi8>
481
519
482
520
auto origElements = valueToStore.getType ().getNumElements ();
483
- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0 ;
521
+ // Note, per-element-alignment was already verified above.
522
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
484
523
485
524
auto stridedMetadata =
486
525
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
496
535
getAsOpFoldResult (adaptor.getIndices ()));
497
536
498
537
std::optional<int64_t > foldedNumFrontPadElems =
499
- isAlignedEmulation
500
- ? 0
501
- : getConstantIntValue (linearizedInfo.intraDataOffset );
538
+ isFullyAligned ? 0
539
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
502
540
503
541
if (!foldedNumFrontPadElems) {
504
542
return rewriter.notifyMatchFailure (
@@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
516
554
// need unaligned emulation because the store address is aligned and the
517
555
// source is a whole byte.
518
556
bool emulationRequiresPartialStores =
519
- !isAlignedEmulation || *foldedNumFrontPadElems != 0 ;
557
+ !isFullyAligned || *foldedNumFrontPadElems != 0 ;
520
558
if (!emulationRequiresPartialStores) {
521
559
// Basic case: storing full bytes.
522
- auto numElements = origElements / numSrcElemsPerDest ;
560
+ auto numElements = origElements / emulatedPerContainerElem ;
523
561
auto bitCast = rewriter.create <vector::BitCastOp>(
524
562
loc, VectorType::get (numElements, containerElemTy),
525
563
op.getValueToStore ());
@@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
567
605
568
606
// Build a mask used for rmw.
569
607
auto subWidthStoreMaskType =
570
- VectorType::get ({numSrcElemsPerDest }, rewriter.getI1Type ());
608
+ VectorType::get ({emulatedPerContainerElem }, rewriter.getI1Type ());
571
609
572
610
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573
611
@@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
576
614
// with the unaligned part so that the rest elements are aligned to width
577
615
// boundary.
578
616
auto frontSubWidthStoreElem =
579
- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
617
+ (emulatedPerContainerElem - *foldedNumFrontPadElems) %
618
+ emulatedPerContainerElem;
580
619
if (frontSubWidthStoreElem > 0 ) {
581
- SmallVector<bool > frontMaskValues (numSrcElemsPerDest , false );
582
- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest ) {
620
+ SmallVector<bool > frontMaskValues (emulatedPerContainerElem , false );
621
+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem ) {
583
622
std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
584
623
origElements, true );
585
624
frontSubWidthStoreElem = origElements;
@@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
590
629
auto frontMask = rewriter.create <arith::ConstantOp>(
591
630
loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
592
631
593
- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
632
+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
594
633
auto value =
595
634
extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
596
635
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
614
653
// After the previous step, the store address is aligned to the emulated
615
654
// width boundary.
616
655
int64_t fullWidthStoreSize =
617
- (origElements - currentSourceIndex) / numSrcElemsPerDest;
618
- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
656
+ (origElements - currentSourceIndex) / emulatedPerContainerElem;
657
+ int64_t numNonFullWidthElements =
658
+ fullWidthStoreSize * emulatedPerContainerElem;
619
659
if (fullWidthStoreSize > 0 ) {
620
660
auto fullWidthStorePart = staticallyExtractSubvector (
621
661
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624
664
auto originType = cast<VectorType>(fullWidthStorePart.getType ());
625
665
auto memrefElemType = getElementTypeOrSelf (memrefBase.getType ());
626
666
auto storeType = VectorType::get (
627
- {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
667
+ {originType.getNumElements () / emulatedPerContainerElem},
668
+ memrefElemType);
628
669
auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
629
670
fullWidthStorePart);
630
671
rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memrefBase,
@@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
646
687
currentSourceIndex, remainingElements, 0 );
647
688
648
689
// Generate back mask.
649
- auto maskValues = SmallVector<bool >(numSrcElemsPerDest , 0 );
690
+ auto maskValues = SmallVector<bool >(emulatedPerContainerElem , 0 );
650
691
std::fill_n (maskValues.begin (), remainingElements, 1 );
651
692
auto backMask = rewriter.create <arith::ConstantOp>(
652
693
loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
@@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final
960
1001
// subvector at the proper offset after bit-casting.
961
1002
auto origType = op.getVectorType ();
962
1003
auto origElements = origType.getNumElements ();
963
- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
1004
+ // Note, per-element-alignment was already verified above.
1005
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
964
1006
965
1007
auto stridedMetadata =
966
1008
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final
975
1017
getAsOpFoldResult (adaptor.getIndices ()));
976
1018
977
1019
std::optional<int64_t > foldedIntraVectorOffset =
978
- isAlignedEmulation
979
- ? 0
980
- : getConstantIntValue (linearizedInfo.intraDataOffset );
1020
+ isFullyAligned ? 0
1021
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
981
1022
982
1023
int64_t maxIntraDataOffset =
983
1024
foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
1001
1042
passthru = dynamicallyInsertSubVector (
1002
1043
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset ,
1003
1044
origElements);
1004
- } else if (!isAlignedEmulation ) {
1045
+ } else if (!isFullyAligned ) {
1005
1046
passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
1006
1047
*foldedIntraVectorOffset);
1007
1048
}
@@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
1029
1070
mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
1030
1071
linearizedInfo.intraDataOffset ,
1031
1072
origElements);
1032
- } else if (!isAlignedEmulation ) {
1073
+ } else if (!isFullyAligned ) {
1033
1074
mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
1034
1075
*foldedIntraVectorOffset);
1035
1076
}
@@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
1040
1081
result = dynamicallyExtractSubVector (
1041
1082
rewriter, loc, result, op.getPassThru (),
1042
1083
linearizedInfo.intraDataOffset , origElements);
1043
- } else if (!isAlignedEmulation ) {
1084
+ } else if (!isFullyAligned ) {
1044
1085
result = staticallyExtractSubvector (
1045
1086
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1046
1087
}
0 commit comments