Skip to content

[AutoDiff] rename @nondiff to @noDerivative #28887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ class TargetParameterTypeFlags {
ValueOwnershipMask = 0x7F,
VariadicMask = 0x80,
AutoClosureMask = 0x100,
NonDifferentiableMask = 0x200
NoDerivativeMask = 0x200
};
int_type Data;

Expand Down Expand Up @@ -888,7 +888,7 @@ class TargetParameterTypeFlags {
bool isVariadic() const { return Data & VariadicMask; }
bool isAutoClosure() const { return Data & AutoClosureMask; }
// SWIFT_ENABLE_TENSORFLOW
bool isNonDifferentiable() const { return Data & NonDifferentiableMask; }
bool isNoDerivative() const { return Data & NoDerivativeMask; }

ValueOwnership getValueOwnership() const {
return (ValueOwnership)(Data & ValueOwnershipMask);
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TYPE_ATTR(convention)
TYPE_ATTR(noescape)
TYPE_ATTR(escaping)
TYPE_ATTR(differentiable)
TYPE_ATTR(noDerivative)
// SWIFT_ENABLE_TENSORFLOW
TYPE_ATTR(autodiff)
TYPE_ATTR(nondiff)
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ NOTE(autodiff_class_member_not_differentiable,none,
NOTE(autodiff_member_subset_indices_not_differentiable,none,
"member is differentiable only with respect to a smaller subset of "
"arguments", ())
NOTE(autodiff_function_nondiff_parameter_not_differentiable,none,
"cannot differentiate with respect to a '@nondiff' parameter", ())
NOTE(autodiff_function_noderivative_parameter_not_differentiable,none,
"cannot differentiate with respect to a '@noDerivative' parameter", ())
NOTE(autodiff_function_assoc_func_unmet_requirements,none,
"function call is not differentiable because generic requirements are not "
"met: '%0'", (StringRef))
Expand Down
16 changes: 12 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3089,6 +3089,11 @@ ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
"'@noDerivative' is only allowed on stored properties in structure or "
"class types that declare a conformance to 'Differentiable'", ())

// SWIFT_ENABLE_TENSORFLOW
// @nonDiff attribute
WARNING(nondiff_attr_deprecated,none,
"'@nondiff' is deprecated; use '@noDerivative' instead", ())

//------------------------------------------------------------------------------
// MARK: Type Check Expressions
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -4070,17 +4075,20 @@ ERROR(opaque_type_in_protocol_requirement,none,
// Function differentiability
ERROR(autodiff_attr_argument_not_differentiable,none,
"argument is not differentiable, but the enclosing function type is "
"marked '@differentiable'; did you want to add '@nondiff' to this argument?",
"marked '@differentiable'; did you want to add '@noDerivative' to this argument?",
())
ERROR(autodiff_attr_result_not_differentiable,none,
"result is not differentiable, but the function type is marked "
"'@differentiable'", ())
ERROR(nondiff_attr_invalid_on_nondifferentiable_function,none,
"'nondiff' cannot be applied to arguments of a non-differentiable "
"function", ())
ERROR(attr_differentiable_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use "
"'transpose:' instead", ())

// Function differentiability
ERROR(attr_only_on_parameters_of_differentiable,none,
"'%0' may only be used on parameters of '@differentiable' function "
"types", (StringRef))

// SIL
ERROR(opened_non_protocol,none,
"@opened cannot be applied to non-protocol type %0", (Type))
Expand Down
56 changes: 21 additions & 35 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1828,9 +1828,7 @@ class ParameterTypeFlags {
NonEphemeral = 1 << 2,
OwnershipShift = 3,
Ownership = 7 << OwnershipShift,
// SWIFT_ENABLE_TENSORFLOW
NonDifferentiable = 1 << 6,

NoDerivative = 1 << 7,
NumBits = 7
};
OptionSet<ParameterFlags> value;
Expand All @@ -1845,20 +1843,18 @@ class ParameterTypeFlags {
}

ParameterTypeFlags(bool variadic, bool autoclosure, bool nonEphemeral,
// SWIFT_ENABLE_TENSORFLOW
ValueOwnership ownership, bool nonDifferentiable)
ValueOwnership ownership, bool noDerivative)
: value((variadic ? Variadic : 0) | (autoclosure ? AutoClosure : 0) |
(nonEphemeral ? NonEphemeral : 0) |
// SWIFT_ENABLE_TENSORFLOW
(uint8_t(ownership) << OwnershipShift) |
(nonDifferentiable ? NonDifferentiable : 0)) {}
uint8_t(ownership) << OwnershipShift |
(noDerivative ? NoDerivative : 0)) {}

/// Create one from what's present in the parameter type
inline static ParameterTypeFlags
// SWIFT_ENABLE_TENSORFLOW
fromParameterType(Type paramTy, bool isVariadic, bool isAutoClosure,
bool isNonEphemeral, ValueOwnership ownership,
bool isNonDifferentiable);
bool isNoDerivative);

bool isNone() const { return !value; }
bool isVariadic() const { return value.contains(Variadic); }
Expand All @@ -1867,8 +1863,7 @@ class ParameterTypeFlags {
bool isInOut() const { return getValueOwnership() == ValueOwnership::InOut; }
bool isShared() const { return getValueOwnership() == ValueOwnership::Shared;}
bool isOwned() const { return getValueOwnership() == ValueOwnership::Owned; }
// SWIFT_ENABLE_TENSORFLOW
bool isNonDifferentiable() const { return value.contains(NonDifferentiable); }
bool isNoDerivative() const { return value.contains(NoDerivative); }

ValueOwnership getValueOwnership() const {
return ValueOwnership((value.toRaw() & Ownership) >> OwnershipShift);
Expand Down Expand Up @@ -1905,19 +1900,18 @@ class ParameterTypeFlags {
: value - ParameterTypeFlags::AutoClosure);
}

// SWIFT_ENABLE_TENSORFLOW
ParameterTypeFlags withNonDifferentiable(bool nonDifferentiable) const {
return ParameterTypeFlags(nonDifferentiable
? value | ParameterTypeFlags::NonDifferentiable
: value - ParameterTypeFlags::NonDifferentiable);
}

ParameterTypeFlags withNonEphemeral(bool isNonEphemeral) const {
return ParameterTypeFlags(isNonEphemeral
? value | ParameterTypeFlags::NonEphemeral
: value - ParameterTypeFlags::NonEphemeral);
}

ParameterTypeFlags withNoDerivative(bool noDerivative) const {
return ParameterTypeFlags(noDerivative
? value | ParameterTypeFlags::NoDerivative
: value - ParameterTypeFlags::NoDerivative);
}

bool operator ==(const ParameterTypeFlags &other) const {
return value.toRaw() == other.value.toRaw();
}
Expand Down Expand Up @@ -1984,10 +1978,8 @@ class YieldTypeFlags {
ParameterTypeFlags asParamFlags() const {
return ParameterTypeFlags(/*variadic*/ false,
/*autoclosure*/ false,
/*nonEphemeral*/ false,
// SWIFT_ENABLE_TENSORFLOW
getValueOwnership(),
/*nondifferentiable*/ false);
/*nonEphemeral*/ false, getValueOwnership(),
/*noDerivative*/ false);
}

bool operator ==(const YieldTypeFlags &other) const {
Expand Down Expand Up @@ -2859,9 +2851,8 @@ class AnyFunctionType : public TypeBase {
/// Whether the parameter is marked '@_nonEphemeral'
bool isNonEphemeral() const { return Flags.isNonEphemeral(); }

// SWIFT_ENABLE_TENSORFLOW
/// Whether the parameter is marked '@nondiff'.
bool isNonDifferentiable() const { return Flags.isNonDifferentiable(); }
/// Whether the parameter is marked '@noDerivative'.
bool isNoDerivative() const { return Flags.isNoDerivative(); }

ValueOwnership getValueOwnership() const {
return Flags.getValueOwnership();
Expand Down Expand Up @@ -4516,7 +4507,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

/// Returns a bit vector that specifices which parameters you can
/// differentiate with respect to for this differentiable function type. (e.g.
/// which parameters are not `@nondiff`). The function type must be
/// which parameters are not `@noDerivative`). The function type must be
/// differentiable.
IndexSubset *getDifferentiationParameterIndices();

Expand Down Expand Up @@ -6052,12 +6043,9 @@ inline TupleTypeElt TupleTypeElt::getWithType(Type T) const {
}

/// Create one from what's present in the parameter decl and type
inline ParameterTypeFlags
ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
bool isAutoClosure, bool isNonEphemeral,
// SWIFT_ENABLE_TENSORFLOW
ValueOwnership ownership,
bool isNonDifferentiable) {
inline ParameterTypeFlags ParameterTypeFlags::fromParameterType(
Type paramTy, bool isVariadic, bool isAutoClosure, bool isNonEphemeral,
ValueOwnership ownership, bool isNoDerivative) {
// FIXME(Remove InOut): The last caller that needs this is argument
// decomposition. Start by enabling the assertion there and fixing up those
// callers, then remove this, then remove
Expand All @@ -6067,9 +6055,7 @@ ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
ownership == ValueOwnership::InOut);
ownership = ValueOwnership::InOut;
}
// SWIFT_ENABLE_TENSORFLOW
return {isVariadic, isAutoClosure, isNonEphemeral, ownership,
isNonDifferentiable};
return {isVariadic, isAutoClosure, isNonEphemeral, ownership, isNoDerivative};
}

inline const Type *BoundGenericType::getTrailingObjectsPointer() const {
Expand Down
9 changes: 4 additions & 5 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3083,11 +3083,10 @@ void AnyFunctionType::decomposeInput(
}

default:
result.emplace_back(type->getInOutObjectType(), Identifier(),
ParameterTypeFlags::fromParameterType(
// SWIFT_ENABLE_TENSORFLOW
type, false, false, false, ValueOwnership::Default,
/*nonDifferentiable*/ false));
result.emplace_back(
type->getInOutObjectType(), Identifier(),
ParameterTypeFlags::fromParameterType(type, false, false, false,
ValueOwnership::Default, false));
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ Type ASTBuilder::createFunctionType(
.withVariadic(flags.isVariadic())
// SWIFT_ENABLE_TENSORFLOW
.withAutoClosure(flags.isAutoClosure())
.withNonDifferentiable(flags.isNonDifferentiable());
.withNoDerivative(flags.isNoDerivative());

funcParams.push_back(AnyFunctionType::Param(type, label, parameterFlags));
}
Expand Down
7 changes: 3 additions & 4 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2497,9 +2497,8 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
ParameterTypeFlags flags, bool escaping) {
if (!options.excludeAttrKind(TAK_autoclosure) && flags.isAutoClosure())
printer << "@autoclosure ";
// SWIFT_ENABLE_TENSORFLOW
if (!options.excludeAttrKind(TAK_nondiff) && flags.isNonDifferentiable())
printer << "@nondiff ";
if (!options.excludeAttrKind(TAK_noDerivative) && flags.isNoDerivative())
printer << "@noDerivative ";

switch (flags.getValueOwnership()) {
case ValueOwnership::Default:
Expand Down Expand Up @@ -4577,7 +4576,7 @@ void SILParameterInfo::print(ASTPrinter &Printer,
/// SWIFT_ENABLE_TENSORFLOW
switch (getDifferentiability()) {
case SILParameterDifferentiability::NotDifferentiable:
Printer << "@nondiff ";
Printer << "@noDerivative ";
break;
default:
break;
Expand Down
11 changes: 4 additions & 7 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6109,13 +6109,10 @@ AnyFunctionType::Param ParamDecl::toFunctionParam(Type type) const {
type = ParamDecl::getVarargBaseTy(type);

auto label = getArgumentName();
auto flags = ParameterTypeFlags::fromParameterType(type,
isVariadic(),
isAutoClosure(),
isNonEphemeral(),
// SWIFT_ENABLE_TENSORFLOW
getValueOwnership(),
/*nondifferentiable*/ false);
auto flags = ParameterTypeFlags::fromParameterType(
type, isVariadic(), isAutoClosure(), isNonEphemeral(),
getValueOwnership(),
/*isNoDerivative*/ false);
return AnyFunctionType::Param(type, label, flags);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5084,7 +5084,7 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
};
auto constrainParametersAndResult = [&](ProtocolDecl *protocol) {
for (auto &param : fnTy->getParams())
if (!param.isNonDifferentiable())
if (!param.isNoDerivative())
addConstraint(param.getPlainType(), protocol);
addConstraint(fnTy->getResult(), protocol);
};
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5164,7 +5164,7 @@ AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
SmallVector<Param, 8> newParams;
for (auto &param : getParams()) {
Param newParam(param.getPlainType(), param.getLabel(),
param.getParameterFlags().withNonDifferentiable(false));
param.getParameterFlags().withNoDerivative(false));
newParams.push_back(newParam);
}
auto nonDiffExtInfo = getExtInfo()
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/TypeRepr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
Printer.printSimpleAttr("@autoclosure") << " ";
if (hasAttr(TAK_escaping))
Printer.printSimpleAttr("@escaping") << " ";
if (hasAttr(TAK_noDerivative) || hasAttr(TAK_nondiff))
Printer.printSimpleAttr("@noDerivative") << " ";

if (hasAttr(TAK_differentiable)) {
if (Attrs.isLinear()) {
Expand Down
2 changes: 1 addition & 1 deletion lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ class DestructureInputs {

visit(flags.getValueOwnership(), /*forSelf=*/false,
// SWIFT_ENABLE_TENSORFLOW
eltPattern, ty, silRepresentation, flags.isNonDifferentiable());
eltPattern, ty, silRepresentation, flags.isNoDerivative());
}

// Process the self parameter. Note that we implicitly drop self
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3337,7 +3337,7 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
}
llvm::SmallBitVector parameterBits(numUncurriedParams);
for (auto i : range(inputSubstType->getNumParams()))
if (!inputSubstType->getParams()[i].isNonDifferentiable())
if (!inputSubstType->getParams()[i].isNoDerivative())
parameterBits.set(i);
auto *parameterIndices = IndexSubset::get(SGF.getASTContext(), parameterBits);

Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ emitDerivativeFunctionReference(
for (auto i : desiredIndices.parameters->getIndices()) {
if (!paramIndices->contains(i)) {
context.emitNondifferentiabilityError(functionSource, invoker,
diag::autodiff_function_nondiff_parameter_not_differentiable);
diag::autodiff_function_noderivative_parameter_not_differentiable);
return None;
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ void JVPEmitter::visitApplyInst(ApplyInst *ai) {
if (!paramIndices->contains(i)) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_function_nondiff_parameter_not_differentiable);
diag::autodiff_function_noderivative_parameter_not_differentiable);
errorOccurred = true;
return;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
if (!paramIndices->contains(i)) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_function_nondiff_parameter_not_differentiable);
diag::autodiff_function_noderivative_parameter_not_differentiable);
errorOccurred = true;
return;
}
Expand Down
Loading