Skip to content

[AutoDiff] Remove differentiation order from AD-related instructions. #27579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5620,21 +5620,18 @@ differentiable_function
sil-differentiable-function-associated-function-list ::=
'{' sil-value ',' sil-value '}'

differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
differentiable_function [wrt 0] %0 : $(T) -> T \
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}

Bundles a function with its associated differentiation functions up to a
specified differentiation order into an ``@differentiable`` function. There are
two associated functions per differentiation order: a Jacobian-vector products
(JVP) function and a vector-Jacobian products (VJP) function.
Bundles a function with its associated differentiation functions into a
``@differentiable`` function. There are two associated functions:
a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP)
function.

``[wrt ...]`` specifies parameter indices that the original function is
differentiable with respect to. When not specified, it defaults to all
parameters.

``[order ...]`` specifies the maximum differentiation order for the resulting
function. The number of lists of associated functions is equal to the order.

A ``with`` clause specifies the differentiation functions associated
with the original function. When a ``with`` clause is not specified, the first
operand will be differentiated to produce associated functions, and a ``with``
Expand All @@ -5660,12 +5657,12 @@ differentiable_function_extract
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T

Extracts the original function or an associated function from the given
``@differentiable`` function at a specific differentiation order. It must be
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
``@differentiable`` function. It must be provided with an extractee:
``[original]``, ``[jvp]`` or ``[vjp]``.


Assertion configuration
Expand Down
29 changes: 4 additions & 25 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,29 +458,25 @@ struct AutoDiffAssociatedFunctionKind {
/// compared by opaque pointer value.
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffAssociatedFunctionKind kind;
const unsigned differentiationOrder;
AutoDiffIndexSubset *const parameterIndices;

AutoDiffAssociatedFunctionIdentifier(
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
AutoDiffAssociatedFunctionKind kind,
AutoDiffIndexSubset *parameterIndices) :
kind(kind), differentiationOrder(differentiationOrder),
parameterIndices(parameterIndices) {}
kind(kind), parameterIndices(parameterIndices) {}

public:
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
unsigned getDifferentiationOrder() const { return differentiationOrder; }
AutoDiffIndexSubset *getParameterIndices() const {
return parameterIndices;
}

static AutoDiffAssociatedFunctionIdentifier *get(
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
AutoDiffAssociatedFunctionKind kind,
AutoDiffIndexSubset *parameterIndices, ASTContext &C);

void Profile(llvm::FoldingSetNodeID &ID) {
ID.AddInteger(kind);
ID.AddInteger(differentiationOrder);
ID.AddPointer(parameterIndices);
}
};
Expand Down Expand Up @@ -520,29 +516,12 @@ void getSubsetParameterTypes(AutoDiffIndexSubset *indices,
AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices,
AnyFunctionType *type);

/// Returns the offset for an associated function at a specific differentiation
/// order.
/// This is used for both ordering in the `differentiable_function` instruction
/// and ABI layout.
///
/// Order 1 Order 2 ...
/// |----------| |-----|-----| |-----|-----| ...
/// | Original | | JVP | VJP | | JVP | VJP | ...
/// |----------| |-----|-----| |-----|-----| ...
unsigned
getOffsetForAutoDiffAssociatedFunction(unsigned order,
AutoDiffAssociatedFunctionKind kind);

unsigned
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);

/// Retrieve config from the function name of a variant of
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
AutoDiffAssociatedFunctionKind &kind,
unsigned &arity, unsigned &order,
bool &rethrows);
unsigned &arity, bool &rethrows);

/// Computes the correct linkage for an associated function given the linkage of
/// the original function. If the original linkage is not external and
Expand Down
4 changes: 0 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1594,10 +1594,6 @@ ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
// SIL autodiff
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
"expected ']' to complete the %0", (StringRef))
ERROR(sil_inst_autodiff_expected_order,PointsToFirstBadToken,
"expected an unsigned integer indicating the differentiation order", ())
ERROR(sil_inst_autodiff_expected_nonzero_order,PointsToFirstBadToken,
"expected a non-zero differentiation order", ())
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate with respect to", ())
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
Expand Down
12 changes: 6 additions & 6 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3099,8 +3099,8 @@ class AnyFunctionType : public TypeBase {
}

// SWIFT_ENABLE_TENSORFLOW
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
/// of the corresponding autodiff associated function.
/// Given `indices` and `kind`, calculates the type of the corresponding
/// autodiff associated function.
///
/// By default, if the original type has a self parameter list and parameter
/// indices include self, the computed associated function type will return a
Expand All @@ -3116,7 +3116,7 @@ class AnyFunctionType : public TypeBase {
/// function, including `@differentiable`.
AnyFunctionType *getAutoDiffAssociatedFunctionType(
AutoDiffIndexSubset *indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
AutoDiffAssociatedFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenericSignature = nullptr,
bool makeSelfParamFirst = false);
Expand Down Expand Up @@ -4216,16 +4216,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

// SWIFT_ENABLE_TENSORFLOW
CanSILFunctionType getWithDifferentiability(
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);
AutoDiffIndexSubset *parameterIndices);

CanSILFunctionType getWithoutDifferentiability();

/// Returns the type of a differentiation function that is associated with
/// a function of this type.
CanSILFunctionType getAutoDiffAssociatedFunctionType(
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance,
AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature associatedFunctionGenericSignature = nullptr);

/// Returns a bit vector that specifices which parameters you can
Expand Down
25 changes: 12 additions & 13 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,28 +511,27 @@ class SILBuilder {

/// SWIFT_ENABLE_TENSORFLOW
DifferentiableFunctionInst *createDifferentiableFunction(
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
unsigned differentiationOrder, SILValue original,
ArrayRef<SILValue> associatedFunctions = {}) {
SILLocation Loc, AutoDiffIndexSubset *ParameterIndices,
SILValue OriginalFunction,
Optional<std::pair<SILValue, SILValue>> JVPAndVJPFunctions = None) {
return insert(DifferentiableFunctionInst::create(
getModule(), getSILDebugLocation(loc), parameterIndices,
differentiationOrder, original, associatedFunctions));
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
}

DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
SILLocation loc, DifferentiableFunctionExtractee extractee,
unsigned differentiationOrder, SILValue theFunction) {
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
SILValue TheFunction) {
return insert(new (getModule()) DifferentiableFunctionExtractInst(
getModule(), getSILDebugLocation(loc), extractee, differentiationOrder,
theFunction));
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
}

DifferentiableFunctionExtractInst *
createDifferentiableFunctionExtractOriginal(SILLocation loc,
SILValue theFunction) {
createDifferentiableFunctionExtractOriginal(SILLocation Loc,
SILValue TheFunction) {
return insert(new (getModule()) DifferentiableFunctionExtractInst(
getModule(), getSILDebugLocation(loc),
DifferentiableFunctionExtractee::Original, 0, theFunction));
getModule(), getSILDebugLocation(Loc),
DifferentiableFunctionExtractee::Original, TheFunction));
}

BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,
Expand Down
12 changes: 5 additions & 7 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,15 +970,14 @@ template<typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
DifferentiableFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
SmallVector<SILValue, 16> mappedAssocFns;
mappedAssocFns.reserve(Inst->getNumAssociatedFunctions());
for (auto &fn : Inst->getAssociatedFunctions())
mappedAssocFns.push_back(getOpValue(fn.get()));
Optional<std::pair<SILValue, SILValue>> assocFns = None;
if (Inst->hasDerivativeFunctions())
assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
getOpValue(Inst->getVJPFunction()));
recordClonedInstruction(
Inst, getBuilder().createDifferentiableFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
Inst->getDifferentiationOrder(),
getOpValue(Inst->getOriginalFunction()), mappedAssocFns));
getOpValue(Inst->getOriginalFunction()), assocFns));
}

template<typename ImplClass>
Expand All @@ -988,7 +987,6 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst)
recordClonedInstruction(
Inst, getBuilder().createDifferentiableFunctionExtract(
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
Inst->getDifferentiationOrder(),
getOpValue(Inst->getFunctionOperand())));
}
// SWIFT_ENABLE_TENSORFLOW END
Expand Down
Loading