-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Add e2e test for linalg.mmt4d + pack/unpack #84964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Add e2e test for linalg.mmt4d + pack/unpack #84964
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThis is a follow-up for #81790. This patch basically extends:
with pack/unpack ops so that to overall computation is a matrix Full diff: https://github.com/llvm/llvm-project/pull/84964.diff 1 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
new file mode 100644
index 00000000000000..bf7a98fc07f086
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
@@ -0,0 +1,174 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+/// End-to-end test for computing matrix-multiplicatin using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in @mmt4d):
+///
+/// A_pack = tensor.pack A
+/// B_pack = tensor.pack B
+/// C_pack = tensor.pack C
+/// out_pack = linalg.mmt4d(A_pack, B_ pack, C_pack)
+///
+/// is equivalent to:
+///
+/// linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul).
+
+func.func @main() {
+ // Allocate and initialise the inputs
+ %A_alloc = tensor.empty() : tensor<7x16xi32>
+ %B_alloc = tensor.empty() : tensor<16x13xi32>
+
+ %three = arith.constant 3 : i32
+ %four = arith.constant 4 : i32
+ %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<7x16xi32>) -> tensor<7x16xi32>
+ %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<16x13xi32>) -> tensor<16x13xi32>
+ %C = arith.constant dense<[
+ [ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+ [ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+ [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+ [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+ [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+ [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+ [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+ ]> : tensor<7x13xi32>
+
+ // Matrix multiplication via linalg.mmt4d
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_mmt4d = func.call @mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %xf = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+ // Matrix multiplicaiton with linalg.matmul
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %xf_2 = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf_2) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+func.func @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+ outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+ return %C_matmul : tensor<7x13xi32>
+}
+
+func.func @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ %zero = arith.constant 0 : i32
+
+ %cst = arith.constant 0 : i32
+ %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+ %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+ %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
+
+ // Pack matrices
+ %A_pack = tensor.pack %A padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
+ %B_pack = tensor.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
+ %C_pack = tensor.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
+
+ // MMT4D
+ %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
+
+ // Unpack output
+ %C_out_empty = tensor.empty() : tensor<7x13xi32>
+ %C_out_unpack = tensor.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
+
+ return %C_out_unpack : tensor<7x13xi32>
+}
+
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+ %func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+ // Step 1: Tile
+ // Tile parallel dims
+ %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split into parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %func_h = transform.structured.hoist_redundant_vector_transfers %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+ // vector.contract with the following split into parallel and reduction
+ // dims:
+ // * parallel, parallel, reduction
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // Step 4. Lower tensor.pack
+ %pack = transform.structured.match ops{["tensor.pack"]} in %func_h
+ : (!transform.op<"func.func">) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+
+ // Step 5. Lower tensor.unpack
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %func_h
+ : (!transform.op<"func.func">) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+
+ transform.named_sequence @match_mmt4d(
+ %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
+ transform.yield %entry : !transform.any_op
+ }
+}
+
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for patch Andrzej, I've left some initial comments but I still need to digest what this is doing, dont fully understand it yet
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left some comments mostly minor, I'm accepting this on the basis it demonstrates a pack + mmt4d + unpack is equivalent to a matmul, which is easily verifiable from the output.
If the purpose of this test is also for documentation I think it could be improved, particularly with some comments to clarify the packing, but additionally I think a simpler example with smaller shapes would be better.
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
d25546c
to
ded93ef
Compare
mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
Outdated
Show resolved
Hide resolved
This is a follow-up for llvm#81790. This patch basically extends: * test/Integration/Dialect/Linalg/CPU/mmt4d.mlir with pack/unpack ops so that to overall computation is a matrix multiplication (as opposed to linalg.mmt4d). For comparison (and to make it easier to verify correctness), linalg.matmul is also included in the test.
Address PR comments
Add "private-function-dynamic-ownership" to "-buffer-deallocation-pipeline". This prevents the bufferizer from inserting `bufferization.clone` (which are a bit tricky to remove) while still avoiding leaks.
fd75b32
to
fec5e97
Compare
This looks great, thanks for doing this! Not formally participating in the review process as I'm unfamiliar with transform dialect. |
This is a follow-up for #81790. This patch basically extends:
with pack/unpack ops so that to overall computation is a matrix
multiplication (as opposed to linalg.mmt4d). For comparison (and to make
it easier to verify correctness), linalg.matmul is also included in the
test.