33
33
#include " mlir/Transforms/DialectConversion.h"
34
34
#include " llvm/ADT/SmallVector.h"
35
35
#include " llvm/Support/Debug.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include " llvm/Support/raw_ostream.h"
38
39
#include < cstdint>
@@ -211,13 +212,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
211
212
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
212
213
// / emitting `vector.extract_strided_slice`.
213
214
static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
214
- VectorType extractType, Value source,
215
- int64_t frontOffset,
215
+ Value source, int64_t frontOffset,
216
216
int64_t subvecSize) {
217
217
auto vectorType = cast<VectorType>(source.getType ());
218
- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
219
- " expected 1-D source and destination types" );
220
- (void )vectorType;
218
+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
221
219
assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
222
220
" subvector out of bounds" );
223
221
@@ -228,9 +226,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
228
226
auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
229
227
auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
230
228
auto strides = rewriter.getI64ArrayAttr ({1 });
229
+
230
+ auto resultVectorType =
231
+ VectorType::get ({subvecSize}, vectorType.getElementType ());
231
232
return rewriter
232
- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
233
- sizes, strides)
233
+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
234
+ offsets, sizes, strides)
234
235
->getResult (0 );
235
236
}
236
237
@@ -309,6 +310,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
309
310
newLoad);
310
311
}
311
312
313
+ // / Atomically store a subbyte-sized value to memory, with a mask.
314
+ static void atomicStore (OpBuilder &builder, Location loc,
315
+ TypedValue<MemRefType> emulatedMemref,
316
+ Value linearizedIndex, TypedValue<VectorType> value,
317
+ Value mask, int64_t ) {
318
+ auto atomicOp = builder.create <memref::GenericAtomicRMWOp>(
319
+ loc, emulatedMemref, ValueRange{linearizedIndex});
320
+ Value origValue = atomicOp.getCurrentValue ();
321
+
322
+ OpBuilder::InsertionGuard guard (builder);
323
+ builder.setInsertionPointToStart (atomicOp.getBody ());
324
+
325
+ // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
326
+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
327
+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
328
+ ValueRange{origValue});
329
+ auto vectorBitCast =
330
+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
331
+
332
+ auto select =
333
+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
334
+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
335
+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
336
+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
337
+ }
338
+
339
+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
340
+ static void rmwStore (OpBuilder &rewriter, Location loc,
341
+ TypedValue<MemRefType> emulatedMemref,
342
+ Value linearizedIndex, TypedValue<VectorType> value,
343
+ Value mask, int64_t numSrcElemsPerDest) {
344
+ auto emulatedIOType =
345
+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
346
+ auto elemLoad = rewriter.create <vector::LoadOp>(
347
+ loc, emulatedIOType, emulatedMemref, ValueRange{linearizedIndex});
348
+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
349
+ loc,
350
+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
351
+ elemLoad);
352
+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
353
+ auto toBitcast =
354
+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
355
+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
356
+ linearizedIndex);
357
+ }
358
+
359
+ static_assert (std::is_same_v<decltype (atomicStore), decltype(rmwStore)> &&
360
+ "`atomicStore` and `rmwStore` must have same signature, as per "
361
+ "the design to keep the code clean, which one to call is "
362
+ "determined by the `useAtomicWrites` flag.");
363
+
364
+ // Extract a slice of a vector, and insert it into a byte vector.
365
+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
366
+ Location loc, TypedValue<VectorType> vector,
367
+ int64_t sliceOffset, int64_t sliceNumElements,
368
+ int64_t byteOffset) {
369
+ auto vectorElementType = vector.getType ().getElementType ();
370
+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
371
+ " vector element must be a valid sub-byte type" );
372
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
373
+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
374
+ loc, VectorType::get ({scale}, vectorElementType),
375
+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
376
+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
377
+ sliceOffset, sliceNumElements);
378
+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
379
+ emptyByteVector, byteOffset);
380
+ return inserted;
381
+ }
382
+
312
383
namespace {
313
384
314
385
// ===----------------------------------------------------------------------===//
@@ -318,6 +389,10 @@ namespace {
318
389
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
319
390
using OpConversionPattern::OpConversionPattern;
320
391
392
+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
393
+ : OpConversionPattern<vector::StoreOp>(context),
394
+ useAtomicWrites_ (useAtomicWrites) {}
395
+
321
396
LogicalResult
322
397
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
323
398
ConversionPatternRewriter &rewriter) const override {
@@ -329,16 +404,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
329
404
330
405
auto loc = op.getLoc ();
331
406
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
332
- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
333
- Type newElementType = convertedType.getElementType ();
407
+ auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore ());
408
+ auto oldElementType = valueToStore.getType ().getElementType ();
409
+ auto newElementType = convertedType.getElementType ();
334
410
int srcBits = oldElementType.getIntOrFloatBitWidth ();
335
411
int dstBits = newElementType.getIntOrFloatBitWidth ();
336
412
337
413
if (dstBits % srcBits != 0 ) {
338
414
return rewriter.notifyMatchFailure (
339
415
op, " only dstBits % srcBits == 0 supported" );
340
416
}
341
- int scale = dstBits / srcBits;
417
+ int numSrcElemsPerDest = dstBits / srcBits;
342
418
343
419
// Adjust the number of elements to store when emulating narrow types.
344
420
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -353,32 +429,153 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
353
429
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
354
430
// vector<4xi8>
355
431
356
- auto origElements = op.getValueToStore ().getType ().getNumElements ();
357
- if (origElements % scale != 0 )
358
- return failure ();
432
+ auto origElements = valueToStore.getType ().getNumElements ();
433
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
359
434
360
435
auto stridedMetadata =
361
436
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
362
437
363
438
OpFoldResult linearizedIndices;
364
- std::tie (std::ignore, linearizedIndices) =
439
+ memref::LinearizedMemRefInfo linearizedInfo;
440
+ std::tie (linearizedInfo, linearizedIndices) =
365
441
memref::getLinearizedMemRefOffsetAndSize (
366
442
rewriter, loc, srcBits, dstBits,
367
443
stridedMetadata.getConstifiedMixedOffset (),
368
444
stridedMetadata.getConstifiedMixedSizes (),
369
445
stridedMetadata.getConstifiedMixedStrides (),
370
446
getAsOpFoldResult (adaptor.getIndices ()));
371
447
372
- auto numElements = origElements / scale;
373
- auto bitCast = rewriter. create <vector::BitCastOp>(
374
- loc, VectorType::get (numElements, newElementType),
375
- op. getValueToStore ()) ;
448
+ auto foldedNumFrontPadElems =
449
+ isUnalignedEmulation
450
+ ? getConstantIntValue (linearizedInfo. intraDataOffset )
451
+ : 0 ;
376
452
377
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
378
- op, bitCast.getResult (), adaptor.getBase (),
379
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
453
+ if (!foldedNumFrontPadElems) {
454
+ // Unimplemented case for dynamic front padding size != 0
455
+ return failure ();
456
+ }
457
+
458
+ auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase ());
459
+
460
+ // Shortcut: conditions when subbyte store at the front is not needed:
461
+ // 1. The source vector size is multiple of byte size
462
+ // 2. The address of the store is aligned to the emulated width boundary
463
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
464
+ auto numElements = origElements / numSrcElemsPerDest;
465
+ auto bitCast = rewriter.create <vector::BitCastOp>(
466
+ loc, VectorType::get (numElements, newElementType),
467
+ op.getValueToStore ());
468
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
469
+ op, bitCast.getResult (), emulatedMemref,
470
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
471
+ return success ();
472
+ }
473
+
474
+ // The index into the target memref we are storing to
475
+ Value currentDestIndex =
476
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
477
+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
478
+ auto subWidthStoreMaskType =
479
+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
480
+ // The index into the source vector we are currently processing
481
+ auto currentSourceIndex = 0 ;
482
+
483
+ // 1. Partial width store for the first byte, when the store address is not
484
+ // aligned to emulated width boundary, deal with the unaligned part so that
485
+ // the rest elements are aligned to width boundary.
486
+ auto frontSubWidthStoreElem =
487
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
488
+ if (frontSubWidthStoreElem != 0 ) {
489
+ SmallVector<bool > frontMaskValues (numSrcElemsPerDest, false );
490
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
491
+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
492
+ origElements, true );
493
+ frontSubWidthStoreElem = origElements;
494
+ } else {
495
+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem,
496
+ *foldedNumFrontPadElems, true );
497
+ }
498
+ auto frontMask = rewriter.create <arith::ConstantOp>(
499
+ loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
500
+
501
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
502
+ auto value =
503
+ extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
504
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
505
+
506
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
507
+ cast<TypedValue<VectorType>>(value),
508
+ frontMask.getResult (), numSrcElemsPerDest);
509
+
510
+ currentDestIndex = rewriter.create <arith::AddIOp>(
511
+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
512
+ }
513
+
514
+ if (currentSourceIndex >= origElements) {
515
+ rewriter.eraseOp (op);
516
+ return success ();
517
+ }
518
+
519
+ // 2. Full width store. After the previous step, the store address is
520
+ // aligned to the emulated width boundary.
521
+ int64_t fullWidthStoreSize =
522
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
523
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
524
+ if (fullWidthStoreSize != 0 ) {
525
+ auto fullWidthStorePart = staticallyExtractSubvector (
526
+ rewriter, loc, valueToStore, currentSourceIndex,
527
+ numNonFullWidthElements);
528
+
529
+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
530
+ auto memrefElemType =
531
+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
532
+ auto storeType = VectorType::get (
533
+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
534
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
535
+ fullWidthStorePart);
536
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
537
+ currentDestIndex);
538
+
539
+ currentSourceIndex += numNonFullWidthElements;
540
+ currentDestIndex = rewriter.create <arith::AddIOp>(
541
+ loc, rewriter.getIndexType (), currentDestIndex,
542
+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
543
+ }
544
+
545
+ // 3. Deal with trailing elements that are aligned to the emulated width,
546
+ // but their length is smaller than the emulated width.
547
+ auto remainingElements = origElements - currentSourceIndex;
548
+ if (remainingElements != 0 ) {
549
+ auto subWidthStorePart = extractSliceIntoByte (
550
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
551
+ currentSourceIndex, remainingElements, 0 );
552
+
553
+ // Generate back mask
554
+ auto maskValues = SmallVector<bool >(numSrcElemsPerDest, 0 );
555
+ std::fill_n (maskValues.begin (), remainingElements, 1 );
556
+ auto backMask = rewriter.create <arith::ConstantOp>(
557
+ loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
558
+
559
+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
560
+ cast<TypedValue<VectorType>>(subWidthStorePart),
561
+ backMask.getResult (), numSrcElemsPerDest);
562
+ }
563
+
564
+ rewriter.eraseOp (op);
380
565
return success ();
381
566
}
567
+
568
+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
569
+ // / configuration, it could be an atomic store or an RMW sequence.
570
+ template <typename ... Args>
571
+ void subEmulatedWidthStore (Args &&...args) const {
572
+ std::function<decltype (atomicStore)> storeFunc =
573
+ useAtomicWrites_ ? atomicStore : rmwStore;
574
+ storeFunc (std::forward<Args>(args)...);
575
+ }
576
+
577
+ private:
578
+ const bool useAtomicWrites_;
382
579
};
383
580
384
581
// ===----------------------------------------------------------------------===//
@@ -584,9 +781,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
584
781
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
585
782
linearizedInfo.intraDataOffset , origElements);
586
783
} else if (isUnalignedEmulation) {
587
- result =
588
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
589
- *foldedIntraVectorOffset, origElements);
784
+ result = staticallyExtractSubvector (
785
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
590
786
}
591
787
rewriter.replaceOp (op, result);
592
788
return success ();
@@ -745,9 +941,8 @@ struct ConvertVectorMaskedLoad final
745
941
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
746
942
op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
747
943
} else if (isUnalignedEmulation) {
748
- result =
749
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
750
- *foldedIntraVectorOffset, origElements);
944
+ result = staticallyExtractSubvector (
945
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
751
946
}
752
947
rewriter.replaceOp (op, result);
753
948
@@ -830,9 +1025,8 @@ struct ConvertVectorTransferRead final
830
1025
linearizedInfo.intraDataOffset ,
831
1026
origElements);
832
1027
} else if (isUnalignedEmulation) {
833
- result =
834
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
835
- *foldedIntraVectorOffset, origElements);
1028
+ result = staticallyExtractSubvector (
1029
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
836
1030
}
837
1031
rewriter.replaceOp (op, result);
838
1032
@@ -1577,12 +1771,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1577
1771
1578
1772
void vector::populateVectorNarrowTypeEmulationPatterns (
1579
1773
const arith::NarrowTypeEmulationConverter &typeConverter,
1580
- RewritePatternSet &patterns) {
1774
+ RewritePatternSet &patterns, bool useAtomicWrites ) {
1581
1775
1582
- // Populate `vector.*` conversion patterns.
1583
- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1776
+ // Populate `vector.*` load conversion patterns.
1777
+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
1584
1778
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1585
1779
typeConverter, patterns.getContext ());
1780
+
1781
+ // Populate `vector.*` store conversion patterns. The caller can choose
1782
+ // to avoid emitting atomic operations and reduce it to load-modify-write
1783
+ // sequence for stores if it is known there are no thread contentions.
1784
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
1586
1785
}
1587
1786
1588
1787
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments