Skip to content

Commit 60afcf8

Browse files
authored
[AutoDiff] Fix crash during generic curry thunk cloning. (#26404)
Create simple `BasicTypeSubstCloner` inheriting from `TypeSubstCloner`. Use `BasicTypeSubstCloner` to clone generic curry thunks, remapping types. Resolves TF-688.
1 parent 6f58fd4 commit 60afcf8

File tree

3 files changed

+79
-31
lines changed

3 files changed

+79
-31
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,30 +78,6 @@ template <typename T> static inline void debugDump(T &v) {
7878
<< v << "\n==== END DEBUG DUMP ====\n");
7979
}
8080

81-
/// Creates arguments in the entry block based on the function type.
82-
static void createEntryArguments(SILFunction *f) {
83-
auto *entry = f->getEntryBlock();
84-
auto conv = f->getConventions();
85-
auto &ctx = f->getASTContext();
86-
auto moduleDecl = f->getModule().getSwiftModule();
87-
assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) &&
88-
"Entry already has arguments?!");
89-
auto createFunctionArgument = [&](SILType type) {
90-
// Create a dummy parameter declaration.
91-
// Necessary to prevent crash during argument explosion optimization.
92-
auto loc = f->getLocation().getSourceLoc();
93-
auto *decl = new (ctx)
94-
ParamDecl(VarDecl::Specifier::Default, loc, loc, Identifier(), loc,
95-
Identifier(), moduleDecl);
96-
decl->setType(type.getASTType());
97-
entry->createFunctionArgument(type, decl);
98-
};
99-
for (auto indResTy : conv.getIndirectSILResultTypes())
100-
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
101-
for (auto paramTy : conv.getParameterSILTypes())
102-
createFunctionArgument(f->mapTypeIntoContext(paramTy));
103-
}
104-
10581
static bool isWithoutDerivative(SILValue v) {
10682
if (auto *fnRef = dyn_cast<FunctionRefInst>(v))
10783
return fnRef->getReferencedFunctionOrNull()->hasSemanticsAttr(
@@ -3003,9 +2979,8 @@ class VJPEmitter final
30032979
original->isBare(), IsNotTransparent, original->isSerialized(),
30042980
original->isDynamicallyReplaceable());
30052981
pullback->setOwnershipEliminated();
3006-
pullback->setDebugScope(new (module)
3007-
SILDebugScope(original->getLocation(),
3008-
pullback));
2982+
pullback->setDebugScope(
2983+
new (module) SILDebugScope(original->getLocation(), pullback));
30092984
return pullback;
30102985
}
30112986

@@ -3677,7 +3652,7 @@ class JVPEmitter final
36773652
void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
36783653
// Clone `autodiff_function` from original to JVP, then add the cloned
36793654
// instruction to the `autodiff_function` worklist.
3680-
SILClonerWithScopes::visitAutoDiffFunctionInst(adfi);
3655+
TypeSubstCloner::visitAutoDiffFunctionInst(adfi);
36813656
auto *newADFI = cast<AutoDiffFunctionInst>(getOpValue(adfi));
36823657
context.getAutoDiffFunctionInsts().push_back(newADFI);
36833658
}
@@ -6588,10 +6563,12 @@ SILValue ADContext::promoteToDifferentiableFunction(
65886563
// returned function value with an `autodiff_function` instruction,
65896564
// and process the `autodiff_function` instruction.
65906565
if (newThunk->empty()) {
6566+
if (auto newThunkGenSig = thunkType->getGenericSignature())
6567+
newThunk->setGenericEnvironment(
6568+
newThunkGenSig->createGenericEnvironment());
65916569
newThunk->setOwnershipEliminated();
6592-
SILFunctionCloner cloner(newThunk);
6593-
cloner.cloneFunction(thunk);
6594-
6570+
BasicTypeSubstCloner cloner(thunk, newThunk);
6571+
cloner.run();
65956572
auto *retInst =
65966573
cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
65976574
SILBuilder thunkBuilder(retInst);

lib/SILOptimizer/Mandatory/Differentiation.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
2525
#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
2626

27+
#include "swift/SIL/TypeSubstCloner.h"
2728
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
2829
#include "swift/SILOptimizer/Utils/Local.h"
2930

@@ -88,6 +89,58 @@ class PostOrderPostDominanceOrder {
8889
}
8990
};
9091

92+
/// Creates arguments in the entry block based on the function type.
93+
void createEntryArguments(SILFunction *f) {
94+
auto *entry = f->getEntryBlock();
95+
auto conv = f->getConventions();
96+
auto &ctx = f->getASTContext();
97+
auto moduleDecl = f->getModule().getSwiftModule();
98+
assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) &&
99+
"Entry already has arguments?!");
100+
auto createFunctionArgument = [&](SILType type) {
101+
// Create a dummy parameter declaration.
102+
// Necessary to prevent crash during argument explosion optimization.
103+
auto loc = f->getLocation().getSourceLoc();
104+
auto *decl = new (ctx)
105+
ParamDecl(VarDecl::Specifier::Default, loc, loc, Identifier(), loc,
106+
Identifier(), moduleDecl);
107+
decl->setType(type.getASTType());
108+
entry->createFunctionArgument(type, decl);
109+
};
110+
for (auto indResTy : conv.getIndirectSILResultTypes())
111+
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
112+
for (auto paramTy : conv.getParameterSILTypes())
113+
createFunctionArgument(f->mapTypeIntoContext(paramTy));
114+
}
115+
116+
/// Cloner that remaps types using the target function's generic environment.
117+
class BasicTypeSubstCloner final
118+
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {
119+
120+
static SubstitutionMap getSubstitutionMap(SILFunction *target) {
121+
if (auto *targetGenEnv = target->getGenericEnvironment())
122+
return targetGenEnv->getForwardingSubstitutionMap();
123+
return SubstitutionMap();
124+
}
125+
126+
public:
127+
explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target)
128+
: TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {}
129+
130+
void postProcess(SILInstruction *orig, SILInstruction *cloned) {
131+
SILClonerWithScopes::postProcess(orig, cloned);
132+
}
133+
134+
void run() {
135+
auto &target = Builder.getFunction();
136+
auto *entry = target.createBasicBlock();
137+
createEntryArguments(&target);
138+
SmallVector<SILValue, 8> entryArguments(target.getArguments().begin(),
139+
target.getArguments().end());
140+
cloneFunctionBody(&Original, entry, entryArguments);
141+
}
142+
};
143+
91144
} // end namespace swift
92145

93146
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H

test/AutoDiff/generics.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,22 @@ extension TF_682_Proto where Self : Differentiable,
229229
}
230230
}
231231

232+
// TF-688: Test generic curry thunk cloning.
233+
public struct TF_688_Struct<Scalar> {
234+
var x: Scalar
235+
}
236+
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
237+
@differentiable
238+
public static func id(x: Self) -> Self {
239+
return x
240+
}
241+
}
242+
@differentiable(wrt: x)
243+
public func TF_688<Scalar: Differentiable>(
244+
_ x: TF_688_Struct<Scalar>,
245+
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
246+
) -> TF_688_Struct<Scalar> {
247+
reduction(x)
248+
}
249+
232250
// TODO: add more tests.

0 commit comments

Comments
 (0)