Skip to content

Commit ef88fdb

Browse files
authored
Make v and hcat with numbers work. (#514)
* Make v and hcat with numbers work. * change RArray to Union{AnyConcrerteRArray, AnyTracedRArray}; add a unit test for this change * update unit tests * update unit tests: confirming vcat and hcat are working as expected * update test: improve reproducibility of test * update test: w/ deterministic values
1 parent e718f44 commit ef88fdb

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ end
202202
# StdLib Overloads
203203
include("stdlibs/LinearAlgebra.jl")
204204
include("stdlibs/Random.jl")
205+
include("stdlibs/Base.jl")
205206

206207
# Other Integrations
207208
include("Enzyme.jl")

src/stdlibs/Base.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@inline Base.vcat(a::Number, b::Union{AnyConcreteRArray, AnyTracedRArray}) =
2+
@allowscalar(vcat(fill!(similar(b, typeof(a), (1, size(b)[2:end]...)), a), b))
3+
@inline Base.hcat(a::Number, b::Union{AnyConcreteRArray, AnyTracedRArray}) =
4+
@allowscalar(hcat(fill!(similar(b, typeof(a), (size(b)[1:end-1]..., 1)), a), b))

test/basic.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,34 @@ end
368368
@test y == test_typed_hvncat(x)
369369
@test eltype(y) === Int
370370
end
371+
372+
@testset "Number and RArray" for a in [1.0f0, 1.0e0]
373+
typeof_a = typeof(a)
374+
_b = [2.0, 3.0, 4.0] .|> typeof_a
375+
_c = [2.0 3.0 4.0] .|> typeof_a
376+
b = Reactant.to_rarray(_b)
377+
c = Reactant.to_rarray(_c)
378+
379+
# vcat test
380+
y = @jit vcat(a, b)
381+
@test y == vcat(a, _b)
382+
@test y isa ConcreteRArray{typeof_a,1}
383+
384+
## vcat test - adjoint
385+
y1 = @jit vcat(a, c')
386+
@test y1 == vcat(a, _c')
387+
@test y1 isa ConcreteRArray{typeof_a,2}
388+
389+
# hcat test
390+
z = @jit hcat(a, c)
391+
@test z == hcat(a, _c)
392+
@test z isa ConcreteRArray{typeof_a,2}
393+
394+
## hcat test - adjoint
395+
z1 = @jit hcat(a, b')
396+
@test z1 == hcat(a, _b')
397+
@test z1 isa ConcreteRArray{typeof_a,2}
398+
end
371399
end
372400

373401
@testset "repeat" begin

0 commit comments

Comments
 (0)