@@ -349,19 +349,26 @@ unsigned AffineDimExpr::getPosition() const {
349
349
return static_cast <ImplType *>(expr)->position ;
350
350
}
351
351
352
+ namespace {
353
+
352
354
// / Returns true if the expression is divisible by the given symbol with
353
355
// / position `symbolPos`. The argument `opKind` specifies here what kind of
354
356
// / division or mod operation called this division. It helps in implementing the
355
357
// / commutative property of the floordiv and ceildiv operations. If the argument
356
358
// /`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357
359
// / operation, then the commutative property can be used otherwise, the floordiv
358
360
// / operation is not divisible. The same argument holds for ceildiv operation.
359
- static bool isDivisibleBySymbol (AffineExpr expr, unsigned symbolPos,
360
- AffineExprKind opKind) {
361
+ bool isDivisibleBySymbolImpl (AffineExpr expr, unsigned symbolPos,
362
+ AffineExprKind opKind,
363
+ SmallVectorImpl<AffineExpr> &visitedExprs,
364
+ size_t depth = 0 ) {
361
365
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
362
366
assert ((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363
367
opKind == AffineExprKind::CeilDiv) &&
364
368
" unexpected opKind" );
369
+ if (visitedExprs.size () > depth)
370
+ visitedExprs.resize (depth);
371
+ visitedExprs.emplace_back (expr);
365
372
switch (expr.getKind ()) {
366
373
case AffineExprKind::Constant:
367
374
return cast<AffineConstantExpr>(expr).getValue () == 0 ;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
372
379
// Checks divisibility by the given symbol for both operands.
373
380
case AffineExprKind::Add: {
374
381
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375
- return isDivisibleBySymbol (binaryExpr.getLHS (), symbolPos, opKind) &&
376
- isDivisibleBySymbol (binaryExpr.getRHS (), symbolPos, opKind);
382
+ return isDivisibleBySymbolImpl (binaryExpr.getLHS (), symbolPos, opKind,
383
+ visitedExprs, depth + 1 ) &&
384
+ isDivisibleBySymbolImpl (binaryExpr.getRHS (), symbolPos, opKind,
385
+ visitedExprs, depth + 1 );
377
386
}
378
387
// Checks divisibility by the given symbol for both operands. Consider the
379
388
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
382
391
// `AffineExprKind::Mod` for this reason.
383
392
case AffineExprKind::Mod: {
384
393
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385
- return isDivisibleBySymbol (binaryExpr.getLHS (), symbolPos,
386
- AffineExprKind::Mod) &&
387
- isDivisibleBySymbol (binaryExpr.getRHS (), symbolPos,
388
- AffineExprKind::Mod);
394
+ return isDivisibleBySymbolImpl (binaryExpr.getLHS (), symbolPos,
395
+ AffineExprKind::Mod, visitedExprs,
396
+ depth + 1 ) &&
397
+ isDivisibleBySymbolImpl (binaryExpr.getRHS (), symbolPos,
398
+ AffineExprKind::Mod, visitedExprs,
399
+ depth + 1 );
389
400
}
390
401
// Checks if any of the operand divisible by the given symbol.
391
402
case AffineExprKind::Mul: {
392
403
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393
- return isDivisibleBySymbol (binaryExpr.getLHS (), symbolPos, opKind) ||
394
- isDivisibleBySymbol (binaryExpr.getRHS (), symbolPos, opKind);
404
+ return isDivisibleBySymbolImpl (binaryExpr.getLHS (), symbolPos, opKind,
405
+ visitedExprs, depth + 1 ) ||
406
+ isDivisibleBySymbolImpl (binaryExpr.getRHS (), symbolPos, opKind,
407
+ visitedExprs, depth + 1 );
395
408
}
396
409
// Floordiv and ceildiv are divisible by the given symbol when the first
397
410
// operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
406
419
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
407
420
if (opKind != expr.getKind ())
408
421
return false ;
409
- return isDivisibleBySymbol (binaryExpr.getLHS (), symbolPos, expr.getKind ());
422
+ if (llvm::any_of (visitedExprs, [](auto expr) {
423
+ return expr.getKind () == AffineExprKind::Mul;
424
+ }))
425
+ return false ;
426
+ return isDivisibleBySymbolImpl (binaryExpr.getLHS (), symbolPos,
427
+ expr.getKind (), visitedExprs, depth + 1 );
410
428
}
411
429
}
412
430
llvm_unreachable (" Unknown AffineExpr" );
413
431
}
414
432
433
+ bool isDivisibleBySymbol (AffineExpr expr, unsigned symbolPos,
434
+ AffineExprKind opKind) {
435
+ SmallVector<AffineExpr> visitedExprs;
436
+ return isDivisibleBySymbolImpl (expr, symbolPos, opKind, visitedExprs);
437
+ }
438
+
439
+ } // namespace
440
+
415
441
// / Divides the given expression by the given symbol at position `symbolPos`. It
416
442
// / considers the divisibility condition is checked before calling itself. A
417
443
// / null expression is returned whenever the divisibility condition fails.
0 commit comments