Skip to content

Commit 4e1bef0

Browse files
authored
Merge branch 'tensorflow' into tensorflow-merge
2 parents 706c8ba + d8df472 commit 4e1bef0

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,18 @@ template <typename T> struct DenseMapInfo;
529529
template <> struct DenseMapInfo<AutoDiffConfig> {
530530
static AutoDiffConfig getEmptyKey() {
531531
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
532+
// The `derivativeGenericSignature` component must be `nullptr` so that
533+
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
534+
// an invalid pointer.
532535
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
533536
nullptr};
534537
}
535538

536539
static AutoDiffConfig getTombstoneKey() {
537540
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
541+
// The `derivativeGenericSignature` component must be `nullptr` so that
542+
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
543+
// an invalid pointer.
538544
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
539545
nullptr};
540546
}

test/AutoDiff/derivative_attr_type_checking.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,33 @@ struct StoredProperty: Differentiable {
532532
(stored, { _ in .zero })
533533
}
534534
}
535+
536+
// When the generic signature was not considered while calculating the actual pullback type, the
537+
// typechecker did not realize that `T.TangentVector == Float`, and therefore it complained that
538+
// "'pullback' does not have expected type '(Float) -> (Float)'".
539+
// Users were able to work around this by setting the pullback type to `(Float) -> Float`.
540+
func genericSignatureConsidered<T>(_ x: T) -> T { fatalError() }
541+
@derivative(of: genericSignatureConsidered)
542+
func dGenericSignatureConsidered<T>(_ x: T)
543+
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
544+
where T: Differentiable, T.TangentVector == Float
545+
{
546+
fatalError()
547+
}
548+
549+
// When the generic signature was not considered while calculating the actual pullback type,
550+
// the typechecker complained that the pullback type was not correct, and there was no pullback type
551+
// that users could specify to satisfy the typechecker.
552+
struct Wrapper<T: AdditiveArithmetic & Equatable>: AdditiveArithmetic, Equatable {
553+
var t: T
554+
init(_ t: T) { self.t = t }
555+
}
556+
extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector {
557+
typealias TangentVector = Wrapper<T.TangentVector>
558+
}
559+
extension Wrapper where T: Differentiable, T == T.TangentVector {
560+
@derivative(of: init(_:))
561+
static func dInit(_ t: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) {
562+
fatalError()
563+
}
564+
}

test/AutoDiff/derivative_registration.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,24 @@ DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") {
156156
expectEqual(1000, dx)
157157
}
158158

159+
// When non-canonicalized generic signatures are used to compare derivative configurations, the
160+
// `@differentiable` and `@derivative` attributes create separate derivatives, and we get a
161+
// duplicate symbol error in TBDGen.
162+
public protocol RefinesDifferentiable: Differentiable {}
163+
extension Float: RefinesDifferentiable {}
164+
@differentiable(where T: Differentiable, T: RefinesDifferentiable)
165+
public func nonCanonicalizedGenSigComparison<T>(_ t: T) -> T { t }
166+
@derivative(of: nonCanonicalizedGenSigComparison)
167+
public func dNonCanonicalizedGenSigComparison<T: RefinesDifferentiable>(_ t: T)
168+
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
169+
{
170+
(t, { _ in T.TangentVector.zero })
171+
}
172+
DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatureComparison") {
173+
let dx = gradient(at: Float(0), in: nonCanonicalizedGenSigComparison)
174+
// Expect that we use the custom registered derivative, not a generated derivative (which would
175+
// give a gradient of 1).
176+
expectEqual(0, dx)
177+
}
178+
159179
runAllTests()

utils/update_checkout/update-checkout-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@
296296
"indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-12-16-a",
297297
"sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-12-16-a",
298298
"tensorflow": "7c7d924821a8b1b20433c2f3f484bbd409873a84",
299-
"tensorflow-swift-apis": "47de23315b574f936daefe1e781bbfd5dfe1a247",
299+
"tensorflow-swift-apis": "0aaf944b91703ea6ae9062b3a720d3e27b7eba36",
300300
"tensorflow-swift-quote": "34d4112c294e9eccfb575dd8f39e6e3b0b6bc008"
301301
}
302302
}

0 commit comments

Comments
 (0)