Skip to content

Commit 20271d9

Browse files
Drop yieldTiledValuesAndReplace as an interface method.
1 parent 3f51bc2 commit 20271d9

File tree

3 files changed

+169
-157
lines changed

3 files changed

+169
-157
lines changed

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -220,55 +220,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
220220
/*defaultImplementation=*/[{
221221
return ::mlir::failure();
222222
}]
223-
>,
224-
InterfaceMethod<[{
225-
Append the specified additional "init" operands: replace this loop with
226-
a new loop that has the additional init operands. The loop body of
227-
this loop is moved over to the new loop.
228-
229-
This method is similar to `replaceWithAdditionalYields` but instead of
230-
yielding a value from within the loop, it allows each loop construct
231-
implementing this method to handle the result of each iteration
232-
appropriately. This allows for unified handling of operations
233-
like `scf.forall` which don't yield a value from the loop, but instead
234-
the terminator specifies where to insert the tile computed by the body of
235-
the loop. For example,
236-
237-
```mlir
238-
%0 = scf.forall ... shared_outs(%arg0 = %arg1) {
239-
...
240-
%tiled_value = ...
241-
scf.forall.in_parallel {
242-
tensor.parallel_insert_slice %tiled_value into %arg0[%o1, %o2]...
243-
}
244-
}
245-
```
246-
247-
For an `scf.for` the same computation would be represented as
248-
```mlir
249-
%0 = scf.for ... iter_args(%arg0 = %arg1) {
250-
...
251-
%tiled_value = ...
252-
%insert = tensor.insert_slice %tiled_value into %arg0[%o1, %o2]...
253-
scf.yield %insert
254-
}
255-
```
256-
257-
So for the caller, the tiled value (`%tiled_value`) and the offsets
258-
`(%o1, %o2)` and sizes (not shown) are generated the same way, but
259-
the implementation method for the different loop constructs handles
260-
the difference in representation.
261-
}],
262-
/*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
263-
/*methodName=*/"yieldTiledValuesAndReplace",
264-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
265-
"::mlir::ValueRange":$newInitOperands,
266-
"::mlir::YieldTiledValuesFn":$yieldTiledValuesFn),
267-
/*methodBody=*/"",
268-
/*defaultImplementation=*/[{
269-
return ::mlir::failure();
270-
}]
271-
>,
223+
>
272224
];
273225

274226
let extraClassDeclaration = [{

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -588,64 +588,6 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
588588
return cast<LoopLikeOpInterface>(newLoop.getOperation());
589589
}
590590

591-
FailureOr<LoopLikeOpInterface>
592-
ForOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
593-
ValueRange newInitOperands,
594-
YieldTiledValuesFn yieldTiledValuesFn) {
595-
OpBuilder::InsertionGuard g(rewriter);
596-
rewriter.setInsertionPoint(getOperation());
597-
598-
auto inits = llvm::to_vector(getInitArgs());
599-
inits.append(newInitOperands.begin(), newInitOperands.end());
600-
auto newLoop = rewriter.create<ForOp>(
601-
getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
602-
[](OpBuilder &, Location, Value, ValueRange) {});
603-
604-
// Move the loop body to the new op.
605-
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
606-
newLoop.getBody()->getArguments().take_front(
607-
getBody()->getNumArguments()));
608-
609-
auto yieldOp = cast<scf::YieldOp>(newLoop.getBody()->getTerminator());
610-
rewriter.setInsertionPoint(yieldOp);
611-
612-
SmallVector<Value> tiledValues;
613-
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
614-
ValueRange newRegionIterArgs =
615-
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
616-
if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVar(),
617-
newRegionIterArgs, tiledValues, resultOffsets,
618-
resultSizes))) {
619-
return rewriter.notifyMatchFailure(getOperation(),
620-
"failed to get tiled values");
621-
}
622-
623-
if (tiledValues.size() != resultOffsets.size() ||
624-
tiledValues.size() != resultSizes.size()) {
625-
return rewriter.notifyMatchFailure(
626-
getOperation(),
627-
"expected number of tiled values returned, the number of offset "
628-
"vectors and number of size vectors to be the same");
629-
}
630-
631-
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
632-
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
633-
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
634-
resultSizes)) {
635-
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
636-
rewriter.getIndexAttr(1));
637-
Value insert = rewriter.create<tensor::InsertSliceOp>(
638-
yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
639-
resultStride);
640-
newYieldValues.push_back(insert);
641-
}
642-
643-
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
644-
rewriter.replaceOp(getOperation(),
645-
newLoop->getResults().take_front(getNumResults()));
646-
return cast<LoopLikeOpInterface>(newLoop.getOperation());
647-
}
648-
649591
ForOp mlir::scf::getForInductionVarOwner(Value val) {
650592
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
651593
if (!ivArg)
@@ -692,54 +634,6 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
692634
return getOutputsMutable();
693635
}
694636

