-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nfc] Add tests for linalg.mmt4d #81422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][nfc] Add tests for linalg.mmt4d #81422
Conversation
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changeslinalg.mmt4d was added a while back (https://reviews.llvm.org/D105244), Full diff: https://github.com/llvm/llvm-project/pull/81422.diff 4 Files Affected:
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
-> tensor<2x16xf32>
return %1 : tensor<2x16xf32>
}
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+ // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x1xf32>)
+ -> tensor<16x16x8x1xf32>
+ return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..317231908a9413 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// -----
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
+ // CHECK: %{{.+}} = linalg.mmt4d
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @batch_mmt4d
func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..21534152a45c4b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x8xf32>)
+ -> tensor<16x16x8x8xf32>
+ return %res : tensor<16x16x8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+ // Step 1: Tile
+ : (!transform.op<"func.func">) -> !transform.any_op
+ // Tile parallel dims
+ %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split int parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %func_h = transform.structured.hoist_redundant_vector_transfers %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+ // vector.contract with the following split splitn parallel and reduction
+ // dims:
+ // * parallel, parallel, reduction
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+ outs(%C_in: memref<16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) Changeslinalg.mmt4d was added a while back (https://reviews.llvm.org/D105244), Full diff: https://github.com/llvm/llvm-project/pull/81422.diff 4 Files Affected:
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
-> tensor<2x16xf32>
return %1 : tensor<2x16xf32>
}
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+ // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x1xf32>)
+ -> tensor<16x16x8x1xf32>
+ return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..317231908a9413 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// -----
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
+ // CHECK: %{{.+}} = linalg.mmt4d
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @batch_mmt4d
func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..21534152a45c4b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x8xf32>)
+ -> tensor<16x16x8x8xf32>
+ return %res : tensor<16x16x8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+ // Step 1: Tile
+ : (!transform.op<"func.func">) -> !transform.any_op
+ // Tile parallel dims
+ %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split int parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %func_h = transform.structured.hoist_redundant_vector_transfers %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+ // vector.contract with the following split splitn parallel and reduction
+ // dims:
+ // * parallel, parallel, reduction
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+ outs(%C_in: memref<16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
|
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244), but there virtually no tests in-tree. In the spirit of documenting through test, this PR adds a few basic examples.
5f8c31b
to
e8c57d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Andrzej, I'm not very familiar with this op so would be good to get eyes from someone who is, I've left a few comments anyway. I would also find a small integration test helpful, particularly if it could be compared with an existing linalg.matmul.
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, | ||
%B: tensor<16x16x8x1xf32>, | ||
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: alignment
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, | |
%B: tensor<16x16x8x1xf32>, | |
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> { | |
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, | |
%B: tensor<16x16x8x1xf32>, | |
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this, @banach-space! Sorry about the omission.
That's what I plan next (in a separate PR), but just run out of spare cycles when working on this :) Reverse engineering that lowering pipeline took quite some time. |
Address Cullen's comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Follow-up for llvm#81422. My intention is to write an e2e test targetting SVE, but more work is needed. Sending this as an intermiedate step.
|
Follow-up for #81422. My intention is to write an e2e test targetting SVE, but more work is needed. Sending this as an intermiedate step.
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there are virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.