Skip to content

Commit 608844b

Browse files
committed
PR feedback.
1 parent 725308a commit 608844b

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,10 @@ class LinearMapInfo {
668668
return linearMapDecl;
669669
}
670670

671-
void prepareLinearMapStructDeclarations(
671+
/// This takes the declared linear map structs per basic block, and populates them with the necessary
672+
/// fields, specifically the linear function (pullback or differential) of the corresponding original function call
673+
/// in the original function, and the branching enum.
674+
void populateLinearMapStructDeclarationFields(
672675
ADContext &context, const SILAutoDiffIndices &indices,
673676
SILFunction *assocFn);
674677

@@ -1484,7 +1487,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
14841487
SILBuilder &builder)
14851488
: kind(kind), original(original), activityInfo(activityInfo),
14861489
typeConverter(context.getTypeConverter()), builder(builder) {
1487-
prepareLinearMapStructDeclarations(context, indices, assocFn);
1490+
populateLinearMapStructDeclarationFields(context, indices, assocFn);
14881491
}
14891492

14901493
bool LinearMapInfo::shouldBeDifferentiated(
@@ -1519,35 +1522,36 @@ bool LinearMapInfo::shouldBeDifferentiated(
15191522
return false;
15201523
}
15211524

1522-
void LinearMapInfo::prepareLinearMapStructDeclarations(
1525+
void LinearMapInfo::populateLinearMapStructDeclarationFields(
15231526
ADContext &context, const SILAutoDiffIndices &indices,
15241527
SILFunction *assocFn) {
15251528

1526-
auto &astCtx = original->getASTContext();
1529+
auto &astCtx = original->getASTContext();
15271530
auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
15281531
auto *loopInfo = loopAnalysis->get(original);
15291532

1530-
// Get the associated function generic signature.
1533+
// Get the associated function generic signature.
15311534
CanGenericSignature assocFnGenSig = nullptr;
15321535
if (auto *assocFnGenEnv = assocFn->getGenericEnvironment())
15331536
assocFnGenSig =
15341537
assocFnGenEnv->getGenericSignature()->getCanonicalSignature();
15351538

1536-
// Create pullback struct for each original block.
1539+
// Create linear map struct for each original block.
15371540
for (auto &origBB : *original) {
1538-
auto *linearMapStruct = createLinearMapStruct(&origBB, indices, assocFnGenSig);
1541+
auto *linearMapStruct =
1542+
createLinearMapStruct(&origBB, indices, assocFnGenSig);
15391543
linearMapStructs.insert({&origBB, linearMapStruct});
15401544
}
15411545

1542-
// Create branching trace enum for each original block and add it to the
1546+
// Create branching trace enum for each original block and add it to the
15431547
// corresponding struct.
1544-
// TODO(bartchr): add support for forward mode.
1548+
// TODO: add support for forward mode.
15451549
for (auto &origBB : *original) {
15461550
auto *linearMapStruct = getLinearMapStruct(&origBB);
15471551
auto *traceEnum =
15481552
createBranchingTraceDecl(&origBB, indices, assocFnGenSig);
15491553

1550-
// If original block is in a loop, mark branching trace enum as indirect.
1554+
// If original block is in a loop, mark branching trace enum as indirect.
15511555
if (loopInfo->getLoopFor(&origBB))
15521556
traceEnum->getAttrs().add(new (astCtx) IndirectAttr(/*Implicit*/ true));
15531557
branchingTraceDecls.insert({&origBB, traceEnum});
@@ -1559,7 +1563,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
15591563
linearMapStructEnumFields.insert({linearMapStruct, traceEnumField});
15601564
}
15611565

1562-
// Add the differential function fields to the differential structs.
1566+
// Add the linear function fields to the linear map structs.
15631567
for (auto &origBB : *original) {
15641568
for (auto &inst : origBB) {
15651569
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
@@ -1578,7 +1582,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
15781582
if (isInout)
15791583
break;
15801584

1581-
// Add linear map to struct for active instructions.
1585+
// Add linear map to struct for active instructions.
15821586
// Do not add it for array functions since those are already linear
15831587
// and we don't need to add it to the struct.
15841588
if (shouldBeDifferentiated(ai, indices) &&
@@ -1588,30 +1592,30 @@ bool LinearMapInfo::shouldBeDifferentiated(
15881592
allResults.append(ai->getIndirectSILResults().begin(),
15891593
ai->getIndirectSILResults().end());
15901594

1591-
// Check if there are any active results or arguments. If not, skip
1595+
// Check if there are any active results or arguments. If not, skip
15921596
// this instruction.
15931597
auto hasActiveResults = llvm::any_of(
1594-
allResults, [&](SILValue res) {
1598+
allResults, [&](SILValue res) {
15951599
return activityInfo.isActive(res, indices);
15961600
});
15971601
auto hasActiveArguments = llvm::any_of(
1598-
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
1602+
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
15991603
return activityInfo.isActive(arg, indices);
16001604
});
16011605
if (!hasActiveResults || !hasActiveArguments)
16021606
continue;
16031607

1604-
unsigned source;
1608+
unsigned source;
16051609
AutoDiffIndexSubset *parameters;
16061610

1607-
SmallVector<unsigned, 8> activeParamIndices;
1611+
SmallVector<unsigned, 8> activeParamIndices;
16081612
SmallVector<unsigned, 8> activeResultIndices;
16091613
collectMinimalIndicesForFunctionCall(
16101614
ai, allResults, indices, activityInfo, activeParamIndices,
16111615
activeResultIndices);
16121616
source = activeResultIndices.front();
16131617

1614-
// If function is already marked differentiable, differentiate WRT
1618+
// If function is already marked differentiable, differentiate WRT
16151619
// all parameters.
16161620
auto originalFnSubstTy = ai->getSubstCalleeType();;
16171621
if (originalFnSubstTy->isDifferentiable()) {
@@ -1628,7 +1632,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
16281632
ai->getArgumentsWithoutIndirectResults().size(),
16291633
activeParamIndices));
16301634

1631-
// Check and diagnose non-differentiable original function type.
1635+
// Check and diagnose non-differentiable original function type.
16321636
auto diagnoseNondifferentiableOriginalFunctionType =
16331637
[&](CanSILFunctionType origFnTy) {
16341638
// Check and diagnose non-differentiable arguments.
@@ -1640,7 +1644,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
16401644
return true;
16411645
}
16421646
}
1643-
// Check and diagnose non-differentiable results.
1647+
// Check non-differentiable results.
16441648
if (!origFnTy->getResults()[curIndices.source]
16451649
.getSILStorageType()
16461650
.isDifferentiable(builder.getModule())) {
@@ -1651,19 +1655,21 @@ bool LinearMapInfo::shouldBeDifferentiated(
16511655
if (diagnoseNondifferentiableOriginalFunctionType(originalFnSubstTy))
16521656
continue;
16531657

1654-
auto JVPType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1658+
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
16551659
parameters, source,
16561660
/*differentiationOrder*/ 1, kind, builder.getModule(),
16571661
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
16581662

1659-
auto JVPResultTypes = JVPType->getAllResultsType().castTo<TupleType>();
1660-
JVPResultTypes->getElement(JVPResultTypes->getElements().size() - 1);
1661-
auto differentialSILType =
1663+
auto assocFnResultTypes =
1664+
assocFnType->getAllResultsType().castTo<TupleType>();
1665+
assocFnResultTypes
1666+
->getElement(JVPResultTypes->getElements().size() - 1);
1667+
auto linearMapSILType =
16621668
SILType::getPrimitiveObjectType(
1663-
JVPResultTypes->getElement(
1664-
JVPResultTypes->getElements().size() - 1)
1669+
assocFnResultTypes->getElement(
1670+
assocFnResultTypes->getElements().size() - 1)
16651671
.getType()->getCanonicalType());
1666-
addLinearMapDecl(ai, differentialSILType);
1672+
addLinearMapDecl(ai, linearMapSILType);
16671673
}
16681674
}
16691675
}

0 commit comments

Comments
 (0)