-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Add Builtin.autodiffGet(JVP|VJP)
for extracting AD associated functions
#21144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -456,24 +456,21 @@ namespace { | |
private: | ||
GenericParamList *TheGenericParamList; | ||
SmallVector<GenericTypeParamDecl*, 2> GenericTypeParams; | ||
GenericEnvironment *GenericEnv = nullptr; | ||
// SWIFT_ENABLE_TENSORFLOW | ||
GenericSignatureBuilder Builder; | ||
SmallVector<AnyFunctionType::Param, 4> InterfaceParams; | ||
Type InterfaceResult; | ||
|
||
public: | ||
BuiltinGenericSignatureBuilder(ASTContext &ctx, unsigned numGenericParams = 1) | ||
: Context(ctx) { | ||
// SWIFT_ENABLE_TENSORFLOW | ||
: Context(ctx), Builder(ctx) { | ||
TheGenericParamList = getGenericParams(ctx, numGenericParams, | ||
GenericTypeParams); | ||
|
||
GenericSignatureBuilder Builder(ctx); | ||
for (auto gp : GenericTypeParams) { | ||
Builder.addGenericParameter(gp); | ||
} | ||
|
||
auto GenericSig = | ||
std::move(Builder).computeGenericSignature(SourceLoc()); | ||
GenericEnv = GenericSig->createGenericEnvironment(); | ||
} | ||
|
||
template <class G> | ||
|
@@ -489,7 +486,20 @@ namespace { | |
InterfaceResult = generator.build(*this); | ||
} | ||
|
||
// SWIFT_ENABLE_TENSORFLOW | ||
template <class G> | ||
void addConformanceRequirement(const G &generator, ProtocolDecl *proto) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DougGregor, is this the right thing to do? If so, shall I upstream this? |
||
Requirement req(RequirementKind::Conformance, | ||
generator.build(*this), | ||
proto->getDeclaredType()); | ||
auto source = | ||
GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); | ||
Builder.addRequirement(req, source, Context.getStdlibModule()); | ||
} | ||
|
||
ValueDecl *build(Identifier name) { | ||
auto GenericSig = std::move(Builder).computeGenericSignature(SourceLoc()); | ||
auto GenericEnv = GenericSig->createGenericEnvironment(); | ||
return getBuiltinGenericFunction(name, InterfaceParams, | ||
InterfaceResult, | ||
TheGenericParamList, | ||
|
@@ -533,22 +543,6 @@ makeConcrete(Type type) { | |
return { type }; | ||
} | ||
|
||
// SWIFT_ENABLE_TENSORFLOW | ||
template <class P, class... Gs> | ||
static BuiltinGenericSignatureBuilder::LambdaGenerator | ||
makeBoundGeneric(NominalTypeDecl *decl, const P &parentGenerator, | ||
const Gs & ...genericParamGenerators) { | ||
return { | ||
[=](BuiltinGenericSignatureBuilder &builder) -> Type { | ||
Type parent = parentGenerator.build(builder); | ||
Type genParams[] = { | ||
genericParamGenerators.build(builder)... | ||
}; | ||
return BoundGenericType::get(decl, parent, genParams); | ||
} | ||
}; | ||
} | ||
|
||
static BuiltinGenericSignatureBuilder::ParameterGenerator | ||
makeGenericParam(unsigned index = 0) { | ||
return { index }; | ||
|
@@ -985,8 +979,7 @@ static ValueDecl *getAutoDiffCreateTape(ASTContext &Context, Identifier Id) { | |
// <T> () -> (Swift._AutoDiffTape<T>) | ||
BuiltinGenericSignatureBuilder builder(Context, 1); | ||
auto *tapeDecl = Context.get_AutoDiffTapeDecl(); | ||
builder.setResult( | ||
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam())); | ||
builder.setResult(makeBoundGenericType(tapeDecl, makeGenericParam())); | ||
return builder.build(Id); | ||
} | ||
|
||
|
@@ -995,7 +988,7 @@ static ValueDecl *getAutoDiffPushToTape(ASTContext &Context, Identifier Id) { | |
BuiltinGenericSignatureBuilder builder(Context, 1); | ||
auto *tapeDecl = Context.get_AutoDiffTapeDecl(); | ||
auto T = makeGenericParam(); | ||
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T)); | ||
builder.addParameter(makeBoundGenericType(tapeDecl, T)); | ||
builder.addParameter(T); | ||
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context))); | ||
builder.setResult(makeConcrete(Context.TheEmptyTupleType)); | ||
|
@@ -1007,7 +1000,7 @@ static ValueDecl *getAutoDiffPopFromTape(ASTContext &Context, Identifier Id) { | |
BuiltinGenericSignatureBuilder builder(Context, 1); | ||
auto *tapeDecl = Context.get_AutoDiffTapeDecl(); | ||
auto T = makeGenericParam(); | ||
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T)); | ||
builder.addParameter(makeBoundGenericType(tapeDecl, T)); | ||
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context))); | ||
builder.setResult(T); | ||
return builder.build(Id); | ||
|
@@ -1017,12 +1010,84 @@ static ValueDecl *getAutoDiffDestroyTape(ASTContext &Context, Identifier Id) { | |
// <T> (Swift._AutoDiffTape<T>) -> () | ||
BuiltinGenericSignatureBuilder builder(Context, 1); | ||
auto *tapeDecl = Context.get_AutoDiffTapeDecl(); | ||
builder.addParameter( | ||
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam())); | ||
builder.addParameter(makeBoundGenericType(tapeDecl, makeGenericParam())); | ||
builder.setResult(makeConcrete(Context.TheEmptyTupleType)); | ||
return builder.build(Id); | ||
} | ||
|
||
static ValueDecl *getAutoDiffGetAssociatedFunction( | ||
ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind, | ||
unsigned order, unsigned arity, bool isThrowing = false) { | ||
assert(arity >= 1); | ||
assert(order == 1 && "higher-order differentiation is not supported yet"); | ||
// JVP(non-throwing): | ||
// <...T...(arity), R> (@autodiff (...T) -> R) | ||
// -> (...T) -> (R, (...T.TangentVector) -> R.TangentVector) | ||
// JVP(throwing): | ||
// <...T...(arity), R> (@autodiff (...T) throws -> R) | ||
// -> (...T) throws -> (R, (...T.TangentVector) -> R.TangentVector) | ||
// VJP(non-throwing): | ||
// <...T...(arity), R> (@autodiff (...T) -> R) | ||
// -> (...T) -> (R, (R.CotangentVector) -> ...T.CotangentVector) | ||
// VJP(throwing): | ||
// <...T...(arity), R> (@autodiff (...T) throws -> R) | ||
// -> (...T) throws -> (R, (R.CotangentVector) -> ...T.CotangentVector) | ||
BuiltinGenericSignatureBuilder builder(Context, | ||
/*numGenericParams*/ 1 + arity); | ||
// Look up the Differentiable protocol. | ||
SmallVector<ValueDecl *, 1> diffableProtoLookup; | ||
Context.lookupInSwiftModule("Differentiable", diffableProtoLookup); | ||
assert(diffableProtoLookup.size() == 1); | ||
auto *diffableProto = cast<ProtocolDecl>(diffableProtoLookup.front()); | ||
// Create type parameters and add conformance constraints. | ||
auto R = makeGenericParam(arity); | ||
builder.addConformanceRequirement(R, diffableProto); | ||
SmallVector<decltype(R), 2> Ts; | ||
for (auto i : range(arity)) { | ||
auto T = makeGenericParam(i); | ||
builder.addConformanceRequirement(T, diffableProto); | ||
Ts.push_back(T); | ||
} | ||
// Generator for the argument. | ||
BuiltinGenericSignatureBuilder::LambdaGenerator argGen { | ||
// Generator for the function type at the argument position, i.e. the | ||
// function being differentiated. | ||
[=, &Ts](BuiltinGenericSignatureBuilder &builder) -> Type { | ||
FunctionType::ExtInfo ext; | ||
auto extInfo = FunctionType::ExtInfo() | ||
.withDifferentiability(FunctionTypeDifferentiability::Bidirectional) | ||
.withNoEscape(); | ||
if (isThrowing) | ||
extInfo = extInfo.withThrows(); | ||
SmallVector<FunctionType::Param, 2> params; | ||
for (auto ¶mGen : Ts) | ||
params.push_back(FunctionType::Param(paramGen.build(builder))); | ||
return FunctionType::get(params, R.build(builder))->withExtInfo(extInfo); | ||
} | ||
}; | ||
AnyFunctionType *origFnTy = argGen.build(builder)->castTo<AnyFunctionType>(); | ||
origFnTy = origFnTy->withExtInfo(origFnTy->getExtInfo() | ||
.withDifferentiability(FunctionTypeDifferentiability::None)); | ||
auto *paramIndices = AutoDiffParameterIndices::create(Context, origFnTy, | ||
/*isMethod*/ false, | ||
/*setAllParams*/ true); | ||
// Generator for the resultant function type, i.e. the AD associated function. | ||
BuiltinGenericSignatureBuilder::LambdaGenerator resultGen { | ||
[=](BuiltinGenericSignatureBuilder &builder) -> Type { | ||
// TODO(rxwei): Use parameter indices and differentiation order that are | ||
// stored in the function type. | ||
auto *vjpType = origFnTy->getAutoDiffAssociatedFunctionType( | ||
*paramIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1, | ||
kind, /*lookupConformance*/ nullptr); | ||
vjpType = vjpType->withExtInfo(vjpType->getExtInfo().withNoEscape(false)); | ||
return vjpType; | ||
} | ||
}; | ||
builder.addParameter(argGen); | ||
builder.setResult(resultGen); | ||
return builder.build(Id); | ||
} | ||
|
||
static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) { | ||
auto int1Type = BuiltinIntegerType::get(1, Context); | ||
auto optionalRawPointerType = BoundGenericEnumType::get( | ||
|
@@ -1958,6 +2023,14 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { | |
return getAutoDiffPopFromTape(Context, Id); | ||
case BuiltinValueKind::AutoDiffDestroyTape: | ||
return getAutoDiffDestroyTape(Context, Id); | ||
case BuiltinValueKind::AutoDiffGetJVP: | ||
return getAutoDiffGetAssociatedFunction(Context, Id, | ||
AutoDiffAssociatedFunctionKind::JVP, | ||
/*order*/ 1, /*arity*/ 1); | ||
case BuiltinValueKind::AutoDiffGetVJP: | ||
return getAutoDiffGetAssociatedFunction(Context, Id, | ||
AutoDiffAssociatedFunctionKind::VJP, | ||
/*order*/ 1, /*arity*/ 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you said, there's a mismatch between Could you please expand on how you plan to handle this mismatch? Is the only solution to create many versions of the builtins like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
case BuiltinValueKind::PoundAssert: | ||
return getPoundAssert(Context, Id); | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.