Skip to content

Commit e5915f7

Browse files
author
marcrasi
authored
[AutoDiff] remove some jvp:, vjp: from stdlib (#28762)
Removes some `jvp:` and `vjp:` from the stdlib. This is partial progress towards unblocking TF-1001. Some bugfixes were necessary: * `DenseMapInfo<AutoDiffConfig>` wasn't canonicalizing the `GenericSignature`, causing equivalent configs to appear distinct. * `@derivative(of:)` typechecking was not taking the generic signature into account when comparing actual/expected differential/pullback types, causing incorrect "incorrect pullback type" diagnositcs. * In deserialization, `MF.getIdentifier` doesn't work on special identifiers like `init`, so `@derivative(of: init)` didn't work.
1 parent 37b507b commit e5915f7

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -492,29 +492,39 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
492492
static AutoDiffConfig getEmptyKey() {
493493
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
494494
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
495-
DenseMapInfo<GenericSignature>::getEmptyKey()};
495+
nullptr};
496496
}
497497

498498
static AutoDiffConfig getTombstoneKey() {
499499
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
500500
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
501-
DenseMapInfo<GenericSignature>::getTombstoneKey()};
501+
nullptr};
502502
}
503503

504504
static unsigned getHashValue(const AutoDiffConfig &Val) {
505+
auto canGenSig =
506+
Val.derivativeGenericSignature
507+
? Val.derivativeGenericSignature->getCanonicalSignature()
508+
: nullptr;
505509
unsigned combinedHash = hash_combine(
506510
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
507511
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
508-
DenseMapInfo<GenericSignature>::getHashValue(
509-
Val.derivativeGenericSignature));
512+
DenseMapInfo<GenericSignature>::getHashValue(canGenSig));
510513
return combinedHash;
511514
}
512515

513516
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
517+
auto lhsCanGenSig =
518+
LHS.derivativeGenericSignature
519+
? LHS.derivativeGenericSignature->getCanonicalSignature()
520+
: nullptr;
521+
auto rhsCanGenSig =
522+
RHS.derivativeGenericSignature
523+
? RHS.derivativeGenericSignature->getCanonicalSignature()
524+
: nullptr;
514525
return LHS.parameterIndices == RHS.parameterIndices &&
515-
LHS.resultIndices == RHS.resultIndices &&
516-
DenseMapInfo<GenericSignature>::isEqual(LHS.derivativeGenericSignature,
517-
RHS.derivativeGenericSignature);
526+
LHS.resultIndices == RHS.resultIndices &&
527+
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
518528
}
519529
};
520530

lib/Sema/TypeCheckAttr.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,8 +3816,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
38163816
auto vectorTy = valueResultConf.getTypeWitnessByName(
38173817
valueResultType, Ctx.Id_TangentVector);
38183818

3819+
// Compute the actual differential/pullback type that we use for comparison
3820+
// with the expected type. We must canonicalize the derivative interface type
3821+
// before extracting the differential/pullback type from it, so that the
3822+
// derivative interface type generic signature is available for simplifying
3823+
// types.
3824+
CanType canActualResultType = derivativeInterfaceType->getCanonicalType();
3825+
while (isa<AnyFunctionType>(canActualResultType)) {
3826+
canActualResultType =
3827+
cast<AnyFunctionType>(canActualResultType).getResult();
3828+
}
3829+
CanType actualFuncEltType =
3830+
cast<TupleType>(canActualResultType).getElementType(1);
3831+
38193832
// Compute expected differential/pullback type.
3820-
auto funcEltType = funcResultElt.getType();
38213833
Type expectedFuncEltType;
38223834
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
38233835
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4>>(
@@ -3832,7 +3844,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
38323844
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext();
38333845

38343846
// Check if differential/pullback type matches expected type.
3835-
if (!funcEltType->isEqual(expectedFuncEltType)) {
3847+
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
38363848
// Emit differential/pullback type mismatch error on attribute.
38373849
diagnose(attr->getLocation(),
38383850
diag::derivative_attr_result_func_type_mismatch,

lib/Serialization/Deserialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4213,7 +4213,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
42134213
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
42144214
parameters);
42154215

4216-
DeclNameWithLoc origName{MF.getIdentifier(origNameId), DeclNameLoc()};
4216+
DeclNameWithLoc origName{MF.getDeclBaseName(origNameId), DeclNameLoc()};
42174217
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
42184218
auto derivativeKind =
42194219
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ public struct Tracked<T> {
6565
}
6666
private var handle: Box
6767

68-
@differentiable(jvp: _jvpInit, vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
68+
@differentiable(where T : Differentiable, T == T.TangentVector)
6969
public init(_ value: T) {
7070
self.handle = Box(value)
7171
}
7272

73-
@differentiable(jvp: _jvpValue, vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
73+
@differentiable(where T : Differentiable, T == T.TangentVector)
7474
public var value: T {
7575
get { handle.value }
7676
set { handle.value = newValue }
@@ -172,24 +172,28 @@ extension Tracked : Differentiable where T : Differentiable, T == T.TangentVecto
172172

173173
extension Tracked where T : Differentiable, T == T.TangentVector {
174174
@usableFromInline
175+
@derivative(of: init)
175176
internal static func _vjpInit(_ value: T)
176177
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
177178
return (Tracked(value), { v in v.value })
178179
}
179180

180181
@usableFromInline
182+
@derivative(of: init)
181183
internal static func _jvpInit(_ value: T)
182184
-> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) {
183185
return (Tracked(value), { v in Tracked(v) })
184186
}
185187

186188
@usableFromInline
187-
internal func _vjpValue() -> (T, (T.TangentVector) -> Self.TangentVector) {
189+
@derivative(of: value)
190+
internal func _vjpValue() -> (value: T, pullback: (T.TangentVector) -> Self.TangentVector) {
188191
return (value, { v in Tracked(v) })
189192
}
190193

191194
@usableFromInline
192-
internal func _jvpValue() -> (T, (Self.TangentVector) -> T.TangentVector) {
195+
@derivative(of: value)
196+
internal func _jvpValue() -> (value: T, differential: (Self.TangentVector) -> T.TangentVector) {
193197
return (value, { v in v.value })
194198
}
195199
}

stdlib/public/Differentiation/DifferentiationSupport.swift

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,16 @@ public func differentiableFunction<T, U, R>(
245245
}
246246

247247
public extension Differentiable {
248-
@differentiable(wrt: self, vjp: _vjpWithDerivative)
248+
@differentiable(wrt: self)
249249
func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
250250
return self
251251
}
252252

253253
@inlinable
254+
@derivative(of: withDerivative)
254255
internal func _vjpWithDerivative(
255256
_ body: @escaping (inout TangentVector) -> Void
256-
) -> (Self, (TangentVector) -> TangentVector) {
257+
) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
257258
return (self, { grad in
258259
var grad = grad
259260
body(&grad)
@@ -275,17 +276,18 @@ public func withRecomputationInPullbacks<T, U>(
275276

276277
public extension Differentiable {
277278
@inlinable
278-
@differentiable(wrt: self, vjp: _vjp_withRecomputationInPullbacks)
279+
@differentiable(wrt: self)
279280
func withRecomputationInPullbacks<Result : Differentiable>(
280281
_ body: @escaping @differentiable (Self) -> Result
281282
) -> Result {
282283
return body(self)
283284
}
284285

285286
@inlinable
287+
@derivative(of: withRecomputationInPullbacks)
286288
internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
287289
_ body: @escaping @differentiable (Self) -> Result
288-
) -> (Result, (Result.TangentVector) -> TangentVector) {
290+
) -> (value: Result, pullback: (Result.TangentVector) -> TangentVector) {
289291
return Swift.valueWithPullback(
290292
at: self, in: Swift.withRecomputationInPullbacks(body)
291293
)
@@ -908,24 +910,26 @@ public struct AnyDerivative: EuclideanDifferentiable & AdditiveArithmetic {
908910

909911
/// Creates a type-erased derivative from the given derivative.
910912
@inlinable
911-
@differentiable(jvp: _jvpInit(_:), vjp: _vjpInit(_:))
913+
@differentiable
912914
public init<T>(_ base: T) where T: Differentiable, T.TangentVector == T {
913915
self._box = _ConcreteDerivativeBox<T>(base)
914916
}
915917

916918
@inlinable
919+
@derivative(of: init)
917920
internal static func _vjpInit<T>(
918921
_ base: T
919-
) -> (AnyDerivative, (AnyDerivative) -> T.TangentVector)
922+
) -> (value: AnyDerivative, pullback: (AnyDerivative) -> T.TangentVector)
920923
where T: Differentiable, T.TangentVector == T
921924
{
922925
return (AnyDerivative(base), { v in v.base as! T.TangentVector })
923926
}
924927

925928
@inlinable
929+
@derivative(of: init)
926930
internal static func _jvpInit<T>(
927931
_ base: T
928-
) -> (AnyDerivative, (T.TangentVector) -> AnyDerivative)
932+
) -> (value: AnyDerivative, differential: (T.TangentVector) -> AnyDerivative)
929933
where T: Differentiable, T.TangentVector == T
930934
{
931935
return (AnyDerivative(base), { dbase in AnyDerivative(dbase) })
@@ -1028,7 +1032,7 @@ public struct AnyDerivative: EuclideanDifferentiable & AdditiveArithmetic {
10281032
//===----------------------------------------------------------------------===//
10291033

10301034
public extension Array where Element: Differentiable {
1031-
@differentiable(wrt: (self, initialResult), vjp: _vjpDifferentiableReduce)
1035+
@differentiable(wrt: (self, initialResult))
10321036
func differentiableReduce<Result: Differentiable>(
10331037
_ initialResult: Result,
10341038
_ nextPartialResult: @differentiable (Result, Element) -> Result
@@ -1037,6 +1041,7 @@ public extension Array where Element: Differentiable {
10371041
}
10381042

10391043
@usableFromInline
1044+
@derivative(of: differentiableReduce)
10401045
internal func _vjpDifferentiableReduce<Result: Differentiable>(
10411046
_ initialResult: Result,
10421047
_ nextPartialResult: @differentiable (Result, Element) -> Result
@@ -1070,14 +1075,15 @@ public extension Array where Element: Differentiable {
10701075
}
10711076

10721077
public extension Array where Element: Differentiable {
1073-
@differentiable(wrt: self, vjp: _vjpDifferentiableMap)
1078+
@differentiable(wrt: self)
10741079
func differentiableMap<Result: Differentiable>(
10751080
_ body: @differentiable (Element) -> Result
10761081
) -> [Result] {
10771082
map(body)
10781083
}
10791084

10801085
@usableFromInline
1086+
@derivative(of: differentiableMap)
10811087
internal func _vjpDifferentiableMap<Result: Differentiable>(
10821088
_ body: @differentiable (Element) -> Result
10831089
) -> (value: [Result],

0 commit comments

Comments
 (0)