Skip to content

[mlir][nfc] Update tests for Contract -> Op transforms #76054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 21, 2023

Conversation

banach-space
Copy link
Contributor

Updates two tests for vector.contract -> vector.outerproduct
transformations:

  1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
    "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
    name more accurate captures what's being tested. it is also
    consistent with
    "vector-contract-to-outerproduct-matvec-transforms.mlir", which
    covers vector matvec operations and makes finding relevant tests
    easier.

  2. For matmul tests, move the traits definining the iteration spaces to
    the top of the file. This is consistent with how matvec tests are
    defined and also makes it easy to quickly identify what cases are
    covered.

  3. For matmul tests, use more meaningful names for function arguments.
    This helps keep things consistent across the file (i.e. function
    definitions wih check lines and comments).

  4. For matvec test, move a few tests around so that the most basic case
    (without masking) is first.

  5. Update comments.

Updates two tests for vector.contract -> vector.outerproduct
transformations:

1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
   "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
   name more accurate captures what's being tested. it is also
   consistent with
   "vector-contract-to-outerproduct-matvec-transforms.mlir", which
   covers vector matvec operations and makes finding relevant tests
   easier.

2. For matmul tests, move the traits definining the iteration spaces to
   the top of the file. This is consistent with how matvec tests are
   defined and also makes it easy to quickly identify what cases are
   covered.

3. For matmul tests, use more meaningful names for function arguments.
   This helps keep things consistent across the file (i.e. function
   definitions wih check lines and comments).

4. For matvec test, move a few tests around so that the most basic case
   (without masking) is first.

5. Update comments.
Re-order matvec tests so that the one without masking is always first
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Updates two tests for vector.contract -> vector.outerproduct
transformations:

  1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
    "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
    name more accurate captures what's being tested. it is also
    consistent with
    "vector-contract-to-outerproduct-matvec-transforms.mlir", which
    covers vector matvec operations and makes finding relevant tests
    easier.

  2. For matmul tests, move the traits definining the iteration spaces to
    the top of the file. This is consistent with how matvec tests are
    defined and also makes it easy to quickly identify what cases are
    covered.

  3. For matmul tests, use more meaningful names for function arguments.
    This helps keep things consistent across the file (i.e. function
    definitions wih check lines and comments).

  4. For matvec test, move a few tests around so that the most basic case
    (without masking) is first.

  5. Update comments.


Patch is 23.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76054.diff

2 Files Affected:

  • (renamed) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir (+119-114)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+30-30)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
similarity index 81%
rename from mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
rename to mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index 7588b738ff9aa3..96023124bd3ff8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -1,20 +1,22 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-// NOTE - tests in this file are duplicated so that there's a version for
-//    * _fixed width_ and for _scalable_ vectors.
-// In order for the "vector.contract -> vector.outerproduct" patterns to work,
-// only the non-reduction dimension can be scalable (*). For Matmul operations
-// that is set to be the N dimension (i.e. rows of the output matrix), which
-// matches how matrix multiplication are normally implemented for e.g. 
-// Arm SVE. However, making the M dimension scalable (i.e. columns of the
-// output matrix) should work as well.
-//
-// (*) The conversion tested in this file unrolls along the reduction
-// dimension, which is not supported for scalable vectors.
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matmul operations:
+///   C += A * B.
+/// (A, B and C are 2-d matrices). ATM three different variants / are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+/// only the non-reduction dimension can be scalable (*). For Matmul operations
+/// that is set to be the N dimension (i.e. rows of the output matrix), which
+/// matches how matrix multiplication are normally implemented for e.g.
+/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
+/// output matrix) should work as well.
+///
+/// (*) The conversion tested in this file unrolls along the reduction
+/// dimension, which is not supported for scalable vectors.
 
-// ============================================================================
-//  Matmul 0 (plain + masked + mixed types)
-// ============================================================================
 #matmat_accesses_0 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -25,6 +27,49 @@
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+#matmat_accesses_1 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+  indexing_maps = #matmat_accesses_1,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_2 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+  indexing_maps = #matmat_accesses_2,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_3 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+  indexing_maps = #matmat_accesses_3,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_4 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+  indexing_maps = #matmat_accesses_4,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// ============================================================================
+//  Matmul 0 (plain + masked + mixed types)
+// ============================================================================
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -53,10 +98,10 @@
 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x3xf32>
