Skip to content

Commit 2de2dbe

Browse files
[mlir][linalg] Replace AffineMinSCFCanonicalizationPattern with SCF reimplementation
Use the new canonicalization pattern in the SCF dialect. Differential Revision: https://reviews.llvm.org/D107732
1 parent 629411d commit 2de2dbe

File tree

11 files changed

+10
-408
lines changed

11 files changed

+10
-408
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -982,69 +982,6 @@ struct LinalgCopyVTWForwardingPattern
982982
PatternRewriter &rewriter) const override;
983983
};
984984

985-
using GetMinMaxExprFn =
986-
std::function<Optional<std::pair<AffineExpr, AffineExpr>>(
987-
Value value, SmallVectorImpl<Value> &dims,
988-
SmallVectorImpl<Value> &symbols)>;
989-
990-
/// Canonicalize AffineMinOp operations in the context of ops with a known range
991-
/// by:
992-
/// 1. building an affine map where uses of the known ops are replaced by
993-
/// their min annd max expressions returned by the lambda `getMinMaxFn`.
994-
/// 2. checking whether any of the results of this affine map is known to be
995-
/// greater than all other results.
996-
/// 3. replacing the AffineMinOp by the result of (2).
997-
struct AffineMinRangeCanonicalizationPattern
998-
: public OpRewritePattern<AffineMinOp> {
999-
AffineMinRangeCanonicalizationPattern(MLIRContext *context,
1000-
GetMinMaxExprFn getMinMaxFn)
1001-
: OpRewritePattern<AffineMinOp>(context), getMinMaxFn(getMinMaxFn) {}
1002-
LogicalResult matchAndRewrite(AffineMinOp minOp,
1003-
PatternRewriter &rewriter) const override;
1004-
1005-
protected:
1006-
GetMinMaxExprFn getMinMaxFn;
1007-
};
1008-
1009-
/// Specialized version of `AffineMinRangeCanonicalizationPattern` pattern
1010-
/// using `getSCFMinMaxExpr` to know the min and max expression of induction
1011-
/// variables from scf loops.
1012-
// TODO: move to a more appropriate place when it is determined. For now Linalg
1013-
// depends both on Affine and SCF but they do not depend on each other.
1014-
struct AffineMinSCFCanonicalizationPattern
1015-
: public AffineMinRangeCanonicalizationPattern {
1016-
static Optional<std::pair<AffineExpr, AffineExpr>>
1017-
getMinMax(Value value, SmallVectorImpl<Value> &dims,
1018-
SmallVectorImpl<Value> &symbols) {
1019-
return getSCFMinMaxExpr(value, dims, symbols);
1020-
}
1021-
AffineMinSCFCanonicalizationPattern(MLIRContext *context)
1022-
: AffineMinRangeCanonicalizationPattern(context, getMinMax) {}
1023-
};
1024-
1025-
/// Helper struct to return the results of `substituteMin`.
1026-
struct AffineMapAndOperands {
1027-
AffineMap map;
1028-
SmallVector<Value> dims;
1029-
SmallVector<Value> symbols;
1030-
};
1031-
1032-
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
1033-
/// dimensions with known range by new expressions involving the min or max
1034-
/// expression:
1035-
/// - If the AffineDimExpr mapped to a known value has a positive sign, it
1036-
/// is replaced by the min expression.
1037-
/// - If the AffineDimExpr mapped to a known value has a negative sign, it is
1038-
/// replaced by the max expression.
1039-
/// All known values are iteratively replaced.
1040-
/// This is used as an intermediate step in computing bounding boxes and
1041-
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
1042-
/// positive values (positive orthant assumptions).
1043-
/// Return a new AffineMap, dims and symbols that have been canonicalized and
1044-
/// simplified.
1045-
AffineMapAndOperands substituteMin(AffineMinOp affineMinOp,
1046-
GetMinMaxExprFn getMinMaxExpr);
1047-
1048985
/// Converts Convolution op into vector contraction.
1049986
///
1050987
/// Conversion expects ConvOp to have dimensions marked in the *mask* as

mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
1515

1616
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17+
#include "mlir/Dialect/SCF/Transforms.h"
1718
#include "mlir/Dialect/Vector/VectorOps.h"
1819
#include "mlir/Dialect/Vector/VectorTransforms.h"
1920
#include "mlir/Pass/PassManager.h"
@@ -47,7 +48,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
4748

4849
RewritePatternSet stage2Patterns =
4950
linalg::getLinalgTilingCanonicalizationPatterns(context);
50-
stage2Patterns.add<AffineMinSCFCanonicalizationPattern>(context);
51+
scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns);
5152

5253
auto stage3Transforms = [&](Operation *op) {
5354
// Some of these may be too aggressive as a stage 3 that is applied on each

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1717
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/Dialect/SCF/Transforms.h"
1920
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2021
#include "mlir/IR/AffineExpr.h"
2122
#include "mlir/IR/AffineMap.h"
@@ -536,7 +537,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
536537
MLIRContext *ctx = funcOp.getContext();
537538
RewritePatternSet patterns(ctx);
538539
insertTilingPatterns(patterns, options);
539-
patterns.add<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
540+
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
540541
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
541542
(void)applyPatternsAndFoldGreedily(
542543
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -494,145 +494,6 @@ LogicalResult mlir::linalg::applyStagedPatterns(
494494
return success();
495495
}
496496

497-
/// Traverse the `dims` and substitute known min or max expressions returned by
498-
/// the lambda |getMinMaxExpr|.
499-
static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
500-
SmallVectorImpl<Value> &symbols,
501-
GetMinMaxExprFn getMinMaxExpr) {
502-
auto exprs = llvm::to_vector<4>(map.getResults());
503-
for (AffineExpr &expr : exprs) {
504-
bool substituted = true;
505-
while (substituted) {
506-
substituted = false;
507-
for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
508-
Value dim = dims[dimIdx];
509-
auto minMax = getMinMaxExpr(dim, dims, symbols);
510-
if (!minMax)
511-
continue;
512-
AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
513-
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
514-
LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
515-
// Substitute occurrences of `dimExpr` by either the min expression or
516-
// the max expression depending on whether the value is used with a
517-
// positive or negative coefficient.
518-
AffineExpr substitutedExpr =
519-
substWithMin(expr, dimExpr, minMax->first, minMax->second);
520-
LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n");
521-
substituted = (substitutedExpr != expr);
522-
expr = substitutedExpr;
523-
}
524-
}
525-
526-
// Cleanup and simplify the results.
527-
// This needs to happen outside of the loop iterating on dims.size() since
528-
// it modifies dims.
529-
SmallVector<Value, 4> operands(dims.begin(), dims.end());
530-
operands.append(symbols.begin(), symbols.end());
531-
auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
532-
exprs.front().getContext());
533-
534-
LLVM_DEBUG({
535-
DBGS() << "Map to simplify: " << map << "\n";
536-
DBGS() << "Operands:\n";
537-
for (Value v : operands)
538-
DBGS() << v << "\n";
539-
});
540-
541-
// Pull in affine.apply operations and compose them fully into the
542-
// result.
543-
fullyComposeAffineMapAndOperands(&map, &operands);
544-
canonicalizeMapAndOperands(&map, &operands);
545-
map = simplifyAffineMap(map);
546-
// Assign the results.
547-
exprs.assign(map.getResults().begin(), map.getResults().end());
548-
dims.assign(operands.begin(), operands.begin() + map.getNumDims());
549-
symbols.assign(operands.begin() + map.getNumDims(), operands.end());
550-
551-
LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
552-
}
553-
554-
assert(!exprs.empty() && "Unexpected empty exprs");
555-
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
556-
}
557-
558-
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
559-
/// dimensions with known range by new expressions involving the min or max
560-
/// expression:
561-
/// - If the AffineDimExpr mapped to a known value has a positive sign, it
562-
/// is replaced by the min expression.
563-
/// - If the AffineDimExpr mapped to a known value has a negative sign, it is
564-
/// replaced by the max expression.
565-
/// All known values are iteratively replaced.
566-
/// This is used as an intermediate step in computing bounding boxes and
567-
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
568-
/// positive values (positive orthant assumptions).
569-
/// Return a new AffineMap, dims and symbols that have been canonicalized and
570-
/// simplified.
571-
AffineMapAndOperands
572-
mlir::linalg::substituteMin(AffineMinOp affineMinOp,
573-
GetMinMaxExprFn getMinMaxExpr) {
574-
AffineMapAndOperands res{affineMinOp.getAffineMap(),
575-
SmallVector<Value>(affineMinOp.getDimOperands()),
576-
SmallVector<Value>(affineMinOp.getSymbolOperands())};
577-
res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
578-
getMinMaxExpr);
579-
return res;
580-
}
581-
582-
LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
583-
AffineMinOp minOp, PatternRewriter &rewriter) const {
584-
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
585-
<< "\n");
586-
587-
auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn);
588-
AffineMap map = affineMapAndOperands.map;
589-
590-
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
591-
592-
// Check whether any of the expressions, when subtracted from all other
593-
// expressions, produces only >= 0 constants. If so, it is the min.
594-
for (auto e : minOp.getAffineMap().getResults()) {
595-
LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
596-
if (!e.isSymbolicOrConstant())
597-
continue;
598-
599-
auto isNonPositive = [](AffineExpr e) {
600-
if (auto cst = e.dyn_cast<AffineConstantExpr>())
601-
return cst.getValue() < 0;
602-
return true;
603-
};
604-
605-
// Build the subMap and check everything is statically known to be
606-
// positive.
607-
SmallVector<AffineExpr, 4> subExprs;
608-
subExprs.reserve(map.getNumResults());
609-
for (auto ee : map.getResults())
610-
subExprs.push_back(ee - e);
611-
MLIRContext *ctx = minOp.getContext();
612-
AffineMap subMap = simplifyAffineMap(
613-
AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
614-
LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
615-
if (llvm::any_of(subMap.getResults(), isNonPositive))
616-
continue;
617-
618-
// Static min found.
619-
if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
620-
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
621-
} else {
622-
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
623-
SmallVector<Value> resultOperands = affineMapAndOperands.dims;
624-
llvm::append_range(resultOperands, affineMapAndOperands.symbols);
625-
canonicalizeMapAndOperands(&resultMap, &resultOperands);
626-
resultMap = simplifyAffineMap(resultMap);
627-
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
628-
resultOperands);
629-
}
630-
return success();
631-
}
632-
633-
return failure();
634-
}
635-
636497
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
637498
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
638499
}

mlir/lib/Dialect/SCF/Transforms/Utils.cpp

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -145,43 +145,3 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
145145
}
146146
return rootEnclosesPloops;
147147
}
148-
149-
/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
150-
/// `ubVal` to `dims` and `stepVal` to `symbols`.
151-
/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
152-
/// with positions matching the newly appended values. Then create a min
153-
/// expression (i.e. `%lb`) and a max expression
154-
/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`.
155-
static std::pair<AffineExpr, AffineExpr>
156-
getMinMaxLoopIndVar(Value lbVal, Value ubVal, Value stepVal,
157-
SmallVectorImpl<Value> &dims,
158-
SmallVectorImpl<Value> &symbols) {
159-
MLIRContext *ctx = lbVal.getContext();
160-
AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
161-
dims.push_back(lbVal);
162-
AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
163-
dims.push_back(ubVal);
164-
AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
165-
symbols.push_back(stepVal);
166-
return std::make_pair(lb, lb + step * ((ub - 1) - lb).floorDiv(step));
167-
}
168-
169-
/// Return the min/max expressions for `value` if it is an induction variable
170-
/// from scf.for or scf.parallel loop.
171-
/// if `loopFilter` is passed, the filter determines which loop to consider.
172-
/// Other induction variables are ignored.
173-
Optional<std::pair<AffineExpr, AffineExpr>> mlir::getSCFMinMaxExpr(
174-
Value value, SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &symbols,
175-
llvm::function_ref<bool(Operation *)> substituteOperation) {
176-
if (auto forOp = scf::getForInductionVarOwner(value))
177-
return getMinMaxLoopIndVar(forOp.lowerBound(), forOp.upperBound(),
178-
forOp.step(), dims, symbols);
179-
180-
if (auto parallelForOp = scf::getParallelForInductionVarOwner(value))
181-
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx)
182-
if (parallelForOp.getInductionVars()[idx] == value)
183-
return getMinMaxLoopIndVar(parallelForOp.lowerBound()[idx],
184-
parallelForOp.upperBound()[idx],
185-
parallelForOp.step()[idx], dims, symbols);
186-
return {};
187-
}

0 commit comments

Comments
 (0)