|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
| 10 | +#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" |
10 | 11 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
| 12 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | 13 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
12 | 14 | #include "mlir/Dialect/UB/IR/UBOps.h"
|
13 | 15 | #include "mlir/IR/AffineExprVisitor.h"
|
@@ -2448,6 +2450,65 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
|
2448 | 2450 | return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
|
2449 | 2451 | }
|
2450 | 2452 |
|
| 2453 | +void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) { |
| 2454 | + // Replace uses of iter arguments with iter operands (initial values). |
| 2455 | + OperandRange iterOperands = forOp.getInits(); |
| 2456 | + MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs(); |
| 2457 | + for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs)) |
| 2458 | + arg.replaceAllUsesWith(operand); |
| 2459 | + |
| 2460 | + // Replace uses of loop results with the values yielded by the loop. |
| 2461 | + ResultRange outerResults = forOp.getResults(); |
| 2462 | + OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands(); |
| 2463 | + for (auto [outer, inner] : llvm::zip(outerResults, innerResults)) |
| 2464 | + outer.replaceAllUsesWith(inner); |
| 2465 | +} |
| 2466 | + |
| 2467 | +LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
| 2468 | + auto forOp = cast<AffineForOp>(getOperation()); |
| 2469 | + std::optional<uint64_t> tripCount = getConstantTripCount(forOp); |
| 2470 | + if (!tripCount || *tripCount != 1) |
| 2471 | + return failure(); |
| 2472 | + |
| 2473 | + // TODO: extend this for arbitrary affine bounds. |
| 2474 | + if (forOp.getLowerBoundMap().getNumResults() != 1) |
| 2475 | + return failure(); |
| 2476 | + |
| 2477 | + // Replaces all IV uses to its single iteration value. |
| 2478 | + BlockArgument iv = forOp.getInductionVar(); |
| 2479 | + Block *parentBlock = forOp->getBlock(); |
| 2480 | + if (!iv.use_empty()) { |
| 2481 | + if (forOp.hasConstantLowerBound()) { |
| 2482 | + OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody()); |
| 2483 | + auto constOp = topBuilder.create<arith::ConstantIndexOp>( |
| 2484 | + forOp.getLoc(), forOp.getConstantLowerBound()); |
| 2485 | + iv.replaceAllUsesWith(constOp); |
| 2486 | + } else { |
| 2487 | + OperandRange lbOperands = forOp.getLowerBoundOperands(); |
| 2488 | + AffineMap lbMap = forOp.getLowerBoundMap(); |
| 2489 | + OpBuilder builder(forOp); |
| 2490 | + if (lbMap == builder.getDimIdentityMap()) { |
| 2491 | + // No need of generating an affine.apply. |
| 2492 | + iv.replaceAllUsesWith(lbOperands[0]); |
| 2493 | + } else { |
| 2494 | + auto affineApplyOp = |
| 2495 | + builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands); |
| 2496 | + iv.replaceAllUsesWith(affineApplyOp); |
| 2497 | + } |
| 2498 | + } |
| 2499 | + } |
| 2500 | + |
| 2501 | + replaceIterArgsAndYieldResults(forOp); |
| 2502 | + |
| 2503 | + // Move the loop body operations, except for its terminator, to the loop's |
| 2504 | + // containing block. |
| 2505 | + forOp.getBody()->back().erase(); |
| 2506 | + parentBlock->getOperations().splice(Block::iterator(forOp), |
| 2507 | + forOp.getBody()->getOperations()); |
| 2508 | + forOp.erase(); |
| 2509 | + return success(); |
| 2510 | +} |
| 2511 | + |
2451 | 2512 | FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
|
2452 | 2513 | RewriterBase &rewriter, ValueRange newInitOperands,
|
2453 | 2514 | bool replaceInitOperandUsesInLoop,
|
@@ -2546,6 +2607,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
|
2546 | 2607 | return nullptr;
|
2547 | 2608 | }
|
2548 | 2609 |
|
| 2610 | +/// Returns the trip count of the loop as an affine expression if the latter is |
| 2611 | +/// expressible as an affine expression, and nullptr otherwise. The trip count |
| 2612 | +/// expression is simplified before returning. This method only utilizes map |
| 2613 | +/// composition to construct lower and upper bounds before computing the trip |
| 2614 | +/// count expressions. |
| 2615 | +void mlir::affine::getTripCountMapAndOperands( |
| 2616 | + AffineForOp forOp, AffineMap *tripCountMap, |
| 2617 | + SmallVectorImpl<Value> *tripCountOperands) { |
| 2618 | + MLIRContext *context = forOp.getContext(); |
| 2619 | + int64_t step = forOp.getStepAsInt(); |
| 2620 | + int64_t loopSpan; |
| 2621 | + if (forOp.hasConstantBounds()) { |
| 2622 | + int64_t lb = forOp.getConstantLowerBound(); |
| 2623 | + int64_t ub = forOp.getConstantUpperBound(); |
| 2624 | + loopSpan = ub - lb; |
| 2625 | + if (loopSpan < 0) |
| 2626 | + loopSpan = 0; |
| 2627 | + *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context); |
| 2628 | + tripCountOperands->clear(); |
| 2629 | + return; |
| 2630 | + } |
| 2631 | + auto lbMap = forOp.getLowerBoundMap(); |
| 2632 | + auto ubMap = forOp.getUpperBoundMap(); |
| 2633 | + if (lbMap.getNumResults() != 1) { |
| 2634 | + *tripCountMap = AffineMap(); |
| 2635 | + return; |
| 2636 | + } |
| 2637 | + |
| 2638 | + // Difference of each upper bound expression from the single lower bound |
| 2639 | + // expression (divided by the step) provides the expressions for the trip |
| 2640 | + // count map. |
| 2641 | + AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands()); |
| 2642 | + |
| 2643 | + SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(), |
| 2644 | + lbMap.getResult(0)); |
| 2645 | + auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), |
| 2646 | + lbSplatExpr, context); |
| 2647 | + AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands()); |
| 2648 | + |
| 2649 | + AffineValueMap tripCountValueMap; |
| 2650 | + AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap); |
| 2651 | + for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i) |
| 2652 | + tripCountValueMap.setResult(i, |
| 2653 | + tripCountValueMap.getResult(i).ceilDiv(step)); |
| 2654 | + |
| 2655 | + *tripCountMap = tripCountValueMap.getAffineMap(); |
| 2656 | + tripCountOperands->assign(tripCountValueMap.getOperands().begin(), |
| 2657 | + tripCountValueMap.getOperands().end()); |
| 2658 | +} |
| 2659 | + |
| 2660 | +std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) { |
| 2661 | + SmallVector<Value, 4> operands; |
| 2662 | + AffineMap map; |
| 2663 | + getTripCountMapAndOperands(forOp, &map, &operands); |
| 2664 | + |
| 2665 | + if (!map) |
| 2666 | + return std::nullopt; |
| 2667 | + |
| 2668 | + // Take the min if all trip counts are constant. |
| 2669 | + std::optional<uint64_t> tripCount; |
| 2670 | + for (auto resultExpr : map.getResults()) { |
| 2671 | + if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) { |
| 2672 | + if (tripCount.has_value()) |
| 2673 | + tripCount = |
| 2674 | + std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue())); |
| 2675 | + else |
| 2676 | + tripCount = constExpr.getValue(); |
| 2677 | + } else |
| 2678 | + return std::nullopt; |
| 2679 | + } |
| 2680 | + return tripCount; |
| 2681 | +} |
| 2682 | + |
2549 | 2683 | /// Extracts the induction variables from a list of AffineForOps and returns
|
2550 | 2684 | /// them.
|
2551 | 2685 | void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
|
@@ -2913,8 +3047,7 @@ static void composeSetAndOperands(IntegerSet &set,
|
2913 | 3047 | }
|
2914 | 3048 |
|
2915 | 3049 | /// Canonicalize an affine if op's conditional (integer set + operands).
|
2916 |
| -LogicalResult AffineIfOp::fold(FoldAdaptor, |
2917 |
| - SmallVectorImpl<OpFoldResult> &) { |
| 3050 | +LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { |
2918 | 3051 | auto set = getIntegerSet();
|
2919 | 3052 | SmallVector<Value, 4> operands(getOperands());
|
2920 | 3053 | composeSetAndOperands(set, operands);
|
@@ -3005,11 +3138,11 @@ static LogicalResult
|
3005 | 3138 | verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
|
3006 | 3139 | Operation::operand_range mapOperands,
|
3007 | 3140 | MemRefType memrefType, unsigned numIndexOperands) {
|
3008 |
| - AffineMap map = mapAttr.getValue(); |
3009 |
| - if (map.getNumResults() != memrefType.getRank()) |
3010 |
| - return op->emitOpError("affine map num results must equal memref rank"); |
3011 |
| - if (map.getNumInputs() != numIndexOperands) |
3012 |
| - return op->emitOpError("expects as many subscripts as affine map inputs"); |
| 3141 | + AffineMap map = mapAttr.getValue(); |
| 3142 | + if (map.getNumResults() != memrefType.getRank()) |
| 3143 | + return op->emitOpError("affine map num results must equal memref rank"); |
| 3144 | + if (map.getNumInputs() != numIndexOperands) |
| 3145 | + return op->emitOpError("expects as many subscripts as affine map inputs"); |
3013 | 3146 |
|
3014 | 3147 | Region *scope = getAffineScope(op);
|
3015 | 3148 | for (auto idx : mapOperands) {
|
|
0 commit comments