Skip to content

Commit ec0a2ca

Browse files
authored
Merge pull request #28278 from marcrasi/ast-nondiff
2 parents bab1239 + 72194c5 commit ec0a2ca

16 files changed

+152
-43
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TYPE_ATTR(convention)
5252
TYPE_ATTR(noescape)
5353
TYPE_ATTR(escaping)
5454
TYPE_ATTR(differentiable)
55+
TYPE_ATTR(noDerivative)
5556

5657
// SIL-specific attributes
5758
TYPE_ATTR(block_storage)

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,6 +3926,11 @@ ERROR(opaque_type_in_protocol_requirement,none,
39263926
"'some' type cannot be the return type of a protocol requirement; did you mean to add an associated type?",
39273927
())
39283928

3929+
// Function differentiability
3930+
ERROR(attr_only_on_parameters_of_differentiable,none,
3931+
"'%0' may only be used on parameters of '@differentiable' function "
3932+
"types", (StringRef))
3933+
39293934
// SIL
39303935
ERROR(opened_non_protocol,none,
39313936
"@opened cannot be applied to non-protocol type %0", (Type))

include/swift/AST/Types.h

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,8 +1808,8 @@ class ParameterTypeFlags {
18081808
NonEphemeral = 1 << 2,
18091809
OwnershipShift = 3,
18101810
Ownership = 7 << OwnershipShift,
1811-
1812-
NumBits = 6
1811+
NoDerivative = 1 << 7,
1812+
NumBits = 7
18131813
};
18141814
OptionSet<ParameterFlags> value;
18151815
static_assert(NumBits < 8*sizeof(OptionSet<ParameterFlags>), "overflowed");
@@ -1823,15 +1823,17 @@ class ParameterTypeFlags {
18231823
}
18241824

18251825
ParameterTypeFlags(bool variadic, bool autoclosure, bool nonEphemeral,
1826-
ValueOwnership ownership)
1826+
ValueOwnership ownership, bool noDerivative)
18271827
: value((variadic ? Variadic : 0) | (autoclosure ? AutoClosure : 0) |
18281828
(nonEphemeral ? NonEphemeral : 0) |
1829-
uint8_t(ownership) << OwnershipShift) {}
1829+
uint8_t(ownership) << OwnershipShift |
1830+
(noDerivative ? NoDerivative : 0)) {}
18301831

18311832
/// Create one from what's present in the parameter type
18321833
inline static ParameterTypeFlags
18331834
fromParameterType(Type paramTy, bool isVariadic, bool isAutoClosure,
1834-
bool isNonEphemeral, ValueOwnership ownership);
1835+
bool isNonEphemeral, ValueOwnership ownership,
1836+
bool isNoDerivative);
18351837

18361838
bool isNone() const { return !value; }
18371839
bool isVariadic() const { return value.contains(Variadic); }
@@ -1840,6 +1842,7 @@ class ParameterTypeFlags {
18401842
bool isInOut() const { return getValueOwnership() == ValueOwnership::InOut; }
18411843
bool isShared() const { return getValueOwnership() == ValueOwnership::Shared;}
18421844
bool isOwned() const { return getValueOwnership() == ValueOwnership::Owned; }
1845+
bool isNoDerivative() const { return value.contains(NoDerivative); }
18431846

18441847
ValueOwnership getValueOwnership() const {
18451848
return ValueOwnership((value.toRaw() & Ownership) >> OwnershipShift);
@@ -1882,6 +1885,12 @@ class ParameterTypeFlags {
18821885
: value - ParameterTypeFlags::NonEphemeral);
18831886
}
18841887

1888+
ParameterTypeFlags withNoDerivative(bool noDerivative) const {
1889+
return ParameterTypeFlags(noDerivative
1890+
? value | ParameterTypeFlags::NoDerivative
1891+
: value - ParameterTypeFlags::NoDerivative);
1892+
}
1893+
18851894
bool operator ==(const ParameterTypeFlags &other) const {
18861895
return value.toRaw() == other.value.toRaw();
18871896
}
@@ -1948,8 +1957,8 @@ class YieldTypeFlags {
19481957
ParameterTypeFlags asParamFlags() const {
19491958
return ParameterTypeFlags(/*variadic*/ false,
19501959
/*autoclosure*/ false,
1951-
/*nonEphemeral*/ false,
1952-
getValueOwnership());
1960+
/*nonEphemeral*/ false, getValueOwnership(),
1961+
/*noDerivative*/ false);
19531962
}
19541963

19551964
bool operator ==(const YieldTypeFlags &other) const {
@@ -2821,6 +2830,9 @@ class AnyFunctionType : public TypeBase {
28212830
/// Whether the parameter is marked '@_nonEphemeral'
28222831
bool isNonEphemeral() const { return Flags.isNonEphemeral(); }
28232832

2833+
/// Whether the parameter is marked '@noDerivative'.
2834+
bool isNoDerivative() const { return Flags.isNoDerivative(); }
2835+
28242836
ValueOwnership getValueOwnership() const {
28252837
return Flags.getValueOwnership();
28262838
}
@@ -5818,10 +5830,9 @@ inline TupleTypeElt TupleTypeElt::getWithType(Type T) const {
58185830
}
58195831

58205832
/// Create one from what's present in the parameter decl and type
5821-
inline ParameterTypeFlags
5822-
ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
5823-
bool isAutoClosure, bool isNonEphemeral,
5824-
ValueOwnership ownership) {
5833+
inline ParameterTypeFlags ParameterTypeFlags::fromParameterType(
5834+
Type paramTy, bool isVariadic, bool isAutoClosure, bool isNonEphemeral,
5835+
ValueOwnership ownership, bool isNoDerivative) {
58255836
// FIXME(Remove InOut): The last caller that needs this is argument
58265837
// decomposition. Start by enabling the assertion there and fixing up those
58275838
// callers, then remove this, then remove
@@ -5831,7 +5842,7 @@ ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
58315842
ownership == ValueOwnership::InOut);
58325843
ownership = ValueOwnership::InOut;
58335844
}
5834-
return {isVariadic, isAutoClosure, isNonEphemeral, ownership};
5845+
return {isVariadic, isAutoClosure, isNonEphemeral, ownership, isNoDerivative};
58355846
}
58365847

58375848
inline const Type *BoundGenericType::getTrailingObjectsPointer() const {

lib/AST/ASTContext.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,9 +2893,10 @@ void AnyFunctionType::decomposeInput(
28932893
}
28942894

28952895
default:
2896-
result.emplace_back(type->getInOutObjectType(), Identifier(),
2897-
ParameterTypeFlags::fromParameterType(
2898-
type, false, false, false, ValueOwnership::Default));
2896+
result.emplace_back(
2897+
type->getInOutObjectType(), Identifier(),
2898+
ParameterTypeFlags::fromParameterType(type, false, false, false,
2899+
ValueOwnership::Default, false));
28992900
return;
29002901
}
29012902
}

lib/AST/ASTPrinter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2502,6 +2502,8 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
25022502
ParameterTypeFlags flags, bool escaping) {
25032503
if (!options.excludeAttrKind(TAK_autoclosure) && flags.isAutoClosure())
25042504
printer << "@autoclosure ";
2505+
if (!options.excludeAttrKind(TAK_noDerivative) && flags.isNoDerivative())
2506+
printer << "@noDerivative ";
25052507

25062508
switch (flags.getValueOwnership()) {
25072509
case ValueOwnership::Default:

lib/AST/Decl.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6039,11 +6039,10 @@ AnyFunctionType::Param ParamDecl::toFunctionParam(Type type) const {
60396039
type = ParamDecl::getVarargBaseTy(type);
60406040

60416041
auto label = getArgumentName();
6042-
auto flags = ParameterTypeFlags::fromParameterType(type,
6043-
isVariadic(),
6044-
isAutoClosure(),
6045-
isNonEphemeral(),
6046-
getValueOwnership());
6042+
auto flags = ParameterTypeFlags::fromParameterType(
6043+
type, isVariadic(), isAutoClosure(), isNonEphemeral(),
6044+
getValueOwnership(),
6045+
/*isNoDerivative*/ false);
60476046
return AnyFunctionType::Param(type, label, flags);
60486047
}
60496048

lib/AST/TypeRepr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
298298
Printer.printSimpleAttr("@autoclosure") << " ";
299299
if (hasAttr(TAK_escaping))
300300
Printer.printSimpleAttr("@escaping") << " ";
301+
if (hasAttr(TAK_noDerivative))
302+
Printer.printSimpleAttr("@noDerivative") << " ";
301303

302304
if (hasAttr(TAK_differentiable)) {
303305
if (Attrs.isLinear()) {

lib/Sema/TypeCheckType.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,7 @@ namespace {
17931793
resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
17941794
TypeResolutionOptions options,
17951795
bool requiresMappingOut,
1796+
DifferentiabilityKind diffKind,
17961797
SmallVectorImpl<AnyFunctionType::Param> &ps);
17971798

17981799
Type resolveSILFunctionType(FunctionTypeRepr *repr,
@@ -2026,6 +2027,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
20262027
// Remember whether this is a function parameter.
20272028
bool isParam = options.is(TypeResolverContext::FunctionInput);
20282029

2030+
// Remember whether this is a variadic function parameter.
2031+
bool isVariadicFunctionParam =
2032+
options.is(TypeResolverContext::VariadicFunctionInput) &&
2033+
!options.hasBase(TypeResolverContext::EnumElementDecl);
2034+
20292035
// The type we're working with, in case we want to build it differently
20302036
// based on the attributes we see.
20312037
Type ty;
@@ -2370,6 +2376,21 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
23702376
attrs.ConventionArguments = None;
23712377
}
23722378

2379+
if (attrs.has(TAK_noDerivative)) {
2380+
if (!Context.LangOpts.EnableExperimentalDifferentiableProgramming) {
2381+
diagnose(attrs.getLoc(TAK_noDerivative),
2382+
diag::experimental_differentiable_programming_disabled);
2383+
} else if (!isParam) {
2384+
// @noDerivative is only valid on parameters.
2385+
diagnose(attrs.getLoc(TAK_noDerivative),
2386+
(isVariadicFunctionParam
2387+
? diag::attr_not_on_variadic_parameters
2388+
: diag::attr_only_on_parameters_of_differentiable),
2389+
"@noDerivative");
2390+
}
2391+
attrs.clearAttribute(TAK_noDerivative);
2392+
}
2393+
23732394
// In SIL, handle @opened (n), which creates an existential archetype.
23742395
if (attrs.has(TAK_opened)) {
23752396
if (!ty->isExistentialType()) {
@@ -2422,7 +2443,7 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
24222443

24232444
bool TypeResolver::resolveASTFunctionTypeParams(
24242445
TupleTypeRepr *inputRepr, TypeResolutionOptions options,
2425-
bool requiresMappingOut,
2446+
bool requiresMappingOut, DifferentiabilityKind diffKind,
24262447
SmallVectorImpl<AnyFunctionType::Param> &elements) {
24272448
elements.reserve(inputRepr->getNumElements());
24282449

@@ -2486,8 +2507,24 @@ bool TypeResolver::resolveASTFunctionTypeParams(
24862507
ownership = ValueOwnership::Default;
24872508
break;
24882509
}
2510+
2511+
bool noDerivative = false;
2512+
if (auto *attrTypeRepr = dyn_cast<AttributedTypeRepr>(eltTypeRepr)) {
2513+
if (attrTypeRepr->getAttrs().has(TAK_noDerivative)) {
2514+
if (diffKind == DifferentiabilityKind::NonDifferentiable &&
2515+
Context.LangOpts.EnableExperimentalDifferentiableProgramming)
2516+
diagnose(eltTypeRepr->getLoc(),
2517+
diag::attr_only_on_parameters_of_differentiable,
2518+
"@noDerivative")
2519+
.highlight(eltTypeRepr->getSourceRange());
2520+
else
2521+
noDerivative = true;
2522+
}
2523+
}
2524+
24892525
auto paramFlags = ParameterTypeFlags::fromParameterType(
2490-
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership);
2526+
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership,
2527+
noDerivative);
24912528
elements.emplace_back(ty, Identifier(), paramFlags);
24922529
}
24932530

@@ -2541,7 +2578,8 @@ Type TypeResolver::resolveASTFunctionType(
25412578

25422579
SmallVector<AnyFunctionType::Param, 8> params;
25432580
if (resolveASTFunctionTypeParams(repr->getArgsTypeRepr(), options,
2544-
repr->getGenericEnvironment() != nullptr, params)) {
2581+
repr->getGenericEnvironment() != nullptr,
2582+
diffKind, params)) {
25452583
return Type();
25462584
}
25472585

lib/Serialization/Deserialization.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4770,12 +4770,11 @@ class swift::TypeDeserializer {
47704770

47714771
IdentifierID labelID;
47724772
TypeID typeID;
4773-
bool isVariadic, isAutoClosure, isNonEphemeral;
4773+
bool isVariadic, isAutoClosure, isNonEphemeral, isNoDerivative;
47744774
unsigned rawOwnership;
4775-
decls_block::FunctionParamLayout::readRecord(scratch, labelID, typeID,
4776-
isVariadic, isAutoClosure,
4777-
isNonEphemeral,
4778-
rawOwnership);
4775+
decls_block::FunctionParamLayout::readRecord(
4776+
scratch, labelID, typeID, isVariadic, isAutoClosure, isNonEphemeral,
4777+
rawOwnership, isNoDerivative);
47794778

47804779
auto ownership =
47814780
getActualValueOwnership((serialization::ValueOwnership)rawOwnership);
@@ -4786,10 +4785,10 @@ class swift::TypeDeserializer {
47864785
if (!paramTy)
47874786
return paramTy.takeError();
47884787

4789-
params.emplace_back(paramTy.get(),
4790-
MF.getIdentifier(labelID),
4788+
params.emplace_back(paramTy.get(), MF.getIdentifier(labelID),
47914789
ParameterTypeFlags(isVariadic, isAutoClosure,
4792-
isNonEphemeral, *ownership));
4790+
isNonEphemeral, *ownership,
4791+
isNoDerivative));
47934792
}
47944793

47954794
if (!isGeneric) {

lib/Serialization/ModuleFormat.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 530; // @_implicitly_synthesizes_nested_requirement
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 531; // function parameter noDerivative
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///
@@ -905,12 +905,13 @@ namespace decls_block {
905905

906906
using FunctionParamLayout = BCRecordLayout<
907907
FUNCTION_PARAM,
908-
IdentifierIDField, // name
909-
TypeIDField, // type
910-
BCFixed<1>, // vararg?
911-
BCFixed<1>, // autoclosure?
912-
BCFixed<1>, // non-ephemeral?
913-
ValueOwnershipField // inout, shared or owned?
908+
IdentifierIDField, // name
909+
TypeIDField, // type
910+
BCFixed<1>, // vararg?
911+
BCFixed<1>, // autoclosure?
912+
BCFixed<1>, // non-ephemeral?
913+
ValueOwnershipField, // inout, shared or owned?
914+
BCFixed<1> // noDerivative?
914915
>;
915916

916917
using MetatypeTypeLayout = BCRecordLayout<

lib/Serialization/Serialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4021,8 +4021,8 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
40214021
S.Out, S.ScratchRecord, abbrCode,
40224022
S.addDeclBaseNameRef(param.getLabel()),
40234023
S.addTypeRef(param.getPlainType()), paramFlags.isVariadic(),
4024-
paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(),
4025-
rawOwnership);
4024+
paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(), rawOwnership,
4025+
paramFlags.isNoDerivative());
40264026
}
40274027
}
40284028

test/AutoDiff/ModuleInterface/differentiation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ public func a(f: @differentiable (Float) -> Float) {}
66

77
public func b(f: @differentiable(linear) (Float) -> Float) {}
88
// CHECK: public func b(f: @differentiable(linear) (Swift.Float) -> Swift.Float)
9+
10+
public func c(f: @differentiable (Float, @noDerivative Float) -> Float) {}
11+
// CHECK: public func c(f: @differentiable (Swift.Float, @noDerivative Swift.Float) -> Swift.Float)

test/AutoDiff/Parse/differentiable_func_type.swift

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,25 @@ let b: @differentiable(linear) (Float) -> Float // okay
88
// CHECK: (pattern_named 'b'
99
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
1010

11-
let c: @differentiable (Float) throws -> Float // okay
11+
let c: @differentiable (Float, @noDerivative Float) -> Float // okay
1212
// CHECK: (pattern_named 'c'
13+
// CHECK-NEXT: (type_attributed attrs=@differentiable
14+
// CHECK-NEXT: (type_function
15+
// CHECK-NEXT: (type_tuple
16+
// CHECK-NEXT: (type_ident
17+
// CHECK-NEXT: (component id='Float' bind=none))
18+
// CHECK-NEXT: (type_attributed attrs=@noDerivative
19+
// CHECK-NEXT: (type_ident
20+
// CHECK-NEXT: (component id='Float' bind=none)))
21+
// CHECK-NEXT: (type_ident
22+
// CHECK-NEXT: (component id='Float' bind=none)))))
23+
24+
let d: @differentiable (Float) throws -> Float // okay
25+
// CHECK: (pattern_named 'd'
1326
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
1427

15-
let d: @differentiable(linear) (Float) throws -> Float // okay
16-
// CHECK: (pattern_named 'd'
28+
let e: @differentiable(linear) (Float) throws -> Float // okay
29+
// CHECK: (pattern_named 'e'
1730
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
1831

1932
// Generic type test.

test/AutoDiff/Sema/differentiable_features_disabled.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
44
let _: @differentiable (Float) -> Float
55

6+
// expected-error @+2 {{differentiable programming is an experimental feature that is currently disabled}}
7+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
8+
let _: @differentiable (Float, @noDerivative Float) -> Float
9+
10+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
11+
let _: (Float, @noDerivative Float) -> Float
12+
13+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
14+
let _: @noDerivative Float
15+
616
func id(_ x: Float) -> Float {
717
return x
818
}

0 commit comments

Comments
 (0)