Skip to content

Commit 3bdec31

Browse files
authored
[RISCV] Custom legalize f16/bf16 FNEG/FABS with Zfhmin/Zbfmin. (#106886)
The LegalizeDAG expansion will go through memory since i16 isn't a legal type. Avoid this by using FMV nodes.
1 parent feb391c commit 3bdec31

File tree

8 files changed

+1064
-2031
lines changed

8 files changed

+1064
-2031
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
459459
setOperationAction(ISD::BR_CC, MVT::bf16, Expand);
460460
setOperationAction(ZfhminZfbfminPromoteOps, MVT::bf16, Promote);
461461
setOperationAction(ISD::FREM, MVT::bf16, Promote);
462-
setOperationAction(ISD::FABS, MVT::bf16, Expand);
463-
setOperationAction(ISD::FNEG, MVT::bf16, Expand);
462+
setOperationAction(ISD::FABS, MVT::bf16, Custom);
463+
setOperationAction(ISD::FNEG, MVT::bf16, Custom);
464464
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
465465
}
466466

@@ -476,8 +476,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
476476
setOperationAction({ISD::STRICT_LRINT, ISD::STRICT_LLRINT,
477477
ISD::STRICT_LROUND, ISD::STRICT_LLROUND},
478478
MVT::f16, Legal);
479-
setOperationAction(ISD::FABS, MVT::f16, Expand);
480-
setOperationAction(ISD::FNEG, MVT::f16, Expand);
479+
setOperationAction(ISD::FABS, MVT::f16, Custom);
480+
setOperationAction(ISD::FNEG, MVT::f16, Custom);
481481
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
482482
}
483483

@@ -5942,6 +5942,29 @@ static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG,
59425942
return Res;
59435943
}
59445944

5945+
static SDValue lowerFABSorFNEG(SDValue Op, SelectionDAG &DAG,
5946+
const RISCVSubtarget &Subtarget) {
5947+
bool IsFABS = Op.getOpcode() == ISD::FABS;
5948+
assert((IsFABS || Op.getOpcode() == ISD::FNEG) &&
5949+
"Wrong opcode for lowering FABS or FNEG.");
5950+
5951+
MVT XLenVT = Subtarget.getXLenVT();
5952+
MVT VT = Op.getSimpleValueType();
5953+
assert((VT == MVT::f16 || VT == MVT::bf16) && "Unexpected type");
5954+
5955+
SDLoc DL(Op);
5956+
SDValue Fmv =
5957+
DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op.getOperand(0));
5958+
5959+
APInt Mask = IsFABS ? APInt::getSignedMaxValue(16) : APInt::getSignMask(16);
5960+
Mask = Mask.sext(Subtarget.getXLen());
5961+
5962+
unsigned LogicOpc = IsFABS ? ISD::AND : ISD::XOR;
5963+
SDValue Logic =
5964+
DAG.getNode(LogicOpc, DL, XLenVT, Fmv, DAG.getConstant(Mask, DL, XLenVT));
5965+
return DAG.getNode(RISCVISD::FMV_H_X, DL, VT, Logic);
5966+
}
5967+
59455968
/// Get a RISC-V target specified VL op for a given SDNode.
59465969
static unsigned getRISCVVLOp(SDValue Op) {
59475970
#define OP_CASE(NODE) \
@@ -7071,12 +7094,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
70717094
assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
70727095
"Unexpected custom legalisation");
70737096
return SDValue();
7097+
case ISD::FABS:
7098+
case ISD::FNEG:
7099+
if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
7100+
return lowerFABSorFNEG(Op, DAG, Subtarget);
7101+
[[fallthrough]];
70747102
case ISD::FADD:
70757103
case ISD::FSUB:
70767104
case ISD::FMUL:
70777105
case ISD::FDIV:
7078-
case ISD::FNEG:
7079-
case ISD::FABS:
70807106
case ISD::FSQRT:
70817107
case ISD::FMA:
70827108
case ISD::FMINNUM:

0 commit comments

Comments
 (0)