Skip to content

Commit f47da5a

Browse files
authored
Merge pull request #21144 from rxwei/fill-autodiff-function-with-vjp
[AutoDiff] Add `Builtin.autodiffGet(JVP|VJP)` for extracting AD associated functions
2 parents 78f131b + f0ea3a9 commit f47da5a

14 files changed

+201
-84
lines changed

include/swift/AST/ASTContext.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,9 +953,6 @@ class ASTContext final {
953953
}
954954

955955
// SWIFT_ENABLE_TENSORFLOW
956-
/// Determine whether the given type is differentiable.
957-
bool isDifferentiable(CanType type, ModuleDecl *module);
958-
959956
/// Compute the tangent space of this manifold, if the given type represents a
960957
/// differentiable manifold.
961958
Optional<TangentSpace> getTangentSpace(CanType type, ModuleDecl *module);

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ class AutoDiffParameterIndices {
111111

112112
unsigned getNumNonSelfParameters() const;
113113

114-
AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag)
115-
: indices(numIndices), isMethodFlag(isMethodFlag) {}
114+
AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag,
115+
bool setAllParams = false)
116+
: indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}
116117

117118
AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
118119
: indices(indices), isMethodFlag(isMethodFlag) {}
@@ -122,7 +123,8 @@ class AutoDiffParameterIndices {
122123
/// given `functionType`. `isMethod` specifies whether to treat the function
123124
/// as a method.
124125
static AutoDiffParameterIndices *
125-
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod);
126+
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod,
127+
bool setAllParams = false);
126128

127129
bool isMethod() const { return isMethodFlag; }
128130

include/swift/AST/Builtins.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,14 @@ BUILTIN_SIL_OPERATION(AllocWithTailElems, "allocWithTailElems", Special)
375375
/// Projects the first tail-allocated element of type E from a class C.
376376
BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)
377377

378+
// SWIFT_ENABLE_TENSORFLOW
379+
/// autodifGetJVP has type <T: Differentiable, R: Differentiable>
380+
/// ((T) -> R) -> (T) -> (R, (T.TangentVector) -> R.TangentVector).
381+
BUILTIN_SIL_OPERATION(AutoDiffGetJVP, "autodiffGetJVP", Special)
382+
/// autodifGetVJP has type <T: Differentiable, R: Differentiable>
383+
/// ((T) -> R) -> (T) -> (R, (R.CotangentVector) -> T.CotangentVector).
384+
BUILTIN_SIL_OPERATION(AutoDiffGetVJP, "autodiffGetVJP", Special)
385+
378386
#undef BUILTIN_SIL_OPERATION
379387

380388
// BUILTIN_RUNTIME_CALL - A call into a runtime function.

