|
| 1 | +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s |
| 2 | + |
| 3 | +func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) { |
| 4 | + linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>) |
| 5 | + outs(%C: memref<3x3xf32>) |
| 6 | + return |
| 7 | +} |
| 8 | + |
| 9 | +// CHECK-LABEL: func.func @outerproduct_matmul( |
| 10 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) { |
| 11 | +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index |
| 12 | +// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 |
| 13 | +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
| 14 | +// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
| 15 | +// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32> |
| 16 | +// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32> |
| 17 | +// CHECK: %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3x3xf32> |
| 18 | +// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3x3xf32> |
| 19 | +// CHECK: %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
| 20 | +// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3x3xf32> |
| 21 | +// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3x3xf32> |
| 22 | +// CHECK: %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
| 23 | +// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3x3xf32> |
| 24 | +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3x3xf32> |
| 25 | +// CHECK: %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32> |
| 26 | +// CHECK: vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32> |
| 27 | +// CHECK: return |
| 28 | +// CHECK: } |
| 29 | + |
| 30 | +transform.sequence failures(propagate) { |
| 31 | +^bb1(%arg1: !pdl.operation): |
| 32 | + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation |
| 33 | + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation |
| 34 | + %2 = transform.structured.vectorize %1 |
| 35 | + transform.vector.lower_contraction %2 lowering_strategy = "outerproduct" : (!pdl.operation) -> !pdl.operation |
| 36 | +} |
0 commit comments