695-
FailureOr<LoopLikeOpInterface>
696-
ForallOp::yieldTiledValuesAndReplace(RewriterBase &rewriter,
697-
ValueRange newInitOperands,
698-
YieldTiledValuesFn yieldTiledValuesFn) {
699-
OpBuilder::InsertionGuard g(rewriter);
700-
rewriter.setInsertionPoint(getOperation());
701-
auto inits = llvm::to_vector(getOutputs());
702-
inits.append(newInitOperands.begin(), newInitOperands.end());
703-
auto newLoop = rewriter.create<scf::ForallOp>(
704-
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
705-
inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
706-
707-
// Move the region of the current block to the newly created op.
708-
Block *newLoopBody = newLoop.getBody();
709-
rewriter.mergeBlocks(
710-
getBody(), newLoopBody,
711-
newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
712-
713-
auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
714-
rewriter.setInsertionPoint(terminator);
715-
SmallVector<Value> tiledValues;
716-
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
717-
ValueRange regionIterArgs =
718-
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
719-
if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVars(),
720-
regionIterArgs, tiledValues, resultOffsets,
721-
resultSizes))) {
722-
return rewriter.notifyMatchFailure(getOperation(),
723-
"failed to get yielded tiled values");
724-
}
725-
726-
// Update the terminator.
727-
rewriter.setInsertionPointToEnd(terminator.getBody());
728-
729-
for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
730-
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
731-
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
732-
rewriter.getIndexAttr(1));
733-
rewriter.create<tensor::ParallelInsertSliceOp>(
734-
terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
735-
resultStride);
736-
}
737-
738-
rewriter.replaceOp(getOperation(),
739-
newLoop->getResults().take_front(getNumResults()));
740-
return cast<LoopLikeOpInterface>(newLoop.getOperation());
741-
}
742-
743637
/// Promotes the loop body of a scf::ForallOp to its containing block.
744638
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
745639
OpBuilder::InsertionGuard g(rewriter);

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2525
#include "mlir/Interfaces/TilingInterface.h"
26+
#include "llvm/ADT/TypeSwitch.h"
2627
#include "llvm/Support/Debug.h"
2728
#include <optional>
2829

@@ -287,6 +288,171 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
287288
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
288289
}
289290

