@@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure
1391
1391
transform::ForeachOp::apply (transform::TransformRewriter &rewriter,
1392
1392
transform::TransformResults &results,
1393
1393
transform::TransformState &state) {
1394
- SmallVector<SmallVector<Operation *>> resultOps (getNumResults (), {});
1395
- // Store payload ops in a vector because ops may be removed from the mapping
1396
- // by the TrackingRewriter while the iteration is in progress.
1397
- SmallVector<Operation *> targets =
1398
- llvm::to_vector (state.getPayloadOps (getTarget ()));
1399
- for (Operation *op : targets) {
1394
+ // We store the payloads before executing the body as ops may be removed from
1395
+ // the mapping by the TrackingRewriter while iteration is in progress.
1396
+ SmallVector<SmallVector<MappedValue>> payloads;
1397
+ detail::prepareValueMappings (payloads, getTargets (), state);
1398
+ size_t numIterations = payloads.empty () ? 0 : payloads.front ().size ();
1399
+
1400
+ // As we will be "zipping" over them, check all payloads have the same size.
1401
+ for (size_t argIdx = 1 ; argIdx < payloads.size (); argIdx++) {
1402
+ if (payloads[argIdx].size () != numIterations) {
1403
+ return emitSilenceableError ()
1404
+ << " prior targets' payload size (" << numIterations
1405
+ << " ) differs from payload size (" << payloads[argIdx].size ()
1406
+ << " ) of target " << getTargets ()[argIdx];
1407
+ }
1408
+ }
1409
+
1410
+ // Start iterating, indexing into payloads to obtain the right arguments to
1411
+ // call the body with - each slice of payloads at the same argument index
1412
+ // corresponding to a tuple to use as the body's block arguments.
1413
+ ArrayRef<BlockArgument> blockArguments = getBody ().front ().getArguments ();
1414
+ SmallVector<SmallVector<MappedValue>> zippedResults (getNumResults (), {});
1415
+ for (size_t iterIdx = 0 ; iterIdx < numIterations; iterIdx++) {
1400
1416
auto scope = state.make_region_scope (getBody ());
1401
- if (failed (state.mapBlockArguments (getIterationVariable (), {op})))
1402
- return DiagnosedSilenceableFailure::definiteFailure ();
1417
+ // Set up arguments to the region's block.
1418
+ for (auto &&[argIdx, blockArg] : llvm::enumerate (blockArguments)) {
1419
+ MappedValue argument = payloads[argIdx][iterIdx];
1420
+ // Note that each blockArg's handle gets associated with just a single
1421
+ // element from the corresponding target's payload.
1422
+ if (failed (state.mapBlockArgument (blockArg, {argument})))
1423
+ return DiagnosedSilenceableFailure::definiteFailure ();
1424
+ }
1403
1425
1404
1426
// Execute loop body.
1405
1427
for (Operation &transform : getBody ().front ().without_terminator ()) {
1406
1428
DiagnosedSilenceableFailure result = state.applyTransform (
1407
- cast<transform::TransformOpInterface>(transform));
1429
+ llvm:: cast<transform::TransformOpInterface>(transform));
1408
1430
if (!result.succeeded ())
1409
1431
return result;
1410
1432
}
1411
1433
1412
- // Append yielded payload ops to result list (if any).
1413
- for (unsigned i = 0 ; i < getNumResults (); ++i) {
1414
- auto yieldedOps = state.getPayloadOps (getYieldOp ().getOperand (i));
1415
- resultOps[i].append (yieldedOps.begin (), yieldedOps.end ());
1416
- }
1417
- }
1418
-
1419
- for (unsigned i = 0 ; i < getNumResults (); ++i)
1420
- results.set (llvm::cast<OpResult>(getResult (i)), resultOps[i]);
1434
+ // Append yielded payloads to corresponding results from prior iterations.
1435
+ OperandRange yieldOperands = getYieldOp ().getOperands ();
1436
+ for (auto &&[result, yieldOperand, resTuple] :
1437
+ llvm::zip_equal (getResults (), yieldOperands, zippedResults))
1438
+ // NB: each iteration we add any number of ops/vals/params to a result.
1439
+ if (isa<TransformHandleTypeInterface>(result.getType ()))
1440
+ llvm::append_range (resTuple, state.getPayloadOps (yieldOperand));
1441
+ else if (isa<TransformValueHandleTypeInterface>(result.getType ()))
1442
+ llvm::append_range (resTuple, state.getPayloadValues (yieldOperand));
1443
+ else if (isa<TransformParamTypeInterface>(result.getType ()))
1444
+ llvm::append_range (resTuple, state.getParams (yieldOperand));
1445
+ else
1446
+ assert (false && " unhandled handle type" );
1447
+ }
1448
+
1449
+ // Associate the accumulated result payloads to the op's actual results.
1450
+ for (auto &&[result, resPayload] : zip_equal (getResults (), zippedResults))
1451
+ results.setMappedValues (llvm::cast<OpResult>(result), resPayload);
1421
1452
1422
1453
return DiagnosedSilenceableFailure::success ();
1423
1454
}
1424
1455
1425
1456
void transform::ForeachOp::getEffects (
1426
1457
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1427
- BlockArgument iterVar = getIterationVariable ();
1428
- if (any_of (getBody ().front ().without_terminator (), [&](Operation &op) {
1429
- return isHandleConsumed (iterVar, cast<TransformOpInterface>(&op));
1430
- })) {
1431
- consumesHandle (getTarget (), effects);
1432
- } else {
1433
- onlyReadsHandle (getTarget (), effects);
1458
+ // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1459
+ // arity errors, this method might get called before/in absence of `verify()`.
1460
+ for (auto &&[target, blockArg] :
1461
+ llvm::zip (getTargets (), getBody ().front ().getArguments ())) {
1462
+ BlockArgument blockArgument = blockArg;
1463
+ if (any_of (getBody ().front ().without_terminator (), [&](Operation &op) {
1464
+ return isHandleConsumed (blockArgument,
1465
+ cast<TransformOpInterface>(&op));
1466
+ })) {
1467
+ consumesHandle (target, effects);
1468
+ } else {
1469
+ onlyReadsHandle (target, effects);
1470
+ }
1434
1471
}
1435
1472
1436
1473
if (any_of (getBody ().front ().without_terminator (), [&](Operation &op) {
@@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(
1463
1500
1464
1501
OperandRange
1465
1502
transform::ForeachOp::getEntrySuccessorOperands (RegionBranchPoint point) {
1466
- // The iteration variable op handle is mapped to a subset (one op to be
1467
- // precise) of the payload ops of the ForeachOp operand.
1503
+ // Each block argument handle is mapped to a subset (one op to be precise)
1504
+ // of the payload of the corresponding `targets` operand of ForeachOp .
1468
1505
assert (point == getBody () && " unexpected region index" );
1469
1506
return getOperation ()->getOperands ();
1470
1507
}
@@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
1474
1511
}
1475
1512
1476
1513
LogicalResult transform::ForeachOp::verify () {
1477
- auto yieldOp = getYieldOp ();
1478
- if (getNumResults () != yieldOp.getNumOperands ())
1479
- return emitOpError () << " expects the same number of results as the "
1480
- " terminator has operands" ;
1481
- for (Value v : yieldOp.getOperands ())
1482
- if (!llvm::isa<TransformHandleTypeInterface>(v.getType ()))
1483
- return yieldOp->emitOpError (" expects operands to have types implementing "
1484
- " TransformHandleTypeInterface" );
1514
+ for (auto [targetOpt, bodyArgOpt] :
1515
+ llvm::zip_longest (getTargets (), getBody ().front ().getArguments ())) {
1516
+ if (!targetOpt || !bodyArgOpt)
1517
+ return emitOpError () << " expects the same number of targets as the body "
1518
+ " has block arguments" ;
1519
+ if (targetOpt.value ().getType () != bodyArgOpt.value ().getType ())
1520
+ return emitOpError (
1521
+ " expects co-indexed targets and the body's "
1522
+ " block arguments to have the same op/value/param type" );
1523
+ }
1524
+
1525
+ for (auto [resultOpt, yieldOperandOpt] :
1526
+ llvm::zip_longest (getResults (), getYieldOp ().getOperands ())) {
1527
+ if (!resultOpt || !yieldOperandOpt)
1528
+ return emitOpError () << " expects the same number of results as the "
1529
+ " yield terminator has operands" ;
1530
+ if (resultOpt.value ().getType () != yieldOperandOpt.value ().getType ())
1531
+ return emitOpError (" expects co-indexed results and yield "
1532
+ " operands to have the same op/value/param type" );
1533
+ }
1534
+
1485
1535
return success ();
1486
1536
}
1487
1537
0 commit comments