Skip to content

Commit d97c867

Browse files
authored
Merge pull request #36487 from rxwei/75496334-reflection
2 parents 0ab315a + e9f82d1 commit d97c867

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

stdlib/public/Differentiation/AnyDifferentiable.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ public struct AnyDifferentiable: Differentiable {
102102
}
103103
}
104104

105+
extension AnyDifferentiable: CustomReflectable {
106+
public var customMirror: Mirror {
107+
Mirror(reflecting: base)
108+
}
109+
}
110+
105111
//===----------------------------------------------------------------------===//
106112
// `AnyDerivative`
107113
//===----------------------------------------------------------------------===//
@@ -365,6 +371,12 @@ public struct AnyDerivative: Differentiable & AdditiveArithmetic {
365371
}
366372
}
367373

374+
extension AnyDerivative: CustomReflectable {
375+
public var customMirror: Mirror {
376+
Mirror(reflecting: base)
377+
}
378+
}
379+
368380
//===----------------------------------------------------------------------===//
369381
// Helpers
370382
//===----------------------------------------------------------------------===//

test/AutoDiff/stdlib/anydifferentiable.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ TypeErasureTests.test("AnyDifferentiable casting") {
9292
expectEqual(nil, genericAny.base as? Generic<Double>)
9393
}
9494

95+
TypeErasureTests.test("AnyDifferentiable reflection") {
96+
let originalVector = Vector(x: 1, y: 1)
97+
let vector = AnyDifferentiable(originalVector)
98+
let mirror = Mirror(reflecting: vector)
99+
let children = Array(mirror.children)
100+
expectEqual(2, children.count)
101+
expectEqual(["x", "y"], children.map(\.label))
102+
expectEqual([originalVector.x, originalVector.y], children.map { $0.value as! Float })
103+
}
104+
95105
TypeErasureTests.test("AnyDerivative casting") {
96106
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
97107
expectEqual(Vector.TangentVector(x: 1, y: 1), tan.base as? Vector.TangentVector)
@@ -107,6 +117,16 @@ TypeErasureTests.test("AnyDerivative casting") {
107117
expectEqual(nil, zero.base as? Generic<Float>.TangentVector)
108118
}
109119

120+
TypeErasureTests.test("AnyDerivative reflection") {
121+
let originalTan = Vector.TangentVector(x: 1, y: 1)
122+
let tan = AnyDerivative(originalTan)
123+
let mirror = Mirror(reflecting: tan)
124+
let children = Array(mirror.children)
125+
expectEqual(2, children.count)
126+
expectEqual(["x", "y"], children.map(\.label))
127+
expectEqual([originalTan.x, originalTan.y], children.map { $0.value as! Float })
128+
}
129+
110130
TypeErasureTests.test("AnyDifferentiable differentiation") {
111131
// Test `AnyDifferentiable` initializer.
112132
do {

0 commit comments

Comments
 (0)