Skip to content

Commit 4eff225

Browse files
committed
[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp
1 parent 2e6558b commit 4eff225

File tree

4 files changed

+263
-46
lines changed

4 files changed

+263
-46
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ template <typename SourceOp>
3131
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
3232
public:
3333
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34-
StringRef f64Func)
34+
StringRef f64Func, StringRef f32FastFunc)
3535
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36-
f64Func(f64Func) {}
36+
f64Func(f64Func), f32FastFunc(f32FastFunc) {}
3737

3838
LogicalResult
3939
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5555
Type resultType = castedOperands.front().getType();
5656
Type funcType = getFunctionType(resultType, castedOperands);
5757
StringRef funcName =
58-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
58+
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
5959
if (funcName.empty())
6060
return failure();
6161

@@ -90,9 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
9090
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
9191
}
9292

93-
StringRef getFunctionName(Type type) const {
94-
if (isa<Float32Type>(type))
95-
return f32Func;
93+
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
94+
if (isa<Float32Type>(type)) {
95+
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
96+
return f32FastFunc;
97+
else
98+
return f32Func;
99+
}
96100
if (isa<Float64Type>(type))
97101
return f64Func;
98102
return "";
@@ -113,6 +117,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
113117

114118
const std::string f32Func;
115119
const std::string f64Func;
120+
const std::string f32FastFunc;
116121
};
117122

118123
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309309
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310310
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311311
target.addIllegalDialect<gpu::GPUDialect>();
312-
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313-
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314-
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315-
LLVM::SqrtOp>();
312+
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313+
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
314+
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
315+
LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
316+
316317

317318
// TODO: Remove once we support replacing non-root ops.
318319
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321322
template <typename OpTy>
322323
static void populateOpPatterns(LLVMTypeConverter &converter,
323324
RewritePatternSet &patterns, StringRef f32Func,
324-
StringRef f64Func) {
325+
StringRef f64Func, StringRef f32FastFunc = "") {
325326
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
326-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
327+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
327328
}
328329

329330
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +371,53 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370371
StringAttr::get(&converter.getContext(),
371372
NVVM::NVVMDialect::getMaxntidAttrName()));
372373

374+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
375+
"__nv_fmod");
373376
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
374377
"__nv_fabs");
378+
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
379+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
380+
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
381+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
375382
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
376383
"__nv_atan");
377384
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
378385
"__nv_atan2");
386+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
379387
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
380388
"__nv_cbrt");
381389
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
382390
"__nv_ceil");
383-
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
391+
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
392+
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
393+
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
384394
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
385-
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
395+
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
386396
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
387397
"__nv_exp2");
388398
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
389399
"__nv_expm1");
390400
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
391401
"__nv_floor");
392-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
393-
"__nv_fmod");
394-
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
402+
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
403+
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
404+
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
405+
"__nv_log10", "__nv_fast_log10f");
395406
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
396407
"__nv_log1p");
397-
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
398-
"__nv_log10");
399408
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
400-
"__nv_log2");
409+
"__nv_log2", "__nv_fast_log2f");
401410
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
402-
"__nv_pow");
411+
"__nv_pow", "__nv_fast_powf");
412+
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
413+
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
403414
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
404415
"__nv_rsqrt");
405-
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
416+
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
417+
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
406418
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
407419
"__nv_sqrt");
420+
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
408421
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
409422
"__nv_tanh");
410-
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
411423
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func) {
41+
StringRef f64Func, StringRef f32FastFunc = "") {
4242
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
43-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
43+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
4444
}
4545

4646
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

0 commit comments

Comments
 (0)