@@ -885,3 +885,76 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
885
885
_box. _move ( along: direction. _box)
886
886
}
887
887
}
888
+
889
+ //===----------------------------------------------------------------------===//
890
+ // Differentiable higher order functions for collections
891
+ //===----------------------------------------------------------------------===//
892
+
893
+ public extension Array where Element: Differentiable {
894
+ @differentiable ( wrt: ( self , initialResult) , vjp: _vjpDifferentiableReduce)
895
+ func differentiableReduce< Result: Differentiable > (
896
+ _ initialResult: Result ,
897
+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
898
+ ) -> Result {
899
+ reduce ( initialResult, nextPartialResult)
900
+ }
901
+
902
+ @usableFromInline
903
+ internal func _vjpDifferentiableReduce< Result: Differentiable > (
904
+ _ initialResult: Result ,
905
+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
906
+ ) -> ( value: Result ,
907
+ pullback: ( Result . TangentVector )
908
+ -> ( Array . TangentVector , Result . TangentVector ) ) {
909
+ var pullbacks :
910
+ [ ( Result . TangentVector ) -> ( Result . TangentVector , Element . TangentVector ) ]
911
+ = [ ]
912
+ let count = self . count
913
+ pullbacks. reserveCapacity ( count)
914
+ var result = initialResult
915
+ for element in self {
916
+ let ( y, pb) =
917
+ Swift . valueWithPullback ( at: result, element, in: nextPartialResult)
918
+ result = y
919
+ pullbacks. append ( pb)
920
+ }
921
+ return ( value: result, pullback: { tangent in
922
+ var resultTangent = tangent
923
+ var elementTangents = TangentVector ( [ ] )
924
+ elementTangents. base. reserveCapacity ( count)
925
+ for pullback in pullbacks. reversed ( ) {
926
+ let ( newResultTangent, elementTangent) = pullback ( resultTangent)
927
+ resultTangent = newResultTangent
928
+ elementTangents. base. append ( elementTangent)
929
+ }
930
+ return ( TangentVector ( elementTangents. base. reversed ( ) ) , resultTangent)
931
+ } )
932
+ }
933
+ }
934
+
935
+ public extension Array where Element: Differentiable {
936
+ @differentiable ( wrt: self , vjp: _vjpDifferentiableMap)
937
+ func differentiableMap< Result: Differentiable > (
938
+ _ body: @differentiable ( Element ) -> Result
939
+ ) -> [ Result ] {
940
+ map ( body)
941
+ }
942
+
943
+ @usableFromInline
944
+ internal func _vjpDifferentiableMap< Result: Differentiable > (
945
+ _ body: @differentiable ( Element ) -> Result
946
+ ) -> ( value: [ Result ] ,
947
+ pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector ) {
948
+ var values : [ Result ] = [ ]
949
+ var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
950
+ for x in self {
951
+ let ( y, pb) = Swift . valueWithPullback ( at: x, in: body)
952
+ values. append ( y)
953
+ pullbacks. append ( pb)
954
+ }
955
+ func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
956
+ . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
957
+ }
958
+ return ( value: values, pullback: pullback)
959
+ }
960
+ }
0 commit comments