Skip to content

[MLIR][Vector] Refactor tests for contract -> OP transforms (4/N) #73807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

This patch refactors tests for:

vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:

  • add 2 missing cases (masked + scalable) when the operation kind is
    maxf.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect.

Implements #72834.

This patch refactors tests for:

    vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:
  * add 2 missing cases (masked + scalable) when the operation kind is
    `maxf`.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect.

Implements llvm#72834.
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors tests for:

vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:

  • add 2 missing cases (masked + scalable) when the operation kind is
    maxf.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect.

Implements #72834.


Full diff: https://github.com/llvm/llvm-project/pull/73807.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+46)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index e84a43feaff39dc..8fed1f8fb341547 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -186,6 +186,52 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
   return %0 : vector<2xf32>
 }
 
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max(
+// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<2xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>,
+                                    %x: vector<3xf32>,
+                                    %b: vector<2xf32>,
+                                    %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
+          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
+// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+                                                          %x: vector<3xf32>,
+                                                          %b: vector<[2]xf32>,
+                                                          %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
+          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // ============================================================================
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, thanks!

@banach-space banach-space merged commit 1726b65 into llvm:main Dec 1, 2023
@banach-space banach-space deleted the andrzej/update_contract_test_v2_p4 branch March 8, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants