Skip to content

[MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) #73447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 26, 2023

This patch refactors tests for:

  vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:

  • names of LIT variables are unified,
  • "plain" tests (i.e. without masking and with fixed-width vectors)
    are moved to the top of their respective sections,
  • missing "plain" cases are added.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.

Implements #72834.

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This is a direct follow-up of #73348. The matvec trait that's used for
@<!-- -->matvec_m_mk_k was incorrectly updated from:

  affine_map&lt;(m, k) -&gt; (m)&gt;,
  affine_map&lt;(m, k) -&gt; (m, k)&gt;,
  affine_map&lt;(m, k) -&gt; (k)&gt;
]
  indexing_maps = #redpar_vecmattrans_accesses,
  iterator_types = ["reduction", "parallel"]
}

to:

  affine_map&lt;(m, k) -&gt; (k)&gt;,
  affine_map&lt;(m, k) -&gt; (k, m)&gt;,
  affine_map&lt;(m, k) -&gt; (m)&gt;
]
  indexing_maps = #matvec_accesses_4,
  iterator_types = ["parallel", "reduction"]
}

Note that these traits describe identical matvec operation, hence the
CHECK lines are identical for both.

Also, #redpar_vecmattrans_trait is identical to #matvec_trait_8
that's already present in:

  • "vector-contract-to-outerproduct-matvec-transforms.mlir"

For this reason:

  • @<!-- -->matvec_m_mk_k is moved near other tests that already use #matvec_trait_8,
  • #redpar_vecmattrans_trait is replaced #matvec_trait_8.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.

Implements #72834.


Patch is 39.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73447.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+292-215)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index 3ca3d344c1abe04..e84a43feaff39dc 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -1,5 +1,17 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+///   b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
+/// are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 different files. If you find the
+/// formatting inconsistent, please update accordingly.
+
 #matvec_accesses_1 = [
   affine_map<(m, k) -> (m, k)>,
   affine_map<(m, k) -> (k)>,
@@ -46,19 +58,67 @@
   iterator_types = ["parallel", "reduction"]
 }
 
-#redpar_vecmattrans_accesses = [
-  affine_map<(m, k) -> (m)>,
-  affine_map<(m, k) -> (m, k)>,
-  affine_map<(m, k) -> (k)>
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
 ]
-#redpar_vecmattrans_trait = {
-  indexing_maps = #redpar_vecmattrans_accesses,
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
   iterator_types = ["reduction", "parallel"]
 }
 
 // ============================================================================
 //  Matvec 1 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
 // CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
@@ -73,12 +133,11 @@
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
-                                %arg1: vector<3xf32>,
-                                %arg2: vector<2xf32>,
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+                                %x: vector<3xf32>,
+                                %b: vector<2xf32>,
                                 %m: vector<2x3xi1>) -> vector<2xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
           : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
@@ -97,46 +156,28 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
-                                                      %arg1: vector<3xf32>,
-                                                      %arg2: vector<[2]xf32>,
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+                                                      %x: vector<3xf32>,
+                                                      %b: vector<[2]xf32>,
                                                       %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
           : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
 
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
 // ============================================================================
 //  Matvec 1  - max (plain)
 // ============================================================================
 // CHECK-LABEL: func @matvec_mk_k_m_max
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
                              %x: vector<2xf32>,
@@ -149,38 +190,38 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================
 // CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>, 
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>, 
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_2 %A, %x, %b
       : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_2 %A, %x, %b
       : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
@@ -188,13 +229,13 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
 
 // CHECK-LABEL: func @matvec_km_k_m
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 func.func @matvec_km_k_m(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
@@ -207,54 +248,53 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>,
 //  Matvec 3 (plain + masked + scalable)
 // ============================================================================
 // CHECK-LABEL: @masked_matvec_k_mk_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>,
+func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+      vector.contract #matvec_trait_3 %x, %A, %b
         : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+      vector.contract #matvec_trait_3 %x, %A, %b
         : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
-
 // CHECK-LABEL: func @matvec_k_mk_m
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
                          %x: vector<2xf32>,
@@ -266,253 +306,290 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
 // ============================================================================
 //  Matvec 4 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_4 %x, %A, %b
       : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_k_km_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>,
+func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
+                                %x: vector<2xf32>,
+                  ...
[truncated]

@banach-space banach-space changed the title [MLIR][Vector] Refactor tests for contract -> OP transforms (2/N) [MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) Nov 26, 2023
@banach-space
Copy link
Contributor Author

Depends on #73445 - only review the most recent commit.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had a look through and it all looks good to me (though it's a little tricky to diff these changes :)).

This patch refactors tests for:

  vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:
  * names of LIT variables are unified,
  * "plain" tests (i.e. without masking and with fixed-width vectors)
    are moved to the top of their respective sections,
  * missing "plain" cases are added.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.

Implements llvm#72834.
@banach-space banach-space force-pushed the andrzej/update_contract_test_v2_p3 branch from 2854e83 to ddd89c3 Compare November 28, 2023 18:55
@banach-space banach-space merged commit 9619a24 into llvm:main Nov 29, 2023
@banach-space banach-space deleted the andrzej/update_contract_test_v2_p3 branch March 8, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants