Skip to content

[mlir][linalg] Add an e2e test for linalg.matmul_transpose_a to ArmSME #71644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 10, 2023

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Nov 8, 2023

This patch adds an integration test demonstrating the first e2e example lowering a linalg.matmul to SME via vector.outerproduct.

The test uses a 'linalg.matmul_transpose_a' rather than 'linalg.matmul' since the latter emits a 'vector.transfer_read' with a vector type of 'vector<[4]x1xf32>' that can't be currently lowered via generic (SVE) path, since it has leading scalable dim.

This patch adds an integration test demonstrating the first e2e example
lowering a linalg.matmul to SME via vector.outerproduct.

The test uses a 'linalg.matmul_transpose_a' rather than 'linalg.matmul'
since the latter emits a 'vector.transfer_read' with a vector type of
'vector<[4]x1xf32>' that can't be currently lowered via generic (SVE)
path, since it has leading scalable dim.
@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-linalg

Author: Cullen Rhodes (c-rhodes)

Changes

This patch adds an integration test demonstrating the first e2e example lowering a linalg.matmul to SME via vector.outerproduct.

The test uses a 'linalg.matmul_transpose_a' rather than 'linalg.matmul' since the latter emits a 'vector.transfer_read' with a vector type of 'vector<[4]x1xf32>' that can't be currently lowered via generic (SVE) path, since it has leading scalable dim.


Full diff: https://github.com/llvm/llvm-project/pull/71644.diff

1 Files Affected:

  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+77)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
new file mode 100644
index 000000000000000..bf445e05fb70e8d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s \
+// RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
+// RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
+// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-vector-to-llvm=enable-arm-sme \
+// RUN:   -convert-vector-to-llvm=enable-arm-sve \
+// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
+  %res = linalg.matmul_transpose_a ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                                   outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+  return
+}
+
+func.func @main() {
+  %c0 = arith.constant 0 : i32
+  %c4 = arith.constant 4 : index
+
+  %A = arith.constant dense<[
+    [  1.0,  2.0,  3.0,  4.0 ],
+    [  5.0,  6.0,  7.0,  8.0 ],
+    [  9.0, 10.0, 11.0, 12.0 ],
+    [ 13.0, 14.0, 15.0, 16.0 ]
+  ]> : tensor<4x4xf32>
+
+  %A_dyn = tensor.cast %A : tensor<4x4xf32> to tensor<?x?xf32>
+
+  %C_init = bufferization.alloc_tensor(%c4, %c4) : tensor<?x?xf32>
+  %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data =
+  // CHECK: [276, 304, 332, 360]
+  // CHECK: [304, 336, 368, 400]
+  // CHECK: [332, 368, 404, 440]
+  // CHECK: [360, 400, 440, 480]
+  call @matmul_transpose_a(%A_dyn, %A_dyn, %C) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> ()
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul_transpose_a"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+    %tiled_linalg_op, %loops:3 = transform.structured.tile_using_for %0[[4], [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.structured.vectorize %tiled_linalg_op vector_sizes [[4], [4], 1]
+      : !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.reduction_to_contract
+    } : !transform.any_op
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_masks
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very, very EXCITING! 🥳

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :shipit: 🚢 ⚓ 🚀

Could you ping Discourse once post-commit CI is 🟢 ? Great work Cullen!

@c-rhodes c-rhodes merged commit fe8c649 into llvm:main Nov 10, 2023
@c-rhodes c-rhodes deleted the mlir-arm-sme-linalg-matmul-transpose-a branch November 10, 2023 07:52
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
llvm#71644)

This patch adds an integration test demonstrating the first e2e example
lowering a linalg.matmul to SME via vector.outerproduct.

The test uses a 'linalg.matmul_transpose_a' rather than 'linalg.matmul'
since the latter emits a 'vector.transfer_read' with a vector type of
'vector<[4]x1xf32>' that can't be currently lowered via generic (SVE)
path, since it has leading scalable dim.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants