Skip to content

[mlir][vector][nfc] Refactor vector.contract matvec tests #72832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 20, 2023

Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:

  • "vector-contract-to-outerproduct-transforms.mlir".

This is to enable further refactoring in a follow-up patch, namely to:

  • remove duplication (this will be much easier once consistent naming
    is used),
  • extend tests in "vector-contract-matvec-transforms.mlir" with cases
    for scalable vectors,
  • merge "vector-contract-matvec-transforms.mlir" and
    "vector-contract-to-outerproduct-transforms.mlir" (there's no need
    for 2 different files testing identical transformations).

Overview of changes in this patch:

  1. Simplify the test by removing MemRef wrappers - this test verifies
    Vector -> Vector transformations and MemRefs are not needed.
  2. Use (m, k) indices instead of (i, j).
  3. Rename function names.

This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect. Implements #72834.

Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:
 * "vector-contract-to-outerproduct-transforms.mlir".

This is to enable further refactoring in a follow-up patch, namely to:
  * remove duplication (this will be much easier once consistent naming
    is used),
  * extend tests in "vector-contract-matvec-transforms.mlir" with cases
    for scalable vectors,
  * merge "vector-contract-matvec-transforms.mlir" and
    "vector-contract-to-outerproduct-transforms.mlir"  (there's no need
    for 2 different files testing identical transformations).

Overview of changes in this patch:
  1. Simplify the test by removing MemRef wrappers - this test verifies
     Vector -> Vector transformations and  MemRefs are not needed.
  2. Use (m, k) indices instead of (i, j).
  3. Rename function names.

This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect.
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:

  • "vector-contract-to-outerproduct-transforms.mlir".

This is to enable further refactoring in a follow-up patch, namely to:

  • remove duplication (this will be much easier once consistent naming
    is used),
  • extend tests in "vector-contract-matvec-transforms.mlir" with cases
    for scalable vectors,
  • merge "vector-contract-matvec-transforms.mlir" and
    "vector-contract-to-outerproduct-transforms.mlir" (there's no need
    for 2 different files testing identical transformations).

Overview of changes in this patch:

  1. Simplify the test by removing MemRef wrappers - this test verifies
    Vector -> Vector transformations and MemRefs are not needed.
  2. Use (m, k) indices instead of (i, j).
  3. Rename function names.

This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect.


Full diff: https://github.com/llvm/llvm-project/pull/72832.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir (+90-138)
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index cfcb14a477b6b71..811fb589792b1a8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 #matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #matvec_trait = {
   indexing_maps = #matvec_accesses,
@@ -16,9 +16,9 @@
 }
 
 #mattransvec_accesses = [
-  affine_map<(i, j) -> (j, i)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #mattransvec_trait = {
   indexing_maps = #mattransvec_accesses,
@@ -26,9 +26,9 @@
 }
 
 #vecmat_accesses = [
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #vecmat_trait = {
   indexing_maps = #vecmat_accesses,
@@ -36,9 +36,9 @@
 }
 
 #vecmattrans_accesses = [
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (j, i)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
 ]
 #vecmattrans_trait = {
   indexing_maps = #vecmattrans_accesses,
@@ -46,166 +46,118 @@
 }
 
 #redpar_vecmattrans_accesses = [
-  affine_map<(i, j) -> (i)>,
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>
+  affine_map<(m, k) -> (m)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>
 ]
 #redpar_vecmattrans_trait = {
   indexing_maps = #redpar_vecmattrans_accesses,
   iterator_types = ["reduction", "parallel"]
 }
 
-// CHECK-LABEL: func @matvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @matvecmax2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                   %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+                             %x: vector<2xf32>,
+                             %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @mattransvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @vecmat2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @redpar_vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
 module attributes {transform.with_named_sequence} {

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:

  • "vector-contract-to-outerproduct-transforms.mlir".

This is to enable further refactoring in a follow-up patch, namely to:

  • remove duplication (this will be much easier once consistent naming
    is used),
  • extend tests in "vector-contract-matvec-transforms.mlir" with cases
    for scalable vectors,
  • merge "vector-contract-matvec-transforms.mlir" and
    "vector-contract-to-outerproduct-transforms.mlir" (there's no need
    for 2 different files testing identical transformations).

Overview of changes in this patch:

  1. Simplify the test by removing MemRef wrappers - this test verifies
    Vector -> Vector transformations and MemRefs are not needed.
  2. Use (m, k) indices instead of (i, j).
  3. Rename function names.

This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect.


Full diff: https://github.com/llvm/llvm-project/pull/72832.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir (+90-138)
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index cfcb14a477b6b71..811fb589792b1a8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 #matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #matvec_trait = {
   indexing_maps = #matvec_accesses,
@@ -16,9 +16,9 @@
 }
 
 #mattransvec_accesses = [
-  affine_map<(i, j) -> (j, i)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #mattransvec_trait = {
   indexing_maps = #mattransvec_accesses,
@@ -26,9 +26,9 @@
 }
 
 #vecmat_accesses = [
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
 ]
 #vecmat_trait = {
   indexing_maps = #vecmat_accesses,
@@ -36,9 +36,9 @@
 }
 
 #vecmattrans_accesses = [
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (j, i)>,
-  affine_map<(i, j) -> (i)>
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
 ]
 #vecmattrans_trait = {
   indexing_maps = #vecmattrans_accesses,
@@ -46,166 +46,118 @@
 }
 
 #redpar_vecmattrans_accesses = [
-  affine_map<(i, j) -> (i)>,
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>
+  affine_map<(m, k) -> (m)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>
 ]
 #redpar_vecmattrans_trait = {
   indexing_maps = #redpar_vecmattrans_accesses,
   iterator_types = ["reduction", "parallel"]
 }
 
-// CHECK-LABEL: func @matvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @matvecmax2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                   %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+                             %x: vector<2xf32>,
+                             %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @mattransvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @vecmat2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func @redpar_vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
-                                                     %arg2: memref<vector<2xf32>>) {
-  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
-  %x = memref.load %arg1[] : memref<vector<2xf32>>
-  %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  memref.store %0, %arg2[] : memref<vector<2xf32>>
-  return
+  return %0 : vector<2xf32>
 }
 
 module attributes {transform.with_named_sequence} {

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I don't spot anything off :)

@banach-space banach-space merged commit 730e0d0 into llvm:main Nov 21, 2023
@banach-space banach-space deleted the andrzej/refactor_contract_to_matvec branch March 16, 2024 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants