Skip to content

Commit 149deda

Browse files
committed
update again to address comments
1 parent cfe16db commit 149deda

File tree

2 files changed

+79
-41
lines changed

2 files changed

+79
-41
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ namespace {
400400
// ConvertVectorStore
401401
//===----------------------------------------------------------------------===//
402402

403+
///
404+
///
405+
403406
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
404407
using OpConversionPattern::OpConversionPattern;
405408

@@ -443,7 +446,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
443446
// vector<4xi8>
444447

445448
auto origElements = valueToStore.getType().getNumElements();
446-
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
449+
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
447450

448451
auto stridedMetadata =
449452
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -459,9 +462,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
459462
getAsOpFoldResult(adaptor.getIndices()));
460463

461464
std::optional<int64_t> foldedNumFrontPadElems =
462-
isUnalignedEmulation
463-
? getConstantIntValue(linearizedInfo.intraDataOffset)
464-
: 0;
465+
isAlignedEmulation
466+
? 0
467+
: getConstantIntValue(linearizedInfo.intraDataOffset);
465468

466469
if (!foldedNumFrontPadElems) {
467470
return failure("subbyte store emulation: dynamic front padding size is "
@@ -472,13 +475,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
472475

473476
// Shortcut: conditions when subbyte emulated store at the front is not
474477
// needed:
475-
// 1. The source vector size is multiple of byte size
476-
// 2. The address of the store is aligned to the emulated width boundary
478+
// 1. The source vector size (in bits) is a multiple of byte size.
479+
// 2. The address of the store is aligned to the emulated width boundary.
477480
//
478481
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
479482
// need unaligned emulation because the store address is aligned and the
480483
// source is a whole byte.
481-
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
484+
if (isAlignedEmulation && *foldedNumFrontPadElems == 0) {
482485
auto numElements = origElements / numSrcElemsPerDest;
483486
auto bitCast = rewriter.create<vector::BitCastOp>(
484487
loc, VectorType::get(numElements, newElementType),
@@ -489,17 +492,50 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
489492
return success();
490493
}
491494

492-
// The index into the target memref we are storing to
495+
// Next, handle the case when sub-byte read-modify-write
496+
// sequences are needed to emulate a vector store.
497+
// Here is an example:
498+
//
499+
// Vector to store: vector<7xi2>
500+
// Value to store: 11 11 11 11 11 11 11 (all ones)
501+
//
502+
// Destination: memref<12xi2>
503+
// Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
504+
//
505+
// MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
506+
//
507+
// Destination memref before:
508+
//
509+
// Byte 0 Byte 1 Byte 2
510+
// +----------+----------+----------+
511+
// | 00000000 | 00000000 | 00000000 |
512+
// +----------+----------+----------+
513+
//
514+
// Destination memref after:
515+
//
516+
// Byte 0 Byte 1 Byte 2
517+
// +----------+----------+----------+
518+
// | 00001111 | 11111111 | 11000000 |
519+
// +----------+----------+----------+
520+
//
521+
// Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
522+
// need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
523+
// requiring RMW access (atomicity is required).
524+
525+
// The index into the target memref we are storing to.
493526
Value currentDestIndex =
494527
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
528+
// The index into the source vector we are currently processing.
529+
auto currentSourceIndex = 0;
530+
531+
// Build a mask used for rmw.
495532
auto subWidthStoreMaskType =
496533
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
497-
// The index into the source vector we are currently processing
498-
auto currentSourceIndex = 0;
499534

500-
// 1. Partial width store for the first byte, when the store address is not
501-
// aligned to emulated width boundary, deal with the unaligned part so that
502-
// the rest elements are aligned to width boundary.
535+
// 1. Partial width store for the leading byte.
536+
// When the store address is not aligned to emulated width boundary, deal
537+
// with the unaligned part so that the rest elements are aligned to width
538+
// boundary.
503539
auto frontSubWidthStoreElem =
504540
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
505541
if (frontSubWidthStoreElem > 0) {
@@ -535,8 +571,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535571
currentDestIndex = rewriter.create<arith::AddIOp>(
536572
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
537573

538-
// 2. Full width store. After the previous step, the store address is
539-
// aligned to the emulated width boundary.
574+
// 2. Full width store for the inner output bytes.
575+
// After the previous step, the store address is aligned to the emulated
576+
// width boundary.
540577
int64_t fullWidthStoreSize =
541578
(origElements - currentSourceIndex) / numSrcElemsPerDest;
542579
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
@@ -560,15 +597,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
560597
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
561598
}
562599

563-
// 3. Deal with trailing elements that are aligned to the emulated width,
564-
// but their length is smaller than the emulated width.
600+
// 3. Partial width store for the trailing output byte.
601+
// It is needed when the residual length is smaller than the emulated width,
602+
// which is not covered in step 2 above.
565603
auto remainingElements = origElements - currentSourceIndex;
566604
if (remainingElements != 0) {
567605
auto subWidthStorePart =
568606
extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
569607
currentSourceIndex, remainingElements, 0);
570608

571-
// Generate back mask
609+
// Generate back mask.
572610
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
573611
std::fill_n(maskValues.begin(), remainingElements, 1);
574612
auto backMask = rewriter.create<arith::ConstantOp>(
@@ -751,7 +789,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
751789
// compile time as they must be constants.
752790

753791
auto origElements = op.getVectorType().getNumElements();
754-
bool isUnalignedEmulation = origElements % scale != 0;
792+
bool isAlignedEmulation = origElements % scale == 0;
755793

756794
auto stridedMetadata =
757795
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -767,9 +805,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
767805
getAsOpFoldResult(adaptor.getIndices()));
768806

769807
std::optional<int64_t> foldedIntraVectorOffset =
770-
isUnalignedEmulation
771-
? getConstantIntValue(linearizedInfo.intraDataOffset)
772-
: 0;
808+
isAlignedEmulation
809+
? 0
810+
: getConstantIntValue(linearizedInfo.intraDataOffset);
773811

774812
// Always load enough elements which can cover the original elements.
775813
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
@@ -785,7 +823,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
785823
result = dynamicallyExtractSubVector(
786824
rewriter, loc, cast<VectorValue>(result), resultVector,
787825
linearizedInfo.intraDataOffset, origElements);
788-
} else if (isUnalignedEmulation) {
826+
} else if (!isAlignedEmulation) {
789827
result = staticallyExtractSubvector(
790828
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
791829
}
@@ -867,7 +905,7 @@ struct ConvertVectorMaskedLoad final
867905
// subvector at the proper offset after bit-casting.
868906
auto origType = op.getVectorType();
869907
auto origElements = origType.getNumElements();
870-
bool isUnalignedEmulation = origElements % scale != 0;
908+
bool isAlignedEmulation = origElements % scale == 0;
871909

872910
auto stridedMetadata =
873911
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -882,9 +920,9 @@ struct ConvertVectorMaskedLoad final
882920
getAsOpFoldResult(adaptor.getIndices()));
883921

884922
std::optional<int64_t> foldedIntraVectorOffset =
885-
isUnalignedEmulation
886-
? getConstantIntValue(linearizedInfo.intraDataOffset)
887-
: 0;
923+
isAlignedEmulation
924+
? 0
925+
: getConstantIntValue(linearizedInfo.intraDataOffset);
888926

889927
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
890928
FailureOr<Operation *> newMask = getCompressedMaskOp(
@@ -905,7 +943,7 @@ struct ConvertVectorMaskedLoad final
905943
passthru = dynamicallyInsertSubVector(
906944
rewriter, loc, cast<VectorValue>(passthru), emptyVector,
907945
linearizedInfo.intraDataOffset, origElements);
908-
} else if (isUnalignedEmulation) {
946+
} else if (!isAlignedEmulation) {
909947
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
910948
*foldedIntraVectorOffset);
911949
}
@@ -933,7 +971,7 @@ struct ConvertVectorMaskedLoad final
933971
mask = dynamicallyInsertSubVector(
934972
rewriter, loc, cast<VectorValue>(mask), emptyMask,
935973
linearizedInfo.intraDataOffset, origElements);
936-
} else if (isUnalignedEmulation) {
974+
} else if (!isAlignedEmulation) {
937975
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
938976
*foldedIntraVectorOffset);
939977
}
@@ -944,7 +982,7 @@ struct ConvertVectorMaskedLoad final
944982
result = dynamicallyExtractSubVector(
945983
rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
946984
linearizedInfo.intraDataOffset, origElements);
947-
} else if (isUnalignedEmulation) {
985+
} else if (!isAlignedEmulation) {
948986
result = staticallyExtractSubvector(
949987
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
950988
}
@@ -986,7 +1024,7 @@ struct ConvertVectorTransferRead final
9861024

9871025
auto origElements = op.getVectorType().getNumElements();
9881026

989-
bool isUnalignedEmulation = origElements % scale != 0;
1027+
bool isAlignedEmulation = origElements % scale == 0;
9901028

9911029
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
9921030
adaptor.getPadding());
@@ -1005,9 +1043,9 @@ struct ConvertVectorTransferRead final
10051043
getAsOpFoldResult(adaptor.getIndices()));
10061044

