Skip to content

[mlir][linalg] Enable fusion by expansion of reduction and named ops #83473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

qedawkins
Copy link
Contributor

@qedawkins qedawkins commented Feb 29, 2024

This adds support for expansion of named linalg ops and linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators.

For named linalg ops, this always converts the named op into a generic.

This adds support for expansion of linalg ops with reduction iterators.
This improves the ability to make fusion decisions WRT reduction
operations. To recover the previous behavior, users of the patterns can
add a control function to restrict propagation of reshape by expansion
through linalg ops with reduction iterators.
@llvmbot
Copy link
Member

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

Changes

This adds support for expansion of linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators.


Full diff: https://github.com/llvm/llvm-project/pull/83473.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+14-4)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+90)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..6310f9105960be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -526,7 +526,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops are parallel loops.
+  // - All the loops for the reshaped operand are parallel loops.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
   return genericOp.hasPureTensorSemantics() &&
          llvm::all_of(genericOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
@@ -534,9 +537,11 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         operandMap.getNumResults() > 0 &&
+         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+           return isParallelIterator(
+               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+         });
 }
 
 namespace {
@@ -848,6 +853,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
+    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
+    for (auto i : group)
+      iteratorTypes[i] = type;
+  }
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..5c0a83258b4b95 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,93 @@ module {
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
 //      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+                                                        %arg1 : tensor<?x?xf32>,
+                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1, 2], [3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+                                         %arg1 : tensor<?x4x?xf32>,
+                                         %arg2 : tensor<?x?xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      %2 = arith.addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2], [3, 4]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really follow these logics these days, so I suggest to wait for Mahesh.

@qedawkins
Copy link
Contributor Author

sounds good, thanks for taking a look!

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is OK to land, but you can change the utility method here to work on any LinalgOp. I am suspecting for.your use case you are generalizing the op before applying the reshape. We can skip that step.

@qedawkins qedawkins changed the title [mlir][linalg] Enable expansion of parallel dims of reduction ops [mlir][linalg] Enable fusion by expansion of reduction and named ops Mar 3, 2024
@qedawkins
Copy link
Contributor Author

qedawkins commented Mar 3, 2024

This is OK to land, but you can change the utility method here to work on any LinalgOp. I am suspecting for.your use case you are generalizing the op before applying the reshape. We can skip that step.

Done. The same should probably be done for the collapsing patterns, but is that fine to do as a follow up?

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thanks!

@qedawkins qedawkins merged commit 3f18f6a into llvm:main Mar 3, 2024
@qedawkins qedawkins deleted the fusion_through_reshape_with_reduction_dims branch March 3, 2024 06:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants