Skip to content

Commit 5edc342

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)
This is PR 1 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames `srcBits/dstBits` to `oldBits/newBits` to improve consistency in naming within the file. This is illustrated below: ```cpp // Extracted from VectorEmulateNarrowType.cpp Type oldElementType = op.getType().getElementType(); Type newElementType = convertedType.getElementType(); // BEFORE (mixing old/new and src/dst): // int srcBits = oldElementType.getIntOrFloatBitWidth(); // int dstBits = newElementType.getIntOrFloatBitWidth(); // AFTER (consistently using old/new): int oldBits = oldElementType.getIntOrFloatBitWidth(); int newBits = newElementType.getIntOrFloatBitWidth(); ``` Also adds some comments and unifies related "rewriter notification" messages.
1 parent 04034f0 commit 5edc342

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
314314
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
315315
Type oldElementType = op.getValueToStore().getType().getElementType();
316316
Type newElementType = convertedType.getElementType();
317-
int srcBits = oldElementType.getIntOrFloatBitWidth();
318-
int dstBits = newElementType.getIntOrFloatBitWidth();
317+
int oldBits = oldElementType.getIntOrFloatBitWidth();
318+
int newBits = newElementType.getIntOrFloatBitWidth();
319319

320-
if (dstBits % srcBits != 0) {
321-
return rewriter.notifyMatchFailure(
322-
op, "only dstBits % srcBits == 0 supported");
320+
// Check per-element alignment.
321+
if (newBits % oldBits != 0) {
322+
return rewriter.notifyMatchFailure(op, "unalagined element types");
323323
}
324-
int scale = dstBits / srcBits;
324+
int scale = newBits / oldBits;
325325

326326
// Adjust the number of elements to store when emulating narrow types.
327327
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
346346
OpFoldResult linearizedIndices;
347347
std::tie(std::ignore, linearizedIndices) =
348348
memref::getLinearizedMemRefOffsetAndSize(
349-
rewriter, loc, srcBits, dstBits,
349+
rewriter, loc, oldBits, newBits,
350350
stridedMetadata.getConstifiedMixedOffset(),
351351
stridedMetadata.getConstifiedMixedSizes(),
352352
stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
385385
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
386386
Type oldElementType = op.getValueToStore().getType().getElementType();
387387
Type newElementType = convertedType.getElementType();
388-
int srcBits = oldElementType.getIntOrFloatBitWidth();
389-
int dstBits = newElementType.getIntOrFloatBitWidth();
388+
int oldBits = oldElementType.getIntOrFloatBitWidth();
389+
int newBits = newElementType.getIntOrFloatBitWidth();
390390

391-
if (dstBits % srcBits != 0) {
392-
return rewriter.notifyMatchFailure(
393-
op, "only dstBits % srcBits == 0 supported");
391+
// Check per-element alignment.
392+
if (newBits % oldBits != 0) {
393+
return rewriter.notifyMatchFailure(op, "unalagined element types");
394394
}
395395

396-
int scale = dstBits / srcBits;
396+
int scale = newBits / oldBits;
397397
int origElements = op.getValueToStore().getType().getNumElements();
398398
if (origElements % scale != 0)
399399
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
404404
memref::LinearizedMemRefInfo linearizedInfo;
405405
std::tie(linearizedInfo, linearizedIndicesOfr) =
406406
memref::getLinearizedMemRefOffsetAndSize(
407-
rewriter, loc, srcBits, dstBits,
407+
rewriter, loc, oldBits, newBits,
408408
stridedMetadata.getConstifiedMixedOffset(),
409409
stridedMetadata.getConstifiedMixedSizes(),
410410
stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
493493
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
494494
Type oldElementType = op.getType().getElementType();
495495
Type newElementType = convertedType.getElementType();
496-
int srcBits = oldElementType.getIntOrFloatBitWidth();
497-
int dstBits = newElementType.getIntOrFloatBitWidth();
496+
int oldBits = oldElementType.getIntOrFloatBitWidth();
497+
int newBits = newElementType.getIntOrFloatBitWidth();
498498

499-
if (dstBits % srcBits != 0) {
500-
return rewriter.notifyMatchFailure(
501-
op, "only dstBits % srcBits == 0 supported");
499+
// Check per-element alignment.
500+
if (newBits % oldBits != 0) {
501+
return rewriter.notifyMatchFailure(op, "unalagined element types");
502502
}
503-
int scale = dstBits / srcBits;
503+
int scale = newBits / oldBits;
504504

505505
// Adjust the number of elements to load when emulating narrow types,
506506
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
541541
memref::LinearizedMemRefInfo linearizedInfo;
542542
std::tie(linearizedInfo, linearizedIndices) =
543543
memref::getLinearizedMemRefOffsetAndSize(
544-
rewriter, loc, srcBits, dstBits,
544+
rewriter, loc, oldBits, newBits,
545545
stridedMetadata.getConstifiedMixedOffset(),
546546
stridedMetadata.getConstifiedMixedSizes(),
547547
stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
596596
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
597597
Type oldElementType = op.getType().getElementType();
598598
Type newElementType = convertedType.getElementType();
599-
int srcBits = oldElementType.getIntOrFloatBitWidth();
600-
int dstBits = newElementType.getIntOrFloatBitWidth();
599+
int oldBits = oldElementType.getIntOrFloatBitWidth();
600+
int newBits = newElementType.getIntOrFloatBitWidth();
601601

602-
if (dstBits % srcBits != 0) {
603-
return rewriter.notifyMatchFailure(
604-
op, "only dstBits % srcBits == 0 supported");
602+
// Check per-element alignment.
603+
if (newBits % oldBits != 0) {
604+
return rewriter.notifyMatchFailure(op, "unalagined element types");
605605
}
606-
int scale = dstBits / srcBits;
606+
int scale = newBits / oldBits;
607607

608608
// Adjust the number of elements to load when emulating narrow types,
609609
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
657657
memref::LinearizedMemRefInfo linearizedInfo;
658658
std::tie(linearizedInfo, linearizedIndices) =
659659
memref::getLinearizedMemRefOffsetAndSize(
660-
rewriter, loc, srcBits, dstBits,
660+
rewriter, loc, oldBits, newBits,
661661
stridedMetadata.getConstifiedMixedOffset(),
662662
stridedMetadata.getConstifiedMixedSizes(),
663663
stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
758758
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
759759
Type oldElementType = op.getType().getElementType();
760760
Type newElementType = convertedType.getElementType();
761-
int srcBits = oldElementType.getIntOrFloatBitWidth();
762-
int dstBits = newElementType.getIntOrFloatBitWidth();
761+
int oldBits = oldElementType.getIntOrFloatBitWidth();
762+
int newBits = newElementType.getIntOrFloatBitWidth();
763763

764-
if (dstBits % srcBits != 0) {
765-
return rewriter.notifyMatchFailure(
766-
op, "only dstBits % srcBits == 0 supported");
764+
// Check per-element alignment.
765+
if (newBits % oldBits != 0) {
766+
return rewriter.notifyMatchFailure(op, "unalagined element types");
767767
}
768-
int scale = dstBits / srcBits;
768+
int scale = newBits / oldBits;
769769

770770
auto origElements = op.getVectorType().getNumElements();
771771

@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
781781
memref::LinearizedMemRefInfo linearizedInfo;
782782
std::tie(linearizedInfo, linearizedIndices) =
783783
memref::getLinearizedMemRefOffsetAndSize(
784-
rewriter, loc, srcBits, dstBits,
784+
rewriter, loc, oldBits, newBits,
785785
stridedMetadata.getConstifiedMixedOffset(),
786786
stridedMetadata.getConstifiedMixedSizes(),
787787
stridedMetadata.getConstifiedMixedStrides(),

0 commit comments

Comments
 (0)