@@ -456,24 +456,21 @@ namespace {
456
456
private:
457
457
GenericParamList *TheGenericParamList;
458
458
SmallVector<GenericTypeParamDecl*, 2 > GenericTypeParams;
459
- GenericEnvironment *GenericEnv = nullptr ;
459
+ // SWIFT_ENABLE_TENSORFLOW
460
+ GenericSignatureBuilder Builder;
460
461
SmallVector<AnyFunctionType::Param, 4 > InterfaceParams;
461
462
Type InterfaceResult;
462
463
463
464
public:
464
465
BuiltinGenericSignatureBuilder (ASTContext &ctx, unsigned numGenericParams = 1 )
465
- : Context(ctx) {
466
+ // SWIFT_ENABLE_TENSORFLOW
467
+ : Context(ctx), Builder(ctx) {
466
468
TheGenericParamList = getGenericParams (ctx, numGenericParams,
467
469
GenericTypeParams);
468
470
469
- GenericSignatureBuilder Builder (ctx);
470
471
for (auto gp : GenericTypeParams) {
471
472
Builder.addGenericParameter (gp);
472
473
}
473
-
474
- auto GenericSig =
475
- std::move (Builder).computeGenericSignature (SourceLoc ());
476
- GenericEnv = GenericSig->createGenericEnvironment ();
477
474
}
478
475
479
476
template <class G >
@@ -489,7 +486,20 @@ namespace {
489
486
InterfaceResult = generator.build (*this );
490
487
}
491
488
489
+ // SWIFT_ENABLE_TENSORFLOW
490
+ template <class G >
491
+ void addConformanceRequirement (const G &generator, ProtocolDecl *proto) {
492
+ Requirement req (RequirementKind::Conformance,
493
+ generator.build (*this ),
494
+ proto->getDeclaredType ());
495
+ auto source =
496
+ GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
497
+ Builder.addRequirement (req, source, Context.getStdlibModule ());
498
+ }
499
+
492
500
ValueDecl *build (Identifier name) {
501
+ auto GenericSig = std::move (Builder).computeGenericSignature (SourceLoc ());
502
+ auto GenericEnv = GenericSig->createGenericEnvironment ();
493
503
return getBuiltinGenericFunction (name, InterfaceParams,
494
504
InterfaceResult,
495
505
TheGenericParamList,
@@ -533,22 +543,6 @@ makeConcrete(Type type) {
533
543
return { type };
534
544
}
535
545
536
- // SWIFT_ENABLE_TENSORFLOW
537
- template <class P , class ... Gs>
538
- static BuiltinGenericSignatureBuilder::LambdaGenerator
539
- makeBoundGeneric (NominalTypeDecl *decl, const P &parentGenerator,
540
- const Gs & ...genericParamGenerators) {
541
- return {
542
- [=](BuiltinGenericSignatureBuilder &builder) -> Type {
543
- Type parent = parentGenerator.build (builder);
544
- Type genParams[] = {
545
- genericParamGenerators.build (builder)...
546
- };
547
- return BoundGenericType::get (decl, parent, genParams);
548
- }
549
- };
550
- }
551
-
552
546
static BuiltinGenericSignatureBuilder::ParameterGenerator
553
547
makeGenericParam (unsigned index = 0 ) {
554
548
return { index };
@@ -985,8 +979,7 @@ static ValueDecl *getAutoDiffCreateTape(ASTContext &Context, Identifier Id) {
985
979
// <T> () -> (Swift._AutoDiffTape<T>)
986
980
BuiltinGenericSignatureBuilder builder (Context, 1 );
987
981
auto *tapeDecl = Context.get_AutoDiffTapeDecl ();
988
- builder.setResult (
989
- makeBoundGeneric (tapeDecl, makeConcrete (Type ()), makeGenericParam ()));
982
+ builder.setResult (makeBoundGenericType (tapeDecl, makeGenericParam ()));
990
983
return builder.build (Id);
991
984
}
992
985
@@ -995,7 +988,7 @@ static ValueDecl *getAutoDiffPushToTape(ASTContext &Context, Identifier Id) {
995
988
BuiltinGenericSignatureBuilder builder (Context, 1 );
996
989
auto *tapeDecl = Context.get_AutoDiffTapeDecl ();
997
990
auto T = makeGenericParam ();
998
- builder.addParameter (makeBoundGeneric (tapeDecl, makeConcrete ( Type ()) , T));
991
+ builder.addParameter (makeBoundGenericType (tapeDecl, T));
999
992
builder.addParameter (T);
1000
993
builder.addParameter (makeConcrete (BuiltinIntegerType::getWordType (Context)));
1001
994
builder.setResult (makeConcrete (Context.TheEmptyTupleType ));
@@ -1007,7 +1000,7 @@ static ValueDecl *getAutoDiffPopFromTape(ASTContext &Context, Identifier Id) {
1007
1000
BuiltinGenericSignatureBuilder builder (Context, 1 );
1008
1001
auto *tapeDecl = Context.get_AutoDiffTapeDecl ();
1009
1002
auto T = makeGenericParam ();
1010
- builder.addParameter (makeBoundGeneric (tapeDecl, makeConcrete ( Type ()) , T));
1003
+ builder.addParameter (makeBoundGenericType (tapeDecl, T));
1011
1004
builder.addParameter (makeConcrete (BuiltinIntegerType::getWordType (Context)));
1012
1005
builder.setResult (T);
1013
1006
return builder.build (Id);
@@ -1017,12 +1010,84 @@ static ValueDecl *getAutoDiffDestroyTape(ASTContext &Context, Identifier Id) {
1017
1010
// <T> (Swift._AutoDiffTape<T>) -> ()
1018
1011
BuiltinGenericSignatureBuilder builder (Context, 1 );
1019
1012
auto *tapeDecl = Context.get_AutoDiffTapeDecl ();
1020
- builder.addParameter (
1021
- makeBoundGeneric (tapeDecl, makeConcrete (Type ()), makeGenericParam ()));
1013
+ builder.addParameter (makeBoundGenericType (tapeDecl, makeGenericParam ()));
1022
1014
builder.setResult (makeConcrete (Context.TheEmptyTupleType ));
1023
1015
return builder.build (Id);
1024
1016
}
1025
1017
1018
+ static ValueDecl *getAutoDiffGetAssociatedFunction (
1019
+ ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind,
1020
+ unsigned order, unsigned arity, bool isThrowing = false ) {
1021
+ assert (arity >= 1 );
1022
+ assert (order == 1 && " higher-order differentiation is not supported yet" );
1023
+ // JVP(non-throwing):
1024
+ // <...T...(arity), R> (@autodiff (...T) -> R)
1025
+ // -> (...T) -> (R, (...T.TangentVector) -> R.TangentVector)
1026
+ // JVP(throwing):
1027
+ // <...T...(arity), R> (@autodiff (...T) throws -> R)
1028
+ // -> (...T) throws -> (R, (...T.TangentVector) -> R.TangentVector)
1029
+ // VJP(non-throwing):
1030
+ // <...T...(arity), R> (@autodiff (...T) -> R)
1031
+ // -> (...T) -> (R, (R.CotangentVector) -> ...T.CotangentVector)
1032
+ // VJP(throwing):
1033
+ // <...T...(arity), R> (@autodiff (...T) throws -> R)
1034
+ // -> (...T) throws -> (R, (R.CotangentVector) -> ...T.CotangentVector)
1035
+ BuiltinGenericSignatureBuilder builder (Context,
1036
+ /* numGenericParams*/ 1 + arity);
1037
+ // Look up the Differentiable protocol.
1038
+ SmallVector<ValueDecl *, 1 > diffableProtoLookup;
1039
+ Context.lookupInSwiftModule (" Differentiable" , diffableProtoLookup);
1040
+ assert (diffableProtoLookup.size () == 1 );
1041
+ auto *diffableProto = cast<ProtocolDecl>(diffableProtoLookup.front ());
1042
+ // Create type parameters and add conformance constraints.
1043
+ auto R = makeGenericParam (arity);
1044
+ builder.addConformanceRequirement (R, diffableProto);
1045
+ SmallVector<decltype (R), 2 > Ts;
1046
+ for (auto i : range (arity)) {
1047
+ auto T = makeGenericParam (i);
1048
+ builder.addConformanceRequirement (T, diffableProto);
1049
+ Ts.push_back (T);
1050
+ }
1051
+ // Generator for the argument.
1052
+ BuiltinGenericSignatureBuilder::LambdaGenerator argGen {
1053
+ // Generator for the function type at the argument position, i.e. the
1054
+ // function being differentiated.
1055
+ [=, &Ts](BuiltinGenericSignatureBuilder &builder) -> Type {
1056
+ FunctionType::ExtInfo ext;
1057
+ auto extInfo = FunctionType::ExtInfo ()
1058
+ .withDifferentiability (FunctionTypeDifferentiability::Bidirectional)
1059
+ .withNoEscape ();
1060
+ if (isThrowing)
1061
+ extInfo = extInfo.withThrows ();
1062
+ SmallVector<FunctionType::Param, 2 > params;
1063
+ for (auto ¶mGen : Ts)
1064
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1065
+ return FunctionType::get (params, R.build (builder))->withExtInfo (extInfo);
1066
+ }
1067
+ };
1068
+ AnyFunctionType *origFnTy = argGen.build (builder)->castTo <AnyFunctionType>();
1069
+ origFnTy = origFnTy->withExtInfo (origFnTy->getExtInfo ()
1070
+ .withDifferentiability (FunctionTypeDifferentiability::None));
1071
+ auto *paramIndices = AutoDiffParameterIndices::create (Context, origFnTy,
1072
+ /* isMethod*/ false ,
1073
+ /* setAllParams*/ true );
1074
+ // Generator for the resultant function type, i.e. the AD associated function.
1075
+ BuiltinGenericSignatureBuilder::LambdaGenerator resultGen {
1076
+ [=](BuiltinGenericSignatureBuilder &builder) -> Type {
1077
+ // TODO(rxwei): Use parameter indices and differentiation order that are
1078
+ // stored in the function type.
1079
+ auto *vjpType = origFnTy->getAutoDiffAssociatedFunctionType (
1080
+ *paramIndices, /* resultIndex*/ 0 , /* differentiationOrder*/ 1 ,
1081
+ kind, /* lookupConformance*/ nullptr );
1082
+ vjpType = vjpType->withExtInfo (vjpType->getExtInfo ().withNoEscape (false ));
1083
+ return vjpType;
1084
+ }
1085
+ };
1086
+ builder.addParameter (argGen);
1087
+ builder.setResult (resultGen);
1088
+ return builder.build (Id);
1089
+ }
1090
+
1026
1091
static ValueDecl *getPoundAssert (ASTContext &Context, Identifier Id) {
1027
1092
auto int1Type = BuiltinIntegerType::get (1 , Context);
1028
1093
auto optionalRawPointerType = BoundGenericEnumType::get (
@@ -1958,6 +2023,14 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
1958
2023
return getAutoDiffPopFromTape (Context, Id);
1959
2024
case BuiltinValueKind::AutoDiffDestroyTape:
1960
2025
return getAutoDiffDestroyTape (Context, Id);
2026
+ case BuiltinValueKind::AutoDiffGetJVP:
2027
+ return getAutoDiffGetAssociatedFunction (Context, Id,
2028
+ AutoDiffAssociatedFunctionKind::JVP,
2029
+ /* order*/ 1 , /* arity*/ 1 );
2030
+ case BuiltinValueKind::AutoDiffGetVJP:
2031
+ return getAutoDiffGetAssociatedFunction (Context, Id,
2032
+ AutoDiffAssociatedFunctionKind::VJP,
2033
+ /* order*/ 1 , /* arity*/ 1 );
1961
2034
case BuiltinValueKind::PoundAssert:
1962
2035
return getPoundAssert (Context, Id);
1963
2036
0 commit comments