Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 91da7b9

Browse files
authored
Add ConstTensor attribute to LazyTensorOperation. (#265)
1 parent 3b2108c commit 91da7b9

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,6 @@ class LazyTensorOperation: TensorOperation {
271271
func updateAttribute(_ name: String, _ value: [String]) {
272272
attributes[name] = Attribute.stringArray(value)
273273
}
274-
func updateAttribute(_ name: String, _ value: _TensorFunctionPointer) {
275-
attributes[name] = Attribute.tensorFunctionPointer(value)
276-
}
277274
}
278275

279276
extension LazyTensorOperation: TFTensorOperation {
@@ -336,6 +333,13 @@ extension LazyTensorOperation: TFTensorOperation {
336333
func updateAttribute(_ name: String, _ value: [TensorShape?]) {
337334
attributes[name] = Attribute.optionalTensorShapeArray(value)
338335
}
336+
func updateAttribute(_ name: String, _ value: _TensorFunctionPointer) {
337+
attributes[name] = Attribute.tensorFunctionPointer(value)
338+
}
339+
func updateAttribute(_ name: String, _ value: TFETensorHandle) {
340+
attributes[name] = Attribute.constTensor(value)
341+
}
342+
339343
func updateAttribute<In: TensorGroup, Out: TensorGroup>(
340344
_ name: String, _ value: (In) -> Out) {
341345
// TODO:

Tests/TensorFlowTests/LazyTensorOperationTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ final class LazyTensorOperationTests: XCTestCase {
167167
XCTAssertEqual(op0.description, "%0 = Nop[shapes: [nil, Optional([4, 5])]]()")
168168
}
169169

170+
func testConstTensorAttribute() {
171+
let op0 = LazyTensorOperation(
172+
_id: "0", name: "Nop", outputCount: 1)
173+
let a = Tensor<Float>(5.5)
174+
let b = Tensor<Float>([1,2])
175+
op0.updateAttribute("a", a.handle.handle._tfeTensorHandle)
176+
op0.updateAttribute("b", b.handle.handle._tfeTensorHandle)
177+
XCTAssertEqual(op0.description, "%0 = Nop[a: 5.5, b: [1.0, 2.0]]()")
178+
}
179+
170180
func testArrayAttributes() {
171181
let op0 = LazyTensorOperation(
172182
_id: "0", name: "Nop", outputCount: 1)
@@ -231,6 +241,7 @@ final class LazyTensorOperationTests: XCTestCase {
231241
("testOptionalTensorShapeAttribute", testOptionalTensorShapeAttribute),
232242
("testTensorShapeArrayAttribute",
233243
testOptionalTensorShapeArrayAttribute),
244+
("testConstTensorAttribute", testConstTensorAttribute),
234245
("testArrayAttributes", testArrayAttributes),
235246
("testMultipleAttributes", testMultipleAttributes),
236247
("testFunctionAttribute", testFunctionAttribute),

0 commit comments

Comments
 (0)