Skip to content

[AutoDiff] Add Builtin.autodiffGet(JVP|VJP) for extracting AD associated functions #21144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,9 +953,6 @@ class ASTContext final {
}

// SWIFT_ENABLE_TENSORFLOW
/// Determine whether the given type is differentiable.
bool isDifferentiable(CanType type, ModuleDecl *module);

/// Compute the tangent space of this manifold, if the given type represents a
/// differentiable manifold.
Optional<TangentSpace> getTangentSpace(CanType type, ModuleDecl *module);
Expand Down
8 changes: 5 additions & 3 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ class AutoDiffParameterIndices {

unsigned getNumNonSelfParameters() const;

AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag)
: indices(numIndices), isMethodFlag(isMethodFlag) {}
AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag,
bool setAllParams = false)
: indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}

AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
: indices(indices), isMethodFlag(isMethodFlag) {}
Expand All @@ -122,7 +123,8 @@ class AutoDiffParameterIndices {
/// given `functionType`. `isMethod` specifies whether to treat the function
/// as a method.
static AutoDiffParameterIndices *
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod);
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod,
bool setAllParams = false);

bool isMethod() const { return isMethodFlag; }

Expand Down
8 changes: 8 additions & 0 deletions include/swift/AST/Builtins.def
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,14 @@ BUILTIN_SIL_OPERATION(AllocWithTailElems, "allocWithTailElems", Special)
/// Projects the first tail-allocated element of type E from a class C.
BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)

// SWIFT_ENABLE_TENSORFLOW
/// autodifGetJVP has type <T: Differentiable, R: Differentiable>
/// ((T) -> R) -> (T) -> (R, (T.TangentVector) -> R.TangentVector).
BUILTIN_SIL_OPERATION(AutoDiffGetJVP, "autodiffGetJVP", Special)
/// autodifGetVJP has type <T: Differentiable, R: Differentiable>
/// ((T) -> R) -> (T) -> (R, (R.CotangentVector) -> T.CotangentVector).
BUILTIN_SIL_OPERATION(AutoDiffGetVJP, "autodiffGetVJP", Special)

#undef BUILTIN_SIL_OPERATION

// BUILTIN_RUNTIME_CALL - A call into a runtime function.
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3120,6 +3120,10 @@ class AnyFunctionType : public TypeBase {
///
/// Pass `selfUncurried = true` when the function type is for a method whose
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
///
/// \note The original function type (`self`) need not be `@autodiff`, and the
/// resulting function will preserve all `ExtInfo` of the original function,
/// including `@autodiff`.
AnyFunctionType *getAutoDiffAssociatedFunctionType(
const AutoDiffParameterIndices &indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
Expand Down
4 changes: 0 additions & 4 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5188,10 +5188,6 @@ LayoutConstraint LayoutConstraint::getLayoutConstraint(LayoutConstraintKind Kind
}

// SWIFT_ENABLE_TENSORFLOW
bool ASTContext::isDifferentiable(CanType type, ModuleDecl *module) {
return getTangentSpace(type, module).hasValue();
}

Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
ModuleDecl *module) {
auto lookup = getImpl().TangentSpaces.find(type);
Expand Down
5 changes: 3 additions & 2 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static AnyFunctionType *unwrapSelfParameter(AnyFunctionType *functionType,
/// as a method.
AutoDiffParameterIndices *
AutoDiffParameterIndices::create(ASTContext &C, AnyFunctionType *functionType,
bool isMethod) {
bool isMethod, bool setAllParams) {
// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
// gets called, which causes a small memory leak in the case that the
// SmallBitVector decides to allocate some heap space.
Expand All @@ -119,7 +119,8 @@ AutoDiffParameterIndices::create(ASTContext &C, AnyFunctionType *functionType,
unsigned paramCount =
unwrapSelfParameter(functionType, isMethod)->getNumParams() +
(isMethod ? 1 : 0);
return ::new (mem) AutoDiffParameterIndices(paramCount, isMethod);
return
::new (mem) AutoDiffParameterIndices(paramCount, isMethod, setAllParams);
}

/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
Expand Down
131 changes: 102 additions & 29 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,24 +456,21 @@ namespace {
private:
GenericParamList *TheGenericParamList;
SmallVector<GenericTypeParamDecl*, 2> GenericTypeParams;
GenericEnvironment *GenericEnv = nullptr;
// SWIFT_ENABLE_TENSORFLOW
GenericSignatureBuilder Builder;
SmallVector<AnyFunctionType::Param, 4> InterfaceParams;
Type InterfaceResult;

public:
BuiltinGenericSignatureBuilder(ASTContext &ctx, unsigned numGenericParams = 1)
: Context(ctx) {
// SWIFT_ENABLE_TENSORFLOW
: Context(ctx), Builder(ctx) {
TheGenericParamList = getGenericParams(ctx, numGenericParams,
GenericTypeParams);

GenericSignatureBuilder Builder(ctx);
for (auto gp : GenericTypeParams) {
Builder.addGenericParameter(gp);
}

auto GenericSig =
std::move(Builder).computeGenericSignature(SourceLoc());
GenericEnv = GenericSig->createGenericEnvironment();
}

template <class G>
Expand All @@ -489,7 +486,20 @@ namespace {
InterfaceResult = generator.build(*this);
}

// SWIFT_ENABLE_TENSORFLOW
template <class G>
void addConformanceRequirement(const G &generator, ProtocolDecl *proto) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DougGregor, is this the right thing to do? If so, shall I upstream this?

Requirement req(RequirementKind::Conformance,
generator.build(*this),
proto->getDeclaredType());
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
Builder.addRequirement(req, source, Context.getStdlibModule());
}

ValueDecl *build(Identifier name) {
auto GenericSig = std::move(Builder).computeGenericSignature(SourceLoc());
auto GenericEnv = GenericSig->createGenericEnvironment();
return getBuiltinGenericFunction(name, InterfaceParams,
InterfaceResult,
TheGenericParamList,
Expand Down Expand Up @@ -533,22 +543,6 @@ makeConcrete(Type type) {
return { type };
}

// SWIFT_ENABLE_TENSORFLOW
template <class P, class... Gs>
static BuiltinGenericSignatureBuilder::LambdaGenerator
makeBoundGeneric(NominalTypeDecl *decl, const P &parentGenerator,
const Gs & ...genericParamGenerators) {
return {
[=](BuiltinGenericSignatureBuilder &builder) -> Type {
Type parent = parentGenerator.build(builder);
Type genParams[] = {
genericParamGenerators.build(builder)...
};
return BoundGenericType::get(decl, parent, genParams);
}
};
}

static BuiltinGenericSignatureBuilder::ParameterGenerator
makeGenericParam(unsigned index = 0) {
return { index };
Expand Down Expand Up @@ -985,8 +979,7 @@ static ValueDecl *getAutoDiffCreateTape(ASTContext &Context, Identifier Id) {
// <T> () -> (Swift._AutoDiffTape<T>)
BuiltinGenericSignatureBuilder builder(Context, 1);
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
builder.setResult(
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam()));
builder.setResult(makeBoundGenericType(tapeDecl, makeGenericParam()));
return builder.build(Id);
}

Expand All @@ -995,7 +988,7 @@ static ValueDecl *getAutoDiffPushToTape(ASTContext &Context, Identifier Id) {
BuiltinGenericSignatureBuilder builder(Context, 1);
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
auto T = makeGenericParam();
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T));
builder.addParameter(makeBoundGenericType(tapeDecl, T));
builder.addParameter(T);
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context)));
builder.setResult(makeConcrete(Context.TheEmptyTupleType));
Expand All @@ -1007,7 +1000,7 @@ static ValueDecl *getAutoDiffPopFromTape(ASTContext &Context, Identifier Id) {
BuiltinGenericSignatureBuilder builder(Context, 1);
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
auto T = makeGenericParam();
builder.addParameter(makeBoundGeneric(tapeDecl, makeConcrete(Type()), T));
builder.addParameter(makeBoundGenericType(tapeDecl, T));
builder.addParameter(makeConcrete(BuiltinIntegerType::getWordType(Context)));
builder.setResult(T);
return builder.build(Id);
Expand All @@ -1017,12 +1010,84 @@ static ValueDecl *getAutoDiffDestroyTape(ASTContext &Context, Identifier Id) {
// <T> (Swift._AutoDiffTape<T>) -> ()
BuiltinGenericSignatureBuilder builder(Context, 1);
auto *tapeDecl = Context.get_AutoDiffTapeDecl();
builder.addParameter(
makeBoundGeneric(tapeDecl, makeConcrete(Type()), makeGenericParam()));
builder.addParameter(makeBoundGenericType(tapeDecl, makeGenericParam()));
builder.setResult(makeConcrete(Context.TheEmptyTupleType));
return builder.build(Id);
}

static ValueDecl *getAutoDiffGetAssociatedFunction(
ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind,
unsigned order, unsigned arity, bool isThrowing = false) {
assert(arity >= 1);
assert(order == 1 && "higher-order differentiation is not supported yet");
// JVP(non-throwing):
// <...T...(arity), R> (@autodiff (...T) -> R)
// -> (...T) -> (R, (...T.TangentVector) -> R.TangentVector)
// JVP(throwing):
// <...T...(arity), R> (@autodiff (...T) throws -> R)
// -> (...T) throws -> (R, (...T.TangentVector) -> R.TangentVector)
// VJP(non-throwing):
// <...T...(arity), R> (@autodiff (...T) -> R)
// -> (...T) -> (R, (R.CotangentVector) -> ...T.CotangentVector)
// VJP(throwing):
// <...T...(arity), R> (@autodiff (...T) throws -> R)
// -> (...T) throws -> (R, (R.CotangentVector) -> ...T.CotangentVector)
BuiltinGenericSignatureBuilder builder(Context,
/*numGenericParams*/ 1 + arity);
// Look up the Differentiable protocol.
SmallVector<ValueDecl *, 1> diffableProtoLookup;
Context.lookupInSwiftModule("Differentiable", diffableProtoLookup);
assert(diffableProtoLookup.size() == 1);
auto *diffableProto = cast<ProtocolDecl>(diffableProtoLookup.front());
// Create type parameters and add conformance constraints.
auto R = makeGenericParam(arity);
builder.addConformanceRequirement(R, diffableProto);
SmallVector<decltype(R), 2> Ts;
for (auto i : range(arity)) {
auto T = makeGenericParam(i);
builder.addConformanceRequirement(T, diffableProto);
Ts.push_back(T);
}
// Generator for the argument.
BuiltinGenericSignatureBuilder::LambdaGenerator argGen {
// Generator for the function type at the argument position, i.e. the
// function being differentiated.
[=, &Ts](BuiltinGenericSignatureBuilder &builder) -> Type {
FunctionType::ExtInfo ext;
auto extInfo = FunctionType::ExtInfo()
.withDifferentiability(FunctionTypeDifferentiability::Bidirectional)
.withNoEscape();
if (isThrowing)
extInfo = extInfo.withThrows();
SmallVector<FunctionType::Param, 2> params;
for (auto &paramGen : Ts)
params.push_back(FunctionType::Param(paramGen.build(builder)));
return FunctionType::get(params, R.build(builder))->withExtInfo(extInfo);
}
};
AnyFunctionType *origFnTy = argGen.build(builder)->castTo<AnyFunctionType>();
origFnTy = origFnTy->withExtInfo(origFnTy->getExtInfo()
.withDifferentiability(FunctionTypeDifferentiability::None));
auto *paramIndices = AutoDiffParameterIndices::create(Context, origFnTy,
/*isMethod*/ false,
/*setAllParams*/ true);
// Generator for the resultant function type, i.e. the AD associated function.
BuiltinGenericSignatureBuilder::LambdaGenerator resultGen {
[=](BuiltinGenericSignatureBuilder &builder) -> Type {
// TODO(rxwei): Use parameter indices and differentiation order that are
// stored in the function type.
auto *vjpType = origFnTy->getAutoDiffAssociatedFunctionType(
*paramIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1,
kind, /*lookupConformance*/ nullptr);
vjpType = vjpType->withExtInfo(vjpType->getExtInfo().withNoEscape(false));
return vjpType;
}
};
builder.addParameter(argGen);
builder.setResult(resultGen);
return builder.build(Id);
}

static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {
auto int1Type = BuiltinIntegerType::get(1, Context);
auto optionalRawPointerType = BoundGenericEnumType::get(
Expand Down Expand Up @@ -1958,6 +2023,14 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
return getAutoDiffPopFromTape(Context, Id);
case BuiltinValueKind::AutoDiffDestroyTape:
return getAutoDiffDestroyTape(Context, Id);
case BuiltinValueKind::AutoDiffGetJVP:
return getAutoDiffGetAssociatedFunction(Context, Id,
AutoDiffAssociatedFunctionKind::JVP,
/*order*/ 1, /*arity*/ 1);
case BuiltinValueKind::AutoDiffGetVJP:
return getAutoDiffGetAssociatedFunction(Context, Id,
AutoDiffAssociatedFunctionKind::VJP,
/*order*/ 1, /*arity*/ 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you said, there's a mismatch between getAutoDiffGetAssociatedFunction (which supports arbitrary order and arity) and Builtin.autodiffGet(JVP|VJP) (which supports order 1 and arity 1).

Could you please expand on how you plan to handle this mismatch? Is the only solution to create many versions of the builtins like Builtin.autodiffGetJVP_Arity2_Order2_Throwing, as you noted in the PR description?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getAutoDiffGetAssociatedFunction is a builder, and Builtin.autodiffGet(JVP|VJP) is a concrete pair of functions created by that builder. When we overload, we will simply declare more builtins, which will then call the builder with a different arity/order/throwing-ness. I think this can also be simplified by extending the parser to recognize the suffixes.

case BuiltinValueKind::PoundAssert:
return getPoundAssert(Context, Id);

Expand Down
48 changes: 30 additions & 18 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4188,24 +4188,36 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
auto *cotangentDependentType = DependentMemberType::get(
differentiableProtocol->getDeclaredInterfaceType(),
cast<AssociatedTypeDecl>(cotangentLookup[0]));
auto getAssociatedType = [&](Type type,
DependentMemberType *dependentType) -> CanType {
// Builtins are their own Tangent/Cotangent.
if (type->is<BuiltinType>()) return type->getCanonicalType();

// TODO: If this is a tuple, recursively get the associated types of its
// components.

// Try to get the associated type, and return it if found (if the result is
// non-null and non-`DependentMemberType`).
auto assocTy = dependentType->substBaseType(type, lookupConformance);
if (assocTy && !assocTy->is<DependentMemberType>())
return assocTy->getCanonicalType();

// When the type does not have an associated type, fallback to treating it
// as its own Tangent/Cotangent.
// TODO: We should eliminate all instances where this happens.
return type->getCanonicalType();

// `getAssociatedFunctionType` takes a base type and a protocol-dependent type
// and returns a canonical type representing the associated type after any
// possible substitutions.
// For base types that are tuples, applies `getAssociatedFunctionType` to
// every element type and a tuple type of new elements.
// For base types that are builtins, returns the types themselves.
std::function<CanType(Type, DependentMemberType *)> getAssociatedType
= [&](Type type, DependentMemberType *dependentType) {
// Builtins floats are their own Tangent/Cotangent.
if (type->is<BuiltinType>())
return type->getCanonicalType();
// Tuples' Tangent/Cotangent is a tuple of each element's Tangent/Cotangent.
if (auto *tupleTy = type->getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
for (auto elt : tupleTy->getElements())
newElts.push_back(
elt.getWithType(getAssociatedType(elt.getType(), dependentType)));
return TupleType::get(newElts, ctx)->getCanonicalType();
}
// If `lookupConformance` is not provided by the caller, try to get the
// associated type by substituting the base type for a protocol associated
// type, and return it if found (if the result is non-null and
// non-`DependentMemberType`).
if (lookupConformance)
if (auto assocTy = dependentType->substBaseType(type, lookupConformance))
return assocTy->getCanonicalType();
// Otherwise, just return the base type's dependent member type.
return DependentMemberType::get(type, dependentType->getAssocType())
->getCanonicalType();
};

SmallVector<Type, 8> wrtParamTypes;
Expand Down
26 changes: 26 additions & 0 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,32 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF,
return ManagedValue::forUnmanaged(val);
}

// SWIFT_ENABLE_TENSORFLOW
/// Specialized emitter for Builtin.addressOfBorrow.
static ManagedValue emitBuiltinAutoDiffGetJVP(SILGenFunction &SGF,
SILLocation loc,
SubstitutionMap substitutions,
Expr *argument,
SGFContext C) {
auto argVal = SGF.emitRValue(argument);
auto jvp = SGF.getBuilder().createAutoDiffFunctionExtract(
loc, AutoDiffFunctionExtractee::JVP, /*differentiationOrder*/ 1,
std::move(argVal).forwardAsSingleValue(SGF, loc));
return SGF.emitManagedRValueWithCleanup(jvp);
}

static ManagedValue emitBuiltinAutoDiffGetVJP(SILGenFunction &SGF,
SILLocation loc,
SubstitutionMap substitutions,
Expr *argument,
SGFContext C) {
auto argVal = SGF.emitRValue(argument);
auto vjp = SGF.getBuilder().createAutoDiffFunctionExtract(
loc, AutoDiffFunctionExtractee::VJP, /*differentiationOrder*/ 1,
std::move(argVal).forwardAsSingleValue(SGF, loc));
return SGF.emitManagedRValueWithCleanup(vjp);
}

Optional<SpecializedEmitter>
SpecializedEmitter::forDecl(SILGenModule &SGM, SILDeclRef function) {
// Only consider standalone declarations in the Builtin module.
Expand Down
3 changes: 0 additions & 3 deletions lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5733,9 +5733,6 @@ RValue RValueEmitter::visitAutoDiffFunctionExpr(AutoDiffFunctionExpr *E,
RValue RValueEmitter::visitAutoDiffFunctionExtractOriginalExpr(
AutoDiffFunctionExtractOriginalExpr *E, SGFContext C) {
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
llvm::outs() << "Difffunc subexpr type = " << E->getSubExpr()->getType() << '\n';
llvm::outs() << "Difffunc = " << diffFunc.getValue() << '\n';
llvm::outs().flush();
auto *orig = SGF.B.createAutoDiffFunctionExtractOriginal(
E, diffFunc.forward(SGF));
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(orig));
Expand Down
Loading