Skip to content

Commit eb580db

Browse files
MaheshRavishankarAlexisPerry
authored andcommitted
[mlir][TilingInterface] Update PartialReductionOpInterface to get it more in line with TilingInterface. (llvm#95460)
The `TilingInterface` methods have return values that allow the interface implementation to return multiple operations, and also return tiled values explicitly. This is to avoid the assumption that the interface needs to return a single operation and this operations result are the expected tiled values. Make the `PartialReductionOpInterface::tileToPartialReduction` return `TilingResult` as well for the same reason. Similarly make the `PartialReductionOpInterface::mergeReductions` also return a list of generated operations and values to use as replacements. This is just a refactoring to allow for deprecation of `linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF` method.
1 parent c27b9d5 commit eb580db

File tree

8 files changed

+84
-47
lines changed

8 files changed

+84
-47
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
873873
/// Transformation information returned after reduction tiling.
874874
struct ForallReductionTilingResult {
875875
/// The partial reduction tiled op generated.
876-
Operation *parallelTiledOp;
876+
SmallVector<Operation *> parallelTiledOps;
877877
/// The final reduction operation merging all the partial reductions.
878-
Operation *mergeOp;
878+
SmallVector<Operation *> mergeOps;
879879
/// Initial values used for partial reductions.
880880
SmallVector<Value> initialValues;
881881
/// The `scf.forall` operation that iterate over the tiles.

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
261261
/// Transformation information returned after reduction tiling.
262262
struct SCFReductionTilingResult {
263263
/// The partial reduction tiled op generated.
264-
Operation *parallelTiledOp;
264+
SmallVector<Operation *> parallelTiledOps;
265265
/// The final reduction operation merging all the partial reductions.
266-
Operation *mergeOp;
266+
SmallVector<Operation *> mergeOps;
267267
/// Initial values used for reduction.
268268
SmallVector<Value> initialValues;
269269
/// The loop operations that iterate over the tiles.
270270
SmallVector<LoopLikeOpInterface> loops;
271+
/// The replacements to use for the results of the tiled operation.
272+
SmallVector<Value> replacements;
271273
};
272274

273275
/// Method to tile a reduction and generate a parallel op within a serial loop.

mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ struct TilingResult {
3333
SmallVector<Value> tiledValues;
3434
};
3535

36+
/// Container for the result of merge operation of tiling.
37+
/// - `mergeOps` contains operations created during the merge.
38+
/// - `replacements` contains the values that represents the result of the
39+
/// merge. These are used as replacements for the original tiled operation.
40+
struct MergeResult {
41+
SmallVector<Operation *> mergeOps;
42+
SmallVector<Value> replacements;
43+
};
44+
3645
} // namespace mlir
3746

3847
/// Include the ODS generated interface header files.

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
360360
less or equal to the tile size. This is meant to be used with
361361
`mergeReductions` method which will combine the partial reductions.
362362
}],
363-
/*retType=*/"Operation*",
363+
/*retType=*/"FailureOr<TilingResult>",
364364
/*methodName=*/"tileToPartialReduction",
365365
/*args=*/(ins
366366
"OpBuilder &":$b,
@@ -371,7 +371,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
371371
"ArrayRef<int>":$reductionDims),
372372
/*methodBody=*/"",
373373
/*defaultImplementation=*/[{
374-
return nullptr;
374+
return failure();
375375
}]
376376
>,
377377
InterfaceMethod<
@@ -380,7 +380,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
380380
tiled along the reduction dimensions. This will only apply the
381381
reduction the operation.
382382
}],
383-
/*retType=*/"Operation*",
383+
/*retType=*/"FailureOr<MergeResult>",
384384
/*methodName=*/"mergeReductions",
385385
/*args=*/(ins
386386
"OpBuilder &":$b,
@@ -389,7 +389,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
389389
"ArrayRef<int>":$reductionDim),
390390
/*methodBody=*/"",
391391
/*defaultImplementation=*/[{
392-
return nullptr;
392+
return failure();
393393
}]
394394
>
395395
];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
25252525
return emitDefaultSilenceableFailure(target);
25262526
for (Value initValue : result->initialValues)
25272527
results.push_back(initValue.getDefiningOp());
2528-
results.push_back(result->parallelTiledOp);
2529-
results.push_back(result->mergeOp);
2528+
for (auto parallelTiledOp : result->parallelTiledOps)
2529+
results.push_back(parallelTiledOp);
2530+
for (auto mergeOp : result->mergeOps)
2531+
results.push_back(mergeOp);
25302532
results.push_back(result->loops.front());
25312533
return DiagnosedSilenceableFailure::success();
25322534
}
@@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25772579
}
25782580
for (Value initValue : result->initialValues)
25792581
results.push_back(initValue.getDefiningOp());
2580-
results.push_back(result->parallelTiledOp);
2581-
results.push_back(result->mergeOp);
2582+
for (auto parallelTiledOp : result->parallelTiledOps)
2583+
results.push_back(parallelTiledOp);
2584+
for (auto mergeOp : result->mergeOps)
2585+
results.push_back(mergeOp);
25822586
results.push_back(result->loops);
25832587
return DiagnosedSilenceableFailure::success();
25842588
}

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
833833

