@@ -75,83 +75,134 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
75
75
int numSrcElemsPerDest,
76
76
int numFrontPadElems = 0 ) {
77
77
78
- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
78
+ assert (numFrontPadElems < numSrcElemsPerDest &&
79
+ " numFrontPadElems must be less than numSrcElemsPerDest" );
79
80
80
- auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81
- numSrcElemsPerDest;
81
+ auto numDestElems =
82
+ (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
83
+ numSrcElemsPerDest;
82
84
83
85
Operation *maskOp = mask.getDefiningOp ();
84
86
SmallVector<vector::ExtractOp, 2 > extractOps;
87
+ // TODO: add support to `vector.splat`.
85
88
// Finding the mask creation operation.
86
- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
89
+ while (maskOp &&
90
+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
91
+ maskOp)) {
87
92
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
88
93
maskOp = extractOp.getVector ().getDefiningOp ();
89
94
extractOps.push_back (extractOp);
90
95
}
91
96
}
92
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
93
- auto constantMaskOp = dyn_cast_or_null< vector::ConstantMaskOp>(maskOp);
94
- if (!createMaskOp && !constantMaskOp )
97
+
98
+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99
+ maskOp) )
95
100
return failure ();
96
101
97
102
// Computing the "compressed" mask. All the emulation logic (i.e. computing
98
103
// new mask index) only happens on the last dimension of the vectors.
99
- Operation *newMask = nullptr ;
100
- SmallVector<int64_t > shape (
104
+ SmallVector<int64_t > maskShape (
101
105
cast<VectorType>(maskOp->getResultTypes ()[0 ]).getShape ());
102
- shape.back () = numElements;
103
- auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
104
- if (createMaskOp) {
105
- OperandRange maskOperands = createMaskOp.getOperands ();
106
- size_t numMaskOperands = maskOperands.size ();
107
- AffineExpr s0;
108
- bindSymbols (rewriter.getContext (), s0);
109
- s0 = s0 + numSrcElemsPerDest - 1 ;
110
- s0 = s0.floorDiv (numSrcElemsPerDest);
111
- OpFoldResult origIndex =
112
- getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
113
- OpFoldResult maskIndex =
114
- affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
115
- SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
116
- newMaskOperands.push_back (
117
- getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
118
- newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
119
- newMaskOperands);
120
- } else if (constantMaskOp) {
121
- ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
122
- size_t numMaskOperands = maskDimSizes.size ();
123
- int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
124
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125
- int64_t maskIndex =
126
- llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
127
-
128
- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
129
- // the rest are false.
130
- if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
131
- return failure ();
132
-
133
- SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
134
- newMaskDimSizes.push_back (maskIndex);
135
-
136
- if (numFrontPadElems == 0 ) {
137
- newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
138
- newMaskDimSizes);
139
- } else {
140
- SmallVector<bool > newMaskValues;
141
- for (int64_t i = 0 ; i < numElements; ++i)
142
- newMaskValues.push_back (i >= startIndex && i < maskIndex);
143
- auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
144
- newMask = rewriter.create <arith::ConstantOp>(loc, newMaskType, denseAttr);
145
- }
146
- }
106
+ maskShape.back () = numDestElems;
107
+ auto newMaskType = VectorType::get (maskShape, rewriter.getI1Type ());
108
+ std::optional<Operation *> newMask =
109
+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
110
+ .Case <vector::CreateMaskOp>(
111
+ [&](auto createMaskOp) -> std::optional<Operation *> {
112
+ OperandRange maskOperands = createMaskOp.getOperands ();
113
+ size_t numMaskOperands = maskOperands.size ();
114
+ AffineExpr s0;
115
+ bindSymbols (rewriter.getContext (), s0);
116
+ s0 = s0 + numSrcElemsPerDest - 1 ;
117
+ s0 = s0.floorDiv (numSrcElemsPerDest);
118
+ OpFoldResult origIndex =
119
+ getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
120
+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply (
121
+ rewriter, loc, s0, origIndex);
122
+ SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
123
+ newMaskOperands.push_back (
124
+ getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
125
+ return rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
126
+ newMaskOperands);
127
+ })
128
+ .Case <vector::ConstantMaskOp>([&](auto constantMaskOp)
129
+ -> std::optional<Operation *> {
130
+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
131
+ size_t numMaskOperands = maskDimSizes.size ();
132
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
133
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
134
+ int64_t maskIndex = llvm::divideCeil (numFrontPadElems + origIndex,
135
+ numSrcElemsPerDest);
136
+
137
+ // TODO: we only want the mask between [startIndex, maskIndex]
138
+ // to be true, the rest are false.
139
+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
140
+ return std::nullopt;
141
+
142
+ SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
143
+ newMaskDimSizes.push_back (maskIndex);
144
+
145
+ if (numFrontPadElems == 0 )
146
+ return rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
147
+ newMaskDimSizes);
148
+
149
+ SmallVector<bool > newMaskValues;
150
+ for (int64_t i = 0 ; i < numDestElems; ++i)
151
+ newMaskValues.push_back (i >= startIndex && i < maskIndex);
152
+ auto newMask = DenseElementsAttr::get (newMaskType, newMaskValues);
153
+ return rewriter.create <arith::ConstantOp>(loc, newMaskType,
154
+ newMask);
155
+ })
156
+ .Case <arith::ConstantOp>([&](auto constantOp)
157
+ -> std::optional<Operation *> {
158
+ // TODO: Support multiple dimensions.
159
+ if (maskShape.size () != 1 )
160
+ return std::nullopt;
161
+ // Rearrange the original mask values to cover the whole potential
162
+ // loading region. For example, in the case of using byte-size for
163
+ // emulation, given the following mask:
164
+ //
165
+ // %mask = [0, 1, 0, 1, 0, 0]
166
+ //
167
+ // With front offset of 1, the mask will be padded 0s in the front
168
+ // and back so that:
169
+ // 1. It is aligned with the effective loading bits
170
+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
171
+ // coverage size is mulitiple of bytes). The new mask will be like
172
+ // this before compressing:
173
+ //
174
+ // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
175
+ auto originalMask =
176
+ cast<DenseIntElementsAttr>(constantOp.getValue ());
177
+ SmallVector<bool > paddedMaskValues (numFrontPadElems, false );
178
+ paddedMaskValues.append (originalMask.template value_begin <bool >(),
179
+ originalMask.template value_end <bool >());
180
+ paddedMaskValues.resize (numDestElems * numSrcElemsPerDest, false );
181
+
182
+ // Compressing by combining every `numSrcElemsPerDest` elements:
183
+ SmallVector<bool > compressedMaskValues;
184
+ for (size_t i = 0 ; i < paddedMaskValues.size ();
185
+ i += numSrcElemsPerDest) {
186
+ bool combinedValue = false ;
187
+ for (int j = 0 ; j < numSrcElemsPerDest; ++j) {
188
+ combinedValue |= paddedMaskValues[i + j];
189
+ }
190
+ compressedMaskValues.push_back (combinedValue);
191
+ }
192
+ return rewriter.create <arith::ConstantOp>(
193
+ loc, DenseElementsAttr::get (newMaskType, compressedMaskValues));
194
+ });
195
+
196
+ if (!newMask)
197
+ return failure ();
147
198
148
199
while (!extractOps.empty ()) {
149
200
newMask = rewriter.create <vector::ExtractOp>(
150
- loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
201
+ loc, (* newMask) ->getResults ()[0 ], extractOps.back ().getMixedPosition ());
151
202
extractOps.pop_back ();
152
203
}
153
204
154
- return newMask;
205
+ return * newMask;
155
206
}
156
207
157
208
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
@@ -185,12 +236,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185
236
// / `vector.insert_strided_slice`.
186
237
static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
187
238
Value src, Value dest, int64_t offset) {
188
- auto srcType = cast<VectorType>(src.getType ());
189
- auto destType = cast<VectorType>(dest.getType ());
239
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
240
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
190
241
assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
191
242
" expected source and dest to be vector type" );
192
- (void )srcType;
193
- (void )destType;
194
243
auto offsets = rewriter.getI64ArrayAttr ({offset});
195
244
auto strides = rewriter.getI64ArrayAttr ({1 });
196
245
return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
0 commit comments