Skip to content

[mlir][spirv] Fix vector reduction lowerings for FP min/max #69025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

unterumarmung
Copy link
Contributor

@unterumarmung unterumarmung commented Oct 13, 2023

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.

Please note that this patch depends on #69023.

…NFC)

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.
@unterumarmung unterumarmung requested review from antiagainst, pzread, kuhar and dcaballe and removed request for pzread October 13, 2023 19:34
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Daniil Dudkin (unterumarmung)

Changes

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.

This patch addresses tasks 2.4 and 2.5 of the RFC.

Please note that this patch depends on #69023.


Patch is 23.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69025.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+150-29)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+143-11)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..040fa69e2e9f27b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <cstdint>
@@ -351,15 +353,13 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+template <typename Derived>
+struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const final {
     Type resultType = typeConverter->convertType(reduceOp.getType());
     if (!resultType)
       return failure();
@@ -368,9 +368,22 @@ struct VectorReductionPattern final
     if (!srcVectorType || srcVectorType.getRank() != 1)
       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
 
-    // Extract all elements.
+    SmallVector<Value> extractedElements =
+        extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+
+    const auto &self = static_cast<const Derived &>(*this);
+
+    return self.reduceExtracted(reduceOp, extractedElements, resultType,
+                                rewriter);
+  }
+
+private:
+  SmallVector<Value>
+  extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                     VectorType srcVectorType,
+                     ConversionPatternRewriter &rewriter) const {
     int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
+    SmallVector<Value> values;
     values.reserve(numElements + (adaptor.getAcc() != nullptr));
     Location loc = reduceOp.getLoc();
     for (int i = 0; i < numElements; ++i) {
@@ -381,9 +394,26 @@ struct VectorReductionPattern final
     if (Value acc = adaptor.getAcc())
       values.push_back(acc);
 
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    return values;
+  }
+};
+
+#define VECTOR_REDUCTION_BASE                                                  \
+  VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp,  \
+                                                    SPIRVSMaxOp, SPIRVSMinOp>>
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
+  using Base = VECTOR_REDUCTION_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +433,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,13 +442,105 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+#undef VECTOR_REDUCTION_BASE
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+#define MIN_MAX_PATTERN_BASE                                                   \
+  VectorReductionPatternBase<                                                  \
+      VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
+template <class SPIRVFMaxOp, class SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
+  using Base = MIN_MAX_PATTERN_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind: {                                          \
+    fop op = rewriter.create<fop>(loc, resultType, result, next);              \
+    result = this->generateActionForOp(rewriter, loc, resultType, op,          \
+                                       vector::CombiningKind::kind);           \
+    break;                                                                     \
+  }
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
 
     rewriter.replaceOp(reduceOp, result);
     return success();
   }
+
+private:
+  enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+  template <typename Op>
+  Action getActionForOp(vector::CombiningKind kind) const {
+    constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+                            std::is_same_v<Op, spirv::CLFMinOp>;
+    switch (kind) {
+    case vector::CombiningKind::MINIMUMF:
+    case vector::CombiningKind::MAXIMUMF:
+      return Action::PropagateNaN;
+    case vector::CombiningKind::MINF:
+    case vector::CombiningKind::MAXF:
+      // CL ops already have the same semantic for NaNs as MINF/MAXF
+      // GL ops have undefined semantics for NaNs, so we need to explicitly
+      // propagate the non-NaN values
+      return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+    default:
+      llvm_unreachable("Unexpected case for the switch");
+    }
+  }
+
+  template <typename Op>
+  Value generateActionForOp(ConversionPatternRewriter &rewriter,
+                            mlir::Location loc, Type resultType, Op op,
+                            vector::CombiningKind kind) const {
+    Action action = getActionForOp<Op>(kind);
+
+    if (action == Action::Nothing) {
+      return op;
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, lhsIsNan,
+        action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+    Value select2 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, rhsIsNan,
+        action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+    return select2;
+  }
 };
+#undef MIN_MAX_PATTERN_BASE
+#undef INT_OR_FLOAT_CASE
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:
@@ -604,25 +722,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-               VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorSplatPattern>(typeConverter, patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
 //       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
 //       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
 //       CHECK:   return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
-  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
 }
 
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_minimumf
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-//       CHECK:  ...
[truncated]

@TiborGY
Copy link
Contributor

TiborGY commented Feb 18, 2025

Is this PR still alive given that #69023 was superseded by #73363 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants