Skip to content

[mlir][nfc] Update Linalg matmul -> Vector OP test #81416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Feb 11, 2024

Updates "transform-op-matmul-to-outerproduct.mlir". Summary:

  • refines TD sequence so that it's easier to reason about the
    compilation pipeline (e.g.
    transform.structured.vectorize_children_and_apply_patterns
    is replaced withtransform.structured.vectorize ),
  • new input dims to be able to distinguish parallel from reduction
    dims,
  • updates LIT variable names (makes the output easier to follow),
  • removes "noise" from the expected LIT output (e.g. types).

These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.

Updates "transform-op-matmul-to-outerproduct.mlir". Summary:
  * refines TD sequence so it's easier to reason about the compilation
    pipeline,
  * new input dims to be able to distinguish parallel from reduction
    dims,
  * updates LIT variable names (makes the output easier to follow),
  * removes "noise" from the expected LIT output (e.g. types).

These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.
@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Updates "transform-op-matmul-to-outerproduct.mlir". Summary:

  • refines TD sequence so it's easier to reason about the compilation
    pipeline,
  • new input dims to be able to distinguish parallel from reduction
    dims,
  • updates LIT variable names (makes the output easier to follow),
  • removes "noise" from the expected LIT output (e.g. types).

These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.


Full diff: https://github.com/llvm/llvm-project/pull/81416.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir (+40-27)
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
index ee66073a9a4193..a1a0c413da0c1c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -1,38 +1,51 @@
 // RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
-func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
-  linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+  linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>)
             outs(%C: memref<3x3xf32>)
   return
 }
 
-// CHECK-LABEL:   func.func @outerproduct_matmul(
-// CHECK-SAME:      %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
-// CHECK:           return
-// CHECK:         }
+// CHECK-LABEL:   func.func @matmul_to_outerproduct(
+// CHECK-SAME:      %[[A:.*]]: memref<3x4xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<4x3xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x3xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]
+// CHECK:           %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+// CHECK:           %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]]
+// CHECK:           %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]]
+// CHECK:           %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]]
+// CHECK:           %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]]
+// CHECK:           vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32>
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
-    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %2 {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // Vectorize: linalg.matmul -> vector.multi_reduction
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %matmul : !transform.any_op
+
+    // vector.multi_reduction --> vector.contract
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transform vector.contract to be more
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+
+    // vector.contract --> vector.outerproduct
+    transform.apply_patterns to %func {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
     } : !transform.any_op
     transform.yield

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

Updates "transform-op-matmul-to-outerproduct.mlir". Summary:

  • refines TD sequence so it's easier to reason about the compilation
    pipeline,
  • new input dims to be able to distinguish parallel from reduction
    dims,
  • updates LIT variable names (makes the output easier to follow),
  • removes "noise" from the expected LIT output (e.g. types).

These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.


Full diff: https://github.com/llvm/llvm-project/pull/81416.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir (+40-27)
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
index ee66073a9a4193..a1a0c413da0c1c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -1,38 +1,51 @@
 // RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
-func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
-  linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+  linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>)
             outs(%C: memref<3x3xf32>)
   return
 }
 
-// CHECK-LABEL:   func.func @outerproduct_matmul(
-// CHECK-SAME:      %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
-// CHECK:           return
-// CHECK:         }
+// CHECK-LABEL:   func.func @matmul_to_outerproduct(
+// CHECK-SAME:      %[[A:.*]]: memref<3x4xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<4x3xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x3xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]
+// CHECK:           %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+// CHECK:           %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]]
+// CHECK:           %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]]
+// CHECK:           %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]]
+// CHECK:           %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]]
+// CHECK:           vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32>
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
-    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %2 {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // Vectorize: linalg.matmul -> vector.multi_reduction
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %matmul : !transform.any_op
+
+    // vector.multi_reduction --> vector.contract
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transform vector.contract to be more
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+
+    // vector.contract --> vector.outerproduct
+    transform.apply_patterns to %func {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
     } : !transform.any_op
     transform.yield

@banach-space banach-space merged commit 3b26948 into llvm:main Mar 13, 2024
@banach-space banach-space deleted the andrzej/refactor_matmul_test branch March 16, 2024 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants