Skip to content

[mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] #73363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 24, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Nov 24, 2023

This is #69023 but with cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of free functions

Original description by @unterumarmung:


This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: minf/maxf and minimumf/maximumf. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.

unterumarmung and others added 2 commits November 24, 2023 14:35
…NFC)

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

This is #69023 but with cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of free functions

Original description by @unterumarmung:


This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: minf/maxf and minimumf/maximumf. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.


Full diff: https://github.com/llvm/llvm-project/pull/73363.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+107-38)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dcc6449d3fe8927..29bc5f1dd73787f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -27,6 +28,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+static SmallVector<Value> extractAllElements(vector::ReductionOp reduceOp,
+                                      vector::ReductionOp::Adaptor adaptor,
+                                      VectorType srcVectorType,
+                                      ConversionPatternRewriter &rewriter) {
+  int numElements = srcVectorType.getDimSize(0);
+  SmallVector<Value> values;
+  values.reserve(numElements + (adaptor.getAcc() != nullptr));
+  Location loc = reduceOp.getLoc();
+  for (int i = 0; i < numElements; ++i) {
+    values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+        loc, srcVectorType.getElementType(), adaptor.getVector(),
+        rewriter.getI32ArrayAttr({i})));
+  }
+  if (Value acc = adaptor.getAcc())
+    values.push_back(acc);
+
+  return values;
+}
+
+struct ReductionRewriteInfo {
+  Type resultType;
+  SmallVector<Value> extractedElements;
+};
+
+FailureOr<ReductionRewriteInfo> static getReductionInfo(
+    vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
+    ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
+  Type resultType = typeConverter.convertType(op.getType());
+  if (!resultType)
+    return failure();
+
+  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+  if (!srcVectorType || srcVectorType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
+
+  SmallVector<Value> extractedElements =
+      extractAllElements(op, adaptor, srcVectorType, rewriter);
+
+  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
+}
+
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resultType = typeConverter->convertType(reduceOp.getType());
-    if (!resultType)
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
       return failure();
 
-    auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
-    if (!srcVectorType || srcVectorType.getRank() != 1)
-      return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
-
-    // Extract all elements.
-    int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
-    values.reserve(numElements + (adaptor.getAcc() != nullptr));
-    Location loc = reduceOp.getLoc();
-    for (int i = 0; i < numElements; ++i) {
-      values.push_back(rewriter.create<spirv::CompositeExtractOp>(
-          loc, srcVectorType.getElementType(), adaptor.getVector(),
-          rewriter.getI32ArrayAttr({i})));
-    }
-    if (Value acc = adaptor.getAcc())
-      values.push_back(acc);
-
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    auto [resultType, extractedElements] = *reductionInfo;
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,8 +439,51 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+
+template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
+      return failure();
+
+    auto [resultType, extractedElements] = *reductionInfo;
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
+    break
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+#undef INT_OR_FLOAT_CASE
 
     rewriter.replaceOp(reduceOp, result);
     return success();
@@ -674,13 +740,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
@@ -689,8 +756,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
       VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext());

Copy link

github-actions bot commented Nov 24, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@unterumarmung unterumarmung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice, thank you for helping!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants