@@ -2515,12 +2515,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2515
2515
2516
2516
// Get predecessor terminator operands.
2517
2517
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4 > incomingValues;
2518
- bbArg->getSingleTerminatorOperands (incomingValues);
2519
-
2520
- // Returns true if the given terminator instruction is a `switch_enum` on
2521
- // an `Optional`-typed value. `switch_enum` instructions require
2522
- // special-case adjoint value propagation for the operand.
2523
- auto isSwitchEnumInstOnOptional =
2518
+ if (bbArg->getSingleTerminatorOperands (incomingValues)) {
2519
+ // Returns true if the given terminator instruction is a `switch_enum` on
2520
+ // an `Optional`-typed value. `switch_enum` instructions require
2521
+ // special-case adjoint value propagation for the operand.
2522
+ auto isSwitchEnumInstOnOptional =
2524
2523
[&ctx = getASTContext ()](TermInst *termInst) {
2525
2524
if (!termInst)
2526
2525
return false ;
@@ -2531,49 +2530,52 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2531
2530
return false ;
2532
2531
};
2533
2532
2534
- // Check the tangent value category of the active basic block argument.
2535
- switch (getTangentValueCategory (bbArg)) {
2536
- // If argument has a loadable tangent value category: materialize adjoint
2537
- // value of the argument, create a copy, and set the copy as the adjoint
2538
- // value of incoming values.
2539
- case SILValueCategory::Object: {
2540
- auto bbArgAdj = getAdjointValue (bb, bbArg);
2541
- auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2542
- auto concreteBBArgAdjCopy =
2533
+ // Check the tangent value category of the active basic block argument.
2534
+ switch (getTangentValueCategory (bbArg)) {
2535
+ // If argument has a loadable tangent value category: materialize adjoint
2536
+ // value of the argument, create a copy, and set the copy as the adjoint
2537
+ // value of incoming values.
2538
+ case SILValueCategory::Object: {
2539
+ auto bbArgAdj = getAdjointValue (bb, bbArg);
2540
+ auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2541
+ auto concreteBBArgAdjCopy =
2543
2542
builder.emitCopyValueOperation (pbLoc, concreteBBArgAdj);
2544
- for (auto pair : incomingValues) {
2545
- auto *predBB = std::get<0 >(pair);
2546
- auto incomingValue = std::get<1 >(pair);
2547
- // Handle `switch_enum` on `Optional`.
2548
- auto termInst = bbArg->getSingleTerminator ();
2549
- if (isSwitchEnumInstOnOptional (termInst)) {
2550
- accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2551
- } else {
2552
- blockTemporaries[getPullbackBlock (predBB)].insert (
2543
+ for (auto pair : incomingValues) {
2544
+ pair.second ->dump ();
2545
+ auto *predBB = std::get<0 >(pair);
2546
+ auto incomingValue = std::get<1 >(pair);
2547
+ // Handle `switch_enum` on `Optional`.
2548
+ auto termInst = bbArg->getSingleTerminator ();
2549
+ if (isSwitchEnumInstOnOptional (termInst)) {
2550
+ accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2551
+ } else {
2552
+ blockTemporaries[getPullbackBlock (predBB)].insert (
2553
2553
concreteBBArgAdjCopy);
2554
- setAdjointValue (predBB, incomingValue,
2555
- makeConcreteAdjointValue (concreteBBArgAdjCopy));
2554
+ setAdjointValue (predBB, incomingValue,
2555
+ makeConcreteAdjointValue (concreteBBArgAdjCopy));
2556
+ }
2556
2557
}
2558
+ break ;
2557
2559
}
2558
- break ;
2559
- }
2560
- // If argument has an address tangent value category: materialize adjoint
2561
- // value of the argument, create a copy, and set the copy as the adjoint
2562
- // value of incoming values.
2563
- case SILValueCategory::Address: {
2564
- auto bbArgAdjBuf = getAdjointBuffer (bb, bbArg );
2565
- for ( auto pair : incomingValues) {
2566
- auto incomingValue = std::get< 1 >(pair );
2567
- // Handle `switch_enum` on `Optional`.
2568
- auto termInst = bbArg-> getSingleTerminator ( );
2569
- if ( isSwitchEnumInstOnOptional (termInst))
2570
- accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf);
2571
- else
2572
- addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc) ;
2560
+ // If argument has an address tangent value category: materialize adjoint
2561
+ // value of the argument, create a copy, and set the copy as the adjoint
2562
+ // value of incoming values.
2563
+ case SILValueCategory::Address: {
2564
+ auto bbArgAdjBuf = getAdjointBuffer (bb, bbArg);
2565
+ for ( auto pair : incomingValues) {
2566
+ auto incomingValue = std::get< 1 >(pair );
2567
+ // Handle `switch_enum` on `Optional`.
2568
+ auto termInst = bbArg-> getSingleTerminator ( );
2569
+ if ( isSwitchEnumInstOnOptional (termInst))
2570
+ accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf );
2571
+ else
2572
+ addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc );
2573
+ }
2574
+ break ;
2573
2575
}
2574
- break ;
2575
- }
2576
- }
2576
+ }
2577
+ } else
2578
+ llvm::report_fatal_error ( " do not know how to handle this incoming bb argument " );
2577
2579
}
2578
2580
2579
2581
// 3. Build the pullback successor cases for the `switch_enum`
0 commit comments