@@ -668,7 +668,11 @@ class LinearMapInfo {
668
668
return linearMapDecl;
669
669
}
670
670
671
- void prepareLinearMapStructDeclarations (
671
+ // / This takes the declared linear map structs and populates
672
+ // / them with the necessary fields, specifically the linear function (pullback
673
+ // / or differential) of the corresponding original function call in the
674
+ // / original function, and the branching enum.
675
+ void populateLinearMapStructDeclarationFields (
672
676
ADContext &context, const SILAutoDiffIndices &indices,
673
677
SILFunction *assocFn);
674
678
@@ -1484,7 +1488,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
1484
1488
SILBuilder &builder)
1485
1489
: kind(kind), original(original), activityInfo(activityInfo),
1486
1490
typeConverter (context.getTypeConverter()), builder(builder) {
1487
- prepareLinearMapStructDeclarations (context, indices, assocFn);
1491
+ populateLinearMapStructDeclarationFields (context, indices, assocFn);
1488
1492
}
1489
1493
1490
1494
bool LinearMapInfo::shouldBeDifferentiated (
@@ -1519,35 +1523,36 @@ bool LinearMapInfo::shouldBeDifferentiated(
1519
1523
return false ;
1520
1524
}
1521
1525
1522
- void LinearMapInfo::prepareLinearMapStructDeclarations (
1526
+ void LinearMapInfo::populateLinearMapStructDeclarationFields (
1523
1527
ADContext &context, const SILAutoDiffIndices &indices,
1524
1528
SILFunction *assocFn) {
1525
1529
1526
- auto &astCtx = original->getASTContext ();
1530
+ auto &astCtx = original->getASTContext ();
1527
1531
auto *loopAnalysis = context.getPassManager ().getAnalysis <SILLoopAnalysis>();
1528
1532
auto *loopInfo = loopAnalysis->get (original);
1529
1533
1530
- // Get the associated function generic signature.
1534
+ // Get the associated function generic signature.
1531
1535
CanGenericSignature assocFnGenSig = nullptr ;
1532
1536
if (auto *assocFnGenEnv = assocFn->getGenericEnvironment ())
1533
1537
assocFnGenSig =
1534
1538
assocFnGenEnv->getGenericSignature ()->getCanonicalSignature ();
1535
1539
1536
- // Create pullback struct for each original block.
1540
+ // Create linear map struct for each original block.
1537
1541
for (auto &origBB : *original) {
1538
- auto *linearMapStruct = createLinearMapStruct (&origBB, indices, assocFnGenSig);
1542
+ auto *linearMapStruct =
1543
+ createLinearMapStruct (&origBB, indices, assocFnGenSig);
1539
1544
linearMapStructs.insert ({&origBB, linearMapStruct});
1540
1545
}
1541
1546
1542
- // Create branching trace enum for each original block and add it to the
1547
+ // Create branching trace enum for each original block and add it to the
1543
1548
// corresponding struct.
1544
- // TODO(bartchr) : add support for forward mode.
1549
+ // TODO: add support for forward mode.
1545
1550
for (auto &origBB : *original) {
1546
1551
auto *linearMapStruct = getLinearMapStruct (&origBB);
1547
1552
auto *traceEnum =
1548
1553
createBranchingTraceDecl (&origBB, indices, assocFnGenSig);
1549
1554
1550
- // If original block is in a loop, mark branching trace enum as indirect.
1555
+ // If original block is in a loop, mark branching trace enum as indirect.
1551
1556
if (loopInfo->getLoopFor (&origBB))
1552
1557
traceEnum->getAttrs ().add (new (astCtx) IndirectAttr (/* Implicit*/ true ));
1553
1558
branchingTraceDecls.insert ({&origBB, traceEnum});
@@ -1559,7 +1564,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1559
1564
linearMapStructEnumFields.insert ({linearMapStruct, traceEnumField});
1560
1565
}
1561
1566
1562
- // Add the differential function fields to the differential structs.
1567
+ // Add the linear function fields to the linear map structs.
1563
1568
for (auto &origBB : *original) {
1564
1569
for (auto &inst : origBB) {
1565
1570
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
@@ -1578,7 +1583,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1578
1583
if (isInout)
1579
1584
break ;
1580
1585
1581
- // Add linear map to struct for active instructions.
1586
+ // Add linear map to struct for active instructions.
1582
1587
// Do not add it for array functions since those are already linear
1583
1588
// and we don't need to add it to the struct.
1584
1589
if (shouldBeDifferentiated (ai, indices) &&
@@ -1588,30 +1593,30 @@ bool LinearMapInfo::shouldBeDifferentiated(
1588
1593
allResults.append (ai->getIndirectSILResults ().begin (),
1589
1594
ai->getIndirectSILResults ().end ());
1590
1595
1591
- // Check if there are any active results or arguments. If not, skip
1596
+ // Check if there are any active results or arguments. If not, skip
1592
1597
// this instruction.
1593
1598
auto hasActiveResults = llvm::any_of (
1594
- allResults, [&](SILValue res) {
1599
+ allResults, [&](SILValue res) {
1595
1600
return activityInfo.isActive (res, indices);
1596
1601
});
1597
1602
auto hasActiveArguments = llvm::any_of (
1598
- ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1603
+ ai->getArgumentsWithoutIndirectResults (), [&](SILValue arg) {
1599
1604
return activityInfo.isActive (arg, indices);
1600
1605
});
1601
1606
if (!hasActiveResults || !hasActiveArguments)
1602
1607
continue ;
1603
1608
1604
- unsigned source;
1609
+ unsigned source;
1605
1610
AutoDiffIndexSubset *parameters;
1606
1611
1607
- SmallVector<unsigned , 8 > activeParamIndices;
1612
+ SmallVector<unsigned , 8 > activeParamIndices;
1608
1613
SmallVector<unsigned , 8 > activeResultIndices;
1609
1614
collectMinimalIndicesForFunctionCall (
1610
1615
ai, allResults, indices, activityInfo, activeParamIndices,
1611
1616
activeResultIndices);
1612
1617
source = activeResultIndices.front ();
1613
1618
1614
- // If function is already marked differentiable, differentiate WRT
1619
+ // If function is already marked differentiable, differentiate WRT
1615
1620
// all parameters.
1616
1621
auto originalFnSubstTy = ai->getSubstCalleeType ();;
1617
1622
if (originalFnSubstTy->isDifferentiable ()) {
@@ -1628,7 +1633,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1628
1633
ai->getArgumentsWithoutIndirectResults ().size (),
1629
1634
activeParamIndices));
1630
1635
1631
- // Check and diagnose non-differentiable original function type.
1636
+ // Check and diagnose non-differentiable original function type.
1632
1637
auto diagnoseNondifferentiableOriginalFunctionType =
1633
1638
[&](CanSILFunctionType origFnTy) {
1634
1639
// Check and diagnose non-differentiable arguments.
@@ -1640,7 +1645,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
1640
1645
return true ;
1641
1646
}
1642
1647
}
1643
- // Check and diagnose non-differentiable results.
1648
+ // Check non-differentiable results.
1644
1649
if (!origFnTy->getResults ()[curIndices.source ]
1645
1650
.getSILStorageType ()
1646
1651
.isDifferentiable (builder.getModule ())) {
@@ -1651,19 +1656,21 @@ bool LinearMapInfo::shouldBeDifferentiated(
1651
1656
if (diagnoseNondifferentiableOriginalFunctionType (originalFnSubstTy))
1652
1657
continue ;
1653
1658
1654
- auto JVPType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1659
+ auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType (
1655
1660
parameters, source,
1656
1661
/* differentiationOrder*/ 1 , kind, builder.getModule (),
1657
1662
LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1658
1663
1659
- auto JVPResultTypes = JVPType->getAllResultsType ().castTo <TupleType>();
1660
- JVPResultTypes->getElement (JVPResultTypes->getElements ().size () - 1 );
1661
- auto differentialSILType =
1664
+ auto assocFnResultTypes =
1665
+ assocFnType->getAllResultsType ().castTo <TupleType>();
1666
+ assocFnResultTypes
1667
+ ->getElement (assocFnResultTypes->getElements ().size () - 1 );
1668
+ auto linearMapSILType =
1662
1669
SILType::getPrimitiveObjectType (
1663
- JVPResultTypes ->getElement (
1664
- JVPResultTypes ->getElements ().size () - 1 )
1670
+ assocFnResultTypes ->getElement (
1671
+ assocFnResultTypes ->getElements ().size () - 1 )
1665
1672
.getType ()->getCanonicalType ());
1666
- addLinearMapDecl (ai, differentialSILType );
1673
+ addLinearMapDecl (ai, linearMapSILType );
1667
1674
}
1668
1675
}
1669
1676
}
0 commit comments