Skip to content

Commit 7ab5c4e

Browse files
KristofferCoschulz
authored andcommitted
make ChainRulesCore dependency into an extension
1 parent 41ca2a3 commit 7ab5c4e

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99

10+
[weakdeps]
11+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12+
13+
[extensions]
14+
ChainRulesCoreExt = "ChainRulesCore"
15+
1016
[compat]
1117
ChainRulesCore = "1"
1218
julia = "1"
1319

1420
[extras]
21+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1522
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1623
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1724
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1825

1926
[targets]
20-
test = ["ChainRulesTestUtils", "Documenter", "ForwardDiff"]
27+
test = ["ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff"]

ext/ChainRulesCoreExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module ChainRulesCoreExt
2+
3+
using ChainRulesCore
4+
5+
import ChangesOfVariables: _with_ladj_on_mapped
6+
7+
# Need to use a type for this, type inference fails when using a pullback
8+
# closure over YLT in the rrule, resulting in bad performance:
9+
struct WithLadjOnMappedPullback{YLT} <: Function end
10+
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
11+
ys, ladj = unthunk(thunked_ΔΩ)
12+
return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys)
13+
end
14+
15+
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
16+
YLT = eltype(y_with_ladj)
17+
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
18+
end
19+
20+
end

src/ChangesOfVariables.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ transformations).
99
"""
1010
module ChangesOfVariables
1111

12-
using ChainRulesCore
1312
using LinearAlgebra
1413
using Test
1514

1615
include("with_ladj.jl")
1716
include("test.jl")
17+
if !isdefined(Base, :get_extension)
18+
include("../ext/ChainRulesCoreExt.jl")
19+
end
1820

1921
end # module

src/with_ladj.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,6 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(
114114
end
115115

116116

117-
# Need to use a type for this, type inference fails when using a pullback
118-
# closure over YLT in the rrule, resulting in bad performance:
119-
struct WithLadjOnMappedPullback{YLT} <: Function end
120-
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
121-
ys, ladj = unthunk(thunked_ΔΩ)
122-
return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys)
123-
end
124-
125-
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
126-
YLT = eltype(y_with_ladj)
127-
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
128-
end
129-
130117
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
131118
map_or_bc = mapped_f.f
132119
f = mapped_f.x

0 commit comments

Comments
 (0)