Skip to content

Commit 36676ae

Browse files
authored
[AutoDiff] Remap apply callee type in derivative context. (#27590)
Previously, `LinearMapInfo::addLinearMapToStruct` did not remap `apply` callee type in derivative context. Now, remapping is done. Remapping is significant when the derivative has a more constrained generic signature. Resolves TF-817.
1 parent c76bde9 commit 36676ae

File tree

2 files changed

+82
-58
lines changed

2 files changed

+82
-58
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ class LinearMapInfo {
434434
/// The original function.
435435
SILFunction *const original;
436436

437+
/// The derivative function.
438+
SILFunction *const derivative;
439+
437440
/// Activity info of the original function.
438441
const DifferentiableActivityInfo &activityInfo;
439442

@@ -464,9 +467,16 @@ class LinearMapInfo {
464467
/// A type converter, used to compute struct/enum SIL types.
465468
Lowering::TypeConverter &typeConverter;
466469

467-
SILBuilder &builder;
468-
469470
private:
471+
/// Remaps the given type into the derivative function's context.
472+
SILType remapTypeInDerivative(SILType ty) {
473+
if (ty.hasArchetype())
474+
return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext());
475+
return derivative->mapTypeIntoContext(ty);
476+
}
477+
478+
/// Adds a `VarDecl` member with the given name and type to the given nominal
479+
/// declaration.
470480
VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) {
471481
auto &astCtx = nominal->getASTContext();
472482
auto id = astCtx.getIdentifier(name);
@@ -485,9 +495,9 @@ class LinearMapInfo {
485495
/// Retrieves the file unit that contains implicit declarations in the
486496
/// current Swift module. If it does not exist, create one.
487497
///
488-
// FIXME: Currently it defaults to the file containing `origFn`, if it can be
489-
// determined. Otherwise, it defaults to any file unit in the module. To
490-
// handle this more properly, we should make a DerivedFileUnit class to
498+
// FIXME: Currently it defaults to the file containing `original`, if it can
499+
// be determined. Otherwise, it defaults to any file unit in the module. To
500+
// handle this more properly, we could revive the DerivedFileUnit class to
491501
// contain all synthesized implicit type declarations.
492502
SourceFile &getDeclarationFileUnit() {
493503
if (original->hasLocation())
@@ -699,7 +709,7 @@ class LinearMapInfo {
699709
/// branching enum field.
700710
void generateDifferentiationDataStructures(
701711
ADContext &context, const SILAutoDiffIndices &indices,
702-
SILFunction *assocFn);
712+
SILFunction *derivative);
703713

704714
public:
705715
bool shouldDifferentiateApplyInst(ApplyInst *ai);
@@ -710,10 +720,9 @@ class LinearMapInfo {
710720

711721
explicit LinearMapInfo(ADContext &context,
712722
AutoDiffLinearMapKind kind,
713-
SILFunction *original, SILFunction *assocFn,
723+
SILFunction *original, SILFunction *derivative,
714724
const SILAutoDiffIndices &indices,
715-
const DifferentiableActivityInfo &activityInfo,
716-
SILBuilder &builder);
725+
const DifferentiableActivityInfo &activityInfo);
717726

718727
/// Returns the linear map struct associated with the given original block.
719728
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
@@ -771,7 +780,9 @@ class LinearMapInfo {
771780
/// `struct_extract` in the original function.
772781
VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
773782
auto lookup = linearMapValueMap.find(inst);
774-
return lookup == linearMapValueMap.end() ? nullptr : lookup->getSecond();
783+
assert(lookup != linearMapValueMap.end() &&
784+
"No linear map declaration corresponding to the given instruction");
785+
return lookup->getSecond();
775786
}
776787
};
777788

@@ -1506,14 +1517,13 @@ static void collectMinimalIndicesForFunctionCall(
15061517

15071518
LinearMapInfo::LinearMapInfo(ADContext &context,
15081519
AutoDiffLinearMapKind kind,
1509-
SILFunction *original, SILFunction *assocFn,
1520+
SILFunction *original, SILFunction *derivative,
15101521
const SILAutoDiffIndices &indices,
1511-
const DifferentiableActivityInfo &activityInfo,
1512-
SILBuilder &builder)
1513-
: kind(kind), original(original), activityInfo(activityInfo),
1514-
indices(indices), typeConverter(context.getTypeConverter()),
1515-
builder(builder) {
1516-
generateDifferentiationDataStructures(context, indices, assocFn);
1522+
const DifferentiableActivityInfo &activityInfo)
1523+
: kind(kind), original(original), derivative(derivative),
1524+
activityInfo(activityInfo), indices(indices),
1525+
typeConverter(context.getTypeConverter()) {
1526+
generateDifferentiationDataStructures(context, indices, derivative);
15171527
}
15181528

15191529
/// Returns a flag that indicates whether the `apply` instruction should be
@@ -1608,7 +1618,7 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
16081618
}
16091619

16101620
/// Takes an `apply` instruction and adds its linear map function to the
1611-
/// linear map struct if it's active.
1621+
/// linear map struct if it is active.
16121622
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16131623
const SILAutoDiffIndices &indices) {
16141624
SmallVector<SILValue, 4> allResults;
@@ -1620,8 +1630,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16201630

16211631
// Check if there are any active results or arguments. If not, skip
16221632
// this instruction.
1623-
auto hasActiveResults = llvm::any_of(
1624-
allResults, [&](SILValue res) {
1633+
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
16251634
return activityInfo.isActive(res, indices);
16261635
});
16271636
auto hasActiveArguments = llvm::any_of(
@@ -1638,9 +1647,12 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16381647
// parameters from the function type.
16391648
// - Otherwise, use the active parameters.
16401649
AutoDiffIndexSubset *parameters;
1641-
auto originalFnSubstTy = ai->getSubstCalleeType();
1642-
if (originalFnSubstTy->isDifferentiable()) {
1643-
parameters = originalFnSubstTy->getDifferentiationParameterIndices();
1650+
auto origFnSubstTy = ai->getSubstCalleeType();
1651+
auto remappedOrigFnSubstTy =
1652+
remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy))
1653+
.castTo<SILFunctionType>();
1654+
if (remappedOrigFnSubstTy->isDifferentiable()) {
1655+
parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices();
16441656
} else {
16451657
parameters = AutoDiffIndexSubset::get(
16461658
original->getASTContext(),
@@ -1653,29 +1665,29 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16531665
// Check for non-differentiable original function type.
16541666
auto checkNondifferentiableOriginalFunctionType =
16551667
[&](CanSILFunctionType origFnTy) {
1656-
// Check and diagnose non-differentiable arguments.
1668+
// Check non-differentiable arguments.
16571669
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
1670+
auto remappedParamType =
1671+
origFnTy->getParameters()[paramIndex].getSILStorageType();
16581672
if (applyIndices.isWrtParameter(paramIndex) &&
1659-
!origFnTy->getParameters()[paramIndex]
1660-
.getSILStorageType()
1661-
.isDifferentiable(builder.getModule()))
1673+
!remappedParamType.isDifferentiable(derivative->getModule()))
16621674
return true;
16631675
}
16641676
// Check non-differentiable results.
1665-
if (!origFnTy->getResults()[applyIndices.source]
1666-
.getSILStorageType()
1667-
.isDifferentiable(builder.getModule()))
1677+
auto remappedResultType =
1678+
origFnTy->getResults()[applyIndices.source].getSILStorageType();
1679+
if (!remappedResultType.isDifferentiable(derivative->getModule()))
16681680
return true;
16691681
return false;
16701682
};
1671-
if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy))
1683+
if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy))
16721684
return;
16731685

16741686
AutoDiffAssociatedFunctionKind assocFnKind(kind);
1675-
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
1687+
auto assocFnType = remappedOrigFnSubstTy->getAutoDiffAssociatedFunctionType(
16761688
parameters, source, /*differentiationOrder*/ 1, assocFnKind,
16771689
context.getTypeConverter(),
1678-
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
1690+
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
16791691

16801692
auto assocFnResultTypes =
16811693
assocFnType->getAllResultsType().castTo<TupleType>();
@@ -1738,8 +1750,6 @@ void LinearMapInfo::generateDifferentiationDataStructures(
17381750
for (auto &origBB : *original) {
17391751
for (auto &inst : origBB) {
17401752
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1741-
LLVM_DEBUG(getADDebugStream()
1742-
<< "Adding linear map struct field for " << *ai);
17431753
// Check for active 'inout' arguments.
17441754
bool isInout = false;
17451755
auto paramInfos = ai->getSubstCalleeConv().getParameters();
@@ -1754,13 +1764,17 @@ void LinearMapInfo::generateDifferentiationDataStructures(
17541764
}
17551765
}
17561766
if (isInout)
1757-
break;
1767+
continue;
1768+
1769+
// Add linear map field to struct for active `apply` instructions.
1770+
// Skip array literal intrinsic applications since array literal
1771+
// initialization is linear and handled separately.
1772+
if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai))
1773+
continue;
17581774

1759-
// Add linear map to struct for active instructions.
1760-
// Do not add it for array functions since those are already linear
1761-
// and we don't need to add it to the struct.
1762-
if (shouldDifferentiateApplyInst(ai) && !isArrayLiteralIntrinsic(ai))
1763-
addLinearMapToStruct(context, ai, indices);
1775+
LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for "
1776+
<< *ai);
1777+
addLinearMapToStruct(context, ai, indices);
17641778
}
17651779
}
17661780
}
@@ -3320,8 +3334,8 @@ class VJPEmitter final
33203334
context(context), original(original), attr(attr), vjp(vjp),
33213335
invoker(invoker), activityInfo(getActivityInfo(
33223336
context, original, attr->getIndices(), vjp)),
3323-
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original,
3324-
vjp, attr->getIndices(), activityInfo, getBuilder()) {
3337+
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
3338+
attr->getIndices(), activityInfo) {
33253339
// Create empty pullback function.
33263340
pullback = createEmptyPullback();
33273341
context.getGeneratedFunctions().push_back(pullback);
@@ -4149,7 +4163,7 @@ class JVPEmitter final
41494163
//--------------------------------------------------------------------------//
41504164

41514165
/// The builder for the differential function.
4152-
SILBuilder differentialAndBuilder;
4166+
SILBuilder differentialBuilder;
41534167

41544168
/// Mapping from original basic blocks to corresponding differential basic
41554169
/// blocks.
@@ -4189,9 +4203,9 @@ class JVPEmitter final
41894203
ASTContext &getASTContext() const { return jvp->getASTContext(); }
41904204
SILModule &getModule() const { return jvp->getModule(); }
41914205
const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); }
4192-
SILBuilder &getDifferentialBuilder() { return differentialAndBuilder; }
4206+
SILBuilder &getDifferentialBuilder() { return differentialBuilder; }
41934207
SILFunction &getDifferential() {
4194-
return differentialAndBuilder.getFunction();
4208+
return differentialBuilder.getFunction();
41954209
}
41964210
SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
41974211
#ifndef NDEBUG
@@ -4235,15 +4249,6 @@ class JVPEmitter final
42354249
return activityInfo;
42364250
}
42374251

4238-
static SILBuilder
4239-
initializeDifferentialAndBuilder(ADContext &context, SILFunction *original,
4240-
SILDifferentiableAttr *attr,
4241-
LinearMapInfo *linearMapInfo) {
4242-
auto *differential =
4243-
createEmptyDifferential(context, original, attr, linearMapInfo);
4244-
return SILBuilder(*differential);
4245-
}
4246-
42474252
//--------------------------------------------------------------------------//
42484253
// Differential struct mapping
42494254
//--------------------------------------------------------------------------//
@@ -5219,9 +5224,9 @@ class JVPEmitter final
52195224
invoker(invoker), activityInfo(getActivityInfo(
52205225
context, original, attr->getIndices(), jvp)),
52215226
differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
5222-
jvp, attr->getIndices(), activityInfo, getBuilder()),
5223-
differentialAndBuilder(initializeDifferentialAndBuilder(
5224-
context, original, attr, &differentialInfo)),
5227+
jvp, attr->getIndices(), activityInfo),
5228+
differentialBuilder(SILBuilder(*createEmptyDifferential(
5229+
context, original, attr, &differentialInfo))),
52255230
diffLocalAllocBuilder(getDifferential()) {
52265231
// Create empty differential function.
52275232
context.getGeneratedFunctions().push_back(&getDifferential());

test/AutoDiff/generics.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,25 @@ extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
274274
}
275275
}
276276

277+
// TF-817: Test remapping `apply` callee types in derivative function context.
278+
struct TF_817<T> {
279+
func foo(_ index: Int) -> T {
280+
fatalError()
281+
}
282+
}
283+
extension TF_817: Differentiable where T: Differentiable {
284+
@differentiating(foo)
285+
func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) {
286+
fatalError()
287+
}
288+
}
289+
extension TF_817 {
290+
@differentiable(wrt: self where T: Differentiable)
291+
public func test(index: Int) -> T {
292+
return self.foo(0) // crash happened here
293+
}
294+
}
295+
277296
// Test layout requirements.
278297

279298
// The layout requirement is "contextual": the requirement is not on `T`, the

0 commit comments

Comments
 (0)