Skip to content

[mlir][spirv] Split codegen for float min/max reductions and others (NFC) #69023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

unterumarmung
Copy link
Contributor

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: minf/maxf and minimumf/maximumf. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.

…NFC)

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Daniil Dudkin (unterumarmung)

Changes

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: minf/maxf and minimumf/maximumf. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.


Full diff: https://github.com/llvm/llvm-project/pull/69023.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+101-29)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..c4c0497c2d1f0fd 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+template <typename Derived>
+struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const final {
     Type resultType = typeConverter->convertType(reduceOp.getType());
     if (!resultType)
       return failure();
@@ -368,9 +367,22 @@ struct VectorReductionPattern final
     if (!srcVectorType || srcVectorType.getRank() != 1)
       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
 
-    // Extract all elements.
+    SmallVector<Value> extractedElements =
+        extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+
+    const auto &self = static_cast<const Derived &>(*this);
+
+    return self.reduceExtracted(reduceOp, extractedElements, resultType,
+                                rewriter);
+  }
+
+private:
+  SmallVector<Value>
+  extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                     VectorType srcVectorType,
+                     ConversionPatternRewriter &rewriter) const {
     int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
+    SmallVector<Value> values;
     values.reserve(numElements + (adaptor.getAcc() != nullptr));
     Location loc = reduceOp.getLoc();
     for (int i = 0; i < numElements; ++i) {
@@ -381,9 +393,26 @@ struct VectorReductionPattern final
     if (Value acc = adaptor.getAcc())
       values.push_back(acc);
 
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    return values;
+  }
+};
+
+#define VECTOR_REDUCTION_BASE                                                  \
+  VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp,  \
+                                                    SPIRVSMaxOp, SPIRVSMinOp>>
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
+  using Base = VECTOR_REDUCTION_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +432,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,6 +441,8 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
 
@@ -423,6 +450,48 @@ struct VectorReductionPattern final
     return success();
   }
 };
