Skip to content

[mlir][vector][nfc] Improve comments in getCompressedMaskOp #115663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

This PR updates and expands the high-level comment for
getCompressedMaskOp and renames input variables with more descriptive
names.

The current variable names are somewhat unclear (e.g., scale) or
derived from memref terminology (e.g., intraDataOffset from
LinearizedMemRefInfo). The updated names in this PR aim to better
align with the context and usage in the vector domain.

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This PR updates and expands the high-level comment for
getCompressedMaskOp and renames input variables with more descriptive
names.

The current variable names are somewhat unclear (e.g., scale) or
derived from memref terminology (e.g., intraDataOffset from
LinearizedMemRefInfo). The updated names in this PR aim to better
align with the context and usage in the vector domain.


Full diff: https://github.com/llvm/llvm-project/pull/115663.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+23-14)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..ef6f270b44cd62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,24 +36,33 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-/// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 1 (intraDataOffset strictly smaller than scale), the following
-/// mask:
+/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
+/// when the number of source elements spans two `i32` elements, this method
+/// will compress `vector<8xi1>` into `vector<2xi1>`.
+///
+/// The compressed/output mask value is set iff any mask in the corresponding
+/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
+/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
+/// following mask:
 ///
 ///   %mask = [1, 1, 0, 0, 0, 0]
 ///
-/// will first be padded with number of `intraDataOffset` zeros:
+/// will first be padded with number of `numFrontPadElems` zeros:
 ///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
 ///   %mask = [1, 1, 0, 0]
+///
+/// `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale,
-                                                  int intraDataOffset = 0) {
-  auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
+                                                  int numSrcElems,
+                                                  int numSrcElemsPerDest,
+                                                  int numFrontPadElems = 0) {
+  auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+                     numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -81,8 +90,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
     bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + scale - 1;
-    s0 = s0.floorDiv(scale);
+    s0 = s0 + numSrcElemsPerDest - 1;
+    s0 = s0.floorDiv(numSrcElemsPerDest);
     OpFoldResult origIndex =
         getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
     OpFoldResult maskIndex =
@@ -96,18 +105,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = intraDataOffset / scale;
-    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+    int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+    int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
-    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+    if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
       return failure();
 
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
 
-    if (intraDataOffset == 0) {
+    if (numFrontPadElems == 0) {
       newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
                                                         newMaskDimSizes);
     } else {

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This PR updates and expands the high-level comment for
getCompressedMaskOp and renames input variables with more descriptive
names.

The current variable names are somewhat unclear (e.g., scale) or
derived from memref terminology (e.g., intraDataOffset from
LinearizedMemRefInfo). The updated names in this PR aim to better
align with the context and usage in the vector domain.


Full diff: https://github.com/llvm/llvm-project/pull/115663.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+23-14)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..ef6f270b44cd62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,24 +36,33 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-/// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 1 (intraDataOffset strictly smaller than scale), the following
-/// mask:
+/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
+/// when the number of source elements spans two `i32` elements, this method
+/// will compress `vector<8xi1>` into `vector<2xi1>`.
+///
+/// The compressed/output mask value is set iff any mask in the corresponding
+/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
+/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
+/// following mask:
 ///
 ///   %mask = [1, 1, 0, 0, 0, 0]
 ///
-/// will first be padded with number of `intraDataOffset` zeros:
+/// will first be padded with number of `numFrontPadElems` zeros:
 ///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
 ///   %mask = [1, 1, 0, 0]
+///
+/// `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale,
-                                                  int intraDataOffset = 0) {
-  auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
+                                                  int numSrcElems,
+                                                  int numSrcElemsPerDest,
+                                                  int numFrontPadElems = 0) {
+  auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+                     numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -81,8 +90,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
     bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + scale - 1;
-    s0 = s0.floorDiv(scale);
+    s0 = s0 + numSrcElemsPerDest - 1;
+    s0 = s0.floorDiv(numSrcElemsPerDest);
     OpFoldResult origIndex =
         getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
     OpFoldResult maskIndex =
@@ -96,18 +105,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = intraDataOffset / scale;
-    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+    int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+    int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
-    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+    if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
       return failure();
 
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
 
-    if (intraDataOffset == 0) {
+    if (numFrontPadElems == 0) {
       newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
                                                         newMaskDimSizes);
     } else {

Copy link

github-actions bot commented Nov 10, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 0baa6a7272970257fd6f527e95eb7cb18ba3361c afc03145f6c7354d9f42c8364ad6efcb34352236 --extensions cpp -- mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a84799..4958a31799 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -75,7 +75,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   int numSrcElemsPerDest,
                                                   int numFrontPadElems = 0) {
 
-  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+  assert(numFrontPadElems < numSrcElemsPerDest &&
+         "intraDataOffset must be less than scale");
 
   auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
                      numSrcElemsPerDest;

///
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// will first be padded with number of `numFrontPadElems` zeros:
Copy link
Member

@lialan lialan Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much appreciated change! it makes the variable easier to follow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I also like the new name better because is more meaningful!

Comment on lines 39 to 41
/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
/// when the number of source elements spans two `i32` elements, this method
/// will compress `vector<8xi1>` into `vector<2xi1>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
/// when the number of source elements spans two `i32` elements, this method
/// will compress `vector<8xi1>` into `vector<2xi1>`.
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight element `i8`vector with `i32` and when the number of
/// source elements spans two `i32`elements, this method will compress
/// `vector<8xi1>` into `vector<2xi1>`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when emulating an eight element i8vector with i32 and when the number of source elements spans two i32elements, this method will compress vector<8xi1> into vector<2xi1>.

Thanks! I want to make sure that I understand your suggestion correctly and that we are on the same page :)

Below is some ASCII explaining what I had in mind.

CASE 1

In my comment, I was thinking about this example of 2 i8 elements (value 1 and value 2) occupying 2 i32 elements:

32-bit Integer 1: | 00000000 | 00000000 | 00000000 | 00001010 |
                  |          |          |          |  value 1 |

32-bit Integer 2: | 00001111 | 00000000 | 00000000 | 00000000 |
                  | value 2  |          |          |          |

In this case, the uncompressed mask would be vector<2xi1> = {1, 1}(2 x i8), and the compressed one would ... also be vector<2xi1> = {1, 1} (2 x i32).

CASE 2
Here's a similar example, but the i8 values are distributed differently:

32-bit Integer 1: | 00000000 | 00000000 | 00001111 | 00001010 |
                  |          |          | value 2  | value 1  |

In this case, the uncompressed mask would be vector<2xi1> = {1, 1} (2 x i8), and the compressed one would be vector<1xi1> = {1} (1 x i32).

QUESTION 1

Is the above consistent with how you understand all of this?

QUESTION 2
In your suggestion you mentioned an eight element i8vector - are you proposing to build a comment around vector<8xi8>, as opposed to vector<2xi8> as I did? I don't mind, I mostly wanted to keep things simple and straightforward 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space

A1: case 1 and case 2 examples are all (what I understand) is correct.
A2: I think it is easier to understand as 8xi8 can bitcast to 2xi32, which is consistent with the idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we support arbitrary distribution formats. Elements are either packed or one per byte.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the discussion! @lialan , I've incorporated your suggestion (with additional tweaks from me).

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave the approval to @lialan who knows the details better.

///
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// will first be padded with number of `numFrontPadElems` zeros:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I also like the new name better because is more meaningful!

@banach-space banach-space force-pushed the andrzej/emulate_narrow_type_update_3 branch from adbd169 to b79654d Compare November 12, 2024 15:56
This PR updates and expands the high-level comment for
`getCompressedMaskOp` and renames input variables with more descriptive
names.

The current variable names are somewhat unclear (e.g., `scale`) or
derived from `memref` terminology (e.g., `intraDataOffset` from
`LinearizedMemRefInfo`). The updated names in this PR aim to better
align with the context and usage in the vector domain.
@banach-space banach-space force-pushed the andrzej/emulate_narrow_type_update_3 branch from b79654d to afc0314 Compare November 13, 2024 16:40
@banach-space banach-space merged commit 7a31f3c into llvm:main Nov 13, 2024
5 of 7 checks passed
@banach-space banach-space deleted the andrzej/emulate_narrow_type_update_3 branch November 14, 2024 08:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants