Skip to content

Commit a1dcc06

Browse files
committed
[mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression not satisfying commutative
Fixes #107508
1 parent 29d0a84 commit a1dcc06

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -349,19 +349,26 @@ unsigned AffineDimExpr::getPosition() const {
349349
return static_cast<ImplType *>(expr)->position;
350350
}
351351

352+
namespace {
353+
352354
/// Returns true if the expression is divisible by the given symbol with
353355
/// position `symbolPos`. The argument `opKind` specifies here what kind of
354356
/// division or mod operation called this division. It helps in implementing the
355357
/// commutative property of the floordiv and ceildiv operations. If the argument
356358
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357359
/// operation, then the commutative property can be used otherwise, the floordiv
358360
/// operation is not divisible. The same argument holds for ceildiv operation.
359-
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
360-
AffineExprKind opKind) {
361+
bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
362+
AffineExprKind opKind,
363+
SmallVectorImpl<AffineExpr> &visitedExprs,
364+
size_t depth = 0) {
361365
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
362366
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363367
opKind == AffineExprKind::CeilDiv) &&
364368
"unexpected opKind");
369+
if (visitedExprs.size() > depth)
370+
visitedExprs.resize(depth);
371+
visitedExprs.emplace_back(expr);
365372
switch (expr.getKind()) {
366373
case AffineExprKind::Constant:
367374
return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
372379
// Checks divisibility by the given symbol for both operands.
373380
case AffineExprKind::Add: {
374381
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
376-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
382+
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
383+
visitedExprs, depth + 1) &&
384+
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
385+
visitedExprs, depth + 1);
377386
}
378387
// Checks divisibility by the given symbol for both operands. Consider the
379388
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
382391
// `AffineExprKind::Mod` for this reason.
383392
case AffineExprKind::Mod: {
384393
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
386-
AffineExprKind::Mod) &&
387-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
388-
AffineExprKind::Mod);
394+
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
395+
AffineExprKind::Mod, visitedExprs,
396+
depth + 1) &&
397+
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
398+
AffineExprKind::Mod, visitedExprs,
399+
depth + 1);
389400
}
390401
// Checks if any of the operand divisible by the given symbol.
391402
case AffineExprKind::Mul: {
392403
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
394-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
404+
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
405+
visitedExprs, depth + 1) ||
406+
isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
407+
visitedExprs, depth + 1);
395408
}
396409
// Floordiv and ceildiv are divisible by the given symbol when the first
397410
// operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
406419
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
407420
if (opKind != expr.getKind())
408421
return false;
409-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
422+
if (llvm::any_of(visitedExprs, [](auto expr) {
423+
return expr.getKind() == AffineExprKind::Mul;
424+
}))
425+
return false;
426+
return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
427+
expr.getKind(), visitedExprs, depth + 1);
410428
}
411429
}
412430
llvm_unreachable("Unknown AffineExpr");
413431
}
414432

433+
bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
434+
AffineExprKind opKind) {
435+
SmallVector<AffineExpr> visitedExprs;
436+
return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
437+
}
438+
439+
} // namespace
440+
415441
/// Divides the given expression by the given symbol at position `symbolPos`. It
416442
/// considers the divisibility condition is checked before calling itself. A
417443
/// null expression is returned whenever the divisibility condition fails.

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
308308
}
309309

310310
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
311-
// CHECK-LABEL: func @semiaffine_composite_floor
312-
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
311+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
312+
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
313+
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314+
// CHECK: %[[CST:.*]] = arith.constant 43
315+
return %a : index
316+
}
317+
318+
// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
319+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
320+
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
313321
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314-
// CHECK: %[[CST:.*]] = arith.constant 47
322+
// CHECK-NOT: arith.constant
323+
return %a : index
324+
}
325+
326+
// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
327+
// CHECK-LABEL: func @semiaffine_composite_floordiv
328+
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
329+
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
330+
// CHECK-NOT: arith.constant
315331
return %a : index
316332
}
317333

0 commit comments

Comments
 (0)