Skip to content

Commit f795d1a

Browse files
[AArch64][LV][SLP] Vectorizers use call cost for vectorized frem (#82488)
getArithmeticInstrCost is used by both LoopVectorizer and SLPVectorizer to compute the cost of frem, which becomes a call cost on AArch64 when TLI has a vector library function. Add tests that do SLP vectorization for code that contains 2x double and 4x float frem instructions.
1 parent 611c62b commit f795d1a

File tree

5 files changed

+79
-20
lines changed

5 files changed

+79
-20
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1247,13 +1247,16 @@ class TargetTransformInfo {
12471247
/// cases or optimizations based on those values.
12481248
/// \p CxtI is the optional original context instruction, if one exists, to
12491249
/// provide even more information.
1250+
/// \p TLibInfo is used to search for platform specific vector library
1251+
/// functions for instructions that might be converted to calls (e.g. frem).
12501252
InstructionCost getArithmeticInstrCost(
12511253
unsigned Opcode, Type *Ty,
12521254
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
12531255
TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
12541256
TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
12551257
ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
1256-
const Instruction *CxtI = nullptr) const;
1258+
const Instruction *CxtI = nullptr,
1259+
const TargetLibraryInfo *TLibInfo = nullptr) const;
12571260

12581261
/// Returns the cost estimation for alternating opcode pattern that can be
12591262
/// lowered to a single instruction on the target. In X86 this is for the

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llvm/Analysis/TargetTransformInfo.h"
1010
#include "llvm/Analysis/CFG.h"
1111
#include "llvm/Analysis/LoopIterator.h"
12+
#include "llvm/Analysis/TargetLibraryInfo.h"
1213
#include "llvm/Analysis/TargetTransformInfoImpl.h"
1314
#include "llvm/IR/CFG.h"
1415
#include "llvm/IR/Dominators.h"
@@ -874,7 +875,22 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
874875
InstructionCost TargetTransformInfo::getArithmeticInstrCost(
875876
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
876877
OperandValueInfo Op1Info, OperandValueInfo Op2Info,
877-
ArrayRef<const Value *> Args, const Instruction *CxtI) const {
878+
ArrayRef<const Value *> Args, const Instruction *CxtI,
879+
const TargetLibraryInfo *TLibInfo) const {
880+
881+
// Use call cost for frem intructions that have platform specific vector math
882+
// functions, as those will be replaced with calls later by SelectionDAG or
883+
// ReplaceWithVecLib pass.
884+
if (TLibInfo && Opcode == Instruction::FRem) {
885+
VectorType *VecTy = dyn_cast<VectorType>(Ty);
886+
LibFunc Func;
887+
if (VecTy &&
888+
TLibInfo->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) &&
889+
TLibInfo->isFunctionVectorizable(TLibInfo->getName(Func),
890+
VecTy->getElementCount()))
891+
return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
892+
}
893+
878894
InstructionCost Cost =
879895
TTIImpl->getArithmeticInstrCost(Opcode, Ty, CostKind,
880896
Op1Info, Op2Info,

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6911,25 +6911,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
69116911
Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
69126912

69136913
SmallVector<const Value *, 4> Operands(I->operand_values());
6914-
auto InstrCost = TTI.getArithmeticInstrCost(
6914+
return TTI.getArithmeticInstrCost(
69156915
I->getOpcode(), VectorTy, CostKind,
69166916
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
6917-
Op2Info, Operands, I);
6918-
6919-
// Some targets can replace frem with vector library calls.
6920-
InstructionCost VecCallCost = InstructionCost::getInvalid();
6921-
if (I->getOpcode() == Instruction::FRem) {
6922-
LibFunc Func;
6923-
if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
6924-
TLI->isFunctionVectorizable(TLI->getName(Func), VF)) {
6925-
SmallVector<Type *, 4> OpTypes;
6926-
for (auto &Op : I->operands())
6927-
OpTypes.push_back(Op->getType());
6928-
VecCallCost =
6929-
TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind);
6930-
}
6931-
}
6932-
return std::min(InstrCost, VecCallCost);
6917+
Op2Info, Operands, I, TLI);
69336918
}
69346919
case Instruction::FNeg: {
69356920
return TTI.getArithmeticInstrCost(

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8902,7 +8902,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
89028902
TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
89038903
TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
89048904
return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
8905-
Op2Info) +
8905+
Op2Info, std::nullopt, nullptr, TLI) +
89068906
CommonCost;
89078907
};
89088908
return GetCostDiff(GetScalarCost, GetVectorCost);
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s
3+
4+
@a = common global ptr null, align 8
5+
6+
define void @frem_v2double() {
7+
; CHECK-LABEL: define void @frem_v2double() {
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8
10+
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8
11+
; CHECK-NEXT: [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]]
12+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr @a, align 8
13+
; CHECK-NEXT: ret void
14+
;
15+
entry:
16+
%a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
17+
%a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
18+
%b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
19+
%b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
20+
%r0 = frem double %a0, %b0
21+
%r1 = frem double %a1, %b1
22+
store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
23+
store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
24+
ret void
25+
}
26+
27+
define void @frem_v4float() {
28+
; CHECK-LABEL: define void @frem_v4float() {
29+
; CHECK-NEXT: entry:
30+
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8
31+
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8
32+
; CHECK-NEXT: [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]]
33+
; CHECK-NEXT: store <4 x float> [[TMP2]], ptr @a, align 8
34+
; CHECK-NEXT: ret void
35+
;
36+
entry:
37+
%a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
38+
%a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
39+
%a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
40+
%a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
41+
%b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
42+
%b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
43+
%b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
44+
%b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
45+
%r0 = frem float %a0, %b0
46+
%r1 = frem float %a1, %b1
47+
%r2 = frem float %a2, %b2
48+
%r3 = frem float %a3, %b3
49+
store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
50+
store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
51+
store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
52+
store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
53+
ret void
54+
}
55+

0 commit comments

Comments
 (0)