Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a2fa8c4

Browse files
Add model summary (#1067)
* Initial annotations prototype * Lint * Add Dan's workaround for SR-13455 * Get tests running * Add docstrings Remove a workaround for a bug that no longer exists Remove `Layer.annotations()` and `Layer.summary(inputShape:)` Add annotations to `Flatten` * Add summary formatting Remove unused validation function in TFEagerTests * Add additional layer types Modify TensorProtocol extensions to support integer scalars
1 parent 61e7724 commit a2fa8c4

File tree

16 files changed

+514
-117
lines changed

16 files changed

+514
-117
lines changed

Documentation/X10/SUMMARY.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Model Summaries
2+
3+
A summary provides details about the architecture of a model, such as layer
4+
types and shapes.
5+
6+
The design proposal can be found [here][design]. This
7+
implementation is a WIP, so please file an [Issue][new_issue] with
8+
enhancements you would like to see or problems you run into.
9+
10+
**Note:** Model summaries are currently supported on the X10 backend only.
11+
12+
## Viewing a model summary
13+
14+
Create an X10 device and model.
15+
16+
```
17+
import TensorFlow
18+
19+
public struct MyModel: Layer {
20+
public var dense1 = Dense<Float>(inputSize: 1, outputSize: 1)
21+
public var dense2 = Dense<Float>(inputSize: 4, outputSize: 4)
22+
public var dense3 = Dense<Float>(inputSize: 4, outputSize: 4)
23+
public var flatten = Flatten<Float>()
24+
25+
@differentiable
26+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
27+
let layer1 = dense1(input)
28+
let layer2 = layer1.reshaped(to: [1, 4])
29+
let layer3 = dense2(layer2)
30+
let layer4 = dense3(layer3)
31+
return flatten(layer4)
32+
}
33+
}
34+
35+
let device = Device.defaultXLA
36+
let model0 = MyModel()
37+
let model = MyModel(copying: model0, to: device)
38+
```
39+
40+
Create an input tensor.
41+
42+
```
43+
let input = Tensor<Float>(repeating: 1, shape: [1, 4, 1, 1], on: device)
44+
```
45+
46+
Generate a summary of your model.
47+
48+
```
49+
let summary = model.summary(input: input)
50+
print(summary)
51+
```
52+
53+
```
54+
Layer Output Shape Attributes
55+
=============================== ==================== ======================
56+
Dense<Float> [1, 4, 1, 1]
57+
Dense<Float> [1, 4]
58+
Dense<Float> [1, 4]
59+
Flatten<Float> [1, 4]
60+
```
61+
62+
**Note:** the `summary()` function executes the model in order to obtain
63+
details about its architecture.
64+
65+
66+
[design]: https://docs.google.com/document/d/1hEhMiwLtuzsN3RvIC3FAh6NvtTimU8o_qdzMkGvntVg/view
67+
[new_issue]: https://github.com/tensorflow/swift-apis/issues/new

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -41,67 +41,6 @@ public struct Tensor<Scalar: TensorFlowScalar> {
4141
}
4242
}
4343

44-
public protocol TensorProtocol {
45-
associatedtype Scalar: TensorFlowScalar
46-
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
47-
var annotations: String { get }
48-
var shape: TensorShape { get }
49-
var summary: String { get }
50-
}
51-
52-
public protocol DifferentiableTensorProtocol:
53-
TensorProtocol & Differentiable & EuclideanDifferentiable
54-
where Scalar: TensorFlowFloatingPoint {
55-
@differentiable(wrt: self)
56-
func annotate(_ annotation: String) -> Self
57-
}
58-
59-
extension Tensor: TensorProtocol & DifferentiableTensorProtocol
60-
where Scalar: TensorFlowFloatingPoint {
61-
62-
public var annotations: String {
63-
#if USING_X10_BACKEND
64-
switch handle.backend {
65-
case .XLA:
66-
let rawAnnotations = XLATensor.annotations(xlaTensor)
67-
68-
// TODO(michellecasbon): Add formatting.
69-
70-
return rawAnnotations
71-
72-
case .TF_EAGER:
73-
return Device.defaultTFEager.annotationsAvailable
74-
}
75-
#else
76-
return "Annotations not available in TF_EAGER."
77-
#endif
78-
}
79-
80-
public var summary: String { annotations }
81-
82-
@differentiable(wrt: self)
83-
public func annotate(_ annotation: String) -> Tensor<Scalar> {
84-
#if USING_X10_BACKEND
85-
switch handle.backend {
86-
case .XLA:
87-
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
88-
case .TF_EAGER:
89-
return self
90-
}
91-
#else
92-
return self
93-
#endif
94-
}
95-
96-
@derivative(of: annotate)
97-
@usableFromInline
98-
func vjpAnnotate(_ annotation: String) -> (
99-
value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>
100-
) {
101-
(annotate(annotation), { $0 })
102-
}
103-
}
104-
10544
extension Tensor: AnyTensor {
10645
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
10746
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
@@ -835,3 +774,73 @@ extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorF
835774
}
836775
}
837776
#endif
777+
778+
//===------------------------------------------------------------------------------------------===//
779+
// Annotations
780+
//===------------------------------------------------------------------------------------------===//
781+
782+
public protocol TensorProtocol {
783+
associatedtype Scalar: TensorFlowScalar
784+
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
785+
var annotations: String { get }
786+
var shape: TensorShape { get }
787+
var summary: String { get }
788+
}
789+
790+
public protocol DifferentiableTensorProtocol:
791+
TensorProtocol & Differentiable & EuclideanDifferentiable
792+
where Scalar: TensorFlowFloatingPoint {
793+
@differentiable(wrt: self)
794+
func annotate(_ annotation: String) -> Self
795+
}
796+
797+
extension Tensor: TensorProtocol {
798+
/// The annotations describing this tensor.
799+
public var annotations: String {
800+
#if USING_X10_BACKEND
801+
switch handle.backend {
802+
case .XLA:
803+
return XLATensor.annotations(xlaTensor)
804+
case .TF_EAGER:
805+
return Device.defaultTFEager.annotationsAvailable
806+
}
807+
#else
808+
return "Annotations not available in TF_EAGER."
809+
#endif
810+
}
811+
812+
/// An alias for annotations.
813+
public var summary: String { annotations }
814+
}
815+
816+
extension Tensor: DifferentiableTensorProtocol
817+
where Scalar: TensorFlowFloatingPoint {
818+
/// Adds an annotation.
819+
///
820+
/// Note: Only X10 is supported. For other backends, umodified `self` is
821+
/// returned.
822+
///
823+
/// - Parameter annotation: The annotation to be added.
824+
/// - Returns: The annotated tensor.
825+
@differentiable(wrt: self)
826+
public func annotate(_ annotation: String) -> Tensor<Scalar> {
827+
#if USING_X10_BACKEND
828+
switch handle.backend {
829+
case .XLA:
830+
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
831+
case .TF_EAGER:
832+
return self
833+
}
834+
#else
835+
return self
836+
#endif
837+
}
838+
839+
@derivative(of: annotate)
840+
@usableFromInline
841+
func vjpAnnotate(_ annotation: String) -> (
842+
value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>
843+
) {
844+
(annotate(annotation), { $0 })
845+
}
846+
}

Sources/TensorFlow/Layer.swift

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Foundation
1516
import _Differentiation
1617

1718
public protocol Module: EuclideanDifferentiable, KeyPathIterable
@@ -20,6 +21,7 @@ where
2021
{
2122
/// The input type of the layer.
2223
associatedtype Input
24+
2325
/// The output type of the layer.
2426
associatedtype Output: Differentiable
2527

@@ -29,6 +31,119 @@ where
2931
/// - Returns: The output.
3032
@differentiable(wrt: self)
3133
func callAsFunction(_ input: Input) -> Output
34+
35+
/// Returns the output obtained from applying the layer to the given input.
36+
///
37+
/// - Parameter input: The input to the layer.
38+
/// - Returns: The output.
39+
@differentiable(wrt: self)
40+
func forward(_ input: Input) -> Output
41+
}
42+
43+
extension Module {
44+
/// Returns the output obtained from applying the layer to the given input.
45+
///
46+
/// - Parameter input: The input to the layer.
47+
/// - Returns: The output.
48+
@differentiable(wrt: self)
49+
public func forward(_ input: Input) -> Output {
50+
return callAsFunction(input)
51+
}
52+
}
53+
54+
extension Module where Input: TensorProtocol, Output: DifferentiableTensorProtocol {
55+
/// Returns the annotated output obtained from applying the layer to the
56+
/// given input.
57+
///
58+
/// - Parameter input: The input to the layer.
59+
/// - Returns: The annotated output.
60+
@differentiable(wrt: self)
61+
public func callAsFunction(_ input: Input) -> Output {
62+
let activation = forward(input)
63+
return annotated(activation)
64+
}
65+
66+
/// Annotates `output`.
67+
///
68+
/// Note: Returns `output` if using a backend that does not support annotations.
69+
///
70+
/// - Parameter output: The output to the layer.
71+
/// - Returns: The annotated output.
72+
@differentiable
73+
public func annotated(_ output: Output) -> Output {
74+
#if USING_X10_BACKEND
75+
let annotated = output.annotate("type=\(Self.self)")
76+
return annotated
77+
#else
78+
return output
79+
#endif
80+
}
81+
82+
/// Returns the annotations obtained from applying the layer to the given input.
83+
///
84+
/// - Parameter input: The input to the layer.
85+
/// - Returns: All collected annotations from the XLA graph.
86+
public func summary(input: Input) -> String {
87+
let output = self.callAsFunction(input)
88+
return formatAnnotations(from: output)
89+
}
90+
91+
/// Returns a formatted version of `tensor.annotations`.
92+
///
93+
/// - Parameter tensor: The output to the layer.
94+
/// - Returns: A formatted summary of `tensor.annotations`.
95+
private func formatAnnotations(from tensor: Output) -> String {
96+
#if USING_X10_BACKEND
97+
let rawAnnotations = tensor.annotations
98+
if rawAnnotations == Device.defaultTFEager.annotationsAvailable {
99+
return rawAnnotations
100+
}
101+
102+
let lines = rawAnnotations.components(separatedBy: "\n")
103+
104+
if lines.count < 3 {
105+
return ""
106+
}
107+
108+
// Isolate layers.
109+
let pattern = "\\s*shape=(.+)\\s+type=([^\\s]+)(\\s+.+=.+)?$"
110+
let regex = try! NSRegularExpression(pattern: pattern)
111+
let contents = lines.filter { $0.contains("shape=") }
112+
.map { line -> String in
113+
let nsrange = NSRange(line.startIndex..., in: line)
114+
if let match = regex.firstMatch(in: line, range: nsrange) {
115+
var content = ""
116+
if let typeRange = Range(match.range(at: 2), in: line) {
117+
let type = line[typeRange]
118+
content += type
119+
}
120+
content += "\t\t\t"
121+
if let shapeRange = Range(match.range(at: 1), in: line) {
122+
let shape = line[shapeRange]
123+
content += shape
124+
}
125+
content += "\t\t"
126+
if let attributesRange = Range(match.range(at: 3), in: line) {
127+
let attribute = line[attributesRange]
128+
content += attribute
129+
}
130+
return content
131+
} else {
132+
return line
133+
}
134+
}
135+
136+
let formattedAnnotations = """
137+
Layer Output Shape Attributes
138+
=============================== ==================== ======================
139+
\(contents.joined(separator: "\n"))
140+
"""
141+
142+
return formattedAnnotations
143+
#else
144+
return tensor.annotations
145+
#endif
146+
}
32147
}
33148

34149
/// A neural network layer.
@@ -45,6 +160,28 @@ public protocol Layer: Module where Input: Differentiable {
45160
/// - Returns: The output.
46161
@differentiable
47162
func callAsFunction(_ input: Input) -> Output
163+
164+
@differentiable
165+
func forward(_ input: Input) -> Output
166+
}
167+
168+
extension Layer {
169+
// Workaround for SR-13455: autodiff undefined symbol linker error.
170+
@differentiable(wrt: self)
171+
@differentiable
172+
public func forward(_ input: Input) -> Output {
173+
return callAsFunction(input)
174+
}
175+
}
176+
177+
extension Layer where Input: DifferentiableTensorProtocol, Output: DifferentiableTensorProtocol {
178+
// Workaround for SR-13455: autodiff undefined symbol linker error.
179+
@differentiable(wrt: self)
180+
@differentiable
181+
public func callAsFunction(_ input: Input) -> Output {
182+
let activation = forward(input)
183+
return annotated(activation)
184+
}
48185
}
49186

50187
/// An empty struct representing empty `TangentVector`s for parameterless layers.

0 commit comments

Comments
 (0)