@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1148
1148
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1149
1149
ISD::FP_TO_UINT_SAT, ISD::FADD});
1150
1150
1151
+ // Try to fmul -> scvtf for powers of 2
1152
+ setTargetDAGCombine(ISD::FMUL);
1153
+
1151
1154
// Try and combine setcc with csel
1152
1155
setTargetDAGCombine(ISD::SETCC);
1153
1156
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
19250
19253
return FixConv;
19251
19254
}
19252
19255
19256
+ /// Try to extract a log2 exponent from a uniform constant FP splat.
19257
+ /// Returns -1 if the value is not a power-of-two float.
19258
+ static int getUniformFPSplatLog2(const BuildVectorSDNode *BV,
19259
+ unsigned MaxExponent) {
19260
+ SDValue FirstElt = BV->getOperand(0);
19261
+ if (!isa<ConstantFPSDNode>(FirstElt))
19262
+ return -1;
19263
+
19264
+ const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
19265
+ const APFloat &FirstVal = FirstConst->getValueAPF();
19266
+ const fltSemantics &Sem = FirstVal.getSemantics();
19267
+
19268
+ // Check all elements are the same
19269
+ for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
19270
+ SDValue Elt = BV->getOperand(i);
19271
+ if (!isa<ConstantFPSDNode>(Elt))
19272
+ return -1;
19273
+ const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
19274
+ if (!Val.bitwiseIsEqual(FirstVal))
19275
+ return -1;
19276
+ }
19277
+
19278
+ // Reject zero, NaN, or negative values
19279
+ if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
19280
+ return -1;
19281
+
19282
+ // Get raw bits
19283
+ APInt Bits = FirstVal.bitcastToAPInt();
19284
+
19285
+ int ExponentBias = 0;
19286
+ unsigned ExponentBits = 0;
19287
+ unsigned MantissaBits = 0;
19288
+
19289
+ if (&Sem == &APFloat::IEEEsingle()) {
19290
+ ExponentBias = 127;
19291
+ ExponentBits = 8;
19292
+ MantissaBits = 23;
19293
+ } else if (&Sem == &APFloat::IEEEdouble()) {
19294
+ ExponentBias = 1023;
19295
+ ExponentBits = 11;
19296
+ MantissaBits = 52;
19297
+ } else {
19298
+ // Unsupported type
19299
+ return -1;
19300
+ }
19301
+
19302
+ // Mask out mantissa and check it's zero (i.e., power of two)
19303
+ APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
19304
+ if ((Bits & MantissaMask) != 0)
19305
+ return -1;
19306
+
19307
+ // Extract exponent
19308
+ unsigned ExponentShift = MantissaBits;
19309
+ APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(), ExponentShift,
19310
+ ExponentShift + ExponentBits);
19311
+ int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
19312
+ int Log2 = ExponentBias - Exponent;
19313
+
19314
+ if (static_cast<unsigned>(Log2) > MaxExponent)
19315
+ return -1;
19316
+
19317
+ return Log2;
19318
+ }
19319
+
19320
+ /// Fold a floating-point multiply by power of two into fixed-point to
19321
+ /// floating-point conversion.
19322
+ static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
19323
+ TargetLowering::DAGCombinerInfo &DCI,
19324
+ const AArch64Subtarget *Subtarget) {
19325
+
19326
+ if (!Subtarget->hasNEON())
19327
+ return SDValue();
19328
+
19329
+ // N is the FMUL node.
19330
+ if (N->getOpcode() != ISD::FMUL)
19331
+ return SDValue();
19332
+
19333
+ // SINT_TO_FP or UINT_TO_FP
19334
+ SDValue Op = N->getOperand(0);
19335
+ unsigned Opc = Op->getOpcode();
19336
+ if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
19337
+ !Op.getOperand(0).getValueType().isSimple() ||
19338
+ (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
19339
+ return SDValue();
19340
+
19341
+ SDValue ConstVec = N->getOperand(1);
19342
+ if (!isa<BuildVectorSDNode>(ConstVec))
19343
+ return SDValue();
19344
+
19345
+ MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
19346
+ int32_t IntBits = IntTy.getSizeInBits();
19347
+ if (IntBits != 16 && IntBits != 32 && IntBits != 64)
19348
+ return SDValue();
19349
+
19350
+ MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
19351
+ int32_t FloatBits = FloatTy.getSizeInBits();
19352
+ if (FloatBits != 32 && FloatBits != 64)
19353
+ return SDValue();
19354
+
19355
+ if (IntBits > FloatBits)
19356
+ return SDValue();
19357
+
19358
+ BitVector UndefElements;
19359
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
19360
+ int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
19361
+
19362
+ // Handle cases where it's not a power of two, or is 2^0.
19363
+ if (IntrinsicC == -1 || IntrinsicC == 0)
19364
+ return SDValue();
19365
+
19366
+ // Check if IntrinsicC is within the valid range [1, FloatBits].
19367
+ // The 's' value must be in [1, FloatBits].
19368
+ if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
19369
+ return SDValue();
19370
+
19371
+ MVT ResTy;
19372
+ unsigned NumLanes = Op.getValueType().getVectorNumElements();
19373
+ switch (NumLanes) {
19374
+ default:
19375
+ return SDValue();
19376
+ case 2:
19377
+ ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
19378
+ break;
19379
+ case 4:
19380
+ ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
19381
+ break;
19382
+ }
19383
+
19384
+ if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
19385
+ return SDValue();
19386
+
19387
+ SDLoc DL(N);
19388
+ SDValue ConvInput = Op.getOperand(0);
19389
+ bool IsSigned = Opc == ISD::SINT_TO_FP;
19390
+
19391
+ if (IntBits < FloatBits)
19392
+ ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
19393
+ ResTy, ConvInput);
19394
+
19395
+ unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
19396
+ : Intrinsic::aarch64_neon_vcvtfxu2fp;
19397
+
19398
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
19399
+ DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
19400
+ DAG.getConstant(IntrinsicC, DL, MVT::i32));
19401
+ }
19402
+
19253
19403
static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19254
19404
const AArch64TargetLowering &TLI) {
19255
19405
EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
26693
26843
case ISD::FP_TO_SINT_SAT:
26694
26844
case ISD::FP_TO_UINT_SAT:
26695
26845
return performFpToIntCombine(N, DAG, DCI, Subtarget);
26846
+ case ISD::FMUL:
26847
+ return performFMulCombine(N, DAG, DCI, Subtarget);
26696
26848
case ISD::OR:
26697
26849
return performORCombine(N, DCI, Subtarget, *this);
26698
26850
case ISD::AND:
0 commit comments