Skip to content

Commit 7e5308d

Browse files
author
Rolf Morel
committed
[mlir][Transform] Extend transform.foreach to take multiple arguments
Changes transform.foreach's interface to take multiple arguments, e.g. transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY } The semantics are that the payloads for these handles get iterated over as if the payloads have been zipped-up together - BODY gets executed once for each such tuple. The documentation explains that this implementation requires that the payloads have the same length. This change also enables the target argument(s) and result(s) to be any op/value/param handle. The added test cases demonstrate some use cases for this change.
1 parent 2ceec68 commit 7e5308d

File tree

7 files changed

+354
-73
lines changed

7 files changed

+354
-73
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
614614
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
615615
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
616616
]> {
617-
let summary = "Executes the body for each payload op";
617+
let summary = "Executes the body for each element of the payload";
618618
let description = [{
619-
This op has exactly one region with exactly one block ("body"). The body is
620-
executed for each payload op that is associated to the target operand in an
621-
unbatched fashion. I.e., the block argument ("iteration variable") is always
622-
mapped to exactly one payload op.
623-
624-
This op always reads the target handle. Furthermore, it consumes the handle
625-
if there is a transform op in the body that consumes the iteration variable.
626-
This op does not return anything.
627-
628-
The transformations inside the body are applied in order of their
629-
appearance. During application, if any transformation in the sequence fails,
630-
the entire sequence fails immediately leaving the payload IR in potentially
631-
invalid state, i.e., this operation offers no transformation rollback
632-
capabilities.
633-
634-
This op generates as many handles as the terminating YieldOp has operands.
635-
For each result, the payload ops of the corresponding YieldOp operand are
636-
merged and mapped to the same resulting handle.
619+
Execute the op's body - its single region block - exactly once per
620+
element of the payload associated to a target handle. The body's
621+
transformations are applied in order of appearance until reaching the
622+
(implicit) YieldOp terminator.
623+
624+
Each iteration gets executed by co-indexing the payloads of the arguments
625+
and mapping the body's arguments to these tuples, as though iterating over
626+
the zipped together `targets`. As such, in each iteration, the size of the
627+
payload of each of the body's block arguments is exactly one.
628+
629+
This op always reads the target handles. Furthermore, it consumes a handle
630+
if there is a transform op in the body that consumes the corresponding
631+
block argument. Handles can point to ops, values, or parameters.
632+
633+
#### Return Modes
634+
635+
This op produces as many result handles as the body's terminating YieldOp
636+
has operands. For each result, the payloads of the corresponding YieldOp
637+
operand are merged and mapped to the same resulting handle.
638+
639+
If the target handles do not associate payloads of the same size, a
640+
silencable failure will be generated.
641+
642+
During application, if any transformation in the sequence fails, the entire
643+
sequence fails immediately with the same failure, leaving the payload IR in
644+
a potentially invalid state, i.e., this operation offers no transformation
645+
rollback capabilities.
637646
}];
638647

639-
let arguments = (ins TransformHandleTypeInterface:$target);
640-
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
648+
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
649+
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
641650
let regions = (region SizedRegion<1>:$body);
642651
let assemblyFormat =
643-
"$target `:` type($target) (`->` type($results)^)? $body attr-dict";
652+
"$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
644653
let hasVerifier = 1;
645654

646655
let extraClassDeclaration = [{
647656
/// Allow the dialect prefix to be omitted.
648657
static StringRef getDefaultDialect() { return "transform"; }
649658

650-
BlockArgument getIterationVariable() {
651-
return getBody().front().getArgument(0);
652-
}
653-
654659
transform::YieldOp getYieldOp();
655660
}];
656661
}

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure
13911391
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
13921392
transform::TransformResults &results,
13931393
transform::TransformState &state) {
1394-
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1395-
// Store payload ops in a vector because ops may be removed from the mapping
1396-
// by the TrackingRewriter while the iteration is in progress.
1397-
SmallVector<Operation *> targets =
1398-
llvm::to_vector(state.getPayloadOps(getTarget()));
1399-
for (Operation *op : targets) {
1394+
// We store the payloads before executing the body as ops may be removed from
1395+
// the mapping by the TrackingRewriter while iteration is in progress.
1396+
SmallVector<SmallVector<MappedValue>> payloads;
1397+
detail::prepareValueMappings(payloads, getTargets(), state);
1398+
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399+
1400+
// As we will be "zipping" over them, check all payloads have the same size.
1401+
for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
1402+
if (payloads[argIdx].size() != numIterations) {
1403+
return emitSilenceableError()
1404+
<< "prior targets' payload size (" << numIterations
1405+
<< ") differs from payload size (" << payloads[argIdx].size()
1406+
<< ") of target " << getTargets()[argIdx];
1407+
}
1408+
}
1409+
1410+
// Start iterating, indexing into payloads to obtain the right arguments to
1411+
// call the body with - each slice of payloads at the same argument index
1412+
// corresponding to a tuple to use as the body's block arguments.
1413+
ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1414+
SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1415+
for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
14001416
auto scope = state.make_region_scope(getBody());
1401-
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
1402-
return DiagnosedSilenceableFailure::definiteFailure();
1417+
// Set up arguments to the region's block.
1418+
for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1419+
MappedValue argument = payloads[argIdx][iterIdx];
1420+
// Note that each blockArg's handle gets associated with just a single
1421+
// element from the corresponding target's payload.
1422+
if (failed(state.mapBlockArgument(blockArg, {argument})))
1423+
return DiagnosedSilenceableFailure::definiteFailure();
1424+
}
14031425

14041426
// Execute loop body.
14051427
for (Operation &transform : getBody().front().without_terminator()) {
14061428
DiagnosedSilenceableFailure result = state.applyTransform(
1407-
cast<transform::TransformOpInterface>(transform));
1429+
llvm::cast<transform::TransformOpInterface>(transform));
14081430
if (!result.succeeded())
14091431
return result;
14101432
}
14111433

1412-
// Append yielded payload ops to result list (if any).
1413-
for (unsigned i = 0; i < getNumResults(); ++i) {
1414-
auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1415-
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1416-
}
1417-
}
1418-
1419-
for (unsigned i = 0; i < getNumResults(); ++i)
1420-
results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1434+
// Append yielded payloads to corresponding results from prior iterations.
1435+
OperandRange yieldOperands = getYieldOp().getOperands();
1436+
for (auto &&[result, yieldOperand, resTuple] :
1437+
llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1438+
// NB: each iteration we add any number of ops/vals/params to a result.
1439+
if (isa<TransformHandleTypeInterface>(result.getType()))
1440+
llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1441+
else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1442+
llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1443+
else if (isa<TransformParamTypeInterface>(result.getType()))
1444+
llvm::append_range(resTuple, state.getParams(yieldOperand));
1445+
else
1446+
assert(false && "unhandled handle type");
1447+
}
1448+
1449+
// Associate the accumulated result payloads to the op's actual results.
1450+
for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1451+
results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
14211452

14221453
return DiagnosedSilenceableFailure::success();
14231454
}
14241455

14251456
void transform::ForeachOp::getEffects(
14261457
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1427-
BlockArgument iterVar = getIterationVariable();
1428-
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1429-
return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
1430-
})) {
1431-
consumesHandle(getTarget(), effects);
1432-
} else {
1433-
onlyReadsHandle(getTarget(), effects);
1458+
// NB: this `zip` should be `zip_equal` - while this op's verifier catches
1459+
// arity errors, this method might get called before/in absence of `verify()`.
1460+
for (auto &&[target, blockArg] :
1461+
llvm::zip(getTargets(), getBody().front().getArguments())) {
1462+
BlockArgument blockArgument = blockArg;
1463+
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1464+
return isHandleConsumed(blockArgument,
1465+
cast<TransformOpInterface>(&op));
1466+
})) {
1467+
consumesHandle(target, effects);
1468+
} else {
1469+
onlyReadsHandle(target, effects);
1470+
}
14341471
}
14351472

14361473
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(
14631500

14641501
OperandRange
14651502
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1466-
// The iteration variable op handle is mapped to a subset (one op to be
1467-
// precise) of the payload ops of the ForeachOp operand.
1503+
// Each block argument handle is mapped to a subset (one op to be precise)
1504+
// of the payload of the corresponding `targets` operand of ForeachOp.
14681505
assert(point == getBody() && "unexpected region index");
14691506
return getOperation()->getOperands();
14701507
}
@@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
14741511
}
14751512

14761513
LogicalResult transform::ForeachOp::verify() {
1477-
auto yieldOp = getYieldOp();
1478-
if (getNumResults() != yieldOp.getNumOperands())
1479-
return emitOpError() << "expects the same number of results as the "
1480-
"terminator has operands";
1481-
for (Value v : yieldOp.getOperands())
1482-
if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1483-
return yieldOp->emitOpError("expects operands to have types implementing "
1484-
"TransformHandleTypeInterface");
1514+
for (auto [targetOpt, bodyArgOpt] :
1515+
llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1516+
if (!targetOpt || !bodyArgOpt)
1517+
return emitOpError() << "expects the same number of targets as the body "
1518+
"has block arguments";
1519+
if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1520+
return emitOpError(
1521+
"expects co-indexed targets and the body's "
1522+
"block arguments to have the same op/value/param type");
1523+
}
1524+
1525+
for (auto [resultOpt, yieldOperandOpt] :
1526+
llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1527+
if (!resultOpt || !yieldOperandOpt)
1528+
return emitOpError() << "expects the same number of results as the "
1529+
"yield terminator has operands";
1530+
if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1531+
return emitOpError("expects co-indexed results and yield "
1532+
"operands to have the same op/value/param type");
1533+
}
1534+
14851535
return success();
14861536
}
14871537

mlir/test/Dialect/Linalg/multisize-tiling-full.mlir

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
66
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
77
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
88
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
9-
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
109
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
1110
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1211
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1312
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
14-
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
15-
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
16-
transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
17-
transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
13+
transform.foreach %5 : !transform.any_op {
14+
^bb0(%inner_linalg: !transform.any_op):
15+
%low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
16+
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
17+
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
18+
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
19+
}
1820
transform.yield
1921
}
2022
}
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
114116
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
115117
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
116118
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
117-
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
118-
transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
119-
transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
119+
transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
120+
^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
121+
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
122+
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
123+
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
124+
}
120125
transform.yield
121126
}
122127
}

0 commit comments

Comments
 (0)