Skip to content

[mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression n… #111254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ unsigned AffineDimExpr::getPosition() const {
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
/// operation, then the commutative property can be used otherwise, the floordiv
/// operation is not divisible. The same argument holds for ceildiv operation.
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
AffineExprKind opKind) {
static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
AffineExprKind opKind,
bool fromMul = false) {
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
Expand All @@ -372,8 +373,9 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
opKind) &&
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
Expand All @@ -382,31 +384,38 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
true) ||
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
true);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
// It will fail 1.if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
// multiplication operation in the expression. For example:
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
if (fromMul)
return false;
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
expr.getKind());
}
}
llvm_unreachable("Unknown AffineExpr");
Expand Down Expand Up @@ -448,7 +457,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
// Dividing any of the operand by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
return binaryExpr.getLHS() *
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
Expand Down Expand Up @@ -583,7 +592,8 @@ static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
if (!symbolExpr)
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
unsigned symbolPos = symbolExpr.getPosition();
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
expr.getKind()))
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
if (expr.getKind() == AffineExprKind::Mod)
return getAffineConstantExpr(0, expr.getContext());
Expand Down
22 changes: 19 additions & 3 deletions mlir/test/Dialect/Affine/simplify-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}

// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
// CHECK-LABEL: func @semiaffine_composite_floor
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 43
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 47
// CHECK-NOT: arith.constant
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
// CHECK-LABEL: func @semiaffine_composite_floordiv
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
// CHECK-NOT: arith.constant
return %a : index
}

Expand Down
Loading