Skip to content

Commit 3ebd797

Browse files
authored
[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend (#102971)
This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.
1 parent 3d9abfc commit 3ebd797

File tree

3 files changed

+31
-113
lines changed

3 files changed

+31
-113
lines changed

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
4848
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
4949
RewritePatternSet &patterns) {
5050
// Handled by mathToLLVM: math::AbsIOp
51+
// Handled by mathToLLVM: math::AbsFIOp
5152
// Handled by mathToLLVM: math::CopySignOp
5253
// Handled by mathToLLVM: math::CountLeadingZerosOp
5354
// Handled by mathToLLVM: math::CountTrailingZerosOp
5455
// Handled by mathToLLVM: math::CgPopOp
56+
// Handled by mathToLLVM: math::ExpOp (32-bit only)
5557
// Handled by mathToLLVM: math::FmaOp
58+
// Handled by mathToLLVM: math::LogOp (32-bit only)
5659
// FIXME: math::IPowIOp
5760
// FIXME: math::FPowIOp
5861
// Handled by mathToLLVM: math::RoundEvenOp
5962
// Handled by mathToLLVM: math::RoundOp
63+
// Handled by mathToLLVM: math::SqrtOp
6064
// Handled by mathToLLVM: math::TruncOp
61-
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
62-
"__ocml_fabs_f64");
6365
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
6466
"__ocml_acos_f64");
6567
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,14 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
8486
"__ocml_cosh_f64");
8587
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
8688
"__ocml_sinh_f64");
87-
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
88-
"__ocml_exp_f64");
89+
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
8990
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
9091
"__ocml_exp2_f64");
9192
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
9293
"__ocml_expm1_f64");
9394
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
9495
"__ocml_floor_f64");
95-
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
96-
"__ocml_log_f64");
96+
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
9797
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
9898
"__ocml_log10_f64");
9999
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +106,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
106106
"__ocml_rsqrt_f64");
107107
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
108108
"__ocml_sin_f64");
109-
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
110-
"__ocml_sqrt_f64");
111109
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
112110
"__ocml_tanh_f64");
113111
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 21 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,6 @@ gpu.module @test_module {
131131

132132
// -----
133133

134-
gpu.module @test_module {
135-
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
136-
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
137-
// CHECK-LABEL: func @gpu_fabs
138-
func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
139-
%result32 = math.absf %arg_f32 : f32
140-
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
141-
%result64 = math.absf %arg_f64 : f64
142-
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
143-
func.return %result32, %result64 : f32, f64
144-
}
145-
}
146-
147-
// -----
148-
149134
gpu.module @test_module {
150135
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
151136
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -207,17 +192,12 @@ gpu.module @test_module {
207192
// -----
208193

209194
gpu.module @test_module {
210-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
211195
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
212196
// CHECK-LABEL: func @gpu_exp
213-
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
214-
%exp_f32 = math.exp %arg_f32 : f32
215-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
216-
%result32 = math.exp %exp_f32 : f32
217-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
197+
func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
218198
%result64 = math.exp %arg_f64 : f64
219199
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
220-
func.return %result32, %result64 : f32, f64
200+
func.return %result64 : f64
221201
}
222202
}
223203

@@ -239,21 +219,20 @@ gpu.module @test_module {
239219
}
240220

241221
// -----
242-
243222
// Test that we handled properly operation with SymbolTable other than module op
244223
gpu.module @test_module {
245224
"test.symbol_scope"() ({
246225
// CHECK: test.symbol_scope
247-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
248-
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
249-
// CHECK-LABEL: func @gpu_exp
250-
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
251-
%exp_f32 = math.exp %arg_f32 : f32
252-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
253-
%result32 = math.exp %exp_f32 : f32
254-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
255-
%result64 = math.exp %arg_f64 : f64
256-
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
226+
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
227+
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
228+
// CHECK-LABEL: func @gpu_sin
229+
func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
230+
%sin_f32 = math.sin %arg_f32 : f32
231+
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
232+
%result32 = math.sin %sin_f32 : f32
233+
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
234+
%result64 = math.sin %arg_f64 : f64
235+
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
257236
func.return %result32, %result64 : f32, f64
258237
}
259238
"test.finish" () : () -> ()
@@ -280,15 +259,12 @@ gpu.module @test_module {
280259
// -----
281260

282261
gpu.module @test_module {
283-
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
284262
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
285263
// CHECK-LABEL: func @gpu_log
286-
func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
287-
%result32 = math.log %arg_f32 : f32
288-
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
264+
func.func @gpu_log(%arg_f64 : f64) -> (f64) {
289265
%result64 = math.log %arg_f64 : f64
290266
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
291-
func.return %result32, %result64 : f32, f64
267+
func.return %result64 : f64
292268
}
293269
}
294270

@@ -359,26 +335,6 @@ gpu.module @test_module {
359335

360336
// -----
361337

362-
gpu.module @test_module {
363-
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
364-
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
365-
// CHECK-LABEL: func @gpu_sqrt
366-
func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
367-
-> (f16, f32, f64) {
368-
%result16 = math.sqrt %arg_f16 : f16
369-
// CHECK: llvm.fpext %{{.*}} : f16 to f32
370-
// CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
371-
// CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
372-
%result32 = math.sqrt %arg_f32 : f32
373-
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
374-
%result64 = math.sqrt %arg_f64 : f64
375-
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
376-
func.return %result16, %result32, %result64 : f16, f32, f64
377-
}
378-
}
379-
380-
// -----
381-
382338
gpu.module @test_module {
383339
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
384340
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +428,15 @@ gpu.module @test_module {
472428
gpu.module @test_module {
473429
// CHECK-LABEL: func @gpu_unroll
474430
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
475-
%result = math.exp %arg0 : vector<4xf32>
431+
%result = math.sin %arg0 : vector<4xf32>
476432
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
477-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
433+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
478434
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
479-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
435+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
480436
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
481-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
437+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
482438
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
483-
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
439+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
484440
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
485441
// CHECK: return %[[V4]]
486442
func.return %result : vector<4xf32>
@@ -526,9 +482,9 @@ gpu.module @test_module {
526482

527483
gpu.module @module {
528484
// CHECK-LABEL: @spirv_exp
529-
// CHECK: llvm.call @__ocml_exp_f32
485+
// CHECK: llvm.call @__ocml_sin_f32
530486
spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
531-
%0 = math.exp %arg0 : vector<4xf32>
487+
%0 = math.sin %arg0 : vector<4xf32>
532488
spirv.ReturnValue %0 : vector<4xf32>
533489
}
534490
}

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,6 @@ module @test_module {
1515

1616
// -----
1717

18-
module @test_module {
19-
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
20-
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
21-
// CHECK-LABEL: func @math_absf
22-
func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
23-
%result32 = math.absf %arg_f32 : f32
24-
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
25-
%result64 = math.absf %arg_f64 : f64
26-
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
27-
func.return %result32, %result64 : f32, f64
28-
}
29-
}
30-
31-
// -----
32-
3318
module @test_module {
3419
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
3520
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -211,15 +196,12 @@ module @test_module {
211196
// -----
212197

213198
module @test_module {
214-
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
215199
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
216200
// CHECK-LABEL: func @math_exp
217-
func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
218-
%result32 = math.exp %arg_f32 : f32
219-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
201+
func.func @math_exp(%arg_f64 : f64) -> (f64) {
220202
%result64 = math.exp %arg_f64 : f64
221203
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
222-
func.return %result32, %result64 : f32, f64
204+
func.return %result64 : f64
223205
}
224206
}
225207

@@ -271,15 +253,12 @@ module @test_module {
271253
// -----
272254

273255
module @test_module {
274-
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
275256
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
276257
// CHECK-LABEL: func @math_log
277-
func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
278-
%result32 = math.log %arg_f32 : f32
279-
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
258+
func.func @math_log(%arg_f64 : f64) -> (f64) {
280259
%result64 = math.log %arg_f64 : f64
281260
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
282-
func.return %result32, %result64 : f32, f64
261+
func.return %result64 : f64
283262
}
284263
}
285264

@@ -360,21 +339,6 @@ module @test_module {
360339

361340
// -----
362341

363-
module @test_module {
364-
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
365-
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
366-
// CHECK-LABEL: func @math_sqrt
367-
func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
368-
%result32 = math.sqrt %arg_f32 : f32
369-
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
370-
%result64 = math.sqrt %arg_f64 : f64
371-
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
372-
func.return %result32, %result64 : f32, f64
373-
}
374-
}
375-
376-
// -----
377-
378342
module @test_module {
379343
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
380344
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64

0 commit comments

Comments
 (0)