-
Notifications
You must be signed in to change notification settings - Fork 14.3k
make transform.split_handle accept any handle kind #118752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesIt can now split value and parameter handles in addition to operation handles. This is a generally useful functionality. Full diff: https://github.com/llvm/llvm-project/pull/118752.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b946fc8875860b..2d71d5b0892afe 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let summary = "Splits a handle of payload ops into handles with a single op";
+ let summary = "Splits a handle or parameter into multiple values";
let description = [{
Splits `handle` into one or multiple handles, as specified by the number
of results of this operation. `handle` should be mapped to as many payload
- ops as there are results. Otherwise, this transform will fail produces a
- silenceable failure by default. Each result handle is mapped to exactly one
- payload op. The order of the payload ops is preserved, i.e., the i-th
- payload op is mapped to the i-th result handle.
+ ops, values or parameteres as there are results. Otherwise, this transform
+ will fail producing a silenceable failure by default. Each result handle
+ is mapped to exactly one payload unless specified otherwise by attributes
+ described below. The order of the payloads is preserved, i.e., the i-th
+ payload is mapped to the i-th result handle.
This operation is useful for ensuring a statically known number of
- operations are tracked by the source `handle` and to extract them into
+ payloads are tracked by the source `handle` and to extract them into
individual handles that can be further manipulated in isolation.
- If there are more payload ops than results, the remaining ops are mapped to
+ If there are more payloads than results, the remaining payloads are mapped to
the result with index `overflow_result`. If no `overflow_result` is
specified, the transform produces a silenceable failure.
If there are fewer payload ops than results, the transform produces a
silenceable failure if `fail_on_payload_too_small` is set to "true".
Otherwise, it succeeds and the remaining result handles are not mapped to
- any op. It also succeeds if `handle` is empty and
+ anything. It also succeeds if `handle` is empty and
`pass_through_empty_handle` is set to "true", regardless of
`fail_on_payload_too_small`.
}];
- let arguments = (ins TransformHandleTypeInterface:$handle,
+ let arguments = (ins Transform_AnyHandleOrParamType:$handle,
DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
OptionalAttr<I64Attr>:$overflow_result);
- let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let hasVerifier = 1;
let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 590cae9aa0d667..68d1f2aef638a5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2415,32 +2415,63 @@ DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
+ int64_t numPayloads =
+ llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
+ .Case<TransformHandleTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case<TransformValueHandleTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case<TransformParamTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](auto x) {
+ llvm_unreachable("unknown transform dialect type interface");
+ return -1;
+ });
+
auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
- << " payload ops but it contains " << numPayloadOps
- << " payload ops";
+ << " payloads but it contains " << numPayloads
+ << " payloads";
};
// Fail if there are more payload ops than results and no overflow result was
// specified.
- if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
+ if (numPayloads > getNumResults() && !getOverflowResult().has_value())
return produceNumOpsError();
// Fail if there are more results than payload ops. Unless:
// - "fail_on_payload_too_small" is set to "false", or
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
- if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
- (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
+ if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
+ (numPayloads != 0 || !getPassThroughEmptyHandle()))
return produceNumOpsError();
- // Distribute payload ops.
- SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
+ // Distribute payloads.
+ SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
if (getOverflowResult())
- resultHandles[*getOverflowResult()].reserve(numPayloadOps -
- getNumResults());
- for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
+ resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
+
+ auto container = [&]() {
+ if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
+ return llvm::map_to_vector(
+ state.getPayloadOps(getHandle()),
+ [](Operation *op) -> MappedValue { return op; });
+ }
+ if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
+ return llvm::map_to_vector(state.getPayloadValues(getHandle()),
+ [](Value v) -> MappedValue { return v; });
+ }
+ assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
+ "unsupported kind of transform dialect type");
+ return llvm::map_to_vector(state.getParams(getHandle()),
+ [](Attribute a) -> MappedValue { return a; });
+ }();
+
+ for (auto &&en : llvm::enumerate(container)) {
int64_t resultNum = en.index();
if (resultNum >= getNumResults())
resultNum = *getOverflowResult();
@@ -2449,7 +2480,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
- results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
+ results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), it.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
if (getOverflowResult().has_value() &&
!(*getOverflowResult() < getNumResults()))
return emitOpError("overflow_result is not a valid result index");
+
+ for (Type resultType : getResultTypes()) {
+ if (implementSameTransformInterface(getHandle().getType(), resultType))
+ continue;
+
+ return emitOpError("expects result types to implement the same transform "
+ "interface as the operand type");
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 4fe2dbedff56e3..ecc234587cda95 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} {
// expected-remark @below {{1}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+ func.call @opaque() : () -> (i32, i32)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+ %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+ %p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param
+ // expected-remark @below {{total 2}}
+ transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+ %h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value)
+ %p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param
+ %p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param
+ // expected-remark @below {{first 1}}
+ transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+ // expected-remark @below {{second 1}}
+ transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+ func.call @opaque() : () -> (i32, i32)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+ %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+ %type = transform.get_type %val : (!transform.any_value) -> !transform.any_param
+ %p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param
+ // expected-remark @below {{total 2}}
+ transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+ %h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param)
+ %p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param
+ %p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param
+ // expected-remark @below {{first 1}}
+ transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+ // expected-remark @below {{second 1}}
+ transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ // expected-error @below {{op expects result types to implement the same transform interface as the operand type}}
+ transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+ transform.yield
+ }
+}
+
+// -----
+
"test.some_op"() : () -> ()
"other_dialect.other_op"() : () -> ()
@@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} {
transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
/// Test that yield does not crash in the presence of silenceable error in
/// propagate mode.
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
It can now split value and parameter handles in addition to operation handles. This is a generally useful functionality.
ae91a4c
to
0385fb5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
It can now split value and parameter handles in addition to operation handles. This is a generally useful functionality.