Skip to content

Commit e9f82d1

Browse files
committed
Conform 'AnyDifferentiable' and 'AnyDerivative' to 'CustomReflectable'.
'AnyDifferentiable' and 'AnyDerivative' should conform to 'CustomReflectable' to prevent leaking implementation details. The mirror should reflect its underlying value directly. Resolves rdar://75496334.
1 parent 5709721 commit e9f82d1

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

stdlib/public/Differentiation/AnyDifferentiable.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ public struct AnyDifferentiable: Differentiable {
102102
}
103103
}
104104

105+
extension AnyDifferentiable: CustomReflectable {
106+
public var customMirror: Mirror {
107+
Mirror(reflecting: base)
108+
}
109+
}
110+
105111
//===----------------------------------------------------------------------===//
106112
// `AnyDerivative`
107113
//===----------------------------------------------------------------------===//
@@ -365,6 +371,12 @@ public struct AnyDerivative: Differentiable & AdditiveArithmetic {
365371
}
366372
}
367373

374+
extension AnyDerivative: CustomReflectable {
375+
public var customMirror: Mirror {
376+
Mirror(reflecting: base)
377+
}
378+
}
379+
368380
//===----------------------------------------------------------------------===//
369381
// Helpers
370382
//===----------------------------------------------------------------------===//

test/AutoDiff/stdlib/anydifferentiable.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ TypeErasureTests.test("AnyDifferentiable casting") {
9292
expectEqual(nil, genericAny.base as? Generic<Double>)
9393
}
9494

95+
TypeErasureTests.test("AnyDifferentiable reflection") {
96+
let originalVector = Vector(x: 1, y: 1)
97+
let vector = AnyDifferentiable(originalVector)
98+
let mirror = Mirror(reflecting: vector)
99+
let children = Array(mirror.children)
100+
expectEqual(2, children.count)
101+
expectEqual(["x", "y"], children.map(\.label))
102+
expectEqual([originalVector.x, originalVector.y], children.map { $0.value as! Float })
103+
}
104+
95105
TypeErasureTests.test("AnyDerivative casting") {
96106
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
97107
expectEqual(Vector.TangentVector(x: 1, y: 1), tan.base as? Vector.TangentVector)
@@ -107,6 +117,16 @@ TypeErasureTests.test("AnyDerivative casting") {
107117
expectEqual(nil, zero.base as? Generic<Float>.TangentVector)
108118
}
109119

120+
TypeErasureTests.test("AnyDerivative reflection") {
121+
let originalTan = Vector.TangentVector(x: 1, y: 1)
122+
let tan = AnyDerivative(originalTan)
123+
let mirror = Mirror(reflecting: tan)
124+
let children = Array(mirror.children)
125+
expectEqual(2, children.count)
126+
expectEqual(["x", "y"], children.map(\.label))
127+
expectEqual([originalTan.x, originalTan.y], children.map { $0.value as! Float })
128+
}
129+
110130
TypeErasureTests.test("AnyDifferentiable differentiation") {
111131
// Test `AnyDifferentiable` initializer.
112132
do {

0 commit comments

Comments
 (0)