Skip to content

[AutoDiff] remove all-concrete gen sig from more places #32916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,29 @@ bool getBuiltinDifferentiableOrLinearFunctionConfig(
bool getBuiltinDifferentiableOrLinearFunctionConfig(
StringRef operationName, unsigned &arity, bool &throws);

/// Returns the SIL differentiability witness generic signature given the
/// original declaration's generic signature and the derivative generic
/// signature.
///
/// In general, the differentiability witness generic signature is equal to the
/// derivative generic signature.
///
/// Edge case, if two conditions are satisfied:
/// 1. The derivative generic signature is equal to the original generic
/// signature.
/// 2. The derivative generic signature has *all concrete* generic parameters
/// (i.e. all generic parameters are bound to concrete types via same-type
/// requirements).
///
/// Then the differentiability witness generic signature is `nullptr`.
///
/// Both the original and derivative declarations are lowered to SIL functions
/// with a fully concrete type and no generic signature, so the
/// differentiability witness should similarly have no generic signature.
GenericSignature
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
GenericSignature derivativeGenSig);

} // end namespace autodiff

} // end namespace swift
Expand Down
17 changes: 17 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,23 @@ bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
return operationName.empty();
}

GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
GenericSignature origGenSig, GenericSignature derivativeGenSig) {
// If there is no derivative generic signature, return the original generic
// signature.
if (!derivativeGenSig)
return origGenSig;
// If derivative generic signature has all concrete generic parameters and is
// equal to the original generic signature, return `nullptr`.
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
auto origCanGenSig = origGenSig.getCanonicalSignature();
if (origCanGenSig == derivativeCanGenSig &&
derivativeCanGenSig->areAllParamsConcrete())
return GenericSignature();
// Otherwise, return the derivative generic signature.
return derivativeGenSig;
}

Type TangentSpace::getType() const {
switch (kind) {
case Kind::TangentVector:
Expand Down
49 changes: 7 additions & 42 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,43 +935,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
emitDifferentiabilityWitnessesForFunction(constant, F);
}

/// Returns the SIL differentiability witness generic signature given the
/// original declaration's generic signature and the derivative generic
/// signature.
///
/// In general, the differentiability witness generic signature is equal to the
/// derivative generic signature.
///
/// Edge case, if two conditions are satisfied:
/// 1. The derivative generic signature is equal to the original generic
/// signature.
/// 2. The derivative generic signature has *all concrete* generic parameters
/// (i.e. all generic parameters are bound to concrete types via same-type
/// requirements).
///
/// Then the differentiability witness generic signature is `nullptr`.
///
/// Both the original and derivative declarations are lowered to SIL functions
/// with a fully concrete type and no generic signature, so the
/// differentiability witness should similarly have no generic signature.
static GenericSignature
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
GenericSignature derivativeGenSig) {
// If there is no derivative generic signature, return the original generic
// signature.
if (!derivativeGenSig)
return origGenSig;
// If derivative generic signature has all concrete generic parameters and is
// equal to the original generic signature, return `nullptr`.
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
auto origCanGenSig = origGenSig.getCanonicalSignature();
if (origCanGenSig == derivativeCanGenSig &&
derivativeCanGenSig->areAllParamsConcrete())
return GenericSignature();
// Otherwise, return the derivative generic signature.
return derivativeGenSig;
}

void SILGenModule::emitDifferentiabilityWitnessesForFunction(
SILDeclRef constant, SILFunction *F) {
// Visit `@derivative` attributes and generate SIL differentiability
Expand All @@ -992,9 +955,10 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
diffAttr->getDerivativeGenericSignature()) &&
"Type-checking should resolve derivative generic signatures for "
"all original SIL functions with generic signatures");
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
AFD->getGenericSignature(),
diffAttr->getDerivativeGenericSignature());
auto witnessGenSig =
autodiff::getDifferentiabilityWitnessGenericSignature(
AFD->getGenericSignature(),
diffAttr->getDerivativeGenericSignature());
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
witnessGenSig);
emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr,
Expand All @@ -1015,8 +979,9 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
origAFD->getGenericSignature(), AFD->getGenericSignature());
auto witnessGenSig =
autodiff::getDifferentiabilityWitnessGenericSignature(
origAFD->getGenericSignature(), AFD->getGenericSignature());
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
witnessGenSig);
Expand Down
7 changes: 5 additions & 2 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,11 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
silParameterIndices->getNumIndices() <
minimalConfig->parameterIndices->getNumIndices())) {
minimalASTParameterIndices = config.parameterIndices;
minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices,
config.derivativeGenericSignature);
minimalConfig =
AutoDiffConfig(silParameterIndices, config.resultIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(),
config.derivativeGenericSignature));
}
}
return minimalConfig;
Expand Down
16 changes: 11 additions & 5 deletions lib/TBDGen/TBDGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,10 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
Mangle::ASTMangler mangler;
AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices,
config.derivativeGenericSignature};
AutoDiffConfig silConfig{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a new SILDeclRef-like abstraction that can allow this logic to be shared between TBDGen and SIL?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SILDeclRef already stores an (optional) AutoDiffDerivativeFunctionIdentifier, which has almost the same contents as AutoDiffConfig.

We could look into using that with TBDGenVisitor::addSymbol(SILDeclRef) instead of calling TBDGenVisitor::addSymbol(StringRef name) here with manually mangled names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I meant. TBDGen should not duplicate any logic from SILGen/IRGen, except for the top-level visiting of decls. Mangling and visibility should be computed by shared code, or else you'll hit frequent issues where the two are out of sync.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks! I filed SR-13269 to track this issue so we can look into it later, if that's okay.

loweredParamIndices, config.resultIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(), config.derivativeGenericSignature)};
std::string linearMapName =
mangler.mangleAutoDiffLinearMapHelper(declRef.mangle(), kind, silConfig);
addSymbol(linearMapName);
Expand All @@ -542,7 +544,9 @@ void TBDGenVisitor::addAutoDiffDerivativeFunction(
GenericSignature derivativeGenericSignature,
AutoDiffDerivativeFunctionKind kind) {
auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get(
kind, parameterIndices, derivativeGenericSignature,
kind, parameterIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(), derivativeGenericSignature),
original->getASTContext());
auto declRef =
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
Expand All @@ -569,8 +573,10 @@ void TBDGenVisitor::addDifferentiabilityWitness(
original->getInterfaceType()->castTo<AnyFunctionType>());

auto originalMangledName = declRef.mangle();
AutoDiffConfig config{silParamIndices, resultIndices,
derivativeGenericSignature};
AutoDiffConfig config{
silParamIndices, resultIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(), derivativeGenericSignature)};
SILDifferentiabilityWitnessKey key(originalMangledName, config);

Mangle::ASTMangler mangler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,14 @@ class Class: Differentiable {
set {}
}
}

struct S: Differentiable {
var value: Float
}

extension Array where Element == S {
@differentiable
func sum() -> Float {
return 0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ func classRequirementSetters(_ x: inout Class, _ newValue: Float) {
x.property = newValue
x[] = newValue
}

// Test cross-file lookup of a derivative function with all-concrete derivative generic signature.
@differentiable
func allConcreteDerivativeGenericSignature(_ a: [S]) -> Float {
// No error expected.
return a.sum()
}
9 changes: 8 additions & 1 deletion test/AutoDiff/TBD/derivative_symbols.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
fatalError()
}

struct Struct: Differentiable {
public struct Struct: Differentiable {
var stored: Float

// Test property.
Expand Down Expand Up @@ -54,3 +54,10 @@ struct Struct: Differentiable {
fatalError()
}
}

extension Array where Element == Struct {
@differentiable
public func sum() -> Float {
return 0
}
}