@@ -668,7 +668,10 @@ class LinearMapInfo {
668
668
return linearMapDecl;
669
669
}
670
670
671
- void prepareLinearMapStructDeclarations (
671
+ // / This takes the declared linear map structs per basic block, and populates them with the necessary
672
+ // / fields, specifically the linear function (pullback or differential) of the corresponding original function call
673
+ // / in the original function, and the branching enum.
674
+ void populateLinearMapStructDeclarationFields (
672
675
ADContext &context, const SILAutoDiffIndices &indices,
673
676
SILFunction *assocFn);
674
677
@@ -1484,7 +1487,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
1484
1487
SILBuilder &builder)
1485
1488
: kind(kind), original(original), activityInfo(activityInfo),
1486
1489
typeConverter (context.getTypeConverter()), builder(builder) {
1487
- prepareLinearMapStructDeclarations (context, indices, assocFn);
1490
+ populateLinearMapStructDeclarationFields (context, indices, assocFn);
1488
1491
}
1489
1492
1490
1493
bool LinearMapInfo::shouldBeDifferentiated (
@@ -1519,35 +1522,36 @@ bool LinearMapInfo::shouldBeDifferentiated(
1519
1522
return false ;
1520
1523
}
1521
1524
1522
- void LinearMapInfo::prepareLinearMapStructDeclarations (
1525
+ void LinearMapInfo::populateLinearMapStructDeclarationFields (
1523
1526
ADContext &context, const SILAutoDiffIndices &indices,
1524
1527
SILFunction *assocFn) {
1525
1528
1526
- auto &astCtx = original->getASTContext ();
1529
+ auto &astCtx = original->getASTContext ();
1527
1530
auto *loopAnalysis = context.getPassManager ().getAnalysis <SILLoopAnalysis>();
1528
1531
auto *loopInfo = loopAnalysis->get (original);
1529
1532
1530
- // Get the associated function generic signature.
1533
+ // Get the associated function generic signature.
1531
1534
CanGenericSignature assocFnGenSig = nullptr ;
1532
1535
if (auto *assocFnGenEnv = assocFn->getGenericEnvironment ())
1533
1536
assocFnGenSig =
1534
1537
assocFnGenEnv->getGenericSignature ()->getCanonicalSignature ();
1535
1538
1536
- // Create pullback struct for each original block.
1539
+ // Create linear map struct for each original block.
1537
1540
for (auto &origBB : *original) {
1538
- auto *linearMapStruct = createLinearMapStruct (&origBB, indices, assocFnGenSig);
1541
+ auto *linearMapStruct =
1542
+ createLinearMapStruct (&origBB, indices, assocFnGenSig);
1539
1543
linearMapStructs.insert ({&origBB, linearMapStruct});
1540
1544
}
1541
1545
1542
- // Create branching trace enum for each original block and add it to the
1546
+ // Create branching trace enum for each original block and add it to the
1543
1547
// corresponding struct.
1544
- // TODO(bartchr) : add support for forward mode.
1548
+ // TODO: add support for forward mode.
1545
1549
for (auto &origBB : *original) {
1546
1550
auto *linearMapStruct = getLinearMapStruct (&origBB);
1547
1551
auto *traceEnum =
1548
1552
createBranchingTraceDecl (&origBB, indices, assocFnGenSig);
1549
1553
1550
- // If original block is in a loop, mark branching trace enum as indirect.
1554
+ // If original block is in a loop, mark branching trace enum as indirect.
1551
1555
if (loopInfo->getLoopFor (&origBB))
1552
1556
traceEnum->getAttrs ().add (new (astCtx) IndirectAttr (/* Implicit*/ true ));
1553
1557
branchingTraceDecls.insert ({&origBB, traceEnum});
@@ -1559,7 +1563,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1559
1563
linearMapStructEnumFields.insert ({linearMapStruct, traceEnumField});
1560
1564
}
1561
1565
1562
- // Add the differential function fields to the differential structs.
1566
+ // Add the linear function fields to the linear map structs.
1563
1567
for (auto &origBB : *original) {
1564
1568
for (auto &inst : origBB) {
1565
1569
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
@@ -1578,7 +1582,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1578
1582
if (isInout)
1579
1583
break ;
1580
1584
1581
- // Add linear map to struct for active instructions.
1585
+ // Add linear map to struct for active instructions.
1582
1586
// Do not add it for array functions since those are already linear
1583
1587
// and we don't need to add it to the struct.
1584
1588
if (shouldBeDifferentiated (ai, indices) &&
@@ -1588,30 +1592,30 @@ bool LinearMapInfo::shouldBeDifferentiated(
1588
1592
allResults.append (ai->getIndirectSILResults ().begin (),
1589
1593
ai->getIndirectSILResults ().end ());
1590
1594
1591
- // Check if there are any active results or arguments. If not, skip
1595
+ // Check if there are any active results or arguments. If not, skip
1592
1596
// this instruction.
1593
1597
auto hasActiveResults = llvm::any_of (
1594
- allResults, [&](SILValue res) {
1598
+ allResults, [&](SILValue res) {
1595
1599
return activityInfo.isActive (res, indices);
1596
1600
});
1597
1601
auto hasActiveArguments = llvm::any_of (
1598
- ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1602
+ ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1599
1603
return activityInfo.isActive (arg, indices);
1600
1604
});
1601
1605
if (!hasActiveResults || !hasActiveArguments)
1602
1606
continue ;
1603
1607
1604
- unsigned source;
1608
+ unsigned source;
1605
1609
AutoDiffIndexSubset *parameters;
1606
1610
1607
- SmallVector<unsigned , 8 > activeParamIndices;
1611
+ SmallVector<unsigned , 8 > activeParamIndices;
1608
1612
SmallVector<unsigned , 8 > activeResultIndices;
1609
1613
collectMinimalIndicesForFunctionCall (
1610
1614
ai, allResults, indices, activityInfo, activeParamIndices,
1611
1615
activeResultIndices);
1612
1616
source = activeResultIndices.front ();
1613
1617
1614
- // If function is already marked differentiable, differentiate WRT
1618
+ // If function is already marked differentiable, differentiate WRT
1615
1619
// all parameters.
1616
1620
auto originalFnSubstTy = ai->getSubstCalleeType ();;
1617
1621
if (originalFnSubstTy->isDifferentiable ()) {
@@ -1628,7 +1632,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1628
1632
ai->getArgumentsWithoutIndirectResults ().size (),
1629
1633
activeParamIndices));
1630
1634
1631
- // Check and diagnose non-differentiable original function type.
1635
+ // Check and diagnose non-differentiable original function type.
1632
1636
auto diagnoseNondifferentiableOriginalFunctionType =
1633
1637
[&](CanSILFunctionType origFnTy) {
1634
1638
// Check and diagnose non-differentiable arguments.
@@ -1640,7 +1644,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1640
1644
return true ;
1641
1645
}
1642
1646
}
1643
- // Check and diagnose non-differentiable results.
1647
+ // Check non-differentiable results.
1644
1648
if (!origFnTy->getResults ()[curIndices.source ]
1645
1649
.getSILStorageType ()
1646
1650
.isDifferentiable (builder.getModule ())) {
@@ -1651,19 +1655,21 @@ bool LinearMapInfo::shouldBeDifferentiated(
1651
1655
if (diagnoseNondifferentiableOriginalFunctionType (originalFnSubstTy))
1652
1656
continue ;
1653
1657
1654
- auto JVPType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1658
+ auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1655
1659
parameters, source,
1656
1660
/* differentiationOrder*/ 1 , kind, builder.getModule (),
1657
1661
LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1658
1662
1659
- auto JVPResultTypes = JVPType->getAllResultsType ().castTo <TupleType>();
1660
- JVPResultTypes->getElement (JVPResultTypes->getElements ().size () - 1 );
1661
- auto differentialSILType =
1663
+ auto assocFnResultTypes =
1664
+ assocFnType->getAllResultsType ().castTo <TupleType>();
1665
+ assocFnResultTypes
1666
+ ->getElement (JVPResultTypes->getElements ().size () - 1 );
1667
+ auto linearMapSILType =
1662
1668
SILType::getPrimitiveObjectType (
1663
- JVPResultTypes ->getElement (
1664
- JVPResultTypes ->getElements ().size () - 1 )
1669
+ assocFnResultTypes ->getElement (
1670
+ assocFnResultTypes ->getElements ().size () - 1 )
1665
1671
.getType ()->getCanonicalType ());
1666
- addLinearMapDecl (ai, differentialSILType );
1672
+ addLinearMapDecl (ai, linearMapSILType );
1667
1673
}
1668
1674
}
1669
1675
}
0 commit comments