Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 84797b2

Browse files
Tagging zero scalar tensors. (#1081)
* Tagging scalar zero tensors. * Update Sources/TensorFlow/Core/Tensor.swift Co-authored-by: Dan Zheng <[email protected]> * Attempt to workaround SR-13263: debug info generation crash. Co-authored-by: Brad Larson <[email protected]> Co-authored-by: Brad Larson <[email protected]>
1 parent 92a96bc commit 84797b2

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ public struct Tensor<Scalar: TensorFlowScalar> {
3535
/// - Note: `handle` is public to allow user defined ops, but should not normally be used.
3636
public let handle: TensorHandle<Scalar>
3737

38+
/// An internal marker to identify scalar zero tensors, for use in optimizations.
39+
@usableFromInline
40+
internal var _isScalarZero = false
41+
42+
/// An internal workaround for SR-13263: debug info generation crash.
43+
@usableFromInline
44+
class SR13263Workaround {}
45+
46+
/// An internal workaround for SR-13263: debug info generation crash.
47+
internal var _sr13263Workaround: SR13263Workaround?
48+
3849
@inlinable
3950
public init(handle: TensorHandle<Scalar>) {
4051
self.handle = handle
@@ -669,6 +680,7 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
669680
if _DeviceThreadLocalState.local.isReducedPrecision {
670681
zero = zero.toReducedPrecision
671682
}
683+
zero._isScalarZero = true
672684
return zero
673685
}
674686
#else
@@ -681,15 +693,23 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
681693
@inlinable
682694
@differentiable(where Scalar: TensorFlowFloatingPoint)
683695
public static func + (lhs: Tensor, rhs: Tensor) -> Tensor {
684-
_Raw.addV2(lhs, rhs)
696+
if lhs._isScalarZero {
697+
return rhs
698+
} else if rhs._isScalarZero {
699+
return lhs
700+
}
701+
return _Raw.addV2(lhs, rhs)
685702
}
686703

687704
/// Subtracts one tensor from another and produces their difference.
688705
/// - Note: `-` supports broadcasting.
689706
@inlinable
690707
@differentiable(where Scalar: TensorFlowFloatingPoint)
691708
public static func - (lhs: Tensor, rhs: Tensor) -> Tensor {
692-
_Raw.sub(lhs, rhs)
709+
if rhs._isScalarZero {
710+
return lhs
711+
}
712+
return _Raw.sub(lhs, rhs)
693713
}
694714
}
695715

0 commit comments

Comments
 (0)