Skip to content

[mlir] make transform.foreach_match forward arguments #89920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Apr 24, 2024

It may be useful to have access to additional handles or parameters when performing matches and actions in foreach_match, for example, to parameterize the matcher by rank or restrict it in a non-trivial way. Enable foreach_match to forward additional handles from operands to matcher symbols and from action symbols to results.

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

It may be useful to have access to additional handles or parameters when performing matches and actions in foreach_match, for example, to parameterize the matcher by rank or restrict it in a non-trivial way. Enable foreach_match to forward additional handles from operands to matcher symbols and from action symbols to results.


Patch is 32.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89920.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+37-23)
  • (modified) mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h (+13)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+110-37)
  • (modified) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (+29-5)
  • (modified) mlir/test/Dialect/Transform/foreach-match.mlir (+110)
  • (modified) mlir/test/Dialect/Transform/ops-invalid.mlir (+81-4)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fbac1ffb621fd2..85269ddf4abfc2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -512,7 +512,8 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-    DeclareOpInterfaceMethods<TransformOpInterface>]> {
+    DeclareOpInterfaceMethods<TransformOpInterface,
+                              ["allowsRepeatedHandleOperands"]>]> {
   let summary = "Applies named sequences when a named matcher succeeds";
   let description = [{
     Given a pair of co-indexed lists of transform dialect symbols (such as
@@ -528,25 +529,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     the following matchers are not applied to the same payload operation. If the
     action succeeds, the next payload operation in walk order is matched. If it
     fails, both silenceable and definite errors are propagated as the result of
-    this op.
-
-    The matcher symbol must take one operand of a type that implements the same
-    transform dialect interface as the `root` operand (a check is performed at
-    application time to see if the associated payload satisfies the constraints
-    of the actual type). It must not consume the operand as multiple matchers
+    this op; propagation of silenceable errors is postponed until the end of the
+    walk.
+
+    The matcher symbol must take at least one operand of a type that implements
+    the same transform dialect interface as the `root` operand (a check is
+    performed at application time to see if the associated payload satisfies the
+    constraints of the actual type), and may take additional operands with a
+    similar type requirement. It must not consume operands as multiple matchers
     may be applied. The matcher may produce any number of results. The action
     symbol paired with the matcher must take the same number of arguments as the
     matcher has results, and these arguments must implement the same transform
     dialect interfaces, but not necessarily have the exact same type (again, a
     check is performed at application time to see if the associated payload
-    satisfies the constraints of actual types on both sides). The action symbol
-    may not have results. The actions are expected to only modify payload
-    operations nested in the `root` payload operations associated with the
-    operand of this transform operation. Furhermore, the actions may not modify
-    operations outside of the currently matched payload operation, e.g., they
-    may not modify sibling or parent operations. If such behavior is desired,
-    the parent must be matched first and the nested operations obtained by
-    traversing the IR from the parent. This is due to the matching being
+    satisfies the constraints of actual types on both sides).
+
+    The action symbol may have results that are accumulated from all actions and
+    returned from the `foreach_match` operation on success. Unless the
+    `flatten_results` attribute is present, each action result must be
+    associated with exactly one payload entity. The actions are expected to only
+    modify payload operations nested in the `root` payload operations associated
+    with the operand of this transform operation. Furthermore, the actions may
+    not modify operations outside of the currently matched payload operation,
+    e.g., they may not modify sibling or parent operations. If such behavior is
+    desired, the parent must be matched first and the nested operations obtained
+    by traversing the IR from the parent. This is due to the matching being
     performed as a post-order IR walk.
 
     This operation consumes the operand and produces a new handle associated
@@ -573,19 +580,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     produced a definite failure.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$root,
-                       UnitAttr:$restrict_root,
-                       SymbolRefArrayAttr:$matchers,
-                       SymbolRefArrayAttr:$actions);
-  let results = (outs TransformHandleTypeInterface:$updated);
+  let arguments =
+      (ins TransformHandleTypeInterface:$root,
+           Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
+           UnitAttr:$restrict_root,
+           UnitAttr:$flatten_results,
+           SymbolRefArrayAttr:$matchers,
+           SymbolRefArrayAttr:$actions);
+  let results =
+      (outs TransformHandleTypeInterface:$updated,
+            Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);
 
   let assemblyFormat = [{
-    (`restrict_root` $restrict_root^)?
+    oilist( `restrict_root` $restrict_root
+          | `flatten_results` $flatten_results
+          )
     `in`
-    $root
+    $root (`,` $forwarded_inputs^)?
     custom<ForeachMatchSymbols>($matchers, $actions)
     attr-dict
-    `:` functional-type($root, $updated)
+    `:` functional-type(operands, results)
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 59cc2f22c93813..2d4425ff918473 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
 /// Verification hook for TransformOpInterface.
 LogicalResult verifyTransformOpInterface(Operation *op);
 
+/// Appends the entities associated with the given transform. values in `state`
+/// to the pre-existing list of mappings. The array of mappings must have as
+/// many elements as values. If `flatten` is set, multiple values may be
+/// associated with each transform value, and this always succeeds. Otherwise,
+/// checks that each value has exactly one mapping associated and return failure
+/// otherwise.
+LogicalResult appendValueMappings(
+    MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
+    ValueRange values, const transform::TransformState &state,
+    bool flatten = true);
+
 /// Populates `mappings` with mapped values associated with the given transform
 /// IR values in the given `state`.
 void prepareValueMappings(
@@ -317,6 +328,8 @@ class TransformState {
   }
   LogicalResult mapBlockArgument(BlockArgument argument,
                                  ArrayRef<MappedValue> values);
+  LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
+                                  ArrayRef<SmallVector<MappedValue>> mapping);
 
   // Forward declarations to support limited visibility.
   class RegionScope;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7a5a6974700586..bd8e55e584da92 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -834,19 +834,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 // CollectMatchingOp
 //===----------------------------------------------------------------------===//
 
-/// Applies matcher operations from the given `block` assigning `op` as the
-/// payload of the block's first argument. Updates `state` accordingly. If any
-/// of the matcher produces a silenceable failure, discards it (printing the
-/// content to the debug output stream) and returns failure. If any of the
-/// matchers produces a definite failure, reports it and returns failure. If all
-/// matchers in the block succeed, populates `mappings` with the payload
-/// entities associated with the block terminator operands.
+/// Applies matcher operations from the given `block` using
+/// `blockArgumentMapping` to initialize block arguments. Updates `state`
+/// accordingly. If any of the matcher produces a silenceable failure, discards
+/// it (printing the content to the debug output stream) and returns failure. If
+/// any of the matchers produces a definite failure, reports it and returns
+/// failure. If all matchers in the block succeed, populates `mappings` with the
+/// payload entities associated with the block terminator operands. Note that
+/// `mappings` will be cleared before that.
 static DiagnosedSilenceableFailure
-matchBlock(Block &block, Operation *op, transform::TransformState &state,
+matchBlock(Block &block,
+           ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
+           transform::TransformState &state,
            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
   assert(block.getParent() && "cannot match using a detached block");
   auto matchScope = state.make_region_scope(*block.getParent());
-  if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
+  if (failed(
+          state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
     return DiagnosedSilenceableFailure::definiteFailure();
 
   for (Operation &match : block.without_terminator()) {
@@ -866,6 +870,7 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
   // Remember the values mapped to the terminator operands so we can
   // forward them to the action.
   ValueRange yieldedValues = block.getTerminator()->getOperands();
+  mappings.clear();
   transform::detail::prepareValueMappings(mappings, yieldedValues, state);
   return DiagnosedSilenceableFailure::success();
 }
@@ -915,8 +920,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
 
       // Try matching.
       SmallVector<SmallVector<MappedValue>> mappings;
-      DiagnosedSilenceableFailure diag =
-          matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+      SmallVector<transform::MappedValue> inputMapping({op});
+      DiagnosedSilenceableFailure diag = matchBlock(
+          matcher.getFunctionBody().front(),
+          ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
+          mappings);
       if (diag.isDefiniteFailure())
         return WalkResult::interrupt();
       if (diag.isSilenceableFailure()) {
@@ -1001,6 +1009,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
 // ForeachMatchOp
 //===----------------------------------------------------------------------===//
 
+// This is fine because nothing is actually consumed by this op.
+bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
+
 DiagnosedSilenceableFailure
 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -1030,6 +1041,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
 
   DiagnosedSilenceableFailure overallDiag =
       DiagnosedSilenceableFailure::success();
+
+  SmallVector<SmallVector<MappedValue>> matchInputMapping;
+  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
+  SmallVector<SmallVector<MappedValue>> actionResultMapping;
+  matchInputMapping.emplace_back();
+  transform::detail::prepareValueMappings(matchInputMapping,
+                                          getForwardedInputs(), state);
+  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
+  actionResultMapping.resize(getForwardedOutputs().size());
+
   for (Operation *root : state.getPayloadOps(getRoot())) {
     WalkResult walkResult = root->walk([&](Operation *op) {
       // If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1065,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
         llvm::dbgs() << " @" << op << "\n";
       });
 
+      firstMatchArgument.clear();
+      firstMatchArgument.push_back(op);
+
       // Try all the match/action pairs until the first successful match.
       for (auto [matcher, action] : matchActionPairs) {
-        SmallVector<SmallVector<MappedValue>> mappings;
         DiagnosedSilenceableFailure diag =
-            matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+            matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
+                       state, matchOutputMapping);
         if (diag.isDefiniteFailure())
           return WalkResult::interrupt();
         if (diag.isSilenceableFailure()) {
@@ -1058,10 +1082,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
         }
 
         auto scope = state.make_region_scope(action.getFunctionBody());
-        for (auto &&[arg, map] : llvm::zip_equal(
-                 action.getFunctionBody().front().getArguments(), mappings)) {
-          if (failed(state.mapBlockArgument(arg, map)))
-            return WalkResult::interrupt();
+        if (failed(state.mapBlockArguments(
+                action.getFunctionBody().front().getArguments(),
+                matchOutputMapping))) {
+          return WalkResult::interrupt();
         }
 
         for (Operation &transform :
@@ -1082,6 +1106,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
             continue;
           }
         }
+        if (failed(detail::appendValueMappings(
+                MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
+                action.getFunctionBody().front().getTerminator()->getOperands(),
+                state, getFlattenResults()))) {
+          emitDefiniteFailure()
+              << "action @" << action.getName()
+              << " has results associated with multiple payload entities, "
+                 "but flattening was not requested";
+          return WalkResult::interrupt();
+        }
         break;
       }
       return WalkResult::advance();
@@ -1096,6 +1130,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
   // by actions, are invalidated.
   results.set(llvm::cast<OpResult>(getUpdated()),
               state.getPayloadOps(getRoot()));
+  for (auto &&[result, mapping] :
+       llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
+    results.setMappedValues(result, mapping);
+  }
   return overallDiag;
 }
 
@@ -1108,7 +1146,8 @@ void transform::ForeachMatchOp::getEffects(
   }
 
   consumesHandle(getRoot(), effects);
-  producesHandle(getUpdated(), effects);
+  onlyReadsHandle(getForwardedInputs(), effects);
+  producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
 
@@ -1224,6 +1263,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
       StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
   for (auto &&[matcher, action] :
        llvm::zip_equal(getMatchers(), getActions())) {
+    // Presence and typing.
     auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
         symbolTable.lookupNearestSymbolFrom(getOperation(),
                                             cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1290,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
       return failure();
     }
 
-    ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
-    ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
+    // Input -> matcher forwarding.
+    TypeRange operandTypes = getOperandTypes();
+    TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
+    if (operandTypes.size() != matcherArguments.size()) {
+      InFlightDiagnostic diag =
+          emitError() << "the number of operands (" << operandTypes.size()
+                      << ") doesn't match the number of matcher arguments ("
+                      << matcherArguments.size() << ") for " << matcher;
+      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+      return diag;
+    }
+    for (auto &&[i, operand, argument] :
+         llvm::enumerate(operandTypes, matcherArguments)) {
+      if (matcherSymbol.getArgAttr(i, consumedAttr)) {
+        InFlightDiagnostic diag =
+            emitOpError()
+            << "does not expect matcher symbol to consume its operand #" << i;
+        diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+        return diag;
+      }
+
+      if (implementSameTransformInterface(operand, argument))
+        continue;
+
+      InFlightDiagnostic diag =
+          emitError()
+          << "mismatching type interfaces for operand and matcher argument #"
+          << i << " of matcher " << matcher;
+      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+      return diag;
+    }
+
+    // Matcher -> action forwarding.
+    TypeRange matcherResults = matcherSymbol.getResultTypes();
+    TypeRange actionArguments = actionSymbol.getArgumentTypes();
     if (matcherResults.size() != actionArguments.size()) {
       return emitError() << "mismatching number of matcher results and "
                             "action arguments between "
@@ -1265,31 +1338,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
 
       return emitError() << "mismatching type interfaces for matcher result "
                             "and action argument #"
-                         << i;
+                         << i << "of matcher " << matcher << " and action "
+                         << action;
     }
 
-    if (!actionSymbol.getResultTypes().empty()) {
+    // Action -> result forwarding.
+    TypeRange actionResults = actionSymbol.getResultTypes();
+    auto resultTypes = TypeRange(getResultTypes()).drop_front();
+    if (actionResults.size() != resultTypes.size()) {
       InFlightDiagnostic diag =
-          emitError() << "action symbol is not expected to have results";
+          emitError() << "the number of action results ("
+                      << actionResults.size() << ") for " << action
+                      << " doesn't match the number of extra op results ("
+                      << resultTypes.size() << ")";
       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
       return diag;
     }
+    for (auto &&[i, resultType, actionType] :
+         llvm::enumerate(resultTypes, actionResults)) {
+      if (implementSameTransformInterface(resultType, actionType))
+        continue;
 
-    if (matcherSymbol.getArgumentTypes().size() != 1 ||
-        !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
-                                         getRoot().getType())) {
-      InFlightDiagnostic diag =
-          emitOpError() << "expects matcher symbol to have one argument with "
-                           "the same transform interface as the first operand";
-      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
-      return diag;
-    }
-
-    if (matcherSymbol.getArgAttr(0, consumedAttr)) {
       InFlightDiagnostic diag =
-          emitOpError()
-          << "does not expect matcher symbol to consume its operand";
-      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+          emitError() << "mismatching type interfaces for action result #" << i
+                      << " of action " << action << " and op result";
+      diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
       return diag;
     }
   }
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 48f3954b6cf69f..b6a35e23a5d1fc 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -206,6 +206,15 @@ transform::TransformState::mapBlockArgument(BlockArgument argument,
       .checkAndReport();
 }
 
+LogicalResult transform::TransformState::mapBlockArguments(
+    Block::BlockArgListType arguments,
+    ArrayRef<SmallVector<MappedValue>> mapping) {
+  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
+    if (failed(mapBlockArgument(argument, values)))
+      return failure();
+  return success();
+}
+
 LogicalResult
 transform::TransformState::setPayloadOps(Value value,
                                          ArrayRef<Operation *> targets) {
@@ -1528,11 +1537,12 @@ void transform::detail::setApplyToOneResults(
 // Utilities for implement...
[truncated]

@match -> @return
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%num = transform.num_associations %func3 : (!transform.any_op) -> !transform.param<i64>
// 2 funcs are yielded for each of the 2 funcs = 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting to review from the test, I am having trouble understanding why num_associations %func3 == 4 , I would have expected less due to %func2 in %func2, %func3 = but I can't put the finger on it.
Could you please explain?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok I seem to infer from the previous example that the first result is actually the match and the second is the forwarded stuff, which makes sense.

C++ impl makes sense.

Is there some syntax we could come up with to make disambiguating this a little more intuitive and avoid the need to go to the doc / C++ impl to follow?

This looks like it would require some custom C++ parse/print which is always painful but it seems it could be worth it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yielded results are concatenated. The first result is just the list of matched ops.

We cannot modify how SSA values are printed for results, this is fixed by the top-level printer. I've added the explicit flatten_results that enables concatenation, may be another name for that?

Copy link
Contributor

@nicolasvasilache nicolasvasilache May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about the result type
(!transform.any_op, !transform.any_op) -> !transform.any_op with forwarded [!transform.any_op...]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use named results. I really don't think we should get inventive with trailing types, they are confusing as they are.

SmallVector<SmallVector<MappedValue>> matchInputMapping;
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
SmallVector<SmallVector<MappedValue>> actionResultMapping;
matchInputMapping.emplace_back();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this outide of the call to prepareValueMappings given that prepareValueMappings immediately resizes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we also need to map the op that is being matched before the payloads that are forwarded. Added a comment.

@nicolasvasilache nicolasvasilache self-requested a review May 2, 2024 17:49
It may be useful to have access to additional handles or parameters when
performing matches and actions in `foreach_match`, for example, to
parameterize the matcher by rank or restrict it in a non-trivial way.
Enable `foreach_match` to forward additional handles from operands to
matcher symbols and from action symbols to results.
@ftynse ftynse merged commit e4b04b3 into llvm:main May 3, 2024
@ftynse ftynse deleted the foreach-match branch May 3, 2024 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants