-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Document ConvertVectorStore
+ unify var names (nfc)
#126422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Document ConvertVectorStore
+ unify var names (nfc)
#126422
Conversation
1. Documents `ConvertVectorStore`. 2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned` and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/126422.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559caf..bb7449d85f079a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -432,7 +432,86 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+// Emulate vector.store using a multi-byte container type
+//
+// The container type is obtained through Op adaptor and would normally be
+// generated via `NarrowTypeEmulationConverter`.
+//
+// EXAMPLE 1
+// (aligned store of i4, emulated using i8)
+//
+// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
+//
+// is rewritten as:
+//
+// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
+// vector.store %src_bitcast, %dest_bitcast[%idx]
+// : memref<16xi8>, vector<4xi8>
+//
+// EXAMPLE 2
+// (unaligned store of i2, emulated using i8, non-atomic)
+//
+// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+//
+// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
+// is modelled using 3 bytes:
+//
+// Byte 0 Byte 1 Byte 2
+// +----------+----------+----------+
+// | oooooooo | ooooNNNN | NNoooooo |
+// +----------+----------+----------+
+//
+// N - (N)ew entries (i.e. to be overwritten by vector.store)
+// o - (o)ld entries (to be preserved)
+//
+// The following 2 RMW sequences will be generated:
+//
+// %init = arith.constant dense<0> : vector<4xi2>
+//
+// (RMW sequence for Byte 1)
+// (Mask for 4 x i2 elements, i.e. a byte)
+// %mask_1 = arith.constant dense<[false, false, true, true]>
+// %src_slice_1 = vector.extract_strided_slice %src
+// {offsets = [0], sizes = [2], strides = [1]}
+// : vector<3xi2> to vector<2xi2>
+// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
+// {offsets = [2], strides = [1]}
+// : vector<2xi2> into vector<4xi2>
+// %dest_byte_1 = vector.load %dest[%c1]
+// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
+// : vector<1xi8> to vector<4xi2>
+// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
+// %res_byte_1_as_i8 = vector.bitcast %res_byte_1
+// vector.store %res_byte_1_as_i8, %dest[1]
+
+// (RMW sequence for Byte 22)
+// (Mask for 4 x i2 elements, i.e. a byte)
+// %mask_2 = arith.constant dense<[true, false, false, false]>
+// %src_slice_2 = vector.extract_strided_slice %src
+// : {offsets = [2], sizes = [1], strides = [1]}
+// : vector<3xi2> to vector<1xi2>
+// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
+// : {offsets = [0], strides = [1]}
+// : vector<1xi2> into vector<4xi2>
+// %dest_byte_2 = vector.load %dest[%c2]
+// %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
+// : vector<1xi8> to vector<4xi2>
+// vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
+// %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
+// vector.store %res_byte_1_as_i8, %dest[2]
+//
+// NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
+// NOTE: This example assumes that `disableAtomicRMW` was set.
+//
+// EXAMPLE 3
+// (unaligned store of i2, emulated using i8, atomic)
+//
+// Similar to EXAMPLE 2, with the addition of
+// * `memref.generic_atomic_rmw`,
+// to guarantee atomicity. The actual output is skipped for brevity.
+//
+// NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
+// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
- int numSrcElemsPerDest = containerBits / emulatedBits;
+ int emulatedPerContainerElem = containerBits / emulatedBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isAlignedEmulation
+ isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
- !isAlignedEmulation || *foldedNumFrontPadElems != 0;
+ !isFullyAligned || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
- auto numElements = origElements / numSrcElemsPerDest;
+ auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Build a mask used for rmw.
auto subWidthStoreMaskType =
- VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// with the unaligned part so that the rest elements are aligned to width
// boundary.
auto frontSubWidthStoreElem =
- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ (emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem;
if (frontSubWidthStoreElem > 0) {
- SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// After the previous step, the store address is aligned to the emulated
// width boundary.
int64_t fullWidthStoreSize =
- (origElements - currentSourceIndex) / numSrcElemsPerDest;
- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ (origElements - currentSourceIndex) / emulatedPerContainerElem;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem;
if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto originType = cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
- {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentSourceIndex, remainingElements, 0);
// Generate back mask.
- auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isAlignedEmulation
+ isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
|
@lialan , you've been contributing most patches in this area recently, would you have the cycles to review? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @banach-space , I do have some of my own opinions on how to better comment the code. We can discuss the details.
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |||
// vector<4xi8> | |||
|
|||
auto origElements = valueToStore.getType().getNumElements(); | |||
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; | |||
// Note, per-element-alignment was already verified above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so far we cannot support non-aligned cases for per element? i.e. we cannot support 7bit emulation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, all patterns require per-element alignment, yes. We should extract that condition somewhere to avoid repeating it.
On a related note, given how quickly this logic is growing, I feel that we should try to avoid emulating: i3
, i5
, i7
. Unless that's really required. I am just worried that the potential cost of supporting that would be relatively high.
// | ||
// %init = arith.constant dense<0> : vector<4xi2> | ||
// | ||
// (RMW sequence for Byte 1) | ||
// (Mask for 4 x i2 elements, i.e. a byte) | ||
// %mask_1 = arith.constant dense<[false, false, true, true]> | ||
// %src_slice_1 = vector.extract_strided_slice %src | ||
// {offsets = [0], sizes = [2], strides = [1]} | ||
// : vector<3xi2> to vector<2xi2> | ||
// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init | ||
// {offsets = [2], strides = [1]} | ||
// : vector<2xi2> into vector<4xi2> | ||
// %dest_byte_1 = vector.load %dest[%c1] | ||
// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1 | ||
// : vector<1xi8> to vector<4xi2> | ||
// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2 | ||
// %res_byte_1_as_i8 = vector.bitcast %res_byte_1 | ||
// vector.store %res_byte_1_as_i8, %dest[1] | ||
|
||
// (RMW sequence for Byte 22) | ||
// (Mask for 4 x i2 elements, i.e. a byte) | ||
// %mask_2 = arith.constant dense<[true, false, false, false]> | ||
// %src_slice_2 = vector.extract_strided_slice %src | ||
// : {offsets = [2], sizes = [1], strides = [1]} | ||
// : vector<3xi2> to vector<1xi2> | ||
// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init | ||
// : {offsets = [0], strides = [1]} | ||
// : vector<1xi2> into vector<4xi2> | ||
// %dest_byte_2 = vector.load %dest[%c2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can refer the reader to take a look at the corresponding test case, and we try to annotate/comment more precisely in the best case instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I was wondering how to avoid this long comment and your suggestion is exactly what we should be doing! 🙏🏻
As this example is taken from "vector-emulate-narrow-type-unaligned-non-atomic.mlir", that's the test file that I've updated to help here. Please check the latest update.
Note, I've made quite a few changes:
- Extended comments.
- Fix
DOWNCAST
vsUPCAST
. - Renamed some variables to avoid generic names (e.g.
%arg0
->%src
,%0
->%dest
). - Added more
CHECK-LINES
, e.g.// CHECK-SAME: : vector<1xi8> to vector<4xi2>
to make sure that the right casting is generated. - Followed formatting style from vectorize-convolution.mlir. IMHO it's a very "readable" style that's particularly handy for complex tests like these ones.
I appreciate that these are quite intrusive changes, but since it's meant as documentation, it felt like the right thing to do. But I am happy to adapt/revert if you feel that this is too much.
Thanks for reviewing!
… (nfc) Remove comment from source file, refactor the test file
LGTM! |
…lvm#126422) 1. Documents `ConvertVectorStore`. As the generated output is rather complex, I have refined the comments + variable names in: * "vector-emulate-narrow-type-unaligned-non-atomic.mlir", to serve as reference for this pattern. 2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned` and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
Documents
ConvertVectorStore
. As the generated output is rather complex, Ihave refined the comments + variable names in:
to serve as reference for this pattern.
As a follow-on for [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) #123527, renames
isAlignedEmulation
toisFullyAligned
and
numSrcElemsPerDest
toemulatedPerContainerElem
.