include/swift/AST/Types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3120,6 +3120,10 @@ class AnyFunctionType : public TypeBase {
31203120
///
31213121
/// Pass `selfUncurried = true` when the function type is for a method whose
31223122
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
3123+
///
3124+
/// \note The original function type (`self`) need not be `@autodiff`, and the
3125+
/// resulting function will preserve all `ExtInfo` of the original function,
3126+
/// including `@autodiff`.
31233127
AnyFunctionType *getAutoDiffAssociatedFunctionType(
31243128
const AutoDiffParameterIndices &indices, unsigned resultIndex,
31253129
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,

lib/AST/ASTContext.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5188,10 +5188,6 @@ LayoutConstraint LayoutConstraint::getLayoutConstraint(LayoutConstraintKind Kind
51885188
}
51895189

51905190
// SWIFT_ENABLE_TENSORFLOW
5191-
bool ASTContext::isDifferentiable(CanType type, ModuleDecl *module) {
5192-
return getTangentSpace(type, module).hasValue();
5193-
}
5194-
51955191
Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
51965192
ModuleDecl *module) {
51975193
auto lookup = getImpl().TangentSpaces.find(type);

lib/AST/AutoDiff.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ static AnyFunctionType *unwrapSelfParameter(AnyFunctionType *functionType,
110110
/// as a method.
111111
AutoDiffParameterIndices *
112112
AutoDiffParameterIndices::create(ASTContext &C, AnyFunctionType *functionType,
113-
bool isMethod) {
113+
bool isMethod, bool setAllParams) {
114114
// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
115115
// gets called, which causes a small memory leak in the case that the
116116
// SmallBitVector decides to allocate some heap space.
@@ -119,7 +119,8 @@ AutoDiffParameterIndices::create(ASTContext &C, AnyFunctionType *functionType,
119119
unsigned paramCount =
120120
unwrapSelfParameter(functionType, isMethod)->getNumParams() +
121121
(isMethod ? 1 : 0);
122-
return ::new (mem) AutoDiffParameterIndices(paramCount, isMethod);
122+
return
123+
::new (mem) AutoDiffParameterIndices(paramCount, isMethod, setAllParams);
123124
}
124125

125126
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to

lib/AST/Builtins.cpp

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -456,24 +456,21 @@ namespace {
456456
private:
457457
GenericParamList *TheGenericParamList;
458458
SmallVector<GenericTypeParamDecl*, 2> GenericTypeParams;
459-
GenericEnvironment *GenericEnv = nullptr;
459+
// SWIFT_ENABLE_TENSORFLOW
460+
GenericSignatureBuilder Builder;
460461
SmallVector<AnyFunctionType::Param, 4> InterfaceParams;
461462
Type InterfaceResult;
462463

463464
public:
464465
BuiltinGenericSignatureBuilder(ASTContext &ctx, unsigned numGenericParams = 1)
465-
: Context(ctx) {
466+
// SWIFT_ENABLE_TENSORFLOW
467+
: Context(ctx), Builder(ctx) {
466468
TheGenericParamList = getGenericParams(ctx, numGenericParams,
467469
GenericTypeParams);
468470

469-
GenericSignatureBuilder Builder(ctx);
470471
for (auto gp : GenericTypeParams) {
471472
Builder.addGenericParameter(gp);
472473
}
473-
474-
auto GenericSig =
475-
std::move(Builder).computeGenericSignature(SourceLoc());
476-
GenericEnv = GenericSig->createGenericEnvironment();
477474
}
478475

479476
template <class G>
@@ -489,7 +486,20 @@ namespace {
489486
InterfaceResult = generator.build(*this);
490487
}
491488

489+
// SWIFT_ENABLE_TENSORFLOW
490+
template <class G>
491+
void addConformanceRequirement(const G &generator, ProtocolDecl *proto) {
492+
Requirement req(RequirementKind::Conformance,
493+
generator.build(*this),
494+
proto->getDeclaredType());
495+
auto source =
496+
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
497+
Builder.addRequirement(req, source, Context.getStdlibModule());
498+
}
499+
492500
ValueDecl *build(Identifier name) {
501+
auto GenericSig = std::move(Builder).computeGenericSignature(SourceLoc());
502+
auto GenericEnv = GenericSig->createGenericEnvironment();
493503
return getBuiltinGenericFunction(name, InterfaceParams,
494504
InterfaceResult,
495505
TheGenericParamList,
@@ -533,22 +543,6 @@ makeConcrete(Type type) {
533543
return { type };
534544
}
535545

536-
// SWIFT_ENABLE_TENSORFLOW
537-
template <class P, class... Gs>
538-
static BuiltinGenericSignatureBuilder::LambdaGenerator
539-
makeBoundGeneric(NominalTypeDecl *decl, const P &parentGenerator,
540-
const Gs & ...genericParamGenerators) {
541-
return {
542-
[=](BuiltinGenericSignatureBuilder &builder) -> Type {
543-
Type parent = parentGenerator.build(builder);
544-
Type genParams[] = {
545-
genericParamGenerators.build(builder)...
546-
};
547-
return BoundGenericType::get(decl, parent, genParams);
548-
}
549-
};
550-
}
551-
552546
static BuiltinGenericSignatureBuilder::ParameterGenerator
553547
makeGenericParam(unsigned index = 0) {
554548
return { index };
@@ -985,8 +979,7 @@ static ValueDecl *getAutoDiffCreateTape(ASTContext &Context, Identifier Id) {
985979
// <T> () -> (Swift._AutoDiffTape<T>)
986980
BuiltinGenericSignatureBuilder builder(Context, 1);
987981
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
988-
builder.setResult(
989-
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam()));
982+
builder.setResult(makeBoundGenericType(tapeDecl, makeGenericParam()));
990983
return builder.build(Id);
991984
}
992985

@@ -995,7 +988,7 @@ static ValueDecl *getAutoDiffPushToTape(ASTContext &Context, Identifier Id) {
995988
BuiltinGenericSignatureBuilder builder(Context, 1);
996989
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
997990
auto T = makeGenericParam();
998-
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T));
991+
builder.addParameter(makeBoundGenericType(tapeDecl, T));
999992
builder.addParameter(T);
1000993
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context)));
1001994
builder.setResult(makeConcrete(Context.TheEmptyTupleType));
@@ -1007,7 +1000,7 @@ static ValueDecl *getAutoDiffPopFromTape(ASTContext &Context, Identifier Id) {
10071000
BuiltinGenericSignatureBuilder builder(Context, 1);
10081001
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
10091002
auto T = makeGenericParam();
1010-
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T));
1003+
builder.addParameter(makeBoundGenericType(tapeDecl, T));
10111004
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context)));
10121005
builder.setResult(T);
10131006
return builder.build(Id);
@@ -1017,12 +1010,84 @@ static ValueDecl *getAutoDiffDestroyTape(ASTContext &Context, Identifier Id) {
10171010
// <T> (Swift._AutoDiffTape<T>) -> ()
10181011
BuiltinGenericSignatureBuilder builder(Context, 1);
10191012
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
1020-
builder.addParameter(
1021-
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam()));
1013+
builder.addParameter(makeBoundGenericType(tapeDecl, makeGenericParam()));
10221014
builder.setResult(makeConcrete(Context.TheEmptyTupleType));
10231015
return builder.build(Id);
10241016
}
10251017

1018+
static ValueDecl *getAutoDiffGetAssociatedFunction(
1019+
ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind,
1020+
unsigned order, unsigned arity, bool isThrowing = false) {
1021+
assert(arity >= 1);
1022+
assert(order == 1 && "higher-order differentiation is not supported yet");
1023+
// JVP(non-throwing):
1024+
// <...T...(arity), R> (@autodiff (...T) -> R)
1025+
// -> (...T) -> (R, (...T.TangentVector) -> R.TangentVector)
1026+
// JVP(throwing):
1027+
// <...T...(arity), R> (@autodiff (...T) throws -> R)
1028+
// -> (...T) throws -> (R, (...T.TangentVector) -> R.TangentVector)
1029+
// VJP(non-throwing):
1030+
// <...T...(arity), R> (@autodiff (...T) -> R)
1031+
// -> (...T) -> (R, (R.CotangentVector) -> ...T.CotangentVector)
1032+
// VJP(throwing):
1033+
// <...T...(arity), R> (@autodiff (...T) throws -> R)
1034+
// -> (...T) throws -> (R, (R.CotangentVector) -> ...T.CotangentVector)
1035+
BuiltinGenericSignatureBuilder builder(Context,
1036+
/*numGenericParams*/ 1 + arity);
1037+
// Look up the Differentiable protocol.
1038+
SmallVector<ValueDecl *, 1> diffableProtoLookup;
1039+
Context.lookupInSwiftModule("Differentiable", diffableProtoLookup);
1040+
assert(diffableProtoLookup.size() == 1);
1041+
auto *diffableProto = cast<ProtocolDecl>(diffableProtoLookup.front());
1042+
// Create type parameters and add conformance constraints.
1043+
auto R = makeGenericParam(arity);
1044+
builder.addConformanceRequirement(R, diffableProto);
1045+
SmallVector<decltype(R), 2> Ts;
1046+
for (auto i : range(arity)) {
1047+
auto T = makeGenericParam(i);
1048+
builder.addConformanceRequirement(T, diffableProto);
1049+
Ts.push_back(T);
1050+
}
1051+
// Generator for the argument.
1052+
BuiltinGenericSignatureBuilder::LambdaGenerator argGen {
1053+
// Generator for the function type at the argument position, i.e. the
1054+
// function being differentiated.
1055+
[=, &Ts](BuiltinGenericSignatureBuilder &builder) -> Type {
1056+
FunctionType::ExtInfo ext;
1057+
auto extInfo = FunctionType::ExtInfo()
1058+
.withDifferentiability(FunctionTypeDifferentiability::Bidirectional)
1059+
.withNoEscape();
1060+
if (isThrowing)
1061+
extInfo = extInfo.withThrows();
1062+
SmallVector<FunctionType::Param, 2> params;
1063+
for (auto &paramGen : Ts)
1064+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1065+
return FunctionType::get(params, R.build(builder))->withExtInfo(extInfo);
1066+
}
1067+
};
1068+
AnyFunctionType *origFnTy = argGen.build(builder)->castTo<AnyFunctionType>();
1069+
origFnTy = origFnTy->withExtInfo(origFnTy->getExtInfo()
1070+
.withDifferentiability(FunctionTypeDifferentiability::None));
1071+
auto *paramIndices = AutoDiffParameterIndices::create(Context, origFnTy,
1072+
/*isMethod*/ false,
1073+
/*setAllParams*/ true);
1074+
// Generator for the resultant function type, i.e. the AD associated function.
1075+
BuiltinGenericSignatureBuilder::LambdaGenerator resultGen {
1076+
[=](BuiltinGenericSignatureBuilder &builder) -> Type {
1077+
// TODO(rxwei): Use parameter indices and differentiation order that are
1078+
// stored in the function type.
1079+
auto *vjpType = origFnTy->getAutoDiffAssociatedFunctionType(
1080+
*paramIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1,
1081+
kind, /*lookupConformance*/ nullptr);
1082+
vjpType = vjpType->withExtInfo(vjpType->getExtInfo().withNoEscape(false));
1083+
return vjpType;
1084+
}
1085+
};
1086+
builder.addParameter(argGen);
1087+
builder.setResult(resultGen);
1088+
return builder.build(Id);
1089+
}
1090+
10261091
static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {
10271092
auto int1Type = BuiltinIntegerType::get(1, Context);
10281093
auto optionalRawPointerType = BoundGenericEnumType::get(
@@ -1958,6 +2023,14 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
19582023
return getAutoDiffPopFromTape(Context, Id);
19592024
case BuiltinValueKind::AutoDiffDestroyTape:
19602025
return getAutoDiffDestroyTape(Context, Id);
2026+
case BuiltinValueKind::AutoDiffGetJVP:
2027+
return getAutoDiffGetAssociatedFunction(Context, Id,
2028+
AutoDiffAssociatedFunctionKind::JVP,
2029+
/*order*/ 1, /*arity*/ 1);
2030+
case BuiltinValueKind::AutoDiffGetVJP:
2031+
return getAutoDiffGetAssociatedFunction(Context, Id,
2032+
AutoDiffAssociatedFunctionKind::VJP,
2033+
/*order*/ 1, /*arity*/ 1);
19612034
case BuiltinValueKind::PoundAssert:
19622035
return getPoundAssert(Context, Id);
19632036

lib/AST/Type.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4188,24 +4188,36 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
41884188
auto *cotangentDependentType = DependentMemberType::get(
41894189
differentiableProtocol->getDeclaredInterfaceType(),
41904190
cast<AssociatedTypeDecl>(cotangentLookup[0]));
4191-
auto getAssociatedType = [&](Type type,
4192-
DependentMemberType *dependentType) -> CanType {
4193-
// Builtins are their own Tangent/Cotangent.
4194-
if (type->is<BuiltinType>()) return type->getCanonicalType();
4195-
4196-
// TODO: If this is a tuple, recursively get the associated types of its
4197-
// components.
4198-
4199-
// Try to get the associated type, and return it if found (if the result is
4200-
// non-null and non-`DependentMemberType`).
4201-
auto assocTy = dependentType->substBaseType(type, lookupConformance);
4202-
if (assocTy && !assocTy->is<DependentMemberType>())
4203-
return assocTy->getCanonicalType();
4204-
4205-
// When the type does not have an associated type, fallback to treating it
4206-
// as its own Tangent/Cotangent.
4207-
// TODO: We should eliminate all instances where this happens.
4208-
return type->getCanonicalType();
4191+
4192+
// `getAssociatedFunctionType` takes a base type and a protocol-dependent type
4193+
// and returns a canonical type representing the associated type after any
4194+
// possible substitutions.
4195+
// For base types that are tuples, applies `getAssociatedFunctionType` to
4196+
// every element type and a tuple type of new elements.
4197+
// For base types that are builtins, returns the types themselves.
4198+
std::function<CanType(Type, DependentMemberType *)> getAssociatedType
4199+
= [&](Type type, DependentMemberType *dependentType) {
4200+
// Builtins floats are their own Tangent/Cotangent.
4201+
if (type->is<BuiltinType>())
4202+
return type->getCanonicalType();
4203+
// Tuples' Tangent/Cotangent is a tuple of each element's Tangent/Cotangent.
4204+
if (auto *tupleTy = type->getAs<TupleType>()) {
4205+
SmallVector<TupleTypeElt, 8> newElts;
4206+
for (auto elt : tupleTy->getElements())
4207+
newElts.push_back(
4208+
elt.getWithType(getAssociatedType(elt.getType(), dependentType)));
4209+
return TupleType::get(newElts, ctx)->getCanonicalType();
4210+
}
4211+
// If `lookupConformance` is not provided by the caller, try to get the
4212+
// associated type by substituting the base type for a protocol associated
4213+
// type, and return it if found (if the result is non-null and
4214+
// non-`DependentMemberType`).
4215+
if (lookupConformance)
4216+
if (auto assocTy = dependentType->substBaseType(type, lookupConformance))
4217+
return assocTy->getCanonicalType();
4218+
// Otherwise, just return the base type's dependent member type.
4219+
return DependentMemberType::get(type, dependentType->getAssocType())
4220+
->getCanonicalType();
42094221
};
42104222

42114223
SmallVector<Type, 8> wrtParamTypes;

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,32 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF,
10411041
return ManagedValue::forUnmanaged(val);
10421042
}
10431043

1044+
// SWIFT_ENABLE_TENSORFLOW
1045+
/// Specialized emitter for Builtin.addressOfBorrow.
1046+
static ManagedValue emitBuiltinAutoDiffGetJVP(SILGenFunction &SGF,
1047+
SILLocation loc,
1048+
SubstitutionMap substitutions,
1049+
Expr *argument,
1050+
SGFContext C) {
1051+
auto argVal = SGF.emitRValue(argument);
1052+
auto jvp = SGF.getBuilder().createAutoDiffFunctionExtract(
1053+
loc, AutoDiffFunctionExtractee::JVP, /*differentiationOrder*/ 1,
1054+
std::move(argVal).forwardAsSingleValue(SGF, loc));
1055+
return SGF.emitManagedRValueWithCleanup(jvp);
1056+
}
1057+
1058+
static ManagedValue emitBuiltinAutoDiffGetVJP(SILGenFunction &SGF,
1059+
SILLocation loc,
1060+
SubstitutionMap substitutions,
1061+
Expr *argument,
1062+
SGFContext C) {
1063+
auto argVal = SGF.emitRValue(argument);
1064+
auto vjp = SGF.getBuilder().createAutoDiffFunctionExtract(
1065+
loc, AutoDiffFunctionExtractee::VJP, /*differentiationOrder*/ 1,
1066+
std::move(argVal).forwardAsSingleValue(SGF, loc));
1067+
return SGF.emitManagedRValueWithCleanup(vjp);
1068+
}
1069+
10441070
Optional<SpecializedEmitter>
10451071
SpecializedEmitter::forDecl(SILGenModule &SGM, SILDeclRef function) {
10461072
// Only consider standalone declarations in the Builtin module.

lib/SILGen/SILGenExpr.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5733,9 +5733,6 @@ RValue RValueEmitter::visitAutoDiffFunctionExpr(AutoDiffFunctionExpr *E,
57335733
RValue RValueEmitter::visitAutoDiffFunctionExtractOriginalExpr(
57345734
AutoDiffFunctionExtractOriginalExpr *E, SGFContext C) {
57355735
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
5736-
llvm::outs() << "Difffunc subexpr type = " << E->getSubExpr()->getType() << '\n';
5737-
llvm::outs() << "Difffunc = " << diffFunc.getValue() << '\n';
5738-
llvm::outs().flush();
57395736
auto *orig = SGF.B.createAutoDiffFunctionExtractOriginal(
57405737
E, diffFunc.forward(SGF));
57415738
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(orig));

0 commit comments

Comments
 (0)