-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector][nfc] Improve comments in getCompressedMaskOp
#115663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector][nfc] Improve comments in getCompressedMaskOp
#115663
Conversation
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThis PR updates and expands the high-level comment for The current variable names are somewhat unclear (e.g., Full diff: https://github.com/llvm/llvm-project/pull/115663.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..ef6f270b44cd62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,24 +36,33 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-/// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 1 (intraDataOffset strictly smaller than scale), the following
-/// mask:
+/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
+/// when the number of source elements spans two `i32` elements, this method
+/// will compress `vector<8xi1>` into `vector<2xi1>`.
+///
+/// The compressed/output mask value is set iff any mask in the corresponding
+/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
+/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
+/// following mask:
///
/// %mask = [1, 1, 0, 0, 0, 0]
///
-/// will first be padded with number of `intraDataOffset` zeros:
+/// will first be padded with number of `numFrontPadElems` zeros:
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
///
/// %mask = [1, 1, 0, 0]
+///
+/// `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
- int origElements, int scale,
- int intraDataOffset = 0) {
- auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
+ int numSrcElems,
+ int numSrcElemsPerDest,
+ int numFrontPadElems = 0) {
+ auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+ numSrcElemsPerDest;
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
@@ -81,8 +90,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + scale - 1;
- s0 = s0.floorDiv(scale);
+ s0 = s0 + numSrcElemsPerDest - 1;
+ s0 = s0.floorDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex =
@@ -96,18 +105,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = intraDataOffset / scale;
- int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+ int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
// TODO: we only want the mask between [startIndex, maskIndex] to be true,
// the rest are false.
- if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+ if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
return failure();
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);
- if (intraDataOffset == 0) {
+ if (numFrontPadElems == 0) {
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
} else {
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis PR updates and expands the high-level comment for The current variable names are somewhat unclear (e.g., Full diff: https://github.com/llvm/llvm-project/pull/115663.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..ef6f270b44cd62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,24 +36,33 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-/// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 1 (intraDataOffset strictly smaller than scale), the following
-/// mask:
+/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
+/// when the number of source elements spans two `i32` elements, this method
+/// will compress `vector<8xi1>` into `vector<2xi1>`.
+///
+/// The compressed/output mask value is set iff any mask in the corresponding
+/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
+/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
+/// following mask:
///
/// %mask = [1, 1, 0, 0, 0, 0]
///
-/// will first be padded with number of `intraDataOffset` zeros:
+/// will first be padded with number of `numFrontPadElems` zeros:
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
///
/// %mask = [1, 1, 0, 0]
+///
+/// `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
- int origElements, int scale,
- int intraDataOffset = 0) {
- auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
+ int numSrcElems,
+ int numSrcElemsPerDest,
+ int numFrontPadElems = 0) {
+ auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+ numSrcElemsPerDest;
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
@@ -81,8 +90,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + scale - 1;
- s0 = s0.floorDiv(scale);
+ s0 = s0 + numSrcElemsPerDest - 1;
+ s0 = s0.floorDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex =
@@ -96,18 +105,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = intraDataOffset / scale;
- int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+ int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
// TODO: we only want the mask between [startIndex, maskIndex] to be true,
// the rest are false.
- if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+ if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
return failure();
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);
- if (intraDataOffset == 0) {
+ if (numFrontPadElems == 0) {
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
} else {
|
You can test this locally with the following command:git-clang-format --diff 0baa6a7272970257fd6f527e95eb7cb18ba3361c afc03145f6c7354d9f42c8364ad6efcb34352236 --extensions cpp -- mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a84799..4958a31799 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -75,7 +75,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {
- assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+ assert(numFrontPadElems < numSrcElemsPerDest &&
+ "intraDataOffset must be less than scale");
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;
|
5967f2d
to
adbd169
Compare
/// | ||
/// %mask = [1, 1, 0, 0, 0, 0] | ||
/// | ||
/// will first be padded with number of `intraDataOffset` zeros: | ||
/// will first be padded with number of `numFrontPadElems` zeros: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much appreciated change! it makes the variable easier to follow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I also like the new name better because is more meaningful!
/// Returns a compressed mask. For example, when emulating `i8` with `i32` and | ||
/// when the number of source elements spans two `i32` elements, this method | ||
/// will compress `vector<8xi1>` into `vector<2xi1>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Returns a compressed mask. For example, when emulating `i8` with `i32` and | |
/// when the number of source elements spans two `i32` elements, this method | |
/// will compress `vector<8xi1>` into `vector<2xi1>`. | |
/// Returns a compressed mask for the emulated vector. For example, when | |
/// emulating an eight element `i8`vector with `i32` and when the number of | |
/// source elements spans two `i32`elements, this method will compress | |
/// `vector<8xi1>` into `vector<2xi1>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when emulating an eight element
i8
vector withi32
and when the number of source elements spans twoi32
elements, this method will compressvector<8xi1>
intovector<2xi1>
.
Thanks! I want to make sure that I understand your suggestion correctly and that we are on the same page :)
Below is some ASCII explaining what I had in mind.
CASE 1
In my comment, I was thinking about this example of 2 i8
elements (value 1
and value 2
) occupying 2 i32
elements:
32-bit Integer 1: | 00000000 | 00000000 | 00000000 | 00001010 |
| | | | value 1 |
32-bit Integer 2: | 00001111 | 00000000 | 00000000 | 00000000 |
| value 2 | | | |
In this case, the uncompressed mask would be vector<2xi1> = {1, 1}
(2 x i8
), and the compressed one would ... also be vector<2xi1> = {1, 1}
(2 x i32
).
CASE 2
Here's a similar example, but the i8
values are distributed differently:
32-bit Integer 1: | 00000000 | 00000000 | 00001111 | 00001010 |
| | | value 2 | value 1 |
In this case, the uncompressed mask would be vector<2xi1> = {1, 1}
(2 x i8
), and the compressed one would be vector<1xi1> = {1}
(1 x i32
).
QUESTION 1
Is the above consistent with how you understand all of this?
QUESTION 2
In your suggestion you mentioned an eight element i8vector
- are you proposing to build a comment around vector<8xi8>
, as opposed to vector<2xi8>
as I did? I don't mind, I mostly wanted to keep things simple and straightforward 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A1: case 1 and case 2 examples are all (what I understand) is correct.
A2: I think it is easier to understand as 8xi8
can bitcast to 2xi32
, which is consistent with the idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we support arbitrary distribution formats. Elements are either packed or one per byte.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the discussion! @lialan , I've incorporated your suggestion (with additional tweaks from me).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave the approval to @lialan who knows the details better.
/// | ||
/// %mask = [1, 1, 0, 0, 0, 0] | ||
/// | ||
/// will first be padded with number of `intraDataOffset` zeros: | ||
/// will first be padded with number of `numFrontPadElems` zeros: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I also like the new name better because is more meaningful!
adbd169
to
b79654d
Compare
This PR updates and expands the high-level comment for `getCompressedMaskOp` and renames input variables with more descriptive names. The current variable names are somewhat unclear (e.g., `scale`) or derived from `memref` terminology (e.g., `intraDataOffset` from `LinearizedMemRefInfo`). The updated names in this PR aim to better align with the context and usage in the vector domain.
Incorporate PR suggestions
…askOp` Final tweaks
b79654d
to
afc0314
Compare
This PR updates and expands the high-level comment for
getCompressedMaskOp
and renames input variables with more descriptivenames.
The current variable names are somewhat unclear (e.g.,
scale
) orderived from
memref
terminology (e.g.,intraDataOffset
fromLinearizedMemRefInfo
). The updated names in this PR aim to betteralign with the context and usage in the vector domain.