Skip to content

Commit 709b274

Browse files
unterumarmungdcaballe
authored andcommitted
[mlir][vector] Bring back maxf/minf reductions
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. In line with the mentioned RFC, this patch tackles tasks 2.3 and 2.4. It adds LLVM conversions for the `maxf`/`minf` reductions to the non-NaN-propagating LLVM intrinsics. Depends on D158618 Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D158659
1 parent 4a83125 commit 709b274

File tree

6 files changed

+62
-4
lines changed

6 files changed

+62
-4
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,14 @@ template <>
577577
struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
578578
using Type = LLVM::MinimumOp;
579579
};
580+
template <>
581+
struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
582+
using Type = LLVM::MaxNumOp;
583+
};
584+
template <>
585+
struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
586+
using Type = LLVM::MinNumOp;
587+
};
580588
} // namespace
581589

582590
template <class LLVMRedIntrinOp>
@@ -770,6 +778,12 @@ class VectorReductionOpConversion
770778
result =
771779
createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
772780
rewriter, loc, llvmType, operand, acc);
781+
} else if (kind == vector::CombiningKind::MINF) {
782+
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
783+
rewriter, loc, llvmType, operand, acc);
784+
} else if (kind == vector::CombiningKind::MAXF) {
785+
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
786+
rewriter, loc, llvmType, operand, acc);
773787
} else
774788
return failure();
775789

@@ -880,15 +894,11 @@ class MaskedReductionOpConversion
880894
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
881895
break;
882896
case vector::CombiningKind::MINF:
883-
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
884-
// NaNs/-0.0/+0.0 in the same way.
885897
result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
886898
ReductionNeutralFPMax>(
887899
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
888900
break;
889901
case vector::CombiningKind::MAXF:
890-
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
891-
// NaNs/-0.0/+0.0 in the same way.
892902
result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
893903
ReductionNeutralFPMin>(
894904
rewriter, loc, llvmType, operand, acc, maskOp.getMask());

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,30 @@ func.func @reduce_fminimum_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
13411341

13421342
// -----
13431343

1344+
func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
1345+
%0 = vector.reduction <maxf>, %arg0, %arg1 : vector<16xf32> into f32
1346+
return %0 : f32
1347+
}
1348+
// CHECK-LABEL: @reduce_fmax_f32(
1349+
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
1350+
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
1351+
// CHECK: %[[R:.*]] = llvm.intr.maxnum(%[[V]], %[[B]]) : (f32, f32) -> f32
1352+
// CHECK: return %[[R]] : f32
1353+
1354+
// -----
1355+
1356+
func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
1357+
%0 = vector.reduction <minf>, %arg0, %arg1 : vector<16xf32> into f32
1358+
return %0 : f32
1359+
}
1360+
// CHECK-LABEL: @reduce_fmin_f32(
1361+
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
1362+
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
1363+
// CHECK: %[[R:.*]] = llvm.intr.minnum(%[[V]], %[[B]]) : (f32, f32) -> f32
1364+
// CHECK: return %[[R]] : f32
1365+
1366+
// -----
1367+
13441368
func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
13451369
%0 = vector.reduction <minui>, %arg0 : vector<16xi32> into i32
13461370
return %0 : i32

mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ func.func @entry() {
3333
%3 = vector.reduction <maximumf>, %v2 : vector<64xf32> into f32
3434
vector.print %3 : f32
3535
// CHECK: 3
36+
%4 = vector.reduction <minf>, %v2 : vector<64xf32> into f32
37+
vector.print %4 : f32
38+
// CHECK: 1
39+
%5 = vector.reduction <maxf>, %v2 : vector<64xf32> into f32
40+
vector.print %5 : f32
41+
// CHECK: 3
3642

3743
return
3844
}

mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func.func @entry() {
4545
%3 = vector.reduction <maximumf>, %v9 : vector<10xf32> into f32
4646
vector.print %3 : f32
4747
// CHECK: 5
48+
%4 = vector.reduction <minf>, %v9 : vector<10xf32> into f32
49+
vector.print %4 : f32
50+
// CHECK: -16
51+
%5 = vector.reduction <maxf>, %v9 : vector<10xf32> into f32
52+
vector.print %5 : f32
53+
// CHECK: 5
4854

4955
return
5056
}

mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ func.func @entry() {
3333
%3 = vector.reduction <maximumf>, %v2 : vector<64xf64> into f64
3434
vector.print %3 : f64
3535
// CHECK: 3
36+
%4 = vector.reduction <minf>, %v2 : vector<64xf64> into f64
37+
vector.print %4 : f64
38+
// CHECK: 1
39+
%5 = vector.reduction <maxf>, %v2 : vector<64xf64> into f64
40+
vector.print %5 : f64
41+
// CHECK: 3
3642

3743
return
3844
}

mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func.func @entry() {
4545
%3 = vector.reduction <maximumf>, %v9 : vector<10xf64> into f64
4646
vector.print %3 : f64
4747
// CHECK: 5
48+
%4 = vector.reduction <minf>, %v9 : vector<10xf64> into f64
49+
vector.print %4 : f64
50+
// CHECK: -16
51+
%5 = vector.reduction <maxf>, %v9 : vector<10xf64> into f64
52+
vector.print %5 : f64
53+
// CHECK: 5
4854

4955
return
5056
}

0 commit comments

Comments
 (0)