10071045
std::optional<int64_t> foldedIntraVectorOffset =
1008-
isUnalignedEmulation
1009-
? getConstantIntValue(linearizedInfo.intraDataOffset)
1010-
: 0;
1046+
isAlignedEmulation
1047+
? 0
1048+
: getConstantIntValue(linearizedInfo.intraDataOffset);
10111049

10121050
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
10131051
auto numElements =
@@ -1028,7 +1066,7 @@ struct ConvertVectorTransferRead final
10281066
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
10291067
linearizedInfo.intraDataOffset,
10301068
origElements);
1031-
} else if (isUnalignedEmulation) {
1069+
} else if (!isAlignedEmulation) {
10321070
result = staticallyExtractSubvector(
10331071
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10341072
}

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361361
/// vector.store
362362
///----------------------------------------------------------------------------------------
363363

364-
func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
364+
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
365365
%src = memref.alloc() : memref<3x3xi2>
366366
%c0 = arith.constant 0 : index
367367
%c2 = arith.constant 2 : index
@@ -374,7 +374,7 @@ func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
374374
// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
375375
// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
376376

377-
// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic_rmw(
377+
// CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores(
378378
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
379379
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
380380
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -413,7 +413,7 @@ func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
413413

414414
// -----
415415

416-
func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
416+
func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
417417
%0 = memref.alloc() : memref<3x7xi2>
418418
%c0 = arith.constant 0 : index
419419
%c1 = arith.constant 1 : index
@@ -422,7 +422,7 @@ func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
422422
}
423423

424424
// In this example, emit 2 atomic RMWs and 1 non-atomic store:
425-
// CHECK-LABEL: func @vector_store_i2_atomic_rmw(
425+
// CHECK-LABEL: func @vector_store_i2_two_partial_one_full_stores(
426426
// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
427427
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
428428
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -469,7 +469,7 @@ func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
469469

470470
// -----
471471

472-
func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
472+
func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
473473
%0 = memref.alloc() : memref<4x1xi2>
474474
%c0 = arith.constant 0 : index
475475
%c1 = arith.constant 1 : index
@@ -478,7 +478,7 @@ func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
478478
}
479479

480480
// In this example, only emit 1 atomic store
481-
// CHECK-LABEL: func @vector_store_i2_const_index_one_atomic_rmw(
481+
// CHECK-LABEL: func @vector_store_i2_const_index_one_partial_store(
482482
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
483483
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
484484
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)