Skip to content

Commit f5677cb

Browse files
author
Marc Rasi
committed
add @nondiff to AnyFunctionType params
1 parent d502355 commit f5677cb

16 files changed

+147
-39
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TYPE_ATTR(convention)
5252
TYPE_ATTR(noescape)
5353
TYPE_ATTR(escaping)
5454
TYPE_ATTR(differentiable)
55+
TYPE_ATTR(nondiff)
5556

5657
// SIL-specific attributes
5758
TYPE_ATTR(block_storage)

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3866,6 +3866,11 @@ ERROR(opaque_type_in_protocol_requirement,none,
38663866
"'some' type cannot be the return type of a protocol requirement; did you mean to add an associated type?",
38673867
())
38683868

3869+
// Function differentiability
3870+
ERROR(attr_only_on_parameters_of_differentiable,none,
3871+
"'%0' may only be used on parameters of '@differentiable' function "
3872+
"types", (StringRef))
3873+
38693874
// SIL
38703875
ERROR(opened_non_protocol,none,
38713876
"@opened cannot be applied to non-protocol type %0", (Type))

include/swift/AST/Types.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,8 +1778,8 @@ class ParameterTypeFlags {
17781778
NonEphemeral = 1 << 2,
17791779
OwnershipShift = 3,
17801780
Ownership = 7 << OwnershipShift,
1781-
1782-
NumBits = 6
1781+
NonDifferentiable = 1 << 7,
1782+
NumBits = 7
17831783
};
17841784
OptionSet<ParameterFlags> value;
17851785
static_assert(NumBits < 8*sizeof(OptionSet<ParameterFlags>), "overflowed");
@@ -1793,15 +1793,17 @@ class ParameterTypeFlags {
17931793
}
17941794

17951795
ParameterTypeFlags(bool variadic, bool autoclosure, bool nonEphemeral,
1796-
ValueOwnership ownership)
1796+
ValueOwnership ownership, bool nonDifferentiable)
17971797
: value((variadic ? Variadic : 0) | (autoclosure ? AutoClosure : 0) |
17981798
(nonEphemeral ? NonEphemeral : 0) |
1799-
uint8_t(ownership) << OwnershipShift) {}
1799+
uint8_t(ownership) << OwnershipShift |
1800+
(nonDifferentiable ? NonDifferentiable : 0)) {}
18001801

18011802
/// Create one from what's present in the parameter type
18021803
inline static ParameterTypeFlags
18031804
fromParameterType(Type paramTy, bool isVariadic, bool isAutoClosure,
1804-
bool isNonEphemeral, ValueOwnership ownership);
1805+
bool isNonEphemeral, ValueOwnership ownership,
1806+
bool isNonDifferentiable);
18051807

18061808
bool isNone() const { return !value; }
18071809
bool isVariadic() const { return value.contains(Variadic); }
@@ -1810,6 +1812,7 @@ class ParameterTypeFlags {
18101812
bool isInOut() const { return getValueOwnership() == ValueOwnership::InOut; }
18111813
bool isShared() const { return getValueOwnership() == ValueOwnership::Shared;}
18121814
bool isOwned() const { return getValueOwnership() == ValueOwnership::Owned; }
1815+
bool isNonDifferentiable() const { return value.contains(NonDifferentiable); }
18131816

18141817
ValueOwnership getValueOwnership() const {
18151818
return ValueOwnership((value.toRaw() & Ownership) >> OwnershipShift);
@@ -1852,6 +1855,12 @@ class ParameterTypeFlags {
18521855
: value - ParameterTypeFlags::NonEphemeral);
18531856
}
18541857

1858+
ParameterTypeFlags withNonDifferentiable(bool nonDifferentiable) const {
1859+
return ParameterTypeFlags(
1860+
nonDifferentiable ? value | ParameterTypeFlags::NonDifferentiable
1861+
: value - ParameterTypeFlags::NonDifferentiable);
1862+
}
1863+
18551864
bool operator ==(const ParameterTypeFlags &other) const {
18561865
return value.toRaw() == other.value.toRaw();
18571866
}
@@ -1919,7 +1928,8 @@ class YieldTypeFlags {
19191928
return ParameterTypeFlags(/*variadic*/ false,
19201929
/*autoclosure*/ false,
19211930
/*nonEphemeral*/ false,
1922-
getValueOwnership());
1931+
getValueOwnership(),
1932+
/*nonDifferentiable*/ false);
19231933
}
19241934

19251935
bool operator ==(const YieldTypeFlags &other) const {
@@ -2791,6 +2801,9 @@ class AnyFunctionType : public TypeBase {
27912801
/// Whether the parameter is marked '@_nonEphemeral'
27922802
bool isNonEphemeral() const { return Flags.isNonEphemeral(); }
27932803

2804+
/// Whether the parameter is marked '@nondiff'.
2805+
bool isNonDifferentiable() const { return Flags.isNonDifferentiable(); }
2806+
27942807
ValueOwnership getValueOwnership() const {
27952808
return Flags.getValueOwnership();
27962809
}
@@ -5684,10 +5697,9 @@ inline TupleTypeElt TupleTypeElt::getWithType(Type T) const {
56845697
}
56855698

56865699
/// Create one from what's present in the parameter decl and type
5687-
inline ParameterTypeFlags
5688-
ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
5689-
bool isAutoClosure, bool isNonEphemeral,
5690-
ValueOwnership ownership) {
5700+
inline ParameterTypeFlags ParameterTypeFlags::fromParameterType(
5701+
Type paramTy, bool isVariadic, bool isAutoClosure, bool isNonEphemeral,
5702+
ValueOwnership ownership, bool isNonDifferentiable) {
56915703
// FIXME(Remove InOut): The last caller that needs this is argument
56925704
// decomposition. Start by enabling the assertion there and fixing up those
56935705
// callers, then remove this, then remove
@@ -5697,7 +5709,8 @@ ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic,
56975709
ownership == ValueOwnership::InOut);
56985710
ownership = ValueOwnership::InOut;
56995711
}
5700-
return {isVariadic, isAutoClosure, isNonEphemeral, ownership};
5712+
return {isVariadic, isAutoClosure, isNonEphemeral, ownership,
5713+
isNonDifferentiable};
57015714
}
57025715

57035716
inline const Type *BoundGenericType::getTrailingObjectsPointer() const {

lib/AST/ASTContext.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,9 +2885,10 @@ void AnyFunctionType::decomposeInput(
28852885
}
28862886

28872887
default:
2888-
result.emplace_back(type->getInOutObjectType(), Identifier(),
2889-
ParameterTypeFlags::fromParameterType(
2890-
type, false, false, false, ValueOwnership::Default));
2888+
result.emplace_back(
2889+
type->getInOutObjectType(), Identifier(),
2890+
ParameterTypeFlags::fromParameterType(type, false, false, false,
2891+
ValueOwnership::Default, false));
28912892
return;
28922893
}
28932894
}

lib/AST/ASTPrinter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,6 +2476,8 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options,
24762476
ParameterTypeFlags flags, bool escaping) {
24772477
if (!options.excludeAttrKind(TAK_autoclosure) && flags.isAutoClosure())
24782478
printer << "@autoclosure ";
2479+
if (!options.excludeAttrKind(TAK_nondiff) && flags.isNonDifferentiable())
2480+
printer << "@nondiff ";
24792481

24802482
switch (flags.getValueOwnership()) {
24812483
case ValueOwnership::Default:

lib/AST/Decl.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6035,11 +6035,10 @@ AnyFunctionType::Param ParamDecl::toFunctionParam(Type type) const {
60356035
type = ParamDecl::getVarargBaseTy(type);
60366036

60376037
auto label = getArgumentName();
6038-
auto flags = ParameterTypeFlags::fromParameterType(type,
6039-
isVariadic(),
6040-
isAutoClosure(),
6041-
isNonEphemeral(),
6042-
getValueOwnership());
6038+
auto flags = ParameterTypeFlags::fromParameterType(
6039+
type, isVariadic(), isAutoClosure(), isNonEphemeral(),
6040+
getValueOwnership(),
6041+
/*isNonDifferentiable*/ false);
60436042
return AnyFunctionType::Param(type, label, flags);
60446043
}
60456044

lib/AST/TypeRepr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
298298
Printer.printSimpleAttr("@autoclosure") << " ";
299299
if (hasAttr(TAK_escaping))
300300
Printer.printSimpleAttr("@escaping") << " ";
301+
if (hasAttr(TAK_nondiff))
302+
Printer.printSimpleAttr("@nondiff") << " ";
301303

302304
if (hasAttr(TAK_differentiable)) {
303305
if (Attrs.isLinear()) {

lib/Sema/TypeCheckType.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,7 @@ namespace {
17361736
bool
17371737
resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
17381738
TypeResolutionOptions options,
1739-
bool requiresMappingOut,
1739+
bool requiresMappingOut, bool isDifferentiable,
17401740
SmallVectorImpl<AnyFunctionType::Param> &ps);
17411741

17421742
Type resolveSILFunctionType(FunctionTypeRepr *repr,
@@ -1970,6 +1970,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
19701970
// Remember whether this is a function parameter.
19711971
bool isParam = options.is(TypeResolverContext::FunctionInput);
19721972

1973+
// Remember whether this is a variadic function parameter.
1974+
bool isVariadicFunctionParam =
1975+
options.is(TypeResolverContext::VariadicFunctionInput) &&
1976+
!options.hasBase(TypeResolverContext::EnumElementDecl);
1977+
19731978
// The type we're working with, in case we want to build it differently
19741979
// based on the attributes we see.
19751980
Type ty;
@@ -2198,10 +2203,6 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
21982203

21992204
// @autoclosure is only valid on parameters.
22002205
if (!isParam && attrs.has(TAK_autoclosure)) {
2201-
bool isVariadicFunctionParam =
2202-
options.is(TypeResolverContext::VariadicFunctionInput) &&
2203-
!options.hasBase(TypeResolverContext::EnumElementDecl);
2204-
22052206
diagnose(attrs.getLoc(TAK_autoclosure),
22062207
isVariadicFunctionParam ? diag::attr_not_on_variadic_parameters
22072208
: diag::attr_only_on_parameters,
@@ -2308,6 +2309,21 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
23082309
attrs.convention = None;
23092310
}
23102311

2312+
if (attrs.has(TAK_nondiff)) {
2313+
if (!Context.LangOpts.EnableExperimentalDifferentiableProgramming) {
2314+
diagnose(attrs.getLoc(TAK_nondiff),
2315+
diag::experimental_differentiable_programming_disabled);
2316+
} else if (!isParam) {
2317+
// @nondiff is only valid on parameters.
2318+
diagnose(attrs.getLoc(TAK_nondiff),
2319+
(isVariadicFunctionParam
2320+
? diag::attr_not_on_variadic_parameters
2321+
: diag::attr_only_on_parameters_of_differentiable),
2322+
"@nondiff");
2323+
}
2324+
attrs.clearAttribute(TAK_nondiff);
2325+
}
2326+
23112327
// In SIL, handle @opened (n), which creates an existential archetype.
23122328
if (attrs.has(TAK_opened)) {
23132329
if (!ty->isExistentialType()) {
@@ -2360,7 +2376,7 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
23602376

23612377
bool TypeResolver::resolveASTFunctionTypeParams(
23622378
TupleTypeRepr *inputRepr, TypeResolutionOptions options,
2363-
bool requiresMappingOut,
2379+
bool requiresMappingOut, bool isDifferentiable,
23642380
SmallVectorImpl<AnyFunctionType::Param> &elements) {
23652381
elements.reserve(inputRepr->getNumElements());
23662382

@@ -2424,8 +2440,23 @@ bool TypeResolver::resolveASTFunctionTypeParams(
24242440
ownership = ValueOwnership::Default;
24252441
break;
24262442
}
2443+
2444+
bool nondiff = false;
2445+
if (auto *attrTypeRepr = dyn_cast<AttributedTypeRepr>(eltTypeRepr)) {
2446+
if (attrTypeRepr->getAttrs().has(TAK_nondiff)) {
2447+
if (!isDifferentiable &&
2448+
Context.LangOpts.EnableExperimentalDifferentiableProgramming)
2449+
diagnose(eltTypeRepr->getLoc(),
2450+
diag::attr_only_on_parameters_of_differentiable, "@nondiff")
2451+
.highlight(eltTypeRepr->getSourceRange());
2452+
else
2453+
nondiff = true;
2454+
}
2455+
}
2456+
24272457
auto paramFlags = ParameterTypeFlags::fromParameterType(
2428-
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership);
2458+
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership,
2459+
nondiff);
24292460
elements.emplace_back(ty, Identifier(), paramFlags);
24302461
}
24312462

@@ -2477,7 +2508,8 @@ Type TypeResolver::resolveASTFunctionType(FunctionTypeRepr *repr,
24772508

24782509
SmallVector<AnyFunctionType::Param, 8> params;
24792510
if (resolveASTFunctionTypeParams(repr->getArgsTypeRepr(), options,
2480-
repr->getGenericEnvironment() != nullptr, params)) {
2511+
repr->getGenericEnvironment() != nullptr,
2512+
extInfo.isDifferentiable(), params)) {
24812513
return Type();
24822514
}
24832515

lib/Serialization/Deserialization.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4673,12 +4673,13 @@ class swift::TypeDeserializer {
46734673

46744674
IdentifierID labelID;
46754675
TypeID typeID;
4676-
bool isVariadic, isAutoClosure, isNonEphemeral;
4676+
bool isVariadic, isAutoClosure, isNonEphemeral, isNonDifferentiable;
46774677
unsigned rawOwnership;
46784678
decls_block::FunctionParamLayout::readRecord(scratch, labelID, typeID,
46794679
isVariadic, isAutoClosure,
46804680
isNonEphemeral,
4681-
rawOwnership);
4681+
rawOwnership,
4682+
isNonDifferentiable);
46824683

46834684
auto ownership =
46844685
getActualValueOwnership((serialization::ValueOwnership)rawOwnership);
@@ -4692,7 +4693,8 @@ class swift::TypeDeserializer {
46924693
params.emplace_back(paramTy.get(),
46934694
MF.getIdentifier(labelID),
46944695
ParameterTypeFlags(isVariadic, isAutoClosure,
4695-
isNonEphemeral, *ownership));
4696+
isNonEphemeral, *ownership,
4697+
isNonDifferentiable));
46964698
}
46974699

46984700
if (!isGeneric) {

lib/Serialization/ModuleFormat.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5252
/// describe what change you made. The content of this comment isn't important;
5353
/// it just ensures a conflict if two people change the module format.
5454
/// Don't worry about adhering to the 80-column limit for this line.
55-
const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // function type differentiability
55+
const uint16_t SWIFTMODULE_VERSION_MINOR = 525; // function parameter nondiff
5656

5757
/// A standard hash seed used for all string hashes in a serialized module.
5858
///
@@ -893,12 +893,13 @@ namespace decls_block {
893893

894894
using FunctionParamLayout = BCRecordLayout<
895895
FUNCTION_PARAM,
896-
IdentifierIDField, // name
897-
TypeIDField, // type
898-
BCFixed<1>, // vararg?
899-
BCFixed<1>, // autoclosure?
900-
BCFixed<1>, // non-ephemeral?
901-
ValueOwnershipField // inout, shared or owned?
896+
IdentifierIDField, // name
897+
TypeIDField, // type
898+
BCFixed<1>, // vararg?
899+
BCFixed<1>, // autoclosure?
900+
BCFixed<1>, // non-ephemeral?
901+
ValueOwnershipField, // inout, shared or owned?
902+
BCFixed<1> // nondifferentiable?
902903
>;
903904

904905
using MetatypeTypeLayout = BCRecordLayout<

lib/Serialization/Serialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3931,8 +3931,8 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
39313931
S.Out, S.ScratchRecord, abbrCode,
39323932
S.addDeclBaseNameRef(param.getLabel()),
39333933
S.addTypeRef(param.getPlainType()), paramFlags.isVariadic(),
3934-
paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(),
3935-
rawOwnership);
3934+
paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(), rawOwnership,
3935+
paramFlags.isNonDifferentiable());
39363936
}
39373937
}
39383938

test/AutoDiff/ModuleInterface/differentiation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ public func a(f: @differentiable (Float) -> Float) {}
66

77
public func b(f: @differentiable(linear) (Float) -> Float) {}
88
// CHECK: public func b(f: @differentiable(linear) (Swift.Float) -> Swift.Float)
9+
10+
public func c(f: @differentiable (Float, @nondiff Float) -> Float) {}
11+
// CHECK: public func c(f: @differentiable (Swift.Float, @nondiff Swift.Float) -> Swift.Float)

test/AutoDiff/Parse/differentiable_func_type.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@ let b: @differentiable(linear) (Float) -> Float // okay
88
// CHECK: (pattern_named 'b'
99
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
1010

11+
let c: @differentiable (Float, @nondiff Float) -> Float // okay
12+
// CHECK: (pattern_named 'c'
13+
// CHECK-NEXT: (type_attributed attrs=@differentiable
14+
// CHECK-NEXT: (type_function
15+
// CHECK-NEXT: (type_tuple
16+
// CHECK-NEXT: (type_ident
17+
// CHECK-NEXT: (component id='Float' bind=none))
18+
// CHECK-NEXT: (type_attributed attrs=@nondiff
19+
// CHECK-NEXT: (type_ident
20+
// CHECK-NEXT: (component id='Float' bind=none)))
21+
// CHECK-NEXT: (type_ident
22+
// CHECK-NEXT: (component id='Float' bind=none)))))
23+
1124
// Generic type test.
1225
struct A<T> {
1326
func foo() {

test/AutoDiff/Sema/differentiable_features_disabled.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,13 @@
22

33
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
44
let _: @differentiable (Float) -> Float
5+
6+
// expected-error @+2 {{differentiable programming is an experimental feature that is currently disabled}}
7+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
8+
let _: @differentiable (Float, @nondiff Float) -> Float
9+
10+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
11+
let _: (Float, @nondiff Float) -> Float
12+
13+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
14+
let _: @nondiff Float

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,24 @@
44
let _: @differentiable Float
55

66
let _: @differentiable (Float) -> Float
7+
8+
// expected-error @+1 {{'@nondiff' may only be used on parameters of '@differentiable' function types}}
9+
let _: @nondiff Float
10+
11+
// expected-error @+1 {{'@nondiff' may only be used on parameters of '@differentiable' function types}}
12+
let _: (Float) -> @nondiff Float
13+
14+
// expected-error @+1 {{'@nondiff' may only be used on parameters of '@differentiable' function types}}
15+
let _: @differentiable (Float) -> @nondiff Float
16+
17+
// expected-error @+1 {{'@nondiff' may only be used on parameters of '@differentiable' function types}}
18+
let _: (@nondiff Float) -> Float
19+
20+
// expected-error @+2 {{'@nondiff' may only be used on parameters of '@differentiable' function types}}
21+
// expected-error @+1 {{'@nondiff' must not be used on variadic parameters}}
22+
let _: (Float, @nondiff Float...) -> Float
23+
24+
let _: @differentiable (@nondiff Float, Float) -> Float
25+
26+
// expected-error @+1 {{'@nondiff' must not be used on variadic parameters}}
27+
let _: @differentiable (Float, @nondiff Float...) -> Float

0 commit comments

Comments
 (0)