Skip to content

Commit 2096ae8

Browse files
committed
[AArch64] Fix llvm#94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2)
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: llvm#91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
1 parent d57b867 commit 2096ae8

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
487487
bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
488488
unsigned Width);
489489

490+
template <unsigned FloatWidth>
491+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
492+
return SelectCVTFixedPosRecipOperandVec(N, FixedPos, FloatWidth);
493+
}
494+
495+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
496+
unsigned Width);
497+
490498
bool SelectCMP_SWAP(SDNode *N);
491499

492500
bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
39523960
return true;
39533961
}
39543962

3963+
static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
3964+
SDValue N,
3965+
SDValue &FixedPos,
3966+
unsigned FloatWidth,
3967+
bool isReciprocal) {
3968+
3969+
// N must be a bitcast/nvcast of a vector float type.
3970+
if (!((N.getOpcode() == ISD::BITCAST ||
3971+
N.getOpcode() == AArch64ISD::NVCAST) &&
3972+
N.getValueType().isVector() && N.getValueType().isFloatingPoint())) {
3973+
return false;
3974+
}
3975+
3976+
if (N.getNumOperands() == 0)
3977+
return false;
3978+
SDValue ImmediateNode = N.getOperand(0);
3979+
3980+
bool isSplatConfirmed = false;
3981+
3982+
if (ImmediateNode.getOpcode() == AArch64ISD::DUP ||
3983+
ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR) {
3984+
// These opcodes inherently mean a splat.
3985+
isSplatConfirmed = true;
3986+
} else if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
3987+
// For BUILD_VECTOR, we must explicitly check if it's a constant splat.
3988+
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
3989+
APInt SplatValue;
3990+
APInt SplatUndef;
3991+
unsigned SplatBitSize;
3992+
bool HasAnyUndefs;
3993+
if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
3994+
HasAnyUndefs)) {
3995+
isSplatConfirmed = true;
3996+
} else {
3997+
return false;
3998+
}
3999+
} else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
4000+
// This implies that the DAG structure was (DUP (MOVIshift C)) or
4001+
// (BUILD_VECTOR (MOVIshift C)).
4002+
isSplatConfirmed = true;
4003+
} else {
4004+
return false;
4005+
}
4006+
4007+
// If we reached here, isSplatConfirmed should be true and ScalarSourceNode
4008+
// should be set. But just in case ...
4009+
if (!isSplatConfirmed)
4010+
return false;
4011+
4012+
// --- Extract the actual constant value ---
4013+
auto ScalarSourceNode = ImmediateNode.getOperand(0);
4014+
APFloat FVal(0.0);
4015+
if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
4016+
// Scalar source is a floating-point constant.
4017+
FVal = CFP->getValueAPF();
4018+
} else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
4019+
// Scalar source is an integer constant; interpret its bits as
4020+
// floating-point.
4021+
EVT FloatEltVT = N.getValueType().getVectorElementType();
4022+
4023+
if (FloatEltVT == MVT::f32) {
4024+
FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
4025+
} else if (FloatEltVT == MVT::f64) {
4026+
FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
4027+
} else if (FloatEltVT == MVT::f16) {
4028+
auto *ShiftAmountConst =
4029+
dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(1));
4030+
4031+
if (!ShiftAmountConst) {
4032+
return false;
4033+
}
4034+
APInt ImmediateVal = CI->getAPIntValue();
4035+
unsigned ShiftAmount = ShiftAmountConst->getAPIntValue().getZExtValue();
4036+
APInt EffectiveBits = ImmediateVal.trunc(16).shl(ShiftAmount);
4037+
FVal = APFloat(APFloat::IEEEhalf(), EffectiveBits);
4038+
} else {
4039+
// Unsupported floating-point element type.
4040+
return false;
4041+
}
4042+
} else {
4043+
// ScalarSourceNode is not a recognized constant type.
4044+
return false;
4045+
}
4046+
4047+
// --- Perform fixed-point reciprocal check and power-of-2 validation on FVal
4048+
// --- Normalize f16 to f32 if needed for consistent APFloat operations.
4049+
if (N.getValueType().getVectorElementType() == MVT::f16) {
4050+
bool ignored;
4051+
FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
4052+
}
4053+
4054+
// Handle reciprocal case.
4055+
if (isReciprocal) {
4056+
if (!FVal.getExactInverse(&FVal))
4057+
// Not an exact reciprocal, or reciprocal not a power of 2.
4058+
return false;
4059+
}
4060+
4061+
bool IsExact;
4062+
unsigned TargetIntBits =
4063+
N.getValueType().getVectorElementType().getSizeInBits();
4064+
APSInt IntVal(
4065+
TargetIntBits + 1,
4066+
true); // Use TargetIntBits + 1 for sufficient bits for conversion
4067+
4068+
FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
4069+
4070+
if (!IsExact || !IntVal.isPowerOf2())
4071+
return false;
4072+
4073+
unsigned FBits = IntVal.logBase2();
4074+
// FBits must be non-zero (implies actual scaling) and within the range
4075+
// supported by the instruction (typically 1 to 64 for AArch64 FCVTZS/FCVTZU).
4076+
// FloatWidth should ideally be the width of the *integer elements* in the
4077+
// vector (16, 32, 64).
4078+
if (FBits == 0 || FBits > FloatWidth)
4079+
return false;
4080+
4081+
// Set FixedPos to the extracted FBits as an i32 constant SDValue.
4082+
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
4083+
return true;
4084+
}
4085+
39554086
bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
39564087
unsigned RegWidth) {
39574088
return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4096,12 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
39654096
true);
39664097
}
39674098