834834
// 7. Merge the partial reductions.
835835
b.setInsertionPointAfter(forallOp);
836-
Operation *mergeOp =
836+
FailureOr<MergeResult> mergeResult =
837837
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
838-
b.replaceOp(op, mergeOp->getResults());
838+
if (failed(mergeResult)) {
839+
return failure();
840+
}
841+
b.replaceOp(op, mergeResult->replacements);
839842

840843
// 8. Return.
841844
ForallReductionTilingResult results;
842845
results.initialValues = initTensors;
843846
results.loops = forallOp;
844-
results.parallelTiledOp = tiledOp;
845-
results.mergeOp = mergeOp;
847+
results.parallelTiledOps.push_back(tiledOp);
848+
results.mergeOps.append(mergeResult->mergeOps);
846849
return results;
847850
}
848851

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface
368368
return inits;
369369
}
370370

371-
Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
372-
ValueRange init,
373-
ArrayRef<OpFoldResult> offsets,
374-
ArrayRef<OpFoldResult> sizes,
375-
ArrayRef<int> reductionDims) const {
371+
FailureOr<TilingResult>
372+
tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
373+
ValueRange init, ArrayRef<OpFoldResult> offsets,
374+
ArrayRef<OpFoldResult> sizes,
375+
ArrayRef<int> reductionDims) const {
376376
OpBuilder::InsertionGuard guard(b);
377377
auto linalgOp = cast<LinalgOp>(op);
378378

@@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface
437437
IRMapping mapping;
438438
op->getRegion(0).cloneInto(&genericOp.getRegion(),
439439
genericOp.getRegion().begin(), mapping);
440-
return genericOp.getOperation();
440+
return TilingResult{
441+
{genericOp.getOperation()},
442+
llvm::map_to_vector(genericOp->getResults(),
443+
[](OpResult r) -> Value { return r; })};
441444
}
442445

443-
Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
444-
ValueRange partialReduce,
445-
ArrayRef<int> reductionDims) const {
446+
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
447+
Location loc, ValueRange partialReduce,
448+
ArrayRef<int> reductionDims) const {
446449
auto linalgOp = cast<LinalgOp>(op);
447450

448451
// Step 1. Recover the dims that actually need to be merged from the
@@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface
493496
}
494497
b.create<linalg::YieldOp>(loc, yieldedValues);
495498
});
496-
return reduction.getOperation();
499+
return MergeResult{
500+
{reduction.getOperation()},
501+
llvm::map_to_vector(reduction->getResults(),
502+
[](OpResult r) -> Value { return r; })};
497503
}
498504
};
499505

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
718718
SmallVector<Value> &initTensors = maybeInitTensors.value();
719719

720720
// 3. Define the callback to use for generating the inner most tile loop body.
721-
Operation *parallelOp = nullptr;
721+
SmallVector<Operation *> parallelTiledOps;
722722
auto innerYieldTiledValuesFn =
723723
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
724724
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
@@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
743743
}
744744

745745
// 4a. Clone the operation.
746-
auto clonedOp = cast<PartialReductionOpInterface>(
747-
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
746+
{
747+
auto clonedOp = cast<PartialReductionOpInterface>(
748+
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
749+
750+
// 4b. Tile the cloned operation.
751+
FailureOr<TilingResult> partialTilingResult =
752+
clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
753+
sizes, reductionDims);
754+
if (failed(partialTilingResult)) {
755+
return failure();
756+
}
757+
std::swap(parallelTiledOps, partialTilingResult->tiledOps);
758+
std::swap(tiledResult, partialTilingResult->tiledValues);
748759

749-
// 4b. Tile the cloned operation.
750-
parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
751-
offsets, sizes, reductionDims);
752-
// 4c. Delete the cloned operation.
753-
b.eraseOp(clonedOp);
760+
// 4c. Delete the cloned operation.
761+
b.eraseOp(clonedOp);
762+
}
754763

755-
tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
756764
// 4d. Compute the offsets and sizes needed to insert the result of the
757765
// tiled value back into destination before yielding the destination.
758-
for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
766+
for (auto result : tiledResult) {
759767
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
760768
resultOffsets.emplace_back(std::move(outOffsets));
761769

762770
SmallVector<OpFoldResult> outSizes;
763771
for (size_t i = 0; i < offsets.size(); i++) {
764-
outSizes.push_back(
765-
tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
772+
outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
766773
}
767774
resultSizes.emplace_back(std::move(outSizes));
768775
}
@@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
782789

783790
// 5. Apply the merge reduction to combine all the partial values.
784791
b.setInsertionPointAfter(*loops.begin());
785-
Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
786-
b.replaceOp(op, mergeOp->getResults());
787-
788-
SCFReductionTilingResult results;
789-
results.initialValues = initTensors;
790-
results.loops = loops;
791-
results.parallelTiledOp = parallelOp;
792-
results.mergeOp = mergeOp;
793-
return results;
792+
FailureOr<MergeResult> mergeResult =
793+
op.mergeReductions(b, loc, replacements, reductionDims);
794+
if (failed(mergeResult)) {
795+
return failure();
796+
}
797+
b.replaceOp(op, mergeResult->replacements);
798+
799+
SCFReductionTilingResult reductionTilingResult;
800+
std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
801+
std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
802+
std::swap(reductionTilingResult.initialValues, initTensors);
803+
std::swap(reductionTilingResult.loops, loops);
804+
std::swap(reductionTilingResult.replacements, mergeResult->replacements);
805+
806+
return reductionTilingResult;
794807
}
795808

796809
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)