@@ -434,6 +434,9 @@ class LinearMapInfo {
434
434
// / The original function.
435
435
SILFunction *const original;
436
436
437
+ // / The derivative function.
438
+ SILFunction *const derivative;
439
+
437
440
// / Activity info of the original function.
438
441
const DifferentiableActivityInfo &activityInfo;
439
442
@@ -464,9 +467,16 @@ class LinearMapInfo {
464
467
// / A type converter, used to compute struct/enum SIL types.
465
468
Lowering::TypeConverter &typeConverter;
466
469
467
- SILBuilder &builder;
468
-
469
470
private:
471
+ // / Remaps the given type into the derivative function's context.
472
+ SILType remapTypeInDerivative (SILType ty) {
473
+ if (ty.hasArchetype ())
474
+ return derivative->mapTypeIntoContext (ty.mapTypeOutOfContext ());
475
+ return derivative->mapTypeIntoContext (ty);
476
+ }
477
+
478
+ // / Adds a `VarDecl` member with the given name and type to the given nominal
479
+ // / declaration.
470
480
VarDecl *addVarDecl (NominalTypeDecl *nominal, StringRef name, Type type) {
471
481
auto &astCtx = nominal->getASTContext ();
472
482
auto id = astCtx.getIdentifier (name);
@@ -485,9 +495,9 @@ class LinearMapInfo {
485
495
// / Retrieves the file unit that contains implicit declarations in the
486
496
// / current Swift module. If it does not exist, create one.
487
497
// /
488
- // FIXME: Currently it defaults to the file containing `origFn `, if it can be
489
- // determined. Otherwise, it defaults to any file unit in the module. To
490
- // handle this more properly, we should make a DerivedFileUnit class to
498
+ // FIXME: Currently it defaults to the file containing `original `, if it can
499
+ // be determined. Otherwise, it defaults to any file unit in the module. To
500
+ // handle this more properly, we could revive the DerivedFileUnit class to
491
501
// contain all synthesized implicit type declarations.
492
502
SourceFile &getDeclarationFileUnit () {
493
503
if (original->hasLocation ())
@@ -699,7 +709,7 @@ class LinearMapInfo {
699
709
// / branching enum field.
700
710
void generateDifferentiationDataStructures (
701
711
ADContext &context, const SILAutoDiffIndices &indices,
702
- SILFunction *assocFn );
712
+ SILFunction *derivative );
703
713
704
714
public:
705
715
bool shouldDifferentiateApplyInst (ApplyInst *ai);
@@ -710,10 +720,9 @@ class LinearMapInfo {
710
720
711
721
explicit LinearMapInfo (ADContext &context,
712
722
AutoDiffLinearMapKind kind,
713
- SILFunction *original, SILFunction *assocFn ,
723
+ SILFunction *original, SILFunction *derivative ,
714
724
const SILAutoDiffIndices &indices,
715
- const DifferentiableActivityInfo &activityInfo,
716
- SILBuilder &builder);
725
+ const DifferentiableActivityInfo &activityInfo);
717
726
718
727
// / Returns the linear map struct associated with the given original block.
719
728
StructDecl *getLinearMapStruct (SILBasicBlock *origBB) const {
@@ -771,7 +780,9 @@ class LinearMapInfo {
771
780
// / `struct_extract` in the original function.
772
781
VarDecl *lookUpLinearMapDecl (SILInstruction *inst) {
773
782
auto lookup = linearMapValueMap.find (inst);
774
- return lookup == linearMapValueMap.end () ? nullptr : lookup->getSecond ();
783
+ assert (lookup != linearMapValueMap.end () &&
784
+ " No linear map declaration corresponding to the given instruction" );
785
+ return lookup->getSecond ();
775
786
}
776
787
};
777
788
@@ -1506,14 +1517,13 @@ static void collectMinimalIndicesForFunctionCall(
1506
1517
1507
1518
LinearMapInfo::LinearMapInfo (ADContext &context,
1508
1519
AutoDiffLinearMapKind kind,
1509
- SILFunction *original, SILFunction *assocFn ,
1520
+ SILFunction *original, SILFunction *derivative ,
1510
1521
const SILAutoDiffIndices &indices,
1511
- const DifferentiableActivityInfo &activityInfo,
1512
- SILBuilder &builder)
1513
- : kind(kind), original(original), activityInfo(activityInfo),
1514
- indices (indices), typeConverter(context.getTypeConverter()),
1515
- builder(builder) {
1516
- generateDifferentiationDataStructures (context, indices, assocFn);
1522
+ const DifferentiableActivityInfo &activityInfo)
1523
+ : kind(kind), original(original), derivative(derivative),
1524
+ activityInfo (activityInfo), indices(indices),
1525
+ typeConverter(context.getTypeConverter()) {
1526
+ generateDifferentiationDataStructures (context, indices, derivative);
1517
1527
}
1518
1528
1519
1529
// / Returns a flag that indicates whether the `apply` instruction should be
@@ -1608,7 +1618,7 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
1608
1618
}
1609
1619
1610
1620
// / Takes an `apply` instruction and adds its linear map function to the
1611
- // / linear map struct if it's active.
1621
+ // / linear map struct if it is active.
1612
1622
void LinearMapInfo::addLinearMapToStruct (ADContext &context, ApplyInst *ai,
1613
1623
const SILAutoDiffIndices &indices) {
1614
1624
SmallVector<SILValue, 4 > allResults;
@@ -1620,8 +1630,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
1620
1630
1621
1631
// Check if there are any active results or arguments. If not, skip
1622
1632
// this instruction.
1623
- auto hasActiveResults = llvm::any_of (
1624
- allResults, [&](SILValue res) {
1633
+ auto hasActiveResults = llvm::any_of (allResults, [&](SILValue res) {
1625
1634
return activityInfo.isActive (res, indices);
1626
1635
});
1627
1636
auto hasActiveArguments = llvm::any_of (
@@ -1638,9 +1647,12 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
1638
1647
// parameters from the function type.
1639
1648
// - Otherwise, use the active parameters.
1640
1649
AutoDiffIndexSubset *parameters;
1641
- auto originalFnSubstTy = ai->getSubstCalleeType ();
1642
- if (originalFnSubstTy->isDifferentiable ()) {
1643
- parameters = originalFnSubstTy->getDifferentiationParameterIndices ();
1650
+ auto origFnSubstTy = ai->getSubstCalleeType ();
1651
+ auto remappedOrigFnSubstTy =
1652
+ remapTypeInDerivative (SILType::getPrimitiveObjectType (origFnSubstTy))
1653
+ .castTo <SILFunctionType>();
1654
+ if (remappedOrigFnSubstTy->isDifferentiable ()) {
1655
+ parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices ();
1644
1656
} else {
1645
1657
parameters = AutoDiffIndexSubset::get (
1646
1658
original->getASTContext (),
@@ -1653,29 +1665,29 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
1653
1665
// Check for non-differentiable original function type.
1654
1666
auto checkNondifferentiableOriginalFunctionType =
1655
1667
[&](CanSILFunctionType origFnTy) {
1656
- // Check and diagnose non-differentiable arguments.
1668
+ // Check non-differentiable arguments.
1657
1669
for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1670
+ auto remappedParamType =
1671
+ origFnTy->getParameters ()[paramIndex].getSILStorageType ();
1658
1672
if (applyIndices.isWrtParameter (paramIndex) &&
1659
- !origFnTy->getParameters ()[paramIndex]
1660
- .getSILStorageType ()
1661
- .isDifferentiable (builder.getModule ()))
1673
+ !remappedParamType.isDifferentiable (derivative->getModule ()))
1662
1674
return true ;
1663
1675
}
1664
1676
// Check non-differentiable results.
1665
- if (!origFnTy-> getResults ()[applyIndices. source ]
1666
- . getSILStorageType ()
1667
- .isDifferentiable (builder. getModule ()))
1677
+ auto remappedResultType =
1678
+ origFnTy-> getResults ()[applyIndices. source ]. getSILStorageType ();
1679
+ if (!remappedResultType .isDifferentiable (derivative-> getModule ()))
1668
1680
return true ;
1669
1681
return false ;
1670
1682
};
1671
- if (checkNondifferentiableOriginalFunctionType (originalFnSubstTy ))
1683
+ if (checkNondifferentiableOriginalFunctionType (remappedOrigFnSubstTy ))
1672
1684
return ;
1673
1685
1674
1686
AutoDiffAssociatedFunctionKind assocFnKind (kind);
1675
- auto assocFnType = originalFnSubstTy ->getAutoDiffAssociatedFunctionType (
1687
+ auto assocFnType = remappedOrigFnSubstTy ->getAutoDiffAssociatedFunctionType (
1676
1688
parameters, source, /* differentiationOrder*/ 1 , assocFnKind,
1677
1689
context.getTypeConverter (),
1678
- LookUpConformanceInModule (builder. getModule ().getSwiftModule ()));
1690
+ LookUpConformanceInModule (derivative-> getModule ().getSwiftModule ()));
1679
1691
1680
1692
auto assocFnResultTypes =
1681
1693
assocFnType->getAllResultsType ().castTo <TupleType>();
@@ -1738,8 +1750,6 @@ void LinearMapInfo::generateDifferentiationDataStructures(
1738
1750
for (auto &origBB : *original) {
1739
1751
for (auto &inst : origBB) {
1740
1752
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1741
- LLVM_DEBUG (getADDebugStream ()
1742
- << " Adding linear map struct field for " << *ai);
1743
1753
// Check for active 'inout' arguments.
1744
1754
bool isInout = false ;
1745
1755
auto paramInfos = ai->getSubstCalleeConv ().getParameters ();
@@ -1754,13 +1764,17 @@ void LinearMapInfo::generateDifferentiationDataStructures(
1754
1764
}
1755
1765
}
1756
1766
if (isInout)
1757
- break ;
1767
+ continue ;
1768
+
1769
+ // Add linear map field to struct for active `apply` instructions.
1770
+ // Skip array literal intrinsic applications since array literal
1771
+ // initialization is linear and handled separately.
1772
+ if (!shouldDifferentiateApplyInst (ai) || isArrayLiteralIntrinsic (ai))
1773
+ continue ;
1758
1774
1759
- // Add linear map to struct for active instructions.
1760
- // Do not add it for array functions since those are already linear
1761
- // and we don't need to add it to the struct.
1762
- if (shouldDifferentiateApplyInst (ai) && !isArrayLiteralIntrinsic (ai))
1763
- addLinearMapToStruct (context, ai, indices);
1775
+ LLVM_DEBUG (getADDebugStream () << " Adding linear map struct field for "
1776
+ << *ai);
1777
+ addLinearMapToStruct (context, ai, indices);
1764
1778
}
1765
1779
}
1766
1780
}
@@ -3320,8 +3334,8 @@ class VJPEmitter final
3320
3334
context(context), original(original), attr(attr), vjp(vjp),
3321
3335
invoker(invoker), activityInfo(getActivityInfo(
3322
3336
context, original, attr->getIndices (), vjp)),
3323
- pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original,
3324
- vjp, attr->getIndices (), activityInfo, getBuilder() ) {
3337
+ pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
3338
+ attr->getIndices (), activityInfo) {
3325
3339
// Create empty pullback function.
3326
3340
pullback = createEmptyPullback ();
3327
3341
context.getGeneratedFunctions ().push_back (pullback);
@@ -4149,7 +4163,7 @@ class JVPEmitter final
4149
4163
// --------------------------------------------------------------------------//
4150
4164
4151
4165
// / The builder for the differential function.
4152
- SILBuilder differentialAndBuilder ;
4166
+ SILBuilder differentialBuilder ;
4153
4167
4154
4168
// / Mapping from original basic blocks to corresponding differential basic
4155
4169
// / blocks.
@@ -4189,9 +4203,9 @@ class JVPEmitter final
4189
4203
ASTContext &getASTContext () const { return jvp->getASTContext (); }
4190
4204
SILModule &getModule () const { return jvp->getModule (); }
4191
4205
const SILAutoDiffIndices &getIndices () const { return attr->getIndices (); }
4192
- SILBuilder &getDifferentialBuilder () { return differentialAndBuilder ; }
4206
+ SILBuilder &getDifferentialBuilder () { return differentialBuilder ; }
4193
4207
SILFunction &getDifferential () {
4194
- return differentialAndBuilder .getFunction ();
4208
+ return differentialBuilder .getFunction ();
4195
4209
}
4196
4210
SILArgument *getDifferentialStructArgument (SILBasicBlock *origBB) {
4197
4211
#ifndef NDEBUG
@@ -4235,15 +4249,6 @@ class JVPEmitter final
4235
4249
return activityInfo;
4236
4250
}
4237
4251
4238
- static SILBuilder
4239
- initializeDifferentialAndBuilder (ADContext &context, SILFunction *original,
4240
- SILDifferentiableAttr *attr,
4241
- LinearMapInfo *linearMapInfo) {
4242
- auto *differential =
4243
- createEmptyDifferential (context, original, attr, linearMapInfo);
4244
- return SILBuilder (*differential);
4245
- }
4246
-
4247
4252
// --------------------------------------------------------------------------//
4248
4253
// Differential struct mapping
4249
4254
// --------------------------------------------------------------------------//
@@ -5219,9 +5224,9 @@ class JVPEmitter final
5219
5224
invoker(invoker), activityInfo(getActivityInfo(
5220
5225
context, original, attr->getIndices (), jvp)),
5221
5226
differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
5222
- jvp, attr->getIndices (), activityInfo, getBuilder() ),
5223
- differentialAndBuilder(initializeDifferentialAndBuilder (
5224
- context, original, attr, &differentialInfo)),
5227
+ jvp, attr->getIndices (), activityInfo),
5228
+ differentialBuilder(SILBuilder(* createEmptyDifferential (
5229
+ context, original, attr, &differentialInfo))) ,
5225
5230
diffLocalAllocBuilder(getDifferential()) {
5226
5231
// Create empty differential function.
5227
5232
context.getGeneratedFunctions ().push_back (&getDifferential ());
0 commit comments