Skip to content

Commit 6217c74

Browse files
authored
Implement isnan and isfinite for TracedRNumber (#525)
* Implement `isnan` for TracedRNumber * isfinite and complex * update
1 parent a5fb8cb commit 6217c74

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/TracedRNumber.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T}
2323
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
2424
end
2525

26+
function Base.isfinite(x::TracedRNumber{<:Complex})
27+
return isfinite(real(x)) & isfinite(imag(x))
28+
end
29+
function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat}
30+
return Reactant.Ops.is_finite(x)
31+
end
32+
33+
function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat}
34+
return !isfinite(x) & (x != typemax(T)) & (x != typemin(T))
35+
end
36+
function Base.isnan(x::TracedRNumber{<:Complex})
37+
return isnan(real(x)) | isnan(imag(x))
38+
end
39+
2640
function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
2741
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
2842
end

test/basic.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,26 +371,26 @@ end
371371

372372
@testset "Number and RArray" for a in [1.0f0, 1.0e0]
373373
typeof_a = typeof(a)
374-
_b = [2.0, 3.0, 4.0] .|> typeof_a
375-
_c = [2.0 3.0 4.0] .|> typeof_a
374+
_b = typeof_a.([2.0, 3.0, 4.0])
375+
_c = typeof_a.([2.0 3.0 4.0])
376376
b = Reactant.to_rarray(_b)
377377
c = Reactant.to_rarray(_c)
378-
378+
379379
# vcat test
380380
y = @jit vcat(a, b)
381381
@test y == vcat(a, _b)
382382
@test y isa ConcreteRArray{typeof_a,1}
383-
383+
384384
## vcat test - adjoint
385385
y1 = @jit vcat(a, c')
386386
@test y1 == vcat(a, _c')
387387
@test y1 isa ConcreteRArray{typeof_a,2}
388-
388+
389389
# hcat test
390390
z = @jit hcat(a, c)
391391
@test z == hcat(a, _c)
392392
@test z isa ConcreteRArray{typeof_a,2}
393-
393+
394394
## hcat test - adjoint
395395
z1 = @jit hcat(a, b')
396396
@test z1 == hcat(a, _b')
@@ -1028,3 +1028,19 @@ end
10281028
@test res[2] isa ConcreteRNumber{Float32}
10291029
end
10301030
end
1031+
1032+
@testset "isfinite" begin
1033+
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
1034+
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]
1035+
1036+
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
1037+
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]
1038+
end
1039+
1040+
@testset "isnan" begin
1041+
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
1042+
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]
1043+
1044+
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
1045+
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]
1046+
end

0 commit comments

Comments
 (0)