Skip to content

Commit a9386dc

Browse files
[RISCV] Add support for STRICT_UINT_TO_FP and STRICT_SINT_TO_FP
This patch adds support for the missing STRICT_UINT_TO_FP and STRICT_SINT_TO_FP for riscv and adds a test case for rv32 which was previously crashing. The code is mostly in line with how other strict_* nodes are handled (e.g., getting op(1) instead of op(0) when it's a strict node, as op(0) in a strict node is the entry token). The only difference is the call to DAG.ReplaceAllUsesOfValueWith, to replace any use of the STRICT_UINT_TO_FP and STRICT_SINT_TO_FP with the new nodes we are creating here (e.g., strict_fp_to_fp16, strict_fp16_to_fp). To understand why this is needed, let us consider the following program: Result type 0 illegal: t4: i16 = truncate t3 SelectionDAG has 14 nodes: t0: ch,glue = EntryToken t2: f32,ch = CopyFromReg t0, Register:f32 %0 t3: i32 = bitcast t2 t4: i16 = truncate t3 t5: f16 = bitcast t4 t6: i32,ch = strict_fp_to_sint t0, t5 t11: i32,ch = strict_fp_to_sint t10:1, t10 t8: ch,glue = CopyToReg t6:1, Register:i32 $x10, t11 t13: i32 = and t3, Constant:i32<65535> t10: f32,ch = strict_fp16_to_fp t0, t13 t9: ch = RISCVISD::RET_GLUE t8, Register:i32 $x10, t8:1 The chain of the original strict conversion was still being used elsewhere in the function, meaning it (and its operands) were being kept alive. This chain has a i16 = truncate t3 which is an illegal return type in rv32 without zfhmin or zfh. This was not a problem for the non-strict version of the nodes (i.e., UINT_TO_FP and SINT_TO_FP) as they don't have chain so the offending node was being removed as dead code. The DAG.ReplaceAllUsesOfValueWith replaces the usage of the STRICT_UINT_TO_FP and STRICT_SINT_TO_FP with the new nodes that are being created, so the truncate node is now also being removed as dead code.
1 parent e80f489 commit a9386dc

File tree

4 files changed

+409
-52
lines changed

4 files changed

+409
-52
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,32 +2385,26 @@ SDValue DAGTypeLegalizer::ExpandFloatOp_LLRINT(SDNode *N) {
23852385
//
23862386

23872387
static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
2388-
if (OpVT == MVT::f16) {
2388+
if (OpVT == MVT::f16)
23892389
return ISD::FP16_TO_FP;
2390-
} else if (RetVT == MVT::f16) {
2390+
if (RetVT == MVT::f16)
23912391
return ISD::FP_TO_FP16;
2392-
} else if (OpVT == MVT::bf16) {
2392+
if (OpVT == MVT::bf16)
23932393
return ISD::BF16_TO_FP;
2394-
} else if (RetVT == MVT::bf16) {
2394+
if (RetVT == MVT::bf16)
23952395
return ISD::FP_TO_BF16;
2396-
}
2397-
23982396
report_fatal_error("Attempt at an invalid promotion-related conversion");
23992397
}
24002398

24012399
static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
24022400
if (OpVT == MVT::f16)
24032401
return ISD::STRICT_FP16_TO_FP;
2404-
24052402
if (RetVT == MVT::f16)
24062403
return ISD::STRICT_FP_TO_FP16;
2407-
24082404
if (OpVT == MVT::bf16)
24092405
return ISD::STRICT_BF16_TO_FP;
2410-
24112406
if (RetVT == MVT::bf16)
24122407
return ISD::STRICT_FP_TO_BF16;
2413-
24142408
report_fatal_error("Attempt at an invalid promotion-related conversion");
24152409
}
24162410

