Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 33fe7f3

Browse files
authored
Lazy tensor: automatically promote constants to inputs based on history (#476)
1 parent 5a8ac7b commit 33fe7f3

10 files changed

+306
-11
lines changed

Sources/TensorFlow/Bindings/TFTensorOperation.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
/// Opaque reference to a function that has been made callable by loading it
1616
/// into the runtime.
17-
public struct _TensorFunctionPointer {
17+
public struct _TensorFunctionPointer: Equatable {
1818
public var name: String
1919
public init(name: String) {
2020
self.name = name

Sources/TensorFlow/Core/DataTypes.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import CTensorFlow
1919
// This simply wraps a `TF_DataType` and allows user code to handle
2020
// `TF_DataType` without importing CTensorFlow, which pollutes the namespace
2121
// with TensorFlow C API declarations.
22-
public struct TensorDataType {
22+
public struct TensorDataType: Equatable {
2323
public var _cDataType: TF_DataType
2424

2525
@usableFromInline

Sources/TensorFlow/Core/LazyTensorContext.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class LazyTensorOperationsTracker {
6363
struct LazyTensorContext {
6464
var operationsTracker = LazyTensorOperationsTracker()
6565
var isShapeTrackingEnabled = true
66+
/// Should constants in trace be heuristically promoted to inputs automatically?
67+
/// (See `LazyTensorTraceCache`)
68+
var shouldPromoteConstants = true
6669

6770
static var local: LazyTensorContext {
6871
_read { yield _ThreadLocalState.local.lazyTensorContext }

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class LazyTensorOperation: TensorOperation {
180180
case list([LazyTensorHandle])
181181
}
182182

183-
enum Attribute {
183+
enum Attribute: Equatable {
184184
case boolValue(Bool)
185185
case intValue(Int)
186186
case floatValue(Float)
@@ -199,7 +199,7 @@ class LazyTensorOperation: TensorOperation {
199199
case optionalTensorShapeArray([TensorShape?])
200200
}
201201

202-
let name: String
202+
var name: String
203203
let outputCount: Int
204204
var inputs: [Input]
205205
var attributes: [String: Attribute]

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ class LazyTensorTraceBuilder {
7777
inputs: builder.inputs,
7878
operations: builder.operations,
7979
outputs: builder.outputs)
80-
return MaterializationTraceInfo(
80+
let materializationTraceInfo = MaterializationTraceInfo(
8181
lazyOperations: builder.originalOutputs,
8282
trace: trace,
8383
concreteInputs: builder.inputValues)
84+
return LazyTensorContext.local.shouldPromoteConstants
85+
? LazyTensorTraceCache.traceWithPromotedConstants(materializationTraceInfo)
86+
: materializationTraceInfo
8487
}
8588

8689
static func materializationTraceInfo(
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import CTensorFlow
15+
16+
extension TFETensorHandle: Equatable {}
17+
18+
public func ==(_ lhs: TFETensorHandle, _ rhs: TFETensorHandle) -> Bool {
19+
return lhs._cTensorHandle == rhs._cTensorHandle
20+
}
21+
22+
extension TFETensorHandle {
23+
/// Returns true if the underlying tensors are equal.
24+
func elementsEqual(_ other: TFETensorHandle) -> Bool {
25+
let selfDtype = TFE_TensorHandleDataType(self._cTensorHandle)
26+
let otherDtype = TFE_TensorHandleDataType(other._cTensorHandle)
27+
precondition(
28+
selfDtype == otherDtype && selfDtype != TF_VARIANT && selfDtype != TF_RESOURCE,
29+
"Datatypes of tensor handles don't match.")
30+
let op = TFE_Op("Equal", 1)
31+
op.updateAttribute("T", TensorDataType(selfDtype))
32+
op.addInput(self)
33+
op.addInput(other)
34+
let result: Tensor<Bool> = op.execute(Int(1))
35+
return result.scalars.allSatisfy { $0 }
36+
}
37+
}
38+
39+
extension LazyTensorHandle {
40+
func isEquivalent(to other: LazyTensorHandle) -> Bool {
41+
switch (self.handle, other.handle) {
42+
case let (.concrete(x, _), .concrete(y, _)):
43+
return x == y
44+
case let (.symbolic(x, xi, _), .symbolic(y, yi, _)):
45+
return xi == yi && x.id == y.id
46+
default: return false
47+
}
48+
}
49+
}
50+
51+
extension LazyTensorOperation.Input {
52+
/// Returns true if these inputs are equivalent when comparing lazy tensor traces.
53+
func isEquivalent(to other: LazyTensorOperation.Input) -> Bool {
54+
switch (self, other) {
55+
case let (.single(l), .single(r)):
56+
return l.isEquivalent(to: r)
57+
case let (.list(l), .list(r)):
58+
return l.elementsEqual(r, by: { $0.isEquivalent(to: $1) })
59+
default:
60+
return false
61+
}
62+
}
63+
}
64+
65+
extension LazyTensorOperation {
66+
/// Returns true if these operations are equivalent when comparing lazy tensor traces.
67+
func isEquivalent(to other: LazyTensorOperation) -> Bool {
68+
return self.name == other.name &&
69+
self.outputCount == other.outputCount &&
70+
self.deviceName == other.deviceName &&
71+
self.inputs.elementsEqual(other.inputs, by: { $0.isEquivalent(to: $1) }) &&
72+
self.attributes == other.attributes
73+
}
74+
}
75+
76+
// TODO(TF-693): This is not thread safe!
77+
struct LazyTensorTraceCache {
78+
/// Cache from signature to traces that match signature.
79+
static private var cache: [String: [LazyTensorTrace]] = [:]
80+
static func clearCache() { cache.removeAll() }
81+
82+
/// Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
83+
static func traceWithPromotedConstants(
84+
_ traceInfo: MaterializationTraceInfo
85+
) -> MaterializationTraceInfo {
86+
let trace = traceInfo.trace
87+
guard var traces = cache[trace.signature] else {
88+
cache[trace.signature] = [trace]
89+
return traceInfo
90+
}
91+
for cachedTrace in traces {
92+
if let promotedTrace = traceInfo.withPromotedConstants(cachedTrace: cachedTrace) {
93+
debugLog("Promoted: \(promotedTrace)\n")
94+
return promotedTrace
95+
}
96+
}
97+
// No match found; cache and return the input `traceInfo` itself.
98+
traces.append(trace)
99+
return traceInfo
100+
}
101+
}
102+
103+
private extension MaterializationTraceInfo {
104+
func withPromotedConstants(cachedTrace: LazyTensorTrace) -> MaterializationTraceInfo? {
105+
let currentTrace = self.trace
106+
if currentTrace.operations.count != cachedTrace.operations.count { return nil }
107+
var promotableConstants: [(Int, TFETensorHandle)] = []
108+
for (i, current) in currentTrace.operations.enumerated() {
109+
let cached = cachedTrace.operations[i]
110+
if let (currentTensor, cachedTensor) = Self.promotableConstants(current, cached) {
111+
if currentTensor.elementsEqual(cachedTensor) { continue }
112+
promotableConstants.append((i, currentTensor))
113+
continue
114+
}
115+
// TODO: we might avoid running the following check based on results of promotableConstant
116+
if current.isEquivalent(to: cached) { continue }
117+
return nil
118+
}
119+
120+
let newConcreteInputs: [TFETensorHandle] = promotableConstants.map { return $0.1 }
121+
let newOperations = currentTrace.operations
122+
let newInputs = promotableConstants.map {
123+
(promotableConstant: (Int, TFETensorHandle)) -> LazyTensorOperation in
124+
let constantOp = newOperations[promotableConstant.0]
125+
constantOp.name = "Placeholder"
126+
constantOp.attributes.removeValue(forKey: "value")
127+
return constantOp
128+
}
129+
let newTrace = LazyTensorTrace(
130+
inputs: currentTrace.inputs + newInputs,
131+
operations: newOperations,
132+
outputs: currentTrace.outputs)
133+
return MaterializationTraceInfo(
134+
lazyOperations: self.lazyOperations,
135+
trace: newTrace,
136+
concreteInputs: self.concreteInputs + newConcreteInputs)
137+
}
138+
139+
/// If `current` and `cached` are compatible constants, returns the constant tensors.
140+
static private func promotableConstants(
141+
_ current: LazyTensorOperation,
142+
_ cached: LazyTensorOperation
143+
) -> (TFETensorHandle, TFETensorHandle)? {
144+
if current.name != "Const" || cached.name != "Const" { return nil }
145+
let currentValue = current.attributes["value"]!
146+
let cachedValue = cached.attributes["value"]!
147+
guard case let .constTensor(currentTensor) = currentValue,
148+
case let .constTensor(cachedTensor) = cachedValue
149+
else { return nil }
150+
let currentDtype = TFE_TensorHandleDataType(currentTensor._cTensorHandle)
151+
let cachedDtype = TFE_TensorHandleDataType(cachedTensor._cTensorHandle)
152+
if currentDtype == TF_VARIANT || currentDtype == TF_RESOURCE { return nil }
153+
if cachedDtype == TF_VARIANT || cachedDtype == TF_RESOURCE { return nil }
154+
return currentTensor.shape == cachedTensor.shape && currentDtype == cachedDtype
155+
? (currentTensor, cachedTensor)
156+
: nil
157+
}
158+
}

Tests/TensorFlowTests/LazyTensorTestHelper.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@ import XCTest
1616
@testable import TensorFlow
1717

1818
class LazyTensorTestCase: XCTestCase {
19+
static var shouldPromoteConstants = true
1920
override class func setUp() {
2021
super.setUp()
2122
_ThreadLocalState.useLazyTensor = true
23+
shouldPromoteConstants = LazyTensorContext.local.shouldPromoteConstants
24+
LazyTensorContext.local.shouldPromoteConstants = false
2225
}
2326

2427
override class func tearDown() {
2528
super.tearDown()
2629
_ThreadLocalState.useLazyTensor = false
30+
LazyTensorContext.local.shouldPromoteConstants = shouldPromoteConstants
2731
}
2832
}
2933

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import TensorFlow
18+
import CTensorFlow
19+
20+
final class LazyTensorTraceCacheTests: LazyTensorTestCase {
21+
override class func setUp() {
22+
super.setUp()
23+
LazyTensorContext.local.shouldPromoteConstants = true
24+
}
25+
26+
override class func tearDown() {
27+
super.tearDown()
28+
LazyTensorTraceCache.clearCache()
29+
}
30+
31+
func testConstPromotion() {
32+
LazyTensorTraceCache.clearCache()
33+
let a = Tensor<Float>(1.0)
34+
let b = Tensor<Float>(2.0)
35+
let c = Tensor<Float>(3.0)
36+
let d = Tensor<Float>(4.0)
37+
let w = a * b
38+
let x = c * d
39+
// Trigger materialization for `w` so that a trace with constants and mul is added to cache.
40+
XCTAssertEqual(
41+
lazyTrace(w).description,
42+
"""
43+
lazyTrace_3() -> (%2) {
44+
%0 = Const[dtype: float, value: 1.0]()
45+
%1 = Const[dtype: float, value: 2.0]()
46+
%2 = Mul[T: float](%0, %1)
47+
}
48+
""")
49+
XCTAssertEqual(w.scalars, [2.0])
50+
51+
// The trace for `x` should have the inputs to Mul as arguments instead of constants.
52+
XCTAssertEqual(
53+
lazyTrace(x).description,
54+
"""
55+
lazyTrace_3(%0: float, %1: float) -> (%2) {
56+
%2 = Mul[T: float](%0, %1)
57+
}
58+
""")
59+
XCTAssertEqual(x.scalarized(), 12.0)
60+
61+
let e = Tensor<Float>(shape: [1,3], scalars: [1, 2, 3])
62+
let f = Tensor<Float>(5.0)
63+
let y = e * f
64+
// We won't promote constants in 'y' as shape of constants is different.
65+
XCTAssertEqual(
66+
lazyTrace(y).description,
67+
"""
68+
lazyTrace_3() -> (%2) {
69+
%0 = Const[dtype: float, value: [[1.0, 2.0, 3.0]]]()
70+
%1 = Const[dtype: float, value: 5.0]()
71+
%2 = Mul[T: float](%0, %1)
72+
}
73+
""")
74+
XCTAssertEqual(y.scalars, [5.0, 10.0, 15.0])
75+
}
76+
77+
func testDoNotPromoteEqualConstants() {
78+
LazyTensorTraceCache.clearCache()
79+
let a = Tensor<Float>(1.0)
80+
let b = Tensor<Float>(2.0)
81+
let c = Tensor<Float>(3.0)
82+
let w = a * b
83+
let x = a * c
84+
XCTAssertEqual(
85+
lazyTrace(w).description,
86+
"""
87+
lazyTrace_3() -> (%2) {
88+
%0 = Const[dtype: float, value: 1.0]()
89+
%1 = Const[dtype: float, value: 2.0]()
90+
%2 = Mul[T: float](%0, %1)
91+
}
92+
""")
93+
XCTAssertEqual(w.scalars, [2.0])
94+
// Const 1.0 is not promoted.
95+
XCTAssertEqual(
96+
lazyTrace(x).description,
97+
"""
98+
lazyTrace_3(%1: float) -> (%2) {
99+
%0 = Const[dtype: float, value: 1.0]()
100+
%2 = Mul[T: float](%0, %1)
101+
}
102+
""")
103+
}
104+
105+
private func lazyTensorOperation<T: TensorFlowScalar>(
106+
_ input: Tensor<T>
107+
) -> LazyTensorOperation? {
108+
let tensor = input.handle.handle
109+
guard let lazyTensor = tensor as? LazyTensorHandle else {
110+
XCTFail("Trying to get lazy trace for a non-lazy tensor.")
111+
return nil
112+
}
113+
guard case let .symbolic(lazyOp, _, _) = lazyTensor.handle else {
114+
XCTFail("Cannot get lazy trace for a concrete tensor.")
115+
return nil
116+
}
117+
return lazyOp
118+
}
119+
120+
private func lazyTrace<T: TensorFlowScalar>(
121+
_ input: Tensor<T>
122+
) -> LazyTensorTrace {
123+
let lazyOperation = lazyTensorOperation(input)!
124+
return LazyTensorTraceBuilder.materializationTraceInfo(lazyOperation).trace
125+
}
126+
127+
static var allTests = [
128+
("testConstPromotion", testConstPromotion),
129+
("testDoNotPromoteEqualConstants", testDoNotPromoteEqualConstants)
130+
]
131+
}

Tests/TensorFlowTests/TensorGroupTests.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@ import XCTest
1616
@testable import TensorFlow
1717
import CTensorFlow
1818

19-
extension TensorDataType : Equatable {
20-
public static func == (lhs: TensorDataType, rhs: TensorDataType) -> Bool {
21-
return Int(lhs._cDataType.rawValue) == Int(rhs._cDataType.rawValue)
22-
}
23-
}
24-
2519
struct Empty : TensorGroup {}
2620

2721
struct Simple : TensorGroup, Equatable {

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public func allTests() -> [XCTestCaseEntry] {
2525
testCase(InitializerTests.allTests),
2626
testCase(LayerTests.allTests),
2727
testCase(LazyTensorEvaluationTests.allTests),
28+
testCase(LazyTensorTraceTests.allTests),
29+
testCase(LazyTensorTraceCacheTests.allTests),
2830
testCase(LazyTensorExplicitTraceTests.allTests),
2931
testCase(LazyTensorHandleTests.allTests),
3032
testCase(LazyTensorOperationTests.allTests),

0 commit comments

Comments
 (0)