Skip to content

Commit eeeeee2

Browse files
authored
[AutoDiff] Remove differentiation order from AD-related instructions. (#27579)
The differentiation order field in `differentiable_function` and `differentiable_function_extract` instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle `order`, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch. Changes include: * Remove `differentiationOrder` from `DifferentiableFunctionInst` and `DifferentiableFunctionExtractInst`. * Make `DifferentiableFunctionInst::DifferentiableFunctionInst` take an optional pair of JVP and VJP instead of a variable-size array. * Rename "associated functions" to "derivative functions" in `DifferentiableFunctionInst` to align better with [the design](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547). Filed task [TF-882](https://bugs.swift.org/browse/TF-882) to track the renaming of all other occurrences of "associated functions". Resolves [TF-880](https://bugs.swift.org/browse/TF-880).
1 parent fb6045c commit eeeeee2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+429
-627
lines changed

docs/SIL.rst

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5620,21 +5620,18 @@ differentiable_function
56205620
sil-differentiable-function-associated-function-list ::=
56215621
'{' sil-value ',' sil-value '}'
56225622

5623-
differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
5623+
differentiable_function [wrt 0] %0 : $(T) -> T \
56245624
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
56255625

5626-
Bundles a function with its associated differentiation functions up to a
5627-
specified differentiation order into an ``@differentiable`` function. There are
5628-
two associated functions per differentiation order: a Jacobian-vector products
5629-
(JVP) function and a vector-Jacobian products (VJP) function.
5626+
Bundles a function with its associated differentiation functions into a
5627+
``@differentiable`` function. There are two associated functions:
5628+
a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP)
5629+
function.
56305630

56315631
``[wrt ...]`` specifies parameter indices that the original function is
56325632
differentiable with respect to. When not specified, it defaults to all
56335633
parameters.
56345634

5635-
``[order ...]`` specifies the maximum differentiation order for the resulting
5636-
function. The number of lists of associated functions is equal to the order.
5637-
56385635
A ``with`` clause specifies the differentiation functions associated
56395636
with the original function. When a ``with`` clause is not specified, the first
56405637
operand will be differentiated to produce associated functions, and a ``with``
@@ -5660,12 +5657,12 @@ differentiable_function_extract
56605657
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'
56615658

56625659
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5663-
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
5664-
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
5660+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5661+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
56655662

56665663
Extracts the original function or an associated function from the given
5667-
``@differentiable`` function at a specific differentiation order. It must be
5668-
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
5664+
``@differentiable`` function. It must be provided with an extractee:
5665+
``[original]``, ``[jvp]`` or ``[vjp]``.
56695666

56705667

56715668
Assertion configuration

include/swift/AST/AutoDiff.h

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -458,29 +458,25 @@ struct AutoDiffAssociatedFunctionKind {
458458
/// compared by opaque pointer value.
459459
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
460460
const AutoDiffAssociatedFunctionKind kind;
461-
const unsigned differentiationOrder;
462461
AutoDiffIndexSubset *const parameterIndices;
463462

464463
AutoDiffAssociatedFunctionIdentifier(
465-
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
464+
AutoDiffAssociatedFunctionKind kind,
466465
AutoDiffIndexSubset *parameterIndices) :
467-
kind(kind), differentiationOrder(differentiationOrder),
468-
parameterIndices(parameterIndices) {}
466+
kind(kind), parameterIndices(parameterIndices) {}
469467

470468
public:
471469
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
472-
unsigned getDifferentiationOrder() const { return differentiationOrder; }
473470
AutoDiffIndexSubset *getParameterIndices() const {
474471
return parameterIndices;
475472
}
476473

477474
static AutoDiffAssociatedFunctionIdentifier *get(
478-
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
475+
AutoDiffAssociatedFunctionKind kind,
479476
AutoDiffIndexSubset *parameterIndices, ASTContext &C);
480477

481478
void Profile(llvm::FoldingSetNodeID &ID) {
482479
ID.AddInteger(kind);
483-
ID.AddInteger(differentiationOrder);
484480
ID.AddPointer(parameterIndices);
485481
}
486482
};
@@ -520,29 +516,12 @@ void getSubsetParameterTypes(AutoDiffIndexSubset *indices,
520516
AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices,
521517
AnyFunctionType *type);
522518

523-
/// Returns the offset for an associated function at a specific differentiation
524-
/// order.
525-
/// This is used for both ordering in the `differentiable_function` instruction
526-
/// and ABI layout.
527-
///
528-
/// Order 1 Order 2 ...
529-
/// |----------| |-----|-----| |-----|-----| ...
530-
/// | Original | | JVP | VJP | | JVP | VJP | ...
531-
/// |----------| |-----|-----| |-----|-----| ...
532-
unsigned
533-
getOffsetForAutoDiffAssociatedFunction(unsigned order,
534-
AutoDiffAssociatedFunctionKind kind);
535-
536-
unsigned
537-
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);
538-
539519
/// Retrieve config from the function name of a variant of
540520
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
541521
/// Returns true if the function name is parsed successfully.
542522
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
543523
AutoDiffAssociatedFunctionKind &kind,
544-
unsigned &arity, unsigned &order,
545-
bool &rethrows);
524+
unsigned &arity, bool &rethrows);
546525

547526
/// Computes the correct linkage for an associated function given the linkage of
548527
/// the original function. If the original linkage is not external and

include/swift/AST/DiagnosticsParse.def

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,10 +1594,6 @@ ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
15941594
// SIL autodiff
15951595
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
15961596
"expected ']' to complete the %0", (StringRef))
1597-
ERROR(sil_inst_autodiff_expected_order,PointsToFirstBadToken,
1598-
"expected an unsigned integer indicating the differentiation order", ())
1599-
ERROR(sil_inst_autodiff_expected_nonzero_order,PointsToFirstBadToken,
1600-
"expected a non-zero differentiation order", ())
16011597
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
16021598
"expected the index of a parameter to differentiate with respect to", ())
16031599
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,

include/swift/AST/Types.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,8 +3099,8 @@ class AnyFunctionType : public TypeBase {
30993099
}
31003100

31013101
// SWIFT_ENABLE_TENSORFLOW
3102-
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
3103-
/// of the corresponding autodiff associated function.
3102+
/// Given `indices` and `kind`, calculates the type of the corresponding
3103+
/// autodiff associated function.
31043104
///
31053105
/// By default, if the original type has a self parameter list and parameter
31063106
/// indices include self, the computed associated function type will return a
@@ -3116,7 +3116,7 @@ class AnyFunctionType : public TypeBase {
31163116
/// function, including `@differentiable`.
31173117
AnyFunctionType *getAutoDiffAssociatedFunctionType(
31183118
AutoDiffIndexSubset *indices, unsigned resultIndex,
3119-
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
3119+
AutoDiffAssociatedFunctionKind kind,
31203120
LookupConformanceFn lookupConformance,
31213121
GenericSignature *whereClauseGenericSignature = nullptr,
31223122
bool makeSelfParamFirst = false);
@@ -4216,16 +4216,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42164216

42174217
// SWIFT_ENABLE_TENSORFLOW
42184218
CanSILFunctionType getWithDifferentiability(
4219-
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);
4219+
AutoDiffIndexSubset *parameterIndices);
42204220

42214221
CanSILFunctionType getWithoutDifferentiability();
42224222

42234223
/// Returns the type of a differentiation function that is associated with
42244224
/// a function of this type.
42254225
CanSILFunctionType getAutoDiffAssociatedFunctionType(
42264226
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
4227-
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
4228-
Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance,
4227+
AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC,
4228+
LookupConformanceFn lookupConformance,
42294229
CanGenericSignature associatedFunctionGenericSignature = nullptr);
42304230

42314231
/// Returns a bit vector that specifices which parameters you can

include/swift/SIL/SILBuilder.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -511,28 +511,27 @@ class SILBuilder {
511511

512512
/// SWIFT_ENABLE_TENSORFLOW
513513
DifferentiableFunctionInst *createDifferentiableFunction(
514-
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
515-
unsigned differentiationOrder, SILValue original,
516-
ArrayRef<SILValue> associatedFunctions = {}) {
514+
SILLocation Loc, AutoDiffIndexSubset *ParameterIndices,
515+
SILValue OriginalFunction,
516+
Optional<std::pair<SILValue, SILValue>> JVPAndVJPFunctions = None) {
517517
return insert(DifferentiableFunctionInst::create(
518-
getModule(), getSILDebugLocation(loc), parameterIndices,
519-
differentiationOrder, original, associatedFunctions));
518+
getModule(), getSILDebugLocation(Loc), ParameterIndices,
519+
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
520520
}
521521

522522
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
523-
SILLocation loc, DifferentiableFunctionExtractee extractee,
524-
unsigned differentiationOrder, SILValue theFunction) {
523+
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
524+
SILValue TheFunction) {
525525
return insert(new (getModule()) DifferentiableFunctionExtractInst(
526-
getModule(), getSILDebugLocation(loc), extractee, differentiationOrder,
527-
theFunction));
526+
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
528527
}
529528

530529
DifferentiableFunctionExtractInst *
531-
createDifferentiableFunctionExtractOriginal(SILLocation loc,
532-
SILValue theFunction) {
530+
createDifferentiableFunctionExtractOriginal(SILLocation Loc,
531+
SILValue TheFunction) {
533532
return insert(new (getModule()) DifferentiableFunctionExtractInst(
534-
getModule(), getSILDebugLocation(loc),
535-
DifferentiableFunctionExtractee::Original, 0, theFunction));
533+
getModule(), getSILDebugLocation(Loc),
534+
DifferentiableFunctionExtractee::Original, TheFunction));
536535
}
537536

538537
BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,

include/swift/SIL/SILCloner.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -970,15 +970,14 @@ template<typename ImplClass>
970970
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
971971
DifferentiableFunctionInst *Inst) {
972972
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
973-
SmallVector<SILValue, 16> mappedAssocFns;
974-
mappedAssocFns.reserve(Inst->getNumAssociatedFunctions());
975-
for (auto &fn : Inst->getAssociatedFunctions())
976-
mappedAssocFns.push_back(getOpValue(fn.get()));
973+
Optional<std::pair<SILValue, SILValue>> assocFns = None;
974+
if (Inst->hasDerivativeFunctions())
975+
assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
976+
getOpValue(Inst->getVJPFunction()));
977977
recordClonedInstruction(
978978
Inst, getBuilder().createDifferentiableFunction(
979979
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
980-
Inst->getDifferentiationOrder(),
981-
getOpValue(Inst->getOriginalFunction()), mappedAssocFns));
980+
getOpValue(Inst->getOriginalFunction()), assocFns));
982981
}
983982

984983
template<typename ImplClass>
@@ -988,7 +987,6 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst)
988987
recordClonedInstruction(
989988
Inst, getBuilder().createDifferentiableFunctionExtract(
990989
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
991-
Inst->getDifferentiationOrder(),
992990
getOpValue(Inst->getFunctionOperand())));
993991
}
994992
// SWIFT_ENABLE_TENSORFLOW END

0 commit comments

Comments
 (0)