Skip to content

assocFn -> derivativeFn everywhere except Differentiation.cpp. #27597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,14 @@ template<typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
DifferentiableFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
Optional<std::pair<SILValue, SILValue>> assocFns = None;
Optional<std::pair<SILValue, SILValue>> derivativeFns = None;
if (Inst->hasDerivativeFunctions())
assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
getOpValue(Inst->getVJPFunction()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
getOpValue(Inst->getVJPFunction()));
getOpValue(Inst->getVJPFunction()));

recordClonedInstruction(
Inst, getBuilder().createDifferentiableFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
getOpValue(Inst->getOriginalFunction()), assocFns));
getOpValue(Inst->getOriginalFunction()), derivativeFns));
}

template<typename ImplClass>
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,10 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
// Generator for the resultant function type, i.e. the AD associated function.
BuiltinGenericSignatureBuilder::LambdaGenerator resultGen{
[=, &Context](BuiltinGenericSignatureBuilder &builder) -> Type {
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
paramIndices, /*resultIndex*/ 0, kind,
LookUpConformanceInModule(Context.TheBuiltinModule));
return assocFnTy->getResult();
return derivativeFnTy->getResult();
}};
builder.addParameter(firstArgGen);
for (auto argGen : fnArgGens)
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,9 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
autoDiffAssociatedFunctionIdentifier->getParameterIndices(),
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
auto assocFnKind = autoDiffAssociatedFunctionIdentifier->getKind();
auto derivativeFnKind = autoDiffAssociatedFunctionIdentifier->getKind();
return mangler.mangleAutoDiffAssociatedFunctionHelper(
originalMangled, assocFnKind, indices);
originalMangled, derivativeFnKind, indices);
}

// As a special case, Clang functions and globals don't get mangled at all.
Expand Down
25 changes: 13 additions & 12 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
// given an existing associated function generic signature. All differentiation
// parameters are constrained to conform to `Differentiable`.
static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
CanGenericSignature assocFnGenSig,
CanGenericSignature derivativeFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
AutoDiffIndexSubset *parameterIndices, ModuleDecl *module) {
if (!assocFnGenSig)
if (!derivativeFnGenSig)
return nullptr;
auto &ctx = module->getASTContext();
GenericSignatureBuilder builder(ctx);

// Add associated function generic signature.
builder.addGenericSignature(assocFnGenSig);
builder.addGenericSignature(derivativeFnGenSig);
// Constrain all wrt parameters to conform to `Differentiable`.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
Expand All @@ -182,7 +182,8 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffAssociatedFunctionKind kind, TypeConverter &TC,
LookupConformanceFn lookupConformance, CanGenericSignature assocFnGenSig) {
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFnGenSig) {
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
Expand All @@ -203,11 +204,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
wrtParams.push_back(valueAndIndex.value());

// Get the canonical associated function generic signature.
if (!assocFnGenSig)
assocFnGenSig = getGenericSignature();
assocFnGenSig = getAutoDiffAssociatedFunctionGenericSignature(
assocFnGenSig, getParameters(), parameterIndices, &TC.M);
Lowering::GenericContextScope genericContextScope(TC, assocFnGenSig);
if (!derivativeFnGenSig)
derivativeFnGenSig = getGenericSignature();
derivativeFnGenSig = getAutoDiffAssociatedFunctionGenericSignature(
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
Lowering::GenericContextScope genericContextScope(TC, derivativeFnGenSig);

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult = [&](
Expand Down Expand Up @@ -310,12 +311,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
newResults.reserve(getNumResults() + 1);
for (auto &result : getResults()) {
auto mappedResult = result.getWithType(
result.getType()->getCanonicalType(assocFnGenSig));
result.getType()->getCanonicalType(derivativeFnGenSig));
newResults.push_back(mappedResult);
}
newResults.push_back({closureType->getCanonicalType(assocFnGenSig),
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
ResultConvention::Owned});
return SILFunctionType::get(assocFnGenSig, getExtInfo(),
return SILFunctionType::get(derivativeFnGenSig, getExtInfo(),
getCoroutineKind(), getCalleeConvention(),
getParameters(), getYields(), newResults,
getOptionalErrorResult(), ctx,
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,10 @@ namespace {
for (AutoDiffAssociatedFunctionKind kind :
{AutoDiffAssociatedFunctionKind::JVP,
AutoDiffAssociatedFunctionKind::VJP}) {
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
paramIndices, 0, kind, TC,
LookUpConformanceInModule(&TC.M));
auto silTy = SILType::getPrimitiveObjectType(assocFnTy);
auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy);
DifferentiableFunctionExtractee extractee(kind);
// Assert that we have the right extractee. A terrible bug in the past
// was caused by implicit conversions from `unsigned` to
Expand Down
14 changes: 7 additions & 7 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff associated function thunk for the given
/// SILDeclRef, SILFunction, and associated function type.
SILFunction *getOrCreateAutoDiffThunk(SILDeclRef assocFnRef,
SILFunction *assocFn,
CanSILFunctionType assocFnTy);
SILFunction *getOrCreateAutoDiffThunk(SILDeclRef derivativeFnRef,
SILFunction *derivativeFn,
CanSILFunctionType derivativeFnTy);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff associated function vtable entry thunk for the
/// given SILDeclRef and associated function type.
SILFunction *
getOrCreateAutoDiffClassMethodThunk(SILDeclRef assocFnRef,
CanSILFunctionType assocFnTy);
getOrCreateAutoDiffClassMethodThunk(SILDeclRef derivativeFnRef,
CanSILFunctionType derivativeFnTy);

/// Emit a vtable thunk for a derived method if its natural abstraction level
/// diverges from the overridden base method. If no thunking is needed,
Expand Down Expand Up @@ -187,8 +187,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// - The last result in the returned pullback.
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
bool reorderSelf);
SILFunction *derivativeFn,
AutoDiffAssociatedFunctionKind derivativeFnKind, bool reorderSelf);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down
36 changes: 18 additions & 18 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,19 +1041,19 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
origFnArgVals.push_back(arg.getValue());

// Get the associated function.
SILValue assocFn = SGF.B.createDifferentiableFunctionExtract(
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
loc, kind, origFnVal);
auto assocFnType = assocFn->getType().castTo<SILFunctionType>();
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();

// We don't need to destroy the original function or retain the `assocFn`,
// because they are trivial (because they are @noescape).
// We don't need to destroy the original function or retain the
// `derivativeFn`, because they are trivial (because they are @noescape).
assert(origFnVal->getType().isTrivial(SGF.F));
assert(assocFn->getType().isTrivial(SGF.F));
bool assocFnNeedsDestroy = false;
assert(derivativeFn->getType().isTrivial(SGF.F));
bool derivativeFnNeedsDestroy = false;

// Unwrap curry levels.
SmallVector<SILFunctionType *, 2> curryLevels;
SILFunctionType *currentLevel = assocFnType;
SILFunctionType *currentLevel = derivativeFnType;
unsigned numParameters = 0;
while (currentLevel != nullptr) {
curryLevels.push_back(currentLevel);
Expand All @@ -1074,25 +1074,25 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
#endif

// Apply all the curry levels except the last one, whose results we handle
// specially. We overwrite `assocFn` with the application results.
// specially. We overwrite `derivativeFn` with the application results.
unsigned currentParameter = 0;
auto curryLevelsWithoutLast =
ArrayRef<SILFunctionType *>(curryLevels).drop_back(1);
for (auto *curryLevel : curryLevelsWithoutLast) {
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
currentParameter, curryLevel->getNumParameters());
auto applyResult = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
loc, derivativeFn, SubstitutionMap(), curryLevelArgVals,
/*isNonThrowing*/ false);
currentParameter += curryLevel->getNumParameters();

assocFn = applyResult;
derivativeFn = applyResult;

// Our new `assocFn` needs to be released because it's an owned result from
// a function call.
// Our new `derivativeFn` needs to be released because it's an owned result
// from a function call.
assert(curryLevel->getSingleResult().getConvention() ==
ResultConvention::Owned);
assocFnNeedsDestroy = true;
derivativeFnNeedsDestroy = true;
}

assert(curryLevels.back()->getNumResults() == 2);
Expand All @@ -1109,10 +1109,10 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
currentParameter);
for (auto origFnArgVal : curryLevelArgVals)
applyArgs.push_back(origFnArgVal);
auto differential = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), applyArgs, /*isNonThrowing*/ false);
auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(),
applyArgs, /*isNonThrowing*/ false);

assocFn = SILValue();
derivativeFn = SILValue();

SGF.B.createStore(loc, differential,
SGF.B.createTupleElementAddr(loc, indResBuffer, 1),
Expand All @@ -1125,10 +1125,10 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
currentParameter);
auto resultTuple = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
loc, derivativeFn, SubstitutionMap(), curryLevelArgVals,
/*isNonThrowing*/ false);

assocFn = SILValue();
derivativeFn = SILValue();

return SGF.emitManagedRValueWithCleanup(resultTuple);
}
Expand Down
72 changes: 39 additions & 33 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3311,22 +3311,24 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
return AbstractionPattern(
pattern.getGenericSignature(), getAssocFnTy(patternType, kind));
};
auto createAssocFnThunk = [&](AutoDiffAssociatedFunctionKind kind)
-> ManagedValue {
auto assocFnInputOrigType = getAssocFnPattern(inputOrigTypeNotDiff, kind);
auto assocFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind);
auto assocFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff,
auto createAssocFnThunk =
[&](AutoDiffAssociatedFunctionKind kind) -> ManagedValue {
auto derivativeFnInputOrigType =
getAssocFnPattern(inputOrigTypeNotDiff, kind);
auto derivativeFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind);
auto derivativeFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff,
kind);
auto assocFnOutputSubstType = getAssocFnTy(outputSubstTypeNotDiff, kind);
auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType,
assocFnOutputSubstType);
SILValue assocFn = SGF.B.createDifferentiableFunctionExtract(
auto derivativeFnOutputSubstType =
getAssocFnTy(outputSubstTypeNotDiff, kind);
auto &derivativeFnExpectedTL = SGF.getTypeLowering(
derivativeFnOutputOrigType, derivativeFnOutputSubstType);
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
loc, kind, borrowedFnValue.getValue());
assocFn = SGF.B.emitCopyValueOperation(loc, assocFn);
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(assocFn);
return createThunk(SGF, loc, managedAssocFn, assocFnInputOrigType,
assocFnInputSubstType, assocFnOutputOrigType,
assocFnOutputSubstType, assocFnExpectedTL);
derivativeFn = SGF.B.emitCopyValueOperation(loc, derivativeFn);
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(derivativeFn);
return createThunk(SGF, loc, managedAssocFn, derivativeFnInputOrigType,
derivativeFnInputSubstType, derivativeFnOutputOrigType,
derivativeFnOutputSubstType, derivativeFnExpectedTL);
};

auto jvpThunk = createAssocFnThunk(AutoDiffAssociatedFunctionKind::JVP);
Expand Down Expand Up @@ -3666,59 +3668,61 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
SILFunction *
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
SILFunction *derivativeFn, AutoDiffAssociatedFunctionKind derivativeFnKind,
bool reorderSelf) {
auto assocFnType = assocFn->getLoweredFunctionType();
auto derivativeFnType = derivativeFn->getLoweredFunctionType();

// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
Mangle::ASTMangler mangler;
auto name = getASTContext().getIdentifier(
mangler.mangleAutoDiffAssociatedFunctionHelper(
original->getName(), assocFnKind, indices)).str();
original->getName(), derivativeFnKind, indices)).str();

Lowering::GenericContextScope genericContextScope(
Types, assocFnType->getGenericSignature());
auto *thunkGenericEnv = assocFnType->getGenericSignature()
? assocFnType->getGenericSignature()->getGenericEnvironment()
Types, derivativeFnType->getGenericSignature());
auto *thunkGenericEnv = derivativeFnType->getGenericSignature()
? derivativeFnType->getGenericSignature()->getGenericEnvironment()
: nullptr;

auto origFnType = original->getLoweredFunctionType();
auto origAssocFnType = origFnType->getAutoDiffAssociatedFunctionType(
indices.parameters, indices.source,
assocFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
assocFnType->getGenericSignature());
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
derivativeFnType->getGenericSignature());
assert(!origAssocFnType->getExtInfo().hasContext());

auto loc = assocFn->getLocation();
auto loc = derivativeFn->getLocation();
SILGenFunctionBuilder fb(*this);
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
original->getLinkage(), /*isAssocFnExported*/ true);
auto *thunk = fb.getOrCreateFunction(
loc, name, linkage, origAssocFnType, IsBare, IsNotTransparent,
assocFn->isSerialized(), assocFn->isDynamicallyReplaceable(),
assocFn->getEntryCount(), assocFn->isThunk(),
assocFn->getClassSubclassScope());
derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(),
derivativeFn->getEntryCount(), derivativeFn->isThunk(),
derivativeFn->getClassSubclassScope());
if (!thunk->empty())
return thunk;
thunk->setGenericEnvironment(thunkGenericEnv);

SILGenFunction thunkSGF(*this, *thunk, assocFn->getDeclContext());
SILGenFunction thunkSGF(*this, *thunk, derivativeFn->getDeclContext());
SmallVector<ManagedValue, 4> params;
SmallVector<SILArgument *, 4> indirectResults;
thunkSGF.collectThunkParams(loc, params, &indirectResults);

auto *assocFnRef = thunkSGF.B.createFunctionRef(loc, assocFn);
auto assocFnRefType = assocFnRef->getType().castTo<SILFunctionType>();
auto *derivativeFnRef = thunkSGF.B.createFunctionRef(loc, derivativeFn);
auto derivativeFnRefType =
derivativeFnRef->getType().castTo<SILFunctionType>();

// Collect thunk arguments, converting ownership.
SmallVector<SILValue, 8> arguments;
for (auto *indRes : indirectResults)
arguments.push_back(indRes);
forwardFunctionArguments(thunkSGF, loc, assocFnRefType, params, arguments);
forwardFunctionArguments(thunkSGF, loc, derivativeFnRefType, params,
arguments);
// Apply function argument.
auto apply = thunkSGF.emitApplyWithRethrow(
loc, assocFnRef, /*substFnType*/ assocFnRef->getType(),
loc, derivativeFnRef, /*substFnType*/ derivativeFnRef->getType(),
thunk->getForwardingSubstitutionMap(), arguments);

// Create return instruction in the thunk, first deallocating local
Expand All @@ -3734,7 +3738,9 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
// If self ordering is not necessary and linear map types are unchanged,
// return the `apply` instruction.
auto linearMapFnType = cast<SILFunctionType>(
thunk->mapTypeIntoContext(assocFnRefType->getResults().back().getType())
thunk
->mapTypeIntoContext(
derivativeFnRefType->getResults().back().getType())
->getCanonicalType());
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
origAssocFnType->getResults().back().getSILStorageType())
Expand All @@ -3749,7 +3755,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
extractAllElements(apply, loc, thunkSGF.B, directResults);
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
auto linearMapKind = assocFnKind.getLinearMapKind();
auto linearMapKind = derivativeFnKind.getLinearMapKind();
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
reorderSelf);
Expand Down
Loading