@@ -199,8 +199,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
199
199
setOperationAction (ISD::FSINCOS, MVT::f32 , Expand);
200
200
setOperationAction (ISD::FPOW, MVT::f32 , Expand);
201
201
setOperationAction (ISD::FREM, MVT::f32 , Expand);
202
- setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Expand);
203
- setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Expand);
202
+ setOperationAction (ISD::FP16_TO_FP, MVT::f32 ,
203
+ Subtarget.isSoftFPABI () ? LibCall : Custom);
204
+ setOperationAction (ISD::FP_TO_FP16, MVT::f32 ,
205
+ Subtarget.isSoftFPABI () ? LibCall : Custom);
204
206
205
207
if (Subtarget.is64Bit ())
206
208
setOperationAction (ISD::FRINT, MVT::f32 , Legal);
@@ -239,7 +241,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
239
241
setOperationAction (ISD::FPOW, MVT::f64 , Expand);
240
242
setOperationAction (ISD::FREM, MVT::f64 , Expand);
241
243
setOperationAction (ISD::FP16_TO_FP, MVT::f64 , Expand);
242
- setOperationAction (ISD::FP_TO_FP16, MVT::f64 , Expand);
244
+ setOperationAction (ISD::FP_TO_FP16, MVT::f64 ,
245
+ Subtarget.isSoftFPABI () ? LibCall : Custom);
243
246
244
247
if (Subtarget.is64Bit ())
245
248
setOperationAction (ISD::FRINT, MVT::f64 , Legal);
@@ -490,6 +493,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
490
493
return lowerPREFETCH (Op, DAG);
491
494
case ISD::SELECT:
492
495
return lowerSELECT (Op, DAG);
496
+ case ISD::FP_TO_FP16:
497
+ return lowerFP_TO_FP16 (Op, DAG);
498
+ case ISD::FP16_TO_FP:
499
+ return lowerFP16_TO_FP (Op, DAG);
493
500
}
494
501
return SDValue ();
495
502
}
@@ -2242,6 +2249,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
2242
2249
return SDValue ();
2243
2250
}
2244
2251
2252
+ SDValue LoongArchTargetLowering::lowerFP_TO_FP16 (SDValue Op,
2253
+ SelectionDAG &DAG) const {
2254
+ // Custom lower to ensure the libcall return is passed in an FPR on hard
2255
+ // float ABIs.
2256
+ SDLoc DL (Op);
2257
+ MakeLibCallOptions CallOptions;
2258
+ SDValue Op0 = Op.getOperand (0 );
2259
+ SDValue Chain = SDValue ();
2260
+ RTLIB::Libcall LC = RTLIB::getFPROUND (Op0.getValueType (), MVT::f16 );
2261
+ SDValue Res;
2262
+ std::tie (Res, Chain) =
2263
+ makeLibCall (DAG, LC, MVT::f32 , Op0, CallOptions, DL, Chain);
2264
+ if (Subtarget.is64Bit ())
2265
+ return DAG.getNode (LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64 , Res);
2266
+ return DAG.getBitcast (MVT::i32 , Res);
2267
+ }
2268
+
2269
+ SDValue LoongArchTargetLowering::lowerFP16_TO_FP (SDValue Op,
2270
+ SelectionDAG &DAG) const {
2271
+ // Custom lower to ensure the libcall argument is passed in an FPR on hard
2272
+ // float ABIs.
2273
+ SDLoc DL (Op);
2274
+ MakeLibCallOptions CallOptions;
2275
+ SDValue Op0 = Op.getOperand (0 );
2276
+ SDValue Chain = SDValue ();
2277
+ SDValue Arg = Subtarget.is64Bit () ? DAG.getNode (LoongArchISD::MOVGR2FR_W_LA64,
2278
+ DL, MVT::f32 , Op0)
2279
+ : DAG.getBitcast (MVT::f32 , Op0);
2280
+ SDValue Res;
2281
+ std::tie (Res, Chain) = makeLibCall (DAG, RTLIB::FPEXT_F16_F32, MVT::f32 , Arg,
2282
+ CallOptions, DL, Chain);
2283
+ return Res;
2284
+ }
2285
+
2245
2286
static bool isConstantOrUndef (const SDValue Op) {
2246
2287
if (Op->isUndef ())
2247
2288
return true ;
@@ -3841,6 +3882,8 @@ void LoongArchTargetLowering::ReplaceNodeResults(
3841
3882
EVT FVT = EVT::getFloatingPointVT (N->getValueSizeInBits (0 ));
3842
3883
if (getTypeAction (*DAG.getContext (), Src.getValueType ()) !=
3843
3884
TargetLowering::TypeSoftenFloat) {
3885
+ if (!isTypeLegal (Src.getValueType ()))
3886
+ return ;
3844
3887
if (Src.getValueType () == MVT::f16 )
3845
3888
Src = DAG.getNode (ISD::FP_EXTEND, DL, MVT::f32 , Src);
3846
3889
SDValue Dst = DAG.getNode (LoongArchISD::FTINT, DL, FVT, Src);
@@ -5289,6 +5332,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
5289
5332
return SDValue ();
5290
5333
}
5291
5334
5335
+ static SDValue performMOVGR2FR_WCombine (SDNode *N, SelectionDAG &DAG,
5336
+ TargetLowering::DAGCombinerInfo &DCI,
5337
+ const LoongArchSubtarget &Subtarget) {
5338
+ // If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
5339
+ // conversion is unnecessary and can be replaced with the
5340
+ // MOVFR2GR_S_LA64 operand.
5341
+ SDValue Op0 = N->getOperand (0 );
5342
+ if (Op0.getOpcode () == LoongArchISD::MOVFR2GR_S_LA64)
5343
+ return Op0.getOperand (0 );
5344
+ return SDValue ();
5345
+ }
5346
+
5347
+ static SDValue performMOVFR2GR_SCombine (SDNode *N, SelectionDAG &DAG,
5348
+ TargetLowering::DAGCombinerInfo &DCI,
5349
+ const LoongArchSubtarget &Subtarget) {
5350
+ // If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
5351
+ // conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
5352
+ // operand.
5353
+ SDValue Op0 = N->getOperand (0 );
5354
+ MVT VT = N->getSimpleValueType (0 );
5355
+ if (Op0->getOpcode () == LoongArchISD::MOVGR2FR_W_LA64) {
5356
+ assert (Op0.getOperand (0 ).getValueType () == VT && " Unexpected value type!" );
5357
+ return Op0.getOperand (0 );
5358
+ }
5359
+ return SDValue ();
5360
+ }
5361
+
5292
5362
SDValue LoongArchTargetLowering::PerformDAGCombine (SDNode *N,
5293
5363
DAGCombinerInfo &DCI) const {
5294
5364
SelectionDAG &DAG = DCI.DAG ;
@@ -5307,6 +5377,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
5307
5377
return performBITREV_WCombine (N, DAG, DCI, Subtarget);
5308
5378
case ISD::INTRINSIC_WO_CHAIN:
5309
5379
return performINTRINSIC_WO_CHAINCombine (N, DAG, DCI, Subtarget);
5380
+ case LoongArchISD::MOVGR2FR_W_LA64:
5381
+ return performMOVGR2FR_WCombine (N, DAG, DCI, Subtarget);
5382
+ case LoongArchISD::MOVFR2GR_S_LA64:
5383
+ return performMOVFR2GR_SCombine (N, DAG, DCI, Subtarget);
5310
5384
}
5311
5385
return SDValue ();
5312
5386
}
@@ -7633,3 +7707,61 @@ LoongArchTargetLowering::getPreferredVectorAction(MVT VT) const {
7633
7707
7634
7708
return TargetLoweringBase::getPreferredVectorAction (VT);
7635
7709
}
7710
+
7711
+ bool LoongArchTargetLowering::splitValueIntoRegisterParts (
7712
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
7713
+ unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
7714
+ bool IsABIRegCopy = CC.has_value ();
7715
+ EVT ValueVT = Val.getValueType ();
7716
+
7717
+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
7718
+ // Cast the f16 to i16, extend to i32, pad with ones to make a float
7719
+ // nan, and cast to f32.
7720
+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i16 , Val);
7721
+ Val = DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i32 , Val);
7722
+ Val = DAG.getNode (ISD::OR, DL, MVT::i32 , Val,
7723
+ DAG.getConstant (0xFFFF0000 , DL, MVT::i32 ));
7724
+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::f32 , Val);
7725
+ Parts[0 ] = Val;
7726
+ return true ;
7727
+ }
7728
+
7729
+ return false ;
7730
+ }
7731
+
7732
+ SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue (
7733
+ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
7734
+ MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
7735
+ bool IsABIRegCopy = CC.has_value ();
7736
+
7737
+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
7738
+ SDValue Val = Parts[0 ];
7739
+
7740
+ // Cast the f32 to i32, truncate to i16, and cast back to f16.
7741
+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i32 , Val);
7742
+ Val = DAG.getNode (ISD::TRUNCATE, DL, MVT::i16 , Val);
7743
+ Val = DAG.getNode (ISD::BITCAST, DL, ValueVT, Val);
7744
+ return Val;
7745
+ }
7746
+
7747
+ return SDValue ();
7748
+ }
7749
+
7750
+ MVT LoongArchTargetLowering::getRegisterTypeForCallingConv (LLVMContext &Context,
7751
+ CallingConv::ID CC,
7752
+ EVT VT) const {
7753
+ // Use f32 to pass f16.
7754
+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
7755
+ return MVT::f32 ;
7756
+
7757
+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
7758
+ }
7759
+
7760
+ unsigned LoongArchTargetLowering::getNumRegistersForCallingConv (
7761
+ LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
7762
+ // Use f32 to pass f16.
7763
+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
7764
+ return 1 ;
7765
+
7766
+ return TargetLowering::getNumRegistersForCallingConv (Context, CC, VT);
7767
+ }
0 commit comments