Skip to content

Commit 586d3fd

Browse files
committed
[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp
1 parent 2e6558b commit 586d3fd

File tree

4 files changed

+282
-47
lines changed

4 files changed

+282
-47
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ template <typename SourceOp>
3131
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
3232
public:
3333
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34-
StringRef f64Func)
34+
StringRef f64Func, StringRef f32FastFunc)
3535
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36-
f64Func(f64Func) {}
36+
f64Func(f64Func), f32FastFunc(f32FastFunc) {}
3737

3838
LogicalResult
3939
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5555
Type resultType = castedOperands.front().getType();
5656
Type funcType = getFunctionType(resultType, castedOperands);
5757
StringRef funcName =
58-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
58+
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
59+
op.getFastmath());
5960
if (funcName.empty())
6061
return failure();
6162

@@ -90,9 +91,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
9091
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
9192
}
9293

93-
StringRef getFunctionName(Type type) const {
94-
if (isa<Float32Type>(type))
95-
return f32Func;
94+
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
95+
if (isa<Float32Type>(type)) {
96+
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
97+
return f32FastFunc;
98+
else
99+
return f32Func;
100+
}
96101
if (isa<Float64Type>(type))
97102
return f64Func;
98103
return "";
@@ -113,6 +118,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
113118

114119
const std::string f32Func;
115120
const std::string f64Func;
121+
const std::string f32FastFunc;
116122
};
117123

118124
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309309
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310310
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311311
target.addIllegalDialect<gpu::GPUDialect>();
312-
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313-
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314-
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315-
LLVM::SqrtOp>();
312+
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313+
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314+
LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315+
LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316+
LLVM::SinOp, LLVM::SqrtOp>();
316317

317318
// TODO: Remove once we support replacing non-root ops.
318319
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,10 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321322
template <typename OpTy>
322323
static void populateOpPatterns(LLVMTypeConverter &converter,
323324
RewritePatternSet &patterns, StringRef f32Func,
324-
StringRef f64Func) {
325+
StringRef f64Func, StringRef f32FastFunc = "") {
325326
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
326-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
327+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
328+
f32FastFunc);
327329
}
328330

329331
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +372,68 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370372
StringAttr::get(&converter.getContext(),
371373
NVVM::NVVMDialect::getMaxntidAttrName()));
372374

375+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
376+
"__nv_fmod");
373377
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
374378
"__nv_fabs");
379+
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
380+
"__nv_acos");
381+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
382+
"__nv_acosh");
383+
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
384+
"__nv_asin");
385+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
386+
"__nv_asinh");
375387
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
376388
"__nv_atan");
377389
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
378390
"__nv_atan2");
391+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
392+
"__nv_atanh");
379393
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
380394
"__nv_cbrt");
381395
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
382396
"__nv_ceil");
383-
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
397+
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
398+
"__nv_copysign");
399+
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
400+
"__nv_fast_cosf");
401+
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
402+
"__nv_cosh");
384403
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
385-
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
404+
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
405+
"__nv_fast_expf");
386406
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
387407
"__nv_exp2");
388408
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
389409
"__nv_expm1");
390410
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
391411
"__nv_floor");
392-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
393-
"__nv_fmod");
394-
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
412+
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
413+
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
414+
"__nv_fast_logf");
415+
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
416+
"__nv_log10", "__nv_fast_log10f");
395417
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
396418
"__nv_log1p");
397-
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
398-
"__nv_log10");
399419
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
400-
"__nv_log2");
401-
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
402-
"__nv_pow");
420+
"__nv_log2", "__nv_fast_log2f");
421+
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
422+
"__nv_fast_powf");
423+
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
424+
"__nv_round");
425+
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
426+
"__nv_rint");
403427
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
404428
"__nv_rsqrt");
405-
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
429+
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
430+
"__nv_fast_sinf");
431+
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
432+
"__nv_sinh");
406433
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
407434
"__nv_sqrt");
435+
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
436+
"__nv_fast_tanf");
408437
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
409438
"__nv_tanh");
410-
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
411439
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func) {
41+
StringRef f64Func, StringRef f32FastFunc = "") {
4242
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
43-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
43+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
44+
f32FastFunc);
4445
}
4546

4647
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

0 commit comments

Comments
 (0)