Skip to content

Commit 2161111

Browse files
authored
Explicitly use minimal type expansion for autodiff-related types (e.g. parameters and pullback result types) (#77831)
As autodiff happens on function types it is not in general possible to determine the real expansion context of the function being differentiated. Use of minimal context is a conservative approach that should work even when libraty evolution mode is enabled. Fixes #55179
1 parent e131682 commit 2161111

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,17 @@ class PullbackCloner::Implementation final
215215
//--------------------------------------------------------------------------//
216216

217217
/// Get the type lowering for the given AST type.
218+
///
219+
/// Explicitly use minimal type expansion context: in general, differentiation
220+
/// happens on function types, so it cannot know if the original function is
221+
/// resilient or not.
218222
const Lowering::TypeLowering &getTypeLowering(Type type) {
219223
auto pbGenSig =
220224
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
221225
Lowering::AbstractionPattern pattern(pbGenSig,
222226
type->getReducedType(pbGenSig));
223-
return getPullback().getTypeLowering(pattern, type);
227+
return getContext().getTypeConverter().getTypeLowering(
228+
pattern, type, TypeExpansionContext::minimal());
224229
}
225230

226231
/// Remap any archetypes into the current function's context.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %target-swift-frontend -c -enable-library-evolution %s
2+
3+
// https://github.com/swiftlang/swift/issues/55179
4+
// Explicitly use minimal type expansion on autodiff-related types.
5+
// Autodiff happens on function types, so in general it does not know
6+
// if the function in question is resilient or not. Using minimal expansion
7+
// provides an universally conservative approach.
8+
9+
import _Differentiation
10+
11+
public class Tracked<T> {}
12+
extension Tracked: Differentiable where T: Differentiable {}
13+
14+
@differentiable(reverse)
15+
func callback(_ x: inout Tracked<Float>.TangentVector) {}
16+
17+
extension Differentiable {
18+
/// Applies the given closure to the derivative of `self`.
19+
///
20+
/// Returns `self` like an identity function. When the return value is used in
21+
/// a context where it is differentiated with respect to, applies the given
22+
/// closure to the derivative of the return value.
23+
@inlinable
24+
@differentiable(reverse, wrt: self)
25+
func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
26+
return self
27+
}
28+
29+
@inlinable
30+
@derivative(of: withDerivative)
31+
internal func _vjpWithDerivative(
32+
_ body: @escaping (inout TangentVector) -> Void
33+
) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
34+
return (self, { grad in
35+
var grad = grad
36+
body(&grad)
37+
return grad
38+
})
39+
}
40+
}
41+
42+
@differentiable(reverse)
43+
public func caller(_ x: Tracked<Float>) -> Tracked<Float> {
44+
return x.withDerivative(callback)
45+
}

0 commit comments

Comments
 (0)