@@ -38,17 +38,17 @@ using namespace mlir;
38
38
template <typename OpTy>
39
39
static void populateOpPatterns (LLVMTypeConverter &converter,
40
40
RewritePatternSet &patterns, StringRef f32Func,
41
- StringRef f64Func,
41
+ StringRef f64Func, StringRef f16Func,
42
42
StringRef f32ApproxFunc = " " ) {
43
43
patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
44
44
patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45
- f32ApproxFunc);
45
+ f32ApproxFunc, f16Func );
46
46
}
47
47
48
48
void mlir::populateMathToROCDLConversionPatterns (LLVMTypeConverter &converter,
49
49
RewritePatternSet &patterns) {
50
50
// Handled by mathToLLVM: math::AbsIOp
51
- // Handled by mathToLLVM: math::AbsFIOp
51
+ // Handled by mathToLLVM: math::AbsFOp
52
52
// Handled by mathToLLVM: math::CopySignOp
53
53
// Handled by mathToLLVM: math::CountLeadingZerosOp
54
54
// Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
63
63
// Handled by mathToLLVM: math::SqrtOp
64
64
// Handled by mathToLLVM: math::TruncOp
65
65
populateOpPatterns<math::AcosOp>(converter, patterns, " __ocml_acos_f32" ,
66
- " __ocml_acos_f64" );
66
+ " __ocml_acos_f64" , " __ocml_acos_f16 " );
67
67
populateOpPatterns<math::AcoshOp>(converter, patterns, " __ocml_acosh_f32" ,
68
- " __ocml_acosh_f64" );
68
+ " __ocml_acosh_f64" , " __ocml_acosh_f16 " );
69
69
populateOpPatterns<math::AsinOp>(converter, patterns, " __ocml_asin_f32" ,
70
- " __ocml_asin_f64" );
70
+ " __ocml_asin_f64" , " __ocml_asin_f16 " );
71
71
populateOpPatterns<math::AsinhOp>(converter, patterns, " __ocml_asinh_f32" ,
72
- " __ocml_asinh_f64" );
72
+ " __ocml_asinh_f64" , " __ocml_asinh_f16 " );
73
73
populateOpPatterns<math::AtanOp>(converter, patterns, " __ocml_atan_f32" ,
74
- " __ocml_atan_f64" );
74
+ " __ocml_atan_f64" , " __ocml_atan_f16 " );
75
75
populateOpPatterns<math::AtanhOp>(converter, patterns, " __ocml_atanh_f32" ,
76
- " __ocml_atanh_f64" );
76
+ " __ocml_atanh_f64" , " __ocml_atanh_f16 " );
77
77
populateOpPatterns<math::Atan2Op>(converter, patterns, " __ocml_atan2_f32" ,
78
- " __ocml_atan2_f64" );
78
+ " __ocml_atan2_f64" , " __ocml_atan2_f16 " );
79
79
populateOpPatterns<math::CbrtOp>(converter, patterns, " __ocml_cbrt_f32" ,
80
- " __ocml_cbrt_f64" );
80
+ " __ocml_cbrt_f64" , " __ocml_cbrt_f16 " );
81
81
populateOpPatterns<math::CeilOp>(converter, patterns, " __ocml_ceil_f32" ,
82
- " __ocml_ceil_f64" );
82
+ " __ocml_ceil_f64" , " __ocml_ceil_f16 " );
83
83
populateOpPatterns<math::CosOp>(converter, patterns, " __ocml_cos_f32" ,
84
- " __ocml_cos_f64" );
84
+ " __ocml_cos_f64" , " __ocml_cos_f16 " );
85
85
populateOpPatterns<math::CoshOp>(converter, patterns, " __ocml_cosh_f32" ,
86
- " __ocml_cosh_f64" );
86
+ " __ocml_cosh_f64" , " __ocml_cosh_f16 " );
87
87
populateOpPatterns<math::SinhOp>(converter, patterns, " __ocml_sinh_f32" ,
88
- " __ocml_sinh_f64" );
89
- populateOpPatterns<math::ExpOp>(converter, patterns, " " , " __ocml_exp_f64" );
88
+ " __ocml_sinh_f64" , " __ocml_sinh_f16" );
89
+ populateOpPatterns<math::ExpOp>(converter, patterns, " " , " __ocml_exp_f64" ,
90
+ " __ocml_exp_f16" );
90
91
populateOpPatterns<math::Exp2Op>(converter, patterns, " __ocml_exp2_f32" ,
91
- " __ocml_exp2_f64" );
92
+ " __ocml_exp2_f64" , " __ocml_exp2_f16 " );
92
93
populateOpPatterns<math::ExpM1Op>(converter, patterns, " __ocml_expm1_f32" ,
93
- " __ocml_expm1_f64" );
94
+ " __ocml_expm1_f64" , " __ocml_expm1_f16 " );
94
95
populateOpPatterns<math::FloorOp>(converter, patterns, " __ocml_floor_f32" ,
95
- " __ocml_floor_f64" );
96
- populateOpPatterns<math::LogOp>(converter, patterns, " " , " __ocml_log_f64" );
96
+ " __ocml_floor_f64" , " __ocml_floor_f16" );
97
+ populateOpPatterns<math::LogOp>(converter, patterns, " " , " __ocml_log_f64" ,
98
+ " __ocml_log_f16" );
97
99
populateOpPatterns<math::Log10Op>(converter, patterns, " __ocml_log10_f32" ,
98
- " __ocml_log10_f64" );
100
+ " __ocml_log10_f64" , " __ocml_log10_f16 " );
99
101
populateOpPatterns<math::Log1pOp>(converter, patterns, " __ocml_log1p_f32" ,
100
- " __ocml_log1p_f64" );
102
+ " __ocml_log1p_f64" , " __ocml_log1p_f16 " );
101
103
populateOpPatterns<math::Log2Op>(converter, patterns, " __ocml_log2_f32" ,
102
- " __ocml_log2_f64" );
104
+ " __ocml_log2_f64" , " __ocml_log2_f16 " );
103
105
populateOpPatterns<math::PowFOp>(converter, patterns, " __ocml_pow_f32" ,
104
- " __ocml_pow_f64" );
106
+ " __ocml_pow_f64" , " __ocml_pow_f16 " );
105
107
populateOpPatterns<math::RsqrtOp>(converter, patterns, " __ocml_rsqrt_f32" ,
106
- " __ocml_rsqrt_f64" );
108
+ " __ocml_rsqrt_f64" , " __ocml_rsqrt_f16 " );
107
109
populateOpPatterns<math::SinOp>(converter, patterns, " __ocml_sin_f32" ,
108
- " __ocml_sin_f64" );
110
+ " __ocml_sin_f64" , " __ocml_sin_f16 " );
109
111
populateOpPatterns<math::TanhOp>(converter, patterns, " __ocml_tanh_f32" ,
110
- " __ocml_tanh_f64" );
112
+ " __ocml_tanh_f64" , " __ocml_tanh_f16 " );
111
113
populateOpPatterns<math::TanOp>(converter, patterns, " __ocml_tan_f32" ,
112
- " __ocml_tan_f64" );
114
+ " __ocml_tan_f64" , " __ocml_tan_f16 " );
113
115
populateOpPatterns<math::ErfOp>(converter, patterns, " __ocml_erf_f32" ,
114
- " __ocml_erf_f64" );
116
+ " __ocml_erf_f64" , " __ocml_erf_f16 " );
115
117
// Single arith pattern that needs a ROCDL call, probably not
116
118
// worth creating a separate pass for it.
117
119
populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
118
- " __ocml_fmod_f64" );
120
+ " __ocml_fmod_f64" , " __ocml_fmod_f16 " );
119
121
}
120
122
121
123
namespace {
0 commit comments