Skip to content

Commit 68bc950

Browse files
committed
Add ProjectTo for CA
1 parent efff2cd commit 68bc950

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.13.9"
4+
version = "0.13.10"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/compat/chainrulescore.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val})
1+
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
22
function getproperty_adjoint(Δ)
33
zero_x = zero(similar(x, eltype(Δ)))
44
setproperty!(zero_x, s, Δ)
@@ -8,6 +8,12 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union
88
return getproperty(x, s), getproperty_adjoint
99
end
1010

11-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
11+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
1212

13-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
13+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
14+
15+
function ChainRulesCore.ProjectTo(ca::ComponentArray)
16+
return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
17+
end
18+
19+
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::AbstractArray) = ComponentArray(p.project(dx), p.axes)

test/autodiff_tests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ truth = ComponentArray(a = [32, 48], x = 156)
5050
@test out isa Vector{<:ForwardDiff.Dual}
5151
end
5252

53+
@testset "Projection" begin
54+
gs_ca = Zygote.gradient(sum, ca)[1]
55+
56+
@test gs_ca isa ComponentArray
57+
end
58+
5359

5460
# # This is commented out because the gradient operation itself is broken due to Zygote's inability
5561
# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple.

0 commit comments

Comments
 (0)