Skip to content

Commit 383482b

Browse files
authored
[AutoDiff] Fix derivative generic signature same-type requirements. (#32803)
Fix SILGen for `@derivative` attributes where the derivative generic signature is equal to the original generic signature and has all concrete generic parameters (i.e. all generic parameters are bound to concrete types via same-type requirements). SILGen should emit a differentiability witness with no generic signature. This is already done for `@differentiable` attributes. Resolves TF-1292.
1 parent d966dbb commit 383482b

File tree

3 files changed

+221
-21
lines changed

3 files changed

+221
-21
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,43 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
935935
emitDifferentiabilityWitnessesForFunction(constant, F);
936936
}
937937

938+
/// Returns the SIL differentiability witness generic signature given the
939+
/// original declaration's generic signature and the derivative generic
940+
/// signature.
941+
///
942+
/// In general, the differentiability witness generic signature is equal to the
943+
/// derivative generic signature.
944+
///
945+
/// Edge case, if two conditions are satisfied:
946+
/// 1. The derivative generic signature is equal to the original generic
947+
/// signature.
948+
/// 2. The derivative generic signature has *all concrete* generic parameters
949+
/// (i.e. all generic parameters are bound to concrete types via same-type
950+
/// requirements).
951+
///
952+
/// Then the differentiability witness generic signature is `nullptr`.
953+
///
954+
/// Both the original and derivative declarations are lowered to SIL functions
955+
/// with a fully concrete type and no generic signature, so the
956+
/// differentiability witness should similarly have no generic signature.
957+
static GenericSignature
958+
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
959+
GenericSignature derivativeGenSig) {
960+
// If there is no derivative generic signature, return the original generic
961+
// signature.
962+
if (!derivativeGenSig)
963+
return origGenSig;
964+
// If derivative generic signature has all concrete generic parameters and is
965+
// equal to the original generic signature, return `nullptr`.
966+
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
967+
auto origCanGenSig = origGenSig.getCanonicalSignature();
968+
if (origCanGenSig == derivativeCanGenSig &&
969+
derivativeCanGenSig->areAllParamsConcrete())
970+
return GenericSignature();
971+
// Otherwise, return the derivative generic signature.
972+
return derivativeGenSig;
973+
}
974+
938975
void SILGenModule::emitDifferentiabilityWitnessesForFunction(
939976
SILDeclRef constant, SILFunction *F) {
940977
// Visit `@derivative` attributes and generate SIL differentiability
@@ -955,8 +992,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
955992
diffAttr->getDerivativeGenericSignature()) &&
956993
"Type-checking should resolve derivative generic signatures for "
957994
"all original SIL functions with generic signatures");
995+
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
996+
AFD->getGenericSignature(),
997+
diffAttr->getDerivativeGenericSignature());
958998
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
959-
diffAttr->getDerivativeGenericSignature());
999+
witnessGenSig);
9601000
emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr,
9611001
/*vjp*/ nullptr, diffAttr);
9621002
}
@@ -975,10 +1015,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
9751015
auto origDeclRef =
9761016
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
9771017
auto *origFn = getFunction(origDeclRef, NotForDefinition);
978-
auto derivativeGenSig = AFD->getGenericSignature();
1018+
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
1019+
origAFD->getGenericSignature(), AFD->getGenericSignature());
9791020
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
9801021
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
981-
derivativeGenSig);
1022+
witnessGenSig);
9821023
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
9831024
derivAttr);
9841025
}

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,24 +4194,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
41944194
attr->getLocation(), /*allowConcreteGenericParams=*/true);
41954195
}
41964196

4197-
// Set the resolved derivative generic signature in the attribute.
4198-
// Do not set the derivative generic signature if the original function's
4199-
// generic signature is equal to `derivativeGenSig` and all generic parameters
4200-
// are concrete. In that case, the original function and derivative functions
4201-
// are all lowered as SIL functions with no generic signature (specialized
4202-
// with concrete types from same-type requirements), so the derivative generic
4203-
// signature should not be set.
4204-
auto skipDerivativeGenericSignature = [&] {
4205-
auto origCanGenSig =
4206-
original->getGenericSignature().getCanonicalSignature();
4207-
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
4208-
if (!derivativeCanGenSig)
4209-
return false;
4210-
return origCanGenSig == derivativeCanGenSig &&
4211-
derivativeCanGenSig->areAllParamsConcrete();
4212-
};
4213-
if (skipDerivativeGenericSignature())
4214-
derivativeGenSig = GenericSignature();
42154197
attr->setDerivativeGenericSignature(derivativeGenSig);
42164198
return false;
42174199
}
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
2+
// RUN: %target-swift-emit-sil -verify -module-name main %s
3+
4+
// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
5+
6+
// This file tests:
7+
// - The "derivative generic signature" of `@differentiable` and `@derivative`
8+
// attributes.
9+
// - The generic signature of lowered SIL differentiability witnesses.
10+
11+
// Context:
12+
// - For `@differentiable` attributes: the derivative generic signature is
13+
// resolved from the original declaration's generic signature and additional
14+
// `where` clause requirements.
15+
// - For `@derivative` attributes: the derivative generic signature is the
16+
// attributed declaration's generic signature.
17+
18+
import _Differentiation
19+
20+
//===----------------------------------------------------------------------===//
21+
// Same-type requirements
22+
//===----------------------------------------------------------------------===//
23+
24+
// Test original declaration with a generic signature and derivative generic
25+
// signature where all generic parameters are concrete (i.e. bound to concrete
26+
// types via same-type requirements).
27+
28+
struct AllConcrete<T>: Differentiable {}
29+
30+
extension AllConcrete {
31+
// Original generic signature: `<T>`
32+
// Derivative generic signature: `<T where T == Float>`
33+
// Witness generic signature: `<T where T == Float>`
34+
@_silgen_name("allconcrete_where_gensig_constrained")
35+
@differentiable(where T == Float)
36+
func whereClauseGenericSignatureConstrained() -> AllConcrete {
37+
return self
38+
}
39+
}
40+
extension AllConcrete where T == Float {
41+
@derivative(of: whereClauseGenericSignatureConstrained)
42+
func jvpWhereClauseGenericSignatureConstrained() -> (
43+
value: AllConcrete, differential: (TangentVector) -> TangentVector
44+
) {
45+
(whereClauseGenericSignatureConstrained(), { $0 })
46+
}
47+
}
48+
49+
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
50+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
51+
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
52+
// CHECK-NEXT: }
53+
54+
// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
55+
// 1. The derivative generic signature is equal to the original generic signature.
56+
// 2. The derivative generic signature has *all concrete* generic parameters.
57+
//
58+
// Then the attribute should be lowered to a SIL differentiability witness with
59+
// *no* derivative generic signature.
60+
61+
extension AllConcrete where T == Float {
62+
// Original generic signature: `<T where T == Float>`
63+
// Derivative generic signature: `<T where T == Float>`
64+
// Witness generic signature: none
65+
@_silgen_name("allconcrete_original_gensig")
66+
@differentiable
67+
func originalGenericSignature() -> AllConcrete {
68+
return self
69+
}
70+
71+
@derivative(of: originalGenericSignature)
72+
func jvpOriginalGenericSignature() -> (
73+
value: AllConcrete, differential: (TangentVector) -> TangentVector
74+
) {
75+
(originalGenericSignature(), { $0 })
76+
}
77+
78+
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
79+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
80+
// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
81+
// CHECK-NEXT: }
82+
83+
// Original generic signature: `<T where T == Float>`
84+
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
85+
// Witness generic signature: none
86+
@_silgen_name("allconcrete_where_gensig")
87+
@differentiable(where T == Float)
88+
func whereClauseGenericSignature() -> AllConcrete {
89+
return self
90+
}
91+
92+
@derivative(of: whereClauseGenericSignature)
93+
func jvpWhereClauseGenericSignature() -> (
94+
value: AllConcrete, differential: (TangentVector) -> TangentVector
95+
) {
96+
(whereClauseGenericSignature(), { $0 })
97+
}
98+
99+
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
100+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
101+
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
102+
// CHECK-NEXT: }
103+
}
104+
105+
// Test original declaration with a generic signature and derivative generic
106+
// signature where *not* all generic parameters are concrete.
107+
// types via same-type requirements).
108+
109+
struct NotAllConcrete<T, U>: Differentiable {}
110+
111+
extension NotAllConcrete {
112+
// Original generic signature: `<T, U>`
113+
// Derivative generic signature: `<T, U where T == Float>`
114+
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
115+
@_silgen_name("notallconcrete_where_gensig_constrained")
116+
@differentiable(where T == Float)
117+
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
118+
return self
119+
}
120+
}
121+
extension NotAllConcrete where T == Float {
122+
@derivative(of: whereClauseGenericSignatureConstrained)
123+
func jvpWhereClauseGenericSignatureConstrained() -> (
124+
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
125+
) {
126+
(whereClauseGenericSignatureConstrained(), { $0 })
127+
}
128+
}
129+
130+
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
131+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
132+
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
133+
// CHECK-NEXT: }
134+
135+
extension NotAllConcrete where T == Float {
136+
// Original generic signature: `<T, U where T == Float>`
137+
// Derivative generic signature: `<T, U where T == Float>`
138+
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
139+
@_silgen_name("notallconcrete_original_gensig")
140+
@differentiable
141+
func originalGenericSignature() -> NotAllConcrete {
142+
return self
143+
}
144+
145+
@derivative(of: originalGenericSignature)
146+
func jvpOriginalGenericSignature() -> (
147+
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
148+
) {
149+
(originalGenericSignature(), { $0 })
150+
}
151+
152+
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
153+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
154+
// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
155+
// CHECK-NEXT: }
156+
157+
// Original generic signature: `<T, U where T == Float>`
158+
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
159+
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
160+
@_silgen_name("notallconcrete_where_gensig")
161+
@differentiable(where T == Float)
162+
func whereClauseGenericSignature() -> NotAllConcrete {
163+
return self
164+
}
165+
166+
@derivative(of: whereClauseGenericSignature)
167+
func jvpWhereClauseGenericSignature() -> (
168+
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
169+
) {
170+
(whereClauseGenericSignature(), { $0 })
171+
}
172+
173+
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
174+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
175+
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
176+
// CHECK-NEXT: }
177+
}

0 commit comments

Comments
 (0)