Skip to content

Commit 31447fb

Browse files
runsenyyuxuanchen1997
authored andcommitted
[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (#99890)
Summary: Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm. 1. lowering mathOp with fastMath attribute to correct libdevice intrinsic. 2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250617
1 parent 77cccf5 commit 31447fb

File tree

4 files changed

+312
-50
lines changed

4 files changed

+312
-50
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
namespace mlir {
1717

18-
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
19-
/// depending on the element type that Op operates upon. The function
20-
/// declaration is added in case it was not added before.
18+
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
19+
/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
20+
/// Op. The function declaration is added in case it was not added before.
2121
///
2222
/// If the input values are of f16 type, the value is first casted to f32, the
2323
/// function called and then the result casted back.
@@ -27,13 +27,22 @@ namespace mlir {
2727
///
2828
/// will be transformed into
2929
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
30+
///
31+
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
32+
/// to the approximate calculation function.
33+
///
34+
/// Also example with NVVM:
35+
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
36+
///
37+
/// will be transformed into
38+
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
3039
template <typename SourceOp>
3140
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
3241
public:
3342
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34-
StringRef f64Func)
43+
StringRef f64Func, StringRef f32ApproxFunc)
3544
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36-
f64Func(f64Func) {}
45+
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
3746

3847
LogicalResult
3948
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +64,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5564
Type resultType = castedOperands.front().getType();
5665
Type funcType = getFunctionType(resultType, castedOperands);
5766
StringRef funcName =
58-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
67+
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
68+
op.getFastmath());
5969
if (funcName.empty())
6070
return failure();
6171

@@ -90,9 +100,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
90100
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
91101
}
92102

93-
StringRef getFunctionName(Type type) const {
94-
if (isa<Float32Type>(type))
95-
return f32Func;
103+
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
104+
if (isa<Float32Type>(type)) {
105+
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
106+
!f32ApproxFunc.empty())
107+
return f32ApproxFunc;
108+
else
109+
return f32Func;
110+
}
96111
if (isa<Float64Type>(type))
97112
return f64Func;
98113
return "";
@@ -113,6 +128,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
113128

114129
const std::string f32Func;
115130
const std::string f64Func;
131+
const std::string f32ApproxFunc;
116132
};
117133

118134
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309309
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310310
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311311
target.addIllegalDialect<gpu::GPUDialect>();
312-
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313-
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314-
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315-
LLVM::SqrtOp>();
312+
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313+
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314+
LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315+
LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316+
LLVM::SinOp, LLVM::SqrtOp>();
316317

317318
// TODO: Remove once we support replacing non-root ops.
318319
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321322
template <typename OpTy>
322323
static void populateOpPatterns(LLVMTypeConverter &converter,
323324
RewritePatternSet &patterns, StringRef f32Func,
324-
StringRef f64Func) {
325+
StringRef f64Func,
326+
StringRef f32ApproxFunc = "") {
325327
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
326-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
328+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
329+
f32ApproxFunc);
327330
}
328331

329332
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +373,68 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370373
StringAttr::get(&converter.getContext(),
371374
NVVM::NVVMDialect::getMaxntidAttrName()));
372375

376+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
377+
"__nv_fmod");
373378
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
374379
"__nv_fabs");
380+
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
381+
"__nv_acos");
382+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
383+
"__nv_acosh");
384+
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
385+
"__nv_asin");
386+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
387+
"__nv_asinh");
375388
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
376389
"__nv_atan");
377390
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
378391
"__nv_atan2");
392+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
393+
"__nv_atanh");
379394
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
380395
"__nv_cbrt");
381396
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
382397
"__nv_ceil");
383-
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
398+
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
399+
"__nv_copysign");
400+
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
401+
"__nv_fast_cosf");
402+
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
403+
"__nv_cosh");
384404
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
385-
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
405+
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
406+
"__nv_fast_expf");
386407
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
387408
"__nv_exp2");
388409
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
389410
"__nv_expm1");
390411
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
391412
"__nv_floor");
392-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
393-
"__nv_fmod");
394-
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
413+
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
414+
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
415+
"__nv_fast_logf");
416+
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
417+
"__nv_log10", "__nv_fast_log10f");
395418
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
396419
"__nv_log1p");
397-
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
398-
"__nv_log10");
399420
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
400-
"__nv_log2");
401-
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
402-
"__nv_pow");
421+
"__nv_log2", "__nv_fast_log2f");
422+
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
423+
"__nv_fast_powf");
424+
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
425+
"__nv_round");
426+
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
427+
"__nv_rint");
403428
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
404429
"__nv_rsqrt");
405-
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
430+
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
431+
"__nv_fast_sinf");
432+
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
433+
"__nv_sinh");
406434
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
407435
"__nv_sqrt");
436+
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
437+
"__nv_fast_tanf");
408438
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
409439
"__nv_tanh");
410-
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
411440
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func) {
41+
StringRef f64Func,
42+
StringRef f32ApproxFunc = "") {
4243
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
43-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
44+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45+
f32ApproxFunc);
4446
}
4547

4648
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

0 commit comments

Comments
 (0)