Skip to content

[AutoDiff] Fix derivative generic signature same-type requirements. #32803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,43 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
emitDifferentiabilityWitnessesForFunction(constant, F);
}

/// Returns the SIL differentiability witness generic signature given the
/// original declaration's generic signature and the derivative generic
/// signature.
///
/// In general, the differentiability witness generic signature is equal to the
/// derivative generic signature.
///
/// Edge case, if two conditions are satisfied:
/// 1. The derivative generic signature is equal to the original generic
/// signature.
/// 2. The derivative generic signature has *all concrete* generic parameters
/// (i.e. all generic parameters are bound to concrete types via same-type
/// requirements).
///
/// Then the differentiability witness generic signature is `nullptr`.
///
/// Both the original and derivative declarations are lowered to SIL functions
/// with a fully concrete type and no generic signature, so the
/// differentiability witness should similarly have no generic signature.
static GenericSignature
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
GenericSignature derivativeGenSig) {
// If there is no derivative generic signature, return the original generic
// signature.
if (!derivativeGenSig)
return origGenSig;
// If derivative generic signature has all concrete generic parameters and is
// equal to the original generic signature, return `nullptr`.
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
auto origCanGenSig = origGenSig.getCanonicalSignature();
if (origCanGenSig == derivativeCanGenSig &&
derivativeCanGenSig->areAllParamsConcrete())
return GenericSignature();
// Otherwise, return the derivative generic signature.
return derivativeGenSig;
}

void SILGenModule::emitDifferentiabilityWitnessesForFunction(
SILDeclRef constant, SILFunction *F) {
// Visit `@derivative` attributes and generate SIL differentiability
Expand All @@ -955,8 +992,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
diffAttr->getDerivativeGenericSignature()) &&
"Type-checking should resolve derivative generic signatures for "
"all original SIL functions with generic signatures");
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
AFD->getGenericSignature(),
diffAttr->getDerivativeGenericSignature());
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
witnessGenSig);
emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr,
/*vjp*/ nullptr, diffAttr);
}
Expand All @@ -975,10 +1015,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
auto derivativeGenSig = AFD->getGenericSignature();
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
origAFD->getGenericSignature(), AFD->getGenericSignature());
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
derivativeGenSig);
witnessGenSig);
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
derivAttr);
}
Expand Down
18 changes: 0 additions & 18 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4188,24 +4188,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
attr->getLocation(), /*allowConcreteGenericParams=*/true);
}

// Set the resolved derivative generic signature in the attribute.
// Do not set the derivative generic signature if the original function's
// generic signature is equal to `derivativeGenSig` and all generic parameters
// are concrete. In that case, the original function and derivative functions
// are all lowered as SIL functions with no generic signature (specialized
// with concrete types from same-type requirements), so the derivative generic
// signature should not be set.
auto skipDerivativeGenericSignature = [&] {
auto origCanGenSig =
original->getGenericSignature().getCanonicalSignature();
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
if (!derivativeCanGenSig)
return false;
return origCanGenSig == derivativeCanGenSig &&
derivativeCanGenSig->areAllParamsConcrete();
};
if (skipDerivativeGenericSignature())
derivativeGenSig = GenericSignature();
attr->setDerivativeGenericSignature(derivativeGenSig);
return false;
}
Expand Down
177 changes: 177 additions & 0 deletions test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -module-name main %s

// NOTE(SR-11950): SILParser crashes for SILGen round-trip.

// This file tests:
// - The "derivative generic signature" of `@differentiable` and `@derivative`
// attributes.
// - The generic signature of lowered SIL differentiability witnesses.

// Context:
// - For `@differentiable` attributes: the derivative generic signature is
// resolved from the original declaration's generic signature and additional
// `where` clause requirements.
// - For `@derivative` attributes: the derivative generic signature is the
// attributed declaration's generic signature.

import _Differentiation

//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//

// Test original declaration with a generic signature and derivative generic
// signature where all generic parameters are concrete (i.e. bound to concrete
// types via same-type requirements).

struct AllConcrete<T>: Differentiable {}

extension AllConcrete {
// Original generic signature: `<T>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: `<T where T == Float>`
@_silgen_name("allconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> AllConcrete {
return self
}
}
extension AllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}

// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }

// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
// 1. The derivative generic signature is equal to the original generic signature.
// 2. The derivative generic signature has *all concrete* generic parameters.
//
// Then the attribute should be lowered to a SIL differentiability witness with
// *no* derivative generic signature.

extension AllConcrete where T == Float {
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: none
@_silgen_name("allconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> AllConcrete {
return self
}

@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}

// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }

// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
// Witness generic signature: none
@_silgen_name("allconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> AllConcrete {
return self
}

@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}

// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}

// Test original declaration with a generic signature and derivative generic
// signature where *not* all generic parameters are concrete.
// types via same-type requirements).

struct NotAllConcrete<T, U>: Differentiable {}

extension NotAllConcrete {
// Original generic signature: `<T, U>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
return self
}
}
extension NotAllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}

// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }

extension NotAllConcrete where T == Float {
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> NotAllConcrete {
return self
}

@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}

// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }

// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> NotAllConcrete {
return self
}

@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}

// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}