Skip to content

Commit 6298351

Browse files
authored
[AutoDiff] Registers VJPs for FloatingPoint.[maximum|minimum] (#35379)
Resolves TF-1134.
1 parent a9dfe48 commit 6298351

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,4 +282,22 @@ where
282282
let y = squareRoot()
283283
return (y, { v in v / (2 * y) })
284284
}
285+
286+
@inlinable
287+
@derivative(of: minimum)
288+
static func _vjpMinimum(_ x: Self, _ y: Self) -> (
289+
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
290+
) {
291+
if x <= y || y.isNaN { return (x, { v in (v, .zero) }) }
292+
return (y, { v in (.zero, v) })
293+
}
294+
295+
@inlinable
296+
@derivative(of: maximum)
297+
static func _vjpMaximum(_ x: Self, _ y: Self) -> (
298+
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
299+
) {
300+
if x > y || y.isNaN { return (x, { v in (v, .zero) }) }
301+
return (y, { v in (.zero, v) })
302+
}
285303
}

test/AutoDiff/stdlib/floating_point.swift.gyb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ FloatingPointDerivativeTests.test("${Self}.addingProduct") {
8383
expectEqual((1, 2, 3), gradient(at: ${Self}(10), 3, 2, in: { $0.addingProduct($1, $2) }))
8484
}
8585

86+
FloatingPointDerivativeTests.test("${Self}.minimum") {
87+
expectEqual((1.0, 0.0), gradient(at: ${Self}(1), ${Self}(2), in : { ${Self}.minimum($0, $1) }))
88+
expectEqual((1.0, 0.0), gradient(at: ${Self}(1), ${Self}(1), in : { ${Self}.minimum($0, $1) }))
89+
expectEqual((0.0, 1.0), gradient(at: ${Self}(2), ${Self}(1), in : { ${Self}.minimum($0, $1) }))
90+
expectEqual((1.0, 0.0), gradient(at: ${Self}(1), .nan, in : { ${Self}.minimum($0, $1) }))
91+
expectEqual((0.0, 1.0), gradient(at: .nan, ${Self}(1), in : { ${Self}.minimum($0, $1) }))
92+
}
93+
94+
FloatingPointDerivativeTests.test("${Self}.maximum") {
95+
expectEqual((0.0, 1.0), gradient(at: ${Self}(1), ${Self}(2), in : { ${Self}.maximum($0, $1) }))
96+
expectEqual((0.0, 1.0), gradient(at: ${Self}(1), ${Self}(1), in : { ${Self}.maximum($0, $1) }))
97+
expectEqual((1.0, 0.0), gradient(at: ${Self}(2), ${Self}(1), in : { ${Self}.maximum($0, $1) }))
98+
expectEqual((1.0, 0.0), gradient(at: ${Self}(1), .nan, in : { ${Self}.maximum($0, $1) }))
99+
expectEqual((0.0, 1.0), gradient(at: .nan, ${Self}(1), in : { ${Self}.maximum($0, $1) }))
100+
}
101+
86102
%if Self == 'Float80':
87103
#endif
88104
%end

0 commit comments

Comments
 (0)