@@ -700,93 +700,6 @@ static std::optional<int64_t> getUpperBound(Value iv) {
700
700
return forOp.getConstantUpperBound () - 1 ;
701
701
}
702
702
703
- // / Get a lower or upper (depending on `isUpper`) bound for `expr` while using
704
- // / the constant lower and upper bounds for its inputs provided in
705
- // / `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
706
- // / bound can't be computed. This method only handles simple sum of product
707
- // / expressions (w.r.t constant coefficients) so as to not depend on anything
708
- // / heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
709
- // / ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
710
- // / semi-affine ones will lead std::nullopt being returned.
711
- static std::optional<int64_t >
712
- getBoundForExpr (AffineExpr expr, unsigned numDims, unsigned numSymbols,
713
- ArrayRef<std::optional<int64_t >> constLowerBounds,
714
- ArrayRef<std::optional<int64_t >> constUpperBounds,
715
- bool isUpper) {
716
- // Handle divs and mods.
717
- if (auto binOpExpr = expr.dyn_cast <AffineBinaryOpExpr>()) {
718
- // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
719
- // can compute an upper bound.
720
- if (binOpExpr.getKind () == AffineExprKind::FloorDiv) {
721
- auto rhsConst = binOpExpr.getRHS ().dyn_cast <AffineConstantExpr>();
722
- if (!rhsConst || rhsConst.getValue () < 1 )
723
- return std::nullopt;
724
- auto bound = getBoundForExpr (binOpExpr.getLHS (), numDims, numSymbols,
725
- constLowerBounds, constUpperBounds, isUpper);
726
- if (!bound)
727
- return std::nullopt;
728
- return mlir::floorDiv (*bound, rhsConst.getValue ());
729
- }
730
- if (binOpExpr.getKind () == AffineExprKind::CeilDiv) {
731
- auto rhsConst = binOpExpr.getRHS ().dyn_cast <AffineConstantExpr>();
732
- if (rhsConst && rhsConst.getValue () >= 1 ) {
733
- auto bound =
734
- getBoundForExpr (binOpExpr.getLHS (), numDims, numSymbols,
735
- constLowerBounds, constUpperBounds, isUpper);
736
- if (!bound)
737
- return std::nullopt;
738
- return mlir::ceilDiv (*bound, rhsConst.getValue ());
739
- }
740
- return std::nullopt;
741
- }
742
- if (binOpExpr.getKind () == AffineExprKind::Mod) {
743
- // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
744
- // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
745
- // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
746
- auto rhsConst = binOpExpr.getRHS ().dyn_cast <AffineConstantExpr>();
747
- if (rhsConst && rhsConst.getValue () >= 1 ) {
748
- int64_t rhsConstVal = rhsConst.getValue ();
749
- auto lb = getBoundForExpr (binOpExpr.getLHS (), numDims, numSymbols,
750
- constLowerBounds, constUpperBounds,
751
- /* isUpper=*/ false );
752
- auto ub = getBoundForExpr (binOpExpr.getLHS (), numDims, numSymbols,
753
- constLowerBounds, constUpperBounds, isUpper);
754
- if (ub && lb &&
755
- floorDiv (*lb, rhsConstVal) == floorDiv (*ub, rhsConstVal))
756
- return isUpper ? mod (*ub, rhsConstVal) : mod (*lb, rhsConstVal);
757
- return isUpper ? rhsConstVal - 1 : 0 ;
758
- }
759
- }
760
- }
761
- // Flatten the expression.
762
- SimpleAffineExprFlattener flattener (numDims, numSymbols);
763
- flattener.walkPostOrder (expr);
764
- ArrayRef<int64_t > flattenedExpr = flattener.operandExprStack .back ();
765
- // TODO: Handle local variables. We can get hold of flattener.localExprs and
766
- // get bound on the local expr recursively.
767
- if (flattener.numLocals > 0 )
768
- return std::nullopt;
769
- int64_t bound = 0 ;
770
- // Substitute the constant lower or upper bound for the dimensional or
771
- // symbolic input depending on `isUpper` to determine the bound.
772
- for (unsigned i = 0 , e = numDims + numSymbols; i < e; ++i) {
773
- if (flattenedExpr[i] > 0 ) {
774
- auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
775
- if (!constBound)
776
- return std::nullopt;
777
- bound += *constBound * flattenedExpr[i];
778
- } else if (flattenedExpr[i] < 0 ) {
779
- auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
780
- if (!constBound)
781
- return std::nullopt;
782
- bound += *constBound * flattenedExpr[i];
783
- }
784
- }
785
- // Constant term.
786
- bound += flattenedExpr.back ();
787
- return bound;
788
- }
789
-
790
703
// / Determine a constant upper bound for `expr` if one exists while exploiting
791
704
// / values in `operands`. Note that the upper bound is an inclusive one. `expr`
792
705
// / is guaranteed to be less than or equal to it.
@@ -805,9 +718,9 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
805
718
if (auto constExpr = expr.dyn_cast <AffineConstantExpr>())
806
719
return constExpr.getValue ();
807
720
808
- return getBoundForExpr (expr, numDims, numSymbols, constLowerBounds,
809
- constUpperBounds,
810
- /* isUpper=*/ true );
721
+ return getBoundForAffineExpr (expr, numDims, numSymbols, constLowerBounds,
722
+ constUpperBounds,
723
+ /* isUpper=*/ true );
811
724
}
812
725
813
726
// / Determine a constant lower bound for `expr` if one exists while exploiting
@@ -829,9 +742,9 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
829
742
if (auto constExpr = expr.dyn_cast <AffineConstantExpr>()) {
830
743
lowerBound = constExpr.getValue ();
831
744
} else {
832
- lowerBound = getBoundForExpr (expr, numDims, numSymbols, constLowerBounds ,
833
- constUpperBounds,
834
- /* isUpper=*/ false );
745
+ lowerBound = getBoundForAffineExpr (expr, numDims, numSymbols,
746
+ constLowerBounds, constUpperBounds,
747
+ /* isUpper=*/ false );
835
748
}
836
749
return lowerBound;
837
750
}
@@ -970,14 +883,14 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
970
883
lowerBounds.push_back (constExpr.getValue ());
971
884
upperBounds.push_back (constExpr.getValue ());
972
885
} else {
973
- lowerBounds.push_back (getBoundForExpr (e, map. getNumDims (),
974
- map.getNumSymbols (),
975
- constLowerBounds, constUpperBounds,
976
- /* isUpper=*/ false ));
977
- upperBounds.push_back (getBoundForExpr (e, map. getNumDims (),
978
- map.getNumSymbols (),
979
- constLowerBounds, constUpperBounds,
980
- /* isUpper=*/ true ));
886
+ lowerBounds.push_back (
887
+ getBoundForAffineExpr (e, map. getNumDims (), map.getNumSymbols (),
888
+ constLowerBounds, constUpperBounds,
889
+ /* isUpper=*/ false ));
890
+ upperBounds.push_back (
891
+ getBoundForAffineExpr (e, map. getNumDims (), map.getNumSymbols (),
892
+ constLowerBounds, constUpperBounds,
893
+ /* isUpper=*/ true ));
981
894
}
982
895
}
983
896
0 commit comments