Skip to content

[LoongArch] Custom lower FP_TO_FP16 and FP16_TO_FP to correct ABI of libcall #141702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 135 additions & 3 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
setOperationAction(ISD::FPOW, MVT::f32, Expand);
setOperationAction(ISD::FREM, MVT::f32, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f32,
Subtarget.isSoftFPABI() ? LibCall : Custom);
setOperationAction(ISD::FP_TO_FP16, MVT::f32,
Subtarget.isSoftFPABI() ? LibCall : Custom);

if (Subtarget.is64Bit())
setOperationAction(ISD::FRINT, MVT::f32, Legal);
Expand Down Expand Up @@ -239,7 +241,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FPOW, MVT::f64, Expand);
setOperationAction(ISD::FREM, MVT::f64, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64,
Subtarget.isSoftFPABI() ? LibCall : Custom);

if (Subtarget.is64Bit())
setOperationAction(ISD::FRINT, MVT::f64, Legal);
Expand Down Expand Up @@ -490,6 +493,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
return lowerPREFETCH(Op, DAG);
case ISD::SELECT:
return lowerSELECT(Op, DAG);
case ISD::FP_TO_FP16:
return lowerFP_TO_FP16(Op, DAG);
case ISD::FP16_TO_FP:
return lowerFP16_TO_FP(Op, DAG);
}
return SDValue();
}
Expand Down Expand Up @@ -2242,6 +2249,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
return SDValue();
}

SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
SelectionDAG &DAG) const {
// Custom lower to ensure the libcall return is passed in an FPR on hard
// float ABIs.
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
SDValue Op0 = Op.getOperand(0);
SDValue Chain = SDValue();
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
SDValue Res;
std::tie(Res, Chain) =
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
if (Subtarget.is64Bit())
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
return DAG.getBitcast(MVT::i32, Res);
}

SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
// Custom lower to ensure the libcall argument is passed in an FPR on hard
// float ABIs.
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
SDValue Op0 = Op.getOperand(0);
SDValue Chain = SDValue();
SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
DL, MVT::f32, Op0)
: DAG.getBitcast(MVT::f32, Op0);
SDValue Res;
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
CallOptions, DL, Chain);
return Res;
}

static bool isConstantOrUndef(const SDValue Op) {
if (Op->isUndef())
return true;
Expand Down Expand Up @@ -3841,6 +3882,8 @@ void LoongArchTargetLowering::ReplaceNodeResults(
EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
TargetLowering::TypeSoftenFloat) {
if (!isTypeLegal(Src.getValueType()))
return;
if (Src.getValueType() == MVT::f16)
Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
Expand Down Expand Up @@ -5289,6 +5332,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
// If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
// conversion is unnecessary and can be replaced with the
// MOVFR2GR_S_LA64 operand.
SDValue Op0 = N->getOperand(0);
if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
return Op0.getOperand(0);
return SDValue();
}

static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
// If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
// conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
// operand.
SDValue Op0 = N->getOperand(0);
MVT VT = N->getSimpleValueType(0);
if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
return Op0.getOperand(0);
}
return SDValue();
}

SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand All @@ -5307,6 +5377,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::MOVGR2FR_W_LA64:
return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::MOVFR2GR_S_LA64:
return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
}
return SDValue();
}
Expand Down Expand Up @@ -7633,3 +7707,61 @@ LoongArchTargetLowering::getPreferredVectorAction(MVT VT) const {

return TargetLoweringBase::getPreferredVectorAction(VT);
}

bool LoongArchTargetLowering::splitValueIntoRegisterParts(
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();

if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
// Cast the f16 to i16, extend to i32, pad with ones to make a float
// nan, and cast to f32.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
Parts[0] = Val;
return true;
}

return false;
}

SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();

if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
SDValue Val = Parts[0];

// Cast the f32 to i32, truncate to i16, and cast back to f16.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
return Val;
}

return SDValue();
}

MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
CallingConv::ID CC,
EVT VT) const {
// Use f32 to pass f16.
if (VT == MVT::f16 && Subtarget.hasBasicF())
return MVT::f32;

return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}

unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
// Use f32 to pass f16.
if (VT == MVT::f16 && Subtarget.hasBasicF())
return 1;

return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
24 changes: 24 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class LoongArchTargetLowering : public TargetLowering {
SDValue lowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;

bool isFPImmLegal(const APFloat &Imm, EVT VT,
bool ForCodeSize) const override;
Expand All @@ -385,6 +387,28 @@ class LoongArchTargetLowering : public TargetLowering {
const SmallVectorImpl<CCValAssign> &ArgLocs) const;

bool softPromoteHalfType() const override { return true; }

bool
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
SDValue *Parts, unsigned NumParts, MVT PartVT,
std::optional<CallingConv::ID> CC) const override;

SDValue
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT,
std::optional<CallingConv::ID> CC) const override;

/// Return the register type for a given MVT, ensuring vectors are treated
/// as a series of gpr sized integers.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
EVT VT) const override;

/// Return the number of registers for a given MVT, ensuring vectors are
/// treated as a series of gpr sized integers.
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
CallingConv::ID CC,
EVT VT) const override;
};

} // end namespace llvm
Expand Down
Loading
Loading