@@ -32,6 +32,79 @@ using namespace mlir;
32
32
#define DBGSNL () (llvm::dbgs() << " \n " )
33
33
#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
34
34
35
+ // / Returns a compressed mask. The mask value is set only if any mask is present
36
+ // / in the scale range. E.g., if `scale` equals to 2, the following mask:
37
+ // /
38
+ // / %mask = [1, 1, 1, 0, 0, 0]
39
+ // /
40
+ // / will return the following new compressed mask:
41
+ // /
42
+ // / %mask = [1, 1, 0]
43
+ static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
44
+ Location loc, Value mask,
45
+ int origElements, int scale) {
46
+ auto numElements = (origElements + scale - 1 ) / scale;
47
+
48
+ Operation *maskOp = mask.getDefiningOp ();
49
+ SmallVector<vector::ExtractOp, 2 > extractOps;
50
+ // Finding the mask creation operation.
51
+ while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53
+ maskOp = extractOp.getVector ().getDefiningOp ();
54
+ extractOps.push_back (extractOp);
55
+ }
56
+ }
57
+ auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58
+ auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59
+ if (!createMaskOp && !constantMaskOp)
60
+ return failure ();
61
+
62
+ // Computing the "compressed" mask. All the emulation logic (i.e. computing
63
+ // new mask index) only happens on the last dimension of the vectors.
64
+ Operation *newMask = nullptr ;
65
+ SmallVector<int64_t > shape (
66
+ maskOp->getResultTypes ()[0 ].cast <VectorType>().getShape ());
67
+ shape.back () = numElements;
68
+ auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
69
+ if (createMaskOp) {
70
+ OperandRange maskOperands = createMaskOp.getOperands ();
71
+ size_t numMaskOperands = maskOperands.size ();
72
+ AffineExpr s0;
73
+ bindSymbols (rewriter.getContext (), s0);
74
+ s0 = s0 + scale - 1 ;
75
+ s0 = s0.floorDiv (scale);
76
+ OpFoldResult origIndex =
77
+ getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
78
+ OpFoldResult maskIndex =
79
+ affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
80
+ SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
81
+ newMaskOperands.push_back (
82
+ getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
83
+ newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
84
+ newMaskOperands);
85
+ } else if (constantMaskOp) {
86
+ ArrayRef<Attribute> maskDimSizes =
87
+ constantMaskOp.getMaskDimSizes ().getValue ();
88
+ size_t numMaskOperands = maskDimSizes.size ();
89
+ auto origIndex =
90
+ cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1 ]).getInt ();
91
+ IntegerAttr maskIndexAttr =
92
+ rewriter.getI64IntegerAttr ((origIndex + scale - 1 ) / scale);
93
+ SmallVector<Attribute> newMaskDimSizes (maskDimSizes.drop_back ());
94
+ newMaskDimSizes.push_back (maskIndexAttr);
95
+ newMask = rewriter.create <vector::ConstantMaskOp>(
96
+ loc, newMaskType, rewriter.getArrayAttr (newMaskDimSizes));
97
+ }
98
+
99
+ while (!extractOps.empty ()) {
100
+ newMask = rewriter.create <vector::ExtractOp>(
101
+ loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
102
+ extractOps.pop_back ();
103
+ }
104
+
105
+ return newMask;
106
+ }
107
+
35
108
namespace {
36
109
37
110
// ===----------------------------------------------------------------------===//
@@ -99,6 +172,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
99
172
}
100
173
};
101
174
175
+ // ===----------------------------------------------------------------------===//
176
+ // ConvertVectorMaskedStore
177
+ // ===----------------------------------------------------------------------===//
178
+
179
+ struct ConvertVectorMaskedStore final
180
+ : OpConversionPattern<vector::MaskedStoreOp> {
181
+ using OpConversionPattern::OpConversionPattern;
182
+
183
+ LogicalResult
184
+ matchAndRewrite (vector::MaskedStoreOp op, OpAdaptor adaptor,
185
+ ConversionPatternRewriter &rewriter) const override {
186
+
187
+ auto loc = op.getLoc ();
188
+ auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
189
+ Type oldElementType = op.getValueToStore ().getType ().getElementType ();
190
+ Type newElementType = convertedType.getElementType ();
191
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
192
+ int dstBits = newElementType.getIntOrFloatBitWidth ();
193
+
194
+ if (dstBits % srcBits != 0 ) {
195
+ return rewriter.notifyMatchFailure (
196
+ op, " only dstBits % srcBits == 0 supported" );
197
+ }
198
+
199
+ int scale = dstBits / srcBits;
200
+ int origElements = op.getValueToStore ().getType ().getNumElements ();
201
+ if (origElements % scale != 0 )
202
+ return failure ();
203
+
204
+ auto stridedMetadata =
205
+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
206
+ OpFoldResult linearizedIndicesOfr;
207
+ std::tie (std::ignore, linearizedIndicesOfr) =
208
+ memref::getLinearizedMemRefOffsetAndSize (
209
+ rewriter, loc, srcBits, dstBits,
210
+ stridedMetadata.getConstifiedMixedOffset (),
211
+ stridedMetadata.getConstifiedMixedSizes (),
212
+ stridedMetadata.getConstifiedMixedStrides (),
213
+ getAsOpFoldResult (adaptor.getIndices ()));
214
+ Value linearizedIndices =
215
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndicesOfr);
216
+
217
+ // Load the whole data and use arith.select to handle the corner cases.
218
+ // E.g., given these input values:
219
+ //
220
+ // %mask = [1, 1, 1, 0, 0, 0]
221
+ // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
222
+ // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
223
+ //
224
+ // we'll have
225
+ //
226
+ // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
227
+ //
228
+ // %new_mask = [1, 1, 0]
229
+ // %maskedload = [0x12, 0x34, 0x0]
230
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
231
+ // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
232
+ // %packed_data = [0x78, 0x94, 0x00]
233
+ //
234
+ // Using the new mask to store %packed_data results in expected output.
235
+ FailureOr<Operation *> newMask =
236
+ getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale);
237
+ if (failed (newMask))
238
+ return failure ();
239
+
240
+ auto numElements = (origElements + scale - 1 ) / scale;
241
+ auto newType = VectorType::get (numElements, newElementType);
242
+ auto passThru = rewriter.create <arith::ConstantOp>(
243
+ loc, newType, rewriter.getZeroAttr (newType));
244
+
245
+ auto newLoad = rewriter.create <vector::MaskedLoadOp>(
246
+ loc, newType, adaptor.getBase (), linearizedIndices,
247
+ newMask.value ()->getResult (0 ), passThru);
248
+
249
+ Value valueToStore = rewriter.create <vector::BitCastOp>(
250
+ loc, op.getValueToStore ().getType (), newLoad);
251
+ valueToStore = rewriter.create <arith::SelectOp>(
252
+ loc, op.getMask (), op.getValueToStore (), valueToStore);
253
+ valueToStore =
254
+ rewriter.create <vector::BitCastOp>(loc, newType, valueToStore);
255
+
256
+ rewriter.replaceOpWithNewOp <vector::MaskedStoreOp>(
257
+ op, adaptor.getBase (), linearizedIndices, newMask.value ()->getResult (0 ),
258
+ valueToStore);
259
+ return success ();
260
+ }
261
+ };
262
+
102
263
// ===----------------------------------------------------------------------===//
103
264
// ConvertVectorLoad
104
265
// ===----------------------------------------------------------------------===//
@@ -236,15 +397,13 @@ struct ConvertVectorMaskedLoad final
236
397
// TODO: Currently, only the even number of elements loading is supported.
237
398
// To deal with the odd number of elements, one has to extract the
238
399
// subvector at the proper offset after bit-casting.
239
-
240
400
auto origType = op.getVectorType ();
241
401
auto origElements = origType.getNumElements ();
242
402
if (origElements % scale != 0 )
243
403
return failure ();
244
404
245
405
auto stridedMetadata =
246
406
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
247
-
248
407
OpFoldResult linearizedIndices;
249
408
std::tie (std::ignore, linearizedIndices) =
250
409
memref::getLinearizedMemRefOffsetAndSize (
@@ -254,74 +413,21 @@ struct ConvertVectorMaskedLoad final
254
413
stridedMetadata.getConstifiedMixedStrides (),
255
414
getAsOpFoldResult (adaptor.getIndices ()));
256
415
257
- auto numElements = (origElements + scale - 1 ) / scale;
258
- auto newType = VectorType::get (numElements, newElementType);
259
-
260
- auto maskOp = op.getMask ().getDefiningOp ();
261
- SmallVector<vector::ExtractOp, 2 > extractOps;
262
- // Finding the mask creation operation.
263
- while (maskOp &&
264
- !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
265
- if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
266
- maskOp = extractOp.getVector ().getDefiningOp ();
267
- extractOps.push_back (extractOp);
268
- }
269
- }
270
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
271
- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
272
- if (!createMaskOp && !constantMaskOp)
416
+ FailureOr<Operation *> newMask =
417
+ getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale);
418
+ if (failed (newMask))
273
419
return failure ();
274
420
275
- // Computing the "compressed" mask. All the emulation logic (i.e. computing
276
- // new mask index) only happens on the last dimension of the vectors.
277
- Operation *newMask = nullptr ;
278
- auto shape = llvm::to_vector (
279
- maskOp->getResultTypes ()[0 ].cast <VectorType>().getShape ().drop_back ());
280
- shape.push_back (numElements);
281
- auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
282
- if (createMaskOp) {
283
- auto maskOperands = createMaskOp.getOperands ();
284
- auto numMaskOperands = maskOperands.size ();
285
- AffineExpr s0;
286
- bindSymbols (rewriter.getContext (), s0);
287
- s0 = s0 + scale - 1 ;
288
- s0 = s0.floorDiv (scale);
289
- OpFoldResult origIndex =
290
- getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
291
- OpFoldResult maskIndex =
292
- affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
293
- auto newMaskOperands = llvm::to_vector (maskOperands.drop_back ());
294
- newMaskOperands.push_back (
295
- getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
296
- newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
297
- newMaskOperands);
298
- } else if (constantMaskOp) {
299
- auto maskDimSizes = constantMaskOp.getMaskDimSizes ().getValue ();
300
- auto numMaskOperands = maskDimSizes.size ();
301
- auto origIndex =
302
- cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1 ]).getInt ();
303
- auto maskIndex =
304
- rewriter.getI64IntegerAttr ((origIndex + scale - 1 ) / scale);
305
- auto newMaskDimSizes = llvm::to_vector (maskDimSizes.drop_back ());
306
- newMaskDimSizes.push_back (maskIndex);
307
- newMask = rewriter.create <vector::ConstantMaskOp>(
308
- loc, newMaskType, rewriter.getArrayAttr (newMaskDimSizes));
309
- }
310
-
311
- while (!extractOps.empty ()) {
312
- newMask = rewriter.create <vector::ExtractOp>(
313
- loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
314
- extractOps.pop_back ();
315
- }
316
-
421
+ auto numElements = (origElements + scale - 1 ) / scale;
422
+ auto newType = VectorType::get (numElements, newElementType);
317
423
auto newPassThru =
318
424
rewriter.create <vector::BitCastOp>(loc, newType, op.getPassThru ());
319
425
320
426
// Generating the new masked load.
321
427
auto newLoad = rewriter.create <vector::MaskedLoadOp>(
322
428
loc, newType, adaptor.getBase (),
323
429
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
324
- newMask->getResult (0 ), newPassThru);
430
+ newMask. value () ->getResult (0 ), newPassThru);
325
431
326
432
// Setting the part that originally was not effectively loaded from memory
327
433
// to pass through.
@@ -821,7 +927,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
821
927
822
928
// Populate `vector.*` conversion patterns.
823
929
patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
824
- ConvertVectorTransferRead>(typeConverter, patterns.getContext ());
930
+ ConvertVectorMaskedStore, ConvertVectorTransferRead>(
931
+ typeConverter, patterns.getContext ());
825
932
}
826
933
827
934
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments