33
33
#include " mlir/Transforms/DialectConversion.h"
34
34
#include " llvm/ADT/SmallVector.h"
35
35
#include " llvm/Support/Debug.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include " llvm/Support/raw_ostream.h"
38
39
#include < cstdint>
@@ -208,13 +209,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
208
209
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
209
210
// / emitting `vector.extract_strided_slice`.
210
211
static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
211
- VectorType extractType, Value source,
212
- int64_t frontOffset,
212
+ Value source, int64_t frontOffset,
213
213
int64_t subvecSize) {
214
214
auto vectorType = cast<VectorType>(source.getType ());
215
- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
216
- " expected 1-D source and destination types" );
217
- (void )vectorType;
215
+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
218
216
assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
219
217
" subvector out of bounds" );
220
218
@@ -225,9 +223,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
225
223
auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
226
224
auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
227
225
auto strides = rewriter.getI64ArrayAttr ({1 });
226
+
227
+ auto resultVectorType =
228
+ VectorType::get ({subvecSize}, vectorType.getElementType ());
228
229
return rewriter
229
- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
230
- sizes, strides)
230
+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
231
+ offsets, sizes, strides)
231
232
->getResult (0 );
232
233
}
233
234
@@ -306,6 +307,73 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
306
307
newLoad);
307
308
}
308
309
310
+ // / Atomically store a subbyte-sized value to memory, with a mask.
311
+ static void atomicStore (OpBuilder &rewriter, Location loc,
312
+ TypedValue<MemRefType> emulatedMemref,
313
+ Value emulatedIndex, TypedValue<VectorType> value,
314
+ Value mask, int64_t scale) {
315
+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
316
+ loc, emulatedMemref, ValueRange{emulatedIndex});
317
+ OpBuilder builder =
318
+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
319
+ Value origValue = atomicOp.getCurrentValue ();
320
+
321
+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
322
+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
323
+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
324
+ ValueRange{origValue});
325
+ auto vectorBitCast =
326
+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
327
+
328
+ auto select =
329
+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
330
+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
331
+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
332
+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
333
+ }
334
+
335
+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
336
+ static void rmwStore (OpBuilder &rewriter, Location loc,
337
+ TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
338
+ TypedValue<VectorType> value, Value mask,
339
+ int64_t numSrcElemsPerDest) {
340
+ auto emulatedIOType =
341
+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
342
+ auto elemLoad = rewriter.create <vector::LoadOp>(
343
+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
344
+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
345
+ loc,
346
+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
347
+ elemLoad);
348
+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
349
+ auto toBitcast =
350
+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
351
+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
352
+ emulatedIndex);
353
+ }
354
+
355
+ static_assert (std::is_same_v<decltype (atomicStore), decltype(rmwStore)> &&
356
+ "`atomicStore` and `rmwStore` must have same function type.");
357
+
358
+ // Extract a slice of a vector, and insert it into a byte vector.
359
+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
360
+ Location loc, TypedValue<VectorType> vector,
361
+ int64_t sliceOffset, int64_t sliceNumElements,
362
+ int64_t byteOffset) {
363
+ auto vectorElementType = vector.getType ().getElementType ();
364
+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
365
+ " vector element must be a valid sub-byte type" );
366
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
367
+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
368
+ loc, VectorType::get ({scale}, vectorElementType),
369
+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
370
+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
371
+ sliceOffset, sliceNumElements);
372
+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
373
+ emptyByteVector, byteOffset);
374
+ return inserted;
375
+ }
376
+
309
377
namespace {
310
378
311
379
// ===----------------------------------------------------------------------===//
@@ -315,6 +383,10 @@ namespace {
315
383
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
316
384
using OpConversionPattern::OpConversionPattern;
317
385
386
+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
387
+ : OpConversionPattern<vector::StoreOp>(context),
388
+ useAtomicWrites_ (useAtomicWrites) {}
389
+
318
390
LogicalResult
319
391
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
320
392
ConversionPatternRewriter &rewriter) const override {
@@ -326,7 +398,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
326
398
327
399
auto loc = op.getLoc ();
328
400
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
329
- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
401
+ auto valueToStore = op.getValueToStore ();
402
+ Type oldElementType = valueToStore.getType ().getElementType ();
330
403
Type newElementType = convertedType.getElementType ();
331
404
int srcBits = oldElementType.getIntOrFloatBitWidth ();
332
405
int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -335,7 +408,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
335
408
return rewriter.notifyMatchFailure (
336
409
op, " only dstBits % srcBits == 0 supported" );
337
410
}
338
- int scale = dstBits / srcBits;
411
+ int numSrcElemsPerDest = dstBits / srcBits;
339
412
340
413
// Adjust the number of elements to store when emulating narrow types.
341
414
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -350,32 +423,154 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
350
423
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
351
424
// vector<4xi8>
352
425
353
- auto origElements = op.getValueToStore ().getType ().getNumElements ();
354
- if (origElements % scale != 0 )
355
- return failure ();
426
+ auto origElements = valueToStore.getType ().getNumElements ();
427
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
356
428
357
429
auto stridedMetadata =
358
430
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
359
431
360
432
OpFoldResult linearizedIndices;
361
- std::tie (std::ignore, linearizedIndices) =
433
+ memref::LinearizedMemRefInfo linearizedInfo;
434
+ std::tie (linearizedInfo, linearizedIndices) =
362
435
memref::getLinearizedMemRefOffsetAndSize (
363
436
rewriter, loc, srcBits, dstBits,
364
437
stridedMetadata.getConstifiedMixedOffset (),
365
438
stridedMetadata.getConstifiedMixedSizes (),
366
439
stridedMetadata.getConstifiedMixedStrides (),
367
440
getAsOpFoldResult (adaptor.getIndices ()));
368
441
369
- auto numElements = origElements / scale;
370
- auto bitCast = rewriter.create <vector::BitCastOp>(
371
- loc, VectorType::get (numElements, newElementType),
372
- op.getValueToStore ());
442
+ auto foldedNumFrontPadElems =
443
+ isUnalignedEmulation
444
+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
445
+ : 0 ;
446
+
447
+ if (!foldedNumFrontPadElems) {
448
+ // Unimplemented case for dynamic front padding size != 0
449
+ return failure ();
450
+ }
451
+
452
+ TypedValue<MemRefType> emulatedMemref =
453
+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
454
+
455
+ // Shortcut: conditions when subbyte store at the front is not needed:
456
+ // 1. The source vector size is multiple of byte size
457
+ // 2. The address of the store is aligned to the emulated width boundary
458
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
459
+ auto numElements = origElements / numSrcElemsPerDest;
460
+ auto bitCast = rewriter.create <vector::BitCastOp>(
461
+ loc, VectorType::get (numElements, newElementType),
462
+ op.getValueToStore ());
463
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
464
+ op, bitCast.getResult (), emulatedMemref,
465
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
466
+ return llvm::success ();
467
+ }
468
+
469
+ // The index into the target memref we are storing to
470
+ Value currentDestIndex =
471
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
472
+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
473
+ auto subWidthStoreMaskType =
474
+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
475
+ // The index into the source vector we are currently processing
476
+ auto currentSourceIndex = 0 ;
477
+
478
+ // 1. Partial width store for the first byte, when the store address is not
479
+ // aligned to emulated width boundary, deal with the unaligned part so that
480
+ // the rest elements are aligned to width boundary.
481
+ auto frontSubWidthStoreElem =
482
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
483
+ if (frontSubWidthStoreElem != 0 ) {
484
+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, false );
485
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
486
+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
487
+ origElements, true );
488
+ frontSubWidthStoreElem = origElements;
489
+ } else {
490
+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem,
491
+ *foldedNumFrontPadElems, true );
492
+ }
493
+ auto frontMask = rewriter.create <arith::ConstantOp>(
494
+ loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
373
495
374
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
375
- op, bitCast.getResult (), adaptor.getBase (),
376
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
496
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
497
+ auto value = extractSliceIntoByte (
498
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
499
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
500
+
501
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
502
+ cast<TypedValue<VectorType>>(value),
503
+ frontMask.getResult (), numSrcElemsPerDest);
504
+
505
+ currentDestIndex = rewriter.create <arith::AddIOp>(
506
+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
507
+ }
508
+
509
+ if (currentSourceIndex >= origElements) {
510
+ rewriter.eraseOp (op);
511
+ return success ();
512
+ }
513
+
514
+ // 2. Full width store. After the previous step, the store address is
515
+ // aligned to the emulated width boundary.
516
+ int64_t fullWidthStoreSize =
517
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
518
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
519
+ if (fullWidthStoreSize != 0 ) {
520
+ auto fullWidthStorePart = staticallyExtractSubvector (
521
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
522
+ currentSourceIndex, numNonFullWidthElements);
523
+
524
+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
525
+ auto memrefElemType =
526
+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
527
+ auto storeType = VectorType::get (
528
+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
529
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
530
+ fullWidthStorePart);
531
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
532
+ currentDestIndex);
533
+
534
+ currentSourceIndex += numNonFullWidthElements;
535
+ currentDestIndex = rewriter.create <arith::AddIOp>(
536
+ loc, rewriter.getIndexType (), currentDestIndex,
537
+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
538
+ }
539
+
540
+ // 3. Deal with trailing elements that are aligned to the emulated width,
541
+ // but their length is smaller than the emulated width.
542
+ auto remainingElements = origElements - currentSourceIndex;
543
+ if (remainingElements != 0 ) {
544
+ auto subWidthStorePart = extractSliceIntoByte (
545
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
546
+ currentSourceIndex, remainingElements, 0 );
547
+
548
+ // Generate back mask
549
+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, 0 );
550
+ std::fill_n (maskValues.begin (), remainingElements, 1 );
551
+ auto backMask = rewriter.create <arith::ConstantOp>(
552
+ loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
553
+
554
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
555
+ cast<TypedValue<VectorType>>(subWidthStorePart),
556
+ backMask.getResult (), numSrcElemsPerDest);
557
+ }
558
+
559
+ rewriter.eraseOp (op);
377
560
return success ();
378
561
}
562
+
563
+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
564
+ // / configuration, it could be an atomic store or an RMW sequence.
565
+ template <typename ... Args>
566
+ void subEmulatedWidthStore (Args &&...args) const {
567
+ std::function<decltype (atomicStore)> storeFunc =
568
+ useAtomicWrites_ ? atomicStore : rmwStore;
569
+ storeFunc (std::forward<Args>(args)...);
570
+ }
571
+
572
+ private:
573
+ const bool useAtomicWrites_;
379
574
};
380
575
381
576
// ===----------------------------------------------------------------------===//
@@ -581,9 +776,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
581
776
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
582
777
linearizedInfo.intraDataOffset , origElements);
583
778
} else if (isUnalignedEmulation) {
584
- result =
585
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
586
- *foldedIntraVectorOffset, origElements);
779
+ result = staticallyExtractSubvector (
780
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
587
781
}
588
782
rewriter.replaceOp (op, result);
589
783
return success ();
@@ -742,9 +936,8 @@ struct ConvertVectorMaskedLoad final
742
936
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
743
937
op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
744
938
} else if (isUnalignedEmulation) {
745
- result =
746
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
747
- *foldedIntraVectorOffset, origElements);
939
+ result = staticallyExtractSubvector (
940
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
748
941
}
749
942
rewriter.replaceOp (op, result);
750
943
@@ -827,9 +1020,8 @@ struct ConvertVectorTransferRead final
827
1020
linearizedInfo.intraDataOffset ,
828
1021
origElements);
829
1022
} else if (isUnalignedEmulation) {
830
- result =
831
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
832
- *foldedIntraVectorOffset, origElements);
1023
+ result = staticallyExtractSubvector (
1024
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
833
1025
}
834
1026
rewriter.replaceOp (op, result);
835
1027
@@ -1574,12 +1766,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1574
1766
1575
1767
void vector::populateVectorNarrowTypeEmulationPatterns (
1576
1768
const arith::NarrowTypeEmulationConverter &typeConverter,
1577
- RewritePatternSet &patterns) {
1769
+ RewritePatternSet &patterns, bool useAtomicWrites ) {
1578
1770
1579
- // Populate `vector.*` conversion patterns.
1580
- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1771
+ // Populate `vector.*` load conversion patterns.
1772
+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
1581
1773
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1582
1774
typeConverter, patterns.getContext ());
1775
+
1776
+ // Populate `vector.*` store conversion patterns. The caller can choose
1777
+ // to avoid emitting atomic operations and reduce it to load-modify-write
1778
+ // sequence for stores if it is known there are no thread contentions.
1779
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
1583
1780
}
1584
1781
1585
1782
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments