Skip to content

Commit 963ff1c

Browse files
ftynsewsmoses
andauthored
[mlir] lower min/maxnum to libdevice calls (#127323)
Introduce lowering from arith.minnum/maxxnum operations to the corresponding Nvidia libdevice calls. This requires to reorder pattern population methods so that the libdevice-targeting patterns are prioritized over default patterns targeting LLVM IR intrinsics from the Arith dialect. The tests are placed into a separate file as the existing gpu-to-nvvm.mlir files has a mode that forces Arith dialect operations to be preserved as is without using a separate FileCheck tag to differentiate. Co-authored-by: William Moses <[email protected]>
1 parent 256145b commit 963ff1c

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ struct LowerGpuOpsToNVVMOpsPass final
378378
RewritePatternSet llvmPatterns(m.getContext());
379379
LLVMConversionTarget target(getContext());
380380

381+
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
382+
381383
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
382384
allowedDialects.end());
383385
for (Dialect *dialect : getContext().getLoadedDialects()) {
@@ -407,7 +409,6 @@ struct LowerGpuOpsToNVVMOpsPass final
407409
llvmPatterns);
408410
}
409411

410-
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
411412
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
412413
if (this->hasRedux)
413414
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
@@ -552,6 +553,11 @@ void mlir::populateGpuToNVVMConversionPatterns(
552553

553554
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
554555
"__nv_fmod");
556+
populateOpPatterns<arith::MaxNumFOp>(converter, patterns, "__nv_fmaxf",
557+
"__nv_fmax");
558+
populateOpPatterns<arith::MinNumFOp>(converter, patterns, "__nv_fminf",
559+
"__nv_fmin");
560+
555561
populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
556562
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
557563
"__nv_fabs");
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
2+
3+
gpu.module @test_module_54 {
4+
// CHECK: llvm.func @__nv_fmaxf(f32, f32) -> f32
5+
// CHECK: llvm.func @__nv_fminf(f32, f32) -> f32
6+
// CHECK: llvm.func @__nv_fmax(f64, f64) -> f64
7+
// CHECK: llvm.func @__nv_fmin(f64, f64) -> f64
8+
// CHECK-LABEL: @gpu_fminmax
9+
func.func @gpu_fminmax(%arg1_f32: f32, %arg2_f32: f32, %arg1_f64: f64, %arg2_f64: f64)
10+
-> (f32, f32, f64, f64) {
11+
// CHECK: llvm.call @__nv_fmaxf
12+
%max_f32 = arith.maxnumf %arg1_f32, %arg2_f32 : f32
13+
// CHECK: llvm.call @__nv_fminf
14+
%min_f32 = arith.minnumf %arg1_f32, %arg2_f32 : f32
15+
// CHECK: llvm.call @__nv_fmax(
16+
%max_f64 = arith.maxnumf %arg1_f64, %arg2_f64 : f64
17+
// CHECK: llvm.call @__nv_fmin(
18+
%min_f64 = arith.minnumf %arg1_f64, %arg2_f64 : f64
19+
return %max_f32, %min_f32, %max_f64, %min_f64 : f32, f32, f64, f64
20+
}
21+
}

0 commit comments

Comments
 (0)