Skip to content

[mlir][nfc] Add tests for linalg.mmt4d #81422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Feb 11, 2024

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there are virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.


Full diff: https://github.com/llvm/llvm-project/pull/81422.diff

4 Files Affected:

  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+26)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+11)
  • (added) mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir (+66)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+32)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
     -> tensor<2x16xf32>
   return %1 : tensor<2x16xf32>
 }
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+                               %B: tensor<16x16x8x1xf32>,
+                               %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+    // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<16x16x8x1xf32>)
+                     -> tensor<16x16x8x1xf32>
+    return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+                 %B: tensor<16x16x8x1xf32>,
+                 %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+    // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<8x8xf32>)
+                     -> tensor<8x8xf32>
+    return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..317231908a9413 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
+  // CHECK: %{{.+}} = linalg.mmt4d
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @batch_mmt4d
 func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..21534152a45c4b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %res = linalg.mmt4d
+                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                   outs(%C_in: tensor<16x16x8x8xf32>)
+                   -> tensor<16x16x8x8xf32>
+  return %res : tensor<16x16x8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+    // Step 1: Tile
+      : (!transform.op<"func.func">) -> !transform.any_op
+    // Tile parallel dims
+    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split int parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+    // Hoisting and LICM - not strictly required
+    %func_h = transform.structured.hoist_redundant_vector_transfers %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+      : (!transform.op<"func.func">) -> !transform.any_op
+    transform.apply_licm to %all_loops : !transform.any_op
+    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+    // vector.contract with the following split splitn parallel and reduction
+    // dims:
+    //    * parallel, parallel, reduction
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.


Full diff: https://github.com/llvm/llvm-project/pull/81422.diff

4 Files Affected:

  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+26)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+11)
  • (added) mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir (+66)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+32)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
     -> tensor<2x16xf32>
   return %1 : tensor<2x16xf32>
 }
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+                               %B: tensor<16x16x8x1xf32>,
+                               %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+    // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<16x16x8x1xf32>)
+                     -> tensor<16x16x8x1xf32>
+    return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+                 %B: tensor<16x16x8x1xf32>,
+                 %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+    // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<8x8xf32>)
+                     -> tensor<8x8xf32>
+    return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..317231908a9413 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
+  // CHECK: %{{.+}} = linalg.mmt4d
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @batch_mmt4d
 func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..21534152a45c4b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %res = linalg.mmt4d
+                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                   outs(%C_in: tensor<16x16x8x8xf32>)
+                   -> tensor<16x16x8x8xf32>
+  return %res : tensor<16x16x8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+    // Step 1: Tile
+      : (!transform.op<"func.func">) -> !transform.any_op
+    // Tile parallel dims
+    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split int parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+    // Hoisting and LICM - not strictly required
+    %func_h = transform.structured.hoist_redundant_vector_transfers %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+      : (!transform.op<"func.func">) -> !transform.any_op
+    transform.apply_licm to %all_loops : !transform.any_op
+    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+    // vector.contract with the following split splitn parallel and reduction
+    // dims:
+    //    * parallel, parallel, reduction
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Andrzej, I'm not very familiar with this op so would be good to get eyes from someone who is, I've left a few comments anyway. I would also find a small integration test helpful, particularly if it could be compared with an existing linalg.matmul.

Comment on lines +763 to +765
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
%B: tensor<16x16x8x1xf32>,
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: alignment

Suggested change
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
%B: tensor<16x16x8x1xf32>,
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
%B: tensor<16x16x8x1xf32>,
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {

Copy link
Contributor

@bjacob bjacob left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, @banach-space! Sorry about the omission.

@banach-space
Copy link
Contributor Author

banach-space commented Feb 13, 2024

I would also find a small integration test helpful, particularly if it could be compared with an existing linalg.matmul.

That's what I plan next (in a separate PR), but just run out of spare cycles when working on this :) Reverse engineering that lowering pipeline took quite some time.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space banach-space merged commit 7a47113 into llvm:main Feb 13, 2024
@banach-space banach-space deleted the andrzej/add_tests_for_mmt4d branch February 14, 2024 08:39
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 14, 2024
Follow-up for llvm#81422. My intention is to write an e2e test targetting
SVE, but more work is needed. Sending this as an intermiedate step.
@banach-space
Copy link
Contributor Author

I would also find a small integration test helpful

#81790

banach-space added a commit that referenced this pull request Feb 20, 2024
Follow-up for #81422. My intention is to write an e2e test targetting
SVE, but more work is needed. Sending this as an intermiedate step.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants