|
1 | 1 | // RUN: mlir-opt %s -transform-interpreter | FileCheck %s
|
2 | 2 |
|
3 |
| -func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) { |
4 |
| - linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>) |
| 3 | +func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) { |
| 4 | + linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>) |
5 | 5 | outs(%C: memref<3x3xf32>)
|
6 | 6 | return
|
7 | 7 | }
|
8 | 8 |
|
9 |
| -// CHECK-LABEL: func.func @outerproduct_matmul( |
10 |
| -// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) { |
11 |
| -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index |
12 |
| -// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 |
13 |
| -// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
14 |
| -// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
15 |
| -// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
16 |
| -// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32> |
17 |
| -// CHECK: %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32> |
18 |
| -// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32> |
19 |
| -// CHECK: %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
20 |
| -// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32> |
21 |
| -// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32> |
22 |
| -// CHECK: %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
23 |
| -// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32> |
24 |
| -// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32> |
25 |
| -// CHECK: %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
26 |
| -// CHECK: vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32> |
27 |
| -// CHECK: return |
28 |
| -// CHECK: } |
| 9 | +// CHECK-LABEL: func.func @matmul_to_outerproduct( |
| 10 | +// CHECK-SAME: %[[A:.*]]: memref<3x4xf32>, |
| 11 | +// CHECK-SAME: %[[B:.*]]: memref<4x3xf32>, |
| 12 | +// CHECK-SAME: %[[C:.*]]: memref<3x3xf32>) { |
| 13 | +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]] |
| 14 | +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]] |
| 15 | +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]] |
| 16 | +// CHECK: %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32> |
| 17 | +// CHECK: %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32> |
| 18 | +// CHECK: %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32> |
| 19 | +// CHECK: %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]] |
| 20 | +// CHECK: %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32> |
| 21 | +// CHECK: %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32> |
| 22 | +// CHECK: %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]] |
| 23 | +// CHECK: %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32> |
| 24 | +// CHECK: %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32> |
| 25 | +// CHECK: %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]] |
| 26 | +// CHECK: %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32> |
| 27 | +// CHECK: %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32> |
| 28 | +// CHECK: %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]] |
| 29 | +// CHECK: vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32> |
29 | 30 |
|
30 | 31 | module attributes {transform.with_named_sequence} {
|
31 |
| - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
32 |
| - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
33 |
| - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op |
34 |
| - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op |
35 |
| - transform.apply_patterns to %2 { |
| 32 | + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { |
| 33 | + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op |
| 34 | + |
| 35 | + // Vectorize: linalg.matmul -> vector.multi_reduction |
| 36 | + %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op |
| 37 | + transform.structured.vectorize %matmul : !transform.any_op |
| 38 | + |
| 39 | + // vector.multi_reduction --> vector.contract |
| 40 | + transform.apply_patterns to %func { |
| 41 | + transform.apply_patterns.vector.reduction_to_contract |
| 42 | + // Reduce the rank of xfer ops. This transform vector.contract to be more |
| 43 | + // more matmul-like and to enable the lowering to outer product Ops. |
| 44 | + transform.apply_patterns.vector.transfer_permutation_patterns |
| 45 | + } : !transform.any_op |
| 46 | + |
| 47 | + // vector.contract --> vector.outerproduct |
| 48 | + transform.apply_patterns to %func { |
36 | 49 | transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
|
37 | 50 | } : !transform.any_op
|
38 | 51 | transform.yield
|
|
0 commit comments