Skip to content

Commit a625c02

Browse files
[mlir][scf] Return replacements explicitly in SCFTilingResult.
In #120115 the replacements for the tiled operations were wrapped within the `MergeResult` object. That is a bit of an obfuscation and not immediately obvious where to get the replacements post tiling. This changes the `SCFTilingResult` to have `replacements` explicit (as it was before that change). It also makes the `mergeOps` a separate field of `SCFTilingResult`, which is empty when the reduction type is `FullReduction`. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 414590b commit a625c02

File tree

5 files changed

+47
-43
lines changed

5 files changed

+47
-43
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,17 @@ struct SCFTilingResult {
136136
SmallVector<Value> initialValues;
137137
/// The `scf.for` operations that iterate over the tiles.
138138
SmallVector<LoopLikeOpInterface> loops;
139-
/// The result generated by the loop nest in tiling, may hold partial results,
140-
/// which need to be merged to match the computation of the untiled operation.
141-
/// `mergeResult` contains the operations used to perform this merge from
142-
/// partial results and the values that can be used as replacements of
143-
/// the untiled operation.
144-
MergeResult mergeResult;
139+
/// Values to use as replacements for the untiled op. Is the same size as the
140+
/// number of results of the untiled op.
141+
SmallVector<Value> replacements;
145142
/// Slices generated after tiling that can be used for fusing with the tiled
146143
/// producer.
147144
SmallVector<Operation *> generatedSlices;
145+
/// In cases where there as an additional merge step after tiling
146+
/// return the merged ops after tiling. This list is empty when reduction
147+
/// tiling strategy is
148+
/// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
149+
SmallVector<Operation *> mergeOps;
148150
};
149151

150152
/// Method to tile an op that implements the `TilingInterface` using
@@ -362,7 +364,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
362364
/// ```
363365
FailureOr<scf::SCFTilingResult>
364366
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
365-
ArrayRef<OpFoldResult> tileSize);
367+
ArrayRef<OpFoldResult> tileSizes);
366368

367369
} // namespace scf
368370
} // namespace mlir

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
363363
];
364364
}
365365

366-
def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
366+
def PartialReductionOpInterface :
367+
OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
367368
let description = [{
368369
Interface for allowing operations to expose information needed to
369370
tile reductions using partial reduction followed by merge. This is

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
23812381
return emitDefaultDefiniteFailure(target);
23822382

23832383
if (target->getNumResults())
2384-
rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2384+
rewriter.replaceOp(target, maybeTilingResult->replacements);
23852385
else
23862386
rewriter.eraseOp(target);
23872387

@@ -2800,12 +2800,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
28002800

28012801
if (failed(result))
28022802
return emitDefaultSilenceableFailure(target);
2803-
rewriter.replaceOp(target, result->mergeResult.replacements);
2803+
rewriter.replaceOp(target, result->replacements);
28042804
for (Value initValue : result->initialValues)
28052805
results.push_back(initValue.getDefiningOp());
28062806
for (auto parallelTiledOp : result->tiledOps)
28072807
results.push_back(parallelTiledOp);
2808-
for (auto mergeOp : result->mergeResult.mergeOps)
2808+
for (auto mergeOp : result->mergeOps)
28092809
results.push_back(mergeOp);
28102810
results.push_back(result->loops.front());
28112811
return DiagnosedSilenceableFailure::success();
@@ -3229,7 +3229,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
32293229
if (failed(maybeTilingResult))
32303230
return DiagnosedSilenceableFailure::definiteFailure();
32313231

3232-
rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3232+
rewriter.replaceOp(op, maybeTilingResult->replacements);
32333233

32343234
tiled.append(maybeTilingResult->tiledOps);
32353235
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3465,7 +3465,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
34653465
if (failed(maybeTilingResult))
34663466
return transformOp.emitDefaultSilenceableFailure(tileableOp);
34673467

3468-
rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3468+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
34693469

34703470
tilingResult = *maybeTilingResult;
34713471

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,48 +1058,50 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
10581058
assert(succeeded(tilingResult) &&
10591059
"expected tiling result to be computed after loop generation");
10601060

1061-
SmallVector<Value> partialResults;
10621061
if (loops.empty()) {
10631062
// If loops are empty, the tiled op is used as the replacement for the
10641063
// untiled op.
1065-
partialResults = tilingResult->tiledValues;
1066-
} else {
1067-
partialResults = llvm::map_to_vector(loops.front()->getResults(),
1064+
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1065+
tilingResult->tiledValues,
1066+
tilingResult->generatedSlices};
1067+
}
1068+
1069+
auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
10681070
[](OpResult r) -> Value { return r; });
1071+
1072+
// For the full reduction case, there is nothing more to do.
1073+
if (options.reductionStrategy ==
1074+
scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
1075+
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1076+
loopResults, tilingResult->generatedSlices};
10691077
}
10701078

1079+
// The results of the loop needs to be merged.
10711080
FailureOr<MergeResult> mergeResult =
1072-
mergeTilingResults(rewriter, op, partialResults, options);
1081+
mergeTilingResults(rewriter, op, loopResults, options);
10731082
if (failed(mergeResult)) {
10741083
return rewriter.notifyMatchFailure(
10751084
op, "Failed to merge partial results from tiling");
10761085
}
1077-
1078-
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1079-
mergeResult.value(),
1080-
tilingResult->generatedSlices};
1086+
return scf::SCFTilingResult{tilingResult->tiledOps,
1087+
initTensors,
1088+
loops,
1089+
mergeResult->replacements,
1090+
tilingResult->generatedSlices,
1091+
mergeResult->mergeOps};
10811092
}
10821093

10831094
FailureOr<scf::SCFTilingResult>
10841095
mlir::scf::tileReductionUsingScf(RewriterBase &b,
10851096
PartialReductionOpInterface op,
1086-
ArrayRef<OpFoldResult> tileSizes) {
1087-
SCFTilingOptions options;
1088-
options.setLoopType(SCFTilingOptions::LoopType::ForOp);
1089-
options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
1090-
PartialReductionOuterReduction);
1091-
options.setTileSizes(tileSizes);
1092-
1093-
TilingInterface tilingInterfaceOp =
1094-
dyn_cast<TilingInterface>(op.getOperation());
1095-
if (!tilingInterfaceOp) {
1096-
return b.notifyMatchFailure(
1097-
op,
1098-
"Operation implementing PartialReductionOpInterface should implement "
1099-
"TilingInterface");
1100-
}
1101-
1102-
return tileUsingSCF(b, tilingInterfaceOp, options);
1097+
ArrayRef<OpFoldResult> tileSize) {
1098+
scf::SCFTilingOptions options;
1099+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1100+
options.setReductionTilingStrategy(
1101+
scf::SCFTilingOptions::ReductionTilingStrategy::
1102+
PartialReductionOuterReduction);
1103+
options.setTileSizes(tileSize);
1104+
return tileUsingSCF(b, op, options);
11031105
}
11041106

11051107
//===----------------------------------------------------------------------===//
@@ -1539,8 +1541,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15391541
tiledAndFusedOps.insert_range(tilingResult->tiledOps);
15401542

15411543
DenseMap<Value, Value> replacements;
1542-
for (auto [origVal, replacement] : llvm::zip_equal(
1543-
consumer->getResults(), tilingResult->mergeResult.replacements)) {
1544+
for (auto [origVal, replacement] :
1545+
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
15441546
replacements[origVal] = replacement;
15451547
}
15461548

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
260260
return failure();
261261

262262
// Perform the replacement of tiled and fused values.
263-
rewriter.replaceOp(tilingInterfaceOp,
264-
tiledResults->mergeResult.replacements);
263+
rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
265264

266265
// Report back the relevant handles to the transform op.
267266
tiledOps.push_back(tiledResults->tiledOps.front());

0 commit comments

Comments
 (0)