@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
76
76
int numSrcElemsPerDest,
77
77
int numFrontPadElems = 0 ) {
78
78
79
- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79
+ assert (numFrontPadElems < numSrcElemsPerDest &&
80
+ " intraDataOffset must be less than scale" );
80
81
81
82
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
82
83
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256
257
newLoad);
257
258
}
258
259
259
- static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
260
- Value memref, Value index, Value value) {
261
- auto originType = dyn_cast<VectorType>(value.getType ());
262
- auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
263
- auto scale = memrefElemType.getIntOrFloatBitWidth () /
264
- originType.getElementType ().getIntOrFloatBitWidth ();
265
- auto storeType =
266
- VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
267
- auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
268
- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
269
- }
270
-
271
- // / atomically store a subbyte-sized value to memory, with a mask.
272
- static Value atomicStore (OpBuilder &rewriter, Location loc,
273
- Value emulatedMemref, Value emulatedIndex,
274
- TypedValue<VectorType> value, Value mask,
275
- int64_t scale) {
260
+ // / Atomically store a subbyte-sized value to memory, with a mask.
261
+ static void atomicStore (OpBuilder &rewriter, Location loc,
262
+ TypedValue<MemRefType> emulatedMemref,
263
+ Value emulatedIndex, TypedValue<VectorType> value,
264
+ Value mask, int64_t scale) {
276
265
auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
277
266
loc, emulatedMemref, ValueRange{emulatedIndex});
278
267
OpBuilder builder =
@@ -291,9 +280,31 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
291
280
auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
292
281
auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
293
282
builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
294
- return atomicOp;
295
283
}
296
284
285
+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
286
+ static void rmwStore (OpBuilder &rewriter, Location loc,
287
+ TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
288
+ TypedValue<VectorType> value, Value mask,
289
+ int64_t numSrcElemsPerDest) {
290
+ auto emulatedIOType =
291
+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
292
+ auto elemLoad = rewriter.create <vector::LoadOp>(
293
+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
294
+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
295
+ loc,
296
+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
297
+ elemLoad);
298
+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
299
+ auto toBitcast =
300
+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
301
+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
302
+ emulatedIndex);
303
+ }
304
+
305
+ static_assert (std::is_same_v<decltype (atomicStore), decltype(rmwStore)> &&
306
+ "`atomicStore` and `rmwStore` must have same function type.");
307
+
297
308
// Extract a slice of a vector, and insert it into a byte vector.
298
309
static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
299
310
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +333,10 @@ namespace {
322
333
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323
334
using OpConversionPattern::OpConversionPattern;
324
335
336
+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
337
+ : OpConversionPattern<vector::StoreOp>(context),
338
+ useAtomicWrites_ (useAtomicWrites) {}
339
+
325
340
LogicalResult
326
341
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
327
342
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +358,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343
358
return rewriter.notifyMatchFailure (
344
359
op, " only dstBits % srcBits == 0 supported" );
345
360
}
346
- int scale = dstBits / srcBits;
361
+ int numSrcElemsPerDest = dstBits / srcBits;
347
362
348
363
// Adjust the number of elements to store when emulating narrow types.
349
364
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +374,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359
374
// vector<4xi8>
360
375
361
376
auto origElements = valueToStore.getType ().getNumElements ();
362
- bool isUnalignedEmulation = origElements % scale != 0 ;
377
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
363
378
364
379
auto stridedMetadata =
365
380
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -374,62 +389,68 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374
389
stridedMetadata.getConstifiedMixedStrides (),
375
390
getAsOpFoldResult (adaptor.getIndices ()));
376
391
377
- auto foldedIntraVectorOffset =
392
+ auto foldedNumFrontPadElems =
378
393
isUnalignedEmulation
379
394
? getConstantIntValue (linearizedInfo.intraDataOffset )
380
395
: 0 ;
381
396
382
- if (!foldedIntraVectorOffset ) {
383
- // unimplemented case for dynamic front padding size
397
+ if (!foldedNumFrontPadElems ) {
398
+ // Unimplemented case for dynamic front padding size != 0
384
399
return failure ();
385
400
}
386
401
387
- // conditions when atomic stores and all that are not needed:
402
+ TypedValue<MemRefType> emulatedMemref =
403
+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
404
+
405
+ // Shortcut: conditions when subbyte store at the front is not needed:
388
406
// 1. The source vector size is multiple of byte size
389
- // 2. The address of the store is byte aligned
390
- if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
391
- auto numElements = origElements / scale ;
407
+ // 2. The address of the store is aligned to the emulated width boundary
408
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
409
+ auto numElements = origElements / numSrcElemsPerDest ;
392
410
auto bitCast = rewriter.create <vector::BitCastOp>(
393
411
loc, VectorType::get (numElements, newElementType),
394
412
op.getValueToStore ());
395
413
rewriter.replaceOpWithNewOp <vector::StoreOp>(
396
- op, bitCast.getResult (), adaptor. getBase () ,
414
+ op, bitCast.getResult (), emulatedMemref ,
397
415
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
398
416
return llvm::success ();
399
417
}
400
418
401
- Value emulatedMemref = adaptor.getBase ();
402
- // the index into the target memref we are storing to
419
+ // The index into the target memref we are storing to
403
420
Value currentDestIndex =
404
421
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
405
422
auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
406
- auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
407
- // the index into the source vector we are currently processing
423
+ auto subWidthStoreMaskType =
424
+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
425
+ // The index into the source vector we are currently processing
408
426
auto currentSourceIndex = 0 ;
409
427
410
- // 1. atomic store for the first byte
411
- auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412
- if (frontAtomicStoreElem != 0 ) {
413
- auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
414
- if (*foldedIntraVectorOffset + origElements < scale) {
415
- std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
428
+ // 1. Partial width store for the first byte, when the store address is not
429
+ // aligned to emulated width boundary, deal with the unaligned part so that
430
+ // the rest elements are aligned to width boundary.
431
+ auto frontSubWidthStoreElem =
432
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
433
+ if (frontSubWidthStoreElem != 0 ) {
434
+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, false );
435
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
436
+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
416
437
origElements, true );
417
- frontAtomicStoreElem = origElements;
438
+ frontSubWidthStoreElem = origElements;
418
439
} else {
419
- std::fill_n (frontMaskValues.end () - frontAtomicStoreElem ,
420
- *foldedIntraVectorOffset , true );
440
+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem ,
441
+ *foldedNumFrontPadElems , true );
421
442
}
422
443
auto frontMask = rewriter.create <arith::ConstantOp>(
423
- loc, DenseElementsAttr::get (atomicMaskType , frontMaskValues));
444
+ loc, DenseElementsAttr::get (subWidthStoreMaskType , frontMaskValues));
424
445
425
- currentSourceIndex = scale - (*foldedIntraVectorOffset );
446
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems );
426
447
auto value = extractSliceIntoByte (
427
448
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
428
- frontAtomicStoreElem , *foldedIntraVectorOffset );
449
+ frontSubWidthStoreElem , *foldedNumFrontPadElems );
429
450
430
- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
431
- cast<TypedValue<VectorType>>(value), frontMask. getResult ( ),
432
- scale );
451
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
452
+ cast<TypedValue<VectorType>>(value),
453
+ frontMask. getResult (), numSrcElemsPerDest );
433
454
434
455
currentDestIndex = rewriter.create <arith::AddIOp>(
435
456
loc, rewriter.getIndexType (), currentDestIndex, constantOne);
@@ -440,44 +461,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440
461
return success ();
441
462
}
442
463
443
- // 2. non-atomic store
444
- int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445
- int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446
- if (nonAtomicStoreSize != 0 ) {
447
- auto nonAtomicStorePart = staticallyExtractSubvector (
464
+ // 2. Full width store. After the previous step, the store address is
465
+ // aligned to the emulated width boundary.
466
+ int64_t fullWidthStoreSize =
467
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
468
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
469
+ if (fullWidthStoreSize != 0 ) {
470
+ auto fullWidthStorePart = staticallyExtractSubvector (
448
471
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449
- currentSourceIndex, numNonAtomicElements);
450
-
451
- nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
452
- nonAtomicStorePart);
453
-
454
- currentSourceIndex += numNonAtomicElements;
472
+ currentSourceIndex, numNonFullWidthElements);
473
+
474
+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
475
+ auto memrefElemType =
476
+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
477
+ auto storeType = VectorType::get (
478
+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
479
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
480
+ fullWidthStorePart);
481
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
482
+ currentDestIndex);
483
+
484
+ currentSourceIndex += numNonFullWidthElements;
455
485
currentDestIndex = rewriter.create <arith::AddIOp>(
456
486
loc, rewriter.getIndexType (), currentDestIndex,
457
- rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize ));
487
+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize ));
458
488
}
459
489
460
- // 3. atomic store for the last byte
490
+ // 3. Deal with trailing elements that are aligned to the emulated width,
491
+ // but their length is smaller than the emulated width.
461
492
auto remainingElements = origElements - currentSourceIndex;
462
493
if (remainingElements != 0 ) {
463
- auto atomicStorePart = extractSliceIntoByte (
494
+ auto subWidthStorePart = extractSliceIntoByte (
464
495
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465
496
currentSourceIndex, remainingElements, 0 );
466
497
467
- // back mask
468
- auto maskValues = llvm::SmallVector<bool >(scale , 0 );
498
+ // Generate back mask
499
+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , 0 );
469
500
std::fill_n (maskValues.begin (), remainingElements, 1 );
470
501
auto backMask = rewriter.create <arith::ConstantOp>(
471
- loc, DenseElementsAttr::get (atomicMaskType , maskValues));
502
+ loc, DenseElementsAttr::get (subWidthStoreMaskType , maskValues));
472
503
473
- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
474
- cast<TypedValue<VectorType>>(atomicStorePart ),
475
- backMask.getResult (), scale );
504
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
505
+ cast<TypedValue<VectorType>>(subWidthStorePart ),
506
+ backMask.getResult (), numSrcElemsPerDest );
476
507
}
477
508
478
509
rewriter.eraseOp (op);
479
510
return success ();
480
511
}
512
+
513
+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
514
+ // / configuration, it could be an atomic store or an RMW sequence.
515
+ template <typename ... Args>
516
+ void subEmulatedWidthStore (Args &&...args) const {
517
+ std::function<decltype (atomicStore)> storeFunc =
518
+ useAtomicWrites_ ? atomicStore : rmwStore;
519
+ storeFunc (std::forward<Args>(args)...);
520
+ }
521
+
522
+ private:
523
+ const bool useAtomicWrites_;
481
524
};
482
525
483
526
// ===----------------------------------------------------------------------===//
@@ -1673,12 +1716,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1673
1716
1674
1717
void vector::populateVectorNarrowTypeEmulationPatterns (
1675
1718
const arith::NarrowTypeEmulationConverter &typeConverter,
1676
- RewritePatternSet &patterns) {
1719
+ RewritePatternSet &patterns, bool useAtomicWrites ) {
1677
1720
1678
- // Populate `vector.*` conversion patterns.
1679
- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1721
+ // Populate `vector.*` load conversion patterns.
1722
+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
1680
1723
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1681
1724
typeConverter, patterns.getContext ());
1725
+
1726
+ // Populate `vector.*` store conversion patterns. The caller can choose
1727
+ // to avoid emitting atomic operations and reduce it to load-modify-write
1728
+ // sequence for stores if it is known there are no thread contentions.
1729
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
1682
1730
}
1683
1731
1684
1732
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments