Skip to content

make transform.split_handle accept any handle kind #118752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Splits a handle of payload ops into handles with a single op";
let summary = "Splits a handle or parameter into multiple values";
let description = [{
Splits `handle` into one or multiple handles, as specified by the number
of results of this operation. `handle` should be mapped to as many payload
ops as there are results. Otherwise, this transform will fail produces a
silenceable failure by default. Each result handle is mapped to exactly one
payload op. The order of the payload ops is preserved, i.e., the i-th
payload op is mapped to the i-th result handle.
ops, values or parameteres as there are results. Otherwise, this transform
will fail producing a silenceable failure by default. Each result handle
is mapped to exactly one payload unless specified otherwise by attributes
described below. The order of the payloads is preserved, i.e., the i-th
payload is mapped to the i-th result handle.

This operation is useful for ensuring a statically known number of
operations are tracked by the source `handle` and to extract them into
payloads are tracked by the source `handle` and to extract them into
individual handles that can be further manipulated in isolation.

If there are more payload ops than results, the remaining ops are mapped to
If there are more payloads than results, the remaining payloads are mapped to
the result with index `overflow_result`. If no `overflow_result` is
specified, the transform produces a silenceable failure.

If there are fewer payload ops than results, the transform produces a
silenceable failure if `fail_on_payload_too_small` is set to "true".
Otherwise, it succeeds and the remaining result handles are not mapped to
any op. It also succeeds if `handle` is empty and
anything. It also succeeds if `handle` is empty and
`pass_through_empty_handle` is set to "true", regardless of
`fail_on_payload_too_small`.
}];

let arguments = (ins TransformHandleTypeInterface:$handle,
let arguments = (ins Transform_AnyHandleOrParamType:$handle,
DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
OptionalAttr<I64Attr>:$overflow_result);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let hasVerifier = 1;

let builders = [
Expand Down
64 changes: 52 additions & 12 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2415,32 +2415,62 @@ DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
int64_t numPayloads =
llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
.Case<TransformHandleTypeInterface>([&](auto x) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case<TransformValueHandleTypeInterface>([&](auto x) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case<TransformParamTypeInterface>([&](auto x) {
return llvm::range_size(state.getParams(getHandle()));
})
.Default([](auto x) {
llvm_unreachable("unknown transform dialect type interface");
return -1;
});

auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
<< " payload ops but it contains " << numPayloadOps
<< " payload ops";
<< " payloads but it contains " << numPayloads << " payloads";
};

// Fail if there are more payload ops than results and no overflow result was
// specified.
if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
if (numPayloads > getNumResults() && !getOverflowResult().has_value())
return produceNumOpsError();

// Fail if there are more results than payload ops. Unless:
// - "fail_on_payload_too_small" is set to "false", or
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
(numPayloadOps != 0 || !getPassThroughEmptyHandle()))
if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
(numPayloads != 0 || !getPassThroughEmptyHandle()))
return produceNumOpsError();

// Distribute payload ops.
SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
// Distribute payloads.
SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
if (getOverflowResult())
resultHandles[*getOverflowResult()].reserve(numPayloadOps -
getNumResults());
for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());

auto container = [&]() {
if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
return llvm::map_to_vector(
state.getPayloadOps(getHandle()),
[](Operation *op) -> MappedValue { return op; });
}
if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
return llvm::map_to_vector(state.getPayloadValues(getHandle()),
[](Value v) -> MappedValue { return v; });
}
assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
"unsupported kind of transform dialect type");
return llvm::map_to_vector(state.getParams(getHandle()),
[](Attribute a) -> MappedValue { return a; });
}();

for (auto &&en : llvm::enumerate(container)) {
int64_t resultNum = en.index();
if (resultNum >= getNumResults())
resultNum = *getOverflowResult();
Expand All @@ -2449,7 +2479,8 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,

// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
it.value());

return DiagnosedSilenceableFailure::success();
}
Expand All @@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
if (getOverflowResult().has_value() &&
!(*getOverflowResult() < getNumResults()))
return emitOpError("overflow_result is not a valid result index");

for (Type resultType : getResultTypes()) {
if (implementSameTransformInterface(getHandle().getType(), resultType))
continue;

return emitOpError("expects result types to implement the same transform "
"interface as the operand type");
}

return success();
}

Expand Down
69 changes: 67 additions & 2 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} {
// expected-remark @below {{1}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
// expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
Expand Down Expand Up @@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} {

// -----

func.func private @opaque() -> (i32, i32)

func.func @split_handle() {
func.call @opaque() : () -> (i32, i32)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
%op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
%val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
%p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param
// expected-remark @below {{total 2}}
transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
%h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value)
%p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param
%p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param
// expected-remark @below {{first 1}}
transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
// expected-remark @below {{second 1}}
transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
transform.yield
}
}

// -----

func.func private @opaque() -> (i32, i32)

func.func @split_handle() {
func.call @opaque() : () -> (i32, i32)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
%op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
%val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
%type = transform.get_type %val : (!transform.any_value) -> !transform.any_param
%p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param
// expected-remark @below {{total 2}}
transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
%h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param)
%p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param
%p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param
// expected-remark @below {{first 1}}
transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
// expected-remark @below {{second 1}}
transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
transform.yield
}
}

// -----

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
// expected-error @below {{op expects result types to implement the same transform interface as the operand type}}
transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
transform.yield
}
}

// -----

"test.some_op"() : () -> ()
"other_dialect.other_op"() : () -> ()

Expand Down Expand Up @@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} {
transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
// expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
/// Test that yield does not crash in the presence of silenceable error in
/// propagate mode.
Expand Down
Loading