Skip to content

Explicitly use minimal type expansion for autodiff-related types (e.g. parameters and pullback result types) #77831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,17 @@ class PullbackCloner::Implementation final
//--------------------------------------------------------------------------//

/// Get the type lowering for the given AST type.
///
/// Explicitly use minimal type expansion context: in general, differentiation
/// happens on function types, so it cannot know if the original function is
/// resilient or not.
const Lowering::TypeLowering &getTypeLowering(Type type) {
auto pbGenSig =
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
Lowering::AbstractionPattern pattern(pbGenSig,
type->getReducedType(pbGenSig));
return getPullback().getTypeLowering(pattern, type);
return getContext().getTypeConverter().getTypeLowering(
pattern, type, TypeExpansionContext::minimal());
}

/// Remap any archetypes into the current function's context.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: %target-swift-frontend -c -enable-library-evolution %s

// https://github.com/swiftlang/swift/issues/55179
// Explicitly use minimal type expansion on autodiff-related types.
// Autodiff happens on function types, so in general it does not know
// if the function in question is resilient or not. Using minimal expansion
// provides an universally conservative approach.

import _Differentiation

public class Tracked<T> {}
extension Tracked: Differentiable where T: Differentiable {}

@differentiable(reverse)
func callback(_ x: inout Tracked<Float>.TangentVector) {}

extension Differentiable {
/// Applies the given closure to the derivative of `self`.
///
/// Returns `self` like an identity function. When the return value is used in
/// a context where it is differentiated with respect to, applies the given
/// closure to the derivative of the return value.
@inlinable
@differentiable(reverse, wrt: self)
func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
return self
}

@inlinable
@derivative(of: withDerivative)
internal func _vjpWithDerivative(
_ body: @escaping (inout TangentVector) -> Void
) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
return (self, { grad in
var grad = grad
body(&grad)
return grad
})
}
}

@differentiable(reverse)
public func caller(_ x: Tracked<Float>) -> Tracked<Float> {
return x.withDerivative(callback)
}