Skip to content

Commit bb9c2fe

Browse files
committed
[AutoDiff] Remove differentiation order from AD-related instructions.
The differentiation order field in `differentiable_function` and `differentiable_function_extract` instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle `order`, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch. Changes include: * Remove `differentiationOrder` from `DifferentiableFunctionInst` and `DifferentiableFunctionExtractInst`. * Make `DifferentiableFunctionInst::DifferentiableFunctionInst` take an optional pair of JVP and VJP instead of a variable-size array. * Rename "associated functions" to "derivative functions" in `DifferentiableFunctionInst` to align better with [the design](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547). Filed task [TF-882](https://bugs.swift.org/browse/TF-882) to track the renaming of all other occurrences of "associated functions". Resolves [TF-880](https://bugs.swift.org/browse/TF-880).
1 parent fb6045c commit bb9c2fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+424
-627
lines changed

docs/SIL.rst

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5620,21 +5620,18 @@ differentiable_function
56205620
sil-differentiable-function-associated-function-list ::=
56215621
'{' sil-value ',' sil-value '}'
56225622

5623-
differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
5623+
differentiable_function [wrt 0] %0 : $(T) -> T \
56245624
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
56255625

5626-
Bundles a function with its associated differentiation functions up to a
5627-
specified differentiation order into an ``@differentiable`` function. There are
5628-
two associated functions per differentiation order: a Jacobian-vector products
5629-
(JVP) function and a vector-Jacobian products (VJP) function.
5626+
Bundles a function with its associated differentiation functions into a
5627+
``@differentiable`` function. There are two associated functions:
5628+
a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP)
5629+
function.
56305630

56315631
``[wrt ...]`` specifies parameter indices that the original function is
56325632
differentiable with respect to. When not specified, it defaults to all
56335633
parameters.
56345634

5635-
``[order ...]`` specifies the maximum differentiation order for the resulting
5636-
function. The number of lists of associated functions is equal to the order.
5637-
56385635
A ``with`` clause specifies the differentiation functions associated
56395636
with the original function. When a ``with`` clause is not specified, the first
56405637
operand will be differentiated to produce associated functions, and a ``with``
@@ -5660,12 +5657,12 @@ differentiable_function_extract
56605657
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'
56615658

56625659
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5663-
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
5664-
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
5660+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5661+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
56655662

56665663
Extracts the original function or an associated function from the given
5667-
``@differentiable`` function at a specific differentiation order. It must be
5668-
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
5664+
``@differentiable`` function. It must be provided with an extractee:
5665+
``[original]``, ``[jvp]`` or ``[vjp]``.
56695666

56705667

56715668
Assertion configuration

include/swift/AST/AutoDiff.h

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -458,29 +458,25 @@ struct AutoDiffAssociatedFunctionKind {
458458
/// compared by opaque pointer value.
459459
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
460460
const AutoDiffAssociatedFunctionKind kind;
461-
const unsigned differentiationOrder;
462461
AutoDiffIndexSubset *const parameterIndices;
463462

464463
AutoDiffAssociatedFunctionIdentifier(
465-
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
464+
AutoDiffAssociatedFunctionKind kind,
466465
AutoDiffIndexSubset *parameterIndices) :
467-
kind(kind), differentiationOrder(differentiationOrder),
468-
parameterIndices(parameterIndices) {}
466+
kind(kind), parameterIndices(parameterIndices) {}
469467

470468
public:
471469
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
472-
unsigned getDifferentiationOrder() const { return differentiationOrder; }
473470
AutoDiffIndexSubset *getParameterIndices() const {
474471
return parameterIndices;
475472
}
476473

477474
static AutoDiffAssociatedFunctionIdentifier *get(
478-
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
475+
AutoDiffAssociatedFunctionKind kind,
479476
AutoDiffIndexSubset *parameterIndices, ASTContext &C);
480477

481478
void Profile(llvm::FoldingSetNodeID &ID) {
482479
ID.AddInteger(kind);
483-
ID.AddInteger(differentiationOrder);
484480
ID.AddPointer(parameterIndices);
485481
}
486482
};
@@ -520,29 +516,12 @@ void getSubsetParameterTypes(AutoDiffIndexSubset *indices,
520516
AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices,
521517
AnyFunctionType *type);
522518

523-
/// Returns the offset for an associated function at a specific differentiation
524-
/// order.
525-
/// This is used for both ordering in the `differentiable_function` instruction
526-
/// and ABI layout.
527-
///
528-
/// Order 1 Order 2 ...
529-
/// |----------| |-----|-----| |-----|-----| ...
530-
/// | Original | | JVP | VJP | | JVP | VJP | ...
531-
/// |----------| |-----|-----| |-----|-----| ...
532-
unsigned
533-
getOffsetForAutoDiffAssociatedFunction(unsigned order,
534-
AutoDiffAssociatedFunctionKind kind);
535-
536-
unsigned
537-
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);
538-
539519
/// Retrieve config from the function name of a variant of
540520
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
541521
/// Returns true if the function name is parsed successfully.
542522
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
543523
AutoDiffAssociatedFunctionKind &kind,
544-
unsigned &arity, unsigned &order,
545-
bool &rethrows);
524+
unsigned &arity, bool &rethrows);
546525

547526
/// Computes the correct linkage for an associated function given the linkage of
548527
/// the original function. If the original linkage is not external and

include/swift/AST/DiagnosticsParse.def

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,10 +1594,6 @@ ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
15941594
// SIL autodiff
15951595
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
15961596
"expected ']' to complete the %0", (StringRef))
1597-
ERROR(sil_inst_autodiff_expected_order,PointsToFirstBadToken,
1598-
"expected an unsigned integer indicating the differentiation order", ())
1599-
ERROR(sil_inst_autodiff_expected_nonzero_order,PointsToFirstBadToken,
1600-
"expected a non-zero differentiation order", ())
16011597
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
16021598
"expected the index of a parameter to differentiate with respect to", ())
16031599
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,

include/swift/AST/Types.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,8 +3099,8 @@ class AnyFunctionType : public TypeBase {
30993099
}
31003100

31013101
// SWIFT_ENABLE_TENSORFLOW
3102-
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
3103-
/// of the corresponding autodiff associated function.
3102+
/// Given `indices` and `kind`, calculates the type of the corresponding
3103+
/// autodiff associated function.
31043104
///
31053105
/// By default, if the original type has a self parameter list and parameter
31063106
/// indices include self, the computed associated function type will return a
@@ -3116,7 +3116,7 @@ class AnyFunctionType : public TypeBase {
31163116
/// function, including `@differentiable`.
31173117
AnyFunctionType *getAutoDiffAssociatedFunctionType(
31183118
AutoDiffIndexSubset *indices, unsigned resultIndex,
3119-
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
3119+
AutoDiffAssociatedFunctionKind kind,
31203120
LookupConformanceFn lookupConformance,
31213121
GenericSignature *whereClauseGenericSignature = nullptr,
31223122
bool makeSelfParamFirst = false);
@@ -4216,16 +4216,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42164216

42174217
// SWIFT_ENABLE_TENSORFLOW
42184218
CanSILFunctionType getWithDifferentiability(
4219-
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);
4219+
AutoDiffIndexSubset *parameterIndices);
42204220

42214221
CanSILFunctionType getWithoutDifferentiability();
42224222

42234223
/// Returns the type of a differentiation function that is associated with
42244224
/// a function of this type.
42254225
CanSILFunctionType getAutoDiffAssociatedFunctionType(
42264226
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
4227-
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
4228-
Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance,
4227+
AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC,
4228+
LookupConformanceFn lookupConformance,
42294229
CanGenericSignature associatedFunctionGenericSignature = nullptr);
42304230

42314231
/// Returns a bit vector that specifices which parameters you can

include/swift/SIL/SILBuilder.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -511,28 +511,27 @@ class SILBuilder {
511511

512512
/// SWIFT_ENABLE_TENSORFLOW
513513
DifferentiableFunctionInst *createDifferentiableFunction(
514-
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
515-
unsigned differentiationOrder, SILValue original,
516-
ArrayRef<SILValue> associatedFunctions = {}) {
514+
SILLocation Loc, AutoDiffIndexSubset *ParameterIndices,
515+
SILValue OriginalFunction,
516+
Optional<std::pair<SILValue, SILValue>> JVPAndVJPFunctions = None) {
517517
return insert(DifferentiableFunctionInst::create(
518-
getModule(), getSILDebugLocation(loc), parameterIndices,
519-
differentiationOrder, original, associatedFunctions));
518+
getModule(), getSILDebugLocation(Loc), ParameterIndices,
519+
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
520520
}
521521

522522
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
523-
SILLocation loc, DifferentiableFunctionExtractee extractee,
524-
unsigned differentiationOrder, SILValue theFunction) {
523+
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
524+
SILValue TheFunction) {
525525
return insert(new (getModule()) DifferentiableFunctionExtractInst(
526-
getModule(), getSILDebugLocation(loc), extractee, differentiationOrder,
527-
theFunction));
526+
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
528527
}
529528

530529
DifferentiableFunctionExtractInst *
531-
createDifferentiableFunctionExtractOriginal(SILLocation loc,
532-
SILValue theFunction) {
530+
createDifferentiableFunctionExtractOriginal(SILLocation Loc,
531+
SILValue TheFunction) {
533532
return insert(new (getModule()) DifferentiableFunctionExtractInst(
534-
getModule(), getSILDebugLocation(loc),
535-
DifferentiableFunctionExtractee::Original, 0, theFunction));
533+
getModule(), getSILDebugLocation(Loc),
534+
DifferentiableFunctionExtractee::Original, TheFunction));
536535
}
537536

538537
BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,

include/swift/SIL/SILCloner.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -970,15 +970,14 @@ template<typename ImplClass>
970970
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
971971
DifferentiableFunctionInst *Inst) {
972972
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
973-
SmallVector<SILValue, 16> mappedAssocFns;
974-
mappedAssocFns.reserve(Inst->getNumAssociatedFunctions());
975-
for (auto &fn : Inst->getAssociatedFunctions())
976-
mappedAssocFns.push_back(getOpValue(fn.get()));
973+
Optional<std::pair<SILValue, SILValue>> assocFns = None;
974+
if (Inst->hasDerivativeFunctions())
975+
assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
976+
getOpValue(Inst->getVJPFunction()));
977977
recordClonedInstruction(
978978
Inst, getBuilder().createDifferentiableFunction(
979979
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
980-
Inst->getDifferentiationOrder(),
981-
getOpValue(Inst->getOriginalFunction()), mappedAssocFns));
980+
getOpValue(Inst->getOriginalFunction()), assocFns));
982981
}
983982

984983
template<typename ImplClass>
@@ -988,7 +987,6 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst)
988987
recordClonedInstruction(
989988
Inst, getBuilder().createDifferentiableFunctionExtract(
990989
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
991-
Inst->getDifferentiationOrder(),
992990
getOpValue(Inst->getFunctionOperand())));
993991
}
994992
// SWIFT_ENABLE_TENSORFLOW END

include/swift/SIL/SILInstruction.h

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7855,60 +7855,69 @@ class DifferentiableFunctionInst final :
78557855
private:
78567856
friend SILBuilder;
78577857
/// Differentiation parameter indices.
7858-
AutoDiffIndexSubset *parameterIndices;
7859-
/// The order of differentiation.
7860-
unsigned differentiationOrder;
7861-
/// The number of operands. The first operand is always the original function.
7862-
/// The rest of operands determined by the order of differentiation and whether
7863-
/// this is the new AD model or the legacy reverse-mode AD model.
7864-
unsigned numOperands;
7858+
AutoDiffIndexSubset *ParameterIndices;
7859+
/// Indicates whether derivative functions (JVP/VJP) exist.
7860+
bool HasDerivativeFunctions;
78657861

7866-
DifferentiableFunctionInst(SILModule &module, SILDebugLocation debugLoc,
7867-
AutoDiffIndexSubset *parameterIndices,
7868-
unsigned differentiationOrder,
7869-
SILValue originalFunction,
7870-
ArrayRef<SILValue> associatedFunctions);
7862+
DifferentiableFunctionInst(
7863+
SILDebugLocation DebugLoc, AutoDiffIndexSubset *ParameterIndices,
7864+
SILValue OriginalFunction, ArrayRef<SILValue> DerivativeFunctions,
7865+
bool HasOwnership);
7866+
7867+
static SILType getDifferentiableFunctionType(
7868+
SILValue Original, AutoDiffIndexSubset *ParameterIndices);
7869+
7870+
static ValueOwnershipKind getMergedOwnershipKind(
7871+
SILValue Original, ArrayRef<SILValue> DerivativeFunctions);
78717872

78727873
public:
78737874
static DifferentiableFunctionInst *create(
7874-
SILModule &module, SILDebugLocation debugLoc,
7875-
AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder,
7876-
SILValue originalFunction, ArrayRef<SILValue> associatedFunctions);
7877-
7878-
static SILType getAutoDiffType(SILValue original,
7879-
unsigned differentiationOrder,
7880-
AutoDiffIndexSubset *parameterIndices);
7875+
SILModule &Module, SILDebugLocation DebugLoc,
7876+
AutoDiffIndexSubset *ParameterIndices, SILValue OriginalFunction,
7877+
Optional<std::pair<SILValue, SILValue>> VJPAndJVPFunctions,
7878+
bool HasOwnership);
78817879

78827880
/// Returns the original function.
7883-
SILValue getOriginalFunction() const { return getAllOperands()[0].get(); }
7881+
SILValue getOriginalFunction() const { return getOperand(0); }
78847882

78857883
/// Returns differentiation indices.
7886-
AutoDiffIndexSubset *getParameterIndices() const {
7887-
return parameterIndices;
7888-
}
7884+
AutoDiffIndexSubset *getParameterIndices() const { return ParameterIndices; }
78897885

7890-
/// Returns the differentiation order.
7891-
unsigned getDifferentiationOrder() const {
7892-
return differentiationOrder;
7893-
}
7886+
/// Returns true if derivative functions (JVP/VJP) exist.
7887+
bool hasDerivativeFunctions() const { return HasDerivativeFunctions; }
78947888

7895-
unsigned getNumAssociatedFunctions() const {
7896-
return numOperands - 1;
7889+
/// Returns the derivative functions, namely the JVP and VJP functions, if
7890+
/// they exist. Otherwise, return None.
7891+
Optional<std::pair<SILValue, SILValue>>
7892+
getOptionalDerivativeFunctionPair() const {
7893+
if (!HasDerivativeFunctions)
7894+
return None;
7895+
return std::make_pair(getOperand(1), getOperand(2));
78977896
}
78987897

7899-
bool hasAssociatedFunctions() const {
7900-
return numOperands > 1;
7898+
ArrayRef<Operand> getDerivativeFunctionArray() const {
7899+
return getAllOperands().drop_front();
79017900
}
79027901

7903-
ArrayRef<Operand> getAssociatedFunctions() const {
7904-
return getAllOperands().drop_front();
7902+
/// Returns the JVP function.
7903+
SILValue getJVPFunction() const {
7904+
assert(HasDerivativeFunctions);
7905+
return getOperand(1);
79057906
}
79067907

7907-
std::pair<SILValue, SILValue>
7908-
getAssociatedFunctionPair(unsigned differentiationOrder) const;
7908+
/// Returns the VJP function.
7909+
SILValue getVJPFunction() const {
7910+
assert(HasDerivativeFunctions);
7911+
return getOperand(2);
7912+
}
79097913

7910-
SILValue getAssociatedFunction(unsigned differentiationOrder,
7911-
AutoDiffAssociatedFunctionKind kind) const;
7914+
/// Returns the derivative function (JVP or VJP) that matches the given kind.
7915+
SILValue getDerivativeFunction(AutoDiffAssociatedFunctionKind kind) const {
7916+
switch (kind) {
7917+
case AutoDiffAssociatedFunctionKind::JVP: return getJVPFunction();
7918+
case AutoDiffAssociatedFunctionKind::VJP: return getVJPFunction();
7919+
}
7920+
}
79127921
};
79137922

79147923
/// `differentiable_function_extract` - given an `@differentiable` function
@@ -7917,7 +7926,7 @@ class DifferentiableFunctionInst final :
79177926
class DifferentiableFunctionExtractInst
79187927
: public InstructionBase<
79197928
SILInstructionKind::DifferentiableFunctionExtractInst,
7920-
OwnershipForwardingSingleValueInst> {
7929+
SingleValueInstruction> {
79217930
public:
79227931
struct Extractee {
79237932
enum innerty : unsigned {
@@ -7939,46 +7948,28 @@ class DifferentiableFunctionExtractInst
79397948
private:
79407949
/// The extractee.
79417950
Extractee extractee;
7942-
/// The differentiation order. A zero value is only legal when the extractee
7943-
/// is the original function, and it is a private representation only.
7944-
unsigned differentiationOrder;
79457951
/// The list containing the `@differentiable` function operand.
79467952
FixedOperandList<1> operands;
79477953

79487954
static SILType
7949-
getExtracteeType(SILValue function, Extractee extractee,
7950-
unsigned differentiationOrder, SILModule &module);
7955+
getExtracteeType(SILValue function, Extractee extractee, SILModule &module);
79517956

79527957
public:
79537958
explicit DifferentiableFunctionExtractInst(
79547959
SILModule &module, SILDebugLocation debugLoc, Extractee extractee,
7955-
unsigned differentiationOrder, SILValue theFunction);
7960+
SILValue theFunction);
79567961

7957-
Extractee getExtractee() const {
7958-
return extractee;
7959-
}
7962+
Extractee getExtractee() const { return extractee; }
79607963

79617964
AutoDiffAssociatedFunctionKind getAssociatedFunctionKind() const {
79627965
auto kind = extractee.getExtracteeAsAssociatedFunction();
79637966
assert(kind);
79647967
return *kind;
79657968
}
79667969

7967-
SILValue getFunctionOperand() const {
7968-
return operands[0].get();
7969-
}
7970-
7971-
unsigned getDifferentiationOrder() const {
7972-
return differentiationOrder;
7973-
}
7974-
7975-
ArrayRef<Operand> getAllOperands() const {
7976-
return operands.asArray();
7977-
}
7978-
7979-
MutableArrayRef<Operand> getAllOperands() {
7980-
return operands.asArray();
7981-
}
7970+
SILValue getFunctionOperand() const { return operands[0].get(); }
7971+
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
7972+
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
79827973
};
79837974

79847975
typedef DifferentiableFunctionExtractInst::Extractee

0 commit comments

Comments
 (0)