Skip to content

Commit d5cb1ee

Browse files
committed
[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp
1 parent 1ebfc81 commit d5cb1ee

File tree

4 files changed

+312
-50
lines changed

4 files changed

+312
-50
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
namespace mlir {
1717

18-
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
19-
/// depending on the element type that Op operates upon. The function
20-
/// declaration is added in case it was not added before.
18+
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
19+
/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
20+
/// Op. The function declaration is added in case it was not added before.
2121
///
2222
/// If the input values are of f16 type, the value is first casted to f32, the
2323
/// function called and then the result casted back.
@@ -27,13 +27,22 @@ namespace mlir {
2727
///
2828
/// will be transformed into
2929
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
30+
///
31+
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
32+
/// to the approximate calculation function.
33+
///
34+
/// Also example with NVVM:
35+
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
36+
///
37+
/// will be transformed into
38+
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
3039
template <typename SourceOp>
3140
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
3241
public:
3342
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34-
StringRef f64Func)
43+
StringRef f64Func, StringRef f32ApproxFunc)
3544
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36-
f64Func(f64Func) {}
45+
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
3746

3847
LogicalResult
3948
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +64,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5564
Type resultType = castedOperands.front().getType();
5665
Type funcType = getFunctionType(resultType, castedOperands);
5766
StringRef funcName =
58-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
67+
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
68+
op.getFastmath());
5969
if (funcName.empty())
6070
return failure();
6171

@@ -90,9 +100,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
90100
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
91101
}
92102

93-
StringRef getFunctionName(Type type) const {
94-
if (isa<Float32Type>(type))
95-
return f32Func;
103+
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
104+
if (isa<Float32Type>(type)) {
105+
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
106+
!f32ApproxFunc.empty())
107+
return f32ApproxFunc;
108+
else
109+
return f32Func;
110+
}
96111
if (isa<Float64Type>(type))
97112
return f64Func;
98113
return "";
@@ -113,6 +128,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
113128

114129
const std::string f32Func;
115130
const std::string f64Func;
131+
const std::string f32ApproxFunc;
116132
};
117133

118134
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309309
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310310
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311311
target.addIllegalDialect<gpu::GPUDialect>();
312-
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313-
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314-
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315-
LLVM::SqrtOp>();
312+
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313+
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314+
LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315+
LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316+
LLVM::SinOp, LLVM::SqrtOp>();
316317

317318
// TODO: Remove once we support replacing non-root ops.
318319
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321322
template <typename OpTy>
322323
static void populateOpPatterns(LLVMTypeConverter &converter,
323324
RewritePatternSet &patterns, StringRef f32Func,
324-
StringRef f64Func) {
325+
StringRef f64Func,
326+
StringRef f32ApproxFunc = "") {
325327
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
326-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
328+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
329+
f32ApproxFunc);
327330
}
328331

329332
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +373,68 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370373
StringAttr::get(&converter.getContext(),
371374
NVVM::NVVMDialect::getMaxntidAttrName()));
372375

376+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
377+
"__nv_fmod");
373378
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
374379
"__nv_fabs");
380+
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
381+
"__nv_acos");
382+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
383+
"__nv_acosh");
384+
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
385+
"__nv_asin");
386+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
387+
"__nv_asinh");
375388
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
376389
"__nv_atan");
377390
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
378391
"__nv_atan2");
392+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
393+
"__nv_atanh");
379394
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
380395
"__nv_cbrt");
381396
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
382397
"__nv_ceil");
383-
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
398+
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
399+
"__nv_copysign");
400+
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
401+
"__nv_fast_cosf");
402+
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
403+
"__nv_cosh");
384404
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
385-
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
405+
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
406+
"__nv_fast_expf");
386407
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
387408
"__nv_exp2");
388409
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
389410
"__nv_expm1");
390411
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
391412
"__nv_floor");
392-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
393-
"__nv_fmod");
394-
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
413+
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
414+
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
415+
"__nv_fast_logf");
416+
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
417+
"__nv_log10", "__nv_fast_log10f");
395418
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
396419
"__nv_log1p");
397-
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
398-
"__nv_log10");
399420
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
400-
"__nv_log2");
401-
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
402-
"__nv_pow");
421+
"__nv_log2", "__nv_fast_log2f");
422+
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
423+
"__nv_fast_powf");
424+
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
425+
"__nv_round");
426+
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
427+
"__nv_rint");
403428
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
404429
"__nv_rsqrt");
405-
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
430+
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
431+
"__nv_fast_sinf");
432+
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
433+
"__nv_sinh");
406434
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
407435
"__nv_sqrt");
436+
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
437+
"__nv_fast_tanf");
408438
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
409439
"__nv_tanh");
410-
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
411440
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func) {
41+
StringRef f64Func,
42+
StringRef f32ApproxFunc = "") {
4243
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
43-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
44+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45+
f32ApproxFunc);
4446
}
4547

4648
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

0 commit comments

Comments
 (0)