@@ -336,6 +336,37 @@ public extension Differentiable {
336
336
// Free-function-style differential operators
337
337
//===----------------------------------------------------------------------===//
338
338
339
+ // Transpose
340
+
341
+ @available ( * , unavailable)
342
+ @inlinable
343
+ public func transpose< T, R> (
344
+ of body: @escaping @differentiable /*(linear)*/ ( T ) -> R
345
+ ) -> @differentiable /*(linear)*/ ( R ) -> T {
346
+ fatalError ( )
347
+ }
348
+
349
+ // Value with differential
350
+
351
+ @available ( * , unavailable)
352
+ @inlinable
353
+ public func valueWithDifferential< T, R> (
354
+ at x: T , in body: @differentiable ( T ) -> R
355
+ ) -> ( value: R , differential: @differentiable /*(linear)*/
356
+ ( T . TangentVector ) -> R . TangentVector ) {
357
+ fatalError ( )
358
+ }
359
+
360
+ @available ( * , unavailable)
361
+ @inlinable
362
+ public func valueWithDifferential< T, U, R> (
363
+ at x: T , _ y: U , in body: @differentiable ( T , U ) -> R
364
+ ) -> ( value: R ,
365
+ differential: @differentiable /*(linear)*/
366
+ ( T . TangentVector , U . TangentVector ) -> R . TangentVector ) {
367
+ fatalError ( )
368
+ }
369
+
339
370
// Value with pullback
340
371
341
372
@inlinable
@@ -362,6 +393,25 @@ public func valueWithPullback<T, U, V, R>(
362
393
return Builtin . autodiffApply_vjp_arity3 ( f, x, y, z)
363
394
}
364
395
396
+ // Differential
397
+
398
+ @available ( * , unavailable)
399
+ @inlinable
400
+ public func differential< T, R> (
401
+ at x: T , in body: @differentiable ( T ) -> R
402
+ ) -> @differentiable /*(linear)*/ ( T . TangentVector ) -> R . TangentVector {
403
+ fatalError ( )
404
+ }
405
+
406
+ @available ( * , unavailable)
407
+ @inlinable
408
+ public func differential< T, U, R> (
409
+ at x: T , _ y: U , in body: @differentiable ( T , U ) -> R
410
+ ) -> @differentiable /*(linear)*/ ( T . TangentVector , U . TangentVector )
411
+ -> R . TangentVector {
412
+ fatalError ( )
413
+ }
414
+
365
415
// Pullback
366
416
367
417
@inlinable
@@ -386,6 +436,70 @@ public func pullback<T, U, V, R>(
386
436
return Builtin . autodiffApply_vjp_arity3 ( f, x, y, z) . 1
387
437
}
388
438
439
+ // Derivative
440
+
441
+ @available ( * , unavailable)
442
+ @inlinable
443
+ public func derivative< T: FloatingPoint , R> (
444
+ at x: T , in body: @escaping @differentiable ( T ) -> R
445
+ ) -> R . TangentVector where T. TangentVector : FloatingPoint {
446
+ fatalError ( )
447
+ }
448
+
449
+ @available ( * , unavailable)
450
+ @inlinable
451
+ public func derivative< T: FloatingPoint , U: FloatingPoint , R> (
452
+ at x: T , _ y: U , in body: @escaping @differentiable ( T ) -> R
453
+ ) -> R . TangentVector where T. TangentVector : FloatingPoint {
454
+ fatalError ( )
455
+ }
456
+
457
+ // Gradient
458
+
459
+ @inlinable
460
+ public func gradient< T, R> (
461
+ at x: T , in f: @differentiable ( T ) -> R
462
+ ) -> T . TangentVector
463
+ where R : FloatingPoint , R. TangentVector == R {
464
+ return pullback ( at: x, in: f) ( R ( 1 ) )
465
+ }
466
+
467
+ @inlinable
468
+ public func gradient< T, U, R> (
469
+ at x: T , _ y: U , in f: @differentiable ( T , U ) -> R
470
+ ) -> ( T . TangentVector , U . TangentVector )
471
+ where R : FloatingPoint , R. TangentVector == R {
472
+ return pullback ( at: x, y, in: f) ( R ( 1 ) )
473
+ }
474
+
475
+ @inlinable
476
+ public func gradient< T, U, V, R> (
477
+ at x: T , _ y: U , _ z: V , in f: @differentiable ( T , U , V ) -> R
478
+ ) -> ( T . TangentVector , U . TangentVector , V . TangentVector )
479
+ where R : FloatingPoint , R. TangentVector == R {
480
+ return pullback ( at: x, y, z, in: f) ( R ( 1 ) )
481
+ }
482
+
483
+ // Value with derivative
484
+
485
+ @available ( * , unavailable)
486
+ @inlinable
487
+ public func valueWithDerivative< T: FloatingPoint , R> (
488
+ at x: T , in body: @escaping @differentiable ( T ) -> R
489
+ ) -> ( value: R , derivative: R . TangentVector )
490
+ where T. TangentVector : FloatingPoint {
491
+ fatalError ( )
492
+ }
493
+
494
+ @available ( * , unavailable)
495
+ @inlinable
496
+ public func valueWithDerivative< T: FloatingPoint , U: FloatingPoint , R> (
497
+ at x: T , _ y: U , in body: @escaping @differentiable ( T ) -> R
498
+ ) -> ( value: R , derivative: R . TangentVector )
499
+ where T. TangentVector : FloatingPoint {
500
+ fatalError ( )
501
+ }
502
+
389
503
// Value with gradient
390
504
391
505
@inlinable
@@ -416,84 +530,96 @@ public func valueWithGradient<T, U, V, R>(
416
530
return ( y, pullback ( R ( 1 ) ) )
417
531
}
418
532
419
- // Value with gradient (curried)
533
+ // Derivative (curried)
534
+
535
+ @available ( * , unavailable)
536
+ @inlinable
537
+ public func derivative< T: FloatingPoint , R> (
538
+ of body: @escaping @differentiable ( T ) -> R
539
+ ) -> ( T ) -> R . TangentVector where T. TangentVector : FloatingPoint {
540
+ fatalError ( )
541
+ }
542
+
543
+ @available ( * , unavailable)
544
+ @inlinable
545
+ public func derivative< T: FloatingPoint , U: FloatingPoint , R> (
546
+ of body: @escaping @differentiable ( T , U ) -> R
547
+ ) -> ( T , U ) -> R . TangentVector where T. TangentVector : FloatingPoint {
548
+ fatalError ( )
549
+ }
550
+
551
+ // Gradient (curried)
420
552
421
553
@inlinable
422
- public func valueWithGradient < T, R> (
554
+ public func gradient < T, R> (
423
555
of f: @escaping @differentiable ( T ) -> R
424
- ) -> ( T ) -> ( value : R , gradient : T . TangentVector )
556
+ ) -> ( T ) -> T . TangentVector
425
557
where R : FloatingPoint , R. TangentVector == R {
426
- return { x in valueWithGradient ( at: x, in: f) }
558
+ return { x in gradient ( at: x, in: f) }
427
559
}
428
560
429
561
@inlinable
430
- public func valueWithGradient < T, U, R> (
562
+ public func gradient < T, U, R> (
431
563
of f: @escaping @differentiable ( T , U ) -> R
432
- ) -> ( T , U ) -> ( value : R , gradient : ( T . TangentVector , U . TangentVector ) )
564
+ ) -> ( T , U ) -> ( T . TangentVector , U . TangentVector )
433
565
where R : FloatingPoint , R. TangentVector == R {
434
- return { x, y in valueWithGradient ( at: x, y, in: f) }
566
+ return { x, y in gradient ( at: x, y, in: f) }
435
567
}
436
568
437
569
@inlinable
438
- public func valueWithGradient < T, U, V, R> (
570
+ public func gradient < T, U, V, R> (
439
571
of f: @escaping @differentiable ( T , U , V ) -> R
440
- ) -> ( T , U , V )
441
- -> ( value: R ,
442
- gradient: ( T . TangentVector , U . TangentVector , V . TangentVector ) )
572
+ ) -> ( T , U , V ) -> ( T . TangentVector , U . TangentVector , V . TangentVector )
443
573
where R : FloatingPoint , R. TangentVector == R {
444
- return { x, y, z in valueWithGradient ( at: x, y, z, in: f) }
574
+ return { x, y, z in gradient ( at: x, y, z, in: f) }
445
575
}
446
576
447
- // Gradient
448
-
449
- @inlinable
450
- public func gradient< T, R> (
451
- at x: T , in f: @differentiable ( T ) -> R
452
- ) -> T . TangentVector
453
- where R : FloatingPoint , R. TangentVector == R {
454
- return pullback ( at: x, in: f) ( R ( 1 ) )
455
- }
577
+ // Value with derivative (curried)
456
578
579
+ @available ( * , unavailable)
457
580
@inlinable
458
- public func gradient < T , U , R> (
459
- at x : T , _ y : U , in f : @differentiable ( T , U ) -> R
460
- ) -> ( T . TangentVector , U . TangentVector )
461
- where R : FloatingPoint , R . TangentVector == R {
462
- return pullback ( at : x , y , in : f ) ( R ( 1 ) )
581
+ public func valueWithDerivative < T : FloatingPoint , R> (
582
+ of body : @escaping @differentiable ( T ) -> R
583
+ ) -> ( value : R , derivative : R . TangentVector )
584
+ where T . TangentVector: FloatingPoint {
585
+ fatalError ( )
463
586
}
464
587
588
+ @available ( * , unavailable)
465
589
@inlinable
466
- public func gradient < T , U, V , R> (
467
- at x : T , _ y : U , _ z : V , in f : @differentiable ( T , U , V ) -> R
468
- ) -> ( T . TangentVector , U . TangentVector , V . TangentVector )
469
- where R : FloatingPoint , R . TangentVector == R {
470
- return pullback ( at : x , y , z , in : f ) ( R ( 1 ) )
590
+ public func valueWithDerivative < T : FloatingPoint , U: FloatingPoint , R> (
591
+ of body : @escaping @differentiable ( T , U ) -> R
592
+ ) -> ( value : R , derivative : R . TangentVector )
593
+ where T . TangentVector : FloatingPoint , U . TangentVector: FloatingPoint {
594
+ fatalError ( )
471
595
}
472
596
473
- // Gradient (curried)
597
+ // Value with gradient (curried)
474
598
475
599
@inlinable
476
- public func gradient < T, R> (
600
+ public func valueWithGradient < T, R> (
477
601
of f: @escaping @differentiable ( T ) -> R
478
- ) -> ( T ) -> T . TangentVector
602
+ ) -> ( T ) -> ( value : R , gradient : T . TangentVector )
479
603
where R : FloatingPoint , R. TangentVector == R {
480
- return { x in gradient ( at: x, in: f) }
604
+ return { x in valueWithGradient ( at: x, in: f) }
481
605
}
482
606
483
607
@inlinable
484
- public func gradient < T, U, R> (
608
+ public func valueWithGradient < T, U, R> (
485
609
of f: @escaping @differentiable ( T , U ) -> R
486
- ) -> ( T , U ) -> ( T . TangentVector , U . TangentVector )
610
+ ) -> ( T , U ) -> ( value : R , gradient : ( T . TangentVector , U . TangentVector ) )
487
611
where R : FloatingPoint , R. TangentVector == R {
488
- return { x, y in gradient ( at: x, y, in: f) }
612
+ return { x, y in valueWithGradient ( at: x, y, in: f) }
489
613
}
490
614
491
615
@inlinable
492
- public func gradient < T, U, V, R> (
616
+ public func valueWithGradient < T, U, V, R> (
493
617
of f: @escaping @differentiable ( T , U , V ) -> R
494
- ) -> ( T , U , V ) -> ( T . TangentVector , U . TangentVector , V . TangentVector )
618
+ ) -> ( T , U , V )
619
+ -> ( value: R ,
620
+ gradient: ( T . TangentVector , U . TangentVector , V . TangentVector ) )
495
621
where R : FloatingPoint , R. TangentVector == R {
496
- return { x, y, z in gradient ( at: x, y, z, in: f) }
622
+ return { x, y, z in valueWithGradient ( at: x, y, z, in: f) }
497
623
}
498
624
499
625
//===----------------------------------------------------------------------===//
0 commit comments