Skip to content

Commit 1a7e8b6

Browse files
committed
PR feedback.
1 parent 725308a commit 1a7e8b6

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,11 @@ class LinearMapInfo {
668668
return linearMapDecl;
669669
}
670670

671-
void prepareLinearMapStructDeclarations(
671+
/// This takes the declared linear map structs and populates
672+
/// them with the necessary fields, specifically the linear function (pullback
673+
/// or differential) of the corresponding original function call in the
674+
/// original function, and the branching enum.
675+
void populateLinearMapStructDeclarationFields(
672676
ADContext &context, const SILAutoDiffIndices &indices,
673677
SILFunction *assocFn);
674678

@@ -1484,7 +1488,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
14841488
SILBuilder &builder)
14851489
: kind(kind), original(original), activityInfo(activityInfo),
14861490
typeConverter(context.getTypeConverter()), builder(builder) {
1487-
prepareLinearMapStructDeclarations(context, indices, assocFn);
1491+
populateLinearMapStructDeclarationFields(context, indices, assocFn);
14881492
}
14891493

14901494
bool LinearMapInfo::shouldBeDifferentiated(
@@ -1519,35 +1523,36 @@ bool LinearMapInfo::shouldBeDifferentiated(
15191523
return false;
15201524
}
15211525

1522-
void LinearMapInfo::prepareLinearMapStructDeclarations(
1526+
void LinearMapInfo::populateLinearMapStructDeclarationFields(
15231527
ADContext &context, const SILAutoDiffIndices &indices,
15241528
SILFunction *assocFn) {
15251529

1526-
auto &astCtx = original->getASTContext();
1530+
auto &astCtx = original->getASTContext();
15271531
auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
15281532
auto *loopInfo = loopAnalysis->get(original);
15291533

1530-
// Get the associated function generic signature.
1534+
// Get the associated function generic signature.
15311535
CanGenericSignature assocFnGenSig = nullptr;
15321536
if (auto *assocFnGenEnv = assocFn->getGenericEnvironment())
15331537
assocFnGenSig =
15341538
assocFnGenEnv->getGenericSignature()->getCanonicalSignature();
15351539

1536-
// Create pullback struct for each original block.
1540+
// Create linear map struct for each original block.
15371541
for (auto &origBB : *original) {
1538-
auto *linearMapStruct = createLinearMapStruct(&origBB, indices, assocFnGenSig);
1542+
auto *linearMapStruct =
1543+
createLinearMapStruct(&origBB, indices, assocFnGenSig);
15391544
linearMapStructs.insert({&origBB, linearMapStruct});
15401545
}
15411546

1542-
// Create branching trace enum for each original block and add it to the
1547+
// Create branching trace enum for each original block and add it to the
15431548
// corresponding struct.
1544-
// TODO(bartchr): add support for forward mode.
1549+
// TODO: add support for forward mode.
15451550
for (auto &origBB : *original) {
15461551
auto *linearMapStruct = getLinearMapStruct(&origBB);
15471552
auto *traceEnum =
15481553
createBranchingTraceDecl(&origBB, indices, assocFnGenSig);
15491554

1550-
// If original block is in a loop, mark branching trace enum as indirect.
1555+
// If original block is in a loop, mark branching trace enum as indirect.
15511556
if (loopInfo->getLoopFor(&origBB))
15521557
traceEnum->getAttrs().add(new (astCtx) IndirectAttr(/*Implicit*/ true));
15531558
branchingTraceDecls.insert({&origBB, traceEnum});
@@ -1559,7 +1564,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
15591564
linearMapStructEnumFields.insert({linearMapStruct, traceEnumField});
15601565
}
15611566

1562-
// Add the differential function fields to the differential structs.
1567+
// Add the linear function fields to the linear map structs.
15631568
for (auto &origBB : *original) {
15641569
for (auto &inst : origBB) {
15651570
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
@@ -1578,7 +1583,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
15781583
if (isInout)
15791584
break;
15801585

1581-
// Add linear map to struct for active instructions.
1586+
// Add linear map to struct for active instructions.
15821587
// Do not add it for array functions since those are already linear
15831588
// and we don't need to add it to the struct.
15841589
if (shouldBeDifferentiated(ai, indices) &&
@@ -1588,30 +1593,30 @@ bool LinearMapInfo::shouldBeDifferentiated(
15881593
allResults.append(ai->getIndirectSILResults().begin(),
15891594
ai->getIndirectSILResults().end());
15901595

1591-
// Check if there are any active results or arguments. If not, skip
1596+
// Check if there are any active results or arguments. If not, skip
15921597
// this instruction.
15931598
auto hasActiveResults = llvm::any_of(
1594-
allResults, [&](SILValue res) {
1599+
allResults, [&](SILValue res) {
15951600
return activityInfo.isActive(res, indices);
15961601
});
15971602
auto hasActiveArguments = llvm::any_of(
1598-
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
1603+
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
15991604
return activityInfo.isActive(arg, indices);
16001605
});
16011606
if (!hasActiveResults || !hasActiveArguments)
16021607
continue;
16031608

1604-
unsigned source;
1609+
unsigned source;
16051610
AutoDiffIndexSubset *parameters;
16061611

1607-
SmallVector<unsigned, 8> activeParamIndices;
1612+
SmallVector<unsigned, 8> activeParamIndices;
16081613
SmallVector<unsigned, 8> activeResultIndices;
16091614
collectMinimalIndicesForFunctionCall(
16101615
ai, allResults, indices, activityInfo, activeParamIndices,
16111616
activeResultIndices);
16121617
source = activeResultIndices.front();
16131618

1614-
// If function is already marked differentiable, differentiate WRT
1619+
// If function is already marked differentiable, differentiate WRT
16151620
// all parameters.
16161621
auto originalFnSubstTy = ai->getSubstCalleeType();;
16171622
if (originalFnSubstTy->isDifferentiable()) {
@@ -1628,7 +1633,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
16281633
ai->getArgumentsWithoutIndirectResults().size(),
16291634
activeParamIndices));
16301635

1631-
// Check and diagnose non-differentiable original function type.
1636+
// Check and diagnose non-differentiable original function type.
16321637
auto diagnoseNondifferentiableOriginalFunctionType =
16331638
[&](CanSILFunctionType origFnTy) {
16341639
// Check and diagnose non-differentiable arguments.
@@ -1640,7 +1645,7 @@ bool LinearMapInfo::shouldBeDifferentiated(
16401645
return true;
16411646
}
16421647
}
1643-
// Check and diagnose non-differentiable results.
1648+
// Check non-differentiable results.
16441649
if (!origFnTy->getResults()[curIndices.source]
16451650
.getSILStorageType()
16461651
.isDifferentiable(builder.getModule())) {
@@ -1651,19 +1656,21 @@ bool LinearMapInfo::shouldBeDifferentiated(
16511656
if (diagnoseNondifferentiableOriginalFunctionType(originalFnSubstTy))
16521657
continue;
16531658

1654-
auto JVPType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1659+
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
16551660
parameters, source,
16561661
/*differentiationOrder*/ 1, kind, builder.getModule(),
16571662
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
16581663

1659-
auto JVPResultTypes = JVPType->getAllResultsType().castTo<TupleType>();
1660-
JVPResultTypes->getElement(JVPResultTypes->getElements().size() - 1);
1661-
auto differentialSILType =
1664+
auto assocFnResultTypes =
1665+
assocFnType->getAllResultsType().castTo<TupleType>();
1666+
assocFnResultTypes
1667+
->getElement(assocFnResultTypes->getElements().size() - 1);
1668+
auto linearMapSILType =
16621669
SILType::getPrimitiveObjectType(
1663-
JVPResultTypes->getElement(
1664-
JVPResultTypes->getElements().size() - 1)
1670+
assocFnResultTypes->getElement(
1671+
assocFnResultTypes->getElements().size() - 1)
16651672
.getType()->getCanonicalType());
1666-
addLinearMapDecl(ai, differentialSILType);
1673+
addLinearMapDecl(ai, linearMapSILType);
16671674
}
16681675
}
16691676
}

0 commit comments

Comments
 (0)