Skip to content

Commit 8abe6b1

Browse files
committed
[AArch64] Guard against getRegisterBitWidth returning zero in vector instr cost.
If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth. It took a while to figure out what was going wrong so there are a few other minor cleanups here too.
1 parent f7dc1d0 commit 8abe6b1

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
32483248
// Check if the extractelement user is scalar fmul.
32493249
auto IsUserFMulScalarTy = [](const Value *EEUser) {
32503250
// Check if the user is scalar fmul.
3251-
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3251+
const auto *BO = dyn_cast<BinaryOperator>(EEUser);
32523252
return BO && BO->getOpcode() == BinaryOperator::FMul &&
32533253
!BO->getType()->isVectorTy();
32543254
};
32553255

32563256
// Check if the extract index is from lane 0 or lane equivalent to 0 for a
32573257
// certain scalar type and a certain vector register width.
3258-
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3259-
const unsigned &EltSz) {
3258+
auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
32603259
auto RegWidth =
32613260
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
32623261
.getFixedValue();
3263-
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
3262+
return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
32643263
};
32653264

32663265
// Check if the type constraints on input vector type and result scalar type
@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
32773276
// important.
32783277
UserToExtractIdx[U];
32793278
}
3279+
if (UserToExtractIdx.empty())
3280+
return false;
32803281
for (auto &[S, U, L] : ScalarUserAndIdx) {
32813282
for (auto *U : S->users()) {
32823283
if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
32833284
auto *FMul = cast<BinaryOperator>(U);
32843285
auto *Op0 = FMul->getOperand(0);
32853286
auto *Op1 = FMul->getOperand(1);
3286-
if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3287+
if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
32873288
UserToExtractIdx[U] = L;
32883289
break;
32893290
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -mtriple=aarch64-unknown-linux -mattr=+sme | FileCheck %s
3+
4+
define double @extract_case7(<4 x double> %a) "aarch64_pstate_sm_enabled" {
5+
; CHECK-LABEL: 'extract_case7'
6+
; CHECK-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %0 = extractelement <4 x double> %a, i32 1
7+
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %1 = extractelement <4 x double> %a, i32 2
8+
; CHECK-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %res = fmul double %0, %1
9+
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret double %res
10+
;
11+
entry:
12+
%1 = extractelement <4 x double> %a, i32 1
13+
%2 = extractelement <4 x double> %a, i32 2
14+
%res = fmul double %1, %2
15+
ret double %res
16+
}
17+
18+
declare void @foo(double)

0 commit comments

Comments
 (0)