Skip to content

[mlir][Transform] Extend transform.foreach to take multiple arguments #93705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024

Conversation

rolfmorel
Copy link
Contributor

Changes transform.foreach's interface to take multiple arguments, e.g. transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY } The semantics are that the payloads for these handles get iterated over as if the payloads have been zipped-up together - BODY gets executed once for each such tuple. The documentation explains that this implementation requires that the payloads have the same length.

This change also enables the target argument(s) to be any op/value/param handle.

The added test cases demonstrate some use cases for this change.

@rolfmorel
Copy link
Contributor Author

NB: the main motivation is to enable "continuous tiling" to be expressed as a couple of Transform ops working in unison, see here: #82792 (comment)
When this PR is merged PR82792 can add a full test case of continuous tiling to the test suite.

@rolfmorel rolfmorel marked this pull request as ready for review May 29, 2024 16:56
@llvmbot
Copy link
Member

llvmbot commented May 29, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-linalg

Author: Rolf Morel (rolfmorel)

Changes

Changes transform.foreach's interface to take multiple arguments, e.g. transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY } The semantics are that the payloads for these handles get iterated over as if the payloads have been zipped-up together - BODY gets executed once for each such tuple. The documentation explains that this implementation requires that the payloads have the same length.

This change also enables the target argument(s) to be any op/value/param handle.

The added test cases demonstrate some use cases for this change.


Patch is 28.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93705.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+31-26)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+112-29)
  • (modified) mlir/test/Dialect/Linalg/multisize-tiling-full.mlir (+13-8)
  • (modified) mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir (+73)
  • (modified) mlir/test/Dialect/Transform/ops.mlir (+18-4)
  • (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+85)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77048a28d7510..e61cd77339ac6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
          "getSuccessorRegions", "getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
     ]> {
-  let summary = "Executes the body for each payload op";
+  let summary = "Executes the body for each element of the payload";
   let description = [{
-    This op has exactly one region with exactly one block ("body"). The body is
-    executed for each payload op that is associated to the target operand in an
-    unbatched fashion. I.e., the block argument ("iteration variable") is always
-    mapped to exactly one payload op.
-
-    This op always reads the target handle. Furthermore, it consumes the handle
-    if there is a transform op in the body that consumes the iteration variable.
-    This op does not return anything.
-
-    The transformations inside the body are applied in order of their
-    appearance. During application, if any transformation in the sequence fails,
-    the entire sequence fails immediately leaving the payload IR in potentially
-    invalid state, i.e., this operation offers no transformation rollback
-    capabilities.
-
-    This op generates as many handles as the terminating YieldOp has operands.
-    For each result, the payload ops of the corresponding YieldOp operand are
-    merged and mapped to the same resulting handle.
+    Execute the op's body - its single region block - exactly once per
+    element of the payload associated to a target handle. The body's
+    transformations are applied in order of appearance until reaching the
+    (implicit) YieldOp terminator.
+
+    Each iteration gets executed by co-indexing the payloads of the arguments
+    and mapping the body's arguments to these tuples, as though iterating over
+    the zipped together `targets`. As such, in each iteration, the size of the
+    payload of each of the body's block arguments is exactly one.
+
+    This op always reads the target handles. Furthermore, it consumes a handle
+    if there is a transform op in the body that consumes the corresponding
+    block argument. Handles can point to ops, values, or parameters.
+
+    #### Return Modes
+
+    This op produces as many result handles as the body's terminating YieldOp
+    has operands. For each result, the payloads of the corresponding YieldOp
+    operand are merged and mapped to the same resulting handle.
+
+    If the target handles do not associate payloads of the same size, or they
+    do not associate any payload at all, a silencable failure will be generated.
+
+    During application, if any transformation in the sequence fails, the entire
+    sequence fails immediately with the same failure, leaving the payload IR in
+    a potentially invalid state, i.e., this operation offers no transformation
+    rollback capabilities.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$target `:` type($target) (`->` type($results)^)? $body attr-dict";
+    "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
 
-    BlockArgument getIterationVariable() {
-      return getBody().front().getArgument(0);
-    }
-
     transform::YieldOp getYieldOp();
   }];
 }
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 247759e21efb1..b58f4e0672fc2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({op});
+    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+      for (auto val : state.getPayloadValues(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({val});
+    } else if (isa<TransformParamTypeInterface>(target.getType())) {
+      for (auto param : state.getParams(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({param});
+    } else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << target.getType();
+
+    if (payloadSize != zippedArgs.size())
+      return emitSilenceableError()
+             << "payload size of prior targets (" << zippedArgs.size()
+             << ") differs from payload size (" << payloadSize << ") of target "
+             << target;
+  }
+
+  // For each arguments-as-tuple collected up above, execute the body region.
+  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+  for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
     auto scope = state.make_region_scope(getBody());
-    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
-      return DiagnosedSilenceableFailure::definiteFailure();
+    // Set up arguments to the region's block.
+    for (auto &&[blockArg, argument] :
+         llvm::zip_equal(getBody().front().getArguments(), argsTuple))
+      // Note: each blockArg's handle gets associated with just a single element
+      // from the corresponding target's payload.
+      if (failed(state.mapBlockArgument(blockArg, {argument})))
+        return DiagnosedSilenceableFailure::definiteFailure();
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
@@ -1409,28 +1456,44 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
         return result;
     }
 
-    // Append yielded payload ops to result list (if any).
-    for (unsigned i = 0; i < getNumResults(); ++i) {
-      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
-      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
-    }
+    // Append yielded payloads to results.
+    auto yieldOperands = getYieldOp().getOperands();
+    for (auto &&[result, yieldOperand, resTuple] :
+         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+      // NB: each iteration we add any number of ops/vals/params to an opresult.
+      if (isa<TransformHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+      else if (isa<TransformParamTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getParams(yieldOperand));
+      else
+        return emitDefiniteFailure()
+               << "unhandled handle type " << result.getType();
   }
 
   for (unsigned i = 0; i < getNumResults(); ++i)
-    results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+    results.setMappedValues(cast<OpResult>(getResult(i)), zippedResults[i]);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  BlockArgument iterVar = getIterationVariable();
-  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
-      })) {
-    consumesHandle(getTarget(), effects);
-  } else {
-    onlyReadsHandle(getTarget(), effects);
+
+  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+  // arity errors, this method might get called before/in absence of `verify()`.
+  for (auto &&[target, blockArg] :
+       llvm::zip(getTargets(), getBody().front().getArguments())) {
+    BlockArgument blockArgument = blockArg;
+    if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+          return isHandleConsumed(blockArgument,
+                                  cast<TransformOpInterface>(&op));
+        })) {
+      consumesHandle(target, effects);
+    } else {
+      onlyReadsHandle(target, effects);
+    }
   }
 
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,6 +1526,7 @@ void transform::ForeachOp::getSuccessorRegions(
 
 OperandRange
 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  // TODO: figure out how to update the comment & impl if necessary
   // The iteration variable op handle is mapped to a subset (one op to be
   // precise) of the payload ops of the ForeachOp operand.
   assert(point == getBody() && "unexpected region index");
@@ -1474,14 +1538,33 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
 }
 
 LogicalResult transform::ForeachOp::verify() {
-  auto yieldOp = getYieldOp();
-  if (getNumResults() != yieldOp.getNumOperands())
-    return emitOpError() << "expects the same number of results as the "
-                            "terminator has operands";
-  for (Value v : yieldOp.getOperands())
-    if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
-      return yieldOp->emitOpError("expects operands to have types implementing "
-                                  "TransformHandleTypeInterface");
+  for (auto [targetOpt, bodyArgOpt] :
+       llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
+    if (!targetOpt || !bodyArgOpt)
+      return emitOpError() << "expects the same number of targets as the body "
+                              "has block arguments";
+    auto target = targetOpt.value();
+    if (target.getType() != bodyArgOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(target.getType()))
+      return emitOpError(
+          "expects co-indexed targets and the body's "
+          "block arguments to have the same op/value/param type");
+  }
+
+  for (auto [resultOpt, yieldOperandOpt] :
+       llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
+    if (!resultOpt || !yieldOperandOpt)
+      return emitOpError() << "expects the same number of results as the "
+                              "yield terminator has operands";
+    auto result = resultOpt.value();
+    if (result.getType() != yieldOperandOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(result.getType()))
+      return emitOpError("expects co-indexed results and yield "
+                         "operands to have the same op/value/param type");
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b56608e3..51332ffce03d1 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
-    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
     %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
-    %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5 : !transform.any_op {
+    ^bb0(%inner_linalg: !transform.any_op):
+      %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+    ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..54dd2bdf953ca 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>)  {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_re...
[truncated]

@rolfmorel
Copy link
Contributor Author

@ftynse, you are probably best placed to review.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Please address the comments.

@@ -1463,6 +1526,7 @@ void transform::ForeachOp::getSuccessorRegions(

OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// TODO: figure out how to update the comment & impl if necessary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic still applies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Could you check that my updated comment here is accurate/relevant?

@rolfmorel
Copy link
Contributor Author

Thanks @ftynse - have addressed all your comments, though I had one question regarding a code comment.

Could you have another look?

@rolfmorel rolfmorel requested a review from ftynse May 30, 2024 16:51
Copy link

github-actions bot commented May 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@rolfmorel rolfmorel force-pushed the rolf.foreach_multi_arg branch 2 times, most recently from a5f147e to 59ce2e7 Compare May 31, 2024 08:53
@rolfmorel
Copy link
Contributor Author

Gentle ping, @ftynse - only doing so as #82792 is blocked waiting for this PR to land (so it can implement a full test case for continuous tiling).

Changes transform.foreach's interface to take multiple arguments, e.g.
transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY }
The semantics are that the payloads for these handles get iterated over as if
the payloads have been zipped-up together - BODY gets executed once for each
such tuple. The documentation explains that this implementation requires that
the payloads have the same length.

This change also enables the target argument(s) and result(s) to be any
op/value/param handle.

The added test cases demonstrate some use cases for this change.
@rolfmorel rolfmorel force-pushed the rolf.foreach_multi_arg branch from 37ae377 to 7e5308d Compare June 7, 2024 10:40
@rolfmorel
Copy link
Contributor Author

Thank you, @ftynse! The missing tests have been added.

I have squashed the fix-up commits. If you could help with merging, that would be perfect!

@rolfmorel
Copy link
Contributor Author

Hi @ftynse - could you (or anyone else for that matter), help with committing/merging this PR?

Thanks in advance!

@ftynse ftynse merged commit d462bf6 into llvm:main Jun 14, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants