Skip to content

Commit 9fab1bb

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * #123526 123526 Please only review the [top commit](d40b31b). **GitHub issue to track this work**: #123630
1 parent 3954c8f commit 9fab1bb

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,13 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
290290
int64_t numContainerElemsToLoad,
291291
Type emulatedElemTy,
292292
Type containerElemTy) {
293-
auto scale = containerElemTy.getIntOrFloatBitWidth() /
293+
auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
294294
emulatedElemTy.getIntOrFloatBitWidth();
295295
auto newLoad = rewriter.create<vector::LoadOp>(
296296
loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
297297
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
298298
return rewriter.create<vector::BitCastOp>(
299-
loc, VectorType::get(numContainerElemsToLoad * scale, emulatedElemTy),
299+
loc, VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem, emulatedElemTy),
300300
newLoad);
301301
}
302302

@@ -388,10 +388,10 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
388388
"sliceNumElements * vector element size must be less than or equal to 8");
389389
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
390390
"vector element must be a valid sub-byte type");
391-
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
391+
auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
392392
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
393-
loc, VectorType::get({scale}, vectorElementType),
394-
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
393+
loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
394+
rewriter.getZeroAttr(VectorType::get({emulatedPerContainerElem}, vectorElementType)));
395395
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
396396
extractOffset, sliceNumElements);
397397
return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
@@ -656,9 +656,9 @@ struct ConvertVectorMaskedStore final
656656
"(bit-wise misalignment)");
657657
}
658658

659-
int scale = containerBits / emulatedBits;
659+
int emulatedPerContainerElem = containerBits / emulatedBits;
660660
int origElements = op.getValueToStore().getType().getNumElements();
661-
if (origElements % scale != 0)
661+
if (origElements % emulatedPerContainerElem != 0)
662662
return failure();
663663

664664
auto stridedMetadata =
@@ -708,11 +708,11 @@ struct ConvertVectorMaskedStore final
708708
// FIXME: Make an example based on the comment above work (see #115460 for
709709
// reproducer).
710710
FailureOr<Operation *> newMask =
711-
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
711+
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
712712
if (failed(newMask))
713713
return failure();
714714

715-
auto numElements = (origElements + scale - 1) / scale;
715+
auto numElements = (origElements + emulatedPerContainerElem - 1) / emulatedPerContainerElem;
716716
auto newType = VectorType::get(numElements, containerElemTy);
717717
auto passThru = rewriter.create<arith::ConstantOp>(
718718
loc, newType, rewriter.getZeroAttr(newType));
@@ -721,7 +721,7 @@ struct ConvertVectorMaskedStore final
721721
loc, newType, adaptor.getBase(), linearizedIndices,
722722
newMask.value()->getResult(0), passThru);
723723

724-
auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
724+
auto newBitCastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
725725
Value valueToStore =
726726
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
727727
valueToStore = rewriter.create<arith::SelectOp>(
@@ -765,7 +765,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
765765
op, "impossible to pack emulated elements into container elements "
766766
"(bit-wise misalignment)");
767767
}
768-
int scale = containerBits / emulatedBits;
768+
int emulatedPerContainerElem = containerBits / emulatedBits;
769769

770770
// Adjust the number of elements to load when emulating narrow types,
771771
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +797,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797797
// compile time as they must be constants.
798798

799799
auto origElements = op.getVectorType().getNumElements();
800-
bool isAlignedEmulation = origElements % scale == 0;
800+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
801801

802802
auto stridedMetadata =
803803
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -818,9 +818,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
818818
: getConstantIntValue(linearizedInfo.intraDataOffset);
819819

820820
// Always load enough elements which can cover the original elements.
821-
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
821+
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
822822
auto numElements =
823-
llvm::divideCeil(maxintraDataOffset + origElements, scale);
823+
llvm::divideCeil(maxintraDataOffset + origElements, emulatedPerContainerElem);
824824
Value result =
825825
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
826826
numElements, emulatedElemTy, containerElemTy);
@@ -870,7 +870,7 @@ struct ConvertVectorMaskedLoad final
870870
op, "impossible to pack emulated elements into container elements "
871871
"(bit-wise misalignment)");
872872
}
873-
int scale = containerBits / emulatedBits;
873+
int emulatedPerContainerElem = containerBits / emulatedBits;
874874

875875
// Adjust the number of elements to load when emulating narrow types,
876876
// and then cast back to the original type with vector.bitcast op.
@@ -916,7 +916,7 @@ struct ConvertVectorMaskedLoad final
916916
// subvector at the proper offset after bit-casting.
917917
auto origType = op.getVectorType();
918918
auto origElements = origType.getNumElements();
919-
bool isAlignedEmulation = origElements % scale == 0;
919+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
920920

921921
auto stridedMetadata =
922922
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -935,18 +935,18 @@ struct ConvertVectorMaskedLoad final
935935
? 0
936936
: getConstantIntValue(linearizedInfo.intraDataOffset);
937937

938-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
938+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
939939
FailureOr<Operation *> newMask = getCompressedMaskOp(
940-
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
940+
rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem, maxIntraDataOffset);
941941
if (failed(newMask))
942942
return failure();
943943

944944
Value passthru = op.getPassThru();
945945

946946
auto numElements =
947-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
947+
llvm::divideCeil(maxIntraDataOffset + origElements, emulatedPerContainerElem);
948948
auto loadType = VectorType::get(numElements, containerElemTy);
949-
auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
949+
auto newBitcastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
950950

951951
auto emptyVector = rewriter.create<arith::ConstantOp>(
952952
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -974,7 +974,7 @@ struct ConvertVectorMaskedLoad final
974974

975975
Value mask = op.getMask();
976976
auto newSelectMaskType =
977-
VectorType::get(numElements * scale, rewriter.getI1Type());
977+
VectorType::get(numElements * emulatedPerContainerElem, rewriter.getI1Type());
978978
// TODO: try to fold if op's mask is constant
979979
auto emptyMask = rewriter.create<arith::ConstantOp>(
980980
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -1033,11 +1033,11 @@ struct ConvertVectorTransferRead final
10331033
op, "impossible to pack emulated elements into container elements "
10341034
"(bit-wise misalignment)");
10351035
}
1036-
int scale = containerBits / emulatedBits;
1036+
int emulatedPerContainerElem = containerBits / emulatedBits;
10371037

10381038
auto origElements = op.getVectorType().getNumElements();
10391039

1040-
bool isAlignedEmulation = origElements % scale == 0;
1040+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
10411041

10421042
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10431043
adaptor.getPadding());
@@ -1060,17 +1060,17 @@ struct ConvertVectorTransferRead final
10601060
? 0
10611061
: getConstantIntValue(linearizedInfo.intraDataOffset);
10621062

1063-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
1063+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
10641064
auto numElements =
1065-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
1065+
llvm::divideCeil(maxIntraDataOffset + origElements, emulatedPerContainerElem);
10661066

10671067
auto newRead = rewriter.create<vector::TransferReadOp>(
10681068
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
10691069
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
10701070
newPadding);
10711071

10721072
auto bitCast = rewriter.create<vector::BitCastOp>(
1073-
loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
1073+
loc, VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy), newRead);
10741074

10751075
Value result = bitCast->getResult(0);
10761076
if (!foldedIntraVectorOffset) {

0 commit comments

Comments
 (0)