Skip to content

Commit eec9d0b

Browse files
authored
[mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface (#94579)
The current implementation of `mergeReduction` in `LinalgOpPartialReductionInterface` builds a `linalg.generic` from scratch. While we already have `linalg.reduce` op which has the same semantic as this generic op, this PR replaces the generic op with `linalg.reduce` to simplify the implementation.
1 parent ca06b61 commit eec9d0b

File tree

2 files changed

+21
-53
lines changed

2 files changed

+21
-53
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -447,39 +447,10 @@ struct LinalgOpPartialReductionInterface
447447
Location loc, ValueRange partialReduce,
448448
ArrayRef<int> reductionDims) const {
449449
auto linalgOp = cast<LinalgOp>(op);
450-
451-
// Step 1. Recover the dims that actually need to be merged from the
452-
// original operation. We can classify the original iterators as follows:
453-
//
454-
// parallel --> parallel
455-
// reduction + not in reductionDims --> parallel (already reduced)
456-
// reduction + in reductionDims --> reduction (will reduce now)
457-
SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
458-
utils::IteratorType::parallel);
459-
for (int redIdx : reductionDims)
460-
iterators[redIdx] = utils::IteratorType::reduction;
461-
462-
// Step 2. For each partial result, create a map to index it. This map
463-
// is simply the indexing map for the original result with reductionDims
464-
// appended (as produced in tileToPartialReduction).
465-
int64_t numInits = linalgOp.getNumDpsInits();
466-
SmallVector<AffineMap> indexingMaps(numInits * 2);
467-
for (int idx : llvm::seq<int>(0, numInits)) {
468-
AffineMap &inputMap = indexingMaps[idx];
469-
AffineMap &outputMap = indexingMaps[numInits + idx];
470-
471-
outputMap =
472-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
473-
inputMap = outputMap;
474-
for (int redPos : reductionDims) {
475-
inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
476-
inputMap.getNumResults());
477-
}
478-
}
479-
480-
auto reduction = b.create<GenericOp>(
481-
loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
482-
indexingMaps, iterators,
450+
SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
451+
reductionDims.end());
452+
auto reduction = b.create<linalg::ReduceOp>(
453+
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
483454
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
484455
int64_t numInits = linalgOp.getNumDpsInits();
485456
SmallVector<Value> yieldedValues;

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
2323
}
2424
}
2525

26-
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
27-
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
28-
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
26+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
27+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
2928
// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
3029
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
3130
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
@@ -37,21 +36,21 @@ module attributes {transform.with_named_sequence} {
3736
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
3837
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
3938
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
40-
// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
39+
// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
4140
// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
4241
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
43-
// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
42+
// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
4443
// CHECK: arith.mulf
4544
// CHECK: arith.addf
4645
// CHECK: linalg.yield
4746
// CHECK: } -> tensor<?x?xf32>
4847
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
4948
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
5049
// CHECK: }
51-
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
50+
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
5251
// CHECK: arith.addf
5352
// CHECK: linalg.yield
54-
// CHECK: } -> tensor<?xf32>
53+
// CHECK: }
5554
// CHECK: return %[[R]] : tensor<?xf32>
5655

5756
// -----
@@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
8180
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
8281
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
8382
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
84-
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
8583
// CHECK: func @reduction_tile_transpose
8684
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
8785
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
@@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
9189
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
9290
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
9391
// CHECK: }
94-
// CHECK: linalg.generic
92+
// CHECK: linalg.reduce
9593
// CHECK: return
9694

9795
// -----
@@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
150148
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
151149
// CHECK: }
152150
// CHECK: }
153-
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
151+
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
152+
// CHECK: {
154153
// CHECK: arith.addf
155154
// CHECK: linalg.yield
156-
// CHECK: } -> tensor<?xf32>
155+
// CHECK: }
157156
// CHECK: return %[[R]] : tensor<?xf32>
158157

159158
// -----
@@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
177176
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
178177
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
179178
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
180-
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
181-
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
182179
// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
183180
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
184181
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
203200
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
204201
// CHECK: }
205202
// CHECK: }
206-
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
203+
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
207204
// CHECK: arith.addf
208205
// CHECK: linalg.yield
209-
// CHECK: } -> tensor<?x?xf32>
206+
// CHECK: }
210207
// CHECK: return %[[R]] : tensor<?x?xf32>
211208

212209
// -----
@@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
270267
// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
271268
// CHECK: }
272269
// CHECK: }
273-
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
270+
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
274271
// CHECK: arith.addf
275272
// CHECK: linalg.yield
276-
// CHECK: } -> tensor<?xf32>
273+
// CHECK: }
277274
// CHECK: return %[[R]] : tensor<?xf32>
278275

279276
// -----
@@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
307304
// CHECK: iterator_types = ["parallel", "reduction"]
308305
transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
309306
// CHECK: expecting parallel reduction
310-
// CHECK-NEXT: linalg.generic
307+
// CHECK-NEXT: linalg.reduce
311308
// CHECK: iterator_types = ["parallel", "reduction"]
312309
transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
313310
transform.yield
@@ -401,7 +398,7 @@ module {
401398
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
402399
// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32>
403400
// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
404-
// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
401+
// CHECK: %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
405402
// CHECK: return %[[OUT2]] : tensor<4096xf32>
406403

407404
// -----
@@ -445,6 +442,6 @@ module attributes {transform.with_named_sequence} {
445442
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
446443
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
447444
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
448-
// CHECK: linalg.generic
445+
// CHECK: linalg.reduce
449446
// CHECK: arith.addf
450447
// CHECK: arith.maximumf

0 commit comments

Comments
 (0)