Skip to content

Commit 9991ea2

Browse files
authored
[CostModel][AArch64] Make extractelement, with fmul user, free whenev… (#111479)
…er possible In case of Neon, if there exists extractelement from lane != 0 such that 1. extractelement does not necessitate a move from vector_reg -> GPR 2. extractelement result feeds into fmul 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0 then the extractelement can be merged with fmul in the backend and it incurs no cost. e.g. ``` define double @foo(<2 x double> %a) { %1 = extractelement <2 x double> %a, i32 0 %2 = extractelement <2 x double> %a, i32 1 %res = fmul double %1, %2 ret double %res } ``` `%2` and `%res` can be merged in the backend to generate: `fmul d0, d0, v0.d[1]` The change was tested with SPEC FP(C/C++) on Neoverse-v2. **Compile time impact**: None **Performance impact**: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.
1 parent 95554cb commit 9991ea2

File tree

8 files changed

+244
-28
lines changed

8 files changed

+244
-28
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H
2323

2424
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/IR/FMF.h"
2627
#include "llvm/IR/InstrTypes.h"
2728
#include "llvm/IR/PassManager.h"
@@ -1404,6 +1405,20 @@ class TargetTransformInfo {
14041405
unsigned Index = -1, Value *Op0 = nullptr,
14051406
Value *Op1 = nullptr) const;
14061407

1408+
/// \return The expected cost of vector Insert and Extract.
1409+
/// Use -1 to indicate that there is no information on the index value.
1410+
/// This is used when the instruction is not available; a typical use
1411+
/// case is to provision the cost of vectorization/scalarization in
1412+
/// vectorizer passes.
1413+
/// \param ScalarUserAndIdx encodes the information about extracts from a
1414+
/// vector with 'Scalar' being the value being extracted,'User' being the user
1415+
/// of the extract(nullptr if user is not known before vectorization) and
1416+
/// 'Idx' being the extract lane.
1417+
InstructionCost getVectorInstrCost(
1418+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1419+
Value *Scalar,
1420+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const;
1421+
14071422
/// \return The expected cost of vector Insert and Extract.
14081423
/// This is used when instruction is available, and implementation
14091424
/// asserts 'I' is not nullptr.
@@ -2100,6 +2115,16 @@ class TargetTransformInfo::Concept {
21002115
TTI::TargetCostKind CostKind,
21012116
unsigned Index, Value *Op0,
21022117
Value *Op1) = 0;
2118+
2119+
/// \param ScalarUserAndIdx encodes the information about extracts from a
2120+
/// vector with 'Scalar' being the value being extracted,'User' being the user
2121+
/// of the extract(nullptr if user is not known before vectorization) and
2122+
/// 'Idx' being the extract lane.
2123+
virtual InstructionCost getVectorInstrCost(
2124+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2125+
Value *Scalar,
2126+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) = 0;
2127+
21032128
virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
21042129
TTI::TargetCostKind CostKind,
21052130
unsigned Index) = 0;
@@ -2785,6 +2810,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27852810
Value *Op1) override {
27862811
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
27872812
}
2813+
InstructionCost getVectorInstrCost(
2814+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2815+
Value *Scalar,
2816+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) override {
2817+
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Scalar,
2818+
ScalarUserAndIdx);
2819+
}
27882820
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
27892821
TTI::TargetCostKind CostKind,
27902822
unsigned Index) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,17 @@ class TargetTransformInfoImplBase {
700700
return 1;
701701
}
702702

703+
/// \param ScalarUserAndIdx encodes the information about extracts from a
704+
/// vector with 'Scalar' being the value being extracted,'User' being the user
705+
/// of the extract(nullptr if user is not known before vectorization) and
706+
/// 'Idx' being the extract lane.
707+
InstructionCost getVectorInstrCost(
708+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
709+
Value *Scalar,
710+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
711+
return 1;
712+
}
713+
703714
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
704715
TTI::TargetCostKind CostKind,
705716
unsigned Index) const {

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#define LLVM_CODEGEN_BASICTTIIMPL_H
1818

1919
#include "llvm/ADT/APInt.h"
20-
#include "llvm/ADT/ArrayRef.h"
2120
#include "llvm/ADT/BitVector.h"
2221
#include "llvm/ADT/SmallPtrSet.h"
2322
#include "llvm/ADT/SmallVector.h"
@@ -1288,6 +1287,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12881287
return getRegUsageForType(Val->getScalarType());
12891288
}
12901289

1290+
/// \param ScalarUserAndIdx encodes the information about extracts from a
1291+
/// vector with 'Scalar' being the value being extracted,'User' being the user
1292+
/// of the extract(nullptr if user is not known before vectorization) and
1293+
/// 'Idx' being the extract lane.
1294+
InstructionCost getVectorInstrCost(
1295+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1296+
Value *Scalar,
1297+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
1298+
return thisT()->getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr,
1299+
nullptr);
1300+
}
1301+
12911302
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
12921303
TTI::TargetCostKind CostKind,
12931304
unsigned Index) {

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,15 +1047,28 @@ InstructionCost TargetTransformInfo::getCmpSelInstrCost(
10471047
InstructionCost TargetTransformInfo::getVectorInstrCost(
10481048
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
10491049
Value *Op0, Value *Op1) const {
1050-
// FIXME: Assert that Opcode is either InsertElement or ExtractElement.
1051-
// This is mentioned in the interface description and respected by all
1052-
// callers, but never asserted upon.
1050+
assert((Opcode == Instruction::InsertElement ||
1051+
Opcode == Instruction::ExtractElement) &&
1052+
"Expecting Opcode to be insertelement/extractelement.");
10531053
InstructionCost Cost =
10541054
TTIImpl->getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
10551055
assert(Cost >= 0 && "TTI should not produce negative costs!");
10561056
return Cost;
10571057
}
10581058

1059+
InstructionCost TargetTransformInfo::getVectorInstrCost(
1060+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1061+
Value *Scalar,
1062+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
1063+
assert((Opcode == Instruction::InsertElement ||
1064+
Opcode == Instruction::ExtractElement) &&
1065+
"Expecting Opcode to be insertelement/extractelement.");
1066+
InstructionCost Cost = TTIImpl->getVectorInstrCost(
1067+
Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx);
1068+
assert(Cost >= 0 && "TTI should not produce negative costs!");
1069+
return Cost;
1070+
}
1071+
10591072
InstructionCost
10601073
TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
10611074
TTI::TargetCostKind CostKind,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "AArch64PerfectShuffle.h"
1212
#include "MCTargetDesc/AArch64AddressingModes.h"
1313
#include "Utils/AArch64SMEAttributes.h"
14+
#include "llvm/ADT/DenseMap.h"
1415
#include "llvm/Analysis/IVDescriptors.h"
1516
#include "llvm/Analysis/LoopInfo.h"
1617
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -3177,10 +3178,10 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31773178
return 0;
31783179
}
31793180

3180-
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
3181-
Type *Val,
3182-
unsigned Index,
3183-
bool HasRealUse) {
3181+
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3182+
unsigned Opcode, Type *Val, unsigned Index, bool HasRealUse,
3183+
const Instruction *I, Value *Scalar,
3184+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
31843185
assert(Val->isVectorTy() && "This must be a vector type");
31853186

31863187
if (Index != -1U) {
@@ -3226,6 +3227,119 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
32263227
// compile-time considerations.
32273228
}
32283229

3230+
// In case of Neon, if there exists extractelement from lane != 0 such that
3231+
// 1. extractelement does not necessitate a move from vector_reg -> GPR.
3232+
// 2. extractelement result feeds into fmul.
3233+
// 3. Other operand of fmul is an extractelement from lane 0 or lane
3234+
// equivalent to 0.
3235+
// then the extractelement can be merged with fmul in the backend and it
3236+
// incurs no cost.
3237+
// e.g.
3238+
// define double @foo(<2 x double> %a) {
3239+
// %1 = extractelement <2 x double> %a, i32 0
3240+
// %2 = extractelement <2 x double> %a, i32 1
3241+
// %res = fmul double %1, %2
3242+
// ret double %res
3243+
// }
3244+
// %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3245+
auto ExtractCanFuseWithFmul = [&]() {
3246+
// We bail out if the extract is from lane 0.
3247+
if (Index == 0)
3248+
return false;
3249+
3250+
// Check if the scalar element type of the vector operand of ExtractElement
3251+
// instruction is one of the allowed types.
3252+
auto IsAllowedScalarTy = [&](const Type *T) {
3253+
return T->isFloatTy() || T->isDoubleTy() ||
3254+
(T->isHalfTy() && ST->hasFullFP16());
3255+
};
3256+
3257+
// Check if the extractelement user is scalar fmul.
3258+
auto IsUserFMulScalarTy = [](const Value *EEUser) {
3259+
// Check if the user is scalar fmul.
3260+
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3261+
return BO && BO->getOpcode() == BinaryOperator::FMul &&
3262+
!BO->getType()->isVectorTy();
3263+
};
3264+
3265+
// Check if the extract index is from lane 0 or lane equivalent to 0 for a
3266+
// certain scalar type and a certain vector register width.
3267+
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3268+
const unsigned &EltSz) {
3269+
auto RegWidth =
3270+
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
3271+
.getFixedValue();
3272+
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
3273+
};
3274+
3275+
// Check if the type constraints on input vector type and result scalar type
3276+
// of extractelement instruction are satisfied.
3277+
if (!isa<FixedVectorType>(Val) || !IsAllowedScalarTy(Val->getScalarType()))
3278+
return false;
3279+
3280+
if (Scalar) {
3281+
DenseMap<User *, unsigned> UserToExtractIdx;
3282+
for (auto *U : Scalar->users()) {
3283+
if (!IsUserFMulScalarTy(U))
3284+
return false;
3285+
// Recording entry for the user is important. Index value is not
3286+
// important.
3287+
UserToExtractIdx[U];
3288+
}
3289+
for (auto &[S, U, L] : ScalarUserAndIdx) {
3290+
for (auto *U : S->users()) {
3291+
if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
3292+
auto *FMul = cast<BinaryOperator>(U);
3293+
auto *Op0 = FMul->getOperand(0);
3294+
auto *Op1 = FMul->getOperand(1);
3295+
if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3296+
UserToExtractIdx[U] = L;
3297+
break;
3298+
}
3299+
}
3300+
}
3301+
}
3302+
for (auto &[U, L] : UserToExtractIdx) {
3303+
if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
3304+
!IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
3305+
return false;
3306+
}
3307+
} else {
3308+
const auto *EE = cast<ExtractElementInst>(I);
3309+
3310+
const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3311+
if (!IdxOp)
3312+
return false;
3313+
3314+
return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3315+
if (!IsUserFMulScalarTy(U))
3316+
return false;
3317+
3318+
// Check if the other operand of extractelement is also extractelement
3319+
// from lane equivalent to 0.
3320+
const auto *BO = cast<BinaryOperator>(U);
3321+
const auto *OtherEE = dyn_cast<ExtractElementInst>(
3322+
BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3323+
if (OtherEE) {
3324+
const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3325+
if (!IdxOp)
3326+
return false;
3327+
return IsExtractLaneEquivalentToZero(
3328+
cast<ConstantInt>(OtherEE->getIndexOperand())
3329+
->getValue()
3330+
.getZExtValue(),
3331+
OtherEE->getType()->getScalarSizeInBits());
3332+
}
3333+
return true;
3334+
});
3335+
}
3336+
return true;
3337+
};
3338+
3339+
if (Opcode == Instruction::ExtractElement && (I || Scalar) &&
3340+
ExtractCanFuseWithFmul())
3341+
return 0;
3342+
32293343
// All other insert/extracts cost this much.
32303344
return ST->getVectorInsertExtractBaseCost();
32313345
}
@@ -3236,14 +3350,23 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32363350
Value *Op1) {
32373351
bool HasRealUse =
32383352
Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3239-
return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
3353+
return getVectorInstrCostHelper(Opcode, Val, Index, HasRealUse);
3354+
}
3355+
3356+
InstructionCost AArch64TTIImpl::getVectorInstrCost(
3357+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3358+
Value *Scalar,
3359+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
3360+
return getVectorInstrCostHelper(Opcode, Val, Index, false, nullptr, Scalar,
3361+
ScalarUserAndIdx);
32403362
}
32413363

32423364
InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
32433365
Type *Val,
32443366
TTI::TargetCostKind CostKind,
32453367
unsigned Index) {
3246-
return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */);
3368+
return getVectorInstrCostHelper(I.getOpcode(), Val, Index,
3369+
true /* HasRealUse */, &I);
32473370
}
32483371

32493372
InstructionCost AArch64TTIImpl::getScalarizationOverhead(

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "AArch64.h"
2020
#include "AArch64Subtarget.h"
2121
#include "AArch64TargetMachine.h"
22-
#include "llvm/ADT/ArrayRef.h"
2322
#include "llvm/Analysis/TargetTransformInfo.h"
2423
#include "llvm/CodeGen/BasicTTIImpl.h"
2524
#include "llvm/IR/Function.h"
@@ -66,8 +65,14 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
6665
// 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
6766
// indicates whether the vector instruction is available in the input IR or
6867
// just imaginary in vectorizer passes.
69-
InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
70-
unsigned Index, bool HasRealUse);
68+
/// \param ScalarUserAndIdx encodes the information about extracts from a
69+
/// vector with 'Scalar' being the value being extracted,'User' being the user
70+
/// of the extract(nullptr if user is not known before vectorization) and
71+
/// 'Idx' being the extract lane.
72+
InstructionCost getVectorInstrCostHelper(
73+
unsigned Opcode, Type *Val, unsigned Index, bool HasRealUse,
74+
const Instruction *I = nullptr, Value *Scalar = nullptr,
75+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx = {});
7176

7277
public:
7378
explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +190,16 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
185190
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
186191
TTI::TargetCostKind CostKind,
187192
unsigned Index, Value *Op0, Value *Op1);
193+
194+
/// \param ScalarUserAndIdx encodes the information about extracts from a
195+
/// vector with 'Scalar' being the value being extracted,'User' being the user
196+
/// of the extract(nullptr if user is not known before vectorization) and
197+
/// 'Idx' being the extract lane.
198+
InstructionCost getVectorInstrCost(
199+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
200+
Value *Scalar,
201+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx);
202+
188203
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
189204
TTI::TargetCostKind CostKind,
190205
unsigned Index);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12192,6 +12192,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1219212192
std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
1219312193
DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
1219412194
SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
12195+
// Keep track {Scalar, Index, User} tuple.
12196+
// On AArch64, this helps in fusing a mov instruction, associated with
12197+
// extractelement, with fmul in the backend so that extractelement is free.
12198+
SmallVector<std::tuple<Value *, User *, int>, 4> ScalarUserAndIdx;
12199+
for (ExternalUser &EU : ExternalUses) {
12200+
ScalarUserAndIdx.emplace_back(EU.Scalar, EU.User, EU.Lane);
12201+
}
1219512202
for (ExternalUser &EU : ExternalUses) {
1219612203
// Uses by ephemeral values are free (because the ephemeral value will be
1219712204
// removed prior to code generation, and so the extraction will be
@@ -12304,8 +12311,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1230412311
ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
1230512312
VecTy, EU.Lane);
1230612313
} else {
12307-
ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
12308-
CostKind, EU.Lane);
12314+
ExtraCost =
12315+
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
12316+
EU.Lane, EU.Scalar, ScalarUserAndIdx);
1230912317
}
1231012318
// Leave the scalar instructions as is if they are cheaper than extracts.
1231112319
if (Entry->Idx != 0 || Entry->getOpcode() == Instruction::GetElementPtr ||

0 commit comments

Comments
 (0)