Skip to content

Commit 1c47fa9

Browse files
authored
[mlir][AMDGPU] Add support for AMD f16 math library calls (#108809)
In this PR we add support for AMD f16 math library calls (`__ocml_*_f16`) CC: @krzysz00 @manupak
1 parent 97b0d20 commit 1c47fa9

File tree

6 files changed

+306
-168
lines changed

6 files changed

+306
-168
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
namespace mlir {
1818

1919
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20-
/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
21-
/// Op. The function declaration is added in case it was not added before.
20+
/// `f32ApproxFunc` or `f16Func` depending on the element type and the
21+
/// fastMathFlag of that Op. The function declaration is added in case it was
22+
/// not added before.
2223
///
23-
/// If the input values are of f16 type, the value is first casted to f32, the
24-
/// function called and then the result casted back.
24+
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
25+
/// value is first casted to f32, the function called and then the result casted
26+
/// back.
2527
///
2628
/// Example with NVVM:
2729
/// %exp_f32 = math.exp %arg_f32 : f32
@@ -41,9 +43,10 @@ template <typename SourceOp>
4143
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
4244
public:
4345
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
44-
StringRef f64Func, StringRef f32ApproxFunc)
46+
StringRef f64Func, StringRef f32ApproxFunc,
47+
StringRef f16Func)
4548
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
46-
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
49+
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
4750

4851
LogicalResult
4952
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +92,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
8992
private:
9093
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
9194
Type type = operand.getType();
92-
if (!isa<Float16Type>(type))
95+
if (!isa<Float16Type, BFloat16Type>(type))
96+
return operand;
97+
98+
// if there's a f16 function, no need to cast f16 values
99+
if (!f16Func.empty() && isa<Float16Type>(type))
93100
return operand;
94101

95102
return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
102109
}
103110

104111
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
112+
if (isa<Float16Type>(type))
113+
return f16Func;
105114
if (isa<Float32Type>(type)) {
106115
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
107116
!f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
130139
const std::string f32Func;
131140
const std::string f64Func;
132141
const std::string f32ApproxFunc;
142+
const std::string f16Func;
133143
};
134144

135145
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
335335
template <typename OpTy>
336336
static void populateOpPatterns(LLVMTypeConverter &converter,
337337
RewritePatternSet &patterns, StringRef f32Func,
338-
StringRef f64Func,
339-
StringRef f32ApproxFunc = "") {
338+
StringRef f64Func, StringRef f32ApproxFunc = "",
339+
StringRef f16Func = "") {
340340
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
341341
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
342-
f32ApproxFunc);
342+
f32ApproxFunc, f16Func);
343343
}
344344

345345
void mlir::populateGpuSubgroupReduceOpLoweringPattern(

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
334334
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
335335
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
336336
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
337-
// These ops are legal for f16 and f32 type.
337+
// These ops are legal for f32 type.
338338
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
339-
return any_of(op->getOperandTypes(),
340-
llvm::IsaPred<Float16Type, Float32Type>);
339+
return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
341340
});
342341
// TODO: Remove once we support replacing non-root ops.
343342
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -346,9 +345,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
346345
template <typename OpTy>
347346
static void populateOpPatterns(LLVMTypeConverter &converter,
348347
RewritePatternSet &patterns, StringRef f32Func,
349-
StringRef f64Func) {
348+
StringRef f64Func, StringRef f32ApproxFunc,
349+
StringRef f16Func) {
350350
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
351-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
351+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
352+
f16Func);
352353
}
353354

354355
void mlir::populateGpuToROCDLConversionPatterns(

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ using namespace mlir;
3838
template <typename OpTy>
3939
static void populateOpPatterns(LLVMTypeConverter &converter,
4040
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func,
41+
StringRef f64Func, StringRef f16Func,
4242
StringRef f32ApproxFunc = "") {
4343
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
4444
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45-
f32ApproxFunc);
45+
f32ApproxFunc, f16Func);
4646
}
4747

4848
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
4949
RewritePatternSet &patterns) {
5050
// Handled by mathToLLVM: math::AbsIOp
51-
// Handled by mathToLLVM: math::AbsFIOp
51+
// Handled by mathToLLVM: math::AbsFOp
5252
// Handled by mathToLLVM: math::CopySignOp
5353
// Handled by mathToLLVM: math::CountLeadingZerosOp
5454
// Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
6363
// Handled by mathToLLVM: math::SqrtOp
6464
// Handled by mathToLLVM: math::TruncOp
6565
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
66-
"__ocml_acos_f64");
66+
"__ocml_acos_f64", "__ocml_acos_f16");
6767
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
68-
"__ocml_acosh_f64");
68+
"__ocml_acosh_f64", "__ocml_acosh_f16");
6969
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
70-
"__ocml_asin_f64");
70+
"__ocml_asin_f64", "__ocml_asin_f16");
7171
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
72-
"__ocml_asinh_f64");
72+
"__ocml_asinh_f64", "__ocml_asinh_f16");
7373
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
74-
"__ocml_atan_f64");
74+
"__ocml_atan_f64", "__ocml_atan_f16");
7575
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
76-
"__ocml_atanh_f64");
76+
"__ocml_atanh_f64", "__ocml_atanh_f16");
7777
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
78-
"__ocml_atan2_f64");
78+
"__ocml_atan2_f64", "__ocml_atan2_f16");
7979
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
80-
"__ocml_cbrt_f64");
80+
"__ocml_cbrt_f64", "__ocml_cbrt_f16");
8181
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
82-
"__ocml_ceil_f64");
82+
"__ocml_ceil_f64", "__ocml_ceil_f16");
8383
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
84-
"__ocml_cos_f64");
84+
"__ocml_cos_f64", "__ocml_cos_f16");
8585
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
86-
"__ocml_cosh_f64");
86+
"__ocml_cosh_f64", "__ocml_cosh_f16");
8787
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
88-
"__ocml_sinh_f64");
89-
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
88+
"__ocml_sinh_f64", "__ocml_sinh_f16");
89+
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
90+
"__ocml_exp_f16");
9091
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
91-
"__ocml_exp2_f64");
92+
"__ocml_exp2_f64", "__ocml_exp2_f16");
9293
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
93-
"__ocml_expm1_f64");
94+
"__ocml_expm1_f64", "__ocml_expm1_f16");
9495
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
95-
"__ocml_floor_f64");
96-
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
96+
"__ocml_floor_f64", "__ocml_floor_f16");
97+
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
98+
"__ocml_log_f16");
9799
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
98-
"__ocml_log10_f64");
100+
"__ocml_log10_f64", "__ocml_log10_f16");
99101
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
100-
"__ocml_log1p_f64");
102+
"__ocml_log1p_f64", "__ocml_log1p_f16");
101103
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
102-
"__ocml_log2_f64");
104+
"__ocml_log2_f64", "__ocml_log2_f16");
103105
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
104-
"__ocml_pow_f64");
106+
"__ocml_pow_f64", "__ocml_pow_f16");
105107
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
106-
"__ocml_rsqrt_f64");
108+
"__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
107109
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
108-
"__ocml_sin_f64");
110+
"__ocml_sin_f64", "__ocml_sin_f16");
109111
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
110-
"__ocml_tanh_f64");
112+
"__ocml_tanh_f64", "__ocml_tanh_f16");
111113
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
112-
"__ocml_tan_f64");
114+
"__ocml_tan_f64", "__ocml_tan_f16");
113115
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
114-
"__ocml_erf_f64");
116+
"__ocml_erf_f64", "__ocml_erf_f16");
115117
// Single arith pattern that needs a ROCDL call, probably not
116118
// worth creating a separate pass for it.
117119
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
118-
"__ocml_fmod_f64");
120+
"__ocml_fmod_f64", "__ocml_fmod_f16");
119121
}
120122

121123
namespace {

0 commit comments

Comments
 (0)