Skip to content

[mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface #94579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 4 additions & 33 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,39 +447,10 @@ struct LinalgOpPartialReductionInterface
Location loc, ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);

// Step 1. Recover the dims that actually need to be merged from the
// original operation. We can classify the original iterators as follows:
//
// parallel --> parallel
// reduction + not in reductionDims --> parallel (already reduced)
// reduction + in reductionDims --> reduction (will reduce now)
SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
utils::IteratorType::parallel);
for (int redIdx : reductionDims)
iterators[redIdx] = utils::IteratorType::reduction;

// Step 2. For each partial result, create a map to index it. This map
// is simply the indexing map for the original result with reductionDims
// appended (as produced in tileToPartialReduction).
int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<AffineMap> indexingMaps(numInits * 2);
for (int idx : llvm::seq<int>(0, numInits)) {
AffineMap &inputMap = indexingMaps[idx];
AffineMap &outputMap = indexingMaps[numInits + idx];

outputMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
inputMap = outputMap;
for (int redPos : reductionDims) {
inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
inputMap.getNumResults());
}
}

auto reduction = b.create<GenericOp>(
loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
indexingMaps, iterators,
SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
reductionDims.end());
auto reduction = b.create<linalg::ReduceOp>(
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Value> yieldedValues;
Expand Down
37 changes: 17 additions & 20 deletions mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
}
}

// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
Expand All @@ -37,21 +36,21 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?x?xf32>
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
// CHECK: }
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?xf32>
// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>

// -----
Expand Down Expand Up @@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: func @reduction_tile_transpose
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
Expand All @@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
// CHECK: }
// CHECK: linalg.generic
// CHECK: linalg.reduce
// CHECK: return

// -----
Expand Down Expand Up @@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: {
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?xf32>
// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>

// -----
Expand All @@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Expand All @@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?x?xf32>
// CHECK: }
// CHECK: return %[[R]] : tensor<?x?xf32>

// -----
Expand Down Expand Up @@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?xf32>
// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>

// -----
Expand Down Expand Up @@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
// CHECK: expecting parallel reduction
// CHECK-NEXT: linalg.generic
// CHECK-NEXT: linalg.reduce
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
transform.yield
Expand Down Expand Up @@ -401,7 +398,7 @@ module {
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32>
// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
// CHECK: %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
// CHECK: return %[[OUT2]] : tensor<4096xf32>

// -----
Expand Down Expand Up @@ -445,6 +442,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
// CHECK: linalg.generic
// CHECK: linalg.reduce
// CHECK: arith.addf
// CHECK: arith.maximumf
Loading