Skip to content

Commit 7482a84

Browse files
committed
refine
1 parent a1dcc06 commit 7482a84

File tree

1 file changed

+24
-40
lines changed

1 file changed

+24
-40
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -349,26 +349,20 @@ unsigned AffineDimExpr::getPosition() const {
349349
return static_cast<ImplType *>(expr)->position;
350350
}
351351

352-
namespace {
353-
354352
/// Returns true if the expression is divisible by the given symbol with
355353
/// position `symbolPos`. The argument `opKind` specifies here what kind of
356354
/// division or mod operation called this division. It helps in implementing the
357355
/// commutative property of the floordiv and ceildiv operations. If the argument
358356
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
359357
/// operation, then the commutative property can be used otherwise, the floordiv
360358
/// operation is not divisible. The same argument holds for ceildiv operation.
361-
bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
362-
AffineExprKind opKind,
363-
SmallVectorImpl<AffineExpr> &visitedExprs,
364-
size_t depth = 0) {
359+
static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
360+
AffineExprKind opKind,
361+
bool fromMul = false) {
365362
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
366363
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
367364
opKind == AffineExprKind::CeilDiv) &&
368365
"unexpected opKind");
369-
if (visitedExprs.size() > depth)
370-
visitedExprs.resize(depth);
371-
visitedExprs.emplace_back(expr);
372366
switch (expr.getKind()) {
373367
case AffineExprKind::Constant:
374368
return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -379,10 +373,9 @@ bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
379373
// Checks divisibility by the given symbol for both operands.
380374
case AffineExprKind::Add: {
381375
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
382-
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
383-
visitedExprs, depth + 1) &&
384-
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
385-
visitedExprs, depth + 1);
376+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
377+
opKind) &&
378+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
386379
}
387380
// Checks divisibility by the given symbol for both operands. Consider the
388381
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -391,53 +384,43 @@ bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
391384
// `AffineExprKind::Mod` for this reason.
392385
case AffineExprKind::Mod: {
393386
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
394-
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
395-
AffineExprKind::Mod, visitedExprs,
396-
depth + 1) &&
397-
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
398-
AffineExprKind::Mod, visitedExprs,
399-
depth + 1);
387+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
388+
AffineExprKind::Mod) &&
389+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
390+
AffineExprKind::Mod);
400391
}
401392
// Checks if any of the operand divisible by the given symbol.
402393
case AffineExprKind::Mul: {
403394
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
404-
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
405-
visitedExprs, depth + 1) ||
406-
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
407-
visitedExprs, depth + 1);
395+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
396+
true) ||
397+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
398+
true);
408399
}
409400
// Floordiv and ceildiv are divisible by the given symbol when the first
410401
// operand is divisible, and the affine expression kind of the argument expr
411402
// is same as the argument `opKind`. This can be inferred from commutative
412403
// property of floordiv and ceildiv operations and are as follow:
413404
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
414405
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
415-
// It will fail if operations are not same. For example:
416-
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
406+
// It will fail 1.if operations are not same. For example:
407+
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408+
// multiplication operation in the expression. For example:
409+
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
417410
case AffineExprKind::FloorDiv:
418411
case AffineExprKind::CeilDiv: {
419412
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
420413
if (opKind != expr.getKind())
421414
return false;
422-
if (llvm::any_of(visitedExprs, [](auto expr) {
423-
return expr.getKind() == AffineExprKind::Mul;
424-
}))
415+
if (fromMul)
425416
return false;
426-
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
427-
expr.getKind(), visitedExprs, depth + 1);
417+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
418+
expr.getKind());
428419
}
429420
}
430421
llvm_unreachable("Unknown AffineExpr");
431422
}
432423

433-
bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
434-
AffineExprKind opKind) {
435-
SmallVector<AffineExpr> visitedExprs;
436-
return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
437-
}
438-
439-
} // namespace
440-
441424
/// Divides the given expression by the given symbol at position `symbolPos`. It
442425
/// considers the divisibility condition is checked before calling itself. A
443426
/// null expression is returned whenever the divisibility condition fails.
@@ -474,7 +457,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
474457
// Dividing any of the operand by the given symbol.
475458
case AffineExprKind::Mul: {
476459
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
477-
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
460+
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
478461
return binaryExpr.getLHS() *
479462
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
480463
return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
@@ -609,7 +592,8 @@ static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
609592
if (!symbolExpr)
610593
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
611594
unsigned symbolPos = symbolExpr.getPosition();
612-
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
595+
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
596+
expr.getKind()))
613597
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
614598
if (expr.getKind() == AffineExprKind::Mod)
615599
return getAffineConstantExpr(0, expr.getContext());

0 commit comments

Comments
 (0)