-func.func @matmul(%arg0: vector<2x4xf32>,
-                  %arg1: vector<4x3xf32>,
-                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul(%A: vector<2x4xf32>,
+                  %B: vector<4x3xf32>,
+                  %C: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -89,10 +134,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x[3]xf32>
-func.func @matmul_scalable(%arg0: vector<2x4xf32>,
-                           %arg1: vector<4x[3]xf32>,
-                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul_scalable(%A: vector<2x4xf32>,
+                           %B: vector<4x[3]xf32>,
+                           %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -114,11 +159,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
-func.func @masked_matmul(%arg0: vector<3x5xf32>,
-                         %arg1: vector<5x7xf32>,
-                         %arg2: vector<3x7xf32>,
+func.func @masked_matmul(%A: vector<3x5xf32>,
+                         %B: vector<5x7xf32>,
+                         %C: vector<3x7xf32>,
                          %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
   return %0 : vector<3x7xf32>
 }
@@ -140,11 +185,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
 
-func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
-                                  %arg1: vector<5x[7]xf32>,
-                                  %arg2: vector<3x[7]xf32>,
+func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
+                                  %B: vector<5x[7]xf32>,
+                                  %C: vector<3x[7]xf32>,
                                   %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
   return %0 : vector<3x[7]xf32>
 }
@@ -160,11 +205,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_mixed(%arg0: vector<2x1xf16>,
-                          %arg1: vector<1x3xf16>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_mixed(%A: vector<2x1xf16>,
+                        %B: vector<1x3xf16>,
+                        %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -180,28 +225,18 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
-                                   %arg1: vector<1x[3]xf16>,
-                                   %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
+                                 %B: vector<1x[3]xf16>,
+                                 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 1 (plain)
+//  Matmul 1 (plain + scalable)
 // ============================================================================
-#matmat_accesses_1 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_1 = {
-  indexing_maps = #matmat_accesses_1,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_1
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -212,11 +247,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_1(%A: vector<2x1xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -231,28 +266,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%A: vector<2x1xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 2 (plain)
+//  Matmul 2 (plain + scalable)
 // ============================================================================
-#matmat_accesses_2 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_2 = {
-  indexing_maps = #matmat_accesses_2,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_2
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -261,11 +286,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_2(%A: vector<1x2xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -278,28 +303,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<1x[3]xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%A: vector<1x2xf32>,
+                             %B: vector<1x[3]xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 3 (plain)
+//  Matmul 3 (plain + scalable)
 // ============================================================================
-#matmat_accesses_3 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_3 = {
-  indexing_maps = #matmat_accesses_3,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_3
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -309,11 +324,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_3(%A: vector<1x2xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -327,28 +342,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_3_scalable(%A: vector<1x2xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 4 (plain)
+//  Matmul 4 (plain + scalable)
 // ============================================================================
-#matmat_accesses_4 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_4 = {
-  indexing_maps = #matmat_accesses_4,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_4
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -358,11 +363,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<3x2xf32>) -> vector<3x2xf32>
+func.func @matmul_4(%A: vector<2x1xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<3x2xf32>) -> vector<3x2xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
   return %0 : vector<3x2xf32>
 }
@@ -376,11 +381,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x[2]xf32>
-func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
-                             %arg1: vector<1x3xf32>,
-                             %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
+func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
+                             %B: vector<1x3xf32>,
+                             %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
   return %0 : vector<3x[2]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index c09a4d569638a5..d86c6158bcdf2f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
 // ============================================================================
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
-// CHECK-LABEL: func @matvec_km_k_m
+// ============================================================================
+//  Matvec 3 (plain + masked + scalable)
+// ====================================...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Updates two tests for vector.contract -> vector.outerproduct
transformations:

  1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
    "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
    name more accurate captures what's being tested. it is also
    consistent with
    "vector-contract-to-outerproduct-matvec-transforms.mlir", which
    covers vector matvec operations and makes finding relevant tests
    easier.

  2. For matmul tests, move the traits definining the iteration spaces to
    the top of the file. This is consistent with how matvec tests are
    defined and also makes it easy to quickly identify what cases are
    covered.

  3. For matmul tests, use more meaningful names for function arguments.
    This helps keep things consistent across the file (i.e. function
    definitions wih check lines and comments).

  4. For matvec test, move a few tests around so that the most basic case
    (without masking) is first.

  5. Update comments.


Patch is 23.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76054.diff

2 Files Affected:

  • (renamed) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir (+119-114)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+30-30)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
similarity index 81%
rename from mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
rename to mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index 7588b738ff9aa3..96023124bd3ff8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -1,20 +1,22 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-// NOTE - tests in this file are duplicated so that there's a version for
-//    * _fixed width_ and for _scalable_ vectors.
-// In order for the "vector.contract -> vector.outerproduct" patterns to work,
-// only the non-reduction dimension can be scalable (*). For Matmul operations
-// that is set to be the N dimension (i.e. rows of the output matrix), which
-// matches how matrix multiplication are normally implemented for e.g. 
-// Arm SVE. However, making the M dimension scalable (i.e. columns of the
-// output matrix) should work as well.
-//
-// (*) The conversion tested in this file unrolls along the reduction
-// dimension, which is not supported for scalable vectors.
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matmul operations:
+///   C += A * B.
+/// (A, B and C are 2-d matrices). ATM three different variants / are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+/// only the non-reduction dimension can be scalable (*). For Matmul operations
+/// that is set to be the N dimension (i.e. rows of the output matrix), which
+/// matches how matrix multiplication are normally implemented for e.g.
+/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
+/// output matrix) should work as well.
+///
+/// (*) The conversion tested in this file unrolls along the reduction
+/// dimension, which is not supported for scalable vectors.
 
-// ============================================================================
-//  Matmul 0 (plain + masked + mixed types)
-// ============================================================================
 #matmat_accesses_0 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -25,6 +27,49 @@
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+#matmat_accesses_1 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+  indexing_maps = #matmat_accesses_1,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_2 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+  indexing_maps = #matmat_accesses_2,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_3 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+  indexing_maps = #matmat_accesses_3,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_4 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+  indexing_maps = #matmat_accesses_4,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// ============================================================================
+//  Matmul 0 (plain + masked + mixed types)
+// ============================================================================
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -53,10 +98,10 @@
 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x3xf32>
-func.func @matmul(%arg0: vector<2x4xf32>,
-                  %arg1: vector<4x3xf32>,
-                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul(%A: vector<2x4xf32>,
+                  %B: vector<4x3xf32>,
+                  %C: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -89,10 +134,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x[3]xf32>
-func.func @matmul_scalable(%arg0: vector<2x4xf32>,
-                           %arg1: vector<4x[3]xf32>,
-                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul_scalable(%A: vector<2x4xf32>,
+                           %B: vector<4x[3]xf32>,
+                           %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -114,11 +159,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
-func.func @masked_matmul(%arg0: vector<3x5xf32>,
-                         %arg1: vector<5x7xf32>,
-                         %arg2: vector<3x7xf32>,
+func.func @masked_matmul(%A: vector<3x5xf32>,
+                         %B: vector<5x7xf32>,
+                         %C: vector<3x7xf32>,
                          %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
   return %0 : vector<3x7xf32>
 }
@@ -140,11 +185,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
 
-func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
-                                  %arg1: vector<5x[7]xf32>,
-                                  %arg2: vector<3x[7]xf32>,
+func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
+                                  %B: vector<5x[7]xf32>,
+                                  %C: vector<3x[7]xf32>,
                                   %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
   return %0 : vector<3x[7]xf32>
 }
@@ -160,11 +205,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_mixed(%arg0: vector<2x1xf16>,
-                          %arg1: vector<1x3xf16>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_mixed(%A: vector<2x1xf16>,
+                        %B: vector<1x3xf16>,
+                        %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -180,28 +225,18 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
-                                   %arg1: vector<1x[3]xf16>,
-                                   %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
+                                 %B: vector<1x[3]xf16>,
+                                 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 1 (plain)
+//  Matmul 1 (plain + scalable)
 // ============================================================================
-#matmat_accesses_1 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_1 = {
-  indexing_maps = #matmat_accesses_1,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_1
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -212,11 +247,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_1(%A: vector<2x1xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -231,28 +266,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%A: vector<2x1xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 2 (plain)
+//  Matmul 2 (plain + scalable)
 // ============================================================================
-#matmat_accesses_2 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_2 = {
-  indexing_maps = #matmat_accesses_2,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_2
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -261,11 +286,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_2(%A: vector<1x2xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -278,28 +303,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<1x[3]xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%A: vector<1x2xf32>,
+                             %B: vector<1x[3]xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 3 (plain)
+//  Matmul 3 (plain + scalable)
 // ============================================================================
-#matmat_accesses_3 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_3 = {
-  indexing_maps = #matmat_accesses_3,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_3
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -309,11 +324,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_3(%A: vector<1x2xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -327,28 +342,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_3_scalable(%A: vector<1x2xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
 
 // ============================================================================
-//  Matmul 4 (plain)
+//  Matmul 4 (plain + scalable)
 // ============================================================================
-#matmat_accesses_4 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_4 = {
-  indexing_maps = #matmat_accesses_4,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_4
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -358,11 +363,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<3x2xf32>) -> vector<3x2xf32>
+func.func @matmul_4(%A: vector<2x1xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<3x2xf32>) -> vector<3x2xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
   return %0 : vector<3x2xf32>
 }
@@ -376,11 +381,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x[2]xf32>
-func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
-                             %arg1: vector<1x3xf32>,
-                             %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
+func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
+                             %B: vector<1x3xf32>,
+                             %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
   return %0 : vector<3x[2]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index c09a4d569638a5..d86c6158bcdf2f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
 // ============================================================================
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
-// CHECK-LABEL: func @matvec_km_k_m
+// ============================================================================
+//  Matvec 3 (plain + masked + scalable)
+// ====================================...
[truncated]

@banach-space
Copy link
Contributor Author

@MacDue I've split this one into 5 commits to make it easier to review. Hope this helps.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for splitting things up! Just some very minor nits:

@banach-space banach-space merged commit 17afa5b into llvm:main Dec 21, 2023
@banach-space banach-space deleted the andrzej/update_test_v2 branch December 21, 2023 13:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants