|
| 1 | +import OnlineStats: OnlineStats, Group, Moments, fit!, OnlineStat, Weight, EqualWeight, mean, std |
| 2 | +export scalar_normalizer, array_normalizer, NormalizedTraces, Normalizer |
| 3 | +import MacroTools.@forward |
| 4 | + |
| 5 | +""" |
| 6 | + Normalizer(::OnlineStat) |
| 7 | +
|
| 8 | +Wraps an OnlineStat to be used by a [`NormalizedTraces`](@ref). |
| 9 | +""" |
| 10 | +struct Normalizer{OS<:OnlineStat} |
| 11 | + os::OS |
| 12 | +end |
| 13 | + |
| 14 | +@forward Normalizer.os OnlineStats.mean, OnlineStats.std, Base.iterate, normalize, Base.length |
| 15 | + |
| 16 | +#Treats last dim as batch dim |
| 17 | +function OnlineStats.fit!(n::Normalizer, data::AbstractArray) |
| 18 | + for d in eachslice(data, dims = ndims(data)) |
| 19 | + fit!(n.os, vec(d)) |
| 20 | + end |
| 21 | + n |
| 22 | +end |
| 23 | + |
| 24 | +function OnlineStats.fit!(n::Normalizer{<:Group}, y::AbstractVector) |
| 25 | + fit!(n.os, y) |
| 26 | + n |
| 27 | +end |
| 28 | + |
| 29 | +function OnlineStats.fit!(n::Normalizer, y) |
| 30 | + for yi in y |
| 31 | + fit!(n.os, vec(yi)) |
| 32 | + end |
| 33 | + n |
| 34 | +end |
| 35 | + |
| 36 | +function OnlineStats.fit!(n::Normalizer{<:Moments}, y::AbstractVector{<:Number}) |
| 37 | + for yi in y |
| 38 | + fit!(n.os, yi) |
| 39 | + end |
| 40 | + n |
| 41 | +end |
| 42 | + |
| 43 | +function OnlineStats.fit!(n::Normalizer, data::Number) |
| 44 | + fit!(n.os, data) |
| 45 | + n |
| 46 | +end |
| 47 | + |
| 48 | +""" |
| 49 | + normalize(os::Moments, x) |
| 50 | +
|
| 51 | +Given an Moments estimate of the elements of x, a vector of scalar traces, |
| 52 | +normalizes x elementwise to zero mean, and unit variance. |
| 53 | +""" |
| 54 | +function normalize(os::Moments, x) |
| 55 | + T = eltype(x) |
| 56 | + m, s = T(mean(os)), T(std(os)) |
| 57 | + return (x .- m) ./ s |
| 58 | +end |
| 59 | + |
| 60 | +""" |
| 61 | + normalize(os::Group{<:AbstractVector{<:Moments}}, x) |
| 62 | +
|
| 63 | +Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of |
| 64 | +each element of x, |
| 65 | +normalizes each element of x to zero mean, and unit variance. Treats the last dimension as |
| 66 | +a batch dimension if `ndims(x) >= 2`. |
| 67 | +""" |
| 68 | +function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector) |
| 69 | + T = eltype(x) |
| 70 | + m = [T(mean(stat)) for stat in os] |
| 71 | + s = [T(std(stat)) for stat in os] |
| 72 | + return (x .- m) ./ s |
| 73 | +end |
| 74 | + |
| 75 | +function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray) |
| 76 | + xn = similar(x) |
| 77 | + for (i, slice) in enumerate(eachslice(x, dims = ndims(x))) |
| 78 | + xn[repeat([:], ndims(x)-1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...) |
| 79 | + end |
| 80 | + return xn |
| 81 | +end |
| 82 | + |
| 83 | +function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray}) |
| 84 | + xn = similar(x) |
| 85 | + for (i,el) in enumerate(x) |
| 86 | + xn[i] = normalize(os, vec(el)) |
| 87 | + end |
| 88 | + return xn |
| 89 | +end |
| 90 | + |
| 91 | +""" |
| 92 | + scalar_normalizer(;weights = OnlineStats.EqualWeight()) |
| 93 | +
|
| 94 | +Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples |
| 95 | +have equal weights in the computation of the moments. |
| 96 | +See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) |
| 97 | +to use variants such as exponential weights to favor the most recent observations. |
| 98 | +""" |
| 99 | +scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight)) |
| 100 | + |
| 101 | +""" |
| 102 | + array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight()) |
| 103 | +
|
| 104 | +Returns preconfigured normalizer for array traces such as vector or matrix states. |
| 105 | +`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements |
| 106 | +vector, or `(252,252)` for a square image. |
| 107 | +By default, all samples have equal weights in the computation of the moments. |
| 108 | +See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) |
| 109 | +to use variants such as exponential weights to favor the most recent observations. |
| 110 | +""" |
| 111 | +array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)])) |
| 112 | + |
| 113 | +""" |
| 114 | + NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple) |
| 115 | + NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...) |
| 116 | +
|
| 117 | +Wraps an [`AbstractTraces`](@ref) and a `NamedTuple` of `Symbol` => [`Normalizer`](@ref) |
| 118 | +pairs. |
| 119 | +When pushing new elements to the traces, a `NormalizedTraces` will first update a running |
| 120 | +estimate of the moments of traces present in the keys of `normalizers`. |
| 121 | +When sampling a normalized trace, it will first normalize the samples to zero mean and unit |
| 122 | +variance. Traces that do not have a normalizer are sample as usual. |
| 123 | +
|
| 124 | +Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap |
| 125 | +the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise |
| 126 | +the running estimate will reset after each episode. |
| 127 | +
|
| 128 | +When used with a MultiplexTraces, the normalizer used with for one symbol (e.g. :state) will |
| 129 | +be the same used for the other one (e.g. :next_state). |
| 130 | +
|
| 131 | +Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and |
| 132 | +arrays (see [`array_normalizer`](@ref)). |
| 133 | +
|
| 134 | +# Examples |
| 135 | +``` |
| 136 | +t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,)) |
| 137 | +nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,))) |
| 138 | +# :next_state will also be normalized. |
| 139 | +traj = Trajectory( |
| 140 | + container = nt, |
| 141 | + sampler = BatchSampler(10) |
| 142 | +) |
| 143 | +``` |
| 144 | +""" |
| 145 | +struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT} |
| 146 | + traces::T |
| 147 | + normalizers::NamedTuple{normnames, N} |
| 148 | +end |
| 149 | + |
| 150 | +function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pairs...) where names where TT |
| 151 | + for key in keys(trace_normalizer_pairs) |
| 152 | + @assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))." |
| 153 | + end |
| 154 | + nt = (; trace_normalizer_pairs...) |
| 155 | + for trace in traces.traces |
| 156 | + #check if all traces of MultiplexTraces are in pairs |
| 157 | + if trace isa MultiplexTraces |
| 158 | + if length(intersect(keys(trace), keys(trace_normalizer_pairs))) in [0, length(keys(trace))] #check if none or all keys are in normalizers |
| 159 | + continue |
| 160 | + else #if not then one is missing |
| 161 | + present_key = only(intersect(keys(trace), keys(trace_normalizer_pairs))) |
| 162 | + absent_key = only(setdiff(keys(trace), keys(trace_normalizer_pairs))) |
| 163 | + nt = merge(nt, (;(absent_key => nt[present_key],)...)) #assign the same normalizer |
| 164 | + end |
| 165 | + end |
| 166 | + end |
| 167 | + NormalizedTraces{names, TT, typeof(traces), keys(nt), typeof(values(nt))}(traces, nt) |
| 168 | +end |
| 169 | + |
| 170 | +function Base.show(io::IO, ::MIME"text/plain", t::NormalizedTraces{names,T}) where {names,T} |
| 171 | + s = nameof(typeof(t)) |
| 172 | + println(io, "$s with $(length(names)) entries:") |
| 173 | + for n in names |
| 174 | + print(io, " :$n => $(summary(t[n]))") |
| 175 | + if n in keys(t.normalizers) |
| 176 | + println(io, " => Normalized") |
| 177 | + else |
| 178 | + println(io, "") |
| 179 | + end |
| 180 | + end |
| 181 | +end |
| 182 | + |
| 183 | +@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent |
| 184 | + |
| 185 | +for f in (:push!, :pushfirst!, :append!, :prepend!) |
| 186 | + @eval function Base.$f(nt::NormalizedTraces, x::NamedTuple) |
| 187 | + for key in intersect(keys(nt.normalizers), keys(x)) |
| 188 | + fit!(nt.normalizers[key], x[key]) |
| 189 | + end |
| 190 | + $f(nt.traces, x) |
| 191 | + end |
| 192 | +end |
| 193 | + |
| 194 | +function sample(s::BatchSampler, nt::NormalizedTraces, names) |
| 195 | + inds = rand(s.rng, 1:length(nt), s.batch_size) |
| 196 | + maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data |
| 197 | + NamedTuple{names}(s.transformer(maybe_normalize(nt[x][inds], x) for x in names)) |
| 198 | +end |
0 commit comments