@@ -99,6 +99,25 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
99
99
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
100
100
}
101
101
102
+ // / A function that allows returning additional yielded values during
103
+ // / `yieldTiledValuesAndReplace`.
104
+ // / - `ivs` induction variable for the loop.
105
+ // / - `newBbArgs` basic block arguments corresponding to newly added iter_args.
106
+ // / - `tiledValues` the tiled values to return. Must be of same size as
107
+ // / `newbbArgs`, each element of this array is inserted into the corresponding
108
+ // / element in `newbbArgs`.
109
+ // / - `resultOffsets` is of the same size as `tiledValues` and represents
110
+ // / the offsets to use when inserting corresponding element from `tiledValues`
111
+ // / into the element from `newBbArgs`.
112
+ // / - `resultSizes` is of the same size as `tiledValues` and represents
113
+ // / the size of the corresponding element from `tiledValues` inserted into
114
+ // / the element from `newBbArgs`.
115
+ using YieldTiledValuesFn = std::function<LogicalResult(
116
+ RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
117
+ SmallVector<Value> &tiledValues,
118
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
119
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
120
+
102
121
// / Clones the operation and updates the destination if the operation
103
122
// / implements the `DestinationStyleOpInterface`.
104
123
static Operation *cloneOpAndUpdateDestinationArgs (RewriterBase &rewriter,
@@ -288,25 +307,6 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
288
307
return rewriter.notifyMatchFailure (loc, " unhandled loop type" );
289
308
}
290
309
291
- // / A function that allows returning additional yielded values during
292
- // / `yieldTiledValuesAndReplace`.
293
- // / - `ivs` induction variable for the loop.
294
- // / - `newBbArgs` basic block arguments corresponding to newly added iter_args.
295
- // / - `tiledValues` the tiled values to return. Must be of same size as
296
- // / `newbbArgs`, each element of this array is inserted into the corresponding
297
- // / element in `newbbArgs`.
298
- // / - `resultOffsets` is of the same size as `tiledValues` and represents
299
- // / the offsets to use when inserting corresponding element from `tiledValues`
300
- // / into the element from `newBbArgs`.
301
- // / - `resultSizes` is of the same size as `tiledValues` and represents
302
- // / the size of the corresponding element from `tiledValues` inserted into
303
- // / the element from `newBbArgs`.
304
- using YieldTiledValuesFn = llvm::function_ref<LogicalResult(
305
- RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
306
- SmallVector<Value> &tiledValues,
307
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
308
- SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
309
-
310
310
// / Append the specified additional `newInitOperands` operands to the
311
311
// / loops existing `init` operands (or similar), and replace `loopOp` with
312
312
// / the new loop that has the additional init operands. The loop body of
@@ -441,8 +441,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
441
441
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop (
442
442
LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
443
443
ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
444
- return TypeSwitch<LoopLikeOpInterface , FailureOr<LoopLikeOpInterface>>(
445
- loopLikeOp)
444
+ return TypeSwitch<Operation * , FailureOr<LoopLikeOpInterface>>(
445
+ loopLikeOp. getOperation () )
446
446
.Case <scf::ForOp, scf::ForallOp>(
447
447
[&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
448
448
return yieldTiledValuesAndReplaceLoop (
@@ -460,7 +460,7 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
460
460
// / the additional values to yield form the innermost loop.
461
461
static LogicalResult addInitOperandsToLoopNest (
462
462
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
463
- ValueRange newInitValues, const YieldTiledValuesFn & getNewTiledYieldsFn) {
463
+ ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
464
464
SmallVector<scf::ForOp> newLoops;
465
465
if (loops.empty ())
466
466
return success ();
0 commit comments