@@ -45,27 +45,40 @@ using namespace mlir;
45
45
#define DBGSNL () (llvm::dbgs() << " \n " )
46
46
#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
47
47
48
- // / Returns a compressed mask. The mask value is set only if any mask is present
49
- // / in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
50
- // / equals to 1 (intraDataOffset strictly smaller than scale), the following
51
- // / mask:
48
+ // / Returns a compressed mask for the emulated vector. For example, when
49
+ // / emulating an eight-element `i8` vector with `i32` (i.e. when the source
50
+ // / elements span two dest elements), this method compresses `vector<8xi1>`
51
+ // / into `vector<2xi1>`.
52
+ // /
53
+ // / The compressed/output mask value is set iff any mask in the corresponding
54
+ // / `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
55
+ // / `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
56
+ // / following mask:
52
57
// /
53
58
// / %mask = [1, 1, 0, 0, 0, 0]
54
59
// /
55
- // / will first be padded in the front with number of `intraDataOffset` zeros,
56
- // / and pad zeros in the back to make the number of elements a multiple of
57
- // / `scale` (just to make it easier to compute). The new mask will be:
60
+ // / will first be padded in the front with `numFrontPadElems` zeros, and zeros
61
+ // / will be added in the back to make the number of elements a multiple of
62
+ // / `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
63
+ // /
58
64
// / %mask = [0, 1, 1, 0, 0, 0, 0, 0]
59
65
// /
60
66
// / then it will return the following new compressed mask:
61
67
// /
62
68
// / %mask = [1, 1, 0, 0]
69
+ // /
70
+ // / NOTE: `numFrontPadElems` is assumed to be strictly smaller than
71
+ // / `numSrcElemsPerDest`.
63
72
static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
64
73
Location loc, Value mask,
65
- int origElements, int scale,
66
- int intraDataOffset = 0 ) {
67
- assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
68
- auto numElements = llvm::divideCeil (intraDataOffset + origElements, scale);
74
+ int numSrcElems,
75
+ int numSrcElemsPerDest,
76
+ int numFrontPadElems = 0 ) {
77
+
78
+ assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79
+
80
+ auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81
+ numSrcElemsPerDest;
69
82
70
83
Operation *maskOp = mask.getDefiningOp ();
71
84
SmallVector<vector::ExtractOp, 2 > extractOps;
@@ -93,8 +106,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
93
106
size_t numMaskOperands = maskOperands.size ();
94
107
AffineExpr s0;
95
108
bindSymbols (rewriter.getContext (), s0);
96
- s0 = s0 + scale - 1 ;
97
- s0 = s0.floorDiv (scale );
109
+ s0 = s0 + numSrcElemsPerDest - 1 ;
110
+ s0 = s0.floorDiv (numSrcElemsPerDest );
98
111
OpFoldResult origIndex =
99
112
getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
100
113
OpFoldResult maskIndex =
@@ -108,18 +121,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
108
121
ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
109
122
size_t numMaskOperands = maskDimSizes.size ();
110
123
int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
111
- int64_t startIndex = intraDataOffset / scale;
112
- int64_t maskIndex = llvm::divideCeil (intraDataOffset + origIndex, scale);
124
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125
+ int64_t maskIndex =
126
+ llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
113
127
114
128
// TODO: we only want the mask between [startIndex, maskIndex] to be true,
115
129
// the rest are false.
116
- if (intraDataOffset != 0 && maskDimSizes.size () > 1 )
130
+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
117
131
return failure ();
118
132
119
133
SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
120
134
newMaskDimSizes.push_back (maskIndex);
121
135
122
- if (intraDataOffset == 0 ) {
136
+ if (numFrontPadElems == 0 ) {
123
137
newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
124
138
newMaskDimSizes);
125
139
} else {
0 commit comments