Skip to content

[mlir][ArmSME] Fail instead of error in vector.outerproduct lowering #75447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.

@c-rhodes c-rhodes requested a review from MacDue December 14, 2023 09:21
@c-rhodes
Copy link
Collaborator Author

Depends on #75446

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.
@c-rhodes c-rhodes force-pushed the mlir-arm-sme-outerproduct-change-emit-error branch from 91483f3 to 17d111f Compare December 14, 2023 11:01
@c-rhodes c-rhodes marked this pull request as ready for review December 14, 2023 11:02
@llvmbot llvmbot added the mlir label Dec 14, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2023

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.


Full diff: https://github.com/llvm/llvm-project/pull/75447.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+6-4)
  • (modified) mlir/test/Conversion/VectorToArmSME/unsupported.mlir (+9-3)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
     // We don't yet support lowering AXPY operations to SME. These could be
     // lowered by masking out all but the first element of the LHS.
     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      return outerProductOp.emitError("AXPY operations not supported");
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         "AXPY operations not supported");
 
     if (!arm_sme::isValidSMETileVectorType(
             outerProductOp.getResultVectorType()))
-      return outerProductOp.emitError(
-          "outer product does not fit into SME tile");
+      return rewriter.notifyMatchFailure(
+          outerProductOp, "outer product does not fit into SME tile");
 
     auto kind = outerProductOp.getKind();
     if (kind != vector::CombiningKind::ADD)
-      return outerProductOp.emitError(
+      return rewriter.notifyMatchFailure(
+          outerProductOp,
           "unsupported kind (lowering to SME only supports ADD at the moment)");
 
     Value lhsMask = {};
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error@+1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
   %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
-  // expected-error@+1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // CHECK: vector.outerproduct
   %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this 👍

@c-rhodes c-rhodes merged commit e7432ba into llvm:main Dec 15, 2023
@c-rhodes c-rhodes deleted the mlir-arm-sme-outerproduct-change-emit-error branch December 15, 2023 07:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants