@@ -668,9 +668,10 @@ class LinearMapInfo {
668
668
return linearMapDecl;
669
669
}
670
670
671
- // / This takes the declared linear map structs per basic block, and populates them with the necessary
672
- // / fields, specifically the linear function (pullback or differential) of the corresponding original function call
673
- // / in the original function, and the branching enum.
671
+ // / This takes the declared linear map structs and populates
672
+ // / them with the necessary fields, specifically the linear function (pullback
673
+ // / or differential) of the corresponding original function call in the
674
+ // / original function, and the branching enum.
674
675
void populateLinearMapStructDeclarationFields (
675
676
ADContext &context, const SILAutoDiffIndices &indices,
676
677
SILFunction *assocFn);
@@ -1490,39 +1491,38 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
1490
1491
populateLinearMapStructDeclarationFields (context, indices, assocFn);
1491
1492
}
1492
1493
1493
- bool LinearMapInfo::shouldBeDifferentiated (
1494
- ApplyInst *ai, const SILAutoDiffIndices &indices) {
1494
+ bool LinearMapInfo::shouldBeDifferentiated (ApplyInst *ai,
1495
+ const SILAutoDiffIndices &indices) {
1495
1496
// Anything with an active result should be differentiated.
1496
1497
if (llvm::any_of (ai->getResults (), [&](SILValue val) {
1497
1498
return activityInfo.isActive (val, indices);
1498
- })) {
1499
+ }))
1499
1500
return true ;
1500
- }
1501
+
1501
1502
// Function applications with an active indirect result should be
1502
1503
// differentiated.
1503
1504
for (auto indRes : ai->getIndirectSILResults ())
1504
- if (activityInfo.isActive (indRes, indices)) {
1505
+ if (activityInfo.isActive (indRes, indices))
1505
1506
return true ;
1506
- }
1507
+
1507
1508
// Function applications with an inout argument should be differentiated.
1508
1509
auto paramInfos = ai->getSubstCalleeConv ().getParameters ();
1509
1510
for (auto i : swift::indices (paramInfos))
1510
1511
if (paramInfos[i].isIndirectInOut () &&
1511
1512
activityInfo.isActive (
1512
- ai->getArgumentsWithoutIndirectResults ()[i], indices)) {
1513
+ ai->getArgumentsWithoutIndirectResults ()[i], indices))
1513
1514
return true ;
1514
- }
1515
+
1515
1516
// Instructions that may write to memory and that have an active operand
1516
1517
// should be differentiated.
1517
1518
if (ai->mayWriteToMemory ())
1518
1519
for (auto &op : ai->getAllOperands ())
1519
- if (activityInfo.isActive (op.get (), indices)) {
1520
+ if (activityInfo.isActive (op.get (), indices))
1520
1521
return true ;
1521
- }
1522
1522
return false ;
1523
1523
}
1524
1524
1525
- void LinearMapInfo::populateLinearMapStructDeclarationFields (
1525
+ void LinearMapInfo::populateLinearMapStructDeclarationFields (
1526
1526
ADContext &context, const SILAutoDiffIndices &indices,
1527
1527
SILFunction *assocFn) {
1528
1528
@@ -1574,7 +1574,8 @@ bool LinearMapInfo::shouldBeDifferentiated(
1574
1574
if (paramInfos[i].isIndirectInOut () &&
1575
1575
activityInfo.isActive (ai->getArgumentsWithoutIndirectResults ()[i],
1576
1576
indices)) {
1577
- // Reject functions with active inout arguments. It's not yet supported.
1577
+ // Reject functions with active inout arguments. It's not yet
1578
+ // supported.
1578
1579
isInout = true ;
1579
1580
break ;
1580
1581
}
@@ -1595,11 +1596,11 @@ bool LinearMapInfo::shouldBeDifferentiated(
1595
1596
// Check if there are any active results or arguments. If not, skip
1596
1597
// this instruction.
1597
1598
auto hasActiveResults = llvm::any_of (
1598
- allResults, [&](SILValue res) {
1599
+ allResults, [&](SILValue res) {
1599
1600
return activityInfo.isActive (res, indices);
1600
1601
});
1601
1602
auto hasActiveArguments = llvm::any_of (
1602
- ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1603
+ ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1603
1604
return activityInfo.isActive (arg, indices);
1604
1605
});
1605
1606
if (!hasActiveResults || !hasActiveArguments)
@@ -1617,9 +1618,10 @@ bool LinearMapInfo::shouldBeDifferentiated(
1617
1618
1618
1619
// If function is already marked differentiable, differentiate WRT
1619
1620
// all parameters.
1620
- auto originalFnSubstTy = ai->getSubstCalleeType ();;
1621
+ auto originalFnSubstTy = ai->getSubstCalleeType ();
1621
1622
if (originalFnSubstTy->isDifferentiable ()) {
1622
- parameters = originalFnSubstTy->getDifferentiationParameterIndices ();
1623
+ parameters =
1624
+ originalFnSubstTy->getDifferentiationParameterIndices ();
1623
1625
} else {
1624
1626
parameters = AutoDiffIndexSubset::get (
1625
1627
original->getASTContext (),
@@ -1632,43 +1634,44 @@ bool LinearMapInfo::shouldBeDifferentiated(
1632
1634
ai->getArgumentsWithoutIndirectResults ().size (),
1633
1635
activeParamIndices));
1634
1636
1635
- // Check and diagnose non-differentiable original function type.
1636
- auto diagnoseNondifferentiableOriginalFunctionType =
1637
+ // Check for non-differentiable original function type.
1638
+ auto checkNondifferentiableOriginalFunctionType =
1637
1639
[&](CanSILFunctionType origFnTy) {
1638
1640
// Check and diagnose non-differentiable arguments.
1639
- for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1641
+ for (unsigned paramIndex :
1642
+ range (origFnTy->getNumParameters ())) {
1640
1643
if (curIndices.isWrtParameter (paramIndex) &&
1641
1644
!origFnTy->getParameters ()[paramIndex]
1642
1645
.getSILStorageType ()
1643
- .isDifferentiable (builder.getModule ())) {
1646
+ .isDifferentiable (builder.getModule ()))
1644
1647
return true ;
1645
- }
1646
1648
}
1647
1649
// Check non-differentiable results.
1648
1650
if (!origFnTy->getResults ()[curIndices.source ]
1649
1651
.getSILStorageType ()
1650
- .isDifferentiable (builder.getModule ())) {
1652
+ .isDifferentiable (builder.getModule ()))
1651
1653
return true ;
1652
- }
1653
1654
return false ;
1654
1655
};
1655
- if (diagnoseNondifferentiableOriginalFunctionType (originalFnSubstTy))
1656
+ if (checkNondifferentiableOriginalFunctionType (originalFnSubstTy))
1656
1657
continue ;
1657
1658
1658
- auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1659
- parameters, source,
1660
- /* differentiationOrder*/ 1 , kind, builder.getModule (),
1661
- LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1659
+ auto assocFnType =
1660
+ originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1661
+ parameters, source, /* differentiationOrder*/ 1 , kind,
1662
+ builder.getModule (),
1663
+ LookUpConformanceInModule (
1664
+ builder.getModule ().getSwiftModule ()));
1662
1665
1663
1666
auto assocFnResultTypes =
1664
1667
assocFnType->getAllResultsType ().castTo <TupleType>();
1665
- assocFnResultTypes
1666
- -> getElement (JVPResultTypes ->getElements ().size () - 1 );
1667
- auto linearMapSILType =
1668
- SILType::getPrimitiveObjectType (
1669
- assocFnResultTypes ->getElement (
1670
- assocFnResultTypes-> getElements (). size () - 1 )
1671
- . getType () ->getCanonicalType ());
1668
+ assocFnResultTypes-> getElement (
1669
+ assocFnResultTypes ->getElements ().size () - 1 );
1670
+ auto linearMapSILType = SILType::getPrimitiveObjectType (
1671
+ assocFnResultTypes
1672
+ ->getElement (assocFnResultTypes-> getElements (). size () - 1 )
1673
+ . getType ( )
1674
+ ->getCanonicalType ());
1672
1675
addLinearMapDecl (ai, linearMapSILType);
1673
1676
}
1674
1677
}
0 commit comments