Skip to content

Commit db8b85a

Browse files
authored
AMDGPU: Support llvm.exp10 (llvm#65860)
1 parent 3c86bc0 commit db8b85a

File tree

4 files changed

+7641
-16
lines changed

4 files changed

+7641
-16
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,9 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
336336
setOperationAction(ISD::FLOG2, MVT::f32, Custom);
337337
setOperationAction(ISD::FROUND, {MVT::f32, MVT::f64}, Custom);
338338

339-
setOperationAction({ISD::FLOG, ISD::FLOG10, ISD::FEXP, ISD::FEXP2}, MVT::f32,
340-
Custom);
339+
setOperationAction(
340+
{ISD::FLOG, ISD::FLOG10, ISD::FEXP, ISD::FEXP2, ISD::FEXP10}, MVT::f32,
341+
Custom);
341342

342343
setOperationAction(ISD::FNEARBYINT, {MVT::f16, MVT::f32, MVT::f64}, Custom);
343344

@@ -352,7 +353,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
352353
setOperationAction({ISD::FLOG2, ISD::FEXP2}, MVT::f16, Custom);
353354
}
354355

355-
setOperationAction({ISD::FLOG10, ISD::FLOG, ISD::FEXP}, MVT::f16, Custom);
356+
setOperationAction({ISD::FLOG10, ISD::FLOG, ISD::FEXP, ISD::FEXP10}, MVT::f16,
357+
Custom);
356358

357359
// FIXME: These IS_FPCLASS vector fp types are marked custom so it reaches
358360
// scalarization code. Can be removed when IS_FPCLASS expand isn't called by
@@ -457,14 +459,17 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
457459

458460
for (MVT VT : FloatVectorTypes) {
459461
setOperationAction(
460-
{ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD,
461-
ISD::FCEIL, ISD::FCOS, ISD::FDIV, ISD::FEXP2,
462-
ISD::FEXP, ISD::FLOG2, ISD::FREM, ISD::FLOG,
463-
ISD::FLOG10, ISD::FPOW, ISD::FFLOOR, ISD::FTRUNC,
464-
ISD::FMUL, ISD::FMA, ISD::FRINT, ISD::FNEARBYINT,
465-
ISD::FSQRT, ISD::FSIN, ISD::FSUB, ISD::FNEG,
466-
ISD::VSELECT, ISD::SELECT_CC, ISD::FCOPYSIGN, ISD::VECTOR_SHUFFLE,
467-
ISD::SETCC, ISD::FCANONICALIZE, ISD::FROUNDEVEN},
462+
{ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM,
463+
ISD::FADD, ISD::FCEIL, ISD::FCOS,
464+
ISD::FDIV, ISD::FEXP2, ISD::FEXP,
465+
ISD::FEXP10, ISD::FLOG2, ISD::FREM,
466+
ISD::FLOG, ISD::FLOG10, ISD::FPOW,
467+
ISD::FFLOOR, ISD::FTRUNC, ISD::FMUL,
468+
ISD::FMA, ISD::FRINT, ISD::FNEARBYINT,
469+
ISD::FSQRT, ISD::FSIN, ISD::FSUB,
470+
ISD::FNEG, ISD::VSELECT, ISD::SELECT_CC,
471+
ISD::FCOPYSIGN, ISD::VECTOR_SHUFFLE, ISD::SETCC,
472+
ISD::FCANONICALIZE, ISD::FROUNDEVEN},
468473
VT, Expand);
469474
}
470475

