|
23 | 23 | #include "mlir/IR/PatternMatch.h"
|
24 | 24 | #include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
25 | 25 | #include "mlir/Interfaces/TilingInterface.h"
|
| 26 | +#include "llvm/ADT/TypeSwitch.h" |
26 | 27 | #include "llvm/Support/Debug.h"
|
27 | 28 | #include <optional>
|
28 | 29 |
|
@@ -287,6 +288,171 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
|
287 | 288 | return rewriter.notifyMatchFailure(loc, "unhandled loop type");
|
288 | 289 | }
|
289 | 290 |
|
| 291 | +/// A function that allows returning additional yielded values during |
| 292 | +/// `yieldTiledValuesAndReplace`. |
| 293 | +/// - `ivs` induction variable for the loop. |
| 294 | +/// - `newBbArgs` basic block arguments corresponding to newly added iter_args. |
| 295 | +/// - `tiledValues` the tiled values to return. Must be of same size as |
| 296 | +/// `newbbArgs`, each element of this array is inserted into the corresponding |
| 297 | +/// element in `newbbArgs`. |
| 298 | +/// - `resultOffsets` is of the same size as `tiledValues` and represents |
| 299 | +/// the offsets to use when inserting corresponding element from `tiledValues` |
| 300 | +/// into the element from `newBbArgs`. |
| 301 | +/// - `resultSizes` is of the same size as `tiledValues` and represents |
| 302 | +/// the size of the corresponding element from `tiledValues` inserted into |
| 303 | +/// the element from `newBbArgs`. |
| 304 | +using YieldTiledValuesFn = llvm::function_ref<LogicalResult( |
| 305 | + RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, |
| 306 | + SmallVector<Value> &tiledValues, |
| 307 | + SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
| 308 | + SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; |
| 309 | + |
| 310 | +/// Append the specified additional `newInitOperands` operands to the |
| 311 | +/// loops existing `init` operands (or similar), and replace `loopOp` with |
| 312 | +/// the new loop that has the additional init operands. The loop body of |
| 313 | +/// this loop is moved over to the new loop. `yieldTiledValuesFn` |
| 314 | +/// is called to get the new tiled values returned, and the offset |
| 315 | +/// and sizes at which the tiled value is inserted into the |
| 316 | +/// new region iter_args that correspond to the newly added init operands. |
| 317 | +template <typename LoopType> |
| 318 | +FailureOr<LoopLikeOpInterface> |
| 319 | +yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, |
| 320 | + ValueRange newInitOperands, |
| 321 | + YieldTiledValuesFn yieldTiledValuesFn) { |
| 322 | + return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); |
| 323 | +} |
| 324 | + |
| 325 | +/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. |
| 326 | +template <> |
| 327 | +FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( |
| 328 | + scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 329 | + YieldTiledValuesFn yieldTiledValuesFn) { |
| 330 | + OpBuilder::InsertionGuard g(rewriter); |
| 331 | + Location loc = loopOp.getLoc(); |
| 332 | + rewriter.setInsertionPoint(loopOp); |
| 333 | + |
| 334 | + auto inits = llvm::to_vector(loopOp.getInitArgs()); |
| 335 | + inits.append(newInitOperands.begin(), newInitOperands.end()); |
| 336 | + auto newLoop = rewriter.create<scf::ForOp>( |
| 337 | + loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), |
| 338 | + inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
| 339 | + |
| 340 | + // Move the loop body to the new op. |
| 341 | + Block *loopBody = loopOp.getBody(); |
| 342 | + Block *newLoopBody = newLoop.getBody(); |
| 343 | + rewriter.mergeBlocks( |
| 344 | + loopBody, newLoopBody, |
| 345 | + newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
| 346 | + |
| 347 | + auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); |
| 348 | + rewriter.setInsertionPoint(yieldOp); |
| 349 | + |
| 350 | + SmallVector<Value> tiledValues; |
| 351 | + SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 352 | + ValueRange newRegionIterArgs = |
| 353 | + newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
| 354 | + if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), |
| 355 | + newRegionIterArgs, tiledValues, resultOffsets, |
| 356 | + resultSizes))) { |
| 357 | + return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); |
| 358 | + } |
| 359 | + |
| 360 | + if (tiledValues.size() != resultOffsets.size() || |
| 361 | + tiledValues.size() != resultSizes.size()) { |
| 362 | + return rewriter.notifyMatchFailure( |
| 363 | + loopOp, |
| 364 | + "expected number of tiled values returned, the number of offset " |
| 365 | + "vectors and number of size vectors to be the same"); |
| 366 | + } |
| 367 | + |
| 368 | + SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); |
| 369 | + for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : |
| 370 | + llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, |
| 371 | + resultSizes)) { |
| 372 | + SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 373 | + rewriter.getIndexAttr(1)); |
| 374 | + Value insert = rewriter.create<tensor::InsertSliceOp>( |
| 375 | + yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, |
| 376 | + resultStride); |
| 377 | + newYieldValues.push_back(insert); |
| 378 | + } |
| 379 | + |
| 380 | + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); |
| 381 | + rewriter.replaceOp(loopOp, |
| 382 | + newLoop->getResults().take_front(loopOp.getNumResults())); |
| 383 | + return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
| 384 | +} |
| 385 | + |
| 386 | +/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` |
| 387 | +template <> |
| 388 | +FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( |
| 389 | + scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 390 | + YieldTiledValuesFn yieldTiledValuesFn) { |
| 391 | + OpBuilder::InsertionGuard g(rewriter); |
| 392 | + Location loc = loopOp.getLoc(); |
| 393 | + rewriter.setInsertionPoint(loopOp); |
| 394 | + auto inits = llvm::to_vector(loopOp.getOutputs()); |
| 395 | + inits.append(newInitOperands.begin(), newInitOperands.end()); |
| 396 | + auto newLoop = rewriter.create<scf::ForallOp>( |
| 397 | + loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), |
| 398 | + loopOp.getMixedStep(), inits, loopOp.getMapping(), |
| 399 | + [](OpBuilder &, Location, ValueRange) {}); |
| 400 | + |
| 401 | + // Move the region of the current block to the newly created op. |
| 402 | + Block *loopBody = loopOp.getBody(); |
| 403 | + Block *newLoopBody = newLoop.getBody(); |
| 404 | + rewriter.mergeBlocks( |
| 405 | + loopBody, newLoopBody, |
| 406 | + newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
| 407 | + |
| 408 | + auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); |
| 409 | + rewriter.setInsertionPoint(terminator); |
| 410 | + SmallVector<Value> tiledValues; |
| 411 | + SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 412 | + ValueRange regionIterArgs = |
| 413 | + newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
| 414 | + if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), |
| 415 | + regionIterArgs, tiledValues, resultOffsets, |
| 416 | + resultSizes))) { |
| 417 | + return rewriter.notifyMatchFailure(loopOp, |
| 418 | + "failed to get yielded tiled values"); |
| 419 | + } |
| 420 | + |
| 421 | + // Update the terminator. |
| 422 | + rewriter.setInsertionPointToEnd(terminator.getBody()); |
| 423 | + |
| 424 | + for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( |
| 425 | + tiledValues, regionIterArgs, resultOffsets, resultSizes)) { |
| 426 | + SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 427 | + rewriter.getIndexAttr(1)); |
| 428 | + rewriter.create<tensor::ParallelInsertSliceOp>( |
| 429 | + terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, |
| 430 | + resultStride); |
| 431 | + } |
| 432 | + |
| 433 | + rewriter.replaceOp(loopOp, |
| 434 | + newLoop->getResults().take_front(loopOp.getNumResults())); |
| 435 | + return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
| 436 | +} |
| 437 | + |
| 438 | +/// Implementation of `yieldTiledValuesAndReplaceLoop` for |
| 439 | +/// `LoopLikeOpInterface`, that just dispatches to the implementation for each |
| 440 | +/// supported loop type. |
| 441 | +FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( |
| 442 | + LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, |
| 443 | + ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { |
| 444 | + return TypeSwitch<LoopLikeOpInterface, FailureOr<LoopLikeOpInterface>>( |
| 445 | + loopLikeOp) |
| 446 | + .Case<scf::ForOp, scf::ForallOp>( |
| 447 | + [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 448 | + return yieldTiledValuesAndReplaceLoop( |
| 449 | + loopOp, rewriter, newInitOperands, yieldTiledValuesFn); |
| 450 | + }) |
| 451 | + .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 452 | + return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); |
| 453 | + }); |
| 454 | +} |
| 455 | + |
290 | 456 | /// Method to add new init values to a loop nest. Updates `loops` in-place with
|
291 | 457 | /// new loops that use the `newInitValues`.
|
292 | 458 | /// The outer-loops are updated to yield the new result values of the inner
|
@@ -334,8 +500,8 @@ static LogicalResult addInitOperandsToLoopNest(
|
334 | 500 | // Update the loop body of the innermost loop to get new yield values.
|
335 | 501 | LoopLikeOpInterface innerMostLoop = loops.back();
|
336 | 502 | FailureOr<LoopLikeOpInterface> newInnerMostLoop =
|
337 |
| - innerMostLoop.yieldTiledValuesAndReplace(rewriter, newInitValues, |
338 |
| - getNewTiledYieldsFn); |
| 503 | + yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, |
| 504 | + getNewTiledYieldsFn); |
339 | 505 |
|
340 | 506 | if (failed(newInnerMostLoop))
|
341 | 507 | return innerMostLoop.emitOpError("failed to return additional yields");
|
|
0 commit comments