Skip to content

Commit ced80e0

Browse files
committed
Handle NoLogAbsDetJacobian in ComposedFunction
1 parent acd2a0c commit ced80e0

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/with_ladj.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,19 @@ with_logabsdet_jacobian(::F, ::T) where {F,T} = NoLogAbsDetJacobian{F,T}()
8686

8787
@static if VERSION >= v"1.6"
8888
function with_logabsdet_jacobian(f::Base.ComposedFunction, x)
89-
y_inner, ladj_inner = with_logabsdet_jacobian(f.inner, x)
90-
y, ladj_outer = with_logabsdet_jacobian(f.outer, y_inner)
91-
(y, ladj_inner + ladj_outer)
89+
y_ladj_inner = with_logabsdet_jacobian(f.inner, x)
90+
if y_ladj_inner isa NoLogAbsDetJacobian
91+
NoLogAbsDetJacobian{typeof(f),typeof(x)}()
92+
else
93+
y_inner, ladj_inner = y_ladj_inner
94+
y_ladj_outer = with_logabsdet_jacobian(f.outer, y_inner)
95+
if y_ladj_outer isa NoLogAbsDetJacobian
96+
NoLogAbsDetJacobian{typeof(f),typeof(x)}()
97+
else
98+
y, ladj_outer = y_ladj_outer
99+
(y, ladj_inner + ladj_outer)
100+
end
101+
end
92102
end
93103
end
94104

test/test_with_ladj.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ include("getjacobian.jl")
1414

1515
@testset "with_logabsdet_jacobian" begin
1616
@test with_logabsdet_jacobian(sum, rand(5)) == NoLogAbsDetJacobian{typeof(sum),Vector{Float64}}()
17+
@test with_logabsdet_jacobian(sum log, 5.0f0) == NoLogAbsDetJacobian{typeof(sum ∘ log),Float32}()
18+
@test with_logabsdet_jacobian(log sum, 5.0f0) == NoLogAbsDetJacobian{typeof(log ∘ sum),Float32}()
1719
@test_throws MethodError _, _ = with_logabsdet_jacobian(sum, rand(5))
1820

1921
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)

0 commit comments

Comments
 (0)