Skip to content

Commit 3e5640b

Browse files
authored
[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (#123526)
This is PR 1 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames: * `srcBits`/`dstBits` + `oldElementType`/`newElementType` to improve consistency in naming within the file. This is illustrated below: ```cpp // Extracted from VectorEmulateNarrowType.cpp // BEFORE (mixing old/new and src/dst): // Type oldElementType = op.getType().getElementType(); // Type newElementType = convertedType.getElementType(); // int srcBits = oldElementType.getIntOrFloatBitWidth(); // int dstBits = newElementType.getIntOrFloatBitWidth(); // AFTER (consistently using emulated/container): Type emulatedElemType = op.getType().getElementType(); Type containerElemType = convertedType.getElementType(); int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); int containerBits = containerElemTy.getIntOrFloatBitWidth(); ``` Also adds some comments and unifies related "rewriter notification" messages. **GitHub issue to track this work:** * #123630
1 parent d68a4b9 commit 3e5640b

File tree

1 file changed

+66
-54
lines changed

1 file changed

+66
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
415415
"only 1-D vectors are supported ATM");
416416

417417
auto loc = op.getLoc();
418+
418419
auto valueToStore = cast<VectorValue>(op.getValueToStore());
419-
auto oldElementType = valueToStore.getType().getElementType();
420-
auto newElementType =
420+
auto containerElemTy =
421421
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
422-
int srcBits = oldElementType.getIntOrFloatBitWidth();
423-
int dstBits = newElementType.getIntOrFloatBitWidth();
422+
Type emulatedElemTy = op.getValueToStore().getType().getElementType();
423+
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
424+
int containerBits = containerElemTy.getIntOrFloatBitWidth();
424425

425-
if (dstBits % srcBits != 0) {
426+
// Check per-element alignment.
427+
if (containerBits % emulatedBits != 0) {
426428
return rewriter.notifyMatchFailure(
427-
op, "only dstBits % srcBits == 0 supported");
429+
op, "impossible to pack emulated elements into container elements "
430+
"(bit-wise misalignment)");
428431
}
429-
int numSrcElemsPerDest = dstBits / srcBits;
432+
int numSrcElemsPerDest = containerBits / emulatedBits;
430433

431434
// Adjust the number of elements to store when emulating narrow types.
432435
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
451454
memref::LinearizedMemRefInfo linearizedInfo;
452455
std::tie(linearizedInfo, linearizedIndices) =
453456
memref::getLinearizedMemRefOffsetAndSize(
454-
rewriter, loc, srcBits, dstBits,
457+
rewriter, loc, emulatedBits, containerBits,
455458
stridedMetadata.getConstifiedMixedOffset(),
456459
stridedMetadata.getConstifiedMixedSizes(),
457460
stridedMetadata.getConstifiedMixedStrides(),
@@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
483486
// Basic case: storing full bytes.
484487
auto numElements = origElements / numSrcElemsPerDest;
485488
auto bitCast = rewriter.create<vector::BitCastOp>(
486-
loc, VectorType::get(numElements, newElementType),
489+
loc, VectorType::get(numElements, containerElemTy),
487490
op.getValueToStore());
488491
rewriter.replaceOpWithNewOp<vector::StoreOp>(
489492
op, bitCast.getResult(), memrefBase,
@@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
638641
"only 1-D vectors are supported ATM");
639642

640643
auto loc = op.getLoc();
641-
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
642-
Type oldElementType = op.getValueToStore().getType().getElementType();
643-
Type newElementType = convertedType.getElementType();
644-
int srcBits = oldElementType.getIntOrFloatBitWidth();
645-
int dstBits = newElementType.getIntOrFloatBitWidth();
644+
auto containerElemTy =
645+
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
646+
Type emulatedElemTy = op.getValueToStore().getType().getElementType();
647+
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
648+
int containerBits = containerElemTy.getIntOrFloatBitWidth();
646649

647-
if (dstBits % srcBits != 0) {
650+
// Check per-element alignment.
651+
if (containerBits % emulatedBits != 0) {
648652
return rewriter.notifyMatchFailure(
649-
op, "only dstBits % srcBits == 0 supported");
653+
op, "impossible to pack emulated elements into container elements "
654+
"(bit-wise misalignment)");
650655
}
651656

652-
int scale = dstBits / srcBits;
657+
int scale = containerBits / emulatedBits;
653658
int origElements = op.getValueToStore().getType().getNumElements();
654659
if (origElements % scale != 0)
655660
return failure();
@@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
660665
memref::LinearizedMemRefInfo linearizedInfo;
661666
std::tie(linearizedInfo, linearizedIndicesOfr) =
662667
memref::getLinearizedMemRefOffsetAndSize(
663-
rewriter, loc, srcBits, dstBits,
668+
rewriter, loc, emulatedBits, containerBits,
664669
stridedMetadata.getConstifiedMixedOffset(),
665670
stridedMetadata.getConstifiedMixedSizes(),
666671
stridedMetadata.getConstifiedMixedStrides(),
@@ -706,15 +711,15 @@ struct ConvertVectorMaskedStore final
706711
return failure();
707712

708713
auto numElements = (origElements + scale - 1) / scale;
709-
auto newType = VectorType::get(numElements, newElementType);
714+
auto newType = VectorType::get(numElements, containerElemTy);
710715
auto passThru = rewriter.create<arith::ConstantOp>(
711716
loc, newType, rewriter.getZeroAttr(newType));
712717

713718
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
714719
loc, newType, adaptor.getBase(), linearizedIndices,
715720
newMask.value()->getResult(0), passThru);
716721

717-
auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
722+
auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
718723
Value valueToStore =
719724
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
720725
valueToStore = rewriter.create<arith::SelectOp>(
@@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
746751
"only 1-D vectors are supported ATM");
747752

748753
auto loc = op.getLoc();
749-
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
750-
Type oldElementType = op.getType().getElementType();
751-
Type newElementType = convertedType.getElementType();
752-
int srcBits = oldElementType.getIntOrFloatBitWidth();
753-
int dstBits = newElementType.getIntOrFloatBitWidth();
754+
auto containerElemTy =
755+
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
756+
Type emulatedElemTy = op.getType().getElementType();
757+
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
758+
int containerBits = containerElemTy.getIntOrFloatBitWidth();
754759

755-
if (dstBits % srcBits != 0) {
760+
// Check per-element alignment.
761+
if (containerBits % emulatedBits != 0) {
756762
return rewriter.notifyMatchFailure(
757-
op, "only dstBits % srcBits == 0 supported");
763+
op, "impossible to pack emulated elements into container elements "
764+
"(bit-wise misalignment)");
758765
}
759-
int scale = dstBits / srcBits;
766+
int scale = containerBits / emulatedBits;
760767

761768
// Adjust the number of elements to load when emulating narrow types,
762769
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797804
memref::LinearizedMemRefInfo linearizedInfo;
798805
std::tie(linearizedInfo, linearizedIndices) =
799806
memref::getLinearizedMemRefOffsetAndSize(
800-
rewriter, loc, srcBits, dstBits,
807+
rewriter, loc, emulatedBits, containerBits,
801808
stridedMetadata.getConstifiedMixedOffset(),
802809
stridedMetadata.getConstifiedMixedSizes(),
803810
stridedMetadata.getConstifiedMixedStrides(),
@@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
814821
llvm::divideCeil(maxintraDataOffset + origElements, scale);
815822
Value result =
816823
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
817-
numElements, oldElementType, newElementType);
824+
numElements, emulatedElemTy, containerElemTy);
818825

819826
if (!foldedIntraVectorOffset) {
820827
auto resultVector = rewriter.create<arith::ConstantOp>(
@@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
848855
"only 1-D vectors are supported ATM");
849856

850857
auto loc = op.getLoc();
851-
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
852-
Type oldElementType = op.getType().getElementType();
853-
Type newElementType = convertedType.getElementType();
854-
int srcBits = oldElementType.getIntOrFloatBitWidth();
855-
int dstBits = newElementType.getIntOrFloatBitWidth();
856858

857-
if (dstBits % srcBits != 0) {
859+
auto containerElemTy =
860+
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
861+
Type emulatedElemTy = op.getType().getElementType();
862+
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
863+
int containerBits = containerElemTy.getIntOrFloatBitWidth();
864+
865+
// Check per-element alignment.
866+
if (containerBits % emulatedBits != 0) {
858867
return rewriter.notifyMatchFailure(
859-
op, "only dstBits % srcBits == 0 supported");
868+
op, "impossible to pack emulated elements into container elements "
869+
"(bit-wise misalignment)");
860870
}
861-
int scale = dstBits / srcBits;
871+
int scale = containerBits / emulatedBits;
862872

863873
// Adjust the number of elements to load when emulating narrow types,
864874
// and then cast back to the original type with vector.bitcast op.
@@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
912922
memref::LinearizedMemRefInfo linearizedInfo;
913923
std::tie(linearizedInfo, linearizedIndices) =
914924
memref::getLinearizedMemRefOffsetAndSize(
915-
rewriter, loc, srcBits, dstBits,
925+
rewriter, loc, emulatedBits, containerBits,
916926
stridedMetadata.getConstifiedMixedOffset(),
917927
stridedMetadata.getConstifiedMixedSizes(),
918928
stridedMetadata.getConstifiedMixedStrides(),
@@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final
933943

934944
auto numElements =
935945
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
936-
auto loadType = VectorType::get(numElements, newElementType);
937-
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
946+
auto loadType = VectorType::get(numElements, containerElemTy);
947+
auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
938948

939949
auto emptyVector = rewriter.create<arith::ConstantOp>(
940950
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
10091019
"only 1-D vectors are supported ATM");
10101020

10111021
auto loc = op.getLoc();
1012-
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
1013-
Type oldElementType = op.getType().getElementType();
1014-
Type newElementType = convertedType.getElementType();
1015-
int srcBits = oldElementType.getIntOrFloatBitWidth();
1016-
int dstBits = newElementType.getIntOrFloatBitWidth();
1017-
1018-
if (dstBits % srcBits != 0) {
1022+
auto containerElemTy =
1023+
cast<MemRefType>(adaptor.getSource().getType()).getElementType();
1024+
Type emulatedElemTy = op.getType().getElementType();
1025+
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1026+
int containerBits = containerElemTy.getIntOrFloatBitWidth();
1027+
1028+
// Check per-element alignment.
1029+
if (containerBits % emulatedBits != 0) {
10191030
return rewriter.notifyMatchFailure(
1020-
op, "only dstBits % srcBits == 0 supported");
1031+
op, "impossible to pack emulated elements into container elements "
1032+
"(bit-wise misalignment)");
10211033
}
1022-
int scale = dstBits / srcBits;
1034+
int scale = containerBits / emulatedBits;
10231035

10241036
auto origElements = op.getVectorType().getNumElements();
10251037

10261038
bool isAlignedEmulation = origElements % scale == 0;
10271039

1028-
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
1040+
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10291041
adaptor.getPadding());
10301042

10311043
auto stridedMetadata =
@@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
10351047
memref::LinearizedMemRefInfo linearizedInfo;
10361048
std::tie(linearizedInfo, linearizedIndices) =
10371049
memref::getLinearizedMemRefOffsetAndSize(
1038-
rewriter, loc, srcBits, dstBits,
1050+
rewriter, loc, emulatedBits, containerBits,
10391051
stridedMetadata.getConstifiedMixedOffset(),
10401052
stridedMetadata.getConstifiedMixedSizes(),
10411053
stridedMetadata.getConstifiedMixedStrides(),
@@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
10511063
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
10521064

10531065
auto newRead = rewriter.create<vector::TransferReadOp>(
1054-
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
1066+
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
10551067
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
10561068
newPadding);
10571069

10581070
auto bitCast = rewriter.create<vector::BitCastOp>(
1059-
loc, VectorType::get(numElements * scale, oldElementType), newRead);
1071+
loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
10601072

10611073
Value result = bitCast->getResult(0);
10621074
if (!foldedIntraVectorOffset) {

0 commit comments

Comments
 (0)