@@ -388,7 +388,7 @@ class LinearMapInfo {
388
388
// / The original function.
389
389
SILFunction *const original;
390
390
391
- // / Activitiy info of the original function.
391
+ // / Activity info of the original function.
392
392
const DifferentiableActivityInfo &activityInfo;
393
393
394
394
// / Mapping from original basic blocks to linear map structs.
@@ -668,6 +668,8 @@ class LinearMapInfo {
668
668
return linearMapDecl;
669
669
}
670
670
671
+ void addLinearMapToStruct (ApplyInst *ai, const SILAutoDiffIndices &indices);
672
+
671
673
// / This takes the declared linear map structs and populates
672
674
// / them with the necessary fields, specifically the linear function (pullback
673
675
// / or differential) of the corresponding original function call in the
@@ -1522,6 +1524,92 @@ bool LinearMapInfo::shouldBeDifferentiated(ApplyInst *ai,
1522
1524
return false ;
1523
1525
}
1524
1526
1527
+
1528
+ // / Takes an `apply` instruction and adds its linear map function to the
1529
+ // / linear map struct if it's active.
1530
+ void LinearMapInfo::addLinearMapToStruct (ApplyInst *ai,
1531
+ const SILAutoDiffIndices &indices) {
1532
+ SmallVector<SILValue, 4 > allResults;
1533
+ allResults.push_back (ai);
1534
+ allResults.append (ai->getIndirectSILResults ().begin (),
1535
+ ai->getIndirectSILResults ().end ());
1536
+
1537
+ // Check if there are any active results or arguments. If not, skip
1538
+ // this instruction.
1539
+ auto hasActiveResults = llvm::any_of (
1540
+ allResults, [&](SILValue res) {
1541
+ return activityInfo.isActive (res, indices);
1542
+ });
1543
+ auto hasActiveArguments = llvm::any_of (
1544
+ ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1545
+ return activityInfo.isActive (arg, indices);
1546
+ });
1547
+ if (!hasActiveResults || !hasActiveArguments)
1548
+ return ;
1549
+
1550
+ unsigned source;
1551
+ AutoDiffIndexSubset *parameters;
1552
+
1553
+ SmallVector<unsigned , 8 > activeParamIndices;
1554
+ SmallVector<unsigned , 8 > activeResultIndices;
1555
+ collectMinimalIndicesForFunctionCall (
1556
+ ai, allResults, indices, activityInfo, activeParamIndices,
1557
+ activeResultIndices);
1558
+ source = activeResultIndices.front ();
1559
+
1560
+ // If function is already marked differentiable, differentiate W.R.T.
1561
+ // all parameters.
1562
+ auto originalFnSubstTy = ai->getSubstCalleeType ();
1563
+ if (originalFnSubstTy->isDifferentiable ()) {
1564
+ parameters = originalFnSubstTy->getDifferentiationParameterIndices ();
1565
+ } else {
1566
+ parameters = AutoDiffIndexSubset::get (
1567
+ original->getASTContext (),
1568
+ ai->getArgumentsWithoutIndirectResults ().size (),
1569
+ activeParamIndices);
1570
+ }
1571
+ SILAutoDiffIndices curIndices (activeResultIndices.front (),
1572
+ AutoDiffIndexSubset::get (
1573
+ builder.getASTContext (),
1574
+ ai->getArgumentsWithoutIndirectResults ().size (),
1575
+ activeParamIndices));
1576
+
1577
+ // Check for non-differentiable original function type.
1578
+ auto checkNondifferentiableOriginalFunctionType =
1579
+ [&](CanSILFunctionType origFnTy) {
1580
+ // Check and diagnose non-differentiable arguments.
1581
+ for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1582
+ if (curIndices.isWrtParameter (paramIndex) &&
1583
+ !origFnTy->getParameters ()[paramIndex]
1584
+ .getSILStorageType ()
1585
+ .isDifferentiable (builder.getModule ()))
1586
+ return true ;
1587
+ }
1588
+ // Check non-differentiable results.
1589
+ if (!origFnTy->getResults ()[curIndices.source ]
1590
+ .getSILStorageType ()
1591
+ .isDifferentiable (builder.getModule ()))
1592
+ return true ;
1593
+ return false ;
1594
+ };
1595
+ if (checkNondifferentiableOriginalFunctionType (originalFnSubstTy))
1596
+ return ;
1597
+
1598
+ auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1599
+ parameters, source, /* differentiationOrder*/ 1 , kind, builder.getModule (),
1600
+ LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1601
+
1602
+ auto assocFnResultTypes =
1603
+ assocFnType->getAllResultsType ().castTo <TupleType>();
1604
+ assocFnResultTypes->getElement (assocFnResultTypes->getElements ().size () - 1 );
1605
+ auto linearMapSILType = SILType::getPrimitiveObjectType (
1606
+ assocFnResultTypes
1607
+ ->getElement (assocFnResultTypes->getElements ().size () - 1 )
1608
+ .getType ()
1609
+ ->getCanonicalType ());
1610
+ addLinearMapDecl (ai, linearMapSILType);
1611
+ }
1612
+
1525
1613
void LinearMapInfo::populateLinearMapStructDeclarationFields (
1526
1614
ADContext &context, const SILAutoDiffIndices &indices,
1527
1615
SILFunction *assocFn) {
@@ -1548,8 +1636,7 @@ void LinearMapInfo::populateLinearMapStructDeclarationFields(
1548
1636
// TODO: add support for forward mode.
1549
1637
for (auto &origBB : *original) {
1550
1638
auto *linearMapStruct = getLinearMapStruct (&origBB);
1551
- auto *traceEnum =
1552
- createBranchingTraceDecl (&origBB, indices, assocFnGenSig);
1639
+ auto *traceEnum = createBranchingTraceDecl (&origBB, indices, assocFnGenSig);
1553
1640
1554
1641
// If original block is in a loop, mark branching trace enum as indirect.
1555
1642
if (loopInfo->getLoopFor (&origBB))
@@ -1587,93 +1674,8 @@ void LinearMapInfo::populateLinearMapStructDeclarationFields(
1587
1674
// Do not add it for array functions since those are already linear
1588
1675
// and we don't need to add it to the struct.
1589
1676
if (shouldBeDifferentiated (ai, indices) &&
1590
- !ai->hasSemantics (" array.uninitialized_intrinsic" )) {
1591
- SmallVector<SILValue, 4 > allResults;
1592
- allResults.push_back (ai);
1593
- allResults.append (ai->getIndirectSILResults ().begin (),
1594
- ai->getIndirectSILResults ().end ());
1595
-
1596
- // Check if there are any active results or arguments. If not, skip
1597
- // this instruction.
1598
- auto hasActiveResults = llvm::any_of (
1599
- allResults, [&](SILValue res) {
1600
- return activityInfo.isActive (res, indices);
1601
- });
1602
- auto hasActiveArguments = llvm::any_of (
1603
- ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1604
- return activityInfo.isActive (arg, indices);
1605
- });
1606
- if (!hasActiveResults || !hasActiveArguments)
1607
- continue ;
1608
-
1609
- unsigned source;
1610
- AutoDiffIndexSubset *parameters;
1611
-
1612
- SmallVector<unsigned , 8 > activeParamIndices;
1613
- SmallVector<unsigned , 8 > activeResultIndices;
1614
- collectMinimalIndicesForFunctionCall (
1615
- ai, allResults, indices, activityInfo, activeParamIndices,
1616
- activeResultIndices);
1617
- source = activeResultIndices.front ();
1618
-
1619
- // If function is already marked differentiable, differentiate WRT
1620
- // all parameters.
1621
- auto originalFnSubstTy = ai->getSubstCalleeType ();
1622
- if (originalFnSubstTy->isDifferentiable ()) {
1623
- parameters =
1624
- originalFnSubstTy->getDifferentiationParameterIndices ();
1625
- } else {
1626
- parameters = AutoDiffIndexSubset::get (
1627
- original->getASTContext (),
1628
- ai->getArgumentsWithoutIndirectResults ().size (),
1629
- activeParamIndices);
1630
- }
1631
- SILAutoDiffIndices curIndices (activeResultIndices.front (),
1632
- AutoDiffIndexSubset::get (
1633
- builder.getASTContext (),
1634
- ai->getArgumentsWithoutIndirectResults ().size (),
1635
- activeParamIndices));
1636
-
1637
- // Check for non-differentiable original function type.
1638
- auto checkNondifferentiableOriginalFunctionType =
1639
- [&](CanSILFunctionType origFnTy) {
1640
- // Check and diagnose non-differentiable arguments.
1641
- for (unsigned paramIndex :
1642
- range (origFnTy->getNumParameters ())) {
1643
- if (curIndices.isWrtParameter (paramIndex) &&
1644
- !origFnTy->getParameters ()[paramIndex]
1645
- .getSILStorageType ()
1646
- .isDifferentiable (builder.getModule ()))
1647
- return true ;
1648
- }
1649
- // Check non-differentiable results.
1650
- if (!origFnTy->getResults ()[curIndices.source ]
1651
- .getSILStorageType ()
1652
- .isDifferentiable (builder.getModule ()))
1653
- return true ;
1654
- return false ;
1655
- };
1656
- if (checkNondifferentiableOriginalFunctionType (originalFnSubstTy))
1657
- continue ;
1658
-
1659
- auto assocFnType =
1660
- originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1661
- parameters, source, /* differentiationOrder*/ 1 , kind,
1662
- builder.getModule (),
1663
- LookUpConformanceInModule (
1664
- builder.getModule ().getSwiftModule ()));
1665
-
1666
- auto assocFnResultTypes =
1667
- assocFnType->getAllResultsType ().castTo <TupleType>();
1668
- assocFnResultTypes->getElement (
1669
- assocFnResultTypes->getElements ().size () - 1 );
1670
- auto linearMapSILType = SILType::getPrimitiveObjectType (
1671
- assocFnResultTypes
1672
- ->getElement (assocFnResultTypes->getElements ().size () - 1 )
1673
- .getType ()
1674
- ->getCanonicalType ());
1675
- addLinearMapDecl (ai, linearMapSILType);
1676
- }
1677
+ !ai->hasSemantics (" array.uninitialized_intrinsic" ))
1678
+ addLinearMapToStruct (ai, indices);
1677
1679
}
1678
1680
}
1679
1681
}
0 commit comments