Skip to content

Update LowerContractionToSMMLAPattern to ingnore matvec #88288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2024

Conversation

KoolJBlack
Copy link
Contributor

Patterns in LowerContractionToSMMLAPattern are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processing rhs with rank < 2:

iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed.

Updates to explicitly check the rhs rank and fail cases that cannot process.

@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-neon

Author: Kojo Acquah (KoolJBlack)

Changes

Patterns in LowerContractionToSMMLAPattern are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processing rhs with rank < 2:

iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait&lt;mlir::VectorType&gt;::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx &lt; getRank() &amp;&amp; "invalid index for shaped type"' failed.

Updates to explicitly check the rhs rank and fail cases that cannot process.


Full diff: https://github.com/llvm/llvm-project/pull/88288.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+2)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+11)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 13740225749e46..efdaeeda4fec5d 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -54,6 +54,8 @@ class LowerContractionToSMMLAPattern
     // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
+    if (rhsType.getRank() < 2)
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 46c4026d13b660..c276a5b0c2a14b 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
   return %res : vector<1x8xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-NOT: arm_neon.intr.smmla
+func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+  %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
+  %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
+  return %res : vector<8xi32>
+}

@@ -54,6 +54,8 @@ class LowerContractionToSMMLAPattern
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
if (rhsType.getRank() < 2)
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since 0-D vectors can also reach this point, could you add another check making sure that if any of the operands doesn't have a rank, we also bail out (+ test for that)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think vector.contract supports 0-D vectors. I couldn't create one without hitting errors:
error: unexpected error: 'vector.contract' op operand #0 must be vector of any type values, but got 'vector<i32>'

I did add "hasRank" check just in case.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@KoolJBlack KoolJBlack merged commit 04bf1a4 into llvm:main Apr 10, 2024
@KoolJBlack KoolJBlack deleted the rhs_matvec_fix branch April 10, 2024 17:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants