Skip to content

Commit 959ffb0

Browse files
marcrasidan-zheng
authored andcommitted
[AutoDiff] rename @nondiff to @noDerivative (#28887)
Cherry-picks #28278 to `tensorflow` branch. `@nondiff` exists with a deprecation warning.
1 parent 05c4539 commit 959ffb0

39 files changed

+231
-158
lines changed

include/swift/ABI/MetadataValues.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ class TargetParameterTypeFlags {
857857
ValueOwnershipMask = 0x7F,
858858
VariadicMask = 0x80,
859859
AutoClosureMask = 0x100,
860-
NonDifferentiableMask = 0x200
860+
NoDerivativeMask = 0x200
861861
};
862862
int_type Data;
863863

@@ -888,7 +888,7 @@ class TargetParameterTypeFlags {
888888
bool isVariadic() const { return Data & VariadicMask; }
889889
bool isAutoClosure() const { return Data & AutoClosureMask; }
890890
// SWIFT_ENABLE_TENSORFLOW
891-
bool isNonDifferentiable() const { return Data & NonDifferentiableMask; }
891+
bool isNoDerivative() const { return Data & NoDerivativeMask; }
892892

893893
ValueOwnership getValueOwnership() const {
894894
return (ValueOwnership)(Data & ValueOwnershipMask);

include/swift/AST/Attr.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TYPE_ATTR(convention)
5252
TYPE_ATTR(noescape)
5353
TYPE_ATTR(escaping)
5454
TYPE_ATTR(differentiable)
55+
TYPE_ATTR(noDerivative)
5556
// SWIFT_ENABLE_TENSORFLOW
5657
TYPE_ATTR(autodiff)
5758
TYPE_ATTR(nondiff)

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,8 @@ NOTE(autodiff_class_member_not_differentiable,none,
505505
NOTE(autodiff_member_subset_indices_not_differentiable,none,
506506
"member is differentiable only with respect to a smaller subset of "
507507
"arguments", ())
508-
NOTE(autodiff_function_nondiff_parameter_not_differentiable,none,
509-
"cannot differentiate with respect to a '@nondiff' parameter", ())
508+
NOTE(autodiff_function_noderivative_parameter_not_differentiable,none,
509+
"cannot differentiate with respect to a '@noDerivative' parameter", ())
510510
NOTE(autodiff_function_assoc_func_unmet_requirements,none,
511511
"function call is not differentiable because generic requirements are not "
512512
"met: '%0'", (StringRef))

include/swift/AST/DiagnosticsSema.def

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3089,6 +3089,11 @@ ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
30893089
"'@noDerivative' is only allowed on stored properties in structure or "
30903090
"class types that declare a conformance to 'Differentiable'", ())
30913091

3092+
// SWIFT_ENABLE_TENSORFLOW
3093+
// @nonDiff attribute
3094+
WARNING(nondiff_attr_deprecated,none,
3095+
"'@nondiff' is deprecated; use '@noDerivative' instead", ())
3096+
30923097
//------------------------------------------------------------------------------
30933098
// MARK: Type Check Expressions
30943099
//------------------------------------------------------------------------------
@@ -4070,17 +4075,20 @@ ERROR(opaque_type_in_protocol_requirement,none,
40704075
// Function differentiability
40714076
ERROR(autodiff_attr_argument_not_differentiable,none,
40724077
"argument is not differentiable, but the enclosing function type is "
4073-
"marked '@differentiable'; did you want to add '@nondiff' to this argument?",
4078+
"marked '@differentiable'; did you want to add '@noDerivative' to this argument?",
40744079
())
40754080
ERROR(autodiff_attr_result_not_differentiable,none,
40764081
"result is not differentiable, but the function type is marked "
40774082
"'@differentiable'", ())
4078-
ERROR(nondiff_attr_invalid_on_nondifferentiable_function,none,
4079-
"'nondiff' cannot be applied to arguments of a non-differentiable "
4080-
"function", ())
40814083
ERROR(attr_differentiable_no_vjp_or_jvp_when_linear,none,
40824084
"cannot specify 'vjp:' or 'jvp:' for linear functions; use "
40834085
"'transpose:' instead", ())
4086+
4087+
// Function differentiability
4088+
ERROR(attr_only_on_parameters_of_differentiable,none,
4089+
"'%0' may only be used on parameters of '@differentiable' function "
4090+
"types", (StringRef))
4091+
40844092
// SIL
40854093
ERROR(opened_non_protocol,none,
40864094
"@opened cannot be applied to non-protocol type %0", (Type))

include/swift/AST/Types.h

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,9 +1828,7 @@ class ParameterTypeFlags {
18281828
NonEphemeral = 1 << 2,
18291829
OwnershipShift = 3,
18301830
Ownership = 7 << OwnershipShift,
1831-
// SWIFT_ENABLE_TENSORFLOW
1832-
NonDifferentiable = 1 << 6,
1833-
1831+
NoDerivative = 1 << 7,
18341832
NumBits = 7
18351833
};
18361834
OptionSet<ParameterFlags> value;
@@ -1845,20 +1843,18 @@ class ParameterTypeFlags {
18451843
}
18461844

18471845
ParameterTypeFlags(bool variadic, bool autoclosure, bool nonEphemeral,
1848-
// SWIFT_ENABLE_TENSORFLOW
1849-
ValueOwnership ownership, bool nonDifferentiable)
1846+
ValueOwnership ownership, bool noDerivative)
18501847
: value((variadic ? Variadic : 0) | (autoclosure ? AutoClosure : 0) |
18511848
(nonEphemeral ? NonEphemeral : 0) |
1852-
// SWIFT_ENABLE_TENSORFLOW
1853-
(uint8_t(ownership) << OwnershipShift) |
1854-
(nonDifferentiable ? NonDifferentiable : 0)) {}
1849+
uint8_t(ownership) << OwnershipShift |
1850+
(noDerivative ? NoDerivative : 0)) {}
18551851

18561852
/// Create one from what's present in the parameter type
18571853
inline static ParameterTypeFlags
18581854
// SWIFT_ENABLE_TENSORFLOW
18591855
fromParameterType(Type paramTy, bool isVariadic, bool isAutoClosure,
18601856
bool isNonEphemeral, ValueOwnership ownership,
1861-
bool isNonDifferentiable);
1857+
bool isNoDerivative);
18621858

18631859
bool isNone() const { return !value; }
18641860
bool isVariadic() const { return value.contains(Variadic); }
@@ -1867,8 +1863,7 @@ class ParameterTypeFlags {
18671863
bool isInOut() const { return getValueOwnership() == ValueOwnership::InOut; }
18681864
bool isShared() const { return getValueOwnership() == ValueOwnership::Shared;}
18691865
bool isOwned() const { return getValueOwnership() == ValueOwnership::Owned; }
1870-
// SWIFT_ENABLE_TENSORFLOW
1871-
bool isNonDifferentiable() const { return value.contains(NonDifferentiable); }
1866+
bool isNoDerivative() const { return value.contains(NoDerivative); }
18721867

18731868
ValueOwnership getValueOwnership() const {
18741869
return ValueOwnership((value.toRaw() & Ownership) >> OwnershipShift);
@@ -1905,19 +1900,18 @@ class ParameterTypeFlags {
19051900
: value - ParameterTypeFlags::AutoClosure);
19061901
}
19071902

1908-
// SWIFT_ENABLE_TENSORFLOW
1909-
ParameterTypeFlags withNonDifferentiable(bool nonDifferentiable) const {
1910-
return ParameterTypeFlags(nonDifferentiable
1911-
? value | ParameterTypeFlags::NonDifferentiable
1912-
: value - ParameterTypeFlags::NonDifferentiable);
1913-
}
1914-
19151903
ParameterTypeFlags withNonEphemeral(bool isNonEphemeral) const {
19161904
return ParameterTypeFlags(isNonEphemeral
19171905
? value | ParameterTypeFlags::NonEphemeral
19181906
: value - ParameterTypeFlags::NonEphemeral);
19191907
}
19201908

1909+
ParameterTypeFlags withNoDerivative(bool noDerivative) const {
1910+
return ParameterTypeFlags(noDerivative
1911+
? value | ParameterTypeFlags::NoDerivative
1912+
: value - ParameterTypeFlags::NoDerivative);
1913+
}
1914+
19211915
bool operator ==(const ParameterTypeFlags &other) const {
19221916
return value.toRaw() == other.value.toRaw();
19231917
}
@@ -1984,10 +1978,8 @@ class YieldTypeFlags {
19841978
ParameterTypeFlags asParamFlags() const {
19851979
return ParameterTypeFlags(/*variadic*/ false,
19861980
/*autoclosure*/ false,
1987-
/*nonEphemeral*/ false,
1988-
// SWIFT_ENABLE_TENSORFLOW
1989-
getValueOwnership(),
1990-
/*nondifferentiable*/ false);
1981+
/*nonEphemeral*/ false, getValueOwnership(),
1982+
/*noDerivative*/ false);
19911983
}
19921984

19931985
bool operator ==(const YieldTypeFlags &other) const {
@@ -2859,9 +2851,8 @@ class AnyFunctionType : public TypeBase {
28592851
/// Whether the parameter is marked '@_nonEphemeral'
28602852
bool isNonEphemeral() const { return Flags.isNonEphemeral(); }
28612853

2862-
// SWIFT_ENABLE_TENSORFLOW
2863-
/// Whether the parameter is marked '@nondiff'.
2864-
bool isNonDifferentiable() const { return Flags.isNonDifferentiable(); }
2854+
/// Whether the parameter is marked '@noDerivative'.
2855+
bool isNoDerivative() const { return Flags.isNoDerivative(); }
28652856

28662857
ValueOwnership getValueOwnership() const {
28672858
return Flags.getValueOwnership();
@@ -4516,7 +4507,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
45164507

45174508
/// Returns a bit vector that specifices which parameters you can
45184509
/// differentiate with respect to for this differentiable function type. (e.g.
4519-
/// which parameters are not `@nondiff`). The function type must be
4510+
/// which parameters are not `@noDerivative`). The function type must be
45204511
/// differentiable.
45214512
IndexSubset *getDifferentiationParameterIndices();
45224513

@@ -6052,12 +6043,9 @@ inline TupleTypeElt TupleTypeElt::getWithType(Type T) const {
60526043
}
60536044

60546045
/// Create one from what's present in the parameter decl and type
6055-
inline ParameterTypeFlags
6056-
ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
6057-
bool isAutoClosure, bool isNonEphemeral,
6058-
// SWIFT_ENABLE_TENSORFLOW
6059-
ValueOwnership ownership,
6060-
bool isNonDifferentiable) {
6046+
inline ParameterTypeFlags ParameterTypeFlags::fromParameterType(
6047+
Type paramTy, bool isVariadic, bool isAutoClosure, bool isNonEphemeral,
6048+
ValueOwnership ownership, bool isNoDerivative) {
60616049
// FIXME(Remove InOut): The last caller that needs this is argument
60626050
// decomposition. Start by enabling the assertion there and fixing up those
60636051
// callers, then remove this, then remove
@@ -6067,9 +6055,7 @@ ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
60676055
ownership == ValueOwnership::InOut);
60686056
ownership = ValueOwnership::InOut;
60696057
}
6070-
// SWIFT_ENABLE_TENSORFLOW
6071-
return {isVariadic, isAutoClosure, isNonEphemeral, ownership,
6072-
isNonDifferentiable};
6058+
return {isVariadic, isAutoClosure, isNonEphemeral, ownership, isNoDerivative};
60736059
}
60746060

60756061
inline const Type *BoundGenericType::getTrailingObjectsPointer() const {

lib/AST/ASTContext.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,11 +3083,10 @@ void AnyFunctionType::decomposeInput(
30833083
}
30843084

30853085
default:
3086-
result.emplace_back(type->getInOutObjectType(), Identifier(),
3087-
ParameterTypeFlags::fromParameterType(
3088-
// SWIFT_ENABLE_TENSORFLOW
3089-
type, false, false, false, ValueOwnership::Default,
3090-
/*nonDifferentiable*/ false));
3086+
result.emplace_back(
3087+
type->getInOutObjectType(), Identifier(),
3088+
ParameterTypeFlags::fromParameterType(type, false, false, false,
3089+
ValueOwnership::Default, false));
30913090
return;
30923091
}
30933092
}

lib/AST/ASTDemangler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ Type ASTBuilder::createFunctionType(
367367
.withVariadic(flags.isVariadic())
368368
// SWIFT_ENABLE_TENSORFLOW
369369
.withAutoClosure(flags.isAutoClosure())
370-
.withNonDifferentiable(flags.isNonDifferentiable());
370+
.withNoDerivative(flags.isNoDerivative());
371371

372372
funcParams.push_back(AnyFunctionType::Param(type, label, parameterFlags));
373373
}

lib/AST/ASTPrinter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,9 +2497,8 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
24972497
ParameterTypeFlags flags, bool escaping) {
24982498
if (!options.excludeAttrKind(TAK_autoclosure) && flags.isAutoClosure())
24992499
printer << "@autoclosure ";
2500-
// SWIFT_ENABLE_TENSORFLOW
2501-
if (!options.excludeAttrKind(TAK_nondiff) && flags.isNonDifferentiable())
2502-
printer << "@nondiff ";
2500+
if (!options.excludeAttrKind(TAK_noDerivative) && flags.isNoDerivative())
2501+
printer << "@noDerivative ";
25032502

25042503
switch (flags.getValueOwnership()) {
25052504
case ValueOwnership::Default:
@@ -4577,7 +4576,7 @@ void SILParameterInfo::print(ASTPrinter &Printer,
45774576
/// SWIFT_ENABLE_TENSORFLOW
45784577
switch (getDifferentiability()) {
45794578
case SILParameterDifferentiability::NotDifferentiable:
4580-
Printer << "@nondiff ";
4579+
Printer << "@noDerivative ";
45814580
break;
45824581
default:
45834582
break;

lib/AST/Decl.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6109,13 +6109,10 @@ AnyFunctionType::Param ParamDecl::toFunctionParam(Type type) const {
61096109
type = ParamDecl::getVarargBaseTy(type);
61106110

61116111
auto label = getArgumentName();
6112-
auto flags = ParameterTypeFlags::fromParameterType(type,
6113-
isVariadic(),
6114-
isAutoClosure(),
6115-
isNonEphemeral(),
6116-
// SWIFT_ENABLE_TENSORFLOW
6117-
getValueOwnership(),
6118-
/*nondifferentiable*/ false);
6112+
auto flags = ParameterTypeFlags::fromParameterType(
6113+
type, isVariadic(), isAutoClosure(), isNonEphemeral(),
6114+
getValueOwnership(),
6115+
/*isNoDerivative*/ false);
61196116
return AnyFunctionType::Param(type, label, flags);
61206117
}
61216118

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5084,7 +5084,7 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
50845084
};
50855085
auto constrainParametersAndResult = [&](ProtocolDecl *protocol) {
50865086
for (auto &param : fnTy->getParams())
5087-
if (!param.isNonDifferentiable())
5087+
if (!param.isNoDerivative())
50885088
addConstraint(param.getPlainType(), protocol);
50895089
addConstraint(fnTy->getResult(), protocol);
50905090
};

lib/AST/Type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5164,7 +5164,7 @@ AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
51645164
SmallVector<Param, 8> newParams;
51655165
for (auto &param : getParams()) {
51665166
Param newParam(param.getPlainType(), param.getLabel(),
5167-
param.getParameterFlags().withNonDifferentiable(false));
5167+
param.getParameterFlags().withNoDerivative(false));
51685168
newParams.push_back(newParam);
51695169
}
51705170
auto nonDiffExtInfo = getExtInfo()

lib/AST/TypeRepr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
298298
Printer.printSimpleAttr("@autoclosure") << " ";
299299
if (hasAttr(TAK_escaping))
300300
Printer.printSimpleAttr("@escaping") << " ";
301+
if (hasAttr(TAK_noDerivative) || hasAttr(TAK_nondiff))
302+
Printer.printSimpleAttr("@noDerivative") << " ";
301303

302304
if (hasAttr(TAK_differentiable)) {
303305
if (Attrs.isLinear()) {

lib/SIL/SILFunctionType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ class DestructureInputs {
10931093

10941094
visit(flags.getValueOwnership(), /*forSelf=*/false,
10951095
// SWIFT_ENABLE_TENSORFLOW
1096-
eltPattern, ty, silRepresentation, flags.isNonDifferentiable());
1096+
eltPattern, ty, silRepresentation, flags.isNoDerivative());
10971097
}
10981098

10991099
// Process the self parameter. Note that we implicitly drop self

lib/SILGen/SILGenPoly.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3337,7 +3337,7 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33373337
}
33383338
llvm::SmallBitVector parameterBits(numUncurriedParams);
33393339
for (auto i : range(inputSubstType->getNumParams()))
3340-
if (!inputSubstType->getParams()[i].isNonDifferentiable())
3340+
if (!inputSubstType->getParams()[i].isNoDerivative())
33413341
parameterBits.set(i);
33423342
auto *parameterIndices = IndexSubset::get(SGF.getASTContext(), parameterBits);
33433343

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ emitDerivativeFunctionReference(
509509
for (auto i : desiredIndices.parameters->getIndices()) {
510510
if (!paramIndices->contains(i)) {
511511
context.emitNondifferentiabilityError(functionSource, invoker,
512-
diag::autodiff_function_nondiff_parameter_not_differentiable);
512+
diag::autodiff_function_noderivative_parameter_not_differentiable);
513513
return None;
514514
}
515515
}

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,7 @@ void JVPEmitter::visitApplyInst(ApplyInst *ai) {
12531253
if (!paramIndices->contains(i)) {
12541254
context.emitNondifferentiabilityError(
12551255
original, invoker,
1256-
diag::autodiff_function_nondiff_parameter_not_differentiable);
1256+
diag::autodiff_function_noderivative_parameter_not_differentiable);
12571257
errorOccurred = true;
12581258
return;
12591259
}

lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
499499
if (!paramIndices->contains(i)) {
500500
context.emitNondifferentiabilityError(
501501
original, invoker,
502-
diag::autodiff_function_nondiff_parameter_not_differentiable);
502+
diag::autodiff_function_noderivative_parameter_not_differentiable);
503503
errorOccurred = true;
504504
return;
505505
}

0 commit comments

Comments
 (0)