Skip to content

Commit ac6e736

Browse files
authored
simple fix for #140 (#142)
* simple fix for #140 * test for issue #140 * print test name for issue 140 * a fix of #140 on Julia 1.0; going back to the old stream of random numbers
1 parent 82fa0ac commit ac6e736

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

src/derivatives/broadcast.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,10 @@ for (M, f, arity) in DiffRules.diffrules()
276276
end
277277
for A in ARRAY_TYPES
278278
@eval begin
279-
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A,TrackedArray}}) = _materialize(bc.f, bc.args)
280-
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray, $A}}) = _materialize(bc.f, bc.args)
281-
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, TrackedReal}}) = _materialize(bc.f, bc.args)
282-
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,$A}}) = _materialize(bc.f, bc.args)
279+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A{<:Number},TrackedArray}}) = _materialize(bc.f, bc.args)
280+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray, $A{<:Number}}}) = _materialize(bc.f, bc.args)
281+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A{<:Number}, TrackedReal}}) = _materialize(bc.f, bc.args)
282+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,$A{<:Number}}}) = _materialize(bc.f, bc.args)
283283
end
284284
end
285285
for R in REAL_TYPES

src/derivatives/elementwise.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,10 @@ for (F, broadcast_f) in ((typeof(+), :broadcast_plus),
337337
end
338338
for A in ARRAY_TYPES
339339
@eval begin
340-
@inline Base.broadcast(::$F, x::TrackedArray{X,D}, y::$A) where {X,D} = $(broadcast_f)(x, y, D)
341-
@inline Base.broadcast(::$F, x::$A, y::TrackedArray{Y,D}) where {Y,D} = $(broadcast_f)(x, y, D)
342-
@inline Base.broadcast(::$F, x::TrackedReal{X,D}, y::$A) where {X,D} = $(broadcast_f)(x, y, D)
343-
@inline Base.broadcast(::$F, x::$A, y::TrackedReal{Y,D}) where {Y,D} = $(broadcast_f)(x, y, D)
340+
@inline Base.broadcast(::$F, x::TrackedArray{X,D}, y::$A{<:Real}) where {X,D} = $(broadcast_f)(x, y, D)
341+
@inline Base.broadcast(::$F, x::$A{<:Real}, y::TrackedArray{Y,D}) where {Y,D} = $(broadcast_f)(x, y, D)
342+
@inline Base.broadcast(::$F, x::TrackedReal{X,D}, y::$A{<:Real}) where {X,D} = $(broadcast_f)(x, y, D)
343+
@inline Base.broadcast(::$F, x::$A{<:Real}, y::TrackedReal{Y,D}) where {Y,D} = $(broadcast_f)(x, y, D)
344344
end
345345
end
346346
for R in REAL_TYPES

test/api/GradientTests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ function test_ternary_gradient(f, a, b, c)
172172
end
173173
end
174174

175+
# issue https://github.com/JuliaDiff/ReverseDiff.jl/issues/140
176+
nested_array_mul_140(x) = sum(sum(x[1] * [[x[2], x[3]]]))
177+
test_println("Issue #140", nested_array_mul_140)
178+
test_unary_gradient(nested_array_mul_140, [1.0, 2.0, 1.0, -2.4, 4.0])
179+
175180
for f in DiffTests.MATRIX_TO_NUMBER_FUNCS
176181
test_println("MATRIX_TO_NUMBER_FUNCS", f)
177182
test_unary_gradient(f, rand(5, 5))

0 commit comments

Comments
 (0)