-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nfc] Update Linalg matmul -> Vector OP test #81416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][nfc] Update Linalg matmul -> Vector OP test #81416
Conversation
Updates "transform-op-matmul-to-outerproduct.mlir". Summary: * refines TD sequence so it's easier to reason about the compilation pipeline, * new input dims to be able to distinguish parallel from reduction dims, * updates LIT variable names (makes the output easier to follow), * removes "noise" from the expected LIT output (e.g. types). These Linalg -> Vector tests using Transform Dialect are great reference points for constructing lowering pipelines. This simplification + clean-up will hopefully make it easier to follow.
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesUpdates "transform-op-matmul-to-outerproduct.mlir". Summary:
These Linalg -> Vector tests using Transform Dialect are great reference Full diff: https://github.com/llvm/llvm-project/pull/81416.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
index ee66073a9a4193..a1a0c413da0c1c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -1,38 +1,51 @@
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
- linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+ linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>)
outs(%C: memref<3x3xf32>)
return
}
-// CHECK-LABEL: func.func @outerproduct_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
-// CHECK: %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
-// CHECK: return
-// CHECK: }
+// CHECK-LABEL: func.func @matmul_to_outerproduct(
+// CHECK-SAME: %[[A:.*]]: memref<3x4xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x3xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]
+// CHECK: %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]]
+// CHECK: %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]]
+// CHECK: %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]]
+// CHECK: %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]]
+// CHECK: vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32>
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %2 {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+
+ // Vectorize: linalg.matmul -> vector.multi_reduction
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul : !transform.any_op
+
+ // vector.multi_reduction --> vector.contract
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transform vector.contract to be more
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.any_op
+
+ // vector.contract --> vector.outerproduct
+ transform.apply_patterns to %func {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
transform.yield
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesUpdates "transform-op-matmul-to-outerproduct.mlir". Summary:
These Linalg -> Vector tests using Transform Dialect are great reference Full diff: https://github.com/llvm/llvm-project/pull/81416.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
index ee66073a9a4193..a1a0c413da0c1c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -1,38 +1,51 @@
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
- linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+ linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>)
outs(%C: memref<3x3xf32>)
return
}
-// CHECK-LABEL: func.func @outerproduct_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
-// CHECK: %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK: %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK: vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
-// CHECK: return
-// CHECK: }
+// CHECK-LABEL: func.func @matmul_to_outerproduct(
+// CHECK-SAME: %[[A:.*]]: memref<3x4xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x3xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]
+// CHECK: %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]]
+// CHECK: %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]]
+// CHECK: %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]]
+// CHECK: %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]]
+// CHECK: vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32>
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %2 {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+
+ // Vectorize: linalg.matmul -> vector.multi_reduction
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul : !transform.any_op
+
+ // vector.multi_reduction --> vector.contract
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transform vector.contract to be more
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.any_op
+
+ // vector.contract --> vector.outerproduct
+ transform.apply_patterns to %func {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
transform.yield
|
Updates "transform-op-matmul-to-outerproduct.mlir". Summary:
compilation pipeline (e.g.
transform.structured.vectorize_children_and_apply_patterns
is replaced with
transform.structured.vectorize
),dims,
These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.