@@ -19087,69 +19087,82 @@ static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
19087
19087
/// functions, this can help to reduce the number of fmovs to/from GPRs.
19088
19088
static SDValue
19089
19089
tryToReplaceScalarFPConversionWithSVE(SDNode *N, SelectionDAG &DAG,
19090
+ TargetLowering::DAGCombinerInfo &DCI,
19090
19091
const AArch64Subtarget *Subtarget) {
19091
19092
if (N->isStrictFPOpcode())
19092
19093
return SDValue();
19093
19094
19095
+ if (DCI.isBeforeLegalizeOps())
19096
+ return SDValue();
19097
+
19094
19098
if (!Subtarget->isSVEorStreamingSVEAvailable() ||
19095
19099
(!Subtarget->isStreaming() && !Subtarget->isStreamingCompatible()))
19096
19100
return SDValue();
19097
19101
19098
19102
auto isSupportedType = [](EVT VT) {
19099
- if (!VT.isSimple())
19100
- return false;
19101
- // There are SVE instructions that can convert to/from all pairs of these
19102
- // int and float types. Note: We don't bother with i8 or i16 as those are
19103
- // illegal types for scalars.
19104
- return is_contained({MVT::i32, MVT::i64, MVT::f16, MVT::f32, MVT::f64},
19105
- VT.getSimpleVT().SimpleTy);
19103
+ return VT != MVT::bf16 && VT != MVT::f128;
19106
19104
};
19107
19105
19108
19106
if (!isSupportedType(N->getValueType(0)) ||
19109
19107
!isSupportedType(N->getOperand(0).getValueType()))
19110
19108
return SDValue();
19111
19109
19110
+ // Look through fp_extends to avoid extra fcvts.
19112
19111
SDValue SrcVal = N->getOperand(0);
19112
+ if (SrcVal->getOpcode() == ISD::FP_EXTEND &&
19113
+ isSupportedType(SrcVal->getOperand(0).getValueType()))
19114
+ SrcVal = SrcVal->getOperand(0);
19115
+
19113
19116
EVT SrcTy = SrcVal.getValueType();
19114
19117
EVT DestTy = N->getValueType(0);
19115
19118
19116
- bool IsI32ToF64 = SrcTy == MVT::i32 && DestTy == MVT::f64;
19117
- bool isF64ToI32 = SrcTy == MVT::f64 && DestTy == MVT::i32;
19118
-
19119
- // Conversions between f64 and i32 are a special case as nxv2i32 is an illegal
19120
- // type (unlike the equivalent nxv2f32 for floating-point types).
19121
- // TODO: Support these conversations.
19122
- if (IsI32ToF64 || isF64ToI32)
19123
- return SDValue();
19119
+ // Merge in any subsequent fp_round to avoid extra fcvts.
19120
+ SDNode *FPRoundNode = nullptr;
19121
+ if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::FP_ROUND &&
19122
+ isSupportedType(N->use_begin()->getValueType(0))) {
19123
+ FPRoundNode = *N->use_begin();
19124
+ DestTy = FPRoundNode->getValueType(0);
19125
+ }
19124
19126
19125
19127
EVT SrcVecTy;
19126
19128
EVT DestVecTy;
19127
19129
if (DestTy.bitsGT(SrcTy)) {
19128
19130
DestVecTy = getPackedSVEVectorVT(DestTy);
19129
- SrcVecTy = SrcTy == MVT::i32 ? getPackedSVEVectorVT(SrcTy)
19130
- : DestVecTy.changeVectorElementType(SrcTy);
19131
+ SrcVecTy = DestVecTy.changeVectorElementType(SrcTy);
19131
19132
} else {
19132
19133
SrcVecTy = getPackedSVEVectorVT(SrcTy);
19133
- DestVecTy = DestTy == MVT::i32 ? getPackedSVEVectorVT(DestTy)
19134
- : SrcVecTy.changeVectorElementType(DestTy);
19134
+ DestVecTy = SrcVecTy.changeVectorElementType(DestTy);
19135
19135
}
19136
19136
19137
+ // Ensure the resulting src/dest vector type is legal.
19138
+ if (SrcVecTy == MVT::nxv2i32 || DestVecTy == MVT::nxv2i32)
19139
+ return SDValue();
19140
+
19137
19141
SDLoc DL(N);
19138
19142
SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
19139
19143
SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, SrcVecTy,
19140
19144
DAG.getUNDEF(SrcVecTy), SrcVal, ZeroIdx);
19141
19145
SDValue Convert = DAG.getNode(N->getOpcode(), DL, DestVecTy, Vec);
19142
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, DestTy, Convert, ZeroIdx);
19146
+ SDValue Scalar =
19147
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, DestTy, Convert, ZeroIdx);
19148
+
19149
+ if (FPRoundNode) {
19150
+ DAG.ReplaceAllUsesWith(SDValue(FPRoundNode, 0), Scalar);
19151
+ return SDValue();
19152
+ }
19153
+ return Scalar;
19143
19154
}
19144
19155
19145
19156
static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
19157
+ TargetLowering::DAGCombinerInfo &DCI,
19146
19158
const AArch64Subtarget *Subtarget) {
19147
19159
// First try to optimize away the conversion when it's conditionally from
19148
19160
// a constant. Vectors only.
19149
19161
if (SDValue Res = performVectorCompareAndMaskUnaryOpCombine(N, DAG))
19150
19162
return Res;
19151
19163
19152
- if (SDValue Res = tryToReplaceScalarFPConversionWithSVE(N, DAG, Subtarget))
19164
+ if (SDValue Res =
19165
+ tryToReplaceScalarFPConversionWithSVE(N, DAG, DCI, Subtarget))
19153
19166
return Res;
19154
19167
19155
19168
EVT VT = N->getValueType(0);
@@ -19190,7 +19203,8 @@ static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
19190
19203
static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
19191
19204
TargetLowering::DAGCombinerInfo &DCI,
19192
19205
const AArch64Subtarget *Subtarget) {
19193
- if (SDValue Res = tryToReplaceScalarFPConversionWithSVE(N, DAG, Subtarget))
19206
+ if (SDValue Res =
19207
+ tryToReplaceScalarFPConversionWithSVE(N, DAG, DCI, Subtarget))
19194
19208
return Res;
19195
19209
19196
19210
if (!Subtarget->isNeonAvailable())
@@ -26273,7 +26287,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
26273
26287
return performMulCombine(N, DAG, DCI, Subtarget);
26274
26288
case ISD::SINT_TO_FP:
26275
26289
case ISD::UINT_TO_FP:
26276
- return performIntToFpCombine(N, DAG, Subtarget);
26290
+ return performIntToFpCombine(N, DAG, DCI, Subtarget);
26277
26291
case ISD::FP_TO_SINT:
26278
26292
case ISD::FP_TO_UINT:
26279
26293
case ISD::FP_TO_SINT_SAT:
0 commit comments