Skip to content

Commit f46055c

Browse files
authored
---
yaml --- r: 341627 b: refs/heads/rxwei-patch-1 c: c263f19 h: refs/heads/master i: 341625: 596e558 341623: 6b60121
1 parent d898a7c commit f46055c

File tree

2 files changed

+111
-55
lines changed

2 files changed

+111
-55
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: f57150e58b07dc6bcd18193cdd075d245794fb46
1018+
refs/heads/rxwei-patch-1: c263f19e1f5baece91d1bc3c3850db18528c0a6f
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/stdlib/public/core/AutoDiff.swift

Lines changed: 110 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ public extension PointwiseMultiplicative {
5959
}
6060
}
6161

62-
public extension PointwiseMultiplicative where Self : ExpressibleByIntegerLiteral {
62+
public extension PointwiseMultiplicative
63+
where Self : ExpressibleByIntegerLiteral {
6364
static var one: Self {
6465
return 1
6566
}
@@ -378,23 +379,28 @@ public func transpose<T, R>(
378379

379380
// Value with differential
380381

381-
@available(*, unavailable)
382382
@inlinable
383383
public func valueWithDifferential<T, R>(
384-
at x: T, in body: @differentiable (T) -> R
385-
) -> (value: R, differential: @differentiable/*(linear)*/
386-
(T.TangentVector) -> R.TangentVector) {
387-
fatalError()
384+
at x: T, in f: @differentiable (T) -> R
385+
) -> (value: R, differential: (T.TangentVector) -> R.TangentVector) {
386+
return Builtin.autodiffApply_jvp(f, x)
388387
}
389388

390-
@available(*, unavailable)
391389
@inlinable
392390
public func valueWithDifferential<T, U, R>(
393-
at x: T, _ y: U, in body: @differentiable (T, U) -> R
391+
at x: T, _ y: U, in f: @differentiable (T, U) -> R
394392
) -> (value: R,
395-
differential: @differentiable/*(linear)*/
396-
(T.TangentVector, U.TangentVector) -> R.TangentVector) {
397-
fatalError()
393+
differential: (T.TangentVector, U.TangentVector) -> R.TangentVector) {
394+
return Builtin.autodiffApply_jvp_arity2(f, x, y)
395+
}
396+
397+
@inlinable
398+
public func valueWithDifferential<T, U, V, R>(
399+
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
400+
) -> (value: R,
401+
differential: (T.TangentVector, U.TangentVector, V.TangentVector)
402+
-> (R.TangentVector)) {
403+
return Builtin.autodiffApply_jvp_arity3(f, x, y, z)
398404
}
399405

400406
// Value with pullback
@@ -425,23 +431,28 @@ public func valueWithPullback<T, U, V, R>(
425431

426432
// Differential
427433

428-
@available(*, unavailable)
429434
@inlinable
430435
public func differential<T, R>(
431-
at x: T, in body: @differentiable (T) -> R
432-
) -> @differentiable/*(linear)*/ (T.TangentVector) -> R.TangentVector {
433-
fatalError()
436+
at x: T, in f: @differentiable (T) -> R
437+
) -> (T.TangentVector) -> R.TangentVector {
438+
return valueWithDifferential(at: x, in: f).1
434439
}
435440

436-
@available(*, unavailable)
437441
@inlinable
438442
public func differential<T, U, R>(
439-
at x: T, _ y: U, in body: @differentiable (T, U) -> R
440-
) -> @differentiable/*(linear)*/ (T.TangentVector, U.TangentVector)
441-
-> R.TangentVector {
442-
fatalError()
443+
at x: T, _ y: U, in f: @differentiable (T, U) -> R
444+
) -> (T.TangentVector, U.TangentVector) -> R.TangentVector {
445+
return valueWithDifferential(at: x, y, in: f).1
446+
}
447+
448+
@inlinable
449+
public func differential<T, U, V, R>(
450+
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
451+
) -> (T.TangentVector, U.TangentVector, V.TangentVector) -> (R.TangentVector) {
452+
return valueWithDifferential(at: x, y, z, in: f).1
443453
}
444454

455+
445456
// Pullback
446457

447458
@inlinable
@@ -468,20 +479,31 @@ public func pullback<T, U, V, R>(
468479

469480
// Derivative
470481

471-
@available(*, unavailable)
472482
@inlinable
473483
public func derivative<T: FloatingPoint, R>(
474-
at x: T, in body: @escaping @differentiable (T) -> R
475-
) -> R.TangentVector where T.TangentVector : FloatingPoint {
476-
fatalError()
484+
at x: T, in f: @escaping @differentiable (T) -> R
485+
) -> R.TangentVector
486+
where T.TangentVector == T {
487+
return differential(at: x, in: f)(T(1))
477488
}
478489

479-
@available(*, unavailable)
480490
@inlinable
481491
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
482-
at x: T, _ y: U, in body: @escaping @differentiable (T) -> R
483-
) -> R.TangentVector where T.TangentVector : FloatingPoint {
484-
fatalError()
492+
at x: T, _ y: U, in f: @escaping @differentiable (T, U) -> R
493+
) -> R.TangentVector
494+
where T.TangentVector == T,
495+
U.TangentVector == U {
496+
return differential(at: x, y, in: f)(T(1), U(1))
497+
}
498+
499+
@inlinable
500+
public func derivative<T: FloatingPoint, U: FloatingPoint, V: FloatingPoint, R>(
501+
at x: T, _ y: U, _ z: V, in f: @escaping @differentiable (T, U, V) -> R
502+
) -> R.TangentVector
503+
where T.TangentVector == T,
504+
U.TangentVector == U,
505+
V.TangentVector == V {
506+
return differential(at: x, y, z, in: f)(T(1), U(1), V(1))
485507
}
486508

487509
// Gradient
@@ -512,22 +534,35 @@ public func gradient<T, U, V, R>(
512534

513535
// Value with derivative
514536

515-
@available(*, unavailable)
516537
@inlinable
517538
public func valueWithDerivative<T: FloatingPoint, R>(
518-
at x: T, in body: @escaping @differentiable (T) -> R
539+
at x: T, in f: @escaping @differentiable (T) -> R
519540
) -> (value: R, derivative: R.TangentVector)
520-
where T.TangentVector : FloatingPoint {
521-
fatalError()
541+
where T.TangentVector == T {
542+
let (y, differential) = valueWithDifferential(at: x, in: f)
543+
return (y, differential(T(1)))
522544
}
523545

524-
@available(*, unavailable)
525546
@inlinable
526547
public func valueWithDerivative<T: FloatingPoint, U: FloatingPoint, R>(
527-
at x: T, _ y: U, in body: @escaping @differentiable (T) -> R
548+
at x: T, _ y: U, in f: @escaping @differentiable (T, U) -> R
528549
) -> (value: R, derivative: R.TangentVector)
529-
where T.TangentVector : FloatingPoint {
530-
fatalError()
550+
where T.TangentVector == T,
551+
U.TangentVector == U {
552+
let (y, differential) = valueWithDifferential(at: x, y, in: f)
553+
return (y, differential(T(1), U(1)))
554+
}
555+
556+
@inlinable
557+
public func valueWithDerivative<
558+
T: FloatingPoint, U: FloatingPoint, V: FloatingPoint, R>(
559+
at x: T, _ y: U, _ z: V, in f: @escaping @differentiable (T, U, V) -> R
560+
) -> (value: R, derivative: R.TangentVector)
561+
where T.TangentVector == T,
562+
U.TangentVector == U,
563+
V.TangentVector == V {
564+
let (y, differential) = valueWithDifferential(at: x, y, z, in: f)
565+
return (y, differential(T(1), U(1), V(1)))
531566
}
532567

533568
// Value with gradient
@@ -562,20 +597,31 @@ public func valueWithGradient<T, U, V, R>(
562597

563598
// Derivative (curried)
564599

565-
@available(*, unavailable)
566600
@inlinable
567601
public func derivative<T: FloatingPoint, R>(
568-
of body: @escaping @differentiable (T) -> R
569-
) -> (T) -> R.TangentVector where T.TangentVector : FloatingPoint {
570-
fatalError()
602+
of f: @escaping @differentiable (T) -> R
603+
) -> (T) -> R.TangentVector
604+
where T.TangentVector == T {
605+
return { x in derivative(at: x, in: f) }
571606
}
572607

573-
@available(*, unavailable)
574608
@inlinable
575609
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
576-
of body: @escaping @differentiable (T, U) -> R
577-
) -> (T, U) -> R.TangentVector where T.TangentVector : FloatingPoint {
578-
fatalError()
610+
of f: @escaping @differentiable (T, U) -> R
611+
) -> (T, U) -> R.TangentVector
612+
where T.TangentVector == T,
613+
U.TangentVector == U {
614+
return { (x, y) in derivative(at: x, y, in: f) }
615+
}
616+
617+
@inlinable
618+
public func derivative<T: FloatingPoint, U: FloatingPoint, V: FloatingPoint, R>(
619+
of f: @escaping @differentiable (T, U, V) -> R
620+
) -> (T, U, V) -> R.TangentVector
621+
where T.TangentVector == T,
622+
U.TangentVector == U,
623+
V.TangentVector == V {
624+
return { (x, y, z) in derivative(at: x, y, z, in: f) }
579625
}
580626

581627
// Gradient (curried)
@@ -606,22 +652,32 @@ public func gradient<T, U, V, R>(
606652

607653
// Value with derivative (curried)
608654

609-
@available(*, unavailable)
610655
@inlinable
611656
public func valueWithDerivative<T: FloatingPoint, R>(
612-
of body: @escaping @differentiable (T) -> R
613-
) -> (value: R, derivative: R.TangentVector)
614-
where T.TangentVector: FloatingPoint {
615-
fatalError()
657+
of f: @escaping @differentiable (T) -> R
658+
) -> (T) -> (value: R, derivative: R.TangentVector)
659+
where T.TangentVector == T {
660+
return { x in valueWithDerivative(at: x, in: f) }
616661
}
617662

618-
@available(*, unavailable)
619663
@inlinable
620664
public func valueWithDerivative<T: FloatingPoint, U: FloatingPoint, R>(
621-
of body: @escaping @differentiable (T, U) -> R
622-
) -> (value: R, derivative: R.TangentVector)
623-
where T.TangentVector: FloatingPoint, U.TangentVector: FloatingPoint {
624-
fatalError()
665+
of f: @escaping @differentiable (T, U) -> R
666+
) -> (T, U) -> (value: R, derivative: R.TangentVector)
667+
where T.TangentVector == T,
668+
U.TangentVector == U {
669+
return { (x, y) in valueWithDerivative(at: x, y, in: f) }
670+
}
671+
672+
@inlinable
673+
public func valueWithDerivative<
674+
T: FloatingPoint, U: FloatingPoint, V: FloatingPoint, R>(
675+
of f: @escaping @differentiable (T, U, V) -> R
676+
) -> (T, U, V) -> (value: R, derivative: R.TangentVector)
677+
where T.TangentVector == T,
678+
U.TangentVector == U,
679+
V.TangentVector == V {
680+
return { (x, y, z) in valueWithDerivative(at: x, y, z, in: f) }
625681
}
626682

627683
// Value with gradient (curried)

0 commit comments

Comments
 (0)