Skip to content

Commit f522a8f

Browse files
authored
Handle ChainRulesCore via Pkg extension (#206)
1 parent 50ace9c commit f522a8f

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.10.0"
3+
version = "3.10.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
99
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1010

11+
[weakdeps]
12+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
14+
[extensions]
15+
LinearMapsChainRulesCoreExt = "ChainRulesCore"
16+
1117
[compat]
1218
ChainRulesCore = "1"
1319
julia = "1.6"

src/chainrules.jl renamed to ext/LinearMapsChainRulesCoreExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
module LinearMapsChainRulesCoreExt
2+
3+
using ChainRulesCore: unthunk, NoTangent, @thunk, @not_implemented
4+
import ChainRulesCore: rrule
5+
6+
using LinearMaps
7+
18
function rrule(::typeof(*), A::LinearMap, x::AbstractVector)
29
y = A*x
310
function pullback(dy)
@@ -15,3 +22,6 @@ function rrule(A::LinearMap, x::AbstractVector)
1522
end
1623
return y, pullback
1724
end
25+
26+
27+
end # module ChainRulesCore

src/LinearMaps.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ using SparseArrays
1010

1111
import Statistics: mean
1212

13-
using ChainRulesCore: unthunk, NoTangent, @thunk, @not_implemented
14-
import ChainRulesCore: rrule
15-
1613
using Base: require_one_based_indexing
1714

1815
abstract type LinearMap{T} end
@@ -354,7 +351,6 @@ include("conversion.jl") # conversion of linear maps to matrices
354351
include("show.jl") # show methods for LinearMap objects
355352
include("getindex.jl") # getindex functionality
356353
include("inversemap.jl")
357-
include("chainrules.jl") # AD rules through ChainRulesCore
358354

359355
"""
360356
LinearMap(A::LinearMap; kwargs...)::WrappedMap
@@ -424,4 +420,8 @@ LinearMap(A::MapOrVecOrMat, dims::Dims{2}; offset::Dims{2}) =
424420
LinearMap{T}(A::MapOrVecOrMat; kwargs...) where {T} = WrappedMap{T}(A; kwargs...)
425421
LinearMap{T}(f, args...; kwargs...) where {T} = FunctionMap{T}(f, args...; kwargs...)
426422

423+
@static if !isdefined(Base, :get_extension)
424+
include("../ext/LinearMapsChainRulesCoreExt.jl")
425+
end
426+
427427
end # module

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test, LinearMaps, Aqua
22

33
@testset "code quality" begin
4-
Aqua.test_all(LinearMaps)
4+
Aqua.test_all(LinearMaps, project_toml_formatting=VERSIONv"1.7")
55
end
66

77
include("linearmaps.jl")

0 commit comments

Comments
 (0)