Skip to content

[MLIR][Vector] Refactor tests for contract -> OP transforms #73217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

banach-space
Copy link
Contributor

This patch refactors tests for :

vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Relevant tests from the following 2
files:

  • vector-contract-matvec-transforms.mlir
  • vector-contract-to-outerproduct-transforms.mlir

are combined into one:

  • vector-contract-to-outerproduct-matvec-transforms.mlir

Summary of changes (on top of moving things between files):

  • duplicate tests are removed,
  • missing cases for scalable vectors are added,
  • names of LIT variables and operation traits are unified.

Implements #72834.

This patch refactors tests for :

  `vector.contract` -> `vector.outerproduct`

for matvec operations (b += Ax). Relevant tests from the following 2
files:

  * vector-contract-matvec-transforms.mlir
  * vector-contract-to-outerproduct-transforms.mlir

are combined into one:

  * vector-contract-to-outerproduct-matvec-transforms.mlir

Summary of changes (on top of moving things between files):
  * duplicate tests are removed,
  * missing cases for scalable vectors are added,
  * names of LIT variables and operation traits are unified.

Implements llvm#72834.
@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors tests for :

vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Relevant tests from the following 2
files:

  • vector-contract-matvec-transforms.mlir
  • vector-contract-to-outerproduct-transforms.mlir

are combined into one:

  • vector-contract-to-outerproduct-matvec-transforms.mlir

Summary of changes (on top of moving things between files):

  • duplicate tests are removed,
  • missing cases for scalable vectors are added,
  • names of LIT variables and operation traits are unified.

Implements #72834.


Patch is 60.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73217.diff

3 Files Affected:

  • (removed) mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir (-170)
  • (added) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+610)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (-396)
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
deleted file mode 100644
index 811fb589792b1a8..000000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ /dev/null
@@ -1,170 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-
-#matvec_accesses = [
-  affine_map<(m, k) -> (m, k)>,
-  affine_map<(m, k) -> (k)>,
-  affine_map<(m, k) -> (m)>
-]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-#matvecmax_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"],
-  kind = #vector.kind<maxf>
-}
-
-#mattransvec_accesses = [
-  affine_map<(m, k) -> (k, m)>,
-  affine_map<(m, k) -> (k)>,
-  affine_map<(m, k) -> (m)>
-]
-#mattransvec_trait = {
-  indexing_maps = #mattransvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-#vecmat_accesses = [
-  affine_map<(m, k) -> (k)>,
-  affine_map<(m, k) -> (m, k)>,
-  affine_map<(m, k) -> (m)>
-]
-#vecmat_trait = {
-  indexing_maps = #vecmat_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-#vecmattrans_accesses = [
-  affine_map<(m, k) -> (k)>,
-  affine_map<(m, k) -> (k, m)>,
-  affine_map<(m, k) -> (m)>
-]
-#vecmattrans_trait = {
-  indexing_maps = #vecmattrans_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-#redpar_vecmattrans_accesses = [
-  affine_map<(m, k) -> (m)>,
-  affine_map<(m, k) -> (m, k)>,
-  affine_map<(m, k) -> (k)>
-]
-#redpar_vecmattrans_trait = {
-  indexing_maps = #redpar_vecmattrans_accesses,
-  iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m_max
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
-                             %x: vector<2xf32>,
-                             %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_km_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_km_k_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_mk_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_km_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_km_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_m_mk_k
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
new file mode 100644
index 000000000000000..9030cc6b1ace4e0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -0,0 +1,610 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+///   b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
+/// are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 different files. If you find the
+/// formatting inconsistent, please update accordingly.
+
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#matvecmax_trait = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"],
+  kind = #vector.kind<maxf>
+}
+
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_3 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+  indexing_maps = #matvec_accesses_3,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_4 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+  indexing_maps = #matvec_accesses_4,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// TODO: No mask
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// TODO: No mask
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// ============================================================================
+//  Matvec 1 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
+// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<2xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+                                %x: vector<3xf32>,
+                                %b: vector<2xf32>,
+                                %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
+          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
+// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+                                                      %x: vector<3xf32>,
+                                                      %b: vector<[2]xf32>,
+                                                      %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
+          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+// ============================================================================
+//  Matvec 1  - max (plain)
+// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+                             %x: vector<2xf32>,
+                             %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// ============================================================================
+//  Matvec 2 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>, 
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[A]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %A, %x, %b
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[A]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %A, %x, %b
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> i...
[truncated]

@MacDue
Copy link
Member

MacDue commented Nov 24, 2023

To make review easier in future think it'd be nice if each of:

  • duplicate tests are removed,
  • missing cases for scalable vectors are added,
  • names of LIT variables and operation traits are unified.

Were separate commits in the PR, even though the PR will be squashed into a single commit :)

@banach-space
Copy link
Contributor Author

To make review easier in future think it'd be nice if each of:

  • duplicate tests are removed,
  • missing cases for scalable vectors are added,
  • names of LIT variables and operation traits are unified.

Were separate commits in the PR, even though the PR will be squashed into a single commit :)

Haha, that was half-expected :) Let me do it as separate PRs - I am yet to get used to having separate commits as meaningful chunks for work (sorry, old habits).

Abandoning in favour of #73348 and follow-ups therein.

@banach-space banach-space deleted the andrzej/update_contract_test branch March 8, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants