Skip to content

Commit 431cc43

Browse files
authored
[AutoDiff] Rename "associated function" to "derivative function". (#27603)
`assocFn` -> `derivativeFn` `AssocFn` -> `DerivativeFn` `assocFunc` -> `derivativeFunc` `AssocFunc` -> `DerivativeFunc` `associatedFunction` -> `derivativeFunction` `AssociatedFunction` -> `DerivativeFunction` `autodiff associated function` -> `derivative function` `autodiff-associated function` -> `derivative function` `AD associated function` -> `derivative function` `associated differentiation function` -> `derivative function` This is a follow-up to #27597. Resolves [TF-882](https://bugs.swift.org/browse/TF-882).
1 parent 0aee08a commit 431cc43

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+510
-512
lines changed

docs/SIL.rst

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,35 +5609,34 @@ differentiable_function
56095609
sil-differentiable-function-parameter-indices?
56105610
sil-differentiable-function-order?
56115611
sil-value ':' sil-type
5612-
sil-differentiable-function-associated-functions-clause?
5612+
sil-differentiable-function-derivative-functions-clause?
56135613
56145614
sil-differentiable-function-parameter-indices ::=
56155615
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
56165616
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
5617-
sil-differentiable-associated-functions-clause ::=
5618-
'with' sil-differentiable-associated-function-list
5619-
(',' sil-differentiable-associated-function-list)*
5620-
sil-differentiable-function-associated-function-list ::=
5617+
sil-differentiable-derivative-functions-clause ::=
5618+
'with' sil-differentiable-derivative-function-list
5619+
(',' sil-differentiable-derivative-function-list)*
5620+
sil-differentiable-function-derivative-function-list ::=
56215621
'{' sil-value ',' sil-value '}'
56225622

56235623
differentiable_function [wrt 0] %0 : $(T) -> T \
56245624
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
56255625

5626-
Bundles a function with its associated differentiation functions into a
5627-
``@differentiable`` function. There are two associated functions:
5628-
a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP)
5629-
function.
5626+
Bundles a function with its derivative functions into a ``@differentiable``
5627+
function. There are two derivative functions: a Jacobian-vector products (JVP)
5628+
function and a vector-Jacobian products (VJP) function.
56305629

56315630
``[wrt ...]`` specifies parameter indices that the original function is
56325631
differentiable with respect to. When not specified, it defaults to all
56335632
parameters.
56345633

56355634
A ``with`` clause specifies the differentiation functions associated
56365635
with the original function. When a ``with`` clause is not specified, the first
5637-
operand will be differentiated to produce associated functions, and a ``with``
5636+
operand will be differentiated to produce derivative functions, and a ``with``
56385637
clause will be added to the instruction.
56395638

5640-
In raw SIL, it is optional to provide an associated function ``with`` clause.
5639+
In raw SIL, it is optional to provide a derivative function ``with`` clause.
56415640
In canonical SIL, a ``with`` clause is mandatory.
56425641

56435642

@@ -5660,7 +5659,7 @@ differentiable_function_extract
56605659
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
56615660
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
56625661

5663-
Extracts the original function or an associated function from the given
5662+
Extracts the original function or a derivative function from the given
56645663
``@differentiable`` function. It must be provided with an extractee:
56655664
``[original]``, ``[jvp]`` or ``[vjp]``.
56665665

include/swift/AST/ASTMangler.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ class ASTMangler : public Mangler {
155155
ModuleDecl *Module);
156156

157157
// SWIFT_ENABLE_TENSORFLOW
158-
// Mangle the autodiff associated function (JVP/VJP) with the given:
158+
// Mangle the derivative function (JVP/VJP) with the given:
159159
// - Mangled original function name.
160-
// - Associated function kind.
160+
// - Derivative function kind.
161161
// - Parameter/result indices.
162-
std::string mangleAutoDiffAssociatedFunctionHelper(
163-
StringRef name, AutoDiffAssociatedFunctionKind kind,
162+
std::string mangleAutoDiffDerivativeFunctionHelper(
163+
StringRef name, AutoDiffDerivativeFunctionKind kind,
164164
const SILAutoDiffIndices &indices);
165165

166166
// SWIFT_ENABLE_TENSORFLOW

include/swift/AST/Attr.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ class DifferentiableAttr final
15551555
AutoDiffIndexSubset *ParameterIndices = nullptr;
15561556
/// The trailing where clause (optional).
15571557
TrailingWhereClause *WhereClause = nullptr;
1558-
/// The generic signature for autodiff associated functions. Resolved by the
1558+
/// The generic signature for autodiff derivative functions. Resolved by the
15591559
/// type checker based on the original function's generic signature and the
15601560
/// attribute's where clause requirements. This is set only if the attribute
15611561
/// has a where clause.
@@ -1650,10 +1650,10 @@ class DifferentiableAttr final
16501650

16511651
// Print the attribute to the given stream.
16521652
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1653-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1653+
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
16541654
void print(llvm::raw_ostream &OS, const Decl *D,
16551655
bool omitWrtClause = false,
1656-
bool omitAssociatedFunctions = false) const;
1656+
bool omitDerivativeFunctions = false) const;
16571657

16581658
static bool classof(const DeclAttribute *DA) {
16591659
return DA->getKind() == DAK_Differentiable;

include/swift/AST/AutoDiff.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -431,48 +431,48 @@ struct AutoDiffLinearMapKind {
431431
operator innerty() const { return rawValue; }
432432
};
433433

434-
/// The kind of an associated function.
435-
struct AutoDiffAssociatedFunctionKind {
434+
/// The kind of a derivative function.
435+
struct AutoDiffDerivativeFunctionKind {
436436
enum innerty : uint8_t {
437437
// The Jacobian-vector products function.
438438
JVP = 0,
439439
// The vector-Jacobian products function.
440440
VJP = 1
441441
} rawValue;
442442

443-
AutoDiffAssociatedFunctionKind() = default;
444-
AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {}
445-
AutoDiffAssociatedFunctionKind(AutoDiffLinearMapKind linMapKind)
443+
AutoDiffDerivativeFunctionKind() = default;
444+
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
445+
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
446446
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
447-
explicit AutoDiffAssociatedFunctionKind(StringRef string);
447+
explicit AutoDiffDerivativeFunctionKind(StringRef string);
448448
operator innerty() const { return rawValue; }
449449
AutoDiffLinearMapKind getLinearMapKind() {
450450
return (AutoDiffLinearMapKind::innerty)rawValue;
451451
}
452452
};
453453

454454
/// In conjunction with the original function declaration, identifies an
455-
/// autodiff associated function.
455+
/// autodiff derivative function.
456456
///
457457
/// Is uniquely allocated within an ASTContext so that it can be hashed and
458458
/// compared by opaque pointer value.
459-
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
460-
const AutoDiffAssociatedFunctionKind kind;
459+
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
460+
const AutoDiffDerivativeFunctionKind kind;
461461
AutoDiffIndexSubset *const parameterIndices;
462462

463-
AutoDiffAssociatedFunctionIdentifier(
464-
AutoDiffAssociatedFunctionKind kind,
463+
AutoDiffDerivativeFunctionIdentifier(
464+
AutoDiffDerivativeFunctionKind kind,
465465
AutoDiffIndexSubset *parameterIndices) :
466466
kind(kind), parameterIndices(parameterIndices) {}
467467

468468
public:
469-
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
469+
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
470470
AutoDiffIndexSubset *getParameterIndices() const {
471471
return parameterIndices;
472472
}
473473

474-
static AutoDiffAssociatedFunctionIdentifier *get(
475-
AutoDiffAssociatedFunctionKind kind,
474+
static AutoDiffDerivativeFunctionIdentifier *get(
475+
AutoDiffDerivativeFunctionKind kind,
476476
AutoDiffIndexSubset *parameterIndices, ASTContext &C);
477477

478478
void Profile(llvm::FoldingSetNodeID &ID) {
@@ -520,15 +520,15 @@ AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices,
520520
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
521521
/// Returns true if the function name is parsed successfully.
522522
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
523-
AutoDiffAssociatedFunctionKind &kind,
523+
AutoDiffDerivativeFunctionKind &kind,
524524
unsigned &arity, bool &rethrows);
525525

526-
/// Computes the correct linkage for an associated function given the linkage of
526+
/// Computes the correct linkage for a derivative function given the linkage of
527527
/// the original function. If the original linkage is not external and
528-
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
529-
/// return hidden linkage.
530-
SILLinkage getAutoDiffAssociatedFunctionLinkage(SILLinkage originalLinkage,
531-
bool isAssocFnExported);
528+
/// `isDerivativeFnExported` is true, use the original function's linkage.
529+
/// Otherwise, return hidden linkage.
530+
SILLinkage getAutoDiffDerivativeFunctionLinkage(SILLinkage originalLinkage,
531+
bool isDerivativeFnExported);
532532

533533
} // end namespace autodiff
534534

include/swift/AST/DiagnosticsParse.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,15 +1597,15 @@ ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
15971597
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
15981598
"expected the index of a parameter to differentiate with respect to", ())
15991599
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
1600-
"expected '{' to start an associated function list", ())
1600+
"expected '{' to start a derivative function list", ())
16011601
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,
1602-
"expected ',' between operands in an associated function list", ())
1602+
"expected ',' between operands in a derivative function list", ())
16031603
ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
1604-
"expected '}' to start an associated function list", ())
1604+
"expected '}' to start a derivative function list", ())
16051605
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
16061606
"the number of operand lists does not match the order", ())
16071607
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
1608-
"expected an associated function kind attribute, e.g. '[jvp]'", ())
1608+
"expected a derivative function kind attribute, e.g. '[jvp]'", ())
16091609
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
16101610
"expected an operand of a function type", ())
16111611

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,8 +2859,7 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
28592859
ERROR(differentiable_attr_void_result,none,
28602860
"cannot differentiate void function %0", (DeclName))
28612861
ERROR(differentiable_attr_associated_function_protocol,none,
2862-
"cannot specify associated differentiation function on protocol "
2863-
"requirement", ())
2862+
"cannot specify derivative functions on protocol requirements", ())
28642863
ERROR(differentiable_attr_overload_not_found,none,
28652864
"%0 does not have expected type %1", (DeclName, Type))
28662865
ERROR(differentiable_attr_no_currying,none,
@@ -2874,17 +2873,17 @@ NOTE(differentiable_attr_duplicate_note,none,
28742873
ERROR(differentiable_attr_function_not_same_type_context,none,
28752874
"%0 is not defined in the current type context", (DeclName))
28762875
ERROR(differentiable_attr_specified_not_function,none,
2877-
"%0 is not a function to be used as associated differentiation function",
2876+
"%0 is not a function to be used as derivative function",
28782877
(DeclName))
28792878
ERROR(differentiable_attr_class_derivative_not_final,none,
28802879
"class member derivative must be final", ())
28812880
ERROR(differentiable_attr_ambiguous_function_identifier,none,
28822881
"ambiguous or overloaded identifier %0 cannot be used in '@differentiable' "
28832882
"attribute", (DeclName))
28842883
ERROR(differentiable_attr_invalid_access,none,
2885-
"associated differentiation function %0 is required to either be public "
2886-
"or @usableFromInline because the original function %1 is public or "
2887-
"@usableFromInline", (DeclName, DeclName))
2884+
"derivative function %0 is required to either be public or "
2885+
"'@usableFromInline' because the original function %1 is public or "
2886+
"'@usableFromInline'", (DeclName, DeclName))
28882887
ERROR(differentiable_attr_result_not_differentiable,none,
28892888
"can only differentiate functions with results that conform to "
28902889
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))

include/swift/AST/Types.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,10 +3100,10 @@ class AnyFunctionType : public TypeBase {
31003100

31013101
// SWIFT_ENABLE_TENSORFLOW
31023102
/// Given `indices` and `kind`, calculates the type of the corresponding
3103-
/// autodiff associated function.
3103+
/// autodiff derivative function.
31043104
///
31053105
/// By default, if the original type has a self parameter list and parameter
3106-
/// indices include self, the computed associated function type will return a
3106+
/// indices include self, the computed derivative function type will return a
31073107
/// linear map taking/returning self's tangent *last* instead of first, for
31083108
/// consistency with SIL.
31093109
///
@@ -3114,18 +3114,18 @@ class AnyFunctionType : public TypeBase {
31143114
/// \note The original function type (`self`) need not be `@differentiable`.
31153115
/// The resulting function will preserve all `ExtInfo` of the original
31163116
/// function, including `@differentiable`.
3117-
AnyFunctionType *getAutoDiffAssociatedFunctionType(
3117+
AnyFunctionType *getAutoDiffDerivativeFunctionType(
31183118
AutoDiffIndexSubset *indices, unsigned resultIndex,
3119-
AutoDiffAssociatedFunctionKind kind,
3119+
AutoDiffDerivativeFunctionKind kind,
31203120
LookupConformanceFn lookupConformance,
31213121
GenericSignature *whereClauseGenericSignature = nullptr,
31223122
bool makeSelfParamFirst = false);
31233123

3124-
/// Given the type of an autodiff associated function, returns the
3124+
/// Given the type of an autodiff derivative function, returns the
31253125
/// corresponding original function type.
31263126
AnyFunctionType *getAutoDiffOriginalFunctionType();
31273127

3128-
/// Given the type of a transposing associated function, returns the
3128+
/// Given the type of a transposing derivative function, returns the
31293129
/// corresponding original function type.
31303130
AnyFunctionType *
31313131
getTransposeOriginalFunctionType(TransposingAttr *attr,
@@ -4222,11 +4222,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42224222

42234223
/// Returns the type of a differentiation function that is associated with
42244224
/// a function of this type.
4225-
CanSILFunctionType getAutoDiffAssociatedFunctionType(
4225+
CanSILFunctionType getAutoDiffDerivativeFunctionType(
42264226
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
4227-
AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC,
4227+
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
42284228
LookupConformanceFn lookupConformance,
4229-
CanGenericSignature associatedFunctionGenericSignature = nullptr);
4229+
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
42304230

42314231
/// Returns a bit vector that specifices which parameters you can
42324232
/// differentiate with respect to for this differentiable function type. (e.g.

include/swift/SIL/SILCloner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
973973
Optional<std::pair<SILValue, SILValue>> derivativeFns = None;
974974
if (Inst->hasDerivativeFunctions())
975975
derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
976-
getOpValue(Inst->getVJPFunction()));
976+
getOpValue(Inst->getVJPFunction()));
977977
recordClonedInstruction(
978978
Inst, getBuilder().createDifferentiableFunction(
979979
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),

0 commit comments

Comments
 (0)