Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Minor fix for AD. #106

Merged
merged 2 commits into from
Apr 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Sources/DeepLearning/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ public extension Tensor where Scalar == Bool {
@inlinable
func all() -> Bool {
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
return _TFGetScalarOrDie(Raw.all(self, reductionIndices: axes).handle)
return Raw.all(self, reductionIndices: axes).scalarized()
}

/// Returns `true` if any scalars are equal to `true`. Otherwise, returns `false`.
Expand All @@ -884,7 +884,7 @@ public extension Tensor where Scalar == Bool {
@inlinable
func any() -> Bool {
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
return _TFGetScalarOrDie(Raw.any(self, reductionIndices: axes).handle)
return Raw.any(self, reductionIndices: axes).scalarized()
}

/// Performs a logical AND operation along the specified axes. The reduced dimensions are
Expand Down
8 changes: 8 additions & 0 deletions Sources/DeepLearning/Tensors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ infix operator .==: ComparisonPrecedence
public extension Tensor {
/// The rank of the tensor, represented as a `Tensor<Int32>`.
@inlinable
@_semantics("autodiff.nonvarying")
var rankTensor: Tensor<Int32> {
return Raw.rank(self)
}

/// The dimensions of the tensor, represented as a `Tensor<Int32>`.
@inlinable
@_semantics("autodiff.nonvarying")
var shapeTensor: Tensor<Int32> {
return Raw.shape(self)
}

/// The number of scalars in the tensor, represented as a `Tensor<Int32>`.
@inlinable
@_semantics("autodiff.nonvarying")
var scalarCountTensor: Tensor<Int32> {
return Raw.size(self)
}
Expand All @@ -53,6 +56,7 @@ extension Tensor: CustomStringConvertible {
/// A textual representation of the tensor.
///
/// - Note: use `fullDescription` for a non-pretty-printed description showing all scalars.
@_semantics("autodiff.nonvarying")
public var description: String {
return array.description
}
Expand All @@ -69,6 +73,7 @@ public extension Tensor {
/// via ellipses (`...`).
/// - summarizing: If true, summarize description if element count exceeds twice
/// `edgeElementCount`.
@_semantics("autodiff.nonvarying")
func description(
lineWidth: Int = 80,
edgeElementCount: Int = 3,
Expand All @@ -82,20 +87,23 @@ public extension Tensor {

/// A full, non-pretty-printed textual representation of the tensor, showing
/// all scalars.
@_semantics("autodiff.nonvarying")
var fullDescription: String {
return array.fullDescription
}
}

// Xcode Playground display conversion.
extension Tensor: CustomPlaygroundDisplayConvertible {
@_semantics("autodiff.nonvarying")
public var playgroundDescription: Any {
return description
}
}

// Mirror representation, used by debugger/REPL.
extension Tensor: CustomReflectable {
@_semantics("autodiff.nonvarying")
public var customMirror: Mirror {
return Mirror(self, children: [], displayStyle: .struct)
}
Expand Down