Skip to content

Commit 51a2f50

Browse files
authored
[mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression not satisfying commutative (#111254)
my prove: we can simple `(n * s) ceildiv a ceildiv s` to `n ceildiv a` because `(n * s) ceildiv a ceildiv b` <=> `(n * s) ceildiv s ceildiv a` <=> `n ceildiv a` let's prove the `s floordiv a floor b` <=> `s floordiv b floor a` let `s = ka +m (m < a)` so `s floordiv a` <=> `s / a - m / a` similarly, it can be proven that: `s floordiv a floordiv b` <=> `s / (a * b) - m / (a * b) - n / (b) constrain (n < b)` <=> `s / (a * b) - (m + a*n) / (a*b)` because `a* b - (m + a*n)` <=> `a*b - a*n - m` > `a - m` > `0` so `s floordiv a floordiv b` <=> `[s / (a*b)]` <=> `s floordiv b floordiv a` but if `s floordiv b` mutiply a factor above didn't always hold true. Fixes #107508
1 parent c9a1cff commit 51a2f50

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,9 @@ unsigned AffineDimExpr::getPosition() const {
356356
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357357
/// operation, then the commutative property can be used otherwise, the floordiv
358358
/// operation is not divisible. The same argument holds for ceildiv operation.
359-
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
360-
AffineExprKind opKind) {
359+
static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
360+
AffineExprKind opKind,
361+
bool fromMul = false) {
361362
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
362363
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363364
opKind == AffineExprKind::CeilDiv) &&
@@ -372,8 +373,9 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
372373
// Checks divisibility by the given symbol for both operands.
373374
case AffineExprKind::Add: {
374375
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
376-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
376+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
377+
opKind) &&
378+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
377379
}
378380
// Checks divisibility by the given symbol for both operands. Consider the
379381
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,31 +384,38 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
382384
// `AffineExprKind::Mod` for this reason.
383385
case AffineExprKind::Mod: {
384386
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
386-
AffineExprKind::Mod) &&
387-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
388-
AffineExprKind::Mod);
387+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
388+
AffineExprKind::Mod) &&
389+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
390+
AffineExprKind::Mod);
389391
}
390392
// Checks if any of the operand divisible by the given symbol.
391393
case AffineExprKind::Mul: {
392394
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
394-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
395+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
396+
true) ||
397+
canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
398+
true);
395399
}
396400
// Floordiv and ceildiv are divisible by the given symbol when the first
397401
// operand is divisible, and the affine expression kind of the argument expr
398402
// is same as the argument `opKind`. This can be inferred from commutative
399403
// property of floordiv and ceildiv operations and are as follow:
400404
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
401405
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
402-
// It will fail if operations are not same. For example:
403-
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
406+
// It will fail 1.if operations are not same. For example:
407+
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408+
// multiplication operation in the expression. For example:
409+
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
404410
case AffineExprKind::FloorDiv:
405411
case AffineExprKind::CeilDiv: {
406412
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
407413
if (opKind != expr.getKind())
408414
return false;
409-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
415+
if (fromMul)
416+
return false;
417+
return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
418+
expr.getKind());
410419
}
411420
}
412421
llvm_unreachable("Unknown AffineExpr");
@@ -448,7 +457,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
448457
// Dividing any of the operand by the given symbol.
449458
case AffineExprKind::Mul: {
450459
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
451-
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
460+
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
452461
return binaryExpr.getLHS() *
453462
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
454463
return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
@@ -583,7 +592,8 @@ static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
583592
if (!symbolExpr)
584593
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
585594
unsigned symbolPos = symbolExpr.getPosition();
586-
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
595+
if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
596+
expr.getKind()))
587597
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
588598
if (expr.getKind() == AffineExprKind::Mod)
589599
return getAffineConstantExpr(0, expr.getContext());

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
308308
}
309309

310310
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
311-
// CHECK-LABEL: func @semiaffine_composite_floor
312-
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
311+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
312+
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
313+
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314+
// CHECK: %[[CST:.*]] = arith.constant 43
315+
return %a : index
316+
}
317+
318+
// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
319+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
320+
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
313321
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314-
// CHECK: %[[CST:.*]] = arith.constant 47
322+
// CHECK-NOT: arith.constant
323+
return %a : index
324+
}
325+
326+
// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
327+
// CHECK-LABEL: func @semiaffine_composite_floordiv
328+
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
329+
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
330+
// CHECK-NOT: arith.constant
315331
return %a : index
316332
}
317333

0 commit comments

Comments
 (0)