Skip to content

[mlir][TilingInterface] Update PartialReductionOpInterface to get it more in line with TilingInterface. #95460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
/// Transformation information returned after reduction tiling.
struct ForallReductionTilingResult {
/// The partial reduction tiled op generated.
Operation *parallelTiledOp;
SmallVector<Operation *> parallelTiledOps;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
SmallVector<Operation *> mergeOps;
/// Initial values used for partial reductions.
SmallVector<Value> initialValues;
/// The `scf.forall` operation that iterate over the tiles.
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
/// Transformation information returned after reduction tiling.
struct SCFReductionTilingResult {
/// The partial reduction tiled op generated.
Operation *parallelTiledOp;
SmallVector<Operation *> parallelTiledOps;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
SmallVector<Operation *> mergeOps;
/// Initial values used for reduction.
SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
/// The replacements to use for the results of the tiled operation.
SmallVector<Value> replacements;
};

/// Method to tile a reduction and generate a parallel op within a serial loop.
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ struct TilingResult {
SmallVector<Value> tiledValues;
};

/// Container for the result of merge operation of tiling.
/// - `mergeOps` contains operations created during the merge.
/// - `replacements` contains the values that represents the result of the
/// merge. These are used as replacements for the original tiled operation.
struct MergeResult {
SmallVector<Operation *> mergeOps;
SmallVector<Value> replacements;
};

} // namespace mlir

/// Include the ODS generated interface header files.
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
less or equal to the tile size. This is meant to be used with
`mergeReductions` method which will combine the partial reductions.
}],
/*retType=*/"Operation*",
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"tileToPartialReduction",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -258,7 +258,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"ArrayRef<int>":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return nullptr;
return failure();
}]
>,
InterfaceMethod<
Expand All @@ -267,7 +267,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
tiled along the reduction dimensions. This will only apply the
reduction the operation.
}],
/*retType=*/"Operation*",
/*retType=*/"FailureOr<MergeResult>",
/*methodName=*/"mergeReductions",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -276,7 +276,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"ArrayRef<int>":$reductionDim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return nullptr;
return failure();
}]
>
];
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
return emitDefaultSilenceableFailure(target);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
for (auto parallelTiledOp : result->parallelTiledOps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand auto, here and below.

results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
}
Expand Down Expand Up @@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
}
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
for (auto parallelTiledOp : result->parallelTiledOps)
results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops);
return DiagnosedSilenceableFailure::success();
}
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(

// 7. Merge the partial reductions.
b.setInsertionPointAfter(forallOp);
Operation *mergeOp =
FailureOr<MergeResult> mergeResult =
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
b.replaceOp(op, mergeOp->getResults());
if (failed(mergeResult)) {
return failure();
}
b.replaceOp(op, mergeResult->replacements);

// 8. Return.
ForallReductionTilingResult results;
results.initialValues = initTensors;
results.loops = forallOp;
results.parallelTiledOp = tiledOp;
results.mergeOp = mergeOp;
results.parallelTiledOps.push_back(tiledOp);
results.mergeOps.append(mergeResult->mergeOps);
return results;
}

Expand Down
26 changes: 16 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface
return inits;
}

Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
ValueRange init,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<int> reductionDims) const {
FailureOr<TilingResult>
tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
ValueRange init, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<int> reductionDims) const {
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);

Expand Down Expand Up @@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
return genericOp.getOperation();
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
[](OpResult r) -> Value { return r; })};
}

Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
Location loc, ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);

// Step 1. Recover the dims that actually need to be merged from the
Expand Down Expand Up @@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface
}
b.create<linalg::YieldOp>(loc, yieldedValues);
});
return reduction.getOperation();
return MergeResult{
{reduction.getOperation()},
llvm::map_to_vector(reduction->getResults(),
[](OpResult r) -> Value { return r; })};
}
};

Expand Down
55 changes: 34 additions & 21 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
SmallVector<Value> &initTensors = maybeInitTensors.value();

// 3. Define the callback to use for generating the inner most tile loop body.
Operation *parallelOp = nullptr;
SmallVector<Operation *> parallelTiledOps;
auto innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
Expand All @@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
}

// 4a. Clone the operation.
auto clonedOp = cast<PartialReductionOpInterface>(
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
{
auto clonedOp = cast<PartialReductionOpInterface>(
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));

// 4b. Tile the cloned operation.
FailureOr<TilingResult> partialTilingResult =
clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
sizes, reductionDims);
if (failed(partialTilingResult)) {
return failure();
}
std::swap(parallelTiledOps, partialTilingResult->tiledOps);
std::swap(tiledResult, partialTilingResult->tiledValues);

// 4b. Tile the cloned operation.
parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
offsets, sizes, reductionDims);
// 4c. Delete the cloned operation.
b.eraseOp(clonedOp);
// 4c. Delete the cloned operation.
b.eraseOp(clonedOp);
}

tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
// 4d. Compute the offsets and sizes needed to insert the result of the
// tiled value back into destination before yielding the destination.
for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
for (auto result : tiledResult) {
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
resultOffsets.emplace_back(std::move(outOffsets));

SmallVector<OpFoldResult> outSizes;
for (size_t i = 0; i < offsets.size(); i++) {
outSizes.push_back(
tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
}
resultSizes.emplace_back(std::move(outSizes));
}
Expand All @@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,

// 5. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
b.replaceOp(op, mergeOp->getResults());

SCFReductionTilingResult results;
results.initialValues = initTensors;
results.loops = loops;
results.parallelTiledOp = parallelOp;
results.mergeOp = mergeOp;
return results;
FailureOr<MergeResult> mergeResult =
op.mergeReductions(b, loc, replacements, reductionDims);
if (failed(mergeResult)) {
return failure();
}
b.replaceOp(op, mergeResult->replacements);

SCFReductionTilingResult reductionTilingResult;
std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
std::swap(reductionTilingResult.initialValues, initTensors);
std::swap(reductionTilingResult.loops, loops);
std::swap(reductionTilingResult.replacements, mergeResult->replacements);

return reductionTilingResult;
}

//===----------------------------------------------------------------------===//
Expand Down
Loading