Skip to content

Commit ad948fa

Browse files
authored
[mlir][vector] Document ConvertVectorStore + unify var names (nfc) (#126422)
1. Documents `ConvertVectorStore`. As the generated output is rather complex, I have refined the comments + variable names in: * "vector-emulate-narrow-type-unaligned-non-atomic.mlir", to serve as reference for this pattern. 2. As a follow-on for #123527, renames `isAlignedEmulation` to `isFullyAligned` and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
1 parent 02fb976 commit ad948fa

File tree

2 files changed

+185
-119
lines changed

2 files changed

+185
-119
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,45 @@ namespace {
432432
// ConvertVectorStore
433433
//===----------------------------------------------------------------------===//
434434

435-
// TODO: Document-me
435+
// Emulate `vector.store` using a multi-byte container type.
436+
//
437+
// The container type is obtained through Op adaptor and would normally be
438+
// generated via `NarrowTypeEmulationConverter`.
439+
//
440+
// EXAMPLE 1
441+
// (aligned store of i4, emulated using i8 as the container type)
442+
//
443+
// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
444+
//
445+
// is rewritten as:
446+
//
447+
// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
448+
// vector.store %src_bitcast, %dest_bitcast[%idx]
449+
// : memref<16xi8>, vector<4xi8>
450+
//
451+
// EXAMPLE 2
452+
// (unaligned store of i2, emulated using i8 as the container type)
453+
//
454+
// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
455+
//
456+
// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
457+
// is modelled using 3 bytes:
458+
//
459+
// Byte 0 Byte 1 Byte 2
460+
// +----------+----------+----------+
461+
// | oooooooo | ooooNNNN | NNoooooo |
462+
// +----------+----------+----------+
463+
//
464+
// N - (N)ew entries (i.e. to be overwritten by vector.store)
465+
// o - (o)ld entries (to be preserved)
466+
//
467+
// For the generated output in the non-atomic case, see:
468+
// * @vector_store_i2_const_index_two_partial_stores`
469+
// in:
470+
// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
471+
//
472+
// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
473+
// `false` to generate non-atomic RMW sequences.
436474
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
437475
using OpConversionPattern::OpConversionPattern;
438476

@@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
464502
op, "impossible to pack emulated elements into container elements "
465503
"(bit-wise misalignment)");
466504
}
467-
int numSrcElemsPerDest = containerBits / emulatedBits;
505+
int emulatedPerContainerElem = containerBits / emulatedBits;
468506

469507
// Adjust the number of elements to store when emulating narrow types.
470508
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
480518
// vector<4xi8>
481519

482520
auto origElements = valueToStore.getType().getNumElements();
483-
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
521+
// Note, per-element-alignment was already verified above.
522+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
484523

485524
auto stridedMetadata =
486525
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
496535
getAsOpFoldResult(adaptor.getIndices()));
497536

498537
std::optional<int64_t> foldedNumFrontPadElems =
499-
isAlignedEmulation
500-
? 0
501-
: getConstantIntValue(linearizedInfo.intraDataOffset);
538+
isFullyAligned ? 0
539+
: getConstantIntValue(linearizedInfo.intraDataOffset);
502540

503541
if (!foldedNumFrontPadElems) {
504542
return rewriter.notifyMatchFailure(
@@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
516554
// need unaligned emulation because the store address is aligned and the
517555
// source is a whole byte.
518556
bool emulationRequiresPartialStores =
519-
!isAlignedEmulation || *foldedNumFrontPadElems != 0;
557+
!isFullyAligned || *foldedNumFrontPadElems != 0;
520558
if (!emulationRequiresPartialStores) {
521559
// Basic case: storing full bytes.
522-
auto numElements = origElements / numSrcElemsPerDest;
560+
auto numElements = origElements / emulatedPerContainerElem;
523561
auto bitCast = rewriter.create<vector::BitCastOp>(
524562
loc, VectorType::get(numElements, containerElemTy),
525563
op.getValueToStore());
@@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
567605

568606
// Build a mask used for rmw.
569607
auto subWidthStoreMaskType =
570-
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
608+
VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
571609

572610
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573611

@@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
576614
// with the unaligned part so that the rest elements are aligned to width
577615
// boundary.
578616
auto frontSubWidthStoreElem =
579-
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
617+
(emulatedPerContainerElem - *foldedNumFrontPadElems) %
618+
emulatedPerContainerElem;
580619
if (frontSubWidthStoreElem > 0) {
581-
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
582-
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
620+
SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
621+
if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
583622
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
584623
origElements, true);
585624
frontSubWidthStoreElem = origElements;
@@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
590629
auto frontMask = rewriter.create<arith::ConstantOp>(
591630
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
592631

593-
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
632+
currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
594633
auto value =
595634
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
596635
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
614653
// After the previous step, the store address is aligned to the emulated
615654
// width boundary.
616655
int64_t fullWidthStoreSize =
617-
(origElements - currentSourceIndex) / numSrcElemsPerDest;
618-
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
656+
(origElements - currentSourceIndex) / emulatedPerContainerElem;
657+
int64_t numNonFullWidthElements =
658+
fullWidthStoreSize * emulatedPerContainerElem;
619659
if (fullWidthStoreSize > 0) {
620660
auto fullWidthStorePart = staticallyExtractSubvector(
621661
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624664
auto originType = cast<VectorType>(fullWidthStorePart.getType());
625665
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
626666
auto storeType = VectorType::get(
627-
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
667+
{originType.getNumElements() / emulatedPerContainerElem},
668+
memrefElemType);
628669
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
629670
fullWidthStorePart);
630671
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
646687
currentSourceIndex, remainingElements, 0);
647688

648689
// Generate back mask.
649-
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
690+
auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
650691
std::fill_n(maskValues.begin(), remainingElements, 1);
651692
auto backMask = rewriter.create<arith::ConstantOp>(
652693
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final
9601001
// subvector at the proper offset after bit-casting.
9611002
auto origType = op.getVectorType();
9621003
auto origElements = origType.getNumElements();
963-
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
1004+
// Note, per-element-alignment was already verified above.
1005+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
9641006

9651007
auto stridedMetadata =
9661008
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final
9751017
getAsOpFoldResult(adaptor.getIndices()));
9761018

9771019
std::optional<int64_t> foldedIntraVectorOffset =
978-
isAlignedEmulation
979-
? 0
980-
: getConstantIntValue(linearizedInfo.intraDataOffset);
1020+
isFullyAligned ? 0
1021+
: getConstantIntValue(linearizedInfo.intraDataOffset);
9811022

9821023
int64_t maxIntraDataOffset =
9831024
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
10011042
passthru = dynamicallyInsertSubVector(
10021043
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
10031044
origElements);
1004-
} else if (!isAlignedEmulation) {
1045+
} else if (!isFullyAligned) {
10051046
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
10061047
*foldedIntraVectorOffset);
10071048
}
@@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
10291070
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
10301071
linearizedInfo.intraDataOffset,
10311072
origElements);
1032-
} else if (!isAlignedEmulation) {
1073+
} else if (!isFullyAligned) {
10331074
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
10341075
*foldedIntraVectorOffset);
10351076
}
@@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
10401081
result = dynamicallyExtractSubVector(
10411082
rewriter, loc, result, op.getPassThru(),
10421083
linearizedInfo.intraDataOffset, origElements);
1043-
} else if (!isAlignedEmulation) {
1084+
} else if (!isFullyAligned) {
10441085
result = staticallyExtractSubvector(
10451086
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10461087
}

0 commit comments

Comments
 (0)