Skip to content

[mlir][spirv] Handle non-innerprod float vector add reductions #73476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Nov 27, 2023

Instead of extracting all individual vector components and performing a scalar summation, use spirv.Dot with the original reduction operand and a vector constant of all ones.

Instead of extracting all individial vector components and performing a
scalar summation, use `spirv.Dot` with the original reduction operand and
a vector constant of all ones.
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

Instead of extracting all individual vector components and performing a scalar summation, use spirv.Dot with the original reduction operand and a vector constant of all ones.


Full diff: https://github.com/llvm/llvm-project/pull/73476.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+26-6)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+40-4)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index ade41b0372c82f1..1db6713d8b85694 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
@@ -755,14 +756,33 @@ struct VectorReductionToFPDotProd final
     if (!resultType)
       return rewriter.notifyMatchFailure(op, "result is not a float");
 
-    auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
-    if (!mul)
-      return rewriter.notifyMatchFailure(
-          op, "reduction operand is not 'arith.mulf'");
+    auto vectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+    if (!vectorType) {
+      assert(isa<FloatType>(adaptor.getVector().getType()) &&
+             "Expected the vector to be scalarized");
+      rewriter.replaceOp(op, adaptor.getVector());
+      return success();
+    }
 
     Location loc = op.getLoc();
-    Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
-                                              mul.getRhs());
+    Value lhs;
+    Value rhs;
+    if (auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>()) {
+      lhs = mul.getLhs();
+      rhs = mul.getRhs();
+    } else {
+      // If the operand is not a mul, use a vector of ones for the dot operand
+      // to just sum up all values.
+      lhs = adaptor.getVector();
+      Attribute oneAttr =
+          rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
+      oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
+      rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
+    }
+    assert(lhs);
+    assert(rhs);
+
+    Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
     if (op.getAcc())
       res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8585d59770bfdc..022bc0114bc523b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,11 +500,11 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
 
 // -----
 
-// CHECK-LABEL: func @reduction_addf
+// CHECK-LABEL: func @reduction_addf_mulf
 //  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
 //  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
 //  CHECK:       return %[[DOT]] : f32
-func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
   %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
   %red = vector.reduction <add>, %mul : vector<4xf32> into f32
   return %red : f32
@@ -512,12 +512,12 @@ func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
-// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-LABEL: func @reduction_addf_acc_mulf
 //  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
 //  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
 //  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
 //  CHECK:       return %[[RES]] : f32
-func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+func.func @reduction_addf_acc_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
   %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
   %red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
   return %red : f32
@@ -525,6 +525,42 @@ func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc:
 
 // -----
 
+// CHECK-LABEL: func @reduction_addf
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>)
+//  CHECK:       %[[ONE:.+]] = spirv.Constant dense<1.0{{.+}}> : vector<4xf32>
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+//  CHECK:       return %[[DOT]] : f32
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>) -> f32 {
+  %red = vector.reduction <add>, %arg0 : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+//  CHECK:       %[[ONE:.+]] = spirv.Constant dense<1.0{{.*}}> : vector<4xf32>
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+//  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %acc: f32) -> f32 {
+  %red = vector.reduction <add>, %arg0, %acc : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<1xf32>)
+//  CHECK:       %[[RES:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf32> to f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
+  %red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %red : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_mul
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants