Skip to content

[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Pass stop condition in the constructor #86099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

This commit changes the API of ValueBoundsConstraintSet: the stop condition is now passed to the constructor instead of processWorklist. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current ValueBoundsConstraintSet is passed as a reference to the stop function, so that the stop function can be defined before the the ValueBoundsConstraintSet is constructed.

This change is in preparation of adding support for branches.

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

This commit changes the API of ValueBoundsConstraintSet: the stop condition is now passed to the constructor instead of processWorklist. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current ValueBoundsConstraintSet is passed as a reference to the stop function, so that the stop function can be defined before the the ValueBoundsConstraintSet is constructed.

This change is in preparation of adding support for branches.


Full diff: https://github.com/llvm/llvm-project/pull/86099.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+10-6)
  • (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+36-24)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+6-3)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 94a8a8b429c801..b79c44162ea8ef 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -113,8 +113,9 @@ class ValueBoundsConstraintSet {
   ///
   /// The first parameter of the function is the shaped value/index-typed
   /// value. The second parameter is the dimension in case of a shaped value.
-  using StopConditionFn =
-      function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
+  /// The third parameter is this constraint set.
+  using StopConditionFn = function_ref<bool(
+      Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
   /// Compute a bound for the given index-typed value or shape dimension size.
   /// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -263,12 +264,12 @@ class ValueBoundsConstraintSet {
   /// An index-typed value or the dimension of a shaped-type value.
   using ValueDim = std::pair<Value, int64_t>;
 
-  ValueBoundsConstraintSet(MLIRContext *ctx);
+  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
   /// Iteratively process all elements on the worklist until an index-typed
-  /// value or shaped value meets `stopCondition`. Such values are not processed
-  /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  /// value or shaped value meets `currentStopCondition`. Such values are not
+  /// processed any further.
+  void processWorklist();
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
@@ -316,6 +317,9 @@ class ValueBoundsConstraintSet {
 
   /// Builder for constructing affine expressions.
   Builder builder;
+
+  /// The current stop condition function.
+  StopConditionFn stopCondition = nullptr;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f76d4465d..117ee8e8701ad7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd1478aa9e6..fad221288f190e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8eebaecb9..c3a08ce86082a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
     FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
         rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
         /*stopCondition=*/
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           if (v == forOp.getUpperBound())
             return false;
           // Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24..1e13e60068ee7f 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
     ValueDimList boundOperands;
     LogicalResult status = ValueBoundsConstraintSet::computeBound(
         bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           // Stop when reaching a block argument of the loop body.
           if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
             return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f2f732f3a21d25..ec710bbacc758f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
-    : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+    MLIRContext *ctx, StopConditionFn stopCondition)
+    : builder(ctx), stopCondition(stopCondition) {}
 
 #ifndef NDEBUG
 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
@@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+  LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
 
     // Do not process any further if the stop condition is met.
     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
-    if (stopCondition(value, maybeDim))
+    if (stopCondition(value, maybeDim, *this)) {
+      LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+                              << " (dim: " << maybeDim << ")\n");
       continue;
+    }
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
     // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+    LLVM_DEBUG(llvm::dbgs()
+               << "Query value bounds for: " << value
+               << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
     if (valueBoundsOp) {
       if (dim == kIndexValue) {
         valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
       }
       continue;
     }
+    LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
 
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool closedUB) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
-  assert(!stopCondition(value, dim) &&
-         "stop condition should not be satisfied for starting point");
 #endif // NDEBUG
 
   int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext());
+  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+  assert(!stopCondition(value, dim, cstr) &&
+         "stop condition should not be satisfied for starting point");
   int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
-  cstr.processWorklist(stopCondition);
+  cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
@@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
       return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
-    return !stopCondition(p.first, maybeDim);
+    return !stopCondition(p.first, maybeDim, cstr);
   });
 
   // Compute lower and upper bounds for `valueDim`.
@@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     bool closedUB) {
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) {
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
       closedUB);
@@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
   // Reify bounds in terms of any independent values.
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+        return isIndependent(v);
+      },
       closedUB);
 }
 
@@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, AffineMap map, ValueDimList operands,
     StopConditionFn stopCondition, bool closedUB) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
-  ValueBoundsConstraintSet cstr(map.getContext());
-  int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+  // Default stop condition if none was specified: Keep adding constraints until
+  // a bound could be computed.
+  int64_t pos;
+  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+                                  ValueBoundsConstraintSet &cstr) {
+    return cstr.cstr.getConstantBound64(type, pos).has_value();
+  };
+
+  ValueBoundsConstraintSet cstr(
+      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.insert(/*isSymbol=*/false);
 
   // Add map and operands to the constraint set. Dimensions are converted to
   // symbols. All operands are added to the worklist.
@@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
       map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
 
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
-  // until `stopCondition` is met.
-  if (stopCondition) {
-    cstr.processWorklist(stopCondition);
-  } else {
-    // No stop condition specified: Keep adding constraints until a bound could
-    // be computed.
-    cstr.processWorklist(
-        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        });
-  }
+  // until the stop condition is met.
+  cstr.processWorklist();
 
   // Compute constant bound for `valueDim`.
   int64_t ubAdjustment = closedUB ? 0 : 1;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..e99a13cdca2f3c 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>)> stopCondition =
-          [&](Value v, std::optional<int64_t> d) {
+      std::function<bool(Value, std::optional<int64_t>,
+                         ValueBoundsConstraintSet & cstr)>
+          stopCondition = [&](Value v, std::optional<int64_t> d,
+                              ValueBoundsConstraintSet &cstr) {
             // Reify in terms of SSA values that are different from `value`.
             return v != value;
           };
       if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
+        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d,
+                                           ValueBoundsConstraintSet &cstr) {
           auto bbArg = dyn_cast<BlockArgument>(v);
           if (!bbArg)
             return false;

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

This commit changes the API of ValueBoundsConstraintSet: the stop condition is now passed to the constructor instead of processWorklist. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current ValueBoundsConstraintSet is passed as a reference to the stop function, so that the stop function can be defined before the the ValueBoundsConstraintSet is constructed.

This change is in preparation of adding support for branches.


Full diff: https://github.com/llvm/llvm-project/pull/86099.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+10-6)
  • (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+36-24)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+6-3)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 94a8a8b429c801..b79c44162ea8ef 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -113,8 +113,9 @@ class ValueBoundsConstraintSet {
   ///
   /// The first parameter of the function is the shaped value/index-typed
   /// value. The second parameter is the dimension in case of a shaped value.
-  using StopConditionFn =
-      function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
+  /// The third parameter is this constraint set.
+  using StopConditionFn = function_ref<bool(
+      Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
   /// Compute a bound for the given index-typed value or shape dimension size.
   /// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -263,12 +264,12 @@ class ValueBoundsConstraintSet {
   /// An index-typed value or the dimension of a shaped-type value.
   using ValueDim = std::pair<Value, int64_t>;
 
-  ValueBoundsConstraintSet(MLIRContext *ctx);
+  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
   /// Iteratively process all elements on the worklist until an index-typed
-  /// value or shaped value meets `stopCondition`. Such values are not processed
-  /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  /// value or shaped value meets `currentStopCondition`. Such values are not
+  /// processed any further.
+  void processWorklist();
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
@@ -316,6 +317,9 @@ class ValueBoundsConstraintSet {
 
   /// Builder for constructing affine expressions.
   Builder builder;
+
+  /// The current stop condition function.
+  StopConditionFn stopCondition = nullptr;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f76d4465d..117ee8e8701ad7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd1478aa9e6..fad221288f190e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8eebaecb9..c3a08ce86082a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
     FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
         rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
         /*stopCondition=*/
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           if (v == forOp.getUpperBound())
             return false;
           // Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24..1e13e60068ee7f 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
     ValueDimList boundOperands;
     LogicalResult status = ValueBoundsConstraintSet::computeBound(
         bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           // Stop when reaching a block argument of the loop body.
           if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
             return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f2f732f3a21d25..ec710bbacc758f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
-    : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+    MLIRContext *ctx, StopConditionFn stopCondition)
+    : builder(ctx), stopCondition(stopCondition) {}
 
 #ifndef NDEBUG
 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
@@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+  LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
 
     // Do not process any further if the stop condition is met.
     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
-    if (stopCondition(value, maybeDim))
+    if (stopCondition(value, maybeDim, *this)) {
+      LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+                              << " (dim: " << maybeDim << ")\n");
       continue;
+    }
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
     // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+    LLVM_DEBUG(llvm::dbgs()
+               << "Query value bounds for: " << value
+               << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
     if (valueBoundsOp) {
       if (dim == kIndexValue) {
         valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
       }
       continue;
     }
+    LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
 
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool closedUB) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
-  assert(!stopCondition(value, dim) &&
-         "stop condition should not be satisfied for starting point");
 #endif // NDEBUG
 
   int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext());
+  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+  assert(!stopCondition(value, dim, cstr) &&
+         "stop condition should not be satisfied for starting point");
   int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
-  cstr.processWorklist(stopCondition);
+  cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
@@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
       return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
-    return !stopCondition(p.first, maybeDim);
+    return !stopCondition(p.first, maybeDim, cstr);
   });
 
   // Compute lower and upper bounds for `valueDim`.
@@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     bool closedUB) {
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) {
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
       closedUB);
@@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
   // Reify bounds in terms of any independent values.
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+        return isIndependent(v);
+      },
       closedUB);
 }
 
@@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, AffineMap map, ValueDimList operands,
     StopConditionFn stopCondition, bool closedUB) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
-  ValueBoundsConstraintSet cstr(map.getContext());
-  int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+  // Default stop condition if none was specified: Keep adding constraints until
+  // a bound could be computed.
+  int64_t pos;
+  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+                                  ValueBoundsConstraintSet &cstr) {
+    return cstr.cstr.getConstantBound64(type, pos).has_value();
+  };
+
+  ValueBoundsConstraintSet cstr(
+      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.insert(/*isSymbol=*/false);
 
   // Add map and operands to the constraint set. Dimensions are converted to
   // symbols. All operands are added to the worklist.
@@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
       map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
 
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
-  // until `stopCondition` is met.
-  if (stopCondition) {
-    cstr.processWorklist(stopCondition);
-  } else {
-    // No stop condition specified: Keep adding constraints until a bound could
-    // be computed.
-    cstr.processWorklist(
-        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        });
-  }
+  // until the stop condition is met.
+  cstr.processWorklist();
 
   // Compute constant bound for `valueDim`.
   int64_t ubAdjustment = closedUB ? 0 : 1;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..e99a13cdca2f3c 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>)> stopCondition =
-          [&](Value v, std::optional<int64_t> d) {
+      std::function<bool(Value, std::optional<int64_t>,
+                         ValueBoundsConstraintSet & cstr)>
+          stopCondition = [&](Value v, std::optional<int64_t> d,
+                              ValueBoundsConstraintSet &cstr) {
             // Reify in terms of SSA values that are different from `value`.
             return v != value;
           };
       if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
+        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d,
+                                           ValueBoundsConstraintSet &cstr) {
           auto bbArg = dyn_cast<BlockArgument>(v);
           if (!bbArg)
             return false;

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from db3dde1 to 305001b Compare March 22, 2024 02:02
Comment on lines +116 to +122
/// The third parameter is this constraint set.
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not directly related to the changes in this PR, but It's not immediately clear to me if which value indicates stop/continuation. I like the walk function much more which exposes WalkResult::interrup() and WalkRedult::advance(), but I think that changing to LogicalResult and renaming this could also help.
Would be cool if we could revisit this separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I am also thinking of splitting up stopCondition into two functions, so that IR traversal boundary and values that are projected out can be specified separately. We had a use case for that recently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to prepare a follow-up PR for this, after my chain of PRs has landed.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_dead_code branch from 12e7e88 to f944279 Compare March 23, 2024 05:56
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from 305001b to ad1b2ac Compare March 23, 2024 05:58
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_dead_code branch from f944279 to 94baa23 Compare April 4, 2024 07:56
Base automatically changed from users/matthias-springer/value_bounds_dead_code to main April 4, 2024 07:56
…on in the constructor

This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed.

This change is in preparation of adding support for branches.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from ad1b2ac to 6dcdd66 Compare April 4, 2024 08:01
Copy link

github-actions bot commented Apr 4, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 35886dc63a2d024e20c10d2e1cb3f5fa5d9f72cc 6dcdd66920a45811d4ba23f65014dd4384edb13d -- mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp mlir/lib/Interfaces/ValueBoundsOpInterface.cpp mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
View the diff from clang-format here.
diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
index 67a6581eb2..3e24f67a63 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -34,7 +34,7 @@ struct ScalableValueBoundsConstraintSet
       ValueBoundsConstraintSet::StopConditionFn stopCondition,
       unsigned vscaleMin, unsigned vscaleMax)
       : RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
-        vscaleMax(vscaleMax) {};
+        vscaleMax(vscaleMax){};
 
   using RTTIExtends::bound;
   using RTTIExtends::StopConditionFn;

@matthias-springer matthias-springer merged commit 5e4a443 into main Apr 4, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_stop_fn_constr branch April 4, 2024 08:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants