@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
487
487
bool SelectCVTFixedPosRecipOperand (SDValue N, SDValue &FixedPos,
488
488
unsigned Width);
489
489
490
+ template <unsigned FloatWidth>
491
+ bool SelectCVTFixedPosRecipOperandVec (SDValue N, SDValue &FixedPos) {
492
+ return SelectCVTFixedPosRecipOperandVec (N, FixedPos, FloatWidth);
493
+ }
494
+
495
+ bool SelectCVTFixedPosRecipOperandVec (SDValue N, SDValue &FixedPos,
496
+ unsigned Width);
497
+
490
498
bool SelectCMP_SWAP (SDNode *N);
491
499
492
500
bool SelectSVEAddSubImm (SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
3952
3960
return true ;
3953
3961
}
3954
3962
3963
+ static bool checkCVTFixedPointOperandWithFBitsForVectors (SelectionDAG *CurDAG,
3964
+ SDValue N,
3965
+ SDValue &FixedPos,
3966
+ unsigned FloatWidth,
3967
+ bool isReciprocal) {
3968
+
3969
+ // N must be a bitcast/nvcast of a vector float type.
3970
+ if (!((N.getOpcode () == ISD::BITCAST ||
3971
+ N.getOpcode () == AArch64ISD::NVCAST) &&
3972
+ N.getValueType ().isVector () && N.getValueType ().isFloatingPoint ())) {
3973
+ return false ;
3974
+ }
3975
+
3976
+ if (N.getNumOperands () == 0 )
3977
+ return false ;
3978
+ SDValue ImmediateNode = N.getOperand (0 );
3979
+
3980
+ bool isSplatConfirmed = false ;
3981
+
3982
+ if (ImmediateNode.getOpcode () == AArch64ISD::DUP ||
3983
+ ImmediateNode.getOpcode () == ISD::SPLAT_VECTOR) {
3984
+ // These opcodes inherently mean a splat.
3985
+ isSplatConfirmed = true ;
3986
+ } else if (ImmediateNode.getOpcode () == ISD::BUILD_VECTOR) {
3987
+ // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
3988
+ BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode ());
3989
+ APInt SplatValue;
3990
+ APInt SplatUndef;
3991
+ unsigned SplatBitSize;
3992
+ bool HasAnyUndefs;
3993
+ if (BVN->isConstantSplat (SplatValue, SplatUndef, SplatBitSize,
3994
+ HasAnyUndefs)) {
3995
+ isSplatConfirmed = true ;
3996
+ } else {
3997
+ return false ;
3998
+ }
3999
+ } else if (ImmediateNode.getOpcode () == AArch64ISD::MOVIshift) {
4000
+ // This implies that the DAG structure was (DUP (MOVIshift C)) or
4001
+ // (BUILD_VECTOR (MOVIshift C)).
4002
+ isSplatConfirmed = true ;
4003
+ } else {
4004
+ return false ;
4005
+ }
4006
+
4007
+ // If we reached here, isSplatConfirmed should be true and ScalarSourceNode
4008
+ // should be set. But just in case ...
4009
+ if (!isSplatConfirmed)
4010
+ return false ;
4011
+
4012
+ // --- Extract the actual constant value ---
4013
+ auto ScalarSourceNode = ImmediateNode.getOperand (0 );
4014
+ APFloat FVal (0.0 );
4015
+ if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
4016
+ // Scalar source is a floating-point constant.
4017
+ FVal = CFP->getValueAPF ();
4018
+ } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
4019
+ // Scalar source is an integer constant; interpret its bits as
4020
+ // floating-point.
4021
+ EVT FloatEltVT = N.getValueType ().getVectorElementType ();
4022
+
4023
+ if (FloatEltVT == MVT::f32 ) {
4024
+ FVal = APFloat (APFloat::IEEEsingle (), CI->getAPIntValue ());
4025
+ } else if (FloatEltVT == MVT::f64 ) {
4026
+ FVal = APFloat (APFloat::IEEEdouble (), CI->getAPIntValue ());
4027
+ } else if (FloatEltVT == MVT::f16 ) {
4028
+ auto *ShiftAmountConst =
4029
+ dyn_cast<ConstantSDNode>(ImmediateNode.getOperand (1 ));
4030
+
4031
+ if (!ShiftAmountConst) {
4032
+ return false ;
4033
+ }
4034
+ APInt ImmediateVal = CI->getAPIntValue ();
4035
+ unsigned ShiftAmount = ShiftAmountConst->getAPIntValue ().getZExtValue ();
4036
+ APInt EffectiveBits = ImmediateVal.trunc (16 ).shl (ShiftAmount);
4037
+ FVal = APFloat (APFloat::IEEEhalf (), EffectiveBits);
4038
+ } else {
4039
+ // Unsupported floating-point element type.
4040
+ return false ;
4041
+ }
4042
+ } else {
4043
+ // ScalarSourceNode is not a recognized constant type.
4044
+ return false ;
4045
+ }
4046
+
4047
+ // --- Perform fixed-point reciprocal check and power-of-2 validation on FVal
4048
+ // --- Normalize f16 to f32 if needed for consistent APFloat operations.
4049
+ if (N.getValueType ().getVectorElementType () == MVT::f16 ) {
4050
+ bool ignored;
4051
+ FVal.convert (APFloat::IEEEsingle (), APFloat::rmNearestTiesToEven, &ignored);
4052
+ }
4053
+
4054
+ // Handle reciprocal case.
4055
+ if (isReciprocal) {
4056
+ if (!FVal.getExactInverse (&FVal))
4057
+ // Not an exact reciprocal, or reciprocal not a power of 2.
4058
+ return false ;
4059
+ }
4060
+
4061
+ bool IsExact;
4062
+ unsigned TargetIntBits =
4063
+ N.getValueType ().getVectorElementType ().getSizeInBits ();
4064
+ APSInt IntVal (
4065
+ TargetIntBits + 1 ,
4066
+ true ); // Use TargetIntBits + 1 for sufficient bits for conversion
4067
+
4068
+ FVal.convertToInteger (IntVal, APFloat::rmTowardZero, &IsExact);
4069
+
4070
+ if (!IsExact || !IntVal.isPowerOf2 ())
4071
+ return false ;
4072
+
4073
+ unsigned FBits = IntVal.logBase2 ();
4074
+ // FBits must be non-zero (implies actual scaling) and within the range
4075
+ // supported by the instruction (typically 1 to 64 for AArch64 FCVTZS/FCVTZU).
4076
+ // FloatWidth should ideally be the width of the *integer elements* in the
4077
+ // vector (16, 32, 64).
4078
+ if (FBits == 0 || FBits > FloatWidth)
4079
+ return false ;
4080
+
4081
+ // Set FixedPos to the extracted FBits as an i32 constant SDValue.
4082
+ FixedPos = CurDAG->getTargetConstant (FBits, SDLoc (N), MVT::i32 );
4083
+ return true ;
4084
+ }
4085
+
3955
4086
bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand (SDValue N, SDValue &FixedPos,
3956
4087
unsigned RegWidth) {
3957
4088
return checkCVTFixedPointOperandWithFBits (CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4096,12 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
3965
4096
true );
3966
4097
}
3967
4098
4099
+ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec (
4100
+ SDValue N, SDValue &FixedPos, unsigned FloatWidth) {
4101
+ return checkCVTFixedPointOperandWithFBitsForVectors (CurDAG, N, FixedPos,
4102
+ FloatWidth, true );
4103
+ }
4104
+
3968
4105
// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
3969
4106
// of the string and obtains the integer values from them and combines these
3970
4107
// into a single value to be used in the MRS/MSR instruction.
0 commit comments