Skip to content

[mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface #94579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2024

Conversation

zhczhong
Copy link
Member

@zhczhong zhczhong commented Jun 6, 2024

The current implementation of mergeReduction in LinalgOpPartialReductionInterface builds a linalg.generic from scratch. While we already have linalg.reduce op which has the same semantic as this generic op, this PR replaces the generic op with linalg.reduce to simplify the implementation.

Copy link

github-actions bot commented Jun 6, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: zhicong zhong (zhczhong)

Changes

The current implementation of mergeReduction in LinalgOpPartialReductionInterface builds a linalg.generic from scratch. While we already have linalg.reduce op which has the same semantic as this generic op. This PR replaces the generic op with linalg.reduce to simplify the implementation.


Full diff: https://github.com/llvm/llvm-project/pull/94579.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+5-33)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+17-20)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..c038d03c15342 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -443,40 +443,12 @@ struct LinalgOpPartialReductionInterface
   Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
                              ValueRange partialReduce,
                              ArrayRef<int> reductionDims) const {
-    auto linalgOp = cast<LinalgOp>(op);
 
-    // Step 1. Recover the dims that actually need to be merged from the
-    // original operation. We can classify the original iterators as follows:
-    //
-    // parallel                         --> parallel
-    // reduction + not in reductionDims --> parallel (already reduced)
-    // reduction + in reductionDims     --> reduction (will reduce now)
-    SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
-                                               utils::IteratorType::parallel);
-    for (int redIdx : reductionDims)
-      iterators[redIdx] = utils::IteratorType::reduction;
-
-    // Step 2. For each partial result, create a map to index it. This map
-    // is simply the indexing map for the original result with reductionDims
-    // appended (as produced in tileToPartialReduction).
-    int64_t numInits = linalgOp.getNumDpsInits();
-    SmallVector<AffineMap> indexingMaps(numInits * 2);
-    for (int idx : llvm::seq<int>(0, numInits)) {
-      AffineMap &inputMap = indexingMaps[idx];
-      AffineMap &outputMap = indexingMaps[numInits + idx];
-
-      outputMap =
-          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
-      inputMap = outputMap;
-      for (int redPos : reductionDims) {
-        inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
-                                         inputMap.getNumResults());
-      }
-    }
-
-    auto reduction = b.create<GenericOp>(
-        loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
-        indexingMaps, iterators,
+    auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
+                                            reductionDims.end());
+    auto reduction = b.create<linalg::ReduceOp>(
+        loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
         [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
           int64_t numInits = linalgOp.getNumDpsInits();
           SmallVector<Value> yieldedValues;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index f3cf7c4dffa05..4a8bb42676fdb 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 //     CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
@@ -37,10 +36,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
-//     CHECK:     %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+//     CHECK:     %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
 //     CHECK:     %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
 //     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
-//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
 //     CHECK:       arith.mulf
 //     CHECK:       arith.addf
 //     CHECK:       linalg.yield
@@ -48,10 +47,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
 //     CHECK:     scf.yield %[[INS]] : tensor<?x5xf32>
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
 //     CHECK: func @reduction_tile_transpose
 //     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
 //     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
@@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
 //     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
 //     CHECK:   }
-//     CHECK:   linalg.generic
+//     CHECK:   linalg.reduce
 //     CHECK:   return
 
 // -----
@@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
+//     CHECK:   {
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 //     CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?x?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?x?xf32>
 
 // -----
@@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
     //      CHECK:     iterator_types = ["parallel", "reduction"]
     transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
     //      CHECK:     expecting parallel reduction
-    // CHECK-NEXT:     linalg.generic
+    // CHECK-NEXT:     linalg.reduce
     //      CHECK:     iterator_types = ["parallel", "reduction"]
     transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
     transform.yield
@@ -402,7 +399,7 @@ module {
 // CHECK:       %[[OUT:.*]] = linalg.generic  {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
 // CHECK:       scf.yield %[[OUT]] : tensor<4096x2x64xf32>
 // CHECK:     scf.yield %[[L1]] : tensor<4096x2x64xf32>
-// CHECK:   %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK:   %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
 // CHECK:  return %[[OUT2]] : tensor<4096xf32>
 
 // -----
@@ -446,6 +443,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
 // CHECK:       %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
 // CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
-// CHECK:       linalg.generic
+// CHECK:       linalg.reduce
 // CHECK:         arith.addf
 // CHECK:         arith.maximumf

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just have a question from downstream use of this. If I run generalization of the linalg.reduce op do we get back the same linalg.generic generated?

@zhczhong
Copy link
Member Author

zhczhong commented Jun 7, 2024

I just have a question from downstream use of this. If I run generalization of the linalg.reduce op do we get back the same linalg.generic generated?

Yes, the linalg generalization can convert the linalg.reduce to linalg.generic in the same form as the original implementation.

func.func @test(%input: tensor<16x32x64xf32>,
                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
  %reduce = linalg.reduce
      ins(%input:tensor<16x32x64xf32>)
      outs(%init:tensor<16x64xf32>)
      dimensions = [1]
      (%in: f32, %out: f32) {
        %0 = arith.addf %out, %in: f32
        linalg.yield %0: f32
      }
  func.return %reduce : tensor<16x64xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
    %1 = transform.structured.generalize %0 : (!transform.any_op) -> !transform.any_op
    transform.yield
  }
}

will be converted to

#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @test(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) -> tensor<16x64xf32> {
  %0 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<16x32x64xfarg0 : tensor<16x32x64xf32>) outs(%arg1 : tensor<16x64xf32>) {                                                       
  ^bb0(%in: f32, %out: f32):
    %1 = arith.addf %out, %in : f32
    linalg.yield %1 : f32
  } -> tensor<16x64xf32>
  return %0 : tensor<16x64xf32>
}

@ftynse ftynse requested a review from nicolasvasilache June 7, 2024 09:21
@zhczhong
Copy link
Member Author

Do you have any further comments on this PR? @MaheshRavishankar @nicolasvasilache @rengolin @ZhennanQin

@zhczhong zhczhong force-pushed the simplify_partial_interface branch from d20cccf to 7833b51 Compare June 19, 2024 01:28
@MaheshRavishankar
Copy link
Contributor

@qedawkins could you PTAL. I think this is fine, but you have more hands on experience with this. In any case if we run into issues we could find and generalize the reduction away.

@zhczhong zhczhong merged commit eec9d0b into llvm:main Jun 28, 2024
7 checks passed
Copy link

@zhczhong Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…artialReductionInterface (llvm#94579)

The current implementation of `mergeReduction` in
`LinalgOpPartialReductionInterface` builds a `linalg.generic` from
scratch. While we already have `linalg.reduce` op which has the same
semantic as this generic op, this PR replaces the generic op with
`linalg.reduce` to simplify the implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants