Skip to content

Commit ec47ab2

Browse files
committed
[AArch64] Fix llvm#94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2)
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: llvm#91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
1 parent c19a3cb commit ec47ab2

File tree

3 files changed

+301
-0
lines changed

3 files changed

+301
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3907,6 +3907,62 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
39073907
unsigned RegWidth,
39083908
bool isReciprocal) {
39093909
APFloat FVal(0.0);
3910+
3911+
if (N.getOpcode() == ISD::BUILD_VECTOR) {
3912+
EVT VT = N.getValueType();
3913+
EVT EltVT = VT.getVectorElementType();
3914+
3915+
unsigned NumElts = N.getNumOperands();
3916+
SDValue FirstOp = N.getOperand(0);
3917+
3918+
ConstantFPSDNode *FirstCN = dyn_cast<ConstantFPSDNode>(FirstOp);
3919+
if (!FirstCN)
3920+
return false;
3921+
3922+
APFloat FirstVal = FirstCN->getValueAPF();
3923+
if (EltVT == MVT::f16) {
3924+
bool ignored;
3925+
FirstVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
3926+
}
3927+
3928+
// Handle reciprocal case if needed
3929+
if (isReciprocal) {
3930+
if (!FirstVal.getExactInverse(&FirstVal))
3931+
return false;
3932+
}
3933+
3934+
bool IsExact;
3935+
APSInt IntVal(65, true);
3936+
FirstVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
3937+
3938+
if (!IsExact || !IntVal.isPowerOf2())
3939+
return false;
3940+
3941+
unsigned FBits = IntVal.logBase2();
3942+
if (FBits == 0 || FBits > RegWidth)
3943+
return false;
3944+
3945+
APInt FirstBits = FirstVal.bitcastToAPInt();
3946+
3947+
for (unsigned i = 1; i < NumElts; ++i) {
3948+
ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N.getOperand(i));
3949+
if (!CN)
3950+
return false;
3951+
3952+
APFloat ElemVal = CN->getValueAPF();
3953+
if (EltVT == MVT::f16) {
3954+
bool ignored;
3955+
ElemVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
3956+
}
3957+
3958+
if (ElemVal.bitcastToAPInt() != FirstBits)
3959+
return false;
3960+
}
3961+
3962+
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
3963+
return true;
3964+
}
3965+
39103966
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
39113967
FVal = CN->getValueAPF();
39123968
else if (LoadSDNode *LN = dyn_cast<LoadSDNode>(N)) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11481148
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
11491149
ISD::FP_TO_UINT_SAT, ISD::FADD});
11501150

1151+
// Try to fmul -> scvtf for powers of 2
1152+
setTargetDAGCombine(ISD::FMUL);
1153+
11511154
// Try and combine setcc with csel
11521155
setTargetDAGCombine(ISD::SETCC);
11531156

@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
1925019253
return FixConv;
1925119254
}
1925219255

