Skip to content

Commit 3110e7b

Browse files
author
Nicolas Vasilache
committed
[mlir] Introduce AffineMinSCF folding as a pattern
This revision adds a folding pattern to replace affine.min ops by the actual min value, when it can be determined statically from the strides and bounds of enclosing scf loop . This matches the type of expressions that Linalg produces during tiling and simplifies boundary checks. For now Linalg depends both on Affine and SCF but they do not depend on each other, so the pattern is added there. In the future this will move to a more appropriate place when it is determined. The canonicalization of AffineMinOp operations in the context of enclosing scf.for and scf.parallel proceeds by: 1. building an affine map where uses of the induction variable of a loop are replaced by `%lb + %step * floordiv(%iv - %lb, %step)` expressions. 2. checking if any of the results of this affine map divides all the other results (in which case it is also guaranteed to be the min). 3. replacing the AffineMinOp by the result of (2). The algorithm is functional in simple parametric tiling cases by using semi-affine maps. However simplifications of such semi-affine maps are not yet available and the canonicalization does not succeed yet. Differential Revision: https://reviews.llvm.org/D82009
1 parent 1bf4629 commit 3110e7b

File tree

6 files changed

+401
-0
lines changed

6 files changed

+401
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,26 @@ struct LinalgCopyVTWForwardingPattern
502502
PatternRewriter &rewriter) const override;
503503
};
504504

505+
/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and
506+
/// scf.parallel by:
507+
/// 1. building an affine map where uses of the induction variable of a loop
508+
/// are replaced by either the min (i.e. `%lb`) of the max
509+
/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending
510+
/// on whether the induction variable is used with a positive or negative
511+
/// coefficient.
512+
/// 2. checking whether any of the results of this affine map is known to be
513+
/// greater than all other results.
514+
/// 3. replacing the AffineMinOp by the result of (2).
515+
// TODO: move to a more appropriate place when it is determined. For now Linalg
516+
// depends both on Affine and SCF but they do not depend on each other.
517+
struct AffineMinSCFCanonicalizationPattern
518+
: public OpRewritePattern<AffineMinOp> {
519+
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
520+
521+
LogicalResult matchAndRewrite(AffineMinOp minOp,
522+
PatternRewriter &rewriter) const override;
523+
};
524+
505525
//===----------------------------------------------------------------------===//
506526
// Support for staged pattern application.
507527
//===----------------------------------------------------------------------===//
@@ -519,6 +539,7 @@ LogicalResult applyStagedPatterns(
519539
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
520540
const OwningRewritePatternList &stage2Patterns,
521541
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
542+
522543
} // namespace linalg
523544
} // namespace mlir
524545

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,20 @@ class AffineExpr {
115115

116116
/// This method substitutes any uses of dimensions and symbols (e.g.
117117
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
118+
/// This is a dense replacement method: a replacement must be specified for
119+
/// every single dim and symbol.
118120
AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
119121
ArrayRef<AffineExpr> symReplacements) const;
120122

123+
/// Sparse replace method. Replace `expr` by `replacement` and return the
124+
/// modified expression tree.
125+
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
126+
127+
/// Sparse replace method. If `*this` appears in `map` replaces it by
128+
/// `map[*this]` and return the modified expression tree. Otherwise traverse
129+
/// `*this` and apply replace with `map` on its subexpressions.
130+
AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
131+
121132
/// Replace symbols[0 .. numDims - 1] by
122133
/// symbols[shift .. shift + numDims - 1].
123134
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using namespace mlir::edsc::intrinsics;
3636
using namespace mlir::linalg;
3737

3838
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39+
3940
//===----------------------------------------------------------------------===//
4041
// Transformations exposed as rewrite patterns.
4142
//===----------------------------------------------------------------------===//
@@ -235,3 +236,177 @@ LogicalResult mlir::linalg::applyStagedPatterns(
235236
}
236237
return success();
237238
}
239+
240+
/// Traverse `e` and return an AffineExpr where all occurrences of `dim` have
241+
/// been replaced by either:
242+
/// - `min` if `positivePath` is true when we reach an occurrence of `dim`
243+
/// - `max` if `positivePath` is true when we reach an occurrence of `dim`
244+
/// `positivePath` is negated each time we hit a multiplicative or divisive
245+
/// binary op with a constant negative coefficient.
246+
static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
247+
AffineExpr max, bool positivePath = true) {
248+
if (e == dim)
249+
return positivePath ? min : max;
250+
if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
251+
AffineExpr lhs = bin.getLHS();
252+
AffineExpr rhs = bin.getRHS();
253+
if (bin.getKind() == mlir::AffineExprKind::Add)
254+
return substWithMin(lhs, dim, min, max, positivePath) +
255+
substWithMin(rhs, dim, min, max, positivePath);
256+
257+
auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
258+
auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
259+
if (c1 && c1.getValue() < 0)
260+
return getAffineBinaryOpExpr(
261+
bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
262+
if (c2 && c2.getValue() < 0)
263+
return getAffineBinaryOpExpr(
264+
bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
265+
return getAffineBinaryOpExpr(
266+
bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
267+
substWithMin(rhs, dim, min, max, positivePath));
268+
}
269+
return e;
270+
}
271+
272+
/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
273+
/// `ubVal` to `dims` and `stepVal` to `symbols`.
274+
/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
275+
/// with positions matching the newly appended values. Substitute occurrences of
276+
/// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression
277+
/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether
278+
/// the induction variable is used with a positive or negative coefficient.
279+
static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
280+
Value lbVal, Value ubVal, Value stepVal,
281+
SmallVectorImpl<Value> &dims,
282+
SmallVectorImpl<Value> &symbols) {
283+
MLIRContext *ctx = lbVal.getContext();
284+
AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
285+
dims.push_back(lbVal);
286+
AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
287+
dims.push_back(ubVal);
288+
AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
289+
symbols.push_back(stepVal);
290+
LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
291+
AffineExpr ee = substWithMin(expr, dimExpr, lb,
292+
lb + step * ((ub - 1) - lb).floorDiv(step));
293+
LLVM_DEBUG(DBGS() << "After: " << expr << "\n");
294+
return ee;
295+
}
296+
297+
/// Traverse the `dims` and substitute known min or max expressions in place of
298+
/// induction variables in `exprs`.
299+
static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
300+
SmallVectorImpl<Value> &symbols) {
301+
auto exprs = llvm::to_vector<4>(map.getResults());
302+
for (AffineExpr &expr : exprs) {
303+
bool substituted = true;
304+
while (substituted) {
305+
substituted = false;
306+
for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
307+
Value dim = dims[dimIdx];
308+
AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
309+
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
310+
AffineExpr substitutedExpr;
311+
if (auto forOp = scf::getForInductionVarOwner(dim))
312+
substitutedExpr = substituteLoopInExpr(
313+
expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
314+
forOp.step(), dims, symbols);
315+
316+
if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
317+
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
318+
++idx)
319+
substitutedExpr = substituteLoopInExpr(
320+
expr, dimExpr, parallelForOp.lowerBound()[idx],
321+
parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
322+
dims, symbols);
323+
324+
if (!substitutedExpr)
325+
continue;
326+
327+
substituted = (substitutedExpr != expr);
328+
expr = substitutedExpr;
329+
}
330+
}
331+
332+
// Cleanup and simplify the results.
333+
// This needs to happen outside of the loop iterating on dims.size() since
334+
// it modifies dims.
335+
SmallVector<Value, 4> operands(dims.begin(), dims.end());
336+
operands.append(symbols.begin(), symbols.end());
337+
auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
338+
exprs.front().getContext());
339+
340+
LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
341+
342+
// Pull in affine.apply operations and compose them fully into the
343+
// result.
344+
fullyComposeAffineMapAndOperands(&map, &operands);
345+
canonicalizeMapAndOperands(&map, &operands);
346+
map = simplifyAffineMap(map);
347+
// Assign the results.
348+
exprs.assign(map.getResults().begin(), map.getResults().end());
349+
dims.assign(operands.begin(), operands.begin() + map.getNumDims());
350+
symbols.assign(operands.begin() + map.getNumDims(), operands.end());
351+
352+
LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
353+
}
354+
355+
assert(!exprs.empty() && "Unexpected empty exprs");
356+
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
357+
}
358+
359+
LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
360+
AffineMinOp minOp, PatternRewriter &rewriter) const {
361+
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
362+
<< "\n");
363+
364+
SmallVector<Value, 4> dims(minOp.getDimOperands()),
365+
symbols(minOp.getSymbolOperands());
366+
AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
367+
368+
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
369+
370+
// Check whether any of the expressions, when subtracted from all other
371+
// expressions, produces only >= 0 constants. If so, it is the min.
372+
for (auto e : minOp.getAffineMap().getResults()) {
373+
LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
374+
if (!e.isSymbolicOrConstant())
375+
continue;
376+
377+
auto isNonPositive = [](AffineExpr e) {
378+
if (auto cst = e.dyn_cast<AffineConstantExpr>())
379+
return cst.getValue() < 0;
380+
return true;
381+
};
382+
383+
// Build the subMap and check everything is statically known to be
384+
// positive.
385+
SmallVector<AffineExpr, 4> subExprs;
386+
subExprs.reserve(map.getNumResults());
387+
for (auto ee : map.getResults())
388+
subExprs.push_back(ee - e);
389+
MLIRContext *ctx = minOp.getContext();
390+
AffineMap subMap = simplifyAffineMap(
391+
AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
392+
LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
393+
if (llvm::any_of(subMap.getResults(), isNonPositive))
394+
continue;
395+
396+
// Static min found.
397+
if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
398+
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
399+
} else {
400+
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
401+
SmallVector<Value, 4> resultOperands = dims;
402+
resultOperands.append(symbols.begin(), symbols.end());
403+
canonicalizeMapAndOperands(&resultMap, &resultOperands);
404+
resultMap = simplifyAffineMap(resultMap);
405+
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
406+
resultOperands);
407+
}
408+
return success();
409+
}
410+
411+
return failure();
412+
}

mlir/lib/IR/AffineExpr.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,37 @@ AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
101101
return replaceDimsAndSymbols({}, symbols);
102102
}
103103

104+
/// Sparse replace method. Return the modified expression tree.
105+
AffineExpr
106+
AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
107+
auto it = map.find(*this);
108+
if (it != map.end())
109+
return it->second;
110+
switch (getKind()) {
111+
default:
112+
return *this;
113+
case AffineExprKind::Add:
114+
case AffineExprKind::Mul:
115+
case AffineExprKind::FloorDiv:
116+
case AffineExprKind::CeilDiv:
117+
case AffineExprKind::Mod:
118+
auto binOp = cast<AffineBinaryOpExpr>();
119+
auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
120+
auto newLHS = lhs.replace(map);
121+
auto newRHS = rhs.replace(map);
122+
if (newLHS == lhs && newRHS == rhs)
123+
return *this;
124+
return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
125+
}
126+
llvm_unreachable("Unknown AffineExpr");
127+
}
128+
129+
/// Sparse replace method. Return the modified expression tree.
130+
AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
131+
DenseMap<AffineExpr, AffineExpr> map;
132+
map.insert(std::make_pair(expr, replacement));
133+
return replace(map);
134+
}
104135
/// Returns true if this expression is made out of only symbols and
105136
/// constants (no dimensional identifiers).
106137
bool AffineExpr::isSymbolicOrConstant() const {

0 commit comments

Comments
 (0)