-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) #123526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) #123526
Conversation
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThis is PR 1 in a series of N patches aimed at improving This PR renames // Extracted from VectorEmulateNarrowType.cpp
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
// BEFORE (mixing old/new and src/dst):
// int srcBits = oldElementType.getIntOrFloatBitWidth();
// int dstBits = newElementType.getIntOrFloatBitWidth();
// AFTER (consistently using old/new):
int oldBits = oldElementType.getIntOrFloatBitWidth();
int newBits = newElementType.getIntOrFloatBitWidth(); Also adds some comments and unifies related "rewriter notification" Full diff: https://github.com/llvm/llvm-project/pull/123526.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..70d50e1d48040c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis is PR 1 in a series of N patches aimed at improving This PR renames // Extracted from VectorEmulateNarrowType.cpp
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
// BEFORE (mixing old/new and src/dst):
// int srcBits = oldElementType.getIntOrFloatBitWidth();
// int dstBits = newElementType.getIntOrFloatBitWidth();
// AFTER (consistently using old/new):
int oldBits = oldElementType.getIntOrFloatBitWidth();
int newBits = newElementType.getIntOrFloatBitWidth(); Also adds some comments and unifies related "rewriter notification" Full diff: https://github.com/llvm/llvm-project/pull/123526.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..70d50e1d48040c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure old/new
is clearer than src/dst
. Could we rename old/newElementType
to src/dstElementType
?
Yeah I concur with this idea. The other way around sounds better to me. |
I have thought about it and I find "source" and "destination" ambiguous. My goals is to reduce ambiguity. Please see Proposal 2 here: Ultimately, we should be using names like "emulated type" and "container type". WDYT? |
I approved cause old and new works better for me, but please wait for consensus. |
Thanks! Just to clarify, I propose that over time we converge towards sth like |
Now I realized indeed we need to use src/dst in other places. So I guess old/new is good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are all okay to me. I'm more like looking for consistency.
Replied here: #123630 |
Folks, following up on the discussion with Diego: I’ve updated the names to use
For example, here vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> If you are OK with this update, please 👍🏻 . Otherwise, please leave a comment :) |
This is PR 1 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames srcBits/dstBits + oldElementType/newElementType to improve consistency in naming within the file. This is illustrated below: ```cpp // Extracted from VectorEmulateNarrowType.cpp // BEFORE (mixing old/new and src/dst): // Type oldElementType = op.getType().getElementType(); // Type newElementType = convertedType.getElementType(); // int srcBits = oldElementType.getIntOrFloatBitWidth(); // int dstBits = newElementType.getIntOrFloatBitWidth(); // AFTER (consistently using emulated/container): Type emulatedElemType = op.getType().getElementType(); Type containerElemType = convertedType.getElementType(); int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); int containerBits = containerElemTy.getIntOrFloatBitWidth(); ``` Also adds some comments and unifies related "rewriter notification" messages.
cda648e
to
2b7a611
Compare
Updates `emulatedVectorLoad` that was introduced in llvm#115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after llvm#123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * llvm#123526 123526 Please only review the [top commit](llvm@d40b31b). **GitHub issue to track this work**: llvm#123630
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * llvm#123526 123526 Please only review the [top commit](llvm@d40b31b). **GitHub issue to track this work**: llvm#123630
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/94/builds/4225 Here is the relevant piece of the build log for the reference
|
Updates `emulatedVectorLoad` that was introduced in #115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after #123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * llvm#123526 123526 Please only review the [top commit](llvm@d40b31b). **GitHub issue to track this work**: llvm#123630
…125415) Updates `emulatedVectorLoad` that was introduced in llvm#115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after llvm#123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
This PR renames:
srcBits
/dstBits
+oldElementType
/newElementType
to improve consistency in naming within the file. This is illustrated
below:
Also adds some comments and unifies related "rewriter notification"
messages.
GitHub issue to track this work: