-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Update LowerContractionToSMMLAPattern
to ingnore matvec
#88288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-neon Author: Kojo Acquah (KoolJBlack) ChangesPatterns in
Updates to explicitly check the rhs rank and fail cases that cannot process. Full diff: https://github.com/llvm/llvm-project/pull/88288.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 13740225749e46..efdaeeda4fec5d 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -54,6 +54,8 @@ class LowerContractionToSMMLAPattern
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
+ if (rhsType.getRank() < 2)
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 46c4026d13b660..c276a5b0c2a14b 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
return %res : vector<1x8xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-NOT: arm_neon.intr.smmla
+func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
+ %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
|
@@ -54,6 +54,8 @@ class LowerContractionToSMMLAPattern | |||
// Note: RHS is not transposed. | |||
mlir::VectorType lhsType = op.getLhsType(); | |||
mlir::VectorType rhsType = op.getRhsType(); | |||
if (rhsType.getRank() < 2) | |||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since 0-D vectors can also reach this point, could you add another check making sure that if any of the operands doesn't have a rank, we also bail out (+ test for that)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think vector.contract
supports 0-D vectors. I couldn't create one without hitting errors:
error: unexpected error: 'vector.contract' op operand #0 must be vector of any type values, but got 'vector<i32>'
I did add "hasRank" check just in case.
45a96db
to
93b8b14
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Patterns in
LowerContractionToSMMLAPattern
are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processingrhs
with rank < 2:Updates to explicitly check the rhs rank and fail cases that cannot process.