Skip to content

Commit dbd1e66

Browse files
committed
Revert "[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (llvm#99890)"
This reverts commit f6431f0.
1 parent caf3975 commit dbd1e66

File tree

4 files changed

+50
-312
lines changed

4 files changed

+50
-312
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
namespace mlir {
1818

19-
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20-
/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
21-
/// Op. The function declaration is added in case it was not added before.
19+
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20+
/// depending on the element type that Op operates upon. The function
21+
/// declaration is added in case it was not added before.
2222
///
2323
/// If the input values are of f16 type, the value is first casted to f32, the
2424
/// function called and then the result casted back.
@@ -28,22 +28,13 @@ namespace mlir {
2828
///
2929
/// will be transformed into
3030
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
31-
///
32-
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
33-
/// to the approximate calculation function.
34-
///
35-
/// Also example with NVVM:
36-
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
37-
///
38-
/// will be transformed into
39-
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
4031
template <typename SourceOp>
4132
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
4233
public:
4334
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
44-
StringRef f64Func, StringRef f32ApproxFunc)
35+
StringRef f64Func)
4536
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
46-
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
37+
f64Func(f64Func) {}
4738

4839
LogicalResult
4940
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -65,8 +56,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
6556
Type resultType = castedOperands.front().getType();
6657
Type funcType = getFunctionType(resultType, castedOperands);
6758
StringRef funcName =
68-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
69-
op.getFastmath());
59+
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
7060
if (funcName.empty())
7161
return failure();
7262

@@ -101,14 +91,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
10191
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
10292
}
10393

104-
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
105-
if (isa<Float32Type>(type)) {
106-
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
107-
!f32ApproxFunc.empty())
108-
return f32ApproxFunc;
109-
else
110-
return f32Func;
111-
}
94+
StringRef getFunctionName(Type type) const {
95+
if (isa<Float32Type>(type))
96+
return f32Func;
11297
if (isa<Float64Type>(type))
11398
return f64Func;
11499
return "";
@@ -129,7 +114,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
129114

130115
const std::string f32Func;
131116
const std::string f64Func;
132-
const std::string f32ApproxFunc;
133117
};
134118

135119
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,10 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309309
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310310
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311311
target.addIllegalDialect<gpu::GPUDialect>();
312-
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313-
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314-
LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315-
LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316-
LLVM::SinOp, LLVM::SqrtOp>();
312+
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313+
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314+
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315+
LLVM::SqrtOp>();
317316

318317
// TODO: Remove once we support replacing non-root ops.
319318
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -322,11 +321,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
322321
template <typename OpTy>
323322
static void populateOpPatterns(LLVMTypeConverter &converter,
324323
RewritePatternSet &patterns, StringRef f32Func,
325-
StringRef f64Func,
326-
StringRef f32ApproxFunc = "") {
324+
StringRef f64Func) {
327325
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
328-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
329-
f32ApproxFunc);
326+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
330327
}
331328

332329
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -373,68 +370,42 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
373370
StringAttr::get(&converter.getContext(),
374371
NVVM::NVVMDialect::getMaxntidAttrName()));
375372

376-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
377-
"__nv_fmod");
378373
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
379374
"__nv_fabs");
380-
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
381-
"__nv_acos");
382-
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
383-
"__nv_acosh");
384-
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
385-
"__nv_asin");
386-
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
387-
"__nv_asinh");
388375
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
389376
"__nv_atan");
390377
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
391378
"__nv_atan2");
392-
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
393-
"__nv_atanh");
394379
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
395380
"__nv_cbrt");
396381
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
397382
"__nv_ceil");
398-
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
399-
"__nv_copysign");
400-
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
401-
"__nv_fast_cosf");
402-
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
403-
"__nv_cosh");
383+
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
404384
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
405-
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
406-
"__nv_fast_expf");
385+
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
407386
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
408387
"__nv_exp2");
409388
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
410389
"__nv_expm1");
411390
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
412391
"__nv_floor");
413-
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
414-
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
415-
"__nv_fast_logf");
416-
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
417-
"__nv_log10", "__nv_fast_log10f");
392+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
393+
"__nv_fmod");
394+
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
418395
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
419396
"__nv_log1p");
397+
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
398+
"__nv_log10");
420399
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
421-
"__nv_log2", "__nv_fast_log2f");
422-
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
423-
"__nv_fast_powf");
424-
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
425-
"__nv_round");
426-
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
427-
"__nv_rint");
400+
"__nv_log2");
401+
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
402+
"__nv_pow");
428403
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
429404
"__nv_rsqrt");
430-
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
431-
"__nv_fast_sinf");
432-
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
433-
"__nv_sinh");
405+
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
434406
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
435407
"__nv_sqrt");
436-
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
437-
"__nv_fast_tanf");
438408
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
439409
"__nv_tanh");
410+
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
440411
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,9 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func,
42-
StringRef f32ApproxFunc = "") {
41+
StringRef f64Func) {
4342
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
44-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45-
f32ApproxFunc);
43+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
4644
}
4745

4846
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

0 commit comments

Comments
 (0)