@@ -1322,6 +1327,7 @@ SDValue AMDGPUTargetLowering::LowerOperation(SDValue Op,
13221327
case ISD::FLOG10:
13231328
return LowerFLOGCommon(Op, DAG);
13241329
case ISD::FEXP:
1330+
case ISD::FEXP10:
13251331
return lowerFEXP(Op, DAG);
13261332
case ISD::FEXP2:
13271333
return lowerFEXP2(Op, DAG);
@@ -1367,6 +1373,7 @@ void AMDGPUTargetLowering::ReplaceNodeResults(SDNode *N,
13671373
Results.push_back(Lowered);
13681374
return;
13691375
case ISD::FEXP:
1376+
case ISD::FEXP10:
13701377
if (SDValue Lowered = lowerFEXP(SDValue(N, 0), DAG))
13711378
Results.push_back(Lowered);
13721379
return;
@@ -2841,12 +2848,66 @@ SDValue AMDGPUTargetLowering::lowerFEXPUnsafe(SDValue X, const SDLoc &SL,
28412848
Flags);
28422849
}
28432850

2851+
/// Emit approx-funcs appropriate lowering for exp10. inf/nan should still be
2852+
/// handled correctly.
2853+
SDValue AMDGPUTargetLowering::lowerFEXP10Unsafe(SDValue X, const SDLoc &SL,
2854+
SelectionDAG &DAG,
2855+
SDNodeFlags Flags) const {
2856+
const EVT VT = X.getValueType();
2857+
const unsigned Exp2Op = VT == MVT::f32 ? AMDGPUISD::EXP : ISD::FEXP2;
2858+
2859+
if (VT != MVT::f32 || !needsDenormHandlingF32(DAG, X, Flags)) {
2860+
// exp2(x * 0x1.a92000p+1f) * exp2(x * 0x1.4f0978p-11f);
2861+
SDValue K0 = DAG.getConstantFP(0x1.a92000p+1f, SL, VT);
2862+
SDValue K1 = DAG.getConstantFP(0x1.4f0978p-11f, SL, VT);
2863+
2864+
SDValue Mul0 = DAG.getNode(ISD::FMUL, SL, VT, X, K0, Flags);
2865+
SDValue Exp2_0 = DAG.getNode(Exp2Op, SL, VT, Mul0, Flags);
2866+
SDValue Mul1 = DAG.getNode(ISD::FMUL, SL, VT, X, K1, Flags);
2867+
SDValue Exp2_1 = DAG.getNode(Exp2Op, SL, VT, Mul1, Flags);
2868+
return DAG.getNode(ISD::FMUL, SL, VT, Exp2_0, Exp2_1);
2869+
}
2870+
2871+
// bool s = x < -0x1.2f7030p+5f;
2872+
// x += s ? 0x1.0p+5f : 0.0f;
2873+
// exp10 = exp2(x * 0x1.a92000p+1f) *
2874+
// exp2(x * 0x1.4f0978p-11f) *
2875+
// (s ? 0x1.9f623ep-107f : 1.0f);
2876+
2877+
EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2878+
2879+
SDValue Threshold = DAG.getConstantFP(-0x1.2f7030p+5f, SL, VT);
2880+
SDValue NeedsScaling = DAG.getSetCC(SL, SetCCVT, X, Threshold, ISD::SETOLT);
2881+
2882+
SDValue ScaleOffset = DAG.getConstantFP(0x1.0p+5f, SL, VT);
2883+
SDValue ScaledX = DAG.getNode(ISD::FADD, SL, VT, X, ScaleOffset, Flags);
2884+
SDValue AdjustedX =
2885+
DAG.getNode(ISD::SELECT, SL, VT, NeedsScaling, ScaledX, X);
2886+
2887+
SDValue K0 = DAG.getConstantFP(0x1.a92000p+1f, SL, VT);
2888+
SDValue K1 = DAG.getConstantFP(0x1.4f0978p-11f, SL, VT);
2889+
2890+
SDValue Mul0 = DAG.getNode(ISD::FMUL, SL, VT, AdjustedX, K0, Flags);
2891+
SDValue Exp2_0 = DAG.getNode(Exp2Op, SL, VT, Mul0, Flags);
2892+
SDValue Mul1 = DAG.getNode(ISD::FMUL, SL, VT, AdjustedX, K1, Flags);
2893+
SDValue Exp2_1 = DAG.getNode(Exp2Op, SL, VT, Mul1, Flags);
2894+
2895+
SDValue MulExps = DAG.getNode(ISD::FMUL, SL, VT, Exp2_0, Exp2_1, Flags);
2896+
2897+
SDValue ResultScaleFactor = DAG.getConstantFP(0x1.9f623ep-107f, SL, VT);
2898+
SDValue AdjustedResult =
2899+
DAG.getNode(ISD::FMUL, SL, VT, MulExps, ResultScaleFactor, Flags);
2900+
2901+
return DAG.getNode(ISD::SELECT, SL, VT, NeedsScaling, AdjustedResult, MulExps,
2902+
Flags);
2903+
}
2904+
28442905
SDValue AMDGPUTargetLowering::lowerFEXP(SDValue Op, SelectionDAG &DAG) const {
28452906
EVT VT = Op.getValueType();
28462907
SDLoc SL(Op);
28472908
SDValue X = Op.getOperand(0);
28482909
SDNodeFlags Flags = Op->getFlags();
2849-
const bool IsExp10 = false; // TODO: For some reason exp10 is missing
2910+
const bool IsExp10 = Op.getOpcode() == ISD::FEXP10;
28502911

28512912
if (VT.getScalarType() == MVT::f16) {
28522913
// v_exp_f16 (fmul x, log2e)
@@ -2871,8 +2932,8 @@ SDValue AMDGPUTargetLowering::lowerFEXP(SDValue Op, SelectionDAG &DAG) const {
28712932
// TODO: Interpret allowApproxFunc as ignoring DAZ. This is currently copying
28722933
// library behavior. Also, is known-not-daz source sufficient?
28732934
if (allowApproxFunc(DAG, Flags)) {
2874-
assert(!IsExp10 && "todo exp10 support");
2875-
return lowerFEXPUnsafe(X, SL, DAG, Flags);
2935+
return IsExp10 ? lowerFEXP10Unsafe(X, SL, DAG, Flags)
2936+
: lowerFEXPUnsafe(X, SL, DAG, Flags);
28762937
}
28772938

28782939
// Algorithm:

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class AMDGPUTargetLowering : public TargetLowering {
8080

8181
SDValue lowerFEXPUnsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
8282
SDNodeFlags Flags) const;
83+
SDValue lowerFEXP10Unsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
84+
SDNodeFlags Flags) const;
8385
SDValue lowerFEXP(SDValue Op, SelectionDAG &DAG) const;
8486

8587
SDValue LowerCTLZ_CTTZ(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,8 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
11731173
Log2Ops.scalarize(0)
11741174
.lower();
11751175

1176-
auto &LogOps = getActionDefinitionsBuilder({G_FLOG, G_FLOG10, G_FEXP});
1176+
auto &LogOps =
1177+
getActionDefinitionsBuilder({G_FLOG, G_FLOG10, G_FEXP, G_FEXP10});
11771178
LogOps.customFor({S32, S16});
11781179
LogOps.clampScalar(0, MinScalarFPTy, S32)
11791180
.scalarize(0);
@@ -2045,6 +2046,7 @@ bool AMDGPULegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
20452046
case TargetOpcode::G_FEXP2:
20462047
return legalizeFExp2(MI, B);
20472048
case TargetOpcode::G_FEXP:
2049+
case TargetOpcode::G_FEXP10:
20482050
return legalizeFExp(MI, B);
20492051
case TargetOpcode::G_FPOW:
20502052
return legalizeFPow(MI, B);
@@ -3466,7 +3468,7 @@ bool AMDGPULegalizerInfo::legalizeFExp(MachineInstr &MI,
34663468
LLT Ty = MRI.getType(Dst);
34673469
const LLT F16 = LLT::scalar(16);
34683470
const LLT F32 = LLT::scalar(32);
3469-
const bool IsExp10 = false; // TODO: For some reason exp10 is missing
3471+
const bool IsExp10 = MI.getOpcode() == TargetOpcode::G_FEXP10;
34703472

34713473
if (Ty == F16) {
34723474
// v_exp_f16 (fmul x, log2e)

0 commit comments

Comments
 (0)