@@ -3138,6 +3132,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
31383132
break;
31393133
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
31403134
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
3135+
case ISD::STRICT_SINT_TO_FP:
3136+
case ISD::STRICT_UINT_TO_FP:
31413137
case ISD::SINT_TO_FP:
31423138
case ISD::UINT_TO_FP: R = SoftPromoteHalfRes_XINT_TO_FP(N); break;
31433139
case ISD::UNDEF: R = SoftPromoteHalfRes_UNDEF(N); break;
@@ -3288,19 +3284,13 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FFREXP(SDNode *N) {
32883284

32893285
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
32903286
EVT RVT = N->getValueType(0);
3291-
EVT SVT = N->getOperand(0).getValueType();
3287+
bool IsStrict = N->isStrictFPOpcode();
3288+
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
3289+
EVT SVT = Op.getValueType();
32923290

3293-
if (N->isStrictFPOpcode()) {
3294-
// FIXME: assume we only have two f16 variants for now.
3295-
unsigned Opcode;
3296-
if (RVT == MVT::f16)
3297-
Opcode = ISD::STRICT_FP_TO_FP16;
3298-
else if (RVT == MVT::bf16)
3299-
Opcode = ISD::STRICT_FP_TO_BF16;
3300-
else
3301-
llvm_unreachable("unknown half type");
3302-
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
3303-
{N->getOperand(0), N->getOperand(1)});
3291+
if (IsStrict) {
3292+
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
3293+
{MVT::i16, MVT::Other}, {N->getOperand(0), Op});
33043294
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
33053295
return Res;
33063296
}
@@ -3359,6 +3349,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
33593349
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
33603350
SDLoc dl(N);
33613351

3352+
if (N->isStrictFPOpcode()) {
3353+
SDValue Op = DAG.getNode(N->getOpcode(), dl, {NVT, MVT::Other},
3354+
{N->getOperand(0), N->getOperand(1)});
3355+
Op = DAG.getNode(GetPromotionOpcodeStrict(NVT, OVT), dl,
3356+
{MVT::i16, MVT::Other}, {N->getOperand(0), Op});
3357+
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
3358+
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Op.getValue(1));
3359+
return Op;
3360+
}
3361+
33623362
SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
33633363

33643364
// Round the value to the softened type.
@@ -3447,6 +3447,8 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
34473447
Res = SoftPromoteHalfOp_FAKE_USE(N, OpNo);
34483448
break;
34493449
case ISD::FCOPYSIGN: Res = SoftPromoteHalfOp_FCOPYSIGN(N, OpNo); break;
3450+
case ISD::STRICT_FP_TO_SINT:
3451+
case ISD::STRICT_FP_TO_UINT:
34503452
case ISD::FP_TO_SINT:
34513453
case ISD::FP_TO_UINT: Res = SoftPromoteHalfOp_FP_TO_XINT(N); break;
34523454
case ISD::FP_TO_SINT_SAT:
@@ -3473,7 +3475,7 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
34733475

34743476
assert(Res.getNode() != N && "Expected a new node!");
34753477

3476-
assert(Res.getValueType() == N->getValueType(0) && N->getNumValues() == 1 &&
3478+
assert(Res.getValueType() == N->getValueType(0) &&
34773479
"Invalid operand expansion");
34783480

34793481
ReplaceValueWith(SDValue(N, 0), Res);
@@ -3517,16 +3519,8 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
35173519
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));
35183520

35193521
if (IsStrict) {
3520-
unsigned Opcode;
3521-
if (SVT == MVT::f16)
3522-
Opcode = ISD::STRICT_FP16_TO_FP;
3523-
else if (SVT == MVT::bf16)
3524-
Opcode = ISD::STRICT_BF16_TO_FP;
3525-
else
3526-
llvm_unreachable("unknown half type");
3527-
SDValue Res =
3528-
DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
3529-
{N->getOperand(0), Op});
3522+
SDValue Res = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), SDLoc(N),
3523+
{RVT, MVT::Other}, {N->getOperand(0), Op});
35303524
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
35313525
ReplaceValueWith(SDValue(N, 0), Res);
35323526
return SDValue();
@@ -3537,17 +3531,26 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
35373531

35383532
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT(SDNode *N) {
35393533
EVT RVT = N->getValueType(0);
3540-
SDValue Op = N->getOperand(0);
3534+
bool IsStrict = N->isStrictFPOpcode();
3535+
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
35413536
EVT SVT = Op.getValueType();
35423537
SDLoc dl(N);
35433538

3544-
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
3545-
3539+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), SVT);
35463540
Op = GetSoftPromotedHalf(Op);
35473541

