Skip to content

Commit db95e53

Browse files
authored
[AutoDiff] [ASTGen] Check for 'linear' when generating 'AttributedTypeRepr'. (#27656)
ASTGen did not know how to handle `@differentiable(linear)`, which caused parsed `@differentiable(linear)` function types to lose the "linear" bit when they are an argument or a result in a SIL function's type signature. This patch teaches ASTGen to check for `linear` in `TypeAttributes` when generating AST. Resolves [TF-901](https://bugs.swift.org/browse/TF-901).
1 parent 98f3545 commit db95e53

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

lib/Parse/ASTGen.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ TypeRepr *ASTGen::generate(const AttributedTypeSyntax &Type,
240240
TypeAttrs.convention = Convention.str();
241241
}
242242

243+
// SWIFT_ENABLE_TENSORFLOW
244+
if (AttrKind == TAK_differentiable && Attr.getArgument()) {
245+
auto Argument = Attr.getArgument()->castTo<TokenSyntax>();
246+
auto Linear = Context.getIdentifier(Argument.getIdentifierText());
247+
TypeAttrs.linear = Linear.is("linear");
248+
}
249+
243250
if (AttrKind == TAK_opened) {
244251
auto AttrText = Attr.getArgument()->castTo<TokenSyntax>().getText();
245252
auto LiteralText = AttrText.slice(1, AttrText.size() - 1);

test/AutoDiff/differentiable_sil_function_type_parse.sil

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,13 @@ bb0:
3434
%ret = tuple ()
3535
return %ret : $()
3636
}
37+
38+
sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
39+
bb0(%0 : $@differentiable(linear) (Float) -> Float):
40+
return %0 : $@differentiable(linear) (Float) -> Float
41+
}
42+
43+
// CHECK-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
44+
// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
45+
// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float
46+
// CHECK: }

0 commit comments

Comments
 (0)