Skip to content

[mlir][Vector] Update v.contract -> v.outerproduct tests (2/N) #70425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 27, 2023

The remaining tests for conversions from vector.contract to
vector.outerproduct for matmul operations in:

  • "vector-contract-to-outerproduct-transforms.mlir"

are updated with cases for scalable vectors. One duplicated test is
removed.

In addition:

  • tests are re-organised so that matvec tests and matmul tests are
    "clustered" together,
  • function formatting is unified,
  • added comments to document and to seperate different cases,
  • unified the naming for matrix/vector dimensions: (i, j, k) -> (m, n,
    k),

While this does add a bit of noise to this patch, I wanted to avoid
sending seperate patches to refactor this file.

Depends on #70379

Tests for conversions from vector.contract to vector.outerproduct
for _matvec_ operations are updated with cases for scalable vectors.

This patch updates one specific test file:

    vector-contract-to-outerproduct-transforms.mlir.

The remaining _matmul_ operations in this file will be updated in a
separate patch. Only the parallel dimension is made scalable. Making the
reduction dimension scalable would lead to different patterns without
vector.outerproduct (that would need to be added to some other file).
The remaining tests for conversions from vector.contract to
vector.outerproduct for _matmul_ operations in:

  * "vector-contract-to-outerproduct-transforms.mlir"

are updated with cases for scalable vectors. One duplicated test is
removed.

In addition:

  * tests are re-organised so that _matvec_ tests and _matmul_ tests are
    "clustered" together,
  * one duplicate case for _matvec_ is removed,
  * function formatting is unified,
  * added comments to document and to seperate different cases,
  * unified the naming for matrix/vector dimensions: (i, j, k) -> (m, n,
    k),

While this does add a bit of noise to this patch, I wanted to avoid
sending seperate patches to refactor this file.

Depends on llvm#70379
@banach-space banach-space requested a review from MacDue October 27, 2023 08:08
@banach-space banach-space changed the title andrzej/add more tests for scalable p2 [mlir][Vector] Update v.contract -> v.outerproduct tests (2/N) Oct 27, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

The remaining tests for conversions from vector.contract to
vector.outerproduct for matmul operations in:

  • "vector-contract-to-outerproduct-transforms.mlir"

are updated with cases for scalable vectors. One duplicated test is
removed.

In addition:

  • tests are re-organised so that matvec tests and matmul tests are
    "clustered" together,
  • function formatting is unified,
  • added comments to document and to seperate different cases,
  • unified the naming for matrix/vector dimensions: (i, j, k) -> (m, n,
    k),

While this does add a bit of noise to this patch, I wanted to avoid
sending seperate patches to refactor this file.

Depends on #70379


Patch is 48.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70425.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+595-265)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..5f560017cad312d 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -1,36 +1,31 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-#matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
+// NOTE - tests in this file are duplicated so that there's a version for
+//    * _fixed width_ and for _scalable_ vectors.
+// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+// only the non-reduction dimension can be scalable (*). For Matmul operations
+// that is set to be the N dimension (i.e. rows of the output matrix), which
+// matches how matrix multiplication are normally implemented for e.g. 
+// Arm's SVE. However, making the M dimension scalable (i.e. columns of the
+// output matrix) should work as well.
+//
+// (*) The conversion tested in this file unrolls along the reduction
+// dimension, which is not supported for scalable vectors.
+
+// ============================================================================
+//  Matvec 1
+// ============================================================================
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
   iterator_types = ["parallel", "reduction"]
 }
 
-#matmat_accesses = [
-  affine_map<(i, j, k) -> (i, k)>,
-  affine_map<(i, j, k) -> (k, j)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
-  indexing_maps = #matmat_accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-#matmat_accesses_0 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_0 = {
-  indexing_maps = #matmat_accesses_0,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL:   func.func @masked_extract_contract2(
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
 // CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<2xf32>,
@@ -45,17 +40,16 @@
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
-func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
-                                    %arg1: vector<3xf32>,
-                                    %arg2: vector<2xf32>,
-                                    %m: vector<2x3xi1>) -> vector<2xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
+                                %arg1: vector<3xf32>,
+                                %arg2: vector<2xf32>,
+                                %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
           : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
 
-
-// CHECK-LABEL:   func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
 // CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
@@ -69,16 +63,434 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
-                                    %arg1: vector<3xf32>,
-                                    %arg2: vector<[2]xf32>,
-                                    %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+                                                      %arg1: vector<3xf32>,
+                                                      %arg2: vector<[2]xf32>,
+                                                      %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
           : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
 
-// CHECK-LABEL: func.func @masked_extract_contract4(
+// ============================================================================
+//  Matvec 2
+// ============================================================================
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>, 
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+//  Matvec 3
+// ============================================================================
+#matvec_accesses_3 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+  indexing_maps = #matvec_accesses_3,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>,
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+//  Matvec 4
+// ============================================================================
+#matvec_accesses_4 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+  indexing_maps = #matvec_accesses_4,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>,
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// ============================================================================
+//  Matvec 5
+// ============================================================================
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+//  Matvec 6
+// ============================================================================
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+//  Matvec 7
+// ============================================================================
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+//  Matvec 8
+// ============================================================================
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
+  iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> v...
[truncated]

@banach-space banach-space marked this pull request as draft October 27, 2023 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants