Skip to content

Commit 0fe0f85

Browse files
authored
Avoiding Type Piracy in Lux (#246)
* Add more functions for tracked componentarrays * Add rrule for NamedTuple conversion * Add Functors compatibility * Add Optimisers * Add tests
1 parent fac804d commit 0fe0f85

12 files changed

+136
-22
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
- '1.6'
1919
- '1.8'
2020
- '1.9'
21+
- '1.10'
2122
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
22-
- '1.10.0-beta3'
2323
os:
2424
- ubuntu-latest
2525
arch:

Project.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.10"
4+
version = "0.15.11"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -17,20 +17,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1717
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1818
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1919
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
20+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2021
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2122
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2223
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2324
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
25+
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2426
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2527

2628
[extensions]
2729
ComponentArraysAdaptExt = "Adapt"
2830
ComponentArraysConstructionBaseExt = "ConstructionBase"
2931
ComponentArraysGPUArraysExt = "GPUArrays"
32+
ComponentArraysOptimisersExt = "Optimisers"
3033
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
3134
ComponentArraysReverseDiffExt = "ReverseDiff"
3235
ComponentArraysSciMLBaseExt = "SciMLBase"
3336
ComponentArraysTrackerExt = "Tracker"
37+
ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces"
3438
ComponentArraysZygoteExt = "Zygote"
3539

3640
[compat]
@@ -41,13 +45,17 @@ ConstructionBase = "1"
4145
ForwardDiff = "0.10"
4246
Functors = "0.4.4"
4347
GPUArrays = "8, 9, 10"
48+
LinearAlgebra = "1"
49+
Optimisers = "0.3"
4450
PackageExtensionCompat = "1"
4551
RecursiveArrayTools = "2, 3"
4652
ReverseDiff = "1"
4753
SciMLBase = "1, 2"
48-
StaticArraysCore = "1"
4954
StaticArrayInterface = "1"
55+
StaticArraysCore = "1"
56+
Test = "1"
5057
Tracker = "0.2"
58+
TruncatedStacktraces = "1.4"
5159
Zygote = "0.6"
5260
julia = "1.6"
5361

@@ -56,9 +64,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5664
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
5765
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5866
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
67+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
5968
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
6069
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
6170
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
6271
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6372
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
73+
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
6474
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/ComponentArraysOptimisersExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module ComponentArraysOptimisersExt
2+
3+
using ComponentArrays, Optimisers
4+
5+
# Optimisers can handle componentarrays by default, but we can vectorize the entire
6+
# operation here instead of doing multiple smaller operations
7+
Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))
8+
9+
function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
10+
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
11+
tree, ps_new = Optimisers.update(tree, getdata(ps), gs_flat)
12+
return tree, ComponentArray(ps_new, getaxes(ps))
13+
end
14+
15+
function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray)
16+
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
17+
tree, ps_new = Optimisers.update!(tree, getdata(ps), gs_flat)
18+
return tree, ComponentArray(ps_new, getaxes(ps))
19+
end
20+
21+
end

ext/ComponentArraysReverseDiffExt.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ComponentArraysReverseDiffExt
22

33
using ComponentArrays, ReverseDiff
44

5-
const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
5+
const TrackedComponentArray{V,D,N,DA,A,Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
66

77
maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape)
88
function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector)
@@ -12,10 +12,10 @@ function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector)
1212
end
1313

1414
for f in [:getindex, :view]
15-
@eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol, Val}...)
16-
val = $f(ReverseDiff.value(tca), inds...)
17-
der = Base.maybeview(ReverseDiff.deriv(tca), inds...)
18-
t = ReverseDiff.tape(tca)
15+
@eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol,Val}...)
16+
val = $f(ReverseDiff.value(tca), inds...)
17+
der = Base.maybeview(ReverseDiff.deriv(tca), inds...)
18+
t = ReverseDiff.tape(tca)
1919
return maybe_tracked_array(val, der, t, inds, tca)
2020
end
2121
end
@@ -31,4 +31,17 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol)
3131
end
3232
end
3333

34+
function Base.propertynames(::TrackedComponentArray{V,D,N,DA,A,Tuple{Ax}}) where {V,D,N,DA,A,Ax<:ComponentArrays.AbstractAxis}
35+
return propertynames(ComponentArrays.indexmap(Ax))
36+
end
37+
38+
function Base.NamedTuple(tca::TrackedComponentArray)
39+
props = propertynames(tca)
40+
return NamedTuple{props}(getproperty(tca, p) for p in props)
41+
end
42+
43+
@inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
44+
@inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
45+
@inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray(ComponentArrays.__value(getdata(x)), getaxes(x))
46+
3447
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module ComponentArraysTruncatedStacktracesExt
2+
3+
using ComponentArrays
4+
import TruncatedStacktraces: @truncate_stacktrace
5+
6+
@truncate_stacktrace ComponentArray 1
7+
8+
end

src/ComponentArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ include("compat/chainrulescore.jl")
5252
include("compat/static_arrays.jl")
5353
export @static_unpack
5454

55+
include("compat/functors.jl")
56+
5557
import PackageExtensionCompat: @require_extensions
5658
function __init__()
5759
@require_extensions

src/compat/chainrulescore.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,26 @@ end
4141
# Prevent double projection
4242
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx
4343

44-
function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A, <:NamedTuple}) where {A}
44+
function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A,<:NamedTuple}) where {A}
4545
nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t))
4646
return ComponentArray(nt)
4747
end
48+
49+
function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
50+
y = CA(nt)
51+
52+
function ∇NamedTupleToComponentArray::AbstractArray)
53+
if length(Δ) == length(y)
54+
return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))
55+
end
56+
error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " *
57+
"of shape $(size(y)) & type $(typeof(y))")
58+
return nothing
59+
end
60+
61+
function ∇NamedTupleToComponentArray::ComponentArray)
62+
return ChainRulesCore.NoTangent(), NamedTuple(Δ)
63+
end
64+
65+
return y, ∇NamedTupleToComponentArray
66+
end

src/compat/functors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,5 @@ recursive_eltype(x::AbstractArray{<:Any}) = isempty(x) ? Base.Bottom : mapreduce
5050
recursive_eltype(x::Dict) = isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x))
5151
recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N} = T
5252
recursive_eltype(x) = typeof(x)
53+
54+
@inline __value(x) = x

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
99
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
12+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1213
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1314
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
17+
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
1618
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1719
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/autodiff_tests.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote
2-
2+
using Optimisers
33
using Test
44

55
F(a, x) = sum(abs2, a) * x^3
@@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156)
3838
@test out isa Vector{<:ForwardDiff.Dual}
3939
end
4040

41+
@testset "Optimisers Update" begin
42+
ca_ = deepcopy(ca)
43+
opt_st = Optimisers.setup(Adam(0.01), ca_)
44+
gs_zyg = only(Zygote.gradient(F_idx_val, ca_))
45+
@test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ca)
46+
Optimisers.update!(opt_st, ca_, gs_zyg)
47+
@test !(ca_ ca)
48+
49+
ca_ = deepcopy(ca)
50+
opt_st = Optimisers.setup(Adam(0.01), ca_)
51+
gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_)
52+
@test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ca)
53+
Optimisers.update!(opt_st, ca_, gs_rdiff)
54+
@test !(ca_ ca)
55+
end
56+
4157
@testset "Projection" begin
4258
gs_ca = Zygote.gradient(sum, ca)[1]
4359

@@ -76,18 +92,28 @@ end
7692
@test ∂r ∂r_ca
7793
end
7894

95+
function F_prop(x)
96+
@assert propertynames(x) == (:x, :y)
97+
return sum(abs2, x.x .- x.y)
98+
end
99+
100+
@testset "Preserve Properties" begin
101+
x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0])
79102

80-
# # This is commented out because the gradient operation itself is broken due to Zygote's inability
81-
# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple.
82-
# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken
83-
# # wouldn't work here because the error happens before the test)
84-
# @testset "Issues" begin
85-
# function mysum(x::AbstractVector)
86-
# y = ComponentVector(x=x)
87-
# return sum(y)
88-
# end
103+
gs_z = only(Zygote.gradient(F_prop, x))
104+
gs_rdiff = ReverseDiff.gradient(F_prop, x)
89105

90-
# Δ = Zygote.gradient(mysum, rand(10))
106+
@test gs_z gs_rdiff
107+
end
108+
109+
@testset "Issues" begin
110+
function mysum(x::AbstractVector)
111+
y = ComponentVector(x=x)
112+
z = ComponentVector(; z = x .^ 2)
113+
return sum(y) + sum(abs2, z)
114+
end
91115

92-
# @test Δ isa Vector{Float64}
93-
# end
116+
Δ = only(Zygote.gradient(mysum, rand(10)))
117+
118+
@test Δ isa AbstractVector{Float64}
119+
end

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using StaticArrays
88
using OffsetArrays
99
using Test
1010
using Unitful
11+
using Functors
12+
import TruncatedStacktraces # This is loaded just to trigger the extension package
1113

1214

1315
## Test setup
@@ -690,6 +692,14 @@ end
690692
@test_throws ArgumentError axpby!(2, x, 3, y)
691693
end
692694

695+
@testset "Functors" begin
696+
for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa)
697+
θ, re = Functors.functor(carray)
698+
@test θ isa NamedTuple
699+
@test re(θ) == carray
700+
end
701+
end
702+
693703
@testset "Autodiff" begin
694704
include("autodiff_tests.jl")
695705
end

0 commit comments

Comments
 (0)