@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
76
76
int numSrcElemsPerDest,
77
77
int numFrontPadElems = 0 ) {
78
78
79
- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79
+ assert (numFrontPadElems < numSrcElemsPerDest &&
80
+ " intraDataOffset must be less than scale" );
80
81
81
82
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
82
83
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256
257
newLoad);
257
258
}
258
259
259
- static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
260
- Value memref, Value index, Value value) {
261
- auto originType = dyn_cast<VectorType>(value.getType ());
262
- auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
263
- auto scale = memrefElemType.getIntOrFloatBitWidth () /
264
- originType.getElementType ().getIntOrFloatBitWidth ();
265
- auto storeType =
266
- VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
267
- auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
268
- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
269
- }
270
-
271
- // / atomically store a subbyte-sized value to memory, with a mask.
260
+ // / Atomically store a subbyte-sized value to memory, with a mask.
272
261
static Value atomicStore (OpBuilder &rewriter, Location loc,
273
- Value emulatedMemref, Value emulatedIndex ,
274
- TypedValue<VectorType> value, Value mask ,
275
- int64_t scale) {
262
+ TypedValue<MemRefType> emulatedMemref,
263
+ Value emulatedIndex, TypedValue<VectorType> value,
264
+ Value mask, int64_t scale) {
276
265
auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
277
266
loc, emulatedMemref, ValueRange{emulatedIndex});
278
267
OpBuilder builder =
@@ -294,6 +283,27 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
294
283
return atomicOp;
295
284
}
296
285
286
+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
287
+ static Value rmwStore (OpBuilder &rewriter, Location loc,
288
+ TypedValue<MemRefType> emulatedMemref,
289
+ Value emulatedIndex, TypedValue<VectorType> value,
290
+ Value mask, int64_t numSrcElemsPerDest) {
291
+ auto emulatedIOType =
292
+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
293
+ auto elemLoad = rewriter.create <vector::LoadOp>(
294
+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
295
+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
296
+ loc,
297
+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
298
+ elemLoad);
299
+ auto select = rewriter.create <arith::SelectOp>(loc, mask, value, fromBitcast);
300
+ auto toBitcast =
301
+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
302
+ return rewriter
303
+ .create <vector::StoreOp>(loc, toBitcast, emulatedMemref, emulatedIndex)
304
+ ->getResult (0 );
305
+ }
306
+
297
307
// Extract a slice of a vector, and insert it into a byte vector.
298
308
static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
299
309
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +332,10 @@ namespace {
322
332
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323
333
using OpConversionPattern::OpConversionPattern;
324
334
335
+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
336
+ : OpConversionPattern<vector::StoreOp>(context),
337
+ useAtomicWrites_ (useAtomicWrites) {}
338
+
325
339
LogicalResult
326
340
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
327
341
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +357,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343
357
return rewriter.notifyMatchFailure (
344
358
op, " only dstBits % srcBits == 0 supported" );
345
359
}
346
- int scale = dstBits / srcBits;
360
+ int numSrcElemsPerDest = dstBits / srcBits;
347
361
348
362
// Adjust the number of elements to store when emulating narrow types.
349
363
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +373,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359
373
// vector<4xi8>
360
374
361
375
auto origElements = valueToStore.getType ().getNumElements ();
362
- bool isUnalignedEmulation = origElements % scale != 0 ;
376
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
363
377
364
378
auto stridedMetadata =
365
379
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -374,21 +388,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374
388
stridedMetadata.getConstifiedMixedStrides (),
375
389
getAsOpFoldResult (adaptor.getIndices ()));
376
390
377
- auto foldedIntraVectorOffset =
391
+ auto foldedNumFrontPadElems =
378
392
isUnalignedEmulation
379
393
? getConstantIntValue (linearizedInfo.intraDataOffset )
380
394
: 0 ;
381
395
382
- if (!foldedIntraVectorOffset ) {
383
- // unimplemented case for dynamic front padding size
396
+ if (!foldedNumFrontPadElems ) {
397
+ // Unimplemented case for dynamic front padding size != 0
384
398
return failure ();
385
399
}
386
400
387
- // conditions when atomic stores and all that are not needed:
401
+ // Conditions when atomic stores and all that are not needed:
388
402
// 1. The source vector size is multiple of byte size
389
403
// 2. The address of the store is byte aligned
390
- if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
391
- auto numElements = origElements / scale ;
404
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
405
+ auto numElements = origElements / numSrcElemsPerDest ;
392
406
auto bitCast = rewriter.create <vector::BitCastOp>(
393
407
loc, VectorType::get (numElements, newElementType),
394
408
op.getValueToStore ());
@@ -398,38 +412,41 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
398
412
return llvm::success ();
399
413
}
400
414
401
- Value emulatedMemref = adaptor.getBase ();
402
- // the index into the target memref we are storing to
415
+ TypedValue<MemRefType> emulatedMemref =
416
+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
417
+ // The index into the target memref we are storing to
403
418
Value currentDestIndex =
404
419
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
405
420
auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
406
- auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
407
- // the index into the source vector we are currently processing
421
+ auto atomicMaskType =
422
+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
423
+ // The index into the source vector we are currently processing
408
424
auto currentSourceIndex = 0 ;
409
425
410
- // 1. atomic store for the first byte
411
- auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
426
+ // 1. Atomic store for the first byte
427
+ auto frontAtomicStoreElem =
428
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
412
429
if (frontAtomicStoreElem != 0 ) {
413
- auto frontMaskValues = llvm::SmallVector<bool >(scale , false );
414
- if (*foldedIntraVectorOffset + origElements < scale ) {
415
- std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset ,
430
+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , false );
431
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest ) {
432
+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems ,
416
433
origElements, true );
417
434
frontAtomicStoreElem = origElements;
418
435
} else {
419
436
std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
420
- *foldedIntraVectorOffset , true );
437
+ *foldedNumFrontPadElems , true );
421
438
}
422
439
auto frontMask = rewriter.create <arith::ConstantOp>(
423
440
loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
424
441
425
- currentSourceIndex = scale - (*foldedIntraVectorOffset );
442
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems );
426
443
auto value = extractSliceIntoByte (
427
444
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
428
- frontAtomicStoreElem, *foldedIntraVectorOffset );
445
+ frontAtomicStoreElem, *foldedNumFrontPadElems );
429
446
430
- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
431
- cast<TypedValue<VectorType>>(value), frontMask.getResult (),
432
- scale );
447
+ subByteStore (rewriter, loc, emulatedMemref, currentDestIndex,
448
+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
449
+ numSrcElemsPerDest );
433
450
434
451
currentDestIndex = rewriter.create <arith::AddIOp>(
435
452
loc, rewriter.getIndexType (), currentDestIndex, constantOne);
@@ -440,44 +457,62 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440
457
return success ();
441
458
}
442
459
443
- // 2. non-atomic store
444
- int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445
- int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
460
+ // 2. Non-atomic store
461
+ int64_t nonAtomicStoreSize =
462
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
463
+ int64_t numNonAtomicElements = nonAtomicStoreSize * numSrcElemsPerDest;
446
464
if (nonAtomicStoreSize != 0 ) {
447
465
auto nonAtomicStorePart = staticallyExtractSubvector (
448
466
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449
467
currentSourceIndex, numNonAtomicElements);
450
468
451
- nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
452
- nonAtomicStorePart);
469
+ auto originType = dyn_cast<VectorType>(nonAtomicStorePart.getType ());
470
+ auto memrefElemType =
471
+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
472
+ auto storeType = VectorType::get (
473
+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
474
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
475
+ nonAtomicStorePart);
476
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
477
+ currentDestIndex);
453
478
454
479
currentSourceIndex += numNonAtomicElements;
455
480
currentDestIndex = rewriter.create <arith::AddIOp>(
456
481
loc, rewriter.getIndexType (), currentDestIndex,
457
482
rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
458
483
}
459
484
460
- // 3. atomic store for the last byte
485
+ // 3. Atomic store for the last byte
461
486
auto remainingElements = origElements - currentSourceIndex;
462
487
if (remainingElements != 0 ) {
463
488
auto atomicStorePart = extractSliceIntoByte (
464
489
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465
490
currentSourceIndex, remainingElements, 0 );
466
491
467
- // back mask
468
- auto maskValues = llvm::SmallVector<bool >(scale , 0 );
492
+ // Generate back mask
493
+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , 0 );
469
494
std::fill_n (maskValues.begin (), remainingElements, 1 );
470
495
auto backMask = rewriter.create <arith::ConstantOp>(
471
496
loc, DenseElementsAttr::get (atomicMaskType, maskValues));
472
497
473
- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
474
- cast<TypedValue<VectorType>>(atomicStorePart),
475
- backMask.getResult (), scale );
498
+ subByteStore (rewriter, loc, emulatedMemref, currentDestIndex,
499
+ cast<TypedValue<VectorType>>(atomicStorePart),
500
+ backMask.getResult (), numSrcElemsPerDest );
476
501
}
477
502
478
503
rewriter.eraseOp (op);
479
504
return success ();
480
505
}
506
+
507
+ template <typename ... Args>
508
+ Value subByteStore (Args &&...args) const {
509
+ std::function<decltype (atomicStore)> storeFunc =
510
+ useAtomicWrites_ ? atomicStore : rmwStore;
511
+ return storeFunc (std::forward<Args>(args)...);
512
+ }
513
+
514
+ private:
515
+ const bool useAtomicWrites_;
481
516
};
482
517
483
518
// ===----------------------------------------------------------------------===//
@@ -1673,12 +1708,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1673
1708
1674
1709
void vector::populateVectorNarrowTypeEmulationPatterns (
1675
1710
const arith::NarrowTypeEmulationConverter &typeConverter,
1676
- RewritePatternSet &patterns) {
1711
+ RewritePatternSet &patterns, bool useAtomicWrites ) {
1677
1712
1678
- // Populate `vector.*` conversion patterns.
1679
- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1713
+ // Populate `vector.*` load conversion patterns.
1714
+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
1680
1715
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1681
1716
typeConverter, patterns.getContext ());
1717
+
1718
+ // Populate `vector.*` store conversion patterns. The caller can choose
1719
+ // to avoid emitting atomic operations and reduce it to load-modify-write
1720
+ // sequence for stores if it is known there are no thread contentions.
1721
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
1682
1722
}
1683
1723
1684
1724
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments