Skip to content

[mlir][vector] Document ConvertVectorStore + unify var names (nfc) #126422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2025

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Feb 9, 2025

  1. Documents ConvertVectorStore. As the generated output is rather complex, I
    have refined the comments + variable names in:

    • "vector-emulate-narrow-type-unaligned-non-atomic.mlir"
      to serve as reference for this pattern.
  2. As a follow-on for [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) #123527, renames isAlignedEmulation to isFullyAligned
    and numSrcElemsPerDest to emulatedPerContainerElem.

1. Documents `ConvertVectorStore`.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to
   `isFullyAligned` and `numSrcElemsPerDest` to
   `emulatedPerContainerElem`.
@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  1. Documents ConvertVectorStore.

  2. As a follow-on for #123527, renames isAlignedEmulation to
    isFullyAligned and numSrcElemsPerDest to
    emulatedPerContainerElem.


Full diff: https://github.com/llvm/llvm-project/pull/126422.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+101-20)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559caf..bb7449d85f079a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -432,7 +432,86 @@ namespace {
 // ConvertVectorStore
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+// Emulate vector.store using a multi-byte container type
+//
+// The container type is obtained through Op adaptor and would normally be
+// generated via `NarrowTypeEmulationConverter`.
+//
+// EXAMPLE 1
+// (aligned store of i4, emulated using i8)
+//
+//      vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
+//
+// is rewritten as:
+//
+//      %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
+//      vector.store %src_bitcast, %dest_bitcast[%idx]
+//        : memref<16xi8>, vector<4xi8>
+//
+// EXAMPLE 2
+// (unaligned store of i2, emulated using i8, non-atomic)
+//
+//    vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+//
+// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
+// is modelled using 3 bytes:
+//
+//    Byte 0     Byte 1     Byte 2
+// +----------+----------+----------+
+// | oooooooo | ooooNNNN | NNoooooo |
+// +----------+----------+----------+
+//
+// N - (N)ew entries (i.e. to be overwritten by vector.store)
+// o - (o)ld entries (to be preserved)
+//
+// The following 2 RMW sequences will be generated:
+//
+//    %init = arith.constant dense<0> : vector<4xi2>
+//
+//    (RMW sequence for Byte 1)
+//    (Mask for 4 x i2 elements, i.e. a byte)
+//    %mask_1 = arith.constant dense<[false, false, true, true]>
+//    %src_slice_1 = vector.extract_strided_slice %src
+//      {offsets = [0], sizes = [2], strides = [1]}
+//      : vector<3xi2> to vector<2xi2>
+//    %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
+//      {offsets = [2], strides = [1]}
+//      : vector<2xi2> into vector<4xi2>
+//    %dest_byte_1 = vector.load %dest[%c1]
+//    %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
+//      : vector<1xi8> to vector<4xi2>
+//    %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
+//    %res_byte_1_as_i8 = vector.bitcast %res_byte_1
+//    vector.store %res_byte_1_as_i8, %dest[1]
+
+//    (RMW sequence for Byte 22)
+//    (Mask for 4 x i2 elements, i.e. a byte)
+//    %mask_2 = arith.constant dense<[true, false, false, false]>
+//    %src_slice_2 = vector.extract_strided_slice %src
+//      : {offsets = [2], sizes = [1], strides = [1]}
+//      : vector<3xi2> to vector<1xi2>
+//    %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
+//      : {offsets = [0], strides = [1]}
+//      : vector<1xi2> into vector<4xi2>
+//    %dest_byte_2 = vector.load %dest[%c2]
+//    %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
+//      : vector<1xi8> to vector<4xi2>
+//    vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
+//    %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
+//    vector.store %res_byte_1_as_i8, %dest[2]
+//
+// NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
+// NOTE: This example assumes that `disableAtomicRMW` was set.
+//
+// EXAMPLE 3
+// (unaligned store of i2, emulated using i8, atomic)
+//
+// Similar to EXAMPLE 2, with the addition of
+//  * `memref.generic_atomic_rmw`,
+// to guarantee atomicity. The actual output is skipped for brevity.
+//
+// NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
+// `false` to generate non-atomic RMW sequences.
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
           op, "impossible to pack emulated elements into container elements "
               "(bit-wise misalignment)");
     }
-    int numSrcElemsPerDest = containerBits / emulatedBits;
+    int emulatedPerContainerElem = containerBits / emulatedBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = valueToStore.getType().getNumElements();
-    bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedNumFrontPadElems =
-        isAlignedEmulation
+        isFullyAligned
             ? 0
             : getConstantIntValue(linearizedInfo.intraDataOffset);
 
@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // need unaligned emulation because the store address is aligned and the
     // source is a whole byte.
     bool emulationRequiresPartialStores =
-        !isAlignedEmulation || *foldedNumFrontPadElems != 0;
+        !isFullyAligned || *foldedNumFrontPadElems != 0;
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
-      auto numElements = origElements / numSrcElemsPerDest;
+      auto numElements = origElements / emulatedPerContainerElem;
       auto bitCast = rewriter.create<vector::BitCastOp>(
           loc, VectorType::get(numElements, containerElemTy),
           op.getValueToStore());
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     // Build a mask used for rmw.
     auto subWidthStoreMaskType =
-        VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+        VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
 
     auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
 
@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // with the unaligned part so that the rest elements are aligned to width
     // boundary.
     auto frontSubWidthStoreElem =
-        (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+        (emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem;
     if (frontSubWidthStoreElem > 0) {
-      SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
-      if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+      SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
+      if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
         std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
                     origElements, true);
         frontSubWidthStoreElem = origElements;
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       auto frontMask = rewriter.create<arith::ConstantOp>(
           loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
 
-      currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+      currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
       auto value =
           extractSliceIntoByte(rewriter, loc, valueToStore, 0,
                                frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // After the previous step, the store address is aligned to the emulated
     // width boundary.
     int64_t fullWidthStoreSize =
-        (origElements - currentSourceIndex) / numSrcElemsPerDest;
-    int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+        (origElements - currentSourceIndex) / emulatedPerContainerElem;
+    int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem;
     if (fullWidthStoreSize > 0) {
       auto fullWidthStorePart = staticallyExtractSubvector(
           rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       auto originType = cast<VectorType>(fullWidthStorePart.getType());
       auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
       auto storeType = VectorType::get(
-          {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+          {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType);
       auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
                                                         fullWidthStorePart);
       rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
                                currentSourceIndex, remainingElements, 0);
 
       // Generate back mask.
-      auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+      auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
       std::fill_n(maskValues.begin(), remainingElements, 1);
       auto backMask = rewriter.create<arith::ConstantOp>(
           loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isAlignedEmulation
+        isFullyAligned
             ? 0
             : getConstantIntValue(linearizedInfo.intraDataOffset);
 
@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
       passthru = dynamicallyInsertSubVector(
           rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
           origElements);
-    } else if (!isAlignedEmulation) {
+    } else if (!isFullyAligned) {
       passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
                                            *foldedIntraVectorOffset);
     }
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
       mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
                                         linearizedInfo.intraDataOffset,
                                         origElements);
-    } else if (!isAlignedEmulation) {
+    } else if (!isFullyAligned) {
       mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
                                        *foldedIntraVectorOffset);
     }
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
       result = dynamicallyExtractSubVector(
           rewriter, loc, result, op.getPassThru(),
           linearizedInfo.intraDataOffset, origElements);
-    } else if (!isAlignedEmulation) {
+    } else if (!isFullyAligned) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }

@banach-space
Copy link
Contributor Author

@lialan , you've been contributing most patches in this area recently, would you have the cycles to review? Thanks!

Copy link
Member

@lialan lialan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @banach-space , I do have some of my own opinions on how to better comment the code. We can discuss the details.

@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>

auto origElements = valueToStore.getType().getNumElements();
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
// Note, per-element-alignment was already verified above.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so far we cannot support non-aligned cases for per element? i.e. we cannot support 7bit emulation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, all patterns require per-element alignment, yes. We should extract that condition somewhere to avoid repeating it.

On a related note, given how quickly this logic is growing, I feel that we should try to avoid emulating: i3, i5, i7. Unless that's really required. I am just worried that the potential cost of supporting that would be relatively high.

Comment on lines 468 to 496
//
// %init = arith.constant dense<0> : vector<4xi2>
//
// (RMW sequence for Byte 1)
// (Mask for 4 x i2 elements, i.e. a byte)
// %mask_1 = arith.constant dense<[false, false, true, true]>
// %src_slice_1 = vector.extract_strided_slice %src
// {offsets = [0], sizes = [2], strides = [1]}
// : vector<3xi2> to vector<2xi2>
// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
// {offsets = [2], strides = [1]}
// : vector<2xi2> into vector<4xi2>
// %dest_byte_1 = vector.load %dest[%c1]
// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
// : vector<1xi8> to vector<4xi2>
// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
// %res_byte_1_as_i8 = vector.bitcast %res_byte_1
// vector.store %res_byte_1_as_i8, %dest[1]

// (RMW sequence for Byte 22)
// (Mask for 4 x i2 elements, i.e. a byte)
// %mask_2 = arith.constant dense<[true, false, false, false]>
// %src_slice_2 = vector.extract_strided_slice %src
// : {offsets = [2], sizes = [1], strides = [1]}
// : vector<3xi2> to vector<1xi2>
// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
// : {offsets = [0], strides = [1]}
// : vector<1xi2> into vector<4xi2>
// %dest_byte_2 = vector.load %dest[%c2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refer the reader to take a look at the corresponding test case, and we try to annotate/comment more precisely in the best case instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I was wondering how to avoid this long comment and your suggestion is exactly what we should be doing! 🙏🏻

As this example is taken from "vector-emulate-narrow-type-unaligned-non-atomic.mlir", that's the test file that I've updated to help here. Please check the latest update.

Note, I've made quite a few changes:

  • Extended comments.
  • Fix DOWNCAST vs UPCAST.
  • Renamed some variables to avoid generic names (e.g. %arg0 -> %src, %0 -> %dest).
  • Added more CHECK-LINES, e.g. // CHECK-SAME: : vector<1xi8> to vector<4xi2> to make sure that the right casting is generated.
  • Followed formatting style from vectorize-convolution.mlir. IMHO it's a very "readable" style that's particularly handy for complex tests like these ones.

I appreciate that these are quite intrusive changes, but since it's meant as documentation, it felt like the right thing to do. But I am happy to adapt/revert if you feel that this is too much.

Thanks for reviewing!

… (nfc)

Remove comment from source file, refactor the test file
@lialan
Copy link
Member

lialan commented Feb 14, 2025

LGTM!

@banach-space banach-space merged commit ad948fa into llvm:main Feb 15, 2025
8 checks passed
@banach-space banach-space deleted the andrzej/refactor_narrow_type_5 branch February 15, 2025 20:16
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…lvm#126422)

1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
  have refined the comments + variable names in:
    * "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
  to serve as reference for this pattern.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned`
  and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants