Skip to content

Commit 02c95fe

Browse files
authored
Add ProjectTo for CA (#202)
* Add ProjectTo for CA * Fix tests: gradients for Int --> Float
1 parent 17faef8 commit 02c95fe

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.13.9"
4+
version = "0.13.10"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/compat/chainrulescore.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val})
1+
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
22
function getproperty_adjoint(Δ)
33
zero_x = zero(similar(x, eltype(Δ)))
44
setproperty!(zero_x, s, Δ)
@@ -8,6 +8,12 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union
88
return getproperty(x, s), getproperty_adjoint
99
end
1010

11-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
11+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
1212

13-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
13+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
14+
15+
function ChainRulesCore.ProjectTo(ca::ComponentArray)
16+
return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
17+
end
18+
19+
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::AbstractArray) = ComponentArray(p.project(dx), p.axes)

test/autodiff_tests.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,22 @@ truth = ComponentArray(a = [32, 48], x = 156)
3232
@test zygote_full truth
3333
end
3434

35-
# Not sure why this doesn't work in v1.2, but I don't want to drop the tests for that just
36-
# for this to work
37-
if VERSION v"1.6"
38-
@test ComponentArray(x=4,) == Zygote.gradient(ComponentArray(x=2,)) do c
39-
(;c...,).x^2
40-
end[1]
41-
else
42-
@test_skip ComponentArray(x=4,) == Zygote.gradient(ComponentArray(x=2,)) do c
43-
(;c...,).x^2
44-
end[1]
45-
end
35+
@test ComponentArray(x=4.0,) Zygote.gradient(ComponentArray(x=2,)) do c
36+
(;c...,).x^2
37+
end[1]
4638

4739
# Issue #148
4840
ps = ComponentArray(;bias = rand(4))
4941
out = Zygote.gradient(x -> sum(x.^3 .+ ps.bias), Zygote.seed(rand(4),Val(12)))[1]
5042
@test out isa Vector{<:ForwardDiff.Dual}
5143
end
5244

45+
@testset "Projection" begin
46+
gs_ca = Zygote.gradient(sum, ca)[1]
47+
48+
@test gs_ca isa ComponentArray
49+
end
50+
5351

5452
# # This is commented out because the gradient operation itself is broken due to Zygote's inability
5553
# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple.

0 commit comments

Comments
 (0)