Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 213a648

Browse files
committed
Style fixes.
1 parent 5a53e02 commit 213a648

File tree

5 files changed

+52
-50
lines changed

5 files changed

+52
-50
lines changed

Sources/TensorFlow/Core/LazyTensorContext.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct LazyTensorContext {
6565
var isShapeTrackingEnabled = true
6666
/// Should constants in trace be heuristically promoted to inputs automatically?
6767
/// (See `LazyTensorTraceCache`)
68-
var constPromotion = true
68+
var shouldPromoteConstants = true
6969

7070
static var local: LazyTensorContext {
7171
_read { yield _ThreadLocalState.local.lazyTensorContext }

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class LazyTensorTraceBuilder {
8181
lazyOperations: builder.originalOutputs,
8282
trace: trace,
8383
concreteInputs: builder.inputValues)
84-
return LazyTensorContext.local.constPromotion
84+
return LazyTensorContext.local.shouldPromoteConstants
8585
? LazyTensorTraceCache.traceWithPromotedConstants(materializationTraceInfo)
8686
: materializationTraceInfo
8787
}

Sources/TensorFlow/Core/LazyTensorTraceCache.swift

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,55 +20,55 @@ extension TFETensorHandle {
2020
}
2121

2222
/// Returns true if the underlying tensors are equal.
23-
static func areTensorsEqual(_ lhs: TFETensorHandle, _ rhs: TFETensorHandle) -> Bool {
24-
let lhsDtype = TFE_TensorHandleDataType(lhs._cTensorHandle)
25-
let rhsDtype = TFE_TensorHandleDataType(rhs._cTensorHandle)
23+
func elementsEqual(_ other: TFETensorHandle) -> Bool {
24+
let selfDtype = TFE_TensorHandleDataType(self._cTensorHandle)
25+
let otherDtype = TFE_TensorHandleDataType(other._cTensorHandle)
2626
precondition(
27-
lhsDtype == rhsDtype && lhsDtype != TF_VARIANT && lhsDtype != TF_RESOURCE,
27+
selfDtype == otherDtype && selfDtype != TF_VARIANT && selfDtype != TF_RESOURCE,
2828
"Datatypes of tensor handles don't match.")
2929
let op = TFE_Op("Equal", 1)
30-
op.updateAttribute("T", TensorDataType(lhsDtype))
31-
op.addInput(lhs)
32-
op.addInput(rhs)
30+
op.updateAttribute("T", TensorDataType(selfDtype))
31+
op.addInput(self)
32+
op.addInput(other)
3333
let result: Tensor<Bool> = op.execute(Int(1))
3434
return result.scalars.allSatisfy { $0 }
3535
}
3636
}
3737

3838
extension LazyTensorHandle {
39-
static func areHandlesEquivalent(_ lhs: LazyTensorHandle, _ rhs: LazyTensorHandle) -> Bool {
40-
switch (lhs.handle, rhs.handle) {
39+
func isEquivalent(to other: LazyTensorHandle) -> Bool {
40+
switch (self.handle, other.handle) {
4141
case let (.concrete(x, _), .concrete(y, _)):
4242
return TFETensorHandle.areHandlesEquivalent(x, y)
4343
case let (.symbolic(x, xi, _), .symbolic(y, yi, _)):
44-
return (xi == yi) && (x.id == y.id)
44+
return xi == yi && x.id == y.id
4545
default: return false
4646
}
4747
}
4848
}
4949

50-
extension LazyTensorOperation {
50+
extension LazyTensorOperation.Input {
5151
/// Returns true if these inputs are equivalent when comparing lazy tensor traces.
52-
static func areInputsEquivalent(_ lhs: Input, _ rhs: Input) -> Bool {
53-
switch (lhs, rhs) {
52+
func isEquivalent(to other: LazyTensorOperation.Input) -> Bool {
53+
switch (self, other) {
5454
case let (.single(l), .single(r)):
55-
return LazyTensorHandle.areHandlesEquivalent(l, r)
55+
return l.isEquivalent(to: r)
5656
case let (.list(l), .list(r)):
57-
return l.elementsEqual(r, by: {LazyTensorHandle.areHandlesEquivalent($0, $1) })
57+
return l.elementsEqual(r, by: { $0.isEquivalent(to: $1) })
5858
default:
5959
return false
6060
}
6161
}
62+
}
6263

64+
extension LazyTensorOperation {
6365
/// Returns true if these operations are equivalent when comparing lazy tensor traces.
64-
static func areEquivalent(_ lhs: LazyTensorOperation, _ rhs: LazyTensorOperation) -> Bool {
65-
return (lhs.name == rhs.name) &&
66-
(lhs.outputCount == rhs.outputCount) &&
67-
(lhs.deviceName == rhs.deviceName) &&
68-
lhs.inputs.elementsEqual(
69-
rhs.inputs,
70-
by: { LazyTensorOperation.areInputsEquivalent($0, $1) }) &&
71-
(lhs.attributes == rhs.attributes)
66+
func isEquivalent(to other: LazyTensorOperation) -> Bool {
67+
return self.name == other.name &&
68+
self.outputCount == other.outputCount &&
69+
self.deviceName == other.deviceName &&
70+
self.inputs.elementsEqual(other.inputs, by: { $0.isEquivalent(to: $1) }) &&
71+
self.attributes == other.attributes
7272
}
7373
}
7474

@@ -100,21 +100,23 @@ func ==(_ lhs: LazyTensorOperation.Attribute, _ rhs: LazyTensorOperation.Attribu
100100
}
101101
}
102102

103-
// TODO(https://bugs.swift.org/browse/TF-693): This is not thread safe!
103+
// TODO(TF-693): This is not thread safe!
104104
struct LazyTensorTraceCache {
105-
// Cache from signature to traces that match signature.
105+
/// Cache from signature to traces that match signature.
106106
static private var cache: [String: [LazyTensorTrace]] = [:]
107107
static func clearCache() { cache.removeAll() }
108108

109-
// Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
110-
static func traceWithPromotedConstants(_ traceInfo: MaterializationTraceInfo) -> MaterializationTraceInfo {
109+
/// Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
110+
static func traceWithPromotedConstants(
111+
_ traceInfo: MaterializationTraceInfo
112+
) -> MaterializationTraceInfo {
111113
let trace = traceInfo.trace
112114
guard var traces = cache[trace.signature] else {
113115
cache[trace.signature] = [trace]
114116
return traceInfo
115117
}
116118
for cachedTrace in traces {
117-
if let promotedTrace = traceWithPromotedConstants(traceInfo, cachedTrace) {
119+
if let promotedTrace = traceInfo.withPromotedConstants(cachedTrace: cachedTrace) {
118120
debugLog("Promoted: \(promotedTrace)\n")
119121
return promotedTrace
120122
}
@@ -123,23 +125,22 @@ struct LazyTensorTraceCache {
123125
traces.append(trace)
124126
return traceInfo
125127
}
128+
}
126129

127-
static private func traceWithPromotedConstants(
128-
_ traceInfo: MaterializationTraceInfo,
129-
_ cachedTrace: LazyTensorTrace
130-
) -> MaterializationTraceInfo? {
131-
let currentTrace = traceInfo.trace
130+
private extension MaterializationTraceInfo {
131+
func withPromotedConstants(cachedTrace: LazyTensorTrace) -> MaterializationTraceInfo? {
132+
let currentTrace = self.trace
132133
if currentTrace.operations.count != cachedTrace.operations.count { return nil }
133134
var promotableConstants: [(Int, TFETensorHandle)] = []
134135
for (i, current) in currentTrace.operations.enumerated() {
135136
let cached = cachedTrace.operations[i]
136-
if let (currentTensor, cachedTensor) = promotableConstant(current, cached) {
137-
if TFETensorHandle.areTensorsEqual(currentTensor, cachedTensor) { continue }
137+
if let (currentTensor, cachedTensor) = Self.promotableConstants(current, cached) {
138+
if currentTensor.elementsEqual(cachedTensor) { continue }
138139
promotableConstants.append((i, currentTensor))
139140
continue
140141
}
141142
// TODO: we might avoid running the following check based on results of promotableConstant
142-
if LazyTensorOperation.areEquivalent(current, cached) { continue }
143+
if current.isEquivalent(to: cached) { continue }
143144
return nil
144145
}
145146

@@ -157,26 +158,27 @@ struct LazyTensorTraceCache {
157158
operations: newOperations,
158159
outputs: currentTrace.outputs)
159160
return MaterializationTraceInfo(
160-
lazyOperations: traceInfo.lazyOperations,
161+
lazyOperations: self.lazyOperations,
161162
trace: newTrace,
162-
concreteInputs: traceInfo.concreteInputs + newConcreteInputs)
163+
concreteInputs: self.concreteInputs + newConcreteInputs)
163164
}
164165

165166
/// If `current` and `cached` are compatible constants, returns the constant tensors.
166-
static private func promotableConstant(
167+
static private func promotableConstants(
167168
_ current: LazyTensorOperation,
168169
_ cached: LazyTensorOperation
169170
) -> (TFETensorHandle, TFETensorHandle)? {
170-
if (current.name != "Const" || cached.name != "Const") { return nil }
171+
if current.name != "Const" || cached.name != "Const" { return nil }
171172
let currentValue = current.attributes["value"]!
172173
let cachedValue = cached.attributes["value"]!
173-
guard case let .constTensor(currentTensor) = currentValue else { return nil }
174-
guard case let .constTensor(cachedTensor) = cachedValue else { return nil }
174+
guard case let .constTensor(currentTensor) = currentValue,
175+
case let .constTensor(cachedTensor) = cachedValue
176+
else { return nil }
175177
let currentDtype = TFE_TensorHandleDataType(currentTensor._cTensorHandle)
176178
let cachedDtype = TFE_TensorHandleDataType(cachedTensor._cTensorHandle)
177179
if currentDtype == TF_VARIANT || currentDtype == TF_RESOURCE { return nil }
178180
if cachedDtype == TF_VARIANT || cachedDtype == TF_RESOURCE { return nil }
179-
return (currentTensor.shape == cachedTensor.shape) && (currentDtype == cachedDtype)
181+
return currentTensor.shape == cachedTensor.shape && currentDtype == cachedDtype
180182
? (currentTensor, cachedTensor)
181183
: nil
182184
}

Tests/TensorFlowTests/LazyTensorTestHelper.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ import XCTest
1616
@testable import TensorFlow
1717

1818
class LazyTensorTestCase: XCTestCase {
19-
static var constPromotion = true
19+
static var shouldPromoteConstants = true
2020
override class func setUp() {
2121
super.setUp()
2222
_ThreadLocalState.useLazyTensor = true
23-
constPromotion = LazyTensorContext.local.constPromotion
24-
LazyTensorContext.local.constPromotion = false
23+
shouldPromoteConstants = LazyTensorContext.local.shouldPromoteConstants
24+
LazyTensorContext.local.shouldPromoteConstants = false
2525
}
2626

2727
override class func tearDown() {
2828
super.tearDown()
2929
_ThreadLocalState.useLazyTensor = false
30-
LazyTensorContext.local.constPromotion = constPromotion
30+
LazyTensorContext.local.shouldPromoteConstants = shouldPromoteConstants
3131
}
3232
}
3333

Tests/TensorFlowTests/LazyTensorTraceCacheTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import CTensorFlow
2020
final class LazyTensorTraceCacheTests: LazyTensorTestCase {
2121
override class func setUp() {
2222
super.setUp()
23-
LazyTensorContext.local.constPromotion = true
23+
LazyTensorContext.local.shouldPromoteConstants = true
2424
}
2525

2626
override class func tearDown() {

0 commit comments

Comments
 (0)