-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64] Guard against getRegisterBitWidth returning zero in vector instr cost. #117749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Guard against getRegisterBitWidth returning zero in vector instr cost. #117749
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-backend-aarch64 Author: David Green (davemgreen) ChangesIf the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth. It took a while to figure out what was going wrong so there are a few other minor cleanups here too. Full diff: https://github.com/llvm/llvm-project/pull/117749.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7a1e401bca18cb..a6b595d71bfe04 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
// Check if the extractelement user is scalar fmul.
auto IsUserFMulScalarTy = [](const Value *EEUser) {
// Check if the user is scalar fmul.
- const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+ const auto *BO = dyn_cast<BinaryOperator>(EEUser);
return BO && BO->getOpcode() == BinaryOperator::FMul &&
!BO->getType()->isVectorTy();
};
// Check if the extract index is from lane 0 or lane equivalent to 0 for a
// certain scalar type and a certain vector register width.
- auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
- const unsigned &EltSz) {
+ auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
auto RegWidth =
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
.getFixedValue();
- return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+ return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
};
// Check if the type constraints on input vector type and result scalar type
@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
// important.
UserToExtractIdx[U];
}
+ if (UserToExtractIdx.empty())
+ return false;
for (auto &[S, U, L] : ScalarUserAndIdx) {
for (auto *U : S->users()) {
if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
auto *FMul = cast<BinaryOperator>(U);
auto *Op0 = FMul->getOperand(0);
auto *Op1 = FMul->getOperand(1);
- if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
+ if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
UserToExtractIdx[U] = L;
break;
}
diff --git a/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
new file mode 100644
index 00000000000000..84502abceed3b0
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -mtriple=aarch64-unknown-linux -mattr=+sme | FileCheck %s
+
+define double @extract_case7(<4 x double> %a) "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: 'extract_case7'
+; CHECK-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %0 = extractelement <4 x double> %a, i32 1
+; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %1 = extractelement <4 x double> %a, i32 2
+; CHECK-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %res = fmul double %0, %1
+; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret double %res
+;
+entry:
+ %1 = extractelement <4 x double> %a, i32 1
+ %2 = extractelement <4 x double> %a, i32 2
+ %res = fmul double %1, %2
+ ret double %res
+}
+
+declare void @foo(double)
|
return BO && BO->getOpcode() == BinaryOperator::FMul && | ||
!BO->getType()->isVectorTy(); | ||
}; | ||
|
||
// Check if the extract index is from lane 0 or lane equivalent to 0 for a | ||
// certain scalar type and a certain vector register width. | ||
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx, | ||
const unsigned &EltSz) { | ||
auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you land the NFC changes separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah will do.
@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper( | |||
// important. | |||
UserToExtractIdx[U]; | |||
} | |||
if (UserToExtractIdx.empty()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the link with the change in IsExtractLaneEquivalentToZero
, is this related?
@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper( | |||
// Check if the extractelement user is scalar fmul. | |||
auto IsUserFMulScalarTy = [](const Value *EEUser) { | |||
// Check if the user is scalar fmul. | |||
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser); | |||
const auto *BO = dyn_cast<BinaryOperator>(EEUser); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EEUser can be null and hence, dont change this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is called from users() in two cases. Can you explain why it can be null? Do you have a test case that shows it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, my bad. You are right. (if_present was required in one of the earlier revisions of the original patch and that lasted till end)
auto RegWidth = | ||
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) | ||
.getFixedValue(); | ||
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0); | ||
return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate why the RegWidth
can be 0 in sme streaming functions? If there is some article which I can read about, you can point that.
Can you push the check inside?
return Idx == 0 || (RegWidth != 0 && (Idx * EltSz) % RegWidth == 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8abe6b1
to
7fe8ef4
Compare
…instr cost. If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth. It took a while to figure out what was going wrong so there are a few other minor cleanups here too.
7fe8ef4
to
f649f01
Compare
Thanks. The first part was split out into d106a39. |
If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth.