Skip to content

Commit 1c352e6

Browse files
authored
make transform.split_handle accept any handle kind (#118752)
It can now split value and parameter handles in addition to operation handles. This is a generally useful functionality.
1 parent 938cdd6 commit 1c352e6

File tree

3 files changed

+130
-24
lines changed

3 files changed

+130
-24
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
10621062
[FunctionalStyleTransformOpTrait,
10631063
DeclareOpInterfaceMethods<TransformOpInterface>,
10641064
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
1065-
let summary = "Splits a handle of payload ops into handles with a single op";
1065+
let summary = "Splits a handle or parameter into multiple values";
10661066
let description = [{
10671067
Splits `handle` into one or multiple handles, as specified by the number
10681068
of results of this operation. `handle` should be mapped to as many payload
1069-
ops as there are results. Otherwise, this transform will fail produces a
1070-
silenceable failure by default. Each result handle is mapped to exactly one
1071-
payload op. The order of the payload ops is preserved, i.e., the i-th
1072-
payload op is mapped to the i-th result handle.
1069+
ops, values or parameteres as there are results. Otherwise, this transform
1070+
will fail producing a silenceable failure by default. Each result handle
1071+
is mapped to exactly one payload unless specified otherwise by attributes
1072+
described below. The order of the payloads is preserved, i.e., the i-th
1073+
payload is mapped to the i-th result handle.
10731074

10741075
This operation is useful for ensuring a statically known number of
1075-
operations are tracked by the source `handle` and to extract them into
1076+
payloads are tracked by the source `handle` and to extract them into
10761077
individual handles that can be further manipulated in isolation.
10771078

1078-
If there are more payload ops than results, the remaining ops are mapped to
1079+
If there are more payloads than results, the remaining payloads are mapped to
10791080
the result with index `overflow_result`. If no `overflow_result` is
10801081
specified, the transform produces a silenceable failure.
10811082

10821083
If there are fewer payload ops than results, the transform produces a
10831084
silenceable failure if `fail_on_payload_too_small` is set to "true".
10841085
Otherwise, it succeeds and the remaining result handles are not mapped to
1085-
any op. It also succeeds if `handle` is empty and
1086+
anything. It also succeeds if `handle` is empty and
10861087
`pass_through_empty_handle` is set to "true", regardless of
10871088
`fail_on_payload_too_small`.
10881089
}];
10891090

1090-
let arguments = (ins TransformHandleTypeInterface:$handle,
1091+
let arguments = (ins Transform_AnyHandleOrParamType:$handle,
10911092
DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
10921093
DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
10931094
OptionalAttr<I64Attr>:$overflow_result);
1094-
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
1095+
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
10951096
let hasVerifier = 1;
10961097

10971098
let builders = [

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,32 +2415,62 @@ DiagnosedSilenceableFailure
24152415
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
24162416
transform::TransformResults &results,
24172417
transform::TransformState &state) {
2418-
int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2418+
int64_t numPayloads =
2419+
llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
2420+
.Case<TransformHandleTypeInterface>([&](auto x) {
2421+
return llvm::range_size(state.getPayloadOps(getHandle()));
2422+
})
2423+
.Case<TransformValueHandleTypeInterface>([&](auto x) {
2424+
return llvm::range_size(state.getPayloadValues(getHandle()));
2425+
})
2426+
.Case<TransformParamTypeInterface>([&](auto x) {
2427+
return llvm::range_size(state.getParams(getHandle()));
2428+
})
2429+
.Default([](auto x) {
2430+
llvm_unreachable("unknown transform dialect type interface");
2431+
return -1;
2432+
});
2433+
24192434
auto produceNumOpsError = [&]() {
24202435
return emitSilenceableError()
24212436
<< getHandle() << " expected to contain " << this->getNumResults()
2422-
<< " payload ops but it contains " << numPayloadOps
2423-
<< " payload ops";
2437+
<< " payloads but it contains " << numPayloads << " payloads";
24242438
};
24252439

24262440
// Fail if there are more payload ops than results and no overflow result was
24272441
// specified.
2428-
if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2442+
if (numPayloads > getNumResults() && !getOverflowResult().has_value())
24292443
return produceNumOpsError();
24302444

24312445
// Fail if there are more results than payload ops. Unless:
24322446
// - "fail_on_payload_too_small" is set to "false", or
24332447
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2434-
if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2435-
(numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2448+
if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2449+
(numPayloads != 0 || !getPassThroughEmptyHandle()))
24362450
return produceNumOpsError();
24372451

2438-
// Distribute payload ops.
2439-
SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2452+
// Distribute payloads.
2453+
SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
24402454
if (getOverflowResult())
2441-
resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2442-
getNumResults());
2443-
for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2455+
resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2456+
2457+
auto container = [&]() {
2458+
if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2459+
return llvm::map_to_vector(
2460+
state.getPayloadOps(getHandle()),
2461+
[](Operation *op) -> MappedValue { return op; });
2462+
}
2463+
if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2464+
return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2465+
[](Value v) -> MappedValue { return v; });
2466+
}
2467+
assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2468+
"unsupported kind of transform dialect type");
2469+
return llvm::map_to_vector(state.getParams(getHandle()),
2470+
[](Attribute a) -> MappedValue { return a; });
2471+
}();
2472+
2473+
for (auto &&en : llvm::enumerate(container)) {
24442474
int64_t resultNum = en.index();
24452475
if (resultNum >= getNumResults())
24462476
resultNum = *getOverflowResult();
@@ -2449,7 +2479,8 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
24492479

24502480
// Set transform op results.
24512481
for (auto &&it : llvm::enumerate(resultHandles))
2452-
results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2482+
results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2483+
it.value());
24532484

24542485
return DiagnosedSilenceableFailure::success();
24552486
}
@@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
24662497
if (getOverflowResult().has_value() &&
24672498
!(*getOverflowResult() < getNumResults()))
24682499
return emitOpError("overflow_result is not a valid result index");
2500+
2501+
for (Type resultType : getResultTypes()) {
2502+
if (implementSameTransformInterface(getHandle().getType(), resultType))
2503+
continue;
2504+
2505+
return emitOpError("expects result types to implement the same transform "
2506+
"interface as the operand type");
2507+
}
2508+
24692509
return success();
24702510
}
24712511

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} {
10941094
// expected-remark @below {{1}}
10951095
transform.debug.emit_param_as_remark %p : !transform.param<i64>
10961096
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
1097-
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
1097+
// expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
10981098
%h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
10991099
transform.yield
11001100
}
@@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} {
11801180

11811181
// -----
11821182

1183+
func.func private @opaque() -> (i32, i32)
1184+
1185+
func.func @split_handle() {
1186+
func.call @opaque() : () -> (i32, i32)
1187+
return
1188+
}
1189+
1190+
module attributes {transform.with_named_sequence} {
1191+
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
1192+
%op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
1193+
%val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
1194+
%p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param
1195+
// expected-remark @below {{total 2}}
1196+
transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
1197+
%h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value)
1198+
%p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param
1199+
%p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param
1200+
// expected-remark @below {{first 1}}
1201+
transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
1202+
// expected-remark @below {{second 1}}
1203+
transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
1204+
transform.yield
1205+
}
1206+
}
1207+
1208+
// -----
1209+
1210+
func.func private @opaque() -> (i32, i32)
1211+
1212+
func.func @split_handle() {
1213+
func.call @opaque() : () -> (i32, i32)
1214+
return
1215+
}
1216+
1217+
module attributes {transform.with_named_sequence} {
1218+
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
1219+
%op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
1220+
%val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
1221+
%type = transform.get_type %val : (!transform.any_value) -> !transform.any_param
1222+
%p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param
1223+
// expected-remark @below {{total 2}}
1224+
transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
1225+
%h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param)
1226+
%p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param
1227+
%p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param
1228+
// expected-remark @below {{first 1}}
1229+
transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
1230+
// expected-remark @below {{second 1}}
1231+
transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
1232+
transform.yield
1233+
}
1234+
}
1235+
1236+
// -----
1237+
1238+
module attributes {transform.with_named_sequence} {
1239+
transform.named_sequence @__transform_main(%fun: !transform.any_op) {
1240+
// expected-error @below {{op expects result types to implement the same transform interface as the operand type}}
1241+
transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
1242+
transform.yield
1243+
}
1244+
}
1245+
1246+
// -----
1247+
11831248
"test.some_op"() : () -> ()
11841249
"other_dialect.other_op"() : () -> ()
11851250

@@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} {
13241389
transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) {
13251390
^bb1(%fun: !transform.any_op):
13261391
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
1327-
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
1392+
// expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
13281393
%h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
13291394
/// Test that yield does not crash in the presence of silenceable error in
13301395
/// propagate mode.

0 commit comments

Comments
 (0)