Skip to content

Commit da317a5

Browse files
realdougrxwei
authored andcommitted
[TF-76][Python Interop] add PythonConvertible conformance to TensorShape (#23762)
* [TF-76] add TensorShape conformance to PythonConvertible
1 parent ec50086 commit da317a5

File tree

4 files changed

+57
-4
lines changed

4 files changed

+57
-4
lines changed

stdlib/public/TensorFlow/TensorShape.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
import Python
14+
1315
// NOTE: it may be possible to edit `TensorShape` to support "labeled tensors".
1416
// Dimensions may be either an Int32 or an enum representing a label.
1517

@@ -158,3 +160,20 @@ extension TensorShape : Codable {
158160
self.init(dimensions)
159161
}
160162
}
163+
164+
extension TensorShape : PythonConvertible {
165+
public var pythonObject: PythonObject {
166+
return dimensions.pythonObject
167+
}
168+
169+
public init?(_ pythonObject: PythonObject) {
170+
let hasLen = Bool(Python.hasattr(pythonObject, "__len__"))
171+
if(hasLen == true) {
172+
guard let array = [Int32](pythonObject) else { return nil }
173+
self.init(array)
174+
} else {
175+
guard let num = Int32(pythonObject) else { return nil }
176+
self.init(num)
177+
}
178+
}
179+
}

test/Python/python_runtime.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ PythonRuntimeTestSuite.testWithLeakChecking("PythonList") {
6666
}
6767

6868
PythonRuntimeTestSuite.testWithLeakChecking("PythonDict") {
69-
let dict: PythonObject = ["a": 1, 1: 0.5]
69+
let dict: PythonObject = ["a" : 1, 1 : 0.5]
7070
expectEqual(2, Python.len(dict))
7171
expectEqual(1, dict["a"])
7272
expectEqual(0.5, dict[1])
@@ -210,7 +210,7 @@ PythonRuntimeTestSuite.testWithLeakChecking("Tuple") {
210210
let element1: PythonObject = 0
211211
let element2: PythonObject = "abc"
212212
let element3: PythonObject = [0, 0]
213-
let element4: PythonObject = ["a": 0, "b": "c"]
213+
let element4: PythonObject = ["a" : 0, "b" : "c"]
214214
let pair = PythonObject(tupleOf: element1, element2)
215215
let (pair1, pair2) = pair.tuple2
216216
expectEqual(element1, pair1)
@@ -258,6 +258,8 @@ PythonRuntimeTestSuite.testWithLeakChecking("ConvertibleFromPython") {
258258
let five: PythonObject = 5
259259
let half: PythonObject = 0.5
260260
let string: PythonObject = "abc"
261+
let intArray: PythonObject = [2, 3]
262+
let dict: PythonObject = ["abc" : 97]
261263

262264
expectEqual(-1, Int(minusOne))
263265
expectEqual(-1, Int8(minusOne))
@@ -285,6 +287,9 @@ PythonRuntimeTestSuite.testWithLeakChecking("ConvertibleFromPython") {
285287

286288
expectEqual("abc", String(string))
287289

290+
expectEqual([2, 3], Array(intArray))
291+
expectEqual(["abc" : 97], Dictionary<String, Int32>(dict))
292+
288293
expectNil(String(zero))
289294
expectNil(Int(string))
290295
expectNil(Double(string))
@@ -293,6 +298,8 @@ PythonRuntimeTestSuite.testWithLeakChecking("ConvertibleFromPython") {
293298
PythonRuntimeTestSuite.testWithLeakChecking("PythonConvertible") {
294299
let minusOne: PythonObject = -1
295300
let five: PythonObject = 5
301+
let intArray: PythonObject = [2, 3]
302+
let dict: PythonObject = ["abc" : 7]
296303

297304
expectEqual(minusOne, Int(-1).pythonObject)
298305
expectEqual(minusOne, Int8(-1).pythonObject)
@@ -309,6 +316,9 @@ PythonRuntimeTestSuite.testWithLeakChecking("PythonConvertible") {
309316
expectEqual(five, UInt64(5).pythonObject)
310317
expectEqual(five, Float(5).pythonObject)
311318
expectEqual(five, Double(5).pythonObject)
319+
320+
expectEqual(intArray, [2, 3].pythonObject)
321+
expectEqual(dict, ["abc" : 7].pythonObject)
312322
}
313323

314324
PythonRuntimeTestSuite.testWithLeakChecking("Optional") {

test/TensorFlow/integration.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ public func testResourceAndVariants() {
377377
// expected-error @+1 {{op named 'TensorDataSet' is not registered in TensorFlow}}
378378
#tfop("TensorDataSet", values,
379379
Toutput_types$dtype: [Float.tensorFlowDataType],
380-
output_shapes: [TensorShape(1)])
380+
output_shapes: [TensorShape([1])])
381381

382382
// REGISTER_OP("Iterator")
383383
// .Output("handle: resource")
@@ -388,7 +388,7 @@ public func testResourceAndVariants() {
388388
// .SetShapeFn(shape_inference::ScalarShape);
389389
let iterator: ResourceHandle =
390390
#tfop("Iterator", shared_name: "foo", container: "bar",
391-
output_types$dtype: [Float.tensorFlowDataType], output_shapes: [TensorShape(1)])
391+
output_types$dtype: [Float.tensorFlowDataType], output_shapes: [TensorShape([1])])
392392

393393
// REGISTER_OP("MakeIterator")
394394
// .Input("dataset: variant")

test/TensorFlowRuntime/numpy_conversion.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ NumpyConversionTests.test("shaped-array-conversion") {
5252
array)
5353
}
5454

55+
let reshaped = np.reshape(numpyArrayInt32, [2, 3] as TensorShape)
56+
if let array = expectNotNil(ShapedArray<Int32>(numpy: reshaped)) {
57+
expectEqual(ShapedArray(shape: [2, 3], scalars: [1, 2, 3, 4, 5, 6]),
58+
array)
59+
}
60+
61+
let numpyArray1D = np.ones(28)
62+
let reshaped3D = np.reshape(numpyArray1D, [2, 7, 2] as TensorShape)
63+
expectEqual(TensorShape(reshaped3D.shape), [2, 7, 2])
64+
let reshaped2D = np.reshape(reshaped3D, [14, 2] as TensorShape)
65+
expectEqual(TensorShape(reshaped2D.shape), [14, 2])
66+
5567
let numpyArrayStrided = np.array([[1, 2], [1, 2]], dtype: np.int32)[
5668
Python.slice(Python.None), 1]
5769
// Assert that the array has a stride, so that we're certainly testing a
@@ -95,6 +107,12 @@ NumpyConversionTests.test("tensor-conversion") {
95107
tensor.array)
96108
}
97109

110+
let reshaped = np.reshape(numpyArrayInt32, [2, 3] as TensorShape)
111+
if let tensor = expectNotNil(Tensor<Int32>(numpy: reshaped)) {
112+
expectEqual(ShapedArray(shape: [2, 3], scalars: [1, 2, 3, 4, 5, 6]),
113+
tensor.array)
114+
}
115+
98116
let numpyArrayStrided = np.array([[1, 2], [1, 2]], dtype: np.int32)[
99117
Python.slice(Python.None), 1]
100118
// Assert that the array has a stride, so that we're certainly testing a
@@ -118,6 +136,12 @@ NumpyConversionTests.test("tensor-round-trip") {
118136
let t3 = Tensor<Int32>(repeating: 30, shape: [8,5,4])
119137
expectEqual(t3, Tensor<Int32>(numpy: t3.makeNumpyArray())!)
120138
}
139+
140+
NumpyConversionTests.test("tensor-shape") {
141+
let pyArray = [2, 3].pythonObject
142+
expectEqual(pyArray, TensorShape(2, 3).pythonObject)
143+
expectEqual(TensorShape(2, 3), TensorShape(pyArray))
144+
}
121145
#endif
122146

123147
runAllTests()

0 commit comments

Comments
 (0)