Skip to content

Commit 8c6b304

Browse files
authored
Merge pull request #12 from JuliaReinforcementLearning/normalization
Normalizer Wrapper
2 parents 9095619 + 9e26766 commit 8c6b304

File tree

8 files changed

+259
-10
lines changed

8 files changed

+259
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55
[deps]
66
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
8+
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1011

@@ -13,6 +14,7 @@ CircularArrayBuffers = "0.1"
1314
MacroTools = "0.5"
1415
StackViews = "0.1"
1516
julia = "1.6"
17+
OnlineStats = "1.0"
1618

1719
[extras]
1820
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/ReinforcementLearningTrajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ const RLTrajectories = ReinforcementLearningTrajectories
44
export RLTrajectories
55

66
include("patch.jl")
7-
87
include("traces.jl")
98
include("samplers.jl")
109
include("controllers.jl")
1110
include("trajectory.jl")
11+
include("normalization.jl")
1212
include("common/common.jl")
1313

1414
end

src/normalization.jl

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import OnlineStats: OnlineStats, Group, Moments, fit!, OnlineStat, Weight, EqualWeight, mean, std
2+
export scalar_normalizer, array_normalizer, NormalizedTraces, Normalizer
3+
import MacroTools.@forward
4+
5+
"""
6+
Normalizer(::OnlineStat)
7+
8+
Wraps an OnlineStat to be used by a [`NormalizedTraces`](@ref).
9+
"""
10+
struct Normalizer{OS<:OnlineStat}
11+
os::OS
12+
end
13+
14+
@forward Normalizer.os OnlineStats.mean, OnlineStats.std, Base.iterate, normalize, Base.length
15+
16+
#Treats last dim as batch dim
17+
function OnlineStats.fit!(n::Normalizer, data::AbstractArray)
18+
for d in eachslice(data, dims = ndims(data))
19+
fit!(n.os, vec(d))
20+
end
21+
n
22+
end
23+
24+
function OnlineStats.fit!(n::Normalizer{<:Group}, y::AbstractVector)
25+
fit!(n.os, y)
26+
n
27+
end
28+
29+
function OnlineStats.fit!(n::Normalizer, y)
30+
for yi in y
31+
fit!(n.os, vec(yi))
32+
end
33+
n
34+
end
35+
36+
function OnlineStats.fit!(n::Normalizer{<:Moments}, y::AbstractVector{<:Number})
37+
for yi in y
38+
fit!(n.os, yi)
39+
end
40+
n
41+
end
42+
43+
function OnlineStats.fit!(n::Normalizer, data::Number)
44+
fit!(n.os, data)
45+
n
46+
end
47+
48+
"""
49+
normalize(os::Moments, x)
50+
51+
Given an Moments estimate of the elements of x, a vector of scalar traces,
52+
normalizes x elementwise to zero mean, and unit variance.
53+
"""
54+
function normalize(os::Moments, x)
55+
T = eltype(x)
56+
m, s = T(mean(os)), T(std(os))
57+
return (x .- m) ./ s
58+
end
59+
60+
"""
61+
normalize(os::Group{<:AbstractVector{<:Moments}}, x)
62+
63+
Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of
64+
each element of x,
65+
normalizes each element of x to zero mean, and unit variance. Treats the last dimension as
66+
a batch dimension if `ndims(x) >= 2`.
67+
"""
68+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
69+
T = eltype(x)
70+
m = [T(mean(stat)) for stat in os]
71+
s = [T(std(stat)) for stat in os]
72+
return (x .- m) ./ s
73+
end
74+
75+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray)
76+
xn = similar(x)
77+
for (i, slice) in enumerate(eachslice(x, dims = ndims(x)))
78+
xn[repeat([:], ndims(x)-1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...)
79+
end
80+
return xn
81+
end
82+
83+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray})
84+
xn = similar(x)
85+
for (i,el) in enumerate(x)
86+
xn[i] = normalize(os, vec(el))
87+
end
88+
return xn
89+
end
90+
91+
"""
92+
scalar_normalizer(;weights = OnlineStats.EqualWeight())
93+
94+
Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples
95+
have equal weights in the computation of the moments.
96+
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
97+
to use variants such as exponential weights to favor the most recent observations.
98+
"""
99+
scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight))
100+
101+
"""
102+
array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight())
103+
104+
Returns preconfigured normalizer for array traces such as vector or matrix states.
105+
`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements
106+
vector, or `(252,252)` for a square image.
107+
By default, all samples have equal weights in the computation of the moments.
108+
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
109+
to use variants such as exponential weights to favor the most recent observations.
110+
"""
111+
array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)]))
112+
113+
"""
114+
NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple)
115+
NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...)
116+
117+
Wraps an [`AbstractTraces`](@ref) and a `NamedTuple` of `Symbol` => [`Normalizer`](@ref)
118+
pairs.
119+
When pushing new elements to the traces, a `NormalizedTraces` will first update a running
120+
estimate of the moments of traces present in the keys of `normalizers`.
121+
When sampling a normalized trace, it will first normalize the samples to zero mean and unit
122+
variance. Traces that do not have a normalizer are sample as usual.
123+
124+
Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap
125+
the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
126+
the running estimate will reset after each episode.
127+
128+
When used with a MultiplexTraces, the normalizer used with for one symbol (e.g. :state) will
129+
be the same used for the other one (e.g. :next_state).
130+
131+
Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and
132+
arrays (see [`array_normalizer`](@ref)).
133+
134+
# Examples
135+
```
136+
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
137+
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
138+
# :next_state will also be normalized.
139+
traj = Trajectory(
140+
container = nt,
141+
sampler = BatchSampler(10)
142+
)
143+
```
144+
"""
145+
struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT}
146+
traces::T
147+
normalizers::NamedTuple{normnames, N}
148+
end
149+
150+
function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pairs...) where names where TT
151+
for key in keys(trace_normalizer_pairs)
152+
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
153+
end
154+
nt = (; trace_normalizer_pairs...)
155+
for trace in traces.traces
156+
#check if all traces of MultiplexTraces are in pairs
157+
if trace isa MultiplexTraces
158+
if length(intersect(keys(trace), keys(trace_normalizer_pairs))) in [0, length(keys(trace))] #check if none or all keys are in normalizers
159+
continue
160+
else #if not then one is missing
161+
present_key = only(intersect(keys(trace), keys(trace_normalizer_pairs)))
162+
absent_key = only(setdiff(keys(trace), keys(trace_normalizer_pairs)))
163+
nt = merge(nt, (;(absent_key => nt[present_key],)...)) #assign the same normalizer
164+
end
165+
end
166+
end
167+
NormalizedTraces{names, TT, typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
168+
end
169+
170+
function Base.show(io::IO, ::MIME"text/plain", t::NormalizedTraces{names,T}) where {names,T}
171+
s = nameof(typeof(t))
172+
println(io, "$s with $(length(names)) entries:")
173+
for n in names
174+
print(io, " :$n => $(summary(t[n]))")
175+
if n in keys(t.normalizers)
176+
println(io, " => Normalized")
177+
else
178+
println(io, "")
179+
end
180+
end
181+
end
182+
183+
@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent
184+
185+
for f in (:push!, :pushfirst!, :append!, :prepend!)
186+
@eval function Base.$f(nt::NormalizedTraces, x::NamedTuple)
187+
for key in intersect(keys(nt.normalizers), keys(x))
188+
fit!(nt.normalizers[key], x[key])
189+
end
190+
$f(nt.traces, x)
191+
end
192+
end
193+
194+
function sample(s::BatchSampler, nt::NormalizedTraces, names)
195+
inds = rand(s.rng, 1:length(nt), s.batch_size)
196+
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
197+
NamedTuple{names}(s.transformer(maybe_normalize(nt[x][inds], x) for x in names))
198+
end

src/rendering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function inner_convert(::Type{Term.AbstractRenderable}, x; style="gray1", width=
1919
end
2020

2121
Base.convert(T::Type{Term.AbstractRenderable}, t::Trace{<:AbstractArray}; kw...) = convert(T, Trace(collect(eachslice(t.x, dims=ndims(t.x)))); kw..., type=typeof(t), subtitle="size: $(size(t.x))")
22-
22+
Base.convert(T::Type{Term.AbstractRenderable}, t::NormalizedTrace; kw...) = convert(T, t.trace; kw..., type = typeof(t))
2323
function Base.convert(
2424
::Type{Term.AbstractRenderable},
2525
t::Trace{<:AbstractVector};

src/samplers.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ end
1212

1313
"""
1414
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG, transformer=identity)
15+
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG, transformer=identity)
1516
16-
Uniformly sample a batch of examples for each trace specified in `names`. By default, all the traces will be sampled.
17+
Uniformly sample a batch of examples for each trace specified in `names`.
18+
By default, all the traces will be sampled.
1719
1820
See also [`sample`](@ref).
1921
"""
@@ -27,18 +29,20 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
2729

2830
function sample(s::BatchSampler, t::AbstractTraces, names)
2931
inds = rand(s.rng, 1:length(t), s.batch_size)
30-
NamedTuple{names}(s.transformer(t[x][inds]) for x in names)
32+
NamedTuple{names}(s.transformer(t[x][inds] for x in names))
3133
end
3234

3335
"""
3436
MetaSampler(::NamedTuple)
3537
36-
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a batch from each sampler.
38+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
39+
batch from each sampler.
3740
Used internally for algorithms that sample multiple times per epoch.
3841
3942
# Example
40-
43+
```
4144
MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
45+
```
4246
"""
4347
struct MetaSampler{names,T} <: AbstractSampler
4448
samplers::NamedTuple{names,T}
@@ -52,11 +56,14 @@ sample(s::MetaSampler, t) = map(x -> sample(x, t), s.samplers)
5256
"""
5357
MultiBatchSampler(sampler, n)
5458
55-
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination with MetaSampler to allow different sampling rates between samplers.
59+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
60+
with MetaSampler to allow different sampling rates between samplers.
5661
5762
# Example
58-
59-
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5))
63+
```
64+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
65+
critic = MultiBatchSampler(BatchSampler(100), 5))
66+
```
6067
"""
6168
struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler
6269
sampler::S

src/traces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181

8282
function MultiplexTraces{names}(t) where {names}
8383
if length(names) != 2
84-
throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names"))
84+
throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $(length(names)) trace names"))
8585
end
8686
trace = convert(AbstractTrace, t)
8787
MultiplexTraces{names,typeof(trace),eltype(trace)}(trace)

test/normalization.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using Test
2+
using ReinforcementLearningTrajectories
3+
import ReinforcementLearningTrajectories: sample
4+
import OnlineStats: mean, std
5+
6+
@testset "normalization.jl" begin
7+
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
8+
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
9+
m = mean(0:4)
10+
s = std(0:4)
11+
ss = std([0,1,2,2,3,4])
12+
for i in 0:4
13+
r = ((1.0:5.0) .+ i) .% 5
14+
push!(nt, (state = [r;], action = 1, reward = Float32(i), terminal = false))
15+
end
16+
push!(nt, (next_state = fill(m, 5), next_action = 1)) #this also updates state moments
17+
18+
@test mean(nt.normalizers[:reward].os) == m && std(nt.normalizers[:reward].os) == s
19+
@test all(nt.normalizers[:state].os) do moments
20+
mean(moments) == m && std(moments) == ss
21+
end
22+
23+
unnormalized_batch = t[[1:5;]]
24+
@test unnormalized_batch[:reward] == [0:4;]
25+
@test extrema(unnormalized_batch[:state]) == (0, 4)
26+
normalized_batch = nt[[1:5;]]
27+
28+
traj = Trajectory(
29+
container = nt,
30+
sampler = BatchSampler(1000),
31+
controller = InsertSampleRatioController(ratio = Inf, threshold = 0)
32+
)
33+
normalized_batch = sample(traj)
34+
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./ss)
35+
@test all(extrema(normalized_batch[:next_state]) .≈ ((0, 4) .- m)./ss)
36+
@test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m)./s)
37+
#check for no mutation
38+
unnormalized_batch = t[[1:5;]]
39+
@test unnormalized_batch[:reward] == [0:4;]
40+
@test extrema(unnormalized_batch[:state]) == (0, 4)
41+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ using Test
77
include("common.jl")
88
include("samplers.jl")
99
include("trajectories.jl")
10+
include("normalization.jl")
1011
include("samplers.jl")
1112
end

0 commit comments

Comments
 (0)