@@ -35,6 +35,17 @@ public struct Tensor<Scalar: TensorFlowScalar> {
35
35
/// - Note: `handle` is public to allow user defined ops, but should not normally be used.
36
36
public let handle : TensorHandle < Scalar >
37
37
38
+ /// An internal marker to identify scalar zero tensors, for use in optimizations.
39
+ @usableFromInline
40
+ internal var _isScalarZero = false
41
+
42
+ /// An internal workaround for SR-13263: debug info generation crash.
43
+ @usableFromInline
44
+ class SR13263Workaround { }
45
+
46
+ /// An internal workaround for SR-13263: debug info generation crash.
47
+ internal var _sr13263Workaround : SR13263Workaround ?
48
+
38
49
@inlinable
39
50
public init ( handle: TensorHandle < Scalar > ) {
40
51
self . handle = handle
@@ -669,6 +680,7 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
669
680
if _DeviceThreadLocalState. local. isReducedPrecision {
670
681
zero = zero. toReducedPrecision
671
682
}
683
+ zero. _isScalarZero = true
672
684
return zero
673
685
}
674
686
#else
@@ -681,15 +693,23 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
681
693
@inlinable
682
694
@differentiable ( where Scalar: TensorFlowFloatingPoint)
683
695
public static func + ( lhs: Tensor , rhs: Tensor ) -> Tensor {
684
- _Raw. addV2 ( lhs, rhs)
696
+ if lhs. _isScalarZero {
697
+ return rhs
698
+ } else if rhs. _isScalarZero {
699
+ return lhs
700
+ }
701
+ return _Raw. addV2 ( lhs, rhs)
685
702
}
686
703
687
704
/// Subtracts one tensor from another and produces their difference.
688
705
/// - Note: `-` supports broadcasting.
689
706
@inlinable
690
707
@differentiable ( where Scalar: TensorFlowFloatingPoint)
691
708
public static func - ( lhs: Tensor , rhs: Tensor ) -> Tensor {
692
- _Raw. sub ( lhs, rhs)
709
+ if rhs. _isScalarZero {
710
+ return lhs
711
+ }
712
+ return _Raw. sub ( lhs, rhs)
693
713
}
694
714
}
695
715
0 commit comments