Skip to content

Commit 982f101

Browse files
[RISCV] Add support for STRICT_UINT_TO_FP and STRICT_SINT_TO_FP
This patch adds support for the missing STRICT_UINT_TO_FP and STRICT_SINT_TO_FP for riscv and adds a test case for rv32 which was previously crashing. The code is in line with how other strict_* nodes are handled (e.g., getting op(1) instead of op(0) when it's a strict node, as op(0) in a strict node is the entry token).
1 parent ce2e386 commit 982f101

File tree

4 files changed

+417
-55
lines changed

4 files changed

+417
-55
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,32 +2385,26 @@ SDValue DAGTypeLegalizer::ExpandFloatOp_LLRINT(SDNode *N) {
23852385
//
23862386

23872387
static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
2388-
if (OpVT == MVT::f16) {
2388+
if (OpVT == MVT::f16)
23892389
return ISD::FP16_TO_FP;
2390-
} else if (RetVT == MVT::f16) {
2390+
if (RetVT == MVT::f16)
23912391
return ISD::FP_TO_FP16;
2392-
} else if (OpVT == MVT::bf16) {
2392+
if (OpVT == MVT::bf16)
23932393
return ISD::BF16_TO_FP;
2394-
} else if (RetVT == MVT::bf16) {
2394+
if (RetVT == MVT::bf16)
23952395
return ISD::FP_TO_BF16;
2396-
}
2397-
23982396
report_fatal_error("Attempt at an invalid promotion-related conversion");
23992397
}
24002398

24012399
static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
24022400
if (OpVT == MVT::f16)
24032401
return ISD::STRICT_FP16_TO_FP;
2404-
24052402
if (RetVT == MVT::f16)
24062403
return ISD::STRICT_FP_TO_FP16;
2407-
24082404
if (OpVT == MVT::bf16)
24092405
return ISD::STRICT_BF16_TO_FP;
2410-
24112406
if (RetVT == MVT::bf16)
24122407
return ISD::STRICT_FP_TO_BF16;
2413-
24142408
report_fatal_error("Attempt at an invalid promotion-related conversion");
24152409
}
24162410

@@ -3138,6 +3132,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
31383132
break;
31393133
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
31403134
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
3135+
case ISD::STRICT_SINT_TO_FP:
3136+
case ISD::STRICT_UINT_TO_FP:
31413137
case ISD::SINT_TO_FP:
31423138
case ISD::UINT_TO_FP: R = SoftPromoteHalfRes_XINT_TO_FP(N); break;
31433139
case ISD::UNDEF: R = SoftPromoteHalfRes_UNDEF(N); break;
@@ -3288,19 +3284,13 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FFREXP(SDNode *N) {
32883284

32893285
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
32903286
EVT RVT = N->getValueType(0);
3291-
EVT SVT = N->getOperand(0).getValueType();
3287+
bool IsStrict = N->isStrictFPOpcode();
3288+
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
3289+
EVT SVT = Op.getValueType();
32923290

3293-
if (N->isStrictFPOpcode()) {
3294-
// FIXME: assume we only have two f16 variants for now.
3295-
unsigned Opcode;
3296-
if (RVT == MVT::f16)
3297-
Opcode = ISD::STRICT_FP_TO_FP16;
3298-
else if (RVT == MVT::bf16)
3299-
Opcode = ISD::STRICT_FP_TO_BF16;
3300-
else
3301-
llvm_unreachable("unknown half type");
3302-
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
3303-
{N->getOperand(0), N->getOperand(1)});
3291+
if (IsStrict) {
3292+
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
3293+
{MVT::i16, MVT::Other}, {N->getOperand(0), Op});
33043294
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
33053295
return Res;
33063296
}
@@ -3359,6 +3349,15 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
33593349
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
33603350
SDLoc dl(N);
33613351

3352+
if (N->isStrictFPOpcode()) {
3353+
SDValue Op = DAG.getNode(N->getOpcode(), dl, {NVT, MVT::Other},
3354+
{N->getOperand(0), N->getOperand(1)});
3355+
Op = DAG.getNode(GetPromotionOpcodeStrict(NVT, OVT), dl,
3356+
{MVT::i16, MVT::Other}, {Op.getValue(1), Op});
3357+
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
3358+
return Op;
3359+
}
3360+
33623361
SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
33633362

33643363
// Round the value to the softened type.
@@ -3447,6 +3446,8 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
34473446
Res = SoftPromoteHalfOp_FAKE_USE(N, OpNo);
34483447
break;
34493448
case ISD::FCOPYSIGN: Res = SoftPromoteHalfOp_FCOPYSIGN(N, OpNo); break;
3449+
case ISD::STRICT_FP_TO_SINT:
3450+
case ISD::STRICT_FP_TO_UINT:
34503451
case ISD::FP_TO_SINT:
34513452
case ISD::FP_TO_UINT: Res = SoftPromoteHalfOp_FP_TO_XINT(N); break;
34523453
case ISD::FP_TO_SINT_SAT:
@@ -3473,7 +3474,7 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
34733474

34743475
assert(Res.getNode() != N && "Expected a new node!");
34753476

3476-
assert(Res.getValueType() == N->getValueType(0) && N->getNumValues() == 1 &&
3477+
assert(Res.getValueType() == N->getValueType(0) &&
34773478
"Invalid operand expansion");
34783479

34793480
ReplaceValueWith(SDValue(N, 0), Res);
@@ -3517,16 +3518,8 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
35173518
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));
35183519