4099+
bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(
4100+
SDValue N, SDValue &FixedPos, unsigned FloatWidth) {
4101+
return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
4102+
FloatWidth, true);
4103+
}
4104+
39684105
// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
39694106
// of the string and obtains the integer values from them and combines these
39704107
// into a single value to be used in the MRS/MSR instruction.

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8473,6 +8473,58 @@ def : Pat<(v8f16 (sint_to_fp (v8i16 (AArch64vashr_exact v8i16:$Vn, i32:$shift)))
84738473
(SCVTFv8i16_shift $Vn, vecshiftR16:$shift)>;
84748474
}
84758475

8476+
// Select fmul(sitofp(x), C) where C is a constant reciprocal of a power of two.
8477+
// For both scalar and vector inputs, if we have sitofp(X) * C (where C is
8478+
// 1/2^N), this can be optimized to scvtf(X, 2^N).
8479+
class fixedpoint_recip_vec_i16<ValueType FloatVT>
8480+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<16>", []>;
8481+
class fixedpoint_recip_vec_i32<ValueType FloatVT>
8482+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<32>", []>;
8483+
class fixedpoint_recip_vec_i64<ValueType FloatVT>
8484+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<64>", []>;
8485+
def fixedpoint_recip_vec_xform : SDNodeXForm<timm, [{
8486+
// Suppress the unused variable warning by explicitly using N.
8487+
// The actual value needed for the pattern is already in V.
8488+
(void)N;
8489+
return V;
8490+
}]>;
8491+
8492+
def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
8493+
def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
8494+
def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
8495+
8496+
def fixedpoint_recip_v4f16_v4i16 : fixedpoint_recip_vec_i16<v4f16>;
8497+
def fixedpoint_recip_v8f16_v8i16 : fixedpoint_recip_vec_i16<v8f16>;
8498+
8499+
let Predicates = [HasNEON] in {
8500+
def : Pat<(v2f32(fmul(sint_to_fp(v2i32 V64:$Rn)),
8501+
fixedpoint_recip_v2f32_v2i32:$scale)),
8502+
(v2f32(SCVTFv2i32_shift(v2i32 V64:$Rn),
8503+
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f32_v2i32:$scale)))>;
8504+
8505+
def : Pat<(v4f32(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
8506+
fixedpoint_recip_v4f32_v4i32:$scale)),
8507+
(v4f32(SCVTFv4i32_shift(v4i32 FPR128:$Rn),
8508+
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f32_v4i32:$scale)))>;
8509+
8510+
def : Pat<(v2f64(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
8511+
fixedpoint_recip_v2f64_v2i64:$scale)),
8512+
(v2f64(SCVTFv2i64_shift(v2i64 FPR128:$Rn),
8513+
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f64_v2i64:$scale)))>;
8514+
}
8515+
8516+
let Predicates = [HasNEON, HasFullFP16] in {
8517+
def : Pat<(v4f16(fmul(sint_to_fp(v4i16 V64:$Rn)),
8518+
fixedpoint_recip_v4f16_v4i16:$scale)),
8519+
(v4f16(SCVTFv4i16_shift(v4i16 V64:$Rn),
8520+
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f16_v4i16:$scale)))>;
8521+
8522+
def : Pat<(v8f16(fmul(sint_to_fp(v8i16 FPR128:$Rn)),
8523+
fixedpoint_recip_v8f16_v8i16:$scale)),
8524+
(v8f16(SCVTFv8i16_shift(v8i16 FPR128:$Rn),
8525+
(fixedpoint_recip_vec_xform fixedpoint_recip_v8f16_v8i16:$scale)))>;
8526+
}
8527+
84768528
// X << 1 ==> X + X
84778529
class SHLToADDPat<ValueType ty, RegisterClass regtype>
84788530
: Pat<(ty (AArch64vshl (ty regtype:$Rn), (i32 1))),
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -mattr=+fullfp16 -o - %s | FileCheck %s
2+
3+
; This test file verifies that fdiv(sitofp(x), C)
4+
; where C is a constant power of two,
5+
; is optimized to scvtf(X, shift_amount).
6+
; This typically involves an implicit fdiv -> fmul_reciprocal transformation.
7+
8+
; --- Scalar Tests ---
9+
10+
; Scalar f32 (from i32)
11+
define float @test_f32_div_const(i32 %in) {
12+
; CHECK-LABEL: test_f32_div_const:
13+
; CHECK: // %bb.0: // %entry
14+
; CHECK-NEXT: scvtf s0, w0, #4
15+
; CHECK-NEXT: ret
16+
entry:
17+
%vcvt.i = sitofp i32 %in to float
18+
%div.i = fdiv float %vcvt.i, 16.0
19+
ret float %div.i
20+
}
21+
22+
; Scalar f64 (from i64)
23+
define double @test_f64_div_const(i64 %in) {
24+
; CHECK-LABEL: test_f64_div_const:
25+
; CHECK: // %bb.0: // %entry
26+
; CHECK-NEXT: scvtf d0, x0, #4
27+
; CHECK-NEXT: ret
28+
entry:
29+
%vcvt.i = sitofp i64 %in to double
30+
%div.i = fdiv double %vcvt.i, 16.0
31+
ret double %div.i
32+
}
33+
34+
; --- Vector Tests ---
35+
36+
; Vector v2f32 (from v2i32)
37+
define <2 x float> @testv_v2f32_div_const(<2 x i32> %in) {
38+
; CHECK-LABEL: testv_v2f32_div_const:
39+
; CHECK: // %bb.0: // %entry
40+
; CHECK-NEXT: scvtf.2s v0, v0, #4
41+
; CHECK-NEXT: ret
42+
entry:
43+
%vcvt.i = sitofp <2 x i32> %in to <2 x float>
44+
%div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
45+
ret <2 x float> %div.i
46+
}
47+
48+
; Vector v4f32 (from v4i32)
49+
define <4 x float> @testv_v4f32_div_const(<4 x i32> %in) {
50+
; CHECK-LABEL: testv_v4f32_div_const:
51+
; CHECK: // %bb.0: // %entry
52+
; CHECK-NEXT: scvtf.4s v0, v0, #4
53+
; CHECK-NEXT: ret
54+
entry:
55+
%vcvt.i = sitofp <4 x i32> %in to <4 x float>
56+
%div.i = fdiv <4 x float> %vcvt.i, <float 16.0, float 16.0, float 16.0, float 16.0>
57+
ret <4 x float> %div.i
58+
}
59+
60+
; Vector v2f64 (from v2i64)
61+
define <2 x double> @testv_v2f64_div_const(<2 x i64> %in) {
62+
; CHECK-LABEL: testv_v2f64_div_const:
63+
; CHECK: // %bb.0: // %entry
64+
; CHECK-NEXT: scvtf.2d v0, v0, #4
65+
; CHECK-NEXT: ret
66+
entry:
67+
%vcvt.i = sitofp <2 x i64> %in to <2 x double>
68+
%div.i = fdiv <2 x double> %vcvt.i, <double 16.0, double 16.0>
69+
ret <2 x double> %div.i
70+
}
71+
72+
; --- f16 Tests (assuming fullfp16 is enabled) ---
73+
74+
; Vector v4f16 (from v4i16)
75+
define <4 x half> @testv_v4f16_div_const(<4 x i16> %in) {
76+
; CHECK-LABEL: testv_v4f16_div_const:
77+
; CHECK: // %bb.0: // %entry
78+
; CHECK-NEXT: scvtf.4h v0, v0, #4
79+
; CHECK-NEXT: ret
80+
entry:
81+
%vcvt.i = sitofp <4 x i16> %in to <4 x half>
82+
%div.i = fdiv <4 x half> %vcvt.i, <half 16.0, half 16.0, half 16.0, half 16.0> ; 16.0 in half-precision
83+
ret <4 x half> %div.i
84+
}
85+
86+
; Vector v8f16 (from v8i16)
87+
define <8 x half> @testv_v8f16_div_const(<8 x i16> %in) {
88+
; CHECK-LABEL: testv_v8f16_div_const:
89+
; CHECK: // %bb.0: // %entry
90+
; CHECK-NEXT: scvtf.8h v0, v0, #4
91+
; CHECK-NEXT: ret
92+
entry:
93+
%vcvt.i = sitofp <8 x i16> %in to <8 x half>
94+
%div.i = fdiv <8 x half> %vcvt.i, <half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0> ; 16.0 in half-precision
95+
ret <8 x half> %div.i
96+
}

0 commit comments

Comments
 (0)