@@ -5106,6 +5106,30 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
5106
5106
uint64_t VTSize = VT.getFixedSizeInBits();
5107
5107
uint64_t InVTSize = InVT.getFixedSizeInBits();
5108
5108
if (VTSize < InVTSize) {
5109
+ // AArch64 doesn't have a direct vector instruction to convert
5110
+ // fixed point to floating point AND narrow it at the same time.
5111
+ // Additional rounding when the target is f32/f64 causes subtle
5112
+ // differences across different platforms (that do have such
5113
+ // instructions). Conversion to f16 however is fine.
5114
+ bool IsTargetf32Orf64 = VT.getVectorElementType() == MVT::f32 ||
5115
+ VT.getVectorElementType() == MVT::f64;
5116
+ bool IsTargetf16 = false;
5117
+ if (Op.hasOneUse() && Op->user_begin()->getOpcode() == ISD::CONCAT_VECTORS) {
5118
+ // Some vector types are split during legalization into half, followed by
5119
+ // concatenation, followed by rounding to the original vector type. If we
5120
+ // end up resolving to f16 type, we shouldn't worry about rounding errors.
5121
+ SDNode *U = *Op->user_begin();
5122
+ if (U->hasOneUse() && U->user_begin()->getOpcode() == ISD::FP_ROUND) {
5123
+ EVT TmpVT = U->user_begin()->getValueType(0);
5124
+ if (TmpVT.isVector() && TmpVT.getVectorElementType() == MVT::f16)
5125
+ IsTargetf16 = true;
5126
+ }
5127
+ }
5128
+
5129
+ if (IsTargetf32Orf64 && !IsTargetf16) {
5130
+ return SDValue();
5131
+ }
5132
+
5109
5133
MVT CastVT =
5110
5134
MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
5111
5135
InVT.getVectorNumElements());
0 commit comments