Skip to content

[mlir][vector] Document ConvertVectorStore + unify var names (nfc) #126422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,45 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//

// TODO: Document-me
// Emulate `vector.store` using a multi-byte container type.
//
// The container type is obtained through Op adaptor and would normally be
// generated via `NarrowTypeEmulationConverter`.
//
// EXAMPLE 1
// (aligned store of i4, emulated using i8 as the container type)
//
// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
//
// is rewritten as:
//
// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
// vector.store %src_bitcast, %dest_bitcast[%idx]
// : memref<16xi8>, vector<4xi8>
//
// EXAMPLE 2
// (unaligned store of i2, emulated using i8 as the container type)
//
// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
//
// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
// is modelled using 3 bytes:
//
// Byte 0 Byte 1 Byte 2
// +----------+----------+----------+
// | oooooooo | ooooNNNN | NNoooooo |
// +----------+----------+----------+
//
// N - (N)ew entries (i.e. to be overwritten by vector.store)
// o - (o)ld entries (to be preserved)
//
// For the generated output in the non-atomic case, see:
// * @vector_store_i2_const_index_two_partial_stores`
// in:
// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
//
// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int numSrcElemsPerDest = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;

// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
Expand All @@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>

auto origElements = valueToStore.getType().getNumElements();
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
// Note, per-element-alignment was already verified above.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so far we cannot support non-aligned cases for per element? i.e. we cannot support 7bit emulation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, all patterns require per-element alignment, yes. We should extract that condition somewhere to avoid repeating it.

On a related note, given how quickly this logic is growing, I feel that we should try to avoid emulating: i3, i5, i7. Unless that's really required. I am just worried that the potential cost of supporting that would be relatively high.

bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedNumFrontPadElems =
isAlignedEmulation
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
Expand All @@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
!isAlignedEmulation || *foldedNumFrontPadElems != 0;
!isFullyAligned || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / numSrcElemsPerDest;
auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
Expand Down Expand Up @@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {

// Build a mask used for rmw.
auto subWidthStoreMaskType =
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());

auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;

Expand All @@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// with the unaligned part so that the rest elements are aligned to width
// boundary.
auto frontSubWidthStoreElem =
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
(emulatedPerContainerElem - *foldedNumFrontPadElems) %
emulatedPerContainerElem;
if (frontSubWidthStoreElem > 0) {
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
Expand All @@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));

currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
Expand All @@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// After the previous step, the store address is aligned to the emulated
// width boundary.
int64_t fullWidthStoreSize =
(origElements - currentSourceIndex) / numSrcElemsPerDest;
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
(origElements - currentSourceIndex) / emulatedPerContainerElem;
int64_t numNonFullWidthElements =
fullWidthStoreSize * emulatedPerContainerElem;
if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
Expand All @@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto originType = cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
{originType.getNumElements() / emulatedPerContainerElem},
memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
Expand All @@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentSourceIndex, remainingElements, 0);

// Generate back mask.
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
Expand Down Expand Up @@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedIntraVectorOffset =
isAlignedEmulation
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
Expand All @@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
Expand Down Expand Up @@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
Expand All @@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
Expand Down
Loading