3548-
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
3542+
if (IsStrict) {
3543+
Op = DAG.getNode(GetPromotionOpcodeStrict(SVT, RVT), dl, {NVT, MVT::Other},
3544+
{N->getOperand(0), Op});
3545+
Op = DAG.getNode(N->getOpcode(), dl, {RVT, MVT::Other},
3546+
{N->getOperand(0), Op});
3547+
ReplaceValueWith(SDValue(N, 1), Op.getValue(1));
3548+
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Op.getValue(1));
3549+
return Op;
3550+
}
35493551

3550-
return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Res);
3552+
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);
3553+
return DAG.getNode(N->getOpcode(), dl, RVT, Res);
35513554
}
35523555

35533556
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT_SAT(SDNode *N) {

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3952,12 +3952,7 @@ void DAGTypeLegalizer::ExpandIntRes_FP_TO_XINT(SDNode *N, SDValue &Lo,
39523952
Op = GetPromotedFloat(Op);
39533953

39543954
if (getTypeAction(Op.getValueType()) == TargetLowering::TypeSoftPromoteHalf) {
3955-
EVT OFPVT = Op.getValueType();
3956-
EVT NFPVT = TLI.getTypeToTransformTo(*DAG.getContext(), OFPVT);
3957-
Op = GetSoftPromotedHalf(Op);
3958-
Op = DAG.getNode(OFPVT == MVT::f16 ? ISD::FP16_TO_FP : ISD::BF16_TO_FP, dl,
3959-
NFPVT, Op);
3960-
Op = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, VT, Op);
3955+
Op = SoftPromoteHalfOp_FP_TO_XINT(N);
39613956
SplitInteger(Op, Lo, Hi);
39623957
return;
39633958
}

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
529529
Subtarget.isSoftFPABI() ? LibCall : Custom);
530530
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
531531
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
532+
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f32, Custom);
533+
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f32, Custom);
532534

533535
if (Subtarget.hasStdExtZfa()) {
534536
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
@@ -577,6 +579,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
577579
Subtarget.isSoftFPABI() ? LibCall : Custom);
578580
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
579581
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
582+
setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f64, Custom);
583+
setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f64, Expand);
580584
}
581585

582586
if (Subtarget.is64Bit()) {
@@ -6851,30 +6855,35 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
68516855
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
68526856
return Res;
68536857
}
6858+
case ISD::STRICT_FP_TO_FP16:
68546859
case ISD::FP_TO_FP16: {
68556860
// Custom lower to ensure the libcall return is passed in an FPR on hard
68566861
// float ABIs.
68576862
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
68586863
SDLoc DL(Op);
68596864
MakeLibCallOptions CallOptions;
6860-
RTLIB::Libcall LC =
6861-
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::f16);
6862-
SDValue Res =
6863-
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
6865+
bool IsStrict = Op->isStrictFPOpcode();
6866+
SDValue Op0 = IsStrict ? Op.getOperand(1) : Op.getOperand(0);
6867+
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
6868+
SDValue Res = makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL).first;
68646869
if (Subtarget.is64Bit())
68656870
return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
6866-
return DAG.getBitcast(MVT::i32, Res);
6871+
SDValue Result = DAG.getBitcast(MVT::i32, IsStrict ? Res.getValue(0) : Res);
6872+
if (IsStrict)
6873+
return DAG.getMergeValues({Result, Op.getOperand(0)}, DL);
6874+
return Result;
68676875
}
6876+
case ISD::STRICT_FP16_TO_FP:
68686877
case ISD::FP16_TO_FP: {
68696878
// Custom lower to ensure the libcall argument is passed in an FPR on hard
68706879
// float ABIs.
68716880
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
68726881
SDLoc DL(Op);
68736882
MakeLibCallOptions CallOptions;
6883+
SDValue Op0 = Op->isStrictFPOpcode() ? Op.getOperand(1) : Op.getOperand(0);
68746884
SDValue Arg = Subtarget.is64Bit()
6875-
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32,
6876-
Op.getOperand(0))
6877-
: DAG.getBitcast(MVT::f32, Op.getOperand(0));
6885+
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op0)
6886+
: DAG.getBitcast(MVT::f32, Op0);
68786887
SDValue Res =
68796888
makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL)
68806889
.first;

0 commit comments

Comments
 (0)