291+
/// A function that allows returning additional yielded values during
292+
/// `yieldTiledValuesAndReplace`.
293+
/// - `ivs` induction variable for the loop.
294+
/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
295+
/// - `tiledValues` the tiled values to return. Must be of same size as
296+
/// `newbbArgs`, each element of this array is inserted into the corresponding
297+
/// element in `newbbArgs`.
298+
/// - `resultOffsets` is of the same size as `tiledValues` and represents
299+
/// the offsets to use when inserting corresponding element from `tiledValues`
300+
/// into the element from `newBbArgs`.
301+
/// - `resultSizes` is of the same size as `tiledValues` and represents
302+
/// the size of the corresponding element from `tiledValues` inserted into
303+
/// the element from `newBbArgs`.
304+
using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
305+
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
306+
SmallVector<Value> &tiledValues,
307+
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
308+
SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
309+
310+
/// Append the specified additional `newInitOperands` operands to the
311+
/// loops existing `init` operands (or similar), and replace `loopOp` with
312+
/// the new loop that has the additional init operands. The loop body of
313+
/// this loop is moved over to the new loop. `yieldTiledValuesFn`
314+
/// is called to get the new tiled values returned, and the offset
315+
/// and sizes at which the tiled value is inserted into the
316+
/// new region iter_args that correspond to the newly added init operands.
317+
template <typename LoopType>
318+
FailureOr<LoopLikeOpInterface>
319+
yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
320+
ValueRange newInitOperands,
321+
YieldTiledValuesFn yieldTiledValuesFn) {
322+
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
323+
}
324+
325+
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
326+
template <>
327+
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
328+
scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
329+
YieldTiledValuesFn yieldTiledValuesFn) {
330+
OpBuilder::InsertionGuard g(rewriter);
331+
Location loc = loopOp.getLoc();
332+
rewriter.setInsertionPoint(loopOp);
333+
334+
auto inits = llvm::to_vector(loopOp.getInitArgs());
335+
inits.append(newInitOperands.begin(), newInitOperands.end());
336+
auto newLoop = rewriter.create<scf::ForOp>(
337+
loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
338+
inits, [](OpBuilder &, Location, Value, ValueRange) {});
339+
340+
// Move the loop body to the new op.
341+
Block *loopBody = loopOp.getBody();
342+
Block *newLoopBody = newLoop.getBody();
343+
rewriter.mergeBlocks(
344+
loopBody, newLoopBody,
345+
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
346+
347+
auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
348+
rewriter.setInsertionPoint(yieldOp);
349+
350+
SmallVector<Value> tiledValues;
351+
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
352+
ValueRange newRegionIterArgs =
353+
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
354+
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
355+
newRegionIterArgs, tiledValues, resultOffsets,
356+
resultSizes))) {
357+
return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
358+
}
359+
360+
if (tiledValues.size() != resultOffsets.size() ||
361+
tiledValues.size() != resultSizes.size()) {
362+
return rewriter.notifyMatchFailure(
363+
loopOp,
364+
"expected number of tiled values returned, the number of offset "
365+
"vectors and number of size vectors to be the same");
366+
}
367+
368+
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
369+
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
370+
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
371+
resultSizes)) {
372+
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
373+
rewriter.getIndexAttr(1));
374+
Value insert = rewriter.create<tensor::InsertSliceOp>(
375+
yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
376+
resultStride);
377+
newYieldValues.push_back(insert);
378+
}
379+
380+
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
381+
rewriter.replaceOp(loopOp,
382+
newLoop->getResults().take_front(loopOp.getNumResults()));
383+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
384+
}
385+
386+
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
387+
template <>
388+
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
389+
scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
390+
YieldTiledValuesFn yieldTiledValuesFn) {
391+
OpBuilder::InsertionGuard g(rewriter);
392+
Location loc = loopOp.getLoc();
393+
rewriter.setInsertionPoint(loopOp);
394+
auto inits = llvm::to_vector(loopOp.getOutputs());
395+
inits.append(newInitOperands.begin(), newInitOperands.end());
396+
auto newLoop = rewriter.create<scf::ForallOp>(
397+
loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
398+
loopOp.getMixedStep(), inits, loopOp.getMapping(),
399+
[](OpBuilder &, Location, ValueRange) {});
400+
401+
// Move the region of the current block to the newly created op.
402+
Block *loopBody = loopOp.getBody();
403+
Block *newLoopBody = newLoop.getBody();
404+
rewriter.mergeBlocks(
405+
loopBody, newLoopBody,
406+
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
407+
408+
auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
409+
rewriter.setInsertionPoint(terminator);
410+
SmallVector<Value> tiledValues;
411+
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
412+
ValueRange regionIterArgs =
413+
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
414+
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
415+
regionIterArgs, tiledValues, resultOffsets,
416+
resultSizes))) {
417+
return rewriter.notifyMatchFailure(loopOp,
418+
"failed to get yielded tiled values");
419+
}
420+
421+
// Update the terminator.
422+
rewriter.setInsertionPointToEnd(terminator.getBody());
423+
424+
for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
425+
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
426+
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
427+
rewriter.getIndexAttr(1));
428+
rewriter.create<tensor::ParallelInsertSliceOp>(
429+
terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
430+
resultStride);
431+
}
432+
433+
rewriter.replaceOp(loopOp,
434+
newLoop->getResults().take_front(loopOp.getNumResults()));
435+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
436+
}
437+
438+
/// Implementation of `yieldTiledValuesAndReplaceLoop` for
439+
/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
440+
/// supported loop type.
441+
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
442+
LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
443+
ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
444+
return TypeSwitch<LoopLikeOpInterface, FailureOr<LoopLikeOpInterface>>(
445+
loopLikeOp)
446+
.Case<scf::ForOp, scf::ForallOp>(
447+
[&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
448+
return yieldTiledValuesAndReplaceLoop(
449+
loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
450+
})
451+
.Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
452+
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
453+
});
454+
}
455+
290456
/// Method to add new init values to a loop nest. Updates `loops` in-place with
291457
/// new loops that use the `newInitValues`.
292458
/// The outer-loops are updated to yield the new result values of the inner
@@ -334,8 +500,8 @@ static LogicalResult addInitOperandsToLoopNest(
334500
// Update the loop body of the innermost loop to get new yield values.
335501
LoopLikeOpInterface innerMostLoop = loops.back();
336502
FailureOr<LoopLikeOpInterface> newInnerMostLoop =
337-
innerMostLoop.yieldTiledValuesAndReplace(rewriter, newInitValues,
338-
getNewTiledYieldsFn);
503+
yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
504+
getNewTiledYieldsFn);
339505

340506
if (failed(newInnerMostLoop))
341507
return innerMostLoop.emitOpError("failed to return additional yields");

0 commit comments

Comments
 (0)