Skip to content

[mlir][Transform] Extend transform.foreach to take multiple arguments #93705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each payload op";
let summary = "Executes the body for each element of the payload";
let description = [{
This op has exactly one region with exactly one block ("body"). The body is
executed for each payload op that is associated to the target operand in an
unbatched fashion. I.e., the block argument ("iteration variable") is always
mapped to exactly one payload op.

This op always reads the target handle. Furthermore, it consumes the handle
if there is a transform op in the body that consumes the iteration variable.
This op does not return anything.

The transformations inside the body are applied in order of their
appearance. During application, if any transformation in the sequence fails,
the entire sequence fails immediately leaving the payload IR in potentially
invalid state, i.e., this operation offers no transformation rollback
capabilities.

This op generates as many handles as the terminating YieldOp has operands.
For each result, the payload ops of the corresponding YieldOp operand are
merged and mapped to the same resulting handle.
Execute the op's body - its single region block - exactly once per
element of the payload associated to a target handle. The body's
transformations are applied in order of appearance until reaching the
(implicit) YieldOp terminator.

Each iteration gets executed by co-indexing the payloads of the arguments
and mapping the body's arguments to these tuples, as though iterating over
the zipped together `targets`. As such, in each iteration, the size of the
payload of each of the body's block arguments is exactly one.

This op always reads the target handles. Furthermore, it consumes a handle
if there is a transform op in the body that consumes the corresponding
block argument. Handles can point to ops, values, or parameters.

#### Return Modes

This op produces as many result handles as the body's terminating YieldOp
has operands. For each result, the payloads of the corresponding YieldOp
operand are merged and mapped to the same resulting handle.

If the target handles do not associate payloads of the same size, a
silencable failure will be generated.

During application, if any transformation in the sequence fails, the entire
sequence fails immediately with the same failure, leaving the payload IR in
a potentially invalid state, i.e., this operation offers no transformation
rollback capabilities.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"$target `:` type($target) (`->` type($results)^)? $body attr-dict";
"$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }

BlockArgument getIterationVariable() {
return getBody().front().getArgument(0);
}

transform::YieldOp getYieldOp();
}];
}
Expand Down
120 changes: 85 additions & 35 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
// Store payload ops in a vector because ops may be removed from the mapping
// by the TrackingRewriter while the iteration is in progress.
SmallVector<Operation *> targets =
llvm::to_vector(state.getPayloadOps(getTarget()));
for (Operation *op : targets) {
// We store the payloads before executing the body as ops may be removed from
// the mapping by the TrackingRewriter while iteration is in progress.
SmallVector<SmallVector<MappedValue>> payloads;
detail::prepareValueMappings(payloads, getTargets(), state);
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();

// As we will be "zipping" over them, check all payloads have the same size.
for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
if (payloads[argIdx].size() != numIterations) {
return emitSilenceableError()
<< "prior targets' payload size (" << numIterations
<< ") differs from payload size (" << payloads[argIdx].size()
<< ") of target " << getTargets()[argIdx];
}
}

// Start iterating, indexing into payloads to obtain the right arguments to
// call the body with - each slice of payloads at the same argument index
// corresponding to a tuple to use as the body's block arguments.
ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
// Set up arguments to the region's block.
for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
MappedValue argument = payloads[argIdx][iterIdx];
// Note that each blockArg's handle gets associated with just a single
// element from the corresponding target's payload.
if (failed(state.mapBlockArgument(blockArg, {argument})))
return DiagnosedSilenceableFailure::definiteFailure();
}

// Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
cast<transform::TransformOpInterface>(transform));
llvm::cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}

// Append yielded payload ops to result list (if any).
for (unsigned i = 0; i < getNumResults(); ++i) {
auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
}
}

for (unsigned i = 0; i < getNumResults(); ++i)
results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
// Append yielded payloads to corresponding results from prior iterations.
OperandRange yieldOperands = getYieldOp().getOperands();
for (auto &&[result, yieldOperand, resTuple] :
llvm::zip_equal(getResults(), yieldOperands, zippedResults))
// NB: each iteration we add any number of ops/vals/params to a result.
if (isa<TransformHandleTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
else if (isa<TransformValueHandleTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
else if (isa<TransformParamTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getParams(yieldOperand));
else
assert(false && "unhandled handle type");
}

// Associate the accumulated result payloads to the op's actual results.
for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
results.setMappedValues(llvm::cast<OpResult>(result), resPayload);

return DiagnosedSilenceableFailure::success();
}

void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
BlockArgument iterVar = getIterationVariable();
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
})) {
consumesHandle(getTarget(), effects);
} else {
onlyReadsHandle(getTarget(), effects);
// NB: this `zip` should be `zip_equal` - while this op's verifier catches
// arity errors, this method might get called before/in absence of `verify()`.
for (auto &&[target, blockArg] :
llvm::zip(getTargets(), getBody().front().getArguments())) {
BlockArgument blockArgument = blockArg;
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return isHandleConsumed(blockArgument,
cast<TransformOpInterface>(&op));
})) {
consumesHandle(target, effects);
} else {
onlyReadsHandle(target, effects);
}
}

if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
Expand Down Expand Up @@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(

OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
Expand All @@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
}

LogicalResult transform::ForeachOp::verify() {
auto yieldOp = getYieldOp();
if (getNumResults() != yieldOp.getNumOperands())
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
return yieldOp->emitOpError("expects operands to have types implementing "
"TransformHandleTypeInterface");
for (auto [targetOpt, bodyArgOpt] :
llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
if (!targetOpt || !bodyArgOpt)
return emitOpError() << "expects the same number of targets as the body "
"has block arguments";
if (targetOpt.value().getType() != bodyArgOpt.value().getType())
return emitOpError(
"expects co-indexed targets and the body's "
"block arguments to have the same op/value/param type");
}

for (auto [resultOpt, yieldOperandOpt] :
llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
if (!resultOpt || !yieldOperandOpt)
return emitOpError() << "expects the same number of results as the "
"yield terminator has operands";
if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
return emitOpError("expects co-indexed results and yield "
"operands to have the same op/value/param type");
}

return success();
}

Expand Down
21 changes: 13 additions & 8 deletions mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.foreach %5 : !transform.any_op {
^bb0(%inner_linalg: !transform.any_op):
%low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
transform.yield
}
}
Expand Down Expand Up @@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
}
transform.yield
}
}
Expand Down
Loading
Loading