-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Enable fusion by expansion of reduction and named ops #83473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Enable fusion by expansion of reduction and named ops #83473
Conversation
This adds support for expansion of linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Quinn Dawkins (qedawkins) ChangesThis adds support for expansion of linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators. Full diff: https://github.com/llvm/llvm-project/pull/83473.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..6310f9105960be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -526,7 +526,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
// - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
- // - All the loops are parallel loops.
+ // - All the loops for the reshaped operand are parallel loops.
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
+ AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
return genericOp.hasPureTensorSemantics() &&
llvm::all_of(genericOp.getIndexingMaps().getValue(),
[](Attribute attr) {
@@ -534,9 +537,11 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.getValue()
.isProjectedPermutation();
}) &&
- genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
- 0 &&
- llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+ operandMap.getNumResults() > 0 &&
+ llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+ return isParallelIterator(
+ iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+ });
}
namespace {
@@ -848,6 +853,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+ for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
+ for (auto i : group)
+ iteratorTypes[i] = type;
+ }
TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..5c0a83258b4b95 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,93 @@ module {
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]], %[[OUTS]] :
// CHECK: return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+ %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) ->
+ tensor<?x?x4x5xf32>
+{
+ %0 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+ %1 = arith.mulf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+ tensor<?x?xf32> into tensor<?x?x4x5xf32>
+ return %1 : tensor<?x?x4x5xf32>
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+// CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME: [0, 1, 2], [3]
+// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME: [0], [1, 2, 3]
+// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
+// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+ %arg1 : tensor<?x4x?xf32>,
+ %arg2 : tensor<?x?xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+ tensor<?x7x?x8xf32> into tensor<?x?xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %1 = arith.mulf %arg3, %arg4 : f32
+ %2 = arith.addf %1, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+// CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME: [0, 1], [2], [3, 4]
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME: [0, 1], [2, 3]
+// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x8x?x7xf32>)
+// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME: [0, 1], [2, 3]
+// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
+// CHECK: return %[[T4]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really follow these logics these days, so I suggest to wait for Mahesh.
sounds good, thanks for taking a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is OK to land, but you can change the utility method here to work on any LinalgOp
. I am suspecting for.your use case you are generalizing the op before applying the reshape. We can skip that step.
Done. The same should probably be done for the collapsing patterns, but is that fine to do as a follow up? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thanks!
This adds support for expansion of named linalg ops and linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators.
For named linalg ops, this always converts the named op into a generic.