Skip to content

Commit ce756e3

Browse files
marcrasirxwei
authored andcommitted
[AutoDiff upstream] parsing for @differentiable function type (#27708)
Adds parsing for a type attribute `@differentiable`, which is optionally allowed to have argument `@differentiable(linear)`. The typechecker currently rejects all uses of `@differentiable` with "error: attribute does not apply to type". Future work (https://bugs.swift.org/browse/TF-871 https://bugs.swift.org/browse/TF-873) will update the typechecker to allow this attribute in places where it is allowed. Resolves https://bugs.swift.org/browse/TF-822.
1 parent d91e474 commit ce756e3

File tree

10 files changed

+192
-2
lines changed

10 files changed

+192
-2
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ TYPE_ATTR(autoclosure)
5151
TYPE_ATTR(convention)
5252
TYPE_ATTR(noescape)
5353
TYPE_ATTR(escaping)
54+
TYPE_ATTR(differentiable)
5455

5556
// SIL-specific attributes
5657
TYPE_ATTR(block_storage)

include/swift/AST/Attr.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class TypeAttributes {
6666
Optional<StringRef> convention = None;
6767
Optional<StringRef> conventionWitnessMethodProtocol = None;
6868

69+
// Indicates whether the type's '@differentiable' attribute has a 'linear'
70+
// argument.
71+
bool linear = false;
72+
6973
// For an opened existential type, the known ID.
7074
Optional<UUID> OpenedID;
7175

@@ -80,7 +84,15 @@ class TypeAttributes {
8084
TypeAttributes() {}
8185

8286
bool isValid() const { return AtLoc.isValid(); }
83-
87+
88+
bool isLinear() const {
89+
assert(
90+
!linear ||
91+
(linear && has(TAK_differentiable)) &&
92+
"Linear shouldn't have been true if there's no `@differentiable`");
93+
return linear;
94+
}
95+
8496
void clearAttribute(TypeAttrKind A) {
8597
AttrLocs[A] = SourceLoc();
8698
}

include/swift/AST/DiagnosticsParse.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,12 @@ ERROR(attr_specialize_expected_partial_or_full,none,
15081508
ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
15091509
"expected a member name as second parameter in '_implements' attribute", ())
15101510

1511+
// differentiable
1512+
ERROR(differentiable_attribute_expected_rparen,none,
1513+
"expected ')' in '@differentiable' attribute", ())
1514+
ERROR(unexpected_argument_differentiable,none,
1515+
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
1516+
15111517
//------------------------------------------------------------------------------
15121518
// MARK: Generics parsing diagnostics
15131519
//------------------------------------------------------------------------------

include/swift/AST/DiagnosticsSema.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4728,6 +4728,13 @@ NOTE(previous_function_builder_here, none,
47284728
ERROR(function_builder_arguments, none,
47294729
"function builder attributes cannot have arguments", ())
47304730

4731+
//------------------------------------------------------------------------------
4732+
// MARK: differentiable programming diagnostics
4733+
//------------------------------------------------------------------------------
4734+
ERROR(experimental_differentiable_programming_disabled, none,
4735+
"differentiable programming is an experimental feature that is "
4736+
"currently disabled", ())
4737+
47314738
#ifndef DIAG_NO_UNDEF
47324739
# if defined(DIAG)
47334740
# undef DIAG

lib/AST/TypeRepr.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
299299
if (hasAttr(TAK_escaping))
300300
Printer.printSimpleAttr("@escaping") << " ";
301301

302+
if (hasAttr(TAK_differentiable)) {
303+
if (Attrs.isLinear()) {
304+
Printer.printSimpleAttr("@differentiable(linear)") << " ";
305+
} else {
306+
Printer.printSimpleAttr("@differentiable") << " ";
307+
}
308+
}
309+
302310
if (hasAttr(TAK_thin))
303311
Printer.printSimpleAttr("@thin") << " ";
304312
if (hasAttr(TAK_thick))

lib/Parse/ParseDecl.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,53 @@ bool Parser::canParseTypeAttribute() {
19441944
/*justChecking*/ true);
19451945
}
19461946

1947+
/// Parses the '@differentiable' argument (no argument list, or '(linear)'),
1948+
/// and sets the appropriate fields on `Attributes`.
1949+
///
1950+
/// \param emitDiagnostics - if false, doesn't emit diagnostics
1951+
/// \returns true on error, false on success
1952+
static bool parseDifferentiableAttributeArgument(Parser &P,
1953+
TypeAttributes &Attributes,
1954+
bool emitDiagnostics) {
1955+
Parser::BacktrackingScope backtrack(P);
1956+
1957+
// Match '( <identifier> )', and store the identifier token to `argument`.
1958+
if (!P.consumeIf(tok::l_paren))
1959+
return false;
1960+
auto argument = P.Tok;
1961+
if (!P.consumeIf(tok::identifier))
1962+
return false;
1963+
if (!P.consumeIf(tok::r_paren)) {
1964+
// Special case handling for '( <identifier> (' so that we don't produce the
1965+
// misleading diagnostic "expected ',' separator" when the real issue is
1966+
// that the user forgot the ')' closing the '@differentiable' argument list.
1967+
if (P.Tok.is(tok::l_paren)) {
1968+
backtrack.cancelBacktrack();
1969+
if (emitDiagnostics)
1970+
P.diagnose(P.Tok, diag::differentiable_attribute_expected_rparen);
1971+
return true;
1972+
}
1973+
return false;
1974+
}
1975+
1976+
// If the next token is an arrow, then the matched '( <identifier> )' is
1977+
// actually the parameter type list, not an argument to '@differentiable'.
1978+
if (P.Tok.is(tok::arrow))
1979+
return false;
1980+
1981+
backtrack.cancelBacktrack();
1982+
1983+
if (argument.getText() != "linear") {
1984+
if (emitDiagnostics)
1985+
P.diagnose(argument, diag::unexpected_argument_differentiable,
1986+
argument.getText());
1987+
return true;
1988+
}
1989+
1990+
Attributes.linear = true;
1991+
return false;
1992+
}
1993+
19471994
/// \verbatim
19481995
/// attribute-type:
19491996
/// 'noreturn'
@@ -2157,6 +2204,13 @@ bool Parser::parseTypeAttribute(TypeAttributes &Attributes, SourceLoc AtLoc,
21572204
break;
21582205
}
21592206

2207+
case TAK_differentiable: {
2208+
if (parseDifferentiableAttributeArgument(*this, Attributes,
2209+
/*emitDiagnostics=*/!justChecking))
2210+
return true;
2211+
break;
2212+
}
2213+
21602214
// Convention attribute.
21612215
case TAK_convention:
21622216
Attributes.convention = conventionName;

lib/Sema/TypeCheckType.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2062,7 +2062,7 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
20622062
static const TypeAttrKind FunctionAttrs[] = {
20632063
TAK_convention, TAK_pseudogeneric,
20642064
TAK_callee_owned, TAK_callee_guaranteed, TAK_noescape, TAK_autoclosure,
2065-
TAK_escaping, TAK_yield_once, TAK_yield_many
2065+
TAK_differentiable, TAK_escaping, TAK_yield_once, TAK_yield_many
20662066
};
20672067

20682068
auto checkUnsupportedAttr = [&](TypeAttrKind attr) {
@@ -2211,6 +2211,12 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
22112211
attrs.clearAttribute(TAK_autoclosure);
22122212
}
22132213

2214+
if (attrs.has(TAK_differentiable) &&
2215+
!Context.LangOpts.EnableExperimentalDifferentiableProgramming) {
2216+
diagnose(attrs.getLoc(TAK_differentiable),
2217+
diag::experimental_differentiable_programming_disabled);
2218+
}
2219+
22142220
// Resolve the function type directly with these attributes.
22152221
FunctionType::ExtInfo extInfo(rep, /*noescape=*/false,
22162222
fnRepr->throws());
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: %target-swift-frontend -dump-parse -verify %s | %FileCheck %s
2+
3+
let a: @differentiable (Float) -> Float // okay
4+
// CHECK: (pattern_named 'a'
5+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
6+
7+
let b: @differentiable(linear) (Float) -> Float // okay
8+
// CHECK: (pattern_named 'b'
9+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
10+
11+
// Generic type test.
12+
struct A<T> {
13+
func foo() {
14+
let local: @differentiable(linear) (T) -> T // okay
15+
// CHECK: (pattern_named 'local'
16+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
17+
}
18+
}
19+
20+
// expected-error @+1 {{expected ')' in '@differentiable' attribute}}
21+
let c: @differentiable(linear (Float) -> Float
22+
23+
// expected-error @+1 {{expected ')' in '@differentiable' attribute}}
24+
let c: @differentiable(notValidArg (Float) -> Float
25+
26+
// expected-error @+1 {{unexpected argument 'notValidArg' in '@differentiable' attribute}}
27+
let d: @differentiable(notValidArg) (Float) -> Float
28+
29+
// Using 'linear' as a type
30+
struct B {
31+
struct linear {}
32+
let propertyB1: @differentiable (linear) -> Float // okay
33+
// CHECK: (pattern_named 'propertyB1'
34+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
35+
36+
let propertyB2: @differentiable(linear) (linear) -> linear // okay
37+
// CHECK: (pattern_named 'propertyB2'
38+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
39+
40+
let propertyB3: @differentiable (linear, linear) -> linear // okay
41+
// CHECK: (pattern_named 'propertyB3'
42+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
43+
44+
let propertyB4: @differentiable (linear, Float) -> linear // okay
45+
// CHECK: (pattern_named 'propertyB4'
46+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
47+
48+
let propertyB5: @differentiable (Float, linear) -> linear // okay
49+
// CHECK: (pattern_named 'propertyB5'
50+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
51+
52+
let propertyB6: @differentiable(linear) (linear, linear, Float, linear)
53+
-> Float // okay
54+
// CHECK: (pattern_named 'propertyB6'
55+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
56+
57+
// expected-error @+1 {{expected ')' in '@differentiable' attribute}}
58+
let propertyB7: @differentiable(linear (linear) -> Float
59+
}
60+
61+
// Using 'linear' as a typealias
62+
struct C {
63+
typealias linear = (C) -> C
64+
let propertyC1: @differentiable (linear) -> Float // okay
65+
// CHECK: (pattern_named 'propertyC1'
66+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
67+
68+
let propertyC2: @differentiable(linear) (linear) -> linear // okay
69+
// CHECK: (pattern_named 'propertyC2'
70+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
71+
72+
let propertyC3: @differentiable linear // okay
73+
// CHECK: (pattern_named 'propertyC3'
74+
// CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}}
75+
76+
let propertyC4: linear // okay
77+
// CHECK: (pattern_named 'propertyC4'
78+
79+
let propertyC5: @differentiable(linear) linear // okay
80+
// CHECK: (pattern_named 'propertyC5'
81+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
82+
83+
let propertyC6: @differentiable(linear) @convention(c) linear // okay
84+
// CHECK: (pattern_named 'propertyC6'
85+
// CHECK-NEXT: (type_attributed attrs=@differentiable(linear)
86+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %s
2+
3+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
4+
let _: @differentiable (Float) -> Float
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -typecheck -verify %s
2+
3+
// expected-error @+1 {{@differentiable attribute only applies to function types}}
4+
let _: @differentiable Float
5+
6+
let _: @differentiable (Float) -> Float

0 commit comments

Comments
 (0)