Skip to content

Commit 3867248

Browse files
committed
AD support via chainrules
1 parent aa54eb3 commit 3867248

File tree

6 files changed

+48
-0
lines changed

6 files changed

+48
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
33
version = "3.6.1"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
89
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
910

1011
[compat]
12+
ChainRulesCore = "1"
1113
julia = "1.6"

src/LinearMaps.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using SparseArrays
1010

1111
import Statistics: mean
1212

13+
import ChainRulesCore: rrule, frule, unthunk, @thunk, NoTangent
14+
1315
using Base: require_one_based_indexing
1416

1517
abstract type LinearMap{T} end
@@ -268,6 +270,7 @@ include("kronecker.jl") # Kronecker product of linear maps
268270
include("fillmap.jl") # linear maps representing constantly filled matrices
269271
include("conversion.jl") # conversion of linear maps to matrices
270272
include("show.jl") # show methods for LinearMap objects
273+
include("chainrules.jl") # AD rules through ChainRulesCore
271274

272275
"""
273276
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/chainrules.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function rrule(::typeof(*), A::LinearMap, x::AbstractVecOrMat)
2+
y = A*x
3+
function pullback(dy)
4+
DY = unthunk(dy)
5+
# Because A is an abstract map, the product is only differentiable w.r.t the input
6+
return NoTangent(), NoTangent(), @thunk(A' * DY)
7+
end
8+
return y, pullback
9+
end
10+
11+
function rrule(A::LinearMap, x::AbstractVecOrMat)
12+
y = A(x)
13+
function pullback(dy)
14+
DY = unthunk(dy)
15+
# Because A is an abstract map, the product is only differentiable w.r.t the input
16+
return NoTangent(), @thunk(A' * DY)
17+
end
18+
return y, pullback
19+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
45
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"

test/rrules.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Flux
2+
3+
@testset "AD rules" begin
4+
A = LinearMap(rand(10, 10))
5+
x = randn(10)
6+
g1 = A'*A*x
7+
# Multiplication rule
8+
g2 = gradient(x -> .5*norm(A*x)^2, x)
9+
@test g1 g2[1]
10+
# Call rule
11+
g3 = gradient(x -> .5*norm(A(x))^2, x)
12+
@test g1 g3[1]
13+
14+
g1 = A*A'*x
15+
# Multiplication rule
16+
g2 = gradient(x -> .5*norm(A'*x)^2, x)
17+
@test g1 g2[1]
18+
# Call rule
19+
g3 = gradient(x -> .5*norm(A'(x))^2, x)
20+
@test g1 g3[1]
21+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,5 @@ include("left.jl")
3333
include("fillmap.jl")
3434

3535
include("nontradaxes.jl")
36+
37+
include("rrule.jl")

0 commit comments

Comments
 (0)