Skip to content

[mlir][spirv] Fix vector reduction lowerings for FP min/max #69053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

unterumarmung
Copy link
Contributor

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Daniil Dudkin (unterumarmung)

Changes

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.


Full diff: https://github.com/llvm/llvm-project/pull/69053.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+51-3)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+143-11)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..1d46d9503e9760d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -397,9 +397,12 @@ struct VectorReductionPattern final
     break
 
 #define INT_OR_FLOAT_CASE(kind, fop)                                           \
-  case vector::CombiningKind::kind:                                            \
-    result = rewriter.create<fop>(loc, resultType, result, next);              \
-    break
+  case vector::CombiningKind::kind: {                                          \
+    fop op = rewriter.create<fop>(loc, resultType, result, next);              \
+    result = this->generateActionForOp(rewriter, loc, resultType, op,          \
+                                       vector::CombiningKind::kind);           \
+    break;                                                                     \
+  }
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
@@ -422,6 +425,51 @@ struct VectorReductionPattern final
     rewriter.replaceOp(reduceOp, result);
     return success();
   }
+
+private:
+  enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+  template <typename Op>
+  Action getActionForOp(vector::CombiningKind kind) const {
+    constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+                            std::is_same_v<Op, spirv::CLFMinOp>;
+    switch (kind) {
+    case vector::CombiningKind::MINIMUMF:
+    case vector::CombiningKind::MAXIMUMF:
+      return Action::PropagateNaN;
+    case vector::CombiningKind::MINF:
+    case vector::CombiningKind::MAXF:
+      // CL ops already have the same semantic for NaNs as MINF/MAXF
+      // GL ops have undefined semantics for NaNs, so we need to explicitly
+      // propagate the non-NaN values
+      return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+    default:
+      return Action::Nothing;
+    }
+  }
+
+  template <typename Op>
+  Value generateActionForOp(ConversionPatternRewriter &rewriter,
+                            mlir::Location loc, Type resultType, Op op,
+                            vector::CombiningKind kind) const {
+    Action action = getActionForOp<Op>(kind);
+
+    if (action == Action::Nothing) {
+      return op;
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, lhsIsNan,
+        action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+    Value select2 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, rhsIsNan,
+        action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+    return select2;
+  }
 };
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
 //       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
 //       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
 //       CHECK:   return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
-  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
 }
 
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_minimumf
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-//       CHECK:   %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
-//       CHECK:   %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
-//       CHECK:   return %[[MIN2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -548,6 +652,34 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_maxsi
 //  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants