Skip to content

Commit e78887d

Browse files
committed
[NVPTX] Consistently check fast-math flags when lowering div
1 parent a5a6ae1 commit e78887d

File tree

8 files changed

+245
-93
lines changed

8 files changed

+245
-93
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,14 @@ enum PrmtMode {
255255
RC16,
256256
};
257257
}
258-
}
258+
259+
enum class DivPrecisionLevel : unsigned {
260+
Approx = 0,
261+
Full = 1,
262+
IEEE754 = 2,
263+
};
264+
265+
} // namespace NVPTX
259266
void initializeNVPTXDAGToDAGISelLegacyPass(PassRegistry &);
260267
} // namespace llvm
261268

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
6565
return SelectionDAGISel::runOnMachineFunction(MF);
6666
}
6767

68-
int NVPTXDAGToDAGISel::getDivF32Level() const {
69-
return Subtarget->getTargetLowering()->getDivF32Level();
68+
NVPTX::DivPrecisionLevel
69+
NVPTXDAGToDAGISel::getDivF32Level(const SDNode *N) const {
70+
return Subtarget->getTargetLowering()->getDivF32Level(*MF, N);
7071
}
7172

7273
bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4343
// If true, generate mul.wide from sext and mul
4444
bool doMulWide;
4545

46-
int getDivF32Level() const;
46+
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
4747
bool usePrecSqrtF32() const;
4848
bool useF32FTZ() const;
4949
bool allowFMA() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,16 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<int> UsePrecDivF32(
88+
static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
8989
"nvptx-prec-divf32", cl::Hidden,
9090
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
9191
" IEEE Compliant F32 div.rnd if available."),
92-
cl::init(2));
92+
cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0",
93+
"Use div.approx"),
94+
clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
95+
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
96+
"Use IEEE Compliant F32 div.rnd if available")),
97+
cl::init(NVPTX::DivPrecisionLevel::IEEE754));
9398

9499
static cl::opt<bool> UsePrecSqrtF32(
95100
"nvptx-prec-sqrtf32", cl::Hidden,
@@ -109,17 +114,24 @@ static cl::opt<bool> ForceMinByValParamAlign(
109114
" params of device functions."),
110115
cl::init(false));
111116

112-
int NVPTXTargetLowering::getDivF32Level() const {
113-
if (UsePrecDivF32.getNumOccurrences() > 0) {
114-
// If nvptx-prec-div32=N is used on the command-line, always honor it
117+
NVPTX::DivPrecisionLevel
118+
NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
119+
const SDNode *N) const {
120+
// If nvptx-prec-div32=N is used on the command-line, always honor it
121+
if (UsePrecDivF32.getNumOccurrences() > 0)
115122
return UsePrecDivF32;
116-
} else {
117-
// Otherwise, use div.approx if fast math is enabled
118-
if (getTargetMachine().Options.UnsafeFPMath)
119-
return 0;
120-
else
121-
return 2;
123+
124+
// Otherwise, use div.approx if fast math is enabled
125+
if (allowUnsafeFPMath(MF))
126+
return NVPTX::DivPrecisionLevel::Approx;
127+
128+
if (N) {
129+
const SDNodeFlags Flags = N->getFlags();
130+
if (Flags.hasApproximateFuncs())
131+
return NVPTX::DivPrecisionLevel::Approx;
122132
}
133+
134+
return NVPTX::DivPrecisionLevel::IEEE754;
123135
}
124136

125137
bool NVPTXTargetLowering::usePrecSqrtF32() const {
@@ -4947,7 +4959,7 @@ bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
49474959
return allowUnsafeFPMath(MF);
49484960
}
49494961

4950-
bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
4962+
bool NVPTXTargetLowering::allowUnsafeFPMath(const MachineFunction &MF) const {
49514963
// Honor TargetOptions flags that explicitly say unsafe math is okay.
49524964
if (MF.getTarget().Options.UnsafeFPMath)
49534965
return true;

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,8 @@ class NVPTXTargetLowering : public TargetLowering {
214214

215215
// Get the degree of precision we want from 32-bit floating point division
216216
// operations.
217-
//
218-
// 0 - Use ptx div.approx
219-
// 1 - Use ptx.div.full (approximate, but less so than div.approx)
220-
// 2 - Use IEEE-compliant div instructions, if available.
221-
int getDivF32Level() const;
217+
NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
218+
const SDNode *N) const;
222219

223220
// Get whether we should use a precise or approximate 32-bit floating point
224221
// sqrt instruction.
@@ -235,7 +232,7 @@ class NVPTXTargetLowering : public TargetLowering {
235232
unsigned combineRepeatedFPDivisors() const override { return 2; }
236233

237234
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
238-
bool allowUnsafeFPMath(MachineFunction &MF) const;
235+
bool allowUnsafeFPMath(const MachineFunction &MF) const;
239236

240237
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
241238
EVT) const override {

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150150

151151
def doMulWide : Predicate<"doMulWide">;
152152

153-
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
154-
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
155-
156153
def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
157154
def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
158155

@@ -1108,26 +1105,19 @@ def INEG64 :
11081105
//-----------------------------------
11091106

11101107
// Constant 1.0f
1111-
def FloatConst1 : PatLeaf<(fpimm), [{
1112-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEsingle() &&
1113-
N->getValueAPF().convertToFloat() == 1.0f;
1108+
def f32imm_1 : FPImmLeaf<f32, [{
1109+
return &Imm.getSemantics() == &llvm::APFloat::IEEEsingle() &&
1110+
Imm.convertToFloat() == 1.0f;
11141111
}]>;
11151112
// Constant 1.0 (double)
1116-
def DoubleConst1 : PatLeaf<(fpimm), [{
1117-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
1118-
N->getValueAPF().convertToDouble() == 1.0;
1113+
def f64imm_1 : FPImmLeaf<f64, [{
1114+
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
1115+
Imm.convertToDouble() == 1.0;
11191116
}]>;
11201117
// Constant -1.0 (double)
1121-
def DoubleConstNeg1 : PatLeaf<(fpimm), [{
1122-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
1123-
N->getValueAPF().convertToDouble() == -1.0;
1124-
}]>;
1125-
1126-
1127-
// Constant -X -> X (double)
1128-
def NegDoubleConst : SDNodeXForm<fpimm, [{
1129-
return CurDAG->getTargetConstantFP(-(N->getValueAPF()),
1130-
SDLoc(N), MVT::f64);
1118+
def f64imm_neg1 : FPImmLeaf<f64, [{
1119+
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
1120+
Imm.convertToDouble() == -1.0;
11311121
}]>;
11321122

11331123
defm FADD : F3_fma_component<"add", fadd>;
@@ -1178,11 +1168,11 @@ def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
11781168
//
11791169
// F64 division
11801170
//
1181-
def FDIV641r :
1171+
def FRCP64r :
11821172
NVPTXInst<(outs Float64Regs:$dst),
1183-
(ins f64imm:$a, Float64Regs:$b),
1173+
(ins Float64Regs:$b),
11841174
"rcp.rn.f64 \t$dst, $b;",
1185-
[(set f64:$dst, (fdiv DoubleConst1:$a, f64:$b))]>;
1175+
[(set f64:$dst, (fdiv f64imm_1, f64:$b))]>;
11861176
def FDIV64rr :
11871177
NVPTXInst<(outs Float64Regs:$dst),
11881178
(ins Float64Regs:$a, Float64Regs:$b),
@@ -1196,109 +1186,114 @@ def FDIV64ri :
11961186

11971187
// fdiv will be converted to rcp
11981188
// fneg (fdiv 1.0, X) => fneg (rcp.rn X)
1199-
def : Pat<(fdiv DoubleConstNeg1:$a, f64:$b),
1200-
(FNEGf64 (FDIV641r (NegDoubleConst node:$a), $b))>;
1189+
def : Pat<(fdiv f64imm_neg1, f64:$b),
1190+
(FNEGf64 (FRCP64r $b))>;
12011191

12021192
//
12031193
// F32 Approximate reciprocal
12041194
//
1205-
def FDIV321r_ftz :
1195+
1196+
def fdiv_approx : PatFrag<(ops node:$a, node:$b),
1197+
(fdiv node:$a, node:$b), [{
1198+
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Approx;
1199+
}]>;
1200+
1201+
1202+
def FRCP32_approx_r_ftz :
12061203
NVPTXInst<(outs Float32Regs:$dst),
1207-
(ins f32imm:$a, Float32Regs:$b),
1204+
(ins Float32Regs:$b),
12081205
"rcp.approx.ftz.f32 \t$dst, $b;",
1209-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1210-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1211-
def FDIV321r :
1206+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>,
1207+
Requires<[doF32FTZ]>;
1208+
def FRCP32_approx_r :
12121209
NVPTXInst<(outs Float32Regs:$dst),
1213-
(ins f32imm:$a, Float32Regs:$b),
1210+
(ins Float32Regs:$b),
12141211
"rcp.approx.f32 \t$dst, $b;",
1215-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1216-
Requires<[do_DIVF32_APPROX]>;
1212+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
1213+
12171214
//
12181215
// F32 Approximate division
12191216
//
12201217
def FDIV32approxrr_ftz :
12211218
NVPTXInst<(outs Float32Regs:$dst),
12221219
(ins Float32Regs:$a, Float32Regs:$b),
12231220
"div.approx.ftz.f32 \t$dst, $a, $b;",
1224-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1225-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1221+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>,
1222+
Requires<[doF32FTZ]>;
12261223
def FDIV32approxri_ftz :
12271224
NVPTXInst<(outs Float32Regs:$dst),
12281225
(ins Float32Regs:$a, f32imm:$b),
12291226
"div.approx.ftz.f32 \t$dst, $a, $b;",
1230-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1231-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1227+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>,
1228+
Requires<[doF32FTZ]>;
12321229
def FDIV32approxrr :
12331230
NVPTXInst<(outs Float32Regs:$dst),
12341231
(ins Float32Regs:$a, Float32Regs:$b),
12351232
"div.approx.f32 \t$dst, $a, $b;",
1236-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1237-
Requires<[do_DIVF32_APPROX]>;
1233+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
12381234
def FDIV32approxri :
12391235
NVPTXInst<(outs Float32Regs:$dst),
12401236
(ins Float32Regs:$a, f32imm:$b),
12411237
"div.approx.f32 \t$dst, $a, $b;",
1242-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1243-
Requires<[do_DIVF32_APPROX]>;
1238+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
12441239
//
12451240
// F32 Semi-accurate reciprocal
12461241
//
12471242
// rcp.approx gives the same result as div.full(1.0f, a) and is faster.
12481243
//
1249-
def FDIV321r_approx_ftz :
1250-
NVPTXInst<(outs Float32Regs:$dst),
1251-
(ins f32imm:$a, Float32Regs:$b),
1252-
"rcp.approx.ftz.f32 \t$dst, $b;",
1253-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1254-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1255-
def FDIV321r_approx :
1256-
NVPTXInst<(outs Float32Regs:$dst),
1257-
(ins f32imm:$a, Float32Regs:$b),
1258-
"rcp.approx.f32 \t$dst, $b;",
1259-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1260-
Requires<[do_DIVF32_FULL]>;
1244+
1245+
def fdiv_full : PatFrag<(ops node:$a, node:$b),
1246+
(fdiv node:$a, node:$b), [{
1247+
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Full;
1248+
}]>;
1249+
1250+
1251+
def : Pat<(fdiv_full f32imm_1, f32:$b),
1252+
(FRCP32_approx_r_ftz $b)>,
1253+
Requires<[doF32FTZ]>;
1254+
1255+
def : Pat<(fdiv_full f32imm_1, f32:$b),
1256+
(FRCP32_approx_r $b)>;
1257+
12611258
//
12621259
// F32 Semi-accurate division
12631260
//
12641261
def FDIV32rr_ftz :
12651262
NVPTXInst<(outs Float32Regs:$dst),
12661263
(ins Float32Regs:$a, Float32Regs:$b),
12671264
"div.full.ftz.f32 \t$dst, $a, $b;",
1268-
[(set f32:$dst, (fdiv Float32Regs:$a, f32:$b))]>,
1269-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1265+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>,
1266+
Requires<[doF32FTZ]>;
12701267
def FDIV32ri_ftz :
12711268
NVPTXInst<(outs Float32Regs:$dst),
12721269
(ins Float32Regs:$a, f32imm:$b),
12731270
"div.full.ftz.f32 \t$dst, $a, $b;",
1274-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1275-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1271+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>,
1272+
Requires<[doF32FTZ]>;
12761273
def FDIV32rr :
12771274
NVPTXInst<(outs Float32Regs:$dst),
12781275
(ins Float32Regs:$a, Float32Regs:$b),
12791276
"div.full.f32 \t$dst, $a, $b;",
1280-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1281-
Requires<[do_DIVF32_FULL]>;
1277+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
12821278
def FDIV32ri :
12831279
NVPTXInst<(outs Float32Regs:$dst),
12841280
(ins Float32Regs:$a, f32imm:$b),
12851281
"div.full.f32 \t$dst, $a, $b;",
1286-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1287-
Requires<[do_DIVF32_FULL]>;
1282+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
12881283
//
12891284
// F32 Accurate reciprocal
12901285
//
12911286
def FDIV321r_prec_ftz :
12921287
NVPTXInst<(outs Float32Regs:$dst),
1293-
(ins f32imm:$a, Float32Regs:$b),
1288+
(ins Float32Regs:$b),
12941289
"rcp.rn.ftz.f32 \t$dst, $b;",
1295-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1290+
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>,
12961291
Requires<[doF32FTZ]>;
1297-
def FDIV321r_prec :
1292+
def FRCP32r_prec :
12981293
NVPTXInst<(outs Float32Regs:$dst),
1299-
(ins f32imm:$a, Float32Regs:$b),
1294+
(ins Float32Regs:$b),
13001295
"rcp.rn.f32 \t$dst, $b;",
1301-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>;
1296+
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
13021297
//
13031298
// F32 Accurate division
13041299
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,24 +1606,24 @@ def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
16061606
F64RT, F64RT, int_nvvm_rsqrt_approx_d>;
16071607

16081608
// 1.0f / sqrt_approx -> rsqrt_approx
1609-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)),
1609+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_f f32:$a)),
16101610
(INT_NVVM_RSQRT_APPROX_F $a)>,
16111611
Requires<[doRsqrtOpt]>;
1612-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
1612+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
16131613
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16141614
Requires<[doRsqrtOpt]>;
16151615
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1616-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
1616+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
16171617
(INT_NVVM_RSQRT_APPROX_F $a)>,
16181618
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1619-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
1619+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
16201620
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16211621
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
16221622

1623-
def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
1623+
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
16241624
(INT_NVVM_RSQRT_APPROX_F $a)>,
16251625
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1626-
def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
1626+
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
16271627
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16281628
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
16291629
//

0 commit comments

Comments
 (0)