|
| 1 | +// DEFINE: %{compile} = mlir-opt %s \ |
| 2 | +// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \ |
| 3 | +// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" \ |
| 4 | +// DEFINE: -buffer-deallocation-pipeline="private-function-dynamic-ownership" \ |
| 5 | +// DEFINE: -cse -canonicalize -test-lower-to-llvm |
| 6 | +// DEFINE: %{entry_point} = main |
| 7 | +// DEFINE: %{run} = mlir-cpu-runner -e %{entry_point} -entry-point-result=void \ |
| 8 | +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils |
| 9 | + |
| 10 | +// RUN: %{compile} | %{run} | FileCheck %s |
| 11 | + |
| 12 | +/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In |
| 13 | +/// particular, demonstrates how the following MLIR sequence (implemented in @mmt4d): |
| 14 | +/// |
| 15 | +/// A_pack = tensor.pack A |
| 16 | +/// B_pack = tensor.pack B |
| 17 | +/// C_pack = tensor.pack C |
| 18 | +/// out_pack = linalg.mmt4d(A_pack, B_pack, C_pack) |
| 19 | +/// |
| 20 | +/// is equivalent to: |
| 21 | +/// |
| 22 | +/// linalg.matmul(A, B, C) |
| 23 | +/// |
| 24 | +/// (implemented in @matmul). |
| 25 | + |
| 26 | +func.func @main() { |
| 27 | + // Allocate and initialise the inputs |
| 28 | + %A_alloc = tensor.empty() : tensor<7x16xi32> |
| 29 | + %B_alloc = tensor.empty() : tensor<16x13xi32> |
| 30 | + |
| 31 | + %three = arith.constant 3 : i32 |
| 32 | + %four = arith.constant 4 : i32 |
| 33 | + %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<7x16xi32>) -> tensor<7x16xi32> |
| 34 | + %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<16x13xi32>) -> tensor<16x13xi32> |
| 35 | + %C = arith.constant dense<[ |
| 36 | + [ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85], |
| 37 | + [ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86], |
| 38 | + [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87], |
| 39 | + [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88], |
| 40 | + [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89], |
| 41 | + [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90], |
| 42 | + [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91] |
| 43 | + ]> : tensor<7x13xi32> |
| 44 | + |
| 45 | + // Matrix multiplication via linalg.mmt4d |
| 46 | + // CHECK: Unranked Memref |
| 47 | + // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277] |
| 48 | + // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278] |
| 49 | + // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279] |
| 50 | + // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280] |
| 51 | + // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281] |
| 52 | + // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282] |
| 53 | + // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283] |
| 54 | + %C_mmt4d = func.call @mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32> |
| 55 | + %xf = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32> |
| 56 | + call @printMemrefI32(%xf) : (tensor<*xi32>) -> () |
| 57 | + |
| 58 | + // Matrix multiplication with linalg.matmul |
| 59 | + // CHECK: Unranked Memref |
| 60 | + // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277] |
| 61 | + // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278] |
| 62 | + // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279] |
| 63 | + // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280] |
| 64 | + // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281] |
| 65 | + // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282] |
| 66 | + // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283] |
| 67 | + %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32> |
| 68 | + %xf_2 = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32> |
| 69 | + call @printMemrefI32(%xf_2) : (tensor<*xi32>) -> () |
| 70 | + |
| 71 | + return |
| 72 | +} |
| 73 | + |
| 74 | +func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { |
| 75 | + %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>) |
| 76 | + outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32> |
| 77 | + |
| 78 | + return %C_matmul : tensor<7x13xi32> |
| 79 | +} |
| 80 | + |
| 81 | +func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { |
| 82 | + %zero = arith.constant 0 : i32 |
| 83 | + |
| 84 | + %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> |
| 85 | + %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> |
| 86 | + %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32> |
| 87 | + |
| 88 | + // Pack matrices |
| 89 | + %A_pack = tensor.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32> |
| 90 | + %B_pack = tensor.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32> |
| 91 | + %C_pack = tensor.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32> |
| 92 | + |
| 93 | + // MMT4D |
| 94 | + %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32> |
| 95 | + |
| 96 | + // Unpack output |
| 97 | + %C_out_empty = tensor.empty() : tensor<7x13xi32> |
| 98 | + %C_out_unpack = tensor.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32> |
| 99 | + |
| 100 | + return %C_out_unpack : tensor<7x13xi32> |
| 101 | +} |
| 102 | + |
| 103 | +module @transforms attributes { transform.with_named_sequence } { |
| 104 | + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { |
| 105 | + %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op) |
| 106 | + %func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 107 | + |
| 108 | + // Step 1: Tile |
| 109 | + // Tile parallel dims |
| 110 | + %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0] |
| 111 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) |
| 112 | + // Tile reduction dims |
| 113 | + %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1] |
| 114 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 115 | + |
| 116 | + // Step 2: Vectorize |
| 117 | + transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op |
| 118 | + |
| 119 | + // Step 3: Simplify |
| 120 | + // vector.multi_reduction --> vector.contract |
| 121 | + // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op |
| 122 | + // and with the following split into parallel and reduction dims: |
| 123 | + // * parallel, parallel, reduction, parallel, parallel, reduction |
| 124 | + transform.apply_patterns to %func { |
| 125 | + transform.apply_patterns.vector.reduction_to_contract |
| 126 | + // Reduce the rank of xfer ops. This transforms vector.contract to be |
| 127 | + // more matmul-like and to enable the lowering to outer product Ops. |
| 128 | + transform.apply_patterns.vector.transfer_permutation_patterns |
| 129 | + } : !transform.op<"func.func"> |
| 130 | + |
| 131 | + // Hoisting and LICM - not strictly required |
| 132 | + %func_h = transform.structured.hoist_redundant_vector_transfers %func |
| 133 | + : (!transform.op<"func.func">) -> !transform.op<"func.func"> |
| 134 | + %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h |
| 135 | + : (!transform.op<"func.func">) -> !transform.any_op |
| 136 | + transform.apply_licm to %all_loops : !transform.any_op |
| 137 | + transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op |
| 138 | + |
| 139 | + // Simplify the 6-dim vector.contract into a 3-dim matmul-like |
| 140 | + // vector.contract with the following split into parallel and reduction |
| 141 | + // dims: |
| 142 | + // * parallel, parallel, reduction |
| 143 | + transform.apply_patterns to %func_h { |
| 144 | + transform.apply_patterns.vector.reduction_to_contract |
| 145 | + transform.apply_patterns.vector.cast_away_vector_leading_one_dim |
| 146 | + transform.apply_patterns.canonicalization |
| 147 | + } : !transform.op<"func.func"> |
| 148 | + |
| 149 | + // Step 4. Lower tensor.pack |
| 150 | + %pack = transform.structured.match ops{["tensor.pack"]} in %func_h |
| 151 | + : (!transform.op<"func.func">) -> !transform.op<"tensor.pack"> |
| 152 | + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) |
| 153 | + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) |
| 154 | + |
| 155 | + // Step 5. Lower tensor.unpack |
| 156 | + %unpack = transform.structured.match ops{["tensor.unpack"]} in %func_h |
| 157 | + : (!transform.op<"func.func">) -> !transform.op<"tensor.unpack"> |
| 158 | + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) |
| 159 | + -> (!transform.op<"tensor.empty">, |
| 160 | + !transform.op<"linalg.transpose">, |
| 161 | + !transform.op<"tensor.collapse_shape">, |
| 162 | + !transform.op<"tensor.extract_slice">) |
| 163 | + transform.yield |
| 164 | + } |
| 165 | + |
| 166 | + transform.named_sequence @match_mmt4d( |
| 167 | + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { |
| 168 | + transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op |
| 169 | + transform.yield %entry : !transform.any_op |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +func.func private @printMemrefI32(%ptr : tensor<*xi32>) |
0 commit comments