Skip to content

Commit 559a9db

Browse files
Ami-zhangheiher
andauthored
[LoongArch] Custom lower FP_TO_FP16 and FP16_TO_FP to correct ABI of libcall (#141702)
This change passes 'half' in the lower 16 bits of an f32 value with F/D ABI. LoongArch currently lacks a hardware extension for the fp16 data type, and the ABI manual now documents the half-precision floating-point type following FP calling conventions. Previously, we maintained the 'half' type in its 16-bit format between operations. Regardless of whether the F extension is enabled, the value would be passed in the lower 16 bits of a GPR in its 'half' format. With this patch, depending on the ABI in use, the value will be passed either in an FPR or a GPR in 'half' format. This ensures consistency with the bits location when the fp16 hardware extension is enabled. Co-authored-by: WANG Rui <[email protected]>
1 parent f90cfb1 commit 559a9db

File tree

6 files changed

+2489
-74
lines changed

6 files changed

+2489
-74
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
199199
setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
200200
setOperationAction(ISD::FPOW, MVT::f32, Expand);
201201
setOperationAction(ISD::FREM, MVT::f32, Expand);
202-
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
203-
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
202+
setOperationAction(ISD::FP16_TO_FP, MVT::f32,
203+
Subtarget.isSoftFPABI() ? LibCall : Custom);
204+
setOperationAction(ISD::FP_TO_FP16, MVT::f32,
205+
Subtarget.isSoftFPABI() ? LibCall : Custom);
204206

205207
if (Subtarget.is64Bit())
206208
setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -239,7 +241,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
239241
setOperationAction(ISD::FPOW, MVT::f64, Expand);
240242
setOperationAction(ISD::FREM, MVT::f64, Expand);
241243
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
242-
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
244+
setOperationAction(ISD::FP_TO_FP16, MVT::f64,
245+
Subtarget.isSoftFPABI() ? LibCall : Custom);
243246

244247
if (Subtarget.is64Bit())
245248
setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -490,6 +493,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
490493
return lowerPREFETCH(Op, DAG);
491494
case ISD::SELECT:
492495
return lowerSELECT(Op, DAG);
496+
case ISD::FP_TO_FP16:
497+
return lowerFP_TO_FP16(Op, DAG);
498+
case ISD::FP16_TO_FP:
499+
return lowerFP16_TO_FP(Op, DAG);
493500
}
494501
return SDValue();
495502
}
@@ -2242,6 +2249,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
22422249
return SDValue();
22432250
}
22442251

