Skip to content

Commit 6ca3817

Browse files
committed
return brackets and spacing.
2 parents 608844b + 931865d commit 6ca3817

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -668,9 +668,10 @@ class LinearMapInfo {
668668
return linearMapDecl;
669669
}
670670

671-
/// This takes the declared linear map structs per basic block, and populates them with the necessary
672-
/// fields, specifically the linear function (pullback or differential) of the corresponding original function call
673-
/// in the original function, and the branching enum.
671+
/// This takes the declared linear map structs and populates
672+
/// them with the necessary fields, specifically the linear function (pullback
673+
/// or differential) of the corresponding original function call in the
674+
/// original function, and the branching enum.
674675
void populateLinearMapStructDeclarationFields(
675676
ADContext &context, const SILAutoDiffIndices &indices,
676677
SILFunction *assocFn);
@@ -1490,39 +1491,38 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
14901491
populateLinearMapStructDeclarationFields(context, indices, assocFn);
14911492
}
14921493

1493-
bool LinearMapInfo::shouldBeDifferentiated(
1494-
ApplyInst *ai, const SILAutoDiffIndices &indices) {
1494+
bool LinearMapInfo::shouldBeDifferentiated(ApplyInst *ai,
1495+
const SILAutoDiffIndices &indices) {
14951496
// Anything with an active result should be differentiated.
14961497
if (llvm::any_of(ai->getResults(), [&](SILValue val) {
14971498
return activityInfo.isActive(val, indices);
1498-
})) {
1499+
}))
14991500
return true;
1500-
}
1501+
15011502
// Function applications with an active indirect result should be
15021503
// differentiated.
15031504
for (auto indRes : ai->getIndirectSILResults())
1504-
if (activityInfo.isActive(indRes, indices)) {
1505+
if (activityInfo.isActive(indRes, indices))
15051506
return true;
1506-
}
1507+
15071508
// Function applications with an inout argument should be differentiated.
15081509
auto paramInfos = ai->getSubstCalleeConv().getParameters();
15091510
for (auto i : swift::indices(paramInfos))
15101511
if (paramInfos[i].isIndirectInOut() &&
15111512
activityInfo.isActive(
1512-
ai->getArgumentsWithoutIndirectResults()[i], indices)) {
1513+
ai->getArgumentsWithoutIndirectResults()[i], indices))
15131514
return true;
1514-
}
1515+
15151516
// Instructions that may write to memory and that have an active operand
15161517
// should be differentiated.
15171518
if (ai->mayWriteToMemory())
15181519
for (auto &op : ai->getAllOperands())
1519-
if (activityInfo.isActive(op.get(), indices)) {
1520+
if (activityInfo.isActive(op.get(), indices))
15201521
return true;
1521-
}
15221522
return false;
15231523
}
15241524

1525-
void LinearMapInfo::populateLinearMapStructDeclarationFields(
1525+
void LinearMapInfo::populateLinearMapStructDeclarationFields(
15261526
ADContext &context, const SILAutoDiffIndices &indices,
15271527
SILFunction *assocFn) {
15281528

@@ -1574,7 +1574,8 @@ bool LinearMapInfo::shouldBeDifferentiated(
15741574
if (paramInfos[i].isIndirectInOut() &&
15751575
activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i],
15761576
indices)) {
1577-
// Reject functions with active inout arguments. It's not yet supported.
1577+
// Reject functions with active inout arguments. It's not yet
1578+
// supported.
15781579
isInout = true;
15791580
break;
15801581
}
@@ -1595,11 +1596,11 @@ bool LinearMapInfo::shouldBeDifferentiated(
15951596
// Check if there are any active results or arguments. If not, skip
15961597
// this instruction.
15971598
auto hasActiveResults = llvm::any_of(
1598-
allResults, [&](SILValue res) {
1599+
allResults, [&](SILValue res) {
15991600
return activityInfo.isActive(res, indices);
16001601
});
16011602
auto hasActiveArguments = llvm::any_of(
1602-
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
1603+
ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
16031604
return activityInfo.isActive(arg, indices);
16041605
});
16051606
if (!hasActiveResults || !hasActiveArguments)
@@ -1617,9 +1618,10 @@ bool LinearMapInfo::shouldBeDifferentiated(
16171618

16181619
// If function is already marked differentiable, differentiate WRT
16191620
// all parameters.
1620-
auto originalFnSubstTy = ai->getSubstCalleeType();;
1621+
auto originalFnSubstTy = ai->getSubstCalleeType();
16211622
if (originalFnSubstTy->isDifferentiable()) {
1622-
parameters = originalFnSubstTy->getDifferentiationParameterIndices();
1623+
parameters =
1624+
originalFnSubstTy->getDifferentiationParameterIndices();
16231625
} else {
16241626
parameters = AutoDiffIndexSubset::get(
16251627
original->getASTContext(),
@@ -1632,43 +1634,44 @@ bool LinearMapInfo::shouldBeDifferentiated(
16321634
ai->getArgumentsWithoutIndirectResults().size(),
16331635
activeParamIndices));
16341636

1635-
// Check and diagnose non-differentiable original function type.
1636-
auto diagnoseNondifferentiableOriginalFunctionType =
1637+
// Check for non-differentiable original function type.
1638+
auto checkNondifferentiableOriginalFunctionType =
16371639
[&](CanSILFunctionType origFnTy) {
16381640
// Check and diagnose non-differentiable arguments.
1639-
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
1641+
for (unsigned paramIndex :
1642+
range(origFnTy->getNumParameters())) {
16401643
if (curIndices.isWrtParameter(paramIndex) &&
16411644
!origFnTy->getParameters()[paramIndex]
16421645
.getSILStorageType()
1643-
.isDifferentiable(builder.getModule())) {
1646+
.isDifferentiable(builder.getModule()))
16441647
return true;
1645-
}
16461648
}
16471649
// Check non-differentiable results.
16481650
if (!origFnTy->getResults()[curIndices.source]
16491651
.getSILStorageType()
1650-
.isDifferentiable(builder.getModule())) {
1652+
.isDifferentiable(builder.getModule()))
16511653
return true;
1652-
}
16531654
return false;
16541655
};
1655-
if (diagnoseNondifferentiableOriginalFunctionType(originalFnSubstTy))
1656+
if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy))
16561657
continue;
16571658

1658-
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1659-
parameters, source,
1660-
/*differentiationOrder*/ 1, kind, builder.getModule(),
1661-
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
1659+
auto assocFnType =
1660+
originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1661+
parameters, source, /*differentiationOrder*/ 1, kind,
1662+
builder.getModule(),
1663+
LookUpConformanceInModule(
1664+
builder.getModule().getSwiftModule()));
16621665

16631666
auto assocFnResultTypes =
16641667
assocFnType->getAllResultsType().castTo<TupleType>();
1665-
assocFnResultTypes
1666-
->getElement(JVPResultTypes->getElements().size() - 1);
1667-
auto linearMapSILType =
1668-
SILType::getPrimitiveObjectType(
1669-
assocFnResultTypes->getElement(
1670-
assocFnResultTypes->getElements().size() - 1)
1671-
.getType()->getCanonicalType());
1668+
assocFnResultTypes->getElement(
1669+
assocFnResultTypes->getElements().size() - 1);
1670+
auto linearMapSILType = SILType::getPrimitiveObjectType(
1671+
assocFnResultTypes
1672+
->getElement(assocFnResultTypes->getElements().size() - 1)
1673+
.getType()
1674+
->getCanonicalType());
16721675
addLinearMapDecl(ai, linearMapSILType);
16731676
}
16741677
}

0 commit comments

Comments
 (0)