Skip to content

Commit bd10c6c

Browse files
mlouboutdkarraschJeffFessler
authored
AD support via chainrules (#177)
Co-authored-by: Daniel Karrasch <[email protected]> Co-authored-by: Jeff Fessler <[email protected]>
1 parent e19a7cc commit bd10c6c

File tree

7 files changed

+56
-1
lines changed

7 files changed

+56
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.8.0"
3+
version = "3.9.0"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
89
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
910

1011
[compat]
12+
ChainRulesCore = "1"
1113
julia = "1.6"

docs/src/history.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Version history
22

3+
## What's new in v3.9
4+
5+
* The application of `LinearMap`s to vectors operation, i.e., `(A,x) -> A*x = A(x)`, is now
6+
differentiable w.r.t. to the input `x` for integration with machine learning frameworks
7+
such as [`Flux.jl`](https://fluxml.ai/Flux.jl/stable/). The reverse differentiation rule
8+
makes `A::LinearMap` usable as a static, i.e., non-trainable, layer in a network, and
9+
requires the adjoint `A'` of `A` to be defined.
10+
311
## What's new in v3.8
412

513
* A new map called [`InverseMap`](@ref) is introduced. Letting an `InverseMap` act on a

src/LinearMaps.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using SparseArrays
1111

1212
import Statistics: mean
1313

14+
using ChainRulesCore: unthunk, NoTangent, @thunk
15+
import ChainRulesCore: rrule
16+
1417
using Base: require_one_based_indexing
1518

1619
abstract type LinearMap{T} end
@@ -347,6 +350,7 @@ include("conversion.jl") # conversion of linear maps to matrices
347350
include("show.jl") # show methods for LinearMap objects
348351
include("getindex.jl") # getindex functionality
349352
include("inversemap.jl")
353+
include("chainrules.jl") # AD rules through ChainRulesCore
350354

351355
"""
352356
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/chainrules.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function rrule(::typeof(*), A::LinearMap, x::AbstractVector)
2+
y = A*x
3+
function pullback(dy)
4+
DY = unthunk(dy)
5+
# Because A is an abstract map, the product is only differentiable w.r.t the input
6+
return NoTangent(), NoTangent(), @thunk(A' * DY)
7+
end
8+
return y, pullback
9+
end
10+
11+
function rrule(A::LinearMap, x::AbstractVector)
12+
y = A*x
13+
function pullback(dy)
14+
DY = unthunk(dy)
15+
# Because A is an abstract map, the product is only differentiable w.r.t the input
16+
return NoTangent(), @thunk(A' * DY)
17+
end
18+
return y, pullback
19+
end

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
4+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
46
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
57
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
68
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -12,5 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1214
[compat]
1315
Aqua = "0.5"
1416
BlockArrays = "0.16"
17+
ChainRulesCore = "1"
18+
ChainRulesTestUtils = "1.9"
1519
Quaternions = "0.5"
1620
julia = "1.6"

test/rrules.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Test, LinearMaps, ChainRulesTestUtils
2+
using ChainRulesCore: NoTangent
3+
4+
@testset "AD rules" begin
5+
x = randn(10)
6+
for A in (
7+
LinearMap(rand(10, 10)),
8+
LinearMap(cumsum, reversecumsumreverse, 10),
9+
LinearMap((y, x) -> cumsum!(y, x), (y, x) -> reverse!(cumsum!(y, reverse!(copyto!(y, x)))), 10)
10+
)
11+
test_rrule(*, A NoTangent(), x)
12+
test_rrule(A NoTangent(), x)
13+
test_rrule(*, A' NoTangent(), x)
14+
test_rrule(A' NoTangent(), x)
15+
end
16+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@ include("embeddedmap.jl")
3939
include("getindex.jl")
4040

4141
include("inversemap.jl")
42+
43+
include("rrules.jl")

0 commit comments

Comments
 (0)