Skip to content

[RISCV] Support STRICT_UINT_TO_FP and STRICT_SINT_TO_FP #102503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2385,32 +2385,26 @@ SDValue DAGTypeLegalizer::ExpandFloatOp_LLRINT(SDNode *N) {
//

static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
if (OpVT == MVT::f16) {
if (OpVT == MVT::f16)
return ISD::FP16_TO_FP;
} else if (RetVT == MVT::f16) {
if (RetVT == MVT::f16)
return ISD::FP_TO_FP16;
} else if (OpVT == MVT::bf16) {
if (OpVT == MVT::bf16)
return ISD::BF16_TO_FP;
} else if (RetVT == MVT::bf16) {
if (RetVT == MVT::bf16)
return ISD::FP_TO_BF16;
}

report_fatal_error("Attempt at an invalid promotion-related conversion");
}

static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
if (OpVT == MVT::f16)
return ISD::STRICT_FP16_TO_FP;

if (RetVT == MVT::f16)
return ISD::STRICT_FP_TO_FP16;

if (OpVT == MVT::bf16)
return ISD::STRICT_BF16_TO_FP;

if (RetVT == MVT::bf16)
return ISD::STRICT_FP_TO_BF16;

report_fatal_error("Attempt at an invalid promotion-related conversion");
}

Expand Down Expand Up @@ -3138,6 +3132,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
break;
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP: R = SoftPromoteHalfRes_XINT_TO_FP(N); break;
case ISD::UNDEF: R = SoftPromoteHalfRes_UNDEF(N); break;
Expand Down Expand Up @@ -3288,19 +3284,13 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FFREXP(SDNode *N) {

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
EVT RVT = N->getValueType(0);
EVT SVT = N->getOperand(0).getValueType();
bool IsStrict = N->isStrictFPOpcode();
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
EVT SVT = Op.getValueType();

if (N->isStrictFPOpcode()) {
// FIXME: assume we only have two f16 variants for now.
unsigned Opcode;
if (RVT == MVT::f16)
Opcode = ISD::STRICT_FP_TO_FP16;
else if (RVT == MVT::bf16)
Opcode = ISD::STRICT_FP_TO_BF16;
else
llvm_unreachable("unknown half type");
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
{N->getOperand(0), N->getOperand(1)});
if (IsStrict) {
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
{MVT::i16, MVT::Other}, {N->getOperand(0), Op});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}
Expand Down Expand Up @@ -3359,6 +3349,15 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDLoc dl(N);

if (N->isStrictFPOpcode()) {
SDValue Op = DAG.getNode(N->getOpcode(), dl, {NVT, MVT::Other},
{N->getOperand(0), N->getOperand(1)});
Op = DAG.getNode(GetPromotionOpcodeStrict(NVT, OVT), dl,
{MVT::i16, MVT::Other}, {Op.getValue(1), Op});
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
return Op;
}

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));

// Round the value to the softened type.
Expand Down Expand Up @@ -3447,6 +3446,8 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
Res = SoftPromoteHalfOp_FAKE_USE(N, OpNo);
break;
case ISD::FCOPYSIGN: Res = SoftPromoteHalfOp_FCOPYSIGN(N, OpNo); break;
case ISD::STRICT_FP_TO_SINT:
case ISD::STRICT_FP_TO_UINT:
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT: Res = SoftPromoteHalfOp_FP_TO_XINT(N); break;
case ISD::FP_TO_SINT_SAT:
Expand All @@ -3473,7 +3474,7 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {

assert(Res.getNode() != N && "Expected a new node!");

assert(Res.getValueType() == N->getValueType(0) && N->getNumValues() == 1 &&
assert(Res.getValueType() == N->getValueType(0) &&
"Invalid operand expansion");

ReplaceValueWith(SDValue(N, 0), Res);
Expand Down Expand Up @@ -3517,16 +3518,8 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));

if (IsStrict) {
unsigned Opcode;
if (SVT == MVT::f16)
Opcode = ISD::STRICT_FP16_TO_FP;
else if (SVT == MVT::bf16)
Opcode = ISD::STRICT_BF16_TO_FP;
else
llvm_unreachable("unknown half type");
SDValue Res =
DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
{N->getOperand(0), Op});
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
{RVT, MVT::Other}, {N->getOperand(0), Op});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
Expand All @@ -3537,17 +3530,25 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {

SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT(SDNode *N) {
EVT RVT = N->getValueType(0);
SDValue Op = N->getOperand(0);
bool IsStrict = N->isStrictFPOpcode();
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
EVT SVT = Op.getValueType();
SDLoc dl(N);

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType());

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), SVT);
Op = GetSoftPromotedHalf(Op);

SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
if (IsStrict) {
Op = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), dl, {NVT, MVT::Other},
{N->getOperand(0), Op});
Op = DAG.getNode(N->getOpcode(), dl, {RVT, MVT::Other},
{Op.getValue(1), Op});
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
return Op;
}

return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Res);
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
return DAG.getNode(N->getOpcode(), dl, RVT, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT_SAT(SDNode *N) {
Expand Down
38 changes: 27 additions & 11 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
Subtarget.isSoftFPABI() ? LibCall : Custom);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f32, Custom);

if (Subtarget.hasStdExtZfa()) {
setOperationAction(ISD::ConstantFP, MVT::f32, Custom);
Expand Down Expand Up @@ -581,6 +583,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
Subtarget.isSoftFPABI() ? LibCall : Custom);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f64, Custom);
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f64, Expand);
}

if (Subtarget.is64Bit()) {
Expand Down Expand Up @@ -6881,33 +6885,45 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
return Res;
}
case ISD::STRICT_FP_TO_FP16:
case ISD::FP_TO_FP16: {
// Custom lower to ensure the libcall return is passed in an FPR on hard
// float ABIs.
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::f16);
SDValue Res =
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
bool IsStrict = Op->isStrictFPOpcode();
SDValue Op0 = IsStrict ? Op.getOperand(1) : Op.getOperand(0);
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
SDValue Res;
std::tie(Res, Chain) =
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
if (Subtarget.is64Bit())
return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
return DAG.getBitcast(MVT::i32, Res);
SDValue Result = DAG.getBitcast(MVT::i32, IsStrict ? Res.getValue(0) : Res);
if (IsStrict)
return DAG.getMergeValues({Result, Chain}, DL);
return Result;
}
case ISD::STRICT_FP16_TO_FP:
case ISD::FP16_TO_FP: {
// Custom lower to ensure the libcall argument is passed in an FPR on hard
// float ABIs.
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
bool IsStrict = Op->isStrictFPOpcode();
SDValue Op0 = IsStrict ? Op.getOperand(1) : Op.getOperand(0);
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
SDValue Arg = Subtarget.is64Bit()
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32,
Op.getOperand(0))
: DAG.getBitcast(MVT::f32, Op.getOperand(0));
SDValue Res =
makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL)
.first;
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op0)
: DAG.getBitcast(MVT::f32, Op0);
SDValue Res;
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
CallOptions, DL, Chain);
if (IsStrict)
return DAG.getMergeValues({Res, Chain}, DL);
return Res;
}
case ISD::FTRUNC:
Expand Down
Loading
Loading