Skip to content

[AutoDiff] Fix crash during generic curry thunk cloning. #26404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 8 additions & 31 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,6 @@ template <typename T> static inline void debugDump(T &v) {
<< v << "\n==== END DEBUG DUMP ====\n");
}

/// Creates arguments in the entry block based on the function type.
static void createEntryArguments(SILFunction *f) {
auto *entry = f->getEntryBlock();
auto conv = f->getConventions();
auto &ctx = f->getASTContext();
auto moduleDecl = f->getModule().getSwiftModule();
assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) &&
"Entry already has arguments?!");
auto createFunctionArgument = [&](SILType type) {
// Create a dummy parameter declaration.
// Necessary to prevent crash during argument explosion optimization.
auto loc = f->getLocation().getSourceLoc();
auto *decl = new (ctx)
ParamDecl(VarDecl::Specifier::Default, loc, loc, Identifier(), loc,
Identifier(), moduleDecl);
decl->setType(type.getASTType());
entry->createFunctionArgument(type, decl);
};
for (auto indResTy : conv.getIndirectSILResultTypes())
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
for (auto paramTy : conv.getParameterSILTypes())
createFunctionArgument(f->mapTypeIntoContext(paramTy));
}

static bool isWithoutDerivative(SILValue v) {
if (auto *fnRef = dyn_cast<FunctionRefInst>(v))
return fnRef->getReferencedFunctionOrNull()->hasSemanticsAttr(
Expand Down Expand Up @@ -2982,9 +2958,8 @@ class VJPEmitter final
original->isBare(), IsNotTransparent, original->isSerialized(),
original->isDynamicallyReplaceable());
pullback->setOwnershipEliminated();
pullback->setDebugScope(new (module)
SILDebugScope(original->getLocation(),
pullback));
pullback->setDebugScope(
new (module) SILDebugScope(original->getLocation(), pullback));
return pullback;
}

Expand Down Expand Up @@ -3655,7 +3630,7 @@ class JVPEmitter final
void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
// Clone `autodiff_function` from original to JVP, then add the cloned
// instruction to the `autodiff_function` worklist.
SILClonerWithScopes::visitAutoDiffFunctionInst(adfi);
TypeSubstCloner::visitAutoDiffFunctionInst(adfi);
auto *newADFI = cast<AutoDiffFunctionInst>(getOpValue(adfi));
context.getAutoDiffFunctionInsts().push_back(newADFI);
}
Expand Down Expand Up @@ -6566,10 +6541,12 @@ SILValue ADContext::promoteToDifferentiableFunction(
// returned function value with an `autodiff_function` instruction,
// and process the `autodiff_function` instruction.
if (newThunk->empty()) {
if (auto newThunkGenSig = thunkType->getGenericSignature())
newThunk->setGenericEnvironment(
newThunkGenSig->createGenericEnvironment());
newThunk->setOwnershipEliminated();
SILFunctionCloner cloner(newThunk);
cloner.cloneFunction(thunk);

BasicTypeSubstCloner cloner(thunk, newThunk);
cloner.run();
auto *retInst =
cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
SILBuilder thunkBuilder(retInst);
Expand Down
53 changes: 53 additions & 0 deletions lib/SILOptimizer/Mandatory/Differentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H

#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Utils/Local.h"

Expand Down Expand Up @@ -88,6 +89,58 @@ class PostOrderPostDominanceOrder {
}
};

/// Creates arguments in the entry block based on the function type.
void createEntryArguments(SILFunction *f) {
auto *entry = f->getEntryBlock();
auto conv = f->getConventions();
auto &ctx = f->getASTContext();
auto moduleDecl = f->getModule().getSwiftModule();
assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) &&
"Entry already has arguments?!");
auto createFunctionArgument = [&](SILType type) {
// Create a dummy parameter declaration.
// Necessary to prevent crash during argument explosion optimization.
auto loc = f->getLocation().getSourceLoc();
auto *decl = new (ctx)
ParamDecl(VarDecl::Specifier::Default, loc, loc, Identifier(), loc,
Identifier(), moduleDecl);
decl->setType(type.getASTType());
entry->createFunctionArgument(type, decl);
};
for (auto indResTy : conv.getIndirectSILResultTypes())
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
for (auto paramTy : conv.getParameterSILTypes())
createFunctionArgument(f->mapTypeIntoContext(paramTy));
}

/// Cloner that remaps types using the target function's generic environment.
class BasicTypeSubstCloner final
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {

static SubstitutionMap getSubstitutionMap(SILFunction *target) {
if (auto *targetGenEnv = target->getGenericEnvironment())
return targetGenEnv->getForwardingSubstitutionMap();
return SubstitutionMap();
}

public:
explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target)
: TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {}

void postProcess(SILInstruction *orig, SILInstruction *cloned) {
SILClonerWithScopes::postProcess(orig, cloned);
}

void run() {
auto &target = Builder.getFunction();
auto *entry = target.createBasicBlock();
createEntryArguments(&target);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SILCloner already knows how to clone function arguments. I’ve been leaning towards getting rid of ‘createEntryArguments’ instead of using it in more places.

Copy link
Contributor

@rxwei rxwei Jul 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All you need to overload in your new class is type remapping. No need to define a ‘run()’ or to overload ‘postProcess()’ or ‘visit()’.

Copy link
Contributor Author

@dan-zheng dan-zheng Jul 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My notes:

  • SILCloner::cloneFunction does clone function arguments, but is not suitable for TypeSubstCloner because it does not remap argument types.
    • Perhaps a good patch would be to add something like TypeSubstCloner::cloneFunction upstream that does BasicTypeSubstCloner::run(). AFAIK currently all subclasses duplicate the same "create entry and function arguments" logic. For now, I defined BasicTypeSubstCloner::run() for convenience.
  • Overriding TypeSubstCloner::postProcess is necessary:
    void postProcess(SILInstruction *Orig, SILInstruction *Cloned) {
      llvm_unreachable("Clients need to explicitly call a base class impl!");
    }
    
  • Not overriding TypeSubstCloner::visit is fine, thanks!

Please let me know if you have suggestions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. It’d be good to include what the original problem was and this context in the PR message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR description, thanks!

SmallVector<SILValue, 8> entryArguments(target.getArguments().begin(),
target.getArguments().end());
cloneFunctionBody(&Original, entry, entryArguments);
}
};

} // end namespace swift

#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
18 changes: 18 additions & 0 deletions test/AutoDiff/generics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,22 @@ extension TF_682_Proto where Self : Differentiable,
}
}

// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
reduction(x)
}

// TODO: add more tests.