Skip to content

[AArch64] Guard against getRegisterBitWidth returning zero in vector instr cost. #117749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2024

Conversation

davemgreen
Copy link
Collaborator

@davemgreen davemgreen commented Nov 26, 2024

If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth.

@llvmbot llvmbot added backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding labels Nov 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

Changes

If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth.

It took a while to figure out what was going wrong so there are a few other minor cleanups here too.


Full diff: https://github.com/llvm/llvm-project/pull/117749.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+6-5)
  • (added) llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll (+18)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7a1e401bca18cb..a6b595d71bfe04 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
     // Check if the extractelement user is scalar fmul.
     auto IsUserFMulScalarTy = [](const Value *EEUser) {
       // Check if the user is scalar fmul.
-      const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      const auto *BO = dyn_cast<BinaryOperator>(EEUser);
       return BO && BO->getOpcode() == BinaryOperator::FMul &&
              !BO->getType()->isVectorTy();
     };
 
     // Check if the extract index is from lane 0 or lane equivalent to 0 for a
     // certain scalar type and a certain vector register width.
-    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
-                                             const unsigned &EltSz) {
+    auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
       auto RegWidth =
           getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
               .getFixedValue();
-      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+      return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
     };
 
     // Check if the type constraints on input vector type and result scalar type
@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
         // important.
         UserToExtractIdx[U];
       }
+      if (UserToExtractIdx.empty())
+        return false;
       for (auto &[S, U, L] : ScalarUserAndIdx) {
         for (auto *U : S->users()) {
           if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
             auto *FMul = cast<BinaryOperator>(U);
             auto *Op0 = FMul->getOperand(0);
             auto *Op1 = FMul->getOperand(1);
-            if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
+            if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
               UserToExtractIdx[U] = L;
               break;
             }
diff --git a/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
new file mode 100644
index 00000000000000..84502abceed3b0
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -mtriple=aarch64-unknown-linux -mattr=+sme | FileCheck %s
+
+define double @extract_case7(<4 x double> %a) "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: 'extract_case7'
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %0 = extractelement <4 x double> %a, i32 1
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %1 = extractelement <4 x double> %a, i32 2
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %res = fmul double %0, %1
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret double %res
+;
+entry:
+  %1 = extractelement <4 x double> %a, i32 1
+  %2 = extractelement <4 x double> %a, i32 2
+  %res = fmul double %1, %2
+  ret double %res
+}
+
+declare void @foo(double)

return BO && BO->getOpcode() == BinaryOperator::FMul &&
!BO->getType()->isVectorTy();
};

// Check if the extract index is from lane 0 or lane equivalent to 0 for a
// certain scalar type and a certain vector register width.
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
const unsigned &EltSz) {
auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you land the NFC changes separately?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will do.

@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
// important.
UserToExtractIdx[U];
}
if (UserToExtractIdx.empty())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the link with the change in IsExtractLaneEquivalentToZero, is this related?

@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
// Check if the extractelement user is scalar fmul.
auto IsUserFMulScalarTy = [](const Value *EEUser) {
// Check if the user is scalar fmul.
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
const auto *BO = dyn_cast<BinaryOperator>(EEUser);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EEUser can be null and hence, dont change this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called from users() in two cases. Can you explain why it can be null? Do you have a test case that shows it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, my bad. You are right. (if_present was required in one of the earlier revisions of the original patch and that lasted till end)

auto RegWidth =
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
.getFixedValue();
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate why the RegWidth can be 0 in sme streaming functions? If there is some article which I can read about, you can point that.

Can you push the check inside?
return Idx == 0 || (RegWidth != 0 && (Idx * EltSz) % RegWidth == 0)

Copy link
Contributor

@sushgokh sushgokh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@davemgreen davemgreen force-pushed the gh-a64-vectorcostregwidthzero branch from 8abe6b1 to 7fe8ef4 Compare November 29, 2024 01:45
…instr cost.

If the getRegisterBitWidth is zero (such as in sme streaming functions), then
we could hit a crash from using % RegWidth.

It took a while to figure out what was going wrong so there are a few other
minor cleanups here too.
@davemgreen davemgreen force-pushed the gh-a64-vectorcostregwidthzero branch from 7fe8ef4 to f649f01 Compare November 29, 2024 01:51
@davemgreen
Copy link
Collaborator Author

Thanks. The first part was split out into d106a39.

@davemgreen davemgreen merged commit d714b22 into llvm:main Nov 29, 2024
8 checks passed
@davemgreen davemgreen deleted the gh-a64-vectorcostregwidthzero branch November 29, 2024 04:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants