Skip to content

Commit 81b8592

Browse files
committed
Pull out the adding of the apply instruction to helper.
1 parent 6ca3817 commit 81b8592

File tree

1 file changed

+92
-90
lines changed

1 file changed

+92
-90
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ class LinearMapInfo {
388388
/// The original function.
389389
SILFunction *const original;
390390

391-
/// Activitiy info of the original function.
391+
/// Activity info of the original function.
392392
const DifferentiableActivityInfo &activityInfo;
393393

394394
/// Mapping from original basic blocks to linear map structs.
@@ -668,6 +668,8 @@ class LinearMapInfo {
668668
return linearMapDecl;
669669
}
670670

671+
void addLinearMapToStruct(ApplyInst *ai, const SILAutoDiffIndices &indices);
672+
671673
/// This takes the declared linear map structs and populates
672674
/// them with the necessary fields, specifically the linear function (pullback
673675
/// or differential) of the corresponding original function call in the
@@ -1522,6 +1524,92 @@ bool LinearMapInfo::shouldBeDifferentiated(ApplyInst *ai,
15221524
return false;
15231525
}
15241526

1527+
1528+
/// Takes an `apply` instruction and adds its linear map function to the
1529+
/// linear map struct if it's active.
1530+
void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai,
1531+
const SILAutoDiffIndices &indices) {
1532+
SmallVector<SILValue, 4> allResults;
1533+
allResults.push_back(ai);
1534+
allResults.append(ai->getIndirectSILResults().begin(),
1535+
ai->getIndirectSILResults().end());
1536+
1537+
// Check if there are any active results or arguments. If not, skip
1538+
// this instruction.
1539+
auto hasActiveResults = llvm::any_of(
1540+
allResults, [&](SILValue res) {
1541+
return activityInfo.isActive(res, indices);
1542+
});
1543+
auto hasActiveArguments = llvm::any_of(
1544+
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
1545+
return activityInfo.isActive(arg, indices);
1546+
});
1547+
if (!hasActiveResults || !hasActiveArguments)
1548+
return;
1549+
1550+
unsigned source;
1551+
AutoDiffIndexSubset *parameters;
1552+
1553+
SmallVector<unsigned, 8> activeParamIndices;
1554+
SmallVector<unsigned, 8> activeResultIndices;
1555+
collectMinimalIndicesForFunctionCall(
1556+
ai, allResults, indices, activityInfo, activeParamIndices,
1557+
activeResultIndices);
1558+
source = activeResultIndices.front();
1559+
1560+
// If function is already marked differentiable, differentiate W.R.T.
1561+
// all parameters.
1562+
auto originalFnSubstTy = ai->getSubstCalleeType();
1563+
if (originalFnSubstTy->isDifferentiable()) {
1564+
parameters = originalFnSubstTy->getDifferentiationParameterIndices();
1565+
} else {
1566+
parameters = AutoDiffIndexSubset::get(
1567+
original->getASTContext(),
1568+
ai->getArgumentsWithoutIndirectResults().size(),
1569+
activeParamIndices);
1570+
}
1571+
SILAutoDiffIndices curIndices(activeResultIndices.front(),
1572+
AutoDiffIndexSubset::get(
1573+
builder.getASTContext(),
1574+
ai->getArgumentsWithoutIndirectResults().size(),
1575+
activeParamIndices));
1576+
1577+
// Check for non-differentiable original function type.
1578+
auto checkNondifferentiableOriginalFunctionType =
1579+
[&](CanSILFunctionType origFnTy) {
1580+
// Check and diagnose non-differentiable arguments.
1581+
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
1582+
if (curIndices.isWrtParameter(paramIndex) &&
1583+
!origFnTy->getParameters()[paramIndex]
1584+
.getSILStorageType()
1585+
.isDifferentiable(builder.getModule()))
1586+
return true;
1587+
}
1588+
// Check non-differentiable results.
1589+
if (!origFnTy->getResults()[curIndices.source]
1590+
.getSILStorageType()
1591+
.isDifferentiable(builder.getModule()))
1592+
return true;
1593+
return false;
1594+
};
1595+
if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy))
1596+
return;
1597+
1598+
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1599+
parameters, source, /*differentiationOrder*/ 1, kind, builder.getModule(),
1600+
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
1601+
1602+
auto assocFnResultTypes =
1603+
assocFnType->getAllResultsType().castTo<TupleType>();
1604+
assocFnResultTypes->getElement(assocFnResultTypes->getElements().size() - 1);
1605+
auto linearMapSILType = SILType::getPrimitiveObjectType(
1606+
assocFnResultTypes
1607+
->getElement(assocFnResultTypes->getElements().size() - 1)
1608+
.getType()
1609+
->getCanonicalType());
1610+
addLinearMapDecl(ai, linearMapSILType);
1611+
}
1612+
15251613
void LinearMapInfo::populateLinearMapStructDeclarationFields(
15261614
ADContext &context, const SILAutoDiffIndices &indices,
15271615
SILFunction *assocFn) {
@@ -1548,8 +1636,7 @@ void LinearMapInfo::populateLinearMapStructDeclarationFields(
15481636
// TODO: add support for forward mode.
15491637
for (auto &origBB : *original) {
15501638
auto *linearMapStruct = getLinearMapStruct(&origBB);
1551-
auto *traceEnum =
1552-
createBranchingTraceDecl(&origBB, indices, assocFnGenSig);
1639+
auto *traceEnum = createBranchingTraceDecl(&origBB, indices, assocFnGenSig);
15531640

15541641
// If original block is in a loop, mark branching trace enum as indirect.
15551642
if (loopInfo->getLoopFor(&origBB))
@@ -1587,93 +1674,8 @@ void LinearMapInfo::populateLinearMapStructDeclarationFields(
15871674
// Do not add it for array functions since those are already linear
15881675
// and we don't need to add it to the struct.
15891676
if (shouldBeDifferentiated(ai, indices) &&
1590-
!ai->hasSemantics("array.uninitialized_intrinsic")) {
1591-
SmallVector<SILValue, 4> allResults;
1592-
allResults.push_back(ai);
1593-
allResults.append(ai->getIndirectSILResults().begin(),
1594-
ai->getIndirectSILResults().end());
1595-
1596-
// Check if there are any active results or arguments. If not, skip
1597-
// this instruction.
1598-
auto hasActiveResults = llvm::any_of(
1599-
allResults, [&](SILValue res) {
1600-
return activityInfo.isActive(res, indices);
1601-
});
1602-
auto hasActiveArguments = llvm::any_of(
1603-
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
1604-
return activityInfo.isActive(arg, indices);
1605-
});
1606-
if (!hasActiveResults || !hasActiveArguments)
1607-
continue;
1608-
1609-
unsigned source;
1610-
AutoDiffIndexSubset *parameters;
1611-
1612-
SmallVector<unsigned, 8> activeParamIndices;
1613-
SmallVector<unsigned, 8> activeResultIndices;
1614-
collectMinimalIndicesForFunctionCall(
1615-
ai, allResults, indices, activityInfo, activeParamIndices,
1616-
activeResultIndices);
1617-
source = activeResultIndices.front();
1618-
1619-
// If function is already marked differentiable, differentiate WRT
1620-
// all parameters.
1621-
auto originalFnSubstTy = ai->getSubstCalleeType();
1622-
if (originalFnSubstTy->isDifferentiable()) {
1623-
parameters =
1624-
originalFnSubstTy->getDifferentiationParameterIndices();
1625-
} else {
1626-
parameters = AutoDiffIndexSubset::get(
1627-
original->getASTContext(),
1628-
ai->getArgumentsWithoutIndirectResults().size(),
1629-
activeParamIndices);
1630-
}
1631-
SILAutoDiffIndices curIndices(activeResultIndices.front(),
1632-
AutoDiffIndexSubset::get(
1633-
builder.getASTContext(),
1634-
ai->getArgumentsWithoutIndirectResults().size(),
1635-
activeParamIndices));
1636-
1637-
// Check for non-differentiable original function type.
1638-
auto checkNondifferentiableOriginalFunctionType =
1639-
[&](CanSILFunctionType origFnTy) {
1640-
// Check and diagnose non-differentiable arguments.
1641-
for (unsigned paramIndex :
1642-
range(origFnTy->getNumParameters())) {
1643-
if (curIndices.isWrtParameter(paramIndex) &&
1644-
!origFnTy->getParameters()[paramIndex]
1645-
.getSILStorageType()
1646-
.isDifferentiable(builder.getModule()))
1647-
return true;
1648-
}
1649-
// Check non-differentiable results.
1650-
if (!origFnTy->getResults()[curIndices.source]
1651-
.getSILStorageType()
1652-
.isDifferentiable(builder.getModule()))
1653-
return true;
1654-
return false;
1655-
};
1656-
if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy))
1657-
continue;
1658-
1659-
auto assocFnType =
1660-
originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1661-
parameters, source, /*differentiationOrder*/ 1, kind,
1662-
builder.getModule(),
1663-
LookUpConformanceInModule(
1664-
builder.getModule().getSwiftModule()));
1665-
1666-
auto assocFnResultTypes =
1667-
assocFnType->getAllResultsType().castTo<TupleType>();
1668-
assocFnResultTypes->getElement(
1669-
assocFnResultTypes->getElements().size() - 1);
1670-
auto linearMapSILType = SILType::getPrimitiveObjectType(
1671-
assocFnResultTypes
1672-
->getElement(assocFnResultTypes->getElements().size() - 1)
1673-
.getType()
1674-
->getCanonicalType());
1675-
addLinearMapDecl(ai, linearMapSILType);
1676-
}
1677+
!ai->hasSemantics("array.uninitialized_intrinsic"))
1678+
addLinearMapToStruct(ai, indices);
16771679
}
16781680
}
16791681
}

0 commit comments

Comments
 (0)