35193520
if (IsStrict) {
3520-
unsigned Opcode;
3521-
if (SVT == MVT::f16)
3522-
Opcode = ISD::STRICT_FP16_TO_FP;
3523-
else if (SVT == MVT::bf16)
3524-
Opcode = ISD::STRICT_BF16_TO_FP;
3525-
else
3526-
llvm_unreachable("unknown half type");
3527-
SDValue Res =
3528-
DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
3529-
{N->getOperand(0), Op});
3521+
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
3522+
{RVT, MVT::Other}, {N->getOperand(0), Op});
35303523
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
35313524
ReplaceValueWith(SDValue(N, 0), Res);
35323525
return SDValue();
@@ -3537,17 +3530,25 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
35373530

35383531
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT(SDNode *N) {
35393532
EVT RVT = N->getValueType(0);
3540-
SDValue Op = N->getOperand(0);
3533+
bool IsStrict = N->isStrictFPOpcode();
3534+
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
35413535
EVT SVT = Op.getValueType();
35423536
SDLoc dl(N);
35433537

3544-
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
3545-
3538+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), SVT);
35463539
Op = GetSoftPromotedHalf(Op);
35473540

3548-
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
3541+
if (IsStrict) {
3542+
Op = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), dl, {NVT, MVT::Other},
3543+
{N->getOperand(0), Op});
3544+
Op = DAG.getNode(N->getOpcode(), dl, {RVT, MVT::Other},
3545+
{Op.getValue(1), Op});
3546+
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
3547+
return Op;
3548+
}
35493549

3550-
return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Res);
3550+
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
3551+
return DAG.getNode(N->getOpcode(), dl, RVT, Res);
35513552
}
35523553

35533554
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT_SAT(SDNode *N) {

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3952,12 +3952,7 @@ void DAGTypeLegalizer::ExpandIntRes_FP_TO_XINT(SDNode *N, SDValue &Lo,
39523952
Op = GetPromotedFloat(Op);
39533953

39543954
if (getTypeAction(Op.getValueType()) == TargetLowering::TypeSoftPromoteHalf) {
3955-
EVT OFPVT = Op.getValueType();
3956-
EVT NFPVT = TLI.getTypeToTransformTo(*DAG.getContext(), OFPVT);
3957-
Op = GetSoftPromotedHalf(Op);
3958-
Op = DAG.getNode(OFPVT == MVT::f16 ? ISD::FP16_TO_FP : ISD::BF16_TO_FP, dl,
3959-
NFPVT, Op);
3960-
Op = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, VT, Op);
3955+
Op = SoftPromoteHalfOp_FP_TO_XINT(N);
39613956
SplitInteger(Op, Lo, Hi);
39623957
return;
39633958
}

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
529529
Subtarget.isSoftFPABI() ? LibCall : Custom);
530530
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
531531
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
532+
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f32, Custom);
533+
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f32, Custom);
532534

533535
if (Subtarget.hasStdExtZfa()) {
534536
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
@@ -577,6 +579,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
577579
Subtarget.isSoftFPABI() ? LibCall : Custom);
578580
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
579581
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
582+
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f64, Custom);
583+
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f64, Expand);
580584
}
581585

582586
if (Subtarget.is64Bit()) {
@@ -6852,33 +6856,45 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
68526856
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
68536857
return Res;
68546858
}
6859+
case ISD::STRICT_FP_TO_FP16:
68556860
case ISD::FP_TO_FP16: {
68566861
// Custom lower to ensure the libcall return is passed in an FPR on hard
68576862
// float ABIs.
68586863
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
68596864
SDLoc DL(Op);
68606865
MakeLibCallOptions CallOptions;
6861-
RTLIB::Libcall LC =
6862-
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::f16);
6863-
SDValue Res =
6864-
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
6866+
bool IsStrict = Op->isStrictFPOpcode();
6867+
SDValue Op0 = IsStrict ? Op.getOperand(1) : Op.getOperand(0);
6868+
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
6869+
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
6870+
SDValue Res;
6871+
std::tie(Res, Chain) =
6872+
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
68656873
if (Subtarget.is64Bit())
68666874
return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
6867-
return DAG.getBitcast(MVT::i32, Res);
6875+
SDValue Result = DAG.getBitcast(MVT::i32, IsStrict ? Res.getValue(0) : Res);
6876+
if (IsStrict)
6877+
return DAG.getMergeValues({Result, Chain}, DL);
6878+
return Result;
68686879
}
6880+
case ISD::STRICT_FP16_TO_FP:
68696881
case ISD::FP16_TO_FP: {
68706882
// Custom lower to ensure the libcall argument is passed in an FPR on hard
68716883
// float ABIs.
68726884
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
68736885
SDLoc DL(Op);
68746886
MakeLibCallOptions CallOptions;
6887+
bool IsStrict = Op->isStrictFPOpcode();
6888+
SDValue Op0 = IsStrict ? Op.getOperand(1) : Op.getOperand(0);
6889+
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
68756890
SDValue Arg = Subtarget.is64Bit()
6876-
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32,
6877-
Op.getOperand(0))
6878-
: DAG.getBitcast(MVT::f32, Op.getOperand(0));
6879-
SDValue Res =
6880-
makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL)
6881-
.first;
6891+
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op0)
6892+
: DAG.getBitcast(MVT::f32, Op0);
6893+
SDValue Res;
6894+
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
6895+
CallOptions, DL, Chain);
6896+
if (IsStrict)
6897+
return DAG.getMergeValues({Res, Chain}, DL);
68826898
return Res;
68836899
}
68846900
case ISD::FTRUNC:

0 commit comments

Comments
 (0)