Skip to content

Commit 4ebb027

Browse files
marcrasirxwei
authored andcommitted
fix @differentiable function in .swiftinterface (#25059)
While messing with .swiftinterface files, I discovered that `@differentiable` functions don't get printed correctly in .swiftinterface files. The test in `test/AutoDiff/differentiable_attr_parseable_interface.swift` demonstrates what this fixes. The extra test in `test/Serialization/differentiable_attr.swift` was already passing before this change, but I noticed that that file didn't have a `@differentiable` function as a param, so I thought it would be good it to add to that file too.
1 parent b20dcc5 commit 4ebb027

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

lib/AST/TypeRepr.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer,
299299
if (hasAttr(TAK_escaping))
300300
Printer.printSimpleAttr("@escaping") << " ";
301301

302+
// SWIFT_ENABLE_TENSORFLOW
303+
if (hasAttr(TAK_differentiable))
304+
Printer.printSimpleAttr("@differentiable") << " ";
305+
302306
if (hasAttr(TAK_thin))
303307
Printer.printSimpleAttr("@thin") << " ";
304308
if (hasAttr(TAK_thick))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// RUN: %target-swift-frontend -typecheck -emit-parseable-module-interface-path %t.swiftinterface -enable-library-evolution %s
3+
// RUN: %FileCheck %s < %t.swiftinterface
4+
5+
// CHECK: func testDifferentiableParam(f: @differentiable (Float) -> Float)
6+
public func testDifferentiableParam(f: @differentiable (Float) -> Float) {}

test/Serialization/differentiable_attr.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,6 @@ extension P where Self : Differentiable, Self == Self.TangentVector {
9797
return (self, { v in v })
9898
}
9999
}
100+
101+
// CHECK: func testDifferentiableParam(f: @differentiable (Float) -> Float)
102+
func testDifferentiableParam(f: @differentiable (Float) -> Float) {}

0 commit comments

Comments
 (0)