Skip to content

Commit c157c84

Browse files
committed
[mlir][Transform] Add zip_shortest to foreach
Adds `zip_shortest` functionality to `foreach` so that when it takes multiple handles of varying lengths - instead of failing - it shrinks the size of all payloads to that of the shortest payload.
1 parent 5caad6d commit c157c84

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
624624
Each iteration gets executed by co-indexing the payloads of the arguments
625625
and mapping the body's arguments to these tuples, as though iterating over
626626
the zipped together `targets`. As such, in each iteration, the size of the
627-
payload of each of the body's block arguments is exactly one.
627+
payload of each of the body's block arguments is exactly one. The attribute
628+
`zip_shortest` can be used if the targets vary in their number of payloads;
629+
this will limit the iterations to only the number of payloads found in the
630+
shortest target.
628631

629632
This op always reads the target handles. Furthermore, it consumes a handle
630633
if there is a transform op in the body that consumes the corresponding
@@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
645648
rollback capabilities.
646649
}];
647650

648-
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
651+
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
652+
UnitAttr:$zip_shortest);
649653
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
650654
let regions = (region SizedRegion<1>:$body);
651655
let assemblyFormat =
652-
"$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
656+
"$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
653657
let hasVerifier = 1;
654658

655659
let extraClassDeclaration = [{

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,9 +1396,26 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
13961396
SmallVector<SmallVector<MappedValue>> payloads;
13971397
detail::prepareValueMappings(payloads, getTargets(), state);
13981398
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399+
bool isZipShortest = getZipShortest();
1400+
1401+
// In case of `zip_shortest`, set the number of iterations to the
1402+
// smallest payload in the targets.
1403+
if (isZipShortest) {
1404+
numIterations =
1405+
llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1406+
const SmallVector<MappedValue> &B) {
1407+
return A.size() < B.size();
1408+
})->size();
1409+
1410+
for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411+
payloads[argIdx].resize(numIterations);
1412+
}
13991413

14001414
// As we will be "zipping" over them, check all payloads have the same size.
1401-
for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
1415+
// `zip_shortest` adjusts all payloads to the same size, so skip this check
1416+
// when true.
1417+
for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
1418+
argIdx++) {
14021419
if (payloads[argIdx].size() != numIterations) {
14031420
return emitSilenceableError()
14041421
<< "prior targets' payload size (" << numIterations

0 commit comments

Comments
 (0)