@@ -59,7 +59,8 @@ public extension PointwiseMultiplicative {
59
59
}
60
60
}
61
61
62
- public extension PointwiseMultiplicative where Self : ExpressibleByIntegerLiteral {
62
+ public extension PointwiseMultiplicative
63
+ where Self : ExpressibleByIntegerLiteral {
63
64
static var one : Self {
64
65
return 1
65
66
}
@@ -378,23 +379,28 @@ public func transpose<T, R>(
378
379
379
380
// Value with differential
380
381
381
- @available ( * , unavailable)
382
382
@inlinable
383
383
public func valueWithDifferential< T, R> (
384
- at x: T , in body: @differentiable ( T ) -> R
385
- ) -> ( value: R , differential: @differentiable /*(linear)*/
386
- ( T . TangentVector ) -> R . TangentVector ) {
387
- fatalError ( )
384
+ at x: T , in f: @differentiable ( T ) -> R
385
+ ) -> ( value: R , differential: ( T . TangentVector ) -> R . TangentVector ) {
386
+ return Builtin . autodiffApply_jvp ( f, x)
388
387
}
389
388
390
- @available ( * , unavailable)
391
389
@inlinable
392
390
public func valueWithDifferential< T, U, R> (
393
- at x: T , _ y: U , in body : @differentiable ( T , U ) -> R
391
+ at x: T , _ y: U , in f : @differentiable ( T , U ) -> R
394
392
) -> ( value: R ,
395
- differential: @differentiable /*(linear)*/
396
- ( T . TangentVector , U . TangentVector ) -> R . TangentVector ) {
397
- fatalError ( )
393
+ differential: ( T . TangentVector , U . TangentVector ) -> R . TangentVector ) {
394
+ return Builtin . autodiffApply_jvp_arity2 ( f, x, y)
395
+ }
396
+
397
+ @inlinable
398
+ public func valueWithDifferential< T, U, V, R> (
399
+ at x: T , _ y: U , _ z: V , in f: @differentiable ( T , U , V ) -> R
400
+ ) -> ( value: R ,
401
+ differential: ( T . TangentVector , U . TangentVector , V . TangentVector )
402
+ -> ( R . TangentVector ) ) {
403
+ return Builtin . autodiffApply_jvp_arity3 ( f, x, y, z)
398
404
}
399
405
400
406
// Value with pullback
@@ -425,23 +431,28 @@ public func valueWithPullback<T, U, V, R>(
425
431
426
432
// Differential
427
433
428
- @available ( * , unavailable)
429
434
@inlinable
430
435
public func differential< T, R> (
431
- at x: T , in body : @differentiable ( T ) -> R
432
- ) -> @ differentiable /*(linear)*/ ( T . TangentVector ) -> R . TangentVector {
433
- fatalError ( )
436
+ at x: T , in f : @differentiable ( T ) -> R
437
+ ) -> ( T . TangentVector ) -> R . TangentVector {
438
+ return valueWithDifferential ( at : x , in : f ) . 1
434
439
}
435
440
436
- @available ( * , unavailable)
437
441
@inlinable
438
442
public func differential< T, U, R> (
439
- at x: T , _ y: U , in body: @differentiable ( T , U ) -> R
440
- ) -> @differentiable /*(linear)*/ ( T . TangentVector , U . TangentVector )
441
- -> R . TangentVector {
442
- fatalError ( )
443
+ at x: T , _ y: U , in f: @differentiable ( T , U ) -> R
444
+ ) -> ( T . TangentVector , U . TangentVector ) -> R . TangentVector {
445
+ return valueWithDifferential ( at: x, y, in: f) . 1
446
+ }
447
+
448
+ @inlinable
449
+ public func differential< T, U, V, R> (
450
+ at x: T , _ y: U , _ z: V , in f: @differentiable ( T , U , V ) -> R
451
+ ) -> ( T . TangentVector , U . TangentVector , V . TangentVector ) -> ( R . TangentVector ) {
452
+ return valueWithDifferential ( at: x, y, z, in: f) . 1
443
453
}
444
454
455
+
445
456
// Pullback
446
457
447
458
@inlinable
@@ -468,20 +479,31 @@ public func pullback<T, U, V, R>(
468
479
469
480
// Derivative
470
481
471
- @available ( * , unavailable)
472
482
@inlinable
473
483
public func derivative< T: FloatingPoint , R> (
474
- at x: T , in body: @escaping @differentiable ( T ) -> R
475
- ) -> R . TangentVector where T. TangentVector : FloatingPoint {
476
- fatalError ( )
484
+ at x: T , in f: @escaping @differentiable ( T ) -> R
485
+ ) -> R . TangentVector
486
+ where T. TangentVector == T {
487
+ return differential ( at: x, in: f) ( T ( 1 ) )
477
488
}
478
489
479
- @available ( * , unavailable)
480
490
@inlinable
481
491
public func derivative< T: FloatingPoint , U: FloatingPoint , R> (
482
- at x: T , _ y: U , in body: @escaping @differentiable ( T ) -> R
483
- ) -> R . TangentVector where T. TangentVector : FloatingPoint {
484
- fatalError ( )
492
+ at x: T , _ y: U , in f: @escaping @differentiable ( T , U ) -> R
493
+ ) -> R . TangentVector
494
+ where T. TangentVector == T ,
495
+ U. TangentVector == U {
496
+ return differential ( at: x, y, in: f) ( T ( 1 ) , U ( 1 ) )
497
+ }
498
+
499
+ @inlinable
500
+ public func derivative< T: FloatingPoint , U: FloatingPoint , V: FloatingPoint , R> (
501
+ at x: T , _ y: U , _ z: V , in f: @escaping @differentiable ( T , U , V ) -> R
502
+ ) -> R . TangentVector
503
+ where T. TangentVector == T ,
504
+ U. TangentVector == U ,
505
+ V. TangentVector == V {
506
+ return differential ( at: x, y, z, in: f) ( T ( 1 ) , U ( 1 ) , V ( 1 ) )
485
507
}
486
508
487
509
// Gradient
@@ -512,22 +534,35 @@ public func gradient<T, U, V, R>(
512
534
513
535
// Value with derivative
514
536
515
- @available ( * , unavailable)
516
537
@inlinable
517
538
public func valueWithDerivative< T: FloatingPoint , R> (
518
- at x: T , in body : @escaping @differentiable ( T ) -> R
539
+ at x: T , in f : @escaping @differentiable ( T ) -> R
519
540
) -> ( value: R , derivative: R . TangentVector )
520
- where T. TangentVector : FloatingPoint {
521
- fatalError ( )
541
+ where T. TangentVector == T {
542
+ let ( y, differential) = valueWithDifferential ( at: x, in: f)
543
+ return ( y, differential ( T ( 1 ) ) )
522
544
}
523
545
524
- @available ( * , unavailable)
525
546
@inlinable
526
547
public func valueWithDerivative< T: FloatingPoint , U: FloatingPoint , R> (
527
- at x: T , _ y: U , in body : @escaping @differentiable ( T ) -> R
548
+ at x: T , _ y: U , in f : @escaping @differentiable ( T , U ) -> R
528
549
) -> ( value: R , derivative: R . TangentVector )
529
- where T. TangentVector : FloatingPoint {
530
- fatalError ( )
550
+ where T. TangentVector == T ,
551
+ U. TangentVector == U {
552
+ let ( y, differential) = valueWithDifferential ( at: x, y, in: f)
553
+ return ( y, differential ( T ( 1 ) , U ( 1 ) ) )
554
+ }
555
+
556
+ @inlinable
557
+ public func valueWithDerivative<
558
+ T: FloatingPoint , U: FloatingPoint , V: FloatingPoint , R> (
559
+ at x: T , _ y: U , _ z: V , in f: @escaping @differentiable ( T , U , V ) -> R
560
+ ) -> ( value: R , derivative: R . TangentVector )
561
+ where T. TangentVector == T ,
562
+ U. TangentVector == U ,
563
+ V. TangentVector == V {
564
+ let ( y, differential) = valueWithDifferential ( at: x, y, z, in: f)
565
+ return ( y, differential ( T ( 1 ) , U ( 1 ) , V ( 1 ) ) )
531
566
}
532
567
533
568
// Value with gradient
@@ -562,20 +597,31 @@ public func valueWithGradient<T, U, V, R>(
562
597
563
598
// Derivative (curried)
564
599
565
- @available ( * , unavailable)
566
600
@inlinable
567
601
public func derivative< T: FloatingPoint , R> (
568
- of body: @escaping @differentiable ( T ) -> R
569
- ) -> ( T ) -> R . TangentVector where T. TangentVector : FloatingPoint {
570
- fatalError ( )
602
+ of f: @escaping @differentiable ( T ) -> R
603
+ ) -> ( T ) -> R . TangentVector
604
+ where T. TangentVector == T {
605
+ return { x in derivative ( at: x, in: f) }
571
606
}
572
607
573
- @available ( * , unavailable)
574
608
@inlinable
575
609
public func derivative< T: FloatingPoint , U: FloatingPoint , R> (
576
- of body: @escaping @differentiable ( T , U ) -> R
577
- ) -> ( T , U ) -> R . TangentVector where T. TangentVector : FloatingPoint {
578
- fatalError ( )
610
+ of f: @escaping @differentiable ( T , U ) -> R
611
+ ) -> ( T , U ) -> R . TangentVector
612
+ where T. TangentVector == T ,
613
+ U. TangentVector == U {
614
+ return { ( x, y) in derivative ( at: x, y, in: f) }
615
+ }
616
+
617
+ @inlinable
618
+ public func derivative< T: FloatingPoint , U: FloatingPoint , V: FloatingPoint , R> (
619
+ of f: @escaping @differentiable ( T , U , V ) -> R
620
+ ) -> ( T , U , V ) -> R . TangentVector
621
+ where T. TangentVector == T ,
622
+ U. TangentVector == U ,
623
+ V. TangentVector == V {
624
+ return { ( x, y, z) in derivative ( at: x, y, z, in: f) }
579
625
}
580
626
581
627
// Gradient (curried)
@@ -606,22 +652,32 @@ public func gradient<T, U, V, R>(
606
652
607
653
// Value with derivative (curried)
608
654
609
- @available ( * , unavailable)
610
655
@inlinable
611
656
public func valueWithDerivative< T: FloatingPoint , R> (
612
- of body : @escaping @differentiable ( T ) -> R
613
- ) -> ( value: R , derivative: R . TangentVector )
614
- where T. TangentVector: FloatingPoint {
615
- fatalError ( )
657
+ of f : @escaping @differentiable ( T ) -> R
658
+ ) -> ( T ) -> ( value: R , derivative: R . TangentVector )
659
+ where T. TangentVector == T {
660
+ return { x in valueWithDerivative ( at : x , in : f ) }
616
661
}
617
662
618
- @available ( * , unavailable)
619
663
@inlinable
620
664
public func valueWithDerivative< T: FloatingPoint , U: FloatingPoint , R> (
621
- of body: @escaping @differentiable ( T , U ) -> R
622
- ) -> ( value: R , derivative: R . TangentVector )
623
- where T. TangentVector: FloatingPoint , U. TangentVector: FloatingPoint {
624
- fatalError ( )
665
+ of f: @escaping @differentiable ( T , U ) -> R
666
+ ) -> ( T , U ) -> ( value: R , derivative: R . TangentVector )
667
+ where T. TangentVector == T ,
668
+ U. TangentVector == U {
669
+ return { ( x, y) in valueWithDerivative ( at: x, y, in: f) }
670
+ }
671
+
672
+ @inlinable
673
+ public func valueWithDerivative<
674
+ T: FloatingPoint , U: FloatingPoint , V: FloatingPoint , R> (
675
+ of f: @escaping @differentiable ( T , U , V ) -> R
676
+ ) -> ( T , U , V ) -> ( value: R , derivative: R . TangentVector )
677
+ where T. TangentVector == T ,
678
+ U. TangentVector == U ,
679
+ V. TangentVector == V {
680
+ return { ( x, y, z) in valueWithDerivative ( at: x, y, z, in: f) }
625
681
}
626
682
627
683
// Value with gradient (curried)
0 commit comments