@@ -1491,8 +1491,8 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
1491
1491
populateLinearMapStructDeclarationFields (context, indices, assocFn);
1492
1492
}
1493
1493
1494
- bool LinearMapInfo::shouldBeDifferentiated (
1495
- ApplyInst *ai, const SILAutoDiffIndices &indices) {
1494
+ bool LinearMapInfo::shouldBeDifferentiated (ApplyInst *ai,
1495
+ const SILAutoDiffIndices &indices) {
1496
1496
// Anything with an active result should be differentiated.
1497
1497
if (llvm::any_of (ai->getResults (), [&](SILValue val) {
1498
1498
return activityInfo.isActive (val, indices);
@@ -1618,9 +1618,10 @@ bool LinearMapInfo::shouldBeDifferentiated(
1618
1618
1619
1619
// If function is already marked differentiable, differentiate WRT
1620
1620
// all parameters.
1621
- auto originalFnSubstTy = ai->getSubstCalleeType ();;
1621
+ auto originalFnSubstTy = ai->getSubstCalleeType ();
1622
1622
if (originalFnSubstTy->isDifferentiable ()) {
1623
- parameters = originalFnSubstTy->getDifferentiationParameterIndices ();
1623
+ parameters =
1624
+ originalFnSubstTy->getDifferentiationParameterIndices ();
1624
1625
} else {
1625
1626
parameters = AutoDiffIndexSubset::get (
1626
1627
original->getASTContext (),
@@ -1633,11 +1634,12 @@ bool LinearMapInfo::shouldBeDifferentiated(
1633
1634
ai->getArgumentsWithoutIndirectResults ().size (),
1634
1635
activeParamIndices));
1635
1636
1636
- // Check and diagnose non-differentiable original function type.
1637
- auto diagnoseNondifferentiableOriginalFunctionType =
1637
+ // Check for non-differentiable original function type.
1638
+ auto checkNondifferentiableOriginalFunctionType =
1638
1639
[&](CanSILFunctionType origFnTy) {
1639
1640
// Check and diagnose non-differentiable arguments.
1640
- for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1641
+ for (unsigned paramIndex :
1642
+ range (origFnTy->getNumParameters ())) {
1641
1643
if (curIndices.isWrtParameter (paramIndex) &&
1642
1644
!origFnTy->getParameters ()[paramIndex]
1643
1645
.getSILStorageType ()
@@ -1653,23 +1655,25 @@ bool LinearMapInfo::shouldBeDifferentiated(
1653
1655
}
1654
1656
return false ;
1655
1657
};
1656
- if (diagnoseNondifferentiableOriginalFunctionType (originalFnSubstTy))
1658
+ if (checkNondifferentiableOriginalFunctionType (originalFnSubstTy))
1657
1659
continue ;
1658
1660
1659
- auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1660
- parameters, source,
1661
- /* differentiationOrder*/ 1 , kind, builder.getModule (),
1662
- LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1661
+ auto assocFnType =
1662
+ originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1663
+ parameters, source,
1664
+ /* differentiationOrder*/ 1 , kind, builder.getModule (),
1665
+ LookUpConformanceInModule (
1666
+ builder.getModule ().getSwiftModule ()));
1663
1667
1664
1668
auto assocFnResultTypes =
1665
1669
assocFnType->getAllResultsType ().castTo <TupleType>();
1666
- assocFnResultTypes
1667
- -> getElement ( assocFnResultTypes->getElements ().size () - 1 );
1668
- auto linearMapSILType =
1669
- SILType::getPrimitiveObjectType (
1670
- assocFnResultTypes ->getElement (
1671
- assocFnResultTypes-> getElements (). size () - 1 )
1672
- . getType () ->getCanonicalType ());
1670
+ assocFnResultTypes-> getElement (
1671
+ assocFnResultTypes->getElements ().size () - 1 );
1672
+ auto linearMapSILType = SILType::getPrimitiveObjectType (
1673
+ assocFnResultTypes
1674
+ ->getElement (assocFnResultTypes-> getElements (). size () - 1 )
1675
+ . getType ( )
1676
+ ->getCanonicalType ());
1673
1677
addLinearMapDecl (ai, linearMapSILType);
1674
1678
}
1675
1679
}
0 commit comments