Skip to content

[mlir][Interfaces] LoopLikeOpInterface: Expose tied loop results #70535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2023

Conversation

matthias-springer
Copy link
Member

Expose loop results, which correspond to the region iter_arg values that are returned from the loop when there are no more iterations. Exposing loop results is optional because some loops (e.g., scf.while) do not have a 1-to-1 mapping between region iter_args and op results.

Also add additional helper functions to query tied results/iter_args/inits.

Expose loop results, which correspond to the region iter_arg values that are returned from the loop when there are no more iterations. Exposing loop results is optional because some loops (e.g., `scf.while`) do not have a 1-to-1 mapping between region iter_args and op results.

Also add additional helper functions to query tied results/iter_args/inits.
@llvmbot
Copy link
Member

llvmbot commented Oct 28, 2023

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

Expose loop results, which correspond to the region iter_arg values that are returned from the loop when there are no more iterations. Exposing loop results is optional because some loops (e.g., scf.while) do not have a 1-to-1 mapping between region iter_args and op results.

Also add additional helper functions to query tied results/iter_args/inits.


Full diff: https://github.com/llvm/llvm-project/pull/70535.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+2-24)
  • (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+87)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+5-9)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+1-2)
  • (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (+31-9)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+13-1)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 43beebc1bf54166..38937fe28949436 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
 
-    /// Get the OpResult that corresponds to an OpOperand.
-    /// Assert that opOperand is an iterArg.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    OpResult getResultForOpOperand(OpOperand &opOperand) {
-      assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
-             "expected an iter args operand");
-      assert(opOperand.getOwner() == getOperation() &&
-             "opOperand does not belong to this scf::ForOp operation");
-      return getOperation()->getResult(
-        opOperand.getOperandNumber() - getNumControlOperands());
-    }
-    /// Get the OpOperand& that corresponds to an OpResultOpOperand.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    OpOperand &getOpOperandForResult(OpResult opResult) {
-      assert(opResult.getDefiningOp() == getOperation() &&
-             "opResult does not belong to the scf::ForOp operation");
-      return getOperation()->getOpOperand(
-        getNumControlOperands() + opResult.getResultNumber());
-    }
-
     /// Returns the step as an `APInt` if it is constant.
     std::optional<APInt> getConstantStep();
 
@@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
-        ["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
+        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 //===----------------------------------------------------------------------===//
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
-    ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
+    ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
                  "ParallelOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index d3d07eec8ebff57..75d90b67bd82f36 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     If one of the respective interface methods is implemented, so must the other
     two. The interface verifier ensures that the number of types of the region
     iter_args, init values and yielded values match.
+
+    Optionally, "loop results" can be exposed through this interface. These are
+    the values that are returned from the loop op when there are no more
+    iterations. The number and types of the loop results must match with the
+    region iter_args. Note: Loop results are optional because some loops
+    (e.g., `scf.while`) may produce results that do match 1-to-1 with the
+    region iter_args.
   }];
   let cppNamespace = "::mlir";
 
@@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return {};
       }]
     >,
+    InterfaceMethod<[{
+        Return the range of results that are return from this loop and
+        correspond to the "init" operands.
+
+        Note: This interface method is optional. If loop results are not
+        exposed via this interface, "std::nullopt" should be returned.
+        Otherwise, the number and types of results must match with the
+        region iter_args, inits and yielded values that are exposed via this
+        interface. If loop results are exposed but this loop op has no
+        loop-carried variables, an empty result range (and not "std::nullopt")
+        should be returned.
+      }],
+      /*retTy=*/"::std::optional<::mlir::ResultRange>",
+      /*methodName=*/"getLoopResults",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::std::nullopt;
+      }]
+    >,
     InterfaceMethod<[{
         Append the specified additional "init" operands: replace this loop with
         a new loop that has the additional init operands. The loop body of
@@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     }
 
     /// Return the region iter_arg that corresponds to the given init operand.
+    /// Return an "empty" block argument if the given operand is not an init
+    /// operand of this loop op.
     BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
       auto initsMutable = $_op.getInitsMutable();
       auto it = llvm::find(initsMutable, *opOperand);
@@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
     }
 
+    /// Return the region iter_arg that corresponds to the given loop result.
+    /// Return an "empty" block argument if the given OpResult is not a loop
+    /// result or if this op does not expose any loop results.
+    BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto it = llvm::find(*loopResults, opResult);
+      if (it == loopResults->end())
+        return {};
+      return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
+    }
+
     /// Return the init operand that corresponds to the given region iter_arg.
+    /// Return "nullptr" if the given block argument is not a region iter_arg
+    /// of this loop op.
     OpOperand *getTiedLoopInit(BlockArgument bbArg) {
       auto iterArgs = $_op.getRegionIterArgs();
       auto it = llvm::find(iterArgs, bbArg);
@@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
     }
 
+    /// Return the init operand that corresponds to the given loop result.
+    /// Return "nullptr" if the given OpResult is not a loop result or if this
+    /// op does not expose any loop results.
+    OpOperand *getTiedLoopInit(OpResult opResult) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return nullptr;
+      auto it = llvm::find(*loopResults, opResult);
+      if (it == loopResults->end())
+        return nullptr;
+      return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
+    }
+
     /// Return the yielded value that corresponds to the given region iter_arg.
+    /// Return "nullptr" if the given block argument is not a region iter_arg
+    /// of this loop op.
     OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
       auto iterArgs = $_op.getRegionIterArgs();
       auto it = llvm::find(iterArgs, bbArg);
@@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return
           &$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
     }
+
+    /// Return the loop result that corresponds to the given init operand.
+    /// Return an "empty" OpResult if the given operand is not an init operand
+    /// of this loop op or if this op does not expose any loop results.
+    OpResult getTiedLoopResult(OpOperand *opOperand) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto initsMutable = $_op.getInitsMutable();
+      auto it = llvm::find(initsMutable, *opOperand);
+      if (it == initsMutable.end())
+        return {};
+      return (*loopResults)[std::distance(initsMutable.begin(), it)];
+    }
+
+    /// Return the loop result that corresponds to the given region iter_arg.
+    /// Return an "empty" OpResult if the given block argument is not a region
+    /// iter_arg of this loop op or if this op does not expose any loop results.
+    OpResult getTiedLoopResult(BlockArgument bbArg) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto iterArgs = $_op.getRegionIterArgs();
+      auto it = llvm::find(iterArgs, bbArg);
+      if (it == iterArgs.end())
+        return {};
+      return (*loopResults)[std::distance(iterArgs.begin(), it)];
+    }
   }];
 
   let verifyWithRegions = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 19f704f5232ed81..866f51b0e92bbde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
+  unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
   auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b8b75f3f476a5da..bc33fe2a9a01079 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
   return OpFoldResult(getUpperBound());
 }
 
+std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 885e00b48ff8434..dc3c46bf896a9cf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -614,7 +614,7 @@ struct ForOpInterface
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     auto forOp = cast<scf::ForOp>(op);
-    OpResult opResult = forOp.getResultForOpOperand(opOperand);
+    OpResult opResult = forOp.getTiedLoopResult(&opOperand);
     BufferRelation relation = bufferRelation(op, opResult, state);
     return {{opResult, relation,
              /*isDefinite=*/relation == BufferRelation::Equivalent}};
@@ -625,10 +625,9 @@ struct ForOpInterface
     // ForOp results are equivalent to their corresponding init_args if the
     // corresponding iter_args and yield values are equivalent.
     auto forOp = cast<scf::ForOp>(op);
-    OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
-    auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
+    BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
     bool equivalentYield = state.areEquivalentBufferizedValues(
-        bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
+        bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
     return equivalentYield ? BufferRelation::Equivalent
                            : BufferRelation::Unknown;
   }
@@ -703,16 +702,13 @@ struct ForOpInterface
 
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
-      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
-          &forOp.getOpOperandForResult(opResult));
+      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
       return bufferization::getBufferType(bbArg, options, invocationStack);
     }
 
     // Compute result/argument number.
     BlockArgument bbArg = cast<BlockArgument>(value);
-    unsigned resultNum =
-        forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
-            .getResultNumber();
+    unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
 
     // Compute the bufferized type.
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e649125a09fea6a..df162d29a48eb89 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   if (destinationInitArg &&
       (*destinationInitArg)->getOwner() == outerMostLoop) {
     unsigned iterArgNumber =
-        outerMostLoop.getResultForOpOperand(**destinationInitArg)
-            .getResultNumber();
+        outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 15a816f4e448839..be1316b95688bf2 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
   // but the LoopLikeOpInterface provides better error messages.
   auto loopLikeOp = cast<LoopLikeOpInterface>(op);
 
-  // Verify number of inits/iter_args/yielded values.
+  // Verify number of inits/iter_args/yielded values/loop results.
   if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
     return op->emitOpError("different number of inits and region iter_args: ")
            << loopLikeOp.getInits().size()
@@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
                "different number of region iter_args and yielded values: ")
            << loopLikeOp.getRegionIterArgs().size()
            << " != " << loopLikeOp.getYieldedValues().size();
+  if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
+                                         loopLikeOp.getRegionIterArgs().size())
+    return op->emitOpError(
+               "different number of loop results and region iter_args: ")
+           << loopLikeOp.getLoopResults()->size()
+           << " != " << loopLikeOp.getRegionIterArgs().size();
 
-  // Verify types of inits/iter_args/yielded values.
+  // Verify types of inits/iter_args/yielded values/loop results.
   int64_t i = 0;
   for (const auto it :
        llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
                        loopLikeOp.getYieldedValues())) {
     if (std::get<0>(it).getType() != std::get<1>(it).getType())
-      op->emitOpError(std::to_string(i))
-          << "-th init and " << i << "-th region iter_arg have different type: "
-          << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+      return op->emitOpError(std::to_string(i))
+             << "-th init and " << i
+             << "-th region iter_arg have different type: "
+             << std::get<0>(it).getType()
+             << " != " << std::get<1>(it).getType();
     if (std::get<1>(it).getType() != std::get<2>(it).getType())
-      op->emitOpError(std::to_string(i))
-          << "-th region iter_arg and " << i
-          << "-th yielded value have different type: "
-          << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+      return op->emitOpError(std::to_string(i))
+             << "-th region iter_arg and " << i
+             << "-th yielded value have different type: "
+             << std::get<1>(it).getType()
+             << " != " << std::get<2>(it).getType();
+    ++i;
+  }
+  i = 0;
+  if (loopLikeOp.getLoopResults()) {
+    for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
+                                         *loopLikeOp.getLoopResults())) {
+      if (std::get<0>(it).getType() != std::get<1>(it).getType())
+        return op->emitOpError(std::to_string(i))
+               << "-th region iter_arg and " << i
+               << "-th loop result have different type: "
+               << std::get<0>(it).getType()
+               << " != " << std::get<1>(it).getType();
+    }
     ++i;
   }
 
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 1b2c3f563195c52..ad07a8b11327deb 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
 
 // -----
 
+func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
+  // expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
+  "scf.for"(%arg0, %arg0, %arg0, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32):
+      scf.yield %iter : f32
+    }
+  ) : (index, index, index, f32) -> (f64)
+  return
+}
+
+// -----
+
 func.func @too_many_iter_args(%arg0: index, %init: f32) {
   // expected-error @below{{different number of inits and region iter_args: 1 != 2}}
   %x = "scf.for"(%arg0, %arg0, %arg0, %init) (
@@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
   %s0 = arith.constant 0.0 : f32
   %t0 = arith.constant 1.0 : f32
   // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
-  // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
   %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
                     iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
     %sn = arith.addf %si, %si : f32

@matthias-springer matthias-springer merged commit 98a6edd into llvm:main Oct 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants