Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c499e16

Browse files
eaplataniosrxwei
authored andcommitted
Minor fix for AD. (#106)
1 parent d2c78f4 commit c499e16

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

Sources/DeepLearning/Operators/Math.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ public extension Tensor where Scalar == Bool {
875875
@inlinable
876876
func all() -> Bool {
877877
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
878-
return _TFGetScalarOrDie(Raw.all(self, reductionIndices: axes).handle)
878+
return Raw.all(self, reductionIndices: axes).scalarized()
879879
}
880880

881881
/// Returns `true` if any scalars are equal to `true`. Otherwise, returns `false`.
@@ -884,7 +884,7 @@ public extension Tensor where Scalar == Bool {
884884
@inlinable
885885
func any() -> Bool {
886886
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
887-
return _TFGetScalarOrDie(Raw.any(self, reductionIndices: axes).handle)
887+
return Raw.any(self, reductionIndices: axes).scalarized()
888888
}
889889

890890
/// Performs a logical AND operation along the specified axes. The reduced dimensions are

Sources/DeepLearning/Tensors.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ infix operator .==: ComparisonPrecedence
2727
public extension Tensor {
2828
/// The rank of the tensor, represented as a `Tensor<Int32>`.
2929
@inlinable
30+
@_semantics("autodiff.nonvarying")
3031
var rankTensor: Tensor<Int32> {
3132
return Raw.rank(self)
3233
}
3334

3435
/// The dimensions of the tensor, represented as a `Tensor<Int32>`.
3536
@inlinable
37+
@_semantics("autodiff.nonvarying")
3638
var shapeTensor: Tensor<Int32> {
3739
return Raw.shape(self)
3840
}
3941

4042
/// The number of scalars in the tensor, represented as a `Tensor<Int32>`.
4143
@inlinable
44+
@_semantics("autodiff.nonvarying")
4245
var scalarCountTensor: Tensor<Int32> {
4346
return Raw.size(self)
4447
}
@@ -53,6 +56,7 @@ extension Tensor: CustomStringConvertible {
5356
/// A textual representation of the tensor.
5457
///
5558
/// - Note: use `fullDescription` for a non-pretty-printed description showing all scalars.
59+
@_semantics("autodiff.nonvarying")
5660
public var description: String {
5761
return array.description
5862
}
@@ -69,6 +73,7 @@ public extension Tensor {
6973
/// via ellipses (`...`).
7074
/// - summarizing: If true, summarize description if element count exceeds twice
7175
/// `edgeElementCount`.
76+
@_semantics("autodiff.nonvarying")
7277
func description(
7378
lineWidth: Int = 80,
7479
edgeElementCount: Int = 3,
@@ -82,20 +87,23 @@ public extension Tensor {
8287

8388
/// A full, non-pretty-printed textual representation of the tensor, showing
8489
/// all scalars.
90+
@_semantics("autodiff.nonvarying")
8591
var fullDescription: String {
8692
return array.fullDescription
8793
}
8894
}
8995

9096
// Xcode Playground display conversion.
9197
extension Tensor: CustomPlaygroundDisplayConvertible {
98+
@_semantics("autodiff.nonvarying")
9299
public var playgroundDescription: Any {
93100
return description
94101
}
95102
}
96103

97104
// Mirror representation, used by debugger/REPL.
98105
extension Tensor: CustomReflectable {
106+
@_semantics("autodiff.nonvarying")
99107
public var customMirror: Mirror {
100108
return Mirror(self, children: [], displayStyle: .struct)
101109
}

0 commit comments

Comments
 (0)