Skip to content

Commit 0c45261

Browse files
ksasidan-zheng
authored andcommitted
TensorShape.swift file updated to improve TensorShape printing (#24253)
1 parent 716575e commit 0c45261

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

stdlib/public/TensorFlow/TensorShape.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,9 @@ extension TensorShape : Codable {
159159
self.init(dimensions)
160160
}
161161
}
162+
163+
extension TensorShape : CustomStringConvertible {
164+
public var description: String {
165+
return dimensions.description
166+
}
167+
}

test/TensorFlowRuntime/tensor.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,11 @@ TensorTests.testAllBackends("SimpleCond") {
623623
expectEqual(0, selectValue(true).scalar)
624624
}
625625

626+
TensorTests.testAllBackends("TensorShapeDescription") {
627+
expectEqual("[2, 2]", Tensor<Int32>(ones: [2, 2]).shape.description)
628+
expectEqual("[]", Tensor(1).shape.description)
629+
}
630+
626631
@inline(never)
627632
func testXORInference() {
628633
func xor(_ x: Float, _ y: Float) -> Float {

0 commit comments

Comments
 (0)