Skip to content

Commit 7a31f3c

Browse files
authored
[mlir][vector][nfc] Improve comments in getCompressedMaskOp (#115663)
1 parent 2ca25ab commit 7a31f3c

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,40 @@ using namespace mlir;
4545
#define DBGSNL() (llvm::dbgs() << "\n")
4646
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
4747

48-
/// Returns a compressed mask. The mask value is set only if any mask is present
49-
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
50-
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
51-
/// mask:
48+
/// Returns a compressed mask for the emulated vector. For example, when
49+
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
50+
/// elements span two dest elements), this method compresses `vector<8xi1>`
51+
/// into `vector<2xi1>`.
52+
///
53+
/// The compressed/output mask value is set iff any mask in the corresponding
54+
/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
55+
/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
56+
/// following mask:
5257
///
5358
/// %mask = [1, 1, 0, 0, 0, 0]
5459
///
55-
/// will first be padded in the front with number of `intraDataOffset` zeros,
56-
/// and pad zeros in the back to make the number of elements a multiple of
57-
/// `scale` (just to make it easier to compute). The new mask will be:
60+
/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
61+
/// will be added in the back to make the number of elements a multiple of
62+
/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
63+
///
5864
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
5965
///
6066
/// then it will return the following new compressed mask:
6167
///
6268
/// %mask = [1, 1, 0, 0]
69+
///
70+
/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
71+
/// `numSrcElemsPerDest`.
6372
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
6473
Location loc, Value mask,
65-
int origElements, int scale,
66-
int intraDataOffset = 0) {
67-
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
68-
auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);
74+
int numSrcElems,
75+
int numSrcElemsPerDest,
76+
int numFrontPadElems = 0) {
77+
78+
assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
79+
80+
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
81+
numSrcElemsPerDest;
6982

7083
Operation *maskOp = mask.getDefiningOp();
7184
SmallVector<vector::ExtractOp, 2> extractOps;
@@ -93,8 +106,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
93106
size_t numMaskOperands = maskOperands.size();
94107
AffineExpr s0;
95108
bindSymbols(rewriter.getContext(), s0);
96-
s0 = s0 + scale - 1;
97-
s0 = s0.floorDiv(scale);
109+
s0 = s0 + numSrcElemsPerDest - 1;
110+
s0 = s0.floorDiv(numSrcElemsPerDest);
98111
OpFoldResult origIndex =
99112
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
100113
OpFoldResult maskIndex =
@@ -108,18 +121,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
108121
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
109122
size_t numMaskOperands = maskDimSizes.size();
110123
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
111-
int64_t startIndex = intraDataOffset / scale;
112-
int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
124+
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125+
int64_t maskIndex =
126+
llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
113127

114128
// TODO: we only want the mask between [startIndex, maskIndex] to be true,
115129
// the rest are false.
116-
if (intraDataOffset != 0 && maskDimSizes.size() > 1)
130+
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
117131
return failure();
118132

119133
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
120134
newMaskDimSizes.push_back(maskIndex);
121135

122-
if (intraDataOffset == 0) {
136+
if (numFrontPadElems == 0) {
123137
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
124138
newMaskDimSizes);
125139
} else {

0 commit comments

Comments
 (0)