2252+
SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
2253+
SelectionDAG &DAG) const {
2254+
// Custom lower to ensure the libcall return is passed in an FPR on hard
2255+
// float ABIs.
2256+
SDLoc DL(Op);
2257+
MakeLibCallOptions CallOptions;
2258+
SDValue Op0 = Op.getOperand(0);
2259+
SDValue Chain = SDValue();
2260+
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
2261+
SDValue Res;
2262+
std::tie(Res, Chain) =
2263+
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
2264+
if (Subtarget.is64Bit())
2265+
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
2266+
return DAG.getBitcast(MVT::i32, Res);
2267+
}
2268+
2269+
SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
2270+
SelectionDAG &DAG) const {
2271+
// Custom lower to ensure the libcall argument is passed in an FPR on hard
2272+
// float ABIs.
2273+
SDLoc DL(Op);
2274+
MakeLibCallOptions CallOptions;
2275+
SDValue Op0 = Op.getOperand(0);
2276+
SDValue Chain = SDValue();
2277+
SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
2278+
DL, MVT::f32, Op0)
2279+
: DAG.getBitcast(MVT::f32, Op0);
2280+
SDValue Res;
2281+
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
2282+
CallOptions, DL, Chain);
2283+
return Res;
2284+
}
2285+
22452286
static bool isConstantOrUndef(const SDValue Op) {
22462287
if (Op->isUndef())
22472288
return true;
@@ -3841,6 +3882,8 @@ void LoongArchTargetLowering::ReplaceNodeResults(
38413882
EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
38423883
if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
38433884
TargetLowering::TypeSoftenFloat) {
3885+
if (!isTypeLegal(Src.getValueType()))
3886+
return;
38443887
if (Src.getValueType() == MVT::f16)
38453888
Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
38463889
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
@@ -5289,6 +5332,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
52895332
return SDValue();
52905333
}
52915334

5335+
static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
5336+
TargetLowering::DAGCombinerInfo &DCI,
5337+
const LoongArchSubtarget &Subtarget) {
5338+
// If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
5339+
// conversion is unnecessary and can be replaced with the
5340+
// MOVFR2GR_S_LA64 operand.
5341+
SDValue Op0 = N->getOperand(0);
5342+
if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
5343+
return Op0.getOperand(0);
5344+
return SDValue();
5345+
}
5346+
5347+
static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
5348+
TargetLowering::DAGCombinerInfo &DCI,
5349+
const LoongArchSubtarget &Subtarget) {
5350+
// If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
5351+
// conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
5352+
// operand.
5353+
SDValue Op0 = N->getOperand(0);
5354+
MVT VT = N->getSimpleValueType(0);
5355+
if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
5356+
assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
5357+
return Op0.getOperand(0);
5358+
}
5359+
return SDValue();
5360+
}
5361+
52925362
SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
52935363
DAGCombinerInfo &DCI) const {
52945364
SelectionDAG &DAG = DCI.DAG;
@@ -5307,6 +5377,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
53075377
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
53085378
case ISD::INTRINSIC_WO_CHAIN:
53095379
return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
5380+
case LoongArchISD::MOVGR2FR_W_LA64:
5381+
return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
5382+
case LoongArchISD::MOVFR2GR_S_LA64:
5383+
return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
53105384
}
53115385
return SDValue();
53125386
}
@@ -7633,3 +7707,61 @@ LoongArchTargetLowering::getPreferredVectorAction(MVT VT) const {
76337707

76347708
return TargetLoweringBase::getPreferredVectorAction(VT);
76357709
}
7710+
7711+
bool LoongArchTargetLowering::splitValueIntoRegisterParts(
7712+
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
7713+
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
7714+
bool IsABIRegCopy = CC.has_value();
7715+
EVT ValueVT = Val.getValueType();
7716+
7717+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
7718+
// Cast the f16 to i16, extend to i32, pad with ones to make a float
7719+
// nan, and cast to f32.
7720+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
7721+
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
7722+
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
7723+
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
7724+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
7725+
Parts[0] = Val;
7726+
return true;
7727+
}
7728+
7729+
return false;
7730+
}
7731+
7732+
SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
7733+
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
7734+
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
7735+
bool IsABIRegCopy = CC.has_value();
7736+
7737+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
7738+
SDValue Val = Parts[0];
7739+
7740+
// Cast the f32 to i32, truncate to i16, and cast back to f16.
7741+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
7742+
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
7743+
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
7744+
return Val;
7745+
}
7746+
7747+
return SDValue();
7748+
}
7749+
7750+
MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
7751+
CallingConv::ID CC,
7752+
EVT VT) const {
7753+
// Use f32 to pass f16.
7754+
if (VT == MVT::f16 && Subtarget.hasBasicF())
7755+
return MVT::f32;
7756+
7757+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
7758+
}
7759+
7760+
unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
7761+
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
7762+
// Use f32 to pass f16.
7763+
if (VT == MVT::f16 && Subtarget.hasBasicF())
7764+
return 1;
7765+
7766+
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
7767+
}

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,8 @@ class LoongArchTargetLowering : public TargetLowering {
361361
SDValue lowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) const;
362362
SDValue lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const;
363363
SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
364+
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
365+
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
364366

365367
bool isFPImmLegal(const APFloat &Imm, EVT VT,
366368
bool ForCodeSize) const override;
@@ -385,6 +387,28 @@ class LoongArchTargetLowering : public TargetLowering {
385387
const SmallVectorImpl<CCValAssign> &ArgLocs) const;
386388

387389
bool softPromoteHalfType() const override { return true; }
390+
391+
bool
392+
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
393+
SDValue *Parts, unsigned NumParts, MVT PartVT,
394+
std::optional<CallingConv::ID> CC) const override;
395+
396+
SDValue
397+
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
398+
const SDValue *Parts, unsigned NumParts,
399+
MVT PartVT, EVT ValueVT,
400+
std::optional<CallingConv::ID> CC) const override;
401+
402+
/// Return the register type for a given MVT, ensuring vectors are treated
403+
/// as a series of gpr sized integers.
404+
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
405+
EVT VT) const override;
406+
407+
/// Return the number of registers for a given MVT, ensuring vectors are
408+
/// treated as a series of gpr sized integers.
409+
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
410+
CallingConv::ID CC,
411+
EVT VT) const override;
388412
};
389413

390414
} // end namespace llvm

0 commit comments

Comments
 (0)