Skip to content

Commit 8fe4b0e

Browse files
committed
[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend
This patch removes pattens for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.
1 parent b4bc7b1 commit 8fe4b0e

File tree

3 files changed

+21
-153
lines changed

3 files changed

+21
-153
lines changed

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
4848
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
4949
RewritePatternSet &patterns) {
5050
// Handled by mathToLLVM: math::AbsIOp
51+
// Handled by mathToLLVM: math::AbsFIOp
5152
// Handled by mathToLLVM: math::CopySignOp
5253
// Handled by mathToLLVM: math::CountLeadingZerosOp
5354
// Handled by mathToLLVM: math::CountTrailingZerosOp
5455
// Handled by mathToLLVM: math::CgPopOp
56+
// Handled by mathToLLVM: math::ExpOp
5557
// Handled by mathToLLVM: math::FmaOp
58+
// Handled by mathToLLVM: math::LogOp
5659
// FIXME: math::IPowIOp
5760
// FIXME: math::FPowIOp
5861
// Handled by mathToLLVM: math::RoundEvenOp
5962
// Handled by mathToLLVM: math::RoundOp
63+
// Handled by mathToLLVM: math::SqrtOp
6064
// Handled by mathToLLVM: math::TruncOp
61-
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
62-
"__ocml_fabs_f64");
6365
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
6466
"__ocml_acos_f64");
6567
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
8486
"__ocml_cosh_f64");
8587
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
8688
"__ocml_sinh_f64");
87-
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
88-
"__ocml_exp_f64");
8989
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
9090
"__ocml_exp2_f64");
9191
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
9292
"__ocml_expm1_f64");
9393
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
9494
"__ocml_floor_f64");
95-
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
96-
"__ocml_log_f64");
9795
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
9896
"__ocml_log10_f64");
9997
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
106104
"__ocml_rsqrt_f64");
107105
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
108106
"__ocml_sin_f64");
109-
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
110-
"__ocml_sqrt_f64");
111107
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
112108
"__ocml_tanh_f64");
113109
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 17 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,6 @@ gpu.module @test_module {
131131

132132
// -----
133133

134-
gpu.module @test_module {
135-
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
136-
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
137-
// CHECK-LABEL: func @gpu_fabs
138-
func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
139-
%result32 = math.absf %arg_f32 : f32
140-
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
141-
%result64 = math.absf %arg_f64 : f64
142-
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
143-
func.return %result32, %result64 : f32, f64
144-
}
145-
}
146-
147-
// -----
148-
149134
gpu.module @test_module {
150135
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
151136
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
206191

207192
// -----
208193

209-
gpu.module @test_module {
210-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
211-
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
212-
// CHECK-LABEL: func @gpu_exp
213-
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
214-
%exp_f32 = math.exp %arg_f32 : f32
215-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
216-
%result32 = math.exp %exp_f32 : f32
217-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
218-
%result64 = math.exp %arg_f64 : f64
219-
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
220-
func.return %result32, %result64 : f32, f64
221-
}
222-
}
223-
224-
// -----
225-
226194
gpu.module @test_module {
227195
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
228196
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
239207
}
240208

241209
// -----
242-
243210
// Test that we handled properly operation with SymbolTable other than module op
244211
gpu.module @test_module {
245212
"test.symbol_scope"() ({
246213
// CHECK: test.symbol_scope
247-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
248-
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
249-
// CHECK-LABEL: func @gpu_exp
250-
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
251-
%exp_f32 = math.exp %arg_f32 : f32
252-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
253-
%result32 = math.exp %exp_f32 : f32
254-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
255-
%result64 = math.exp %arg_f64 : f64
256-
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
214+
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
215+
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
216+
// CHECK-LABEL: func @gpu_sin
217+
func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
218+
%sin_f32 = math.sin %arg_f32 : f32
219+
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
220+
%result32 = math.sin %sin_f32 : f32
221+
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
222+
%result64 = math.sin %arg_f64 : f64
223+
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
257224
func.return %result32, %result64 : f32, f64
258225
}
259226
"test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
279246

280247
// -----
281248

282-
gpu.module @test_module {
283-
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
284-
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
285-
// CHECK-LABEL: func @gpu_log
286-
func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
287-
%result32 = math.log %arg_f32 : f32
288-
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
289-
%result64 = math.log %arg_f64 : f64
290-
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
291-
func.return %result32, %result64 : f32, f64
292-
}
293-
}
294-
295-
// -----
296-
297249
gpu.module @test_module {
298250
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
299251
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
359311

360312
// -----
361313

362-
gpu.module @test_module {
363-
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
364-
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
365-
// CHECK-LABEL: func @gpu_sqrt
366-
func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
367-
-> (f16, f32, f64) {
368-
%result16 = math.sqrt %arg_f16 : f16
369-
// CHECK: llvm.fpext %{{.*}} : f16 to f32
370-
// CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
371-
// CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
372-
%result32 = math.sqrt %arg_f32 : f32
373-
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
374-
%result64 = math.sqrt %arg_f64 : f64
375-
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
376-
func.return %result16, %result32, %result64 : f16, f32, f64
377-
}
378-
}
379-
380-
// -----
381-
382314
gpu.module @test_module {
383315
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
384316
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
472404
gpu.module @test_module {
473405
// CHECK-LABEL: func @gpu_unroll
474406
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
475-
%result = math.exp %arg0 : vector<4xf32>
407+
%result = math.sin %arg0 : vector<4xf32>
476408
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
477-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
409+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
478410
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
479-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
411+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
480412
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
481-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
413+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
482414
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
483-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
415+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
484416
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
485417
// CHECK: return %[[V4]]
486418
func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
526458

527459
gpu.module @module {
528460
// CHECK-LABEL: @spirv_exp
529-
// CHECK: llvm.call @__ocml_exp_f32
461+
// CHECK: llvm.call @__ocml_sin_f32
530462
spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
531-
%0 = math.exp %arg0 : vector<4xf32>
463+
%0 = math.sin %arg0 : vector<4xf32>
532464
spirv.ReturnValue %0 : vector<4xf32>
533465
}
534466
}

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,6 @@ module @test_module {
1515

1616
// -----
1717

18-
module @test_module {
19-
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
20-
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
21-
// CHECK-LABEL: func @math_absf
22-
func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
23-
%result32 = math.absf %arg_f32 : f32
24-
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
25-
%result64 = math.absf %arg_f64 : f64
26-
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
27-
func.return %result32, %result64 : f32, f64
28-
}
29-
}
30-
31-
// -----
32-
3318
module @test_module {
3419
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
3520
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
210195

211196
// -----
212197

213-
module @test_module {
214-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
215-
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
216-
// CHECK-LABEL: func @math_exp
217-
func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
218-
%result32 = math.exp %arg_f32 : f32
219-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
220-
%result64 = math.exp %arg_f64 : f64
221-
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
222-
func.return %result32, %result64 : f32, f64
223-
}
224-
}
225-
226-
// -----
227-
228198
module @test_module {
229199
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
230200
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
270240

271241
// -----
272242

273-
module @test_module {
274-
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
275-
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
276-
// CHECK-LABEL: func @math_log
277-
func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
278-
%result32 = math.log %arg_f32 : f32
279-
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
280-
%result64 = math.log %arg_f64 : f64
281-
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
282-
func.return %result32, %result64 : f32, f64
283-
}
284-
}
285-
286-
// -----
287-
288243
module @test_module {
289244
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
290245
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
360315

361316
// -----
362317

363-
module @test_module {
364-
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
365-
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
366-
// CHECK-LABEL: func @math_sqrt
367-
func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
368-
%result32 = math.sqrt %arg_f32 : f32
369-
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
370-
%result64 = math.sqrt %arg_f64 : f64
371-
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
372-
func.return %result32, %result64 : f32, f64
373-
}
374-
}
375-
376-
// -----
377-
378318
module @test_module {
379319
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
380320
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64

0 commit comments

Comments
 (0)