+#undef VECTOR_REDUCTION_BASE
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+#define MIN_MAX_PATTERN_BASE                                                   \
+  VectorReductionPatternBase<                                                  \
+      VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
+template <class SPIRVFMaxOp, class SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
+  using Base = MIN_MAX_PATTERN_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
+    break
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+#undef MIN_MAX_PATTERN_BASE
+#undef INT_OR_FLOAT_CASE
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:
@@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-               VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorSplatPattern>(typeConverter, patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

Comment on lines +400 to +402
#define VECTOR_REDUCTION_BASE \
VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
SPIRVSMaxOp, SPIRVSMinOp>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a using declaration instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no.
The macro is needed to avoid repetition between the base class and explicitly inheriting the constructor of the base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is other option though but it might look weird:

template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
          typename SPIRVSMinOp, typename..., 
          typename Base = VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, SPIRVSMaxOp, SPIRVSMinOp>>>
struct VectorReductionPattern final : Base {
  using Base::Base;
  ...
};

typename... is needed to protect from possible wrongful overrides of the Base and typename Base is used here as an alias.

Copy link
Member

@kuhar kuhar Oct 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this:

template <typename... Args>
struct VectorReductionPattern;

typename <typename... Args>
using CrtpBase = VectorReductionPatternBase<VectorReductionPattern, Args...>;

template <typename... Args>
struct VectorReductionPattern final : CrtpBase<Args...> {
  using Base = CrtpBase<Args...>; 
   using Base::Base;
   // define SPIRVUMaxOp et al. in VectorReductionPatternBase
  ...
};

just an idea, I haven't tried compiling this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it should work, I'll definitely give it a try! My idea wouldn't work because VectorReductionPattern is not declared before using it.

Copy link
Contributor Author

@unterumarmung unterumarmung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kuhar, thank you for the review!
After I submitted these patches, I've realised that I might have overengineered #69023 and #69025. Today I created a new PR, that uses implementation of #69025 without this, possibly excessive, refactoring. Please take a look at #69053 and pick your favorite 😊

Comment on lines +400 to +402
#define VECTOR_REDUCTION_BASE \
VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
SPIRVSMaxOp, SPIRVSMinOp>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no.
The macro is needed to avoid repetition between the base class and explicitly inheriting the constructor of the base class.

Comment on lines +400 to +402
#define VECTOR_REDUCTION_BASE \
VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
SPIRVSMaxOp, SPIRVSMinOp>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is other option though but it might look weird:

template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
          typename SPIRVSMinOp, typename..., 
          typename Base = VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, SPIRVSMaxOp, SPIRVSMinOp>>>
struct VectorReductionPattern final : Base {
  using Base::Base;
  ...
};

typename... is needed to protect from possible wrongful overrides of the Base and typename Base is used here as an alias.

@kuhar
Copy link
Member

kuhar commented Oct 14, 2023

@unterumarmung I played with the code a bit and have an alternative idea: if all the derived classes implement just one function, could we make that function a constructor/template parameter and drop CRTP and the resulting complexity?

@unterumarmung
Copy link
Contributor Author

@unterumarmung I played with the code a bit and have an alternative idea: if all the derived classes implement just one function, could we make that function a constructor/template parameter and drop CRTP and the resulting complexity?

It is an interesting question!

In #69053, I avoided this complexity by just not refactoring. I began the refactoring because I believed it would be more challenging to implement the additional logic of the floating-point min/max operations. I did not want to complicate the main reduction pattern, especially in the macros. However, it turned out that all I needed to do was add one line to the macro, so I created #69053.

If you want to split the floating-point min/max ops from the rest, we could explore some options:

  1. Passing the function to the constructor: It could work, but I don't have extensive experience with the MLIR infrastructure, so I'm not sure how we can introduce a constructor parameter here:
    patterns.add<
    VectorBitcastConvert, VectorBroadcastConvert,
    VectorExtractElementOpConvert, VectorExtractOpConvert,
    VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
    VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
    VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
    VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
    VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
    VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
    VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
    VectorSplatPattern>(typeConverter, patterns.getContext());
  2. Passing the function as a template parameter: it seems easier than the first option. We could implement functions as static ones and simply pass pointers (or even llvm::function_refs) to them into the template.
  3. Virtual functions: we could implement the functions as virtual ones. It'd reduce the complexity of the template machinery by introducing a negligible runtime overhead.

There might be other options, but currently I can't think of any other.

@kuhar
Copy link
Member

kuhar commented Oct 16, 2023

@unterumarmung Yeah there's no requirement for all patterns to be added together in one .add function call. If we split it into 2 or more .add calls it should work just as well.

RE constructor parameters, you just add pass them via .add<...>(-- here --). If you hit any issues and can't get this to work, feel free to ping me and I can give it a shot.

@unterumarmung
Copy link
Contributor Author

@kuhar, I apologize for disappearing for such a long time. I've tried to rewrite this by extracting just one function, but, frankly speaking, it does not make much sense to me. The reason is that we have template parameters in the class, and the reduceExtracted part depends on them. We could split the class to handle only specific pairs of Min and Max operations, but we still have the Mul and Add operations to consider, and we would have to include them for each such pair of Min and Max, which seems excessive. I believe this is why I've unconsciously chosen to refactor this code using CRTP in the first place. It appears there are only two options to solve the problem: either we proceed with the CRTP refactoring as demonstrated here, or we do not refactor at all and use the approach in #69053, which seems quite neat, in my honest opinion.

@kuhar
Copy link
Member

kuhar commented Nov 20, 2023

(Flagging so that I don't forget to come back to this. cc: @kuhar)

@kuhar
Copy link
Member

kuhar commented Nov 24, 2023

Hi @unterumarmung, I gave it a second try and came up with something simpler IMO: #73363. Please take a look and see if this solves the original problem.

kuhar added a commit that referenced this pull request Nov 24, 2023
…2. [NFC] (#73363)

This is #69023 but with
cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of
free functions

Original description by @unterumarmung:

---

This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers:
`minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these
operations should differ from that of other vector reduction kinds. This
difference arises because CL and GL operations for floating-point min
and max do not have the same semantics when handling NaNs. Therefore, we
must enforce the desired semantics with additional ops.

~~However, since the code generation for floating-point min/max
operations shares the same functionality as extracting values for the
vector, we have decided to refactor the existing code using the CRTP
pattern.~~ This change does not alter the actual behavior of the code
and is necessary for future fixes to the codegen for floating-point
min/max operations.

---------

Co-authored-by: Daniil Dudkin <[email protected]>
@kuhar
Copy link
Member

kuhar commented Nov 24, 2023

The alternative PR landed, closing this one to take it off my queue.

@kuhar kuhar closed this Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants