Skip to content

[AutoDiff] Add builtin differentiable/linear function consturctors. #28467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,19 @@ IndexSubset *getLoweredParameterIndices(IndexSubset *indices,
AnyFunctionType *type);

/// Retrieve config from the function name of a variant of
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
AutoDiffDerivativeFunctionKind &kind,
unsigned &arity, bool &rethrows);

/// Retrieve config from the function name of a variant of
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
/// `Builtin.differentiableFunction_arity1_throws`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinDifferentiableOrLinearFunctionConfig(
StringRef operationName, unsigned &arity, bool &throws);

/// Computes the correct linkage for a derivative function given the linkage of
/// the original function. If the original linkage is not external and
/// `isDerivativeFnExported` is true, use the original function's linkage.
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/Builtins.def
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)
/// autodiffApply
BUILTIN_SIL_OPERATION(AutoDiffApply, "autodiffApply", Special)

/// differentiableFunction
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)

/// linearFunction
BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special)

#undef BUILTIN_SIL_OPERATION

// BUILTIN_RUNTIME_CALL - A call into a runtime function.
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ IDENTIFIER(withKeywordArguments)
IDENTIFIER(wrapped)
IDENTIFIER(wrappedValue)
IDENTIFIER(wrapperValue)
// SWIFT_ENABLE_TENSORFLOW
IDENTIFIER(differential)
IDENTIFIER(pullback)

// SWIFT_ENABLE_TENSORFLOW
IDENTIFIER(TensorFlow)
Expand Down
56 changes: 47 additions & 9 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,22 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
bool autodiff::getBuiltinAutoDiffApplyConfig(
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
unsigned &arity, bool &rethrows) {
if (!operationName.startswith("autodiffApply_"))
constexpr char prefix[] = "autodiffApply";
if (!operationName.startswith(prefix))
return false;
operationName = operationName.drop_front(strlen("autodiffApply_"));
operationName = operationName.drop_front(sizeof(prefix) - 1);
// Parse 'jvp' or 'vjp'.
if (operationName.startswith("jvp"))
constexpr char jvpPrefix[] = "_jvp";
constexpr char vjpPrefix[] = "_vjp";
if (operationName.startswith(jvpPrefix))
kind = AutoDiffDerivativeFunctionKind::JVP;
else if (operationName.startswith("vjp"))
else if (operationName.startswith(vjpPrefix))
kind = AutoDiffDerivativeFunctionKind::VJP;
operationName = operationName.drop_front(3);
operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
// Parse '_arity'.
if (operationName.startswith("_arity")) {
operationName = operationName.drop_front(strlen("_arity"));
constexpr char arityPrefix[] = "_arity";
if (operationName.startswith(arityPrefix)) {
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
auto arityStr = operationName.take_while(llvm::isDigit);
operationName = operationName.drop_front(arityStr.size());
auto converted = llvm::to_integer(arityStr, arity);
Expand All @@ -230,15 +234,49 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
arity = 1;
}
// Parse '_rethrows'.
if (operationName.startswith("_rethrows")) {
operationName = operationName.drop_front(strlen("_rethrows"));
constexpr char rethrowsPrefix[] = "_rethrows";
if (operationName.startswith(rethrowsPrefix)) {
operationName = operationName.drop_front(sizeof(rethrowsPrefix) - 1);
rethrows = true;
} else {
rethrows = false;
}
return operationName.empty();
}

bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
StringRef operationName, unsigned &arity, bool &throws) {
constexpr char differentiablePrefix[] = "differentiableFunction";
constexpr char linearPrefix[] = "linearFunction";
if (operationName.startswith(differentiablePrefix))
operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1);
else if (operationName.startswith(linearPrefix))
operationName = operationName.drop_front(sizeof(linearPrefix) - 1);
else
return false;
// Parse '_arity'.
constexpr char arityPrefix[] = "_arity";
if (operationName.startswith(arityPrefix)) {
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
auto arityStr = operationName.take_while(llvm::isDigit);
operationName = operationName.drop_front(arityStr.size());
auto converted = llvm::to_integer(arityStr, arity);
assert(converted); (void)converted;
assert(arity > 0);
} else {
arity = 1;
}
// Parse '_throws'.
constexpr char throwsPrefix[] = "_throws";
if (operationName.startswith(throwsPrefix)) {
operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
throws = true;
} else {
throws = false;
}
return operationName.empty();
}

SILLinkage autodiff::getAutoDiffDerivativeFunctionLinkage(
SILLinkage originalLinkage, bool isDerivativeFnExported) {
// If the original is defined externally, then the AD pass is just generating
Expand Down
181 changes: 181 additions & 0 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,169 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
return builder.build(Id);
}

static ValueDecl *getDifferentiableFunctionConstructor(
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
assert(arity >= 1);
unsigned numGenericParams = 1 + arity;
BuiltinFunctionBuilder builder(Context, numGenericParams);
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
auto *diffableProto =
Context.getProtocol(KnownProtocolKind::Differentiable);
auto *tangentVectorDecl =
diffableProto->getAssociatedType(Context.Id_TangentVector);
assert(tangentVectorDecl);
// Create type parameters and add conformance constraints.
auto origResultGen = makeGenericParam(arity);
builder.addConformanceRequirement(origResultGen, diffableProto);
SmallVector<decltype(origResultGen), 2> fnArgGens;
for (auto i : range(arity)) {
auto T = makeGenericParam(i);
builder.addConformanceRequirement(T, diffableProto);
fnArgGens.push_back(T);
}

BuiltinFunctionBuilder::LambdaGenerator origFnGen {
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
SmallVector<FunctionType::Param, 2> params;
for (auto &paramGen : fnArgGens)
params.push_back(FunctionType::Param(paramGen.build(builder)));
return FunctionType::get(params, origResultGen.build(builder))
->withExtInfo(
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
}
};

BuiltinFunctionBuilder::LambdaGenerator jvpGen {
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
SmallVector<FunctionType::Param, 2> params;
for (auto &paramGen : fnArgGens)
params.push_back(FunctionType::Param(paramGen.build(builder)));
auto origResultType = origResultGen.build(builder);
SmallVector<FunctionType::Param, 2> differentialParams;
for (auto &param : params) {
auto tanType = DependentMemberType::get(
param.getPlainType(), tangentVectorDecl);
differentialParams.push_back(FunctionType::Param(tanType));
}
auto differentialResultType = DependentMemberType::get(
origResultType, tangentVectorDecl);
auto differentialType =
FunctionType::get({differentialParams}, differentialResultType);
auto jvpResultType = TupleType::get(
{TupleTypeElt(origResultType, Context.Id_value),
TupleTypeElt(differentialType, Context.Id_differential)}, Context);
return FunctionType::get(params, jvpResultType)
->withExtInfo(
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
}
};

BuiltinFunctionBuilder::LambdaGenerator vjpGen {
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
SmallVector<FunctionType::Param, 2> params;
for (auto &paramGen : fnArgGens)
params.push_back(FunctionType::Param(paramGen.build(builder)));
auto origResultType = origResultGen.build(builder);
SmallVector<TupleTypeElt, 2> pullbackResultTupleElts;
for (auto &param : params) {
auto tanType = DependentMemberType::get(
param.getPlainType(), tangentVectorDecl);
pullbackResultTupleElts.push_back(TupleTypeElt(tanType));
}
auto pullbackParam = FunctionType::Param(
DependentMemberType::get(origResultType, tangentVectorDecl));
auto pullbackType = FunctionType::get(
{pullbackParam},
pullbackResultTupleElts.size() == 1
? pullbackResultTupleElts.front().getType()
: TupleType::get(pullbackResultTupleElts, Context));
auto vjpResultType = TupleType::get(
{TupleTypeElt(origResultType, Context.Id_value),
TupleTypeElt(pullbackType, Context.Id_pullback)}, Context);
return FunctionType::get(params, vjpResultType)
->withExtInfo(
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
}
};

BuiltinFunctionBuilder::LambdaGenerator resultGen {
[&](BuiltinFunctionBuilder &builder) -> Type {
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
return origFnType->withExtInfo(
origFnType->getExtInfo()
.withDifferentiabilityKind(DifferentiabilityKind::Normal));
}
};

builder.addParameter(origFnGen, ValueOwnership::Owned);
builder.addParameter(jvpGen, ValueOwnership::Owned);
builder.addParameter(vjpGen, ValueOwnership::Owned);
builder.setResult(resultGen);
return builder.build(Id);
}

static ValueDecl *getLinearFunctionConstructor(
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
assert(arity >= 1);
unsigned numGenericParams = 1 + arity;
BuiltinFunctionBuilder builder(Context, numGenericParams);
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
auto *diffableProto =
Context.getProtocol(KnownProtocolKind::Differentiable);
auto *addArithProto =
Context.getProtocol(KnownProtocolKind::AdditiveArithmetic);
// Create type parameters and add conformance constraints.
auto origResultGen = makeGenericParam(arity);
builder.addConformanceRequirement(origResultGen, diffableProto);
builder.addConformanceRequirement(origResultGen, addArithProto);
SmallVector<decltype(origResultGen), 2> fnArgGens;
for (auto i : range(arity)) {
auto T = makeGenericParam(i);
builder.addConformanceRequirement(T, diffableProto);
builder.addConformanceRequirement(T, addArithProto);
fnArgGens.push_back(T);
}

BuiltinFunctionBuilder::LambdaGenerator origFnGen {
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
SmallVector<FunctionType::Param, 2> params;
for (auto &paramGen : fnArgGens)
params.push_back(FunctionType::Param(paramGen.build(builder)));
return FunctionType::get(params, origResultGen.build(builder))
->withExtInfo(
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
}
};

BuiltinFunctionBuilder::LambdaGenerator transposeFnGen {
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
auto origResultType = origResultGen.build(builder);
SmallVector<TupleTypeElt, 2> resultTupleElts;
for (auto &paramGen : fnArgGens)
resultTupleElts.push_back(paramGen.build(builder));
return FunctionType::get(
{FunctionType::Param(origResultType)},
resultTupleElts.size() == 1
? resultTupleElts.front().getType()
: TupleType::get(resultTupleElts, Context));
}
};

BuiltinFunctionBuilder::LambdaGenerator resultGen {
[&](BuiltinFunctionBuilder &builder) -> Type {
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
return origFnType->withExtInfo(
origFnType->getExtInfo()
.withDifferentiabilityKind(DifferentiabilityKind::Linear));
}
};

builder.addParameter(origFnGen, ValueOwnership::Owned);
builder.addParameter(transposeFnGen, ValueOwnership::Owned);
builder.setResult(resultGen);
return builder.build(Id);
}

static ValueDecl *getGlobalStringTablePointer(ASTContext &Context,
Identifier Id) {
// String -> Builtin.RawPointer
Expand Down Expand Up @@ -1839,6 +2002,22 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
rethrows);
}
if (OperationName.startswith("differentiableFunction_")) {
unsigned arity;
bool throws;
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
OperationName, arity, throws))
return nullptr;
return getDifferentiableFunctionConstructor(Context, Id, arity, throws);
}
if (OperationName.startswith("linearFunction_")) {
unsigned arity;
bool throws;
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
OperationName, arity, throws))
return nullptr;
return getLinearFunctionConstructor(Context, Id, arity, throws);
}
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
#define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id)
#include "swift/AST/Builtins.def"
Expand Down Expand Up @@ -2110,6 +2289,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {

// SWIFT_ENABLE_TENSORFLOW
case BuiltinValueKind::AutoDiffApply:
case BuiltinValueKind::DifferentiableFunction:
case BuiltinValueKind::LinearFunction:
llvm_unreachable("Handled above");

case BuiltinValueKind::OnFastPath:
Expand Down
4 changes: 4 additions & 0 deletions lib/SIL/SILModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ const BuiltinInfo &SILModule::getBuiltinInfo(Identifier ID) {
// SWIFT_ENABLE_TENSORFLOW
else if (OperationName.startswith("autodiffApply_"))
Info.ID = BuiltinValueKind::AutoDiffApply;
else if (OperationName.startswith("differentiableFunction_"))
Info.ID = BuiltinValueKind::DifferentiableFunction;
else if (OperationName.startswith("linearFunction_"))
Info.ID = BuiltinValueKind::LinearFunction;
else
Info.ID = llvm::StringSwitch<BuiltinValueKind>(OperationName)
#define BUILTIN(id, name, attrs) .Case(name, BuiltinValueKind::id)
Expand Down
31 changes: 31 additions & 0 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,37 @@ static ManagedValue emitBuiltinAutoDiffApply(SILGenFunction &SGF,
substitutions, args, C);
}

static ManagedValue emitBuiltinDifferentiableFunction(
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
ArrayRef<ManagedValue> args, SGFContext C) {
assert(args.size() == 3);
auto origFn = args.front();
auto origType = origFn.getType().castTo<SILFunctionType>();
auto diffFn = SGF.B.createDifferentiableFunction(
loc,
IndexSubset::getDefault(
SGF.getASTContext(), origType->getNumParameters(),
/*includeAll*/ true),
origFn.forward(SGF),
std::make_pair(args[1].forward(SGF), args[2].forward(SGF)));
return SGF.emitManagedRValueWithCleanup(diffFn);
}

static ManagedValue emitBuiltinLinearFunction(
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
ArrayRef<ManagedValue> args, SGFContext C) {
assert(args.size() == 2);
auto origFn = args.front();
auto origType = origFn.getType().castTo<SILFunctionType>();
auto linearFn = SGF.B.createLinearFunction(
loc,
IndexSubset::getDefault(
SGF.getASTContext(), origType->getNumParameters(),
/*includeAll*/ true),
origFn.forward(SGF), args[1].forward(SGF));
return SGF.emitManagedRValueWithCleanup(linearFn);
}

/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
/// ownership convention for named builtins, which is to take (non-trivial)
/// arguments as Owned, this builtin accepts owned as well as guaranteed
Expand Down
12 changes: 8 additions & 4 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8961,11 +8961,15 @@ void Differentiation::run() {
for (SILInstruction &i : bb) {
if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i))
context.getDifferentiableFunctionInsts().push_back(dfi);
// Reject uncanonical `linear_function` instructions.
// FIXME(SR-11850): Add support for linear map transposition.
else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
astCtx.Diags.diagnose(
lfi->getLoc().getSourceLoc(),
diag::autodiff_conversion_to_linear_function_not_supported);
errorOccurred = true;
if (!lfi->hasTransposeFunction()) {
astCtx.Diags.diagnose(
lfi->getLoc().getSourceLoc(),
diag::autodiff_conversion_to_linear_function_not_supported);
errorOccurred = true;
}
}
}
}
Expand Down
Loading