19256+
/// Try to extract a log2 exponent from a uniform constant FP splat.
19257+
/// Returns -1 if the value is not a power-of-two float.
19258+
static int getUniformFPSplatLog2(const BuildVectorSDNode *BV,
19259+
unsigned MaxExponent) {
19260+
SDValue FirstElt = BV->getOperand(0);
19261+
if (!isa<ConstantFPSDNode>(FirstElt))
19262+
return -1;
19263+
19264+
const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
19265+
const APFloat &FirstVal = FirstConst->getValueAPF();
19266+
const fltSemantics &Sem = FirstVal.getSemantics();
19267+
19268+
// Check all elements are the same
19269+
for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
19270+
SDValue Elt = BV->getOperand(i);
19271+
if (!isa<ConstantFPSDNode>(Elt))
19272+
return -1;
19273+
const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
19274+
if (!Val.bitwiseIsEqual(FirstVal))
19275+
return -1;
19276+
}
19277+
19278+
// Reject zero, NaN, or negative values
19279+
if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
19280+
return -1;
19281+
19282+
// Get raw bits
19283+
APInt Bits = FirstVal.bitcastToAPInt();
19284+
19285+
int ExponentBias = 0;
19286+
unsigned ExponentBits = 0;
19287+
unsigned MantissaBits = 0;
19288+
19289+
if (&Sem == &APFloat::IEEEsingle()) {
19290+
ExponentBias = 127;
19291+
ExponentBits = 8;
19292+
MantissaBits = 23;
19293+
} else if (&Sem == &APFloat::IEEEdouble()) {
19294+
ExponentBias = 1023;
19295+
ExponentBits = 11;
19296+
MantissaBits = 52;
19297+
} else {
19298+
// Unsupported type
19299+
return -1;
19300+
}
19301+
19302+
// Mask out mantissa and check it's zero (i.e., power of two)
19303+
APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
19304+
if ((Bits & MantissaMask) != 0)
19305+
return -1;
19306+
19307+
// Extract exponent
19308+
unsigned ExponentShift = MantissaBits;
19309+
APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(), ExponentShift,
19310+
ExponentShift + ExponentBits);
19311+
int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
19312+
int Log2 = ExponentBias - Exponent;
19313+
19314+
if (static_cast<unsigned>(Log2) > MaxExponent)
19315+
return -1;
19316+
19317+
return Log2;
19318+
}
19319+
19320+
/// Fold a floating-point multiply by power of two into fixed-point to
19321+
/// floating-point conversion.
19322+
static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
19323+
TargetLowering::DAGCombinerInfo &DCI,
19324+
const AArch64Subtarget *Subtarget) {
19325+
19326+
if (!Subtarget->hasNEON())
19327+
return SDValue();
19328+
19329+
// N is the FMUL node.
19330+
if (N->getOpcode() != ISD::FMUL)
19331+
return SDValue();
19332+
19333+
// SINT_TO_FP or UINT_TO_FP
19334+
SDValue Op = N->getOperand(0);
19335+
unsigned Opc = Op->getOpcode();
19336+
if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
19337+
!Op.getOperand(0).getValueType().isSimple() ||
19338+
(Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
19339+
return SDValue();
19340+
19341+
SDValue ConstVec = N->getOperand(1);
19342+
if (!isa<BuildVectorSDNode>(ConstVec))
19343+
return SDValue();
19344+
19345+
MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
19346+
int32_t IntBits = IntTy.getSizeInBits();
19347+
if (IntBits != 16 && IntBits != 32 && IntBits != 64)
19348+
return SDValue();
19349+
19350+
MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
19351+
int32_t FloatBits = FloatTy.getSizeInBits();
19352+
if (FloatBits != 32 && FloatBits != 64)
19353+
return SDValue();
19354+
19355+
if (IntBits > FloatBits)
19356+
return SDValue();
19357+
19358+
BitVector UndefElements;
19359+
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
19360+
int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
19361+
19362+
// Handle cases where it's not a power of two, or is 2^0.
19363+
if (IntrinsicC == -1 || IntrinsicC == 0)
19364+
return SDValue();
19365+
19366+
// Check if IntrinsicC is within the valid range [1, FloatBits].
19367+
// The 's' value must be in [1, FloatBits].
19368+
if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
19369+
return SDValue();
19370+
19371+
MVT ResTy;
19372+
unsigned NumLanes = Op.getValueType().getVectorNumElements();
19373+
switch (NumLanes) {
19374+
default:
19375+
return SDValue();
19376+
case 2:
19377+
ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
19378+
break;
19379+
case 4:
19380+
ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
19381+
break;
19382+
}
19383+
19384+
if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
19385+
return SDValue();
19386+
19387+
SDLoc DL(N);
19388+
SDValue ConvInput = Op.getOperand(0);
19389+
bool IsSigned = Opc == ISD::SINT_TO_FP;
19390+
19391+
if (IntBits < FloatBits)
19392+
ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
19393+
ResTy, ConvInput);
19394+
19395+
unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
19396+
: Intrinsic::aarch64_neon_vcvtfxu2fp;
19397+
19398+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
19399+
DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
19400+
DAG.getConstant(IntrinsicC, DL, MVT::i32));
19401+
}
19402+
1925319403
static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1925419404
const AArch64TargetLowering &TLI) {
1925519405
EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2669326843
case ISD::FP_TO_SINT_SAT:
2669426844
case ISD::FP_TO_UINT_SAT:
2669526845
return performFpToIntCombine(N, DAG, DCI, Subtarget);
26846+
case ISD::FMUL:
26847+
return performFMulCombine(N, DAG, DCI, Subtarget);
2669626848
case ISD::OR:
2669726849
return performORCombine(N, DCI, Subtarget, *this);
2669826850
case ISD::AND:
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+fullfp16 -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
2+
3+
; Scalar fdiv by 16.0 (f32)
4+
define float @tests_f32_div(i32 %in) {
5+
; CHECK-LABEL: tests_f32_div:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: scvtf s0, w0, #4
8+
; CHECK-NEXT: ret
9+
entry:
10+
%vcvt.i = sitofp i32 %in to float
11+
%div.i = fdiv float %vcvt.i, 16.0
12+
ret float %div.i
13+
}
14+
15+
; Scalar fmul by (2^-4) (f32)
16+
define float @testsmul_f32_mul(i32 %in) local_unnamed_addr #0 {
17+
; CHECK-LABEL: testsmul_f32_mul:
18+
; CHECK: // %bb.0:
19+
; CHECK-NEXT: scvtf s0, w0, #4
20+
; CHECK-NEXT: ret
21+
%vcvt.i = sitofp i32 %in to float
22+
%div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
23+
ret float %div.i
24+
}
25+
26+
; Vector fdiv by 16.0 (v2f32)
27+
define <2 x float> @testv_v2f32_div(<2 x i32> %in) {
28+
; CHECK-LABEL: testv_v2f32_div:
29+
; CHECK: // %bb.0: // %entry
30+
; CHECK-NEXT: scvtf.2s v0, v0, #4
31+
; CHECK-NEXT: ret
32+
entry:
33+
%vcvt.i = sitofp <2 x i32> %in to <2 x float>
34+
%div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
35+
ret <2 x float> %div.i
36+
}
37+
38+
; Vector fmul by 2^-4 (v2f32)
39+
define <2 x float> @testvmul_v2f32_mul(<2 x i32> %in) local_unnamed_addr #0 {
40+
; CHECK-LABEL: testvmul_v2f32_mul:
41+
; CHECK: // %bb.0:
42+
; CHECK-NEXT: scvtf.2s v0, v0, #4
43+
; CHECK-NEXT: ret
44+
%vcvt.i = sitofp <2 x i32> %in to <2 x float>
45+
%div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
46+
ret <2 x float> %div.i
47+
}
48+
49+
; Scalar fdiv by 16.0 (f64)
50+
define double @tests_f64_div(i64 %in) {
51+
; CHECK-LABEL: tests_f64_div:
52+
; CHECK: // %bb.0: // %entry
53+
; CHECK-NEXT: scvtf d0, x0, #4
54+
; CHECK-NEXT: ret
55+
entry:
56+
%vcvt.i = sitofp i64 %in to double
57+
%div.i = fdiv double %vcvt.i, 1.600000e+01 ; 16.0 in double-precision
58+
ret double %div.i
59+
}
60+
61+
; Scalar fmul by (2^-4) (f64)
62+
define double @testsmul_f64_mul(i64 %in) local_unnamed_addr #0 {
63+
; CHECK-LABEL: testsmul_f64_mul:
64+
; CHECK: // %bb.0:
65+
; CHECK-NEXT: scvtf d0, x0, #4
66+
; CHECK-NEXT: ret
67+
%vcvt.i = sitofp i64 %in to double
68+
%div.i = fmul double %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4 in double-precision
69+
ret double %div.i
70+
}
71+
72+
; Vector fdiv by 16.0 (v2f64)
73+
define <2 x double> @testv_v2f64_div(<2 x i64> %in) {
74+
; CHECK-LABEL: testv_v2f64_div:
75+
; CHECK: // %bb.0: // %entry
76+
; CHECK-NEXT: scvtf.2d v0, v0, #4
77+
; CHECK-NEXT: ret
78+
entry:
79+
%vcvt.i = sitofp <2 x i64> %in to <2 x double>
80+
%div.i = fdiv <2 x double> %vcvt.i, <double 1.600000e+01, double 1.600000e+01>
81+
ret <2 x double> %div.i
82+
}
83+
84+
; Vector fmul by 2^-4 (v2f64)
85+
define <2 x double> @testvmul_v2f64_mul(<2 x i64> %in) local_unnamed_addr #0 {
86+
; CHECK-LABEL: testvmul_v2f64_mul:
87+
; CHECK: // %bb.0:
88+
; CHECK-NEXT: scvtf.2d v0, v0, #4
89+
; CHECK-NEXT: ret
90+
%vcvt.i = sitofp <2 x i64> %in to <2 x double>
91+
%div.i = fmul <2 x double> %vcvt.i, splat (double 6.250000e-02) ; 0.0625 is 2^-4 in double-precision
92+
ret <2 x double> %div.i
93+
}

0 commit comments

Comments
 (0)