File tree Expand file tree Collapse file tree 2 files changed +25
-4
lines changed
include/mlir/Dialect/Transform/IR Expand file tree Collapse file tree 2 files changed +25
-4
lines changed Original file line number Diff line number Diff line change @@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
624
624
Each iteration gets executed by co-indexing the payloads of the arguments
625
625
and mapping the body's arguments to these tuples, as though iterating over
626
626
the zipped together `targets`. As such, in each iteration, the size of the
627
- payload of each of the body's block arguments is exactly one.
627
+ payload of each of the body's block arguments is exactly one. The attribute
628
+ `zip_shortest` can be used if the targets vary in their number of payloads;
629
+ this will limit the iterations to only the number of payloads found in the
630
+ shortest target.
628
631
629
632
This op always reads the target handles. Furthermore, it consumes a handle
630
633
if there is a transform op in the body that consumes the corresponding
@@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
645
648
rollback capabilities.
646
649
}];
647
650
648
- let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
651
+ let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
652
+ UnitAttr:$zip_shortest);
649
653
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
650
654
let regions = (region SizedRegion<1>:$body);
651
655
let assemblyFormat =
652
- "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict ";
656
+ "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
653
657
let hasVerifier = 1;
654
658
655
659
let extraClassDeclaration = [{
Original file line number Diff line number Diff line change @@ -1396,9 +1396,26 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1396
1396
SmallVector<SmallVector<MappedValue>> payloads;
1397
1397
detail::prepareValueMappings (payloads, getTargets (), state);
1398
1398
size_t numIterations = payloads.empty () ? 0 : payloads.front ().size ();
1399
+ bool isZipShortest = getZipShortest ();
1400
+
1401
+ // In case of `zip_shortest`, set the number of iterations to the
1402
+ // smallest payload in the targets.
1403
+ if (isZipShortest) {
1404
+ numIterations =
1405
+ llvm::min_element (payloads, [&](const SmallVector<MappedValue> &A,
1406
+ const SmallVector<MappedValue> &B) {
1407
+ return A.size () < B.size ();
1408
+ })->size ();
1409
+
1410
+ for (size_t argIdx = 0 ; argIdx < payloads.size (); argIdx++)
1411
+ payloads[argIdx].resize (numIterations);
1412
+ }
1399
1413
1400
1414
// As we will be "zipping" over them, check all payloads have the same size.
1401
- for (size_t argIdx = 1 ; argIdx < payloads.size (); argIdx++) {
1415
+ // `zip_shortest` adjusts all payloads to the same size, so skip this check
1416
+ // when true.
1417
+ for (size_t argIdx = 1 ; !isZipShortest && argIdx < payloads.size ();
1418
+ argIdx++) {
1402
1419
if (payloads[argIdx].size () != numIterations) {
1403
1420
return emitSilenceableError ()
1404
1421
<< " prior targets' payload size (" << numIterations
You can’t perform that action at this time.
0 commit comments