|
60 | 60 | """
|
61 | 61 | normalize(os::Group{<:AbstractVector{<:Moments}}, x)
|
62 | 62 |
|
63 |
| -Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x, |
64 |
| -normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`. |
| 63 | +Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of |
| 64 | +each element of x, |
| 65 | +normalizes each element of x to zero mean, and unit variance. Treats the last dimension as |
| 66 | +a batch dimension if `ndims(x) >= 2`. |
65 | 67 | """
|
66 | 68 | function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
|
67 | 69 | T = eltype(x)
|
|
89 | 91 | """
|
90 | 92 | scalar_normalizer(;weights = OnlineStats.EqualWeight())
|
91 | 93 |
|
92 |
| -Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples have equal weights in the computation of the moments. |
93 |
| -See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations. |
| 94 | +Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples |
| 95 | +have equal weights in the computation of the moments. |
| 96 | +See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) |
| 97 | +to use variants such as exponential weights to favor the most recent observations. |
94 | 98 | """
|
95 | 99 | scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight))
|
96 | 100 |
|
97 | 101 | """
|
98 | 102 | array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight())
|
99 | 103 |
|
100 | 104 | Returns preconfigured normalizer for array traces such as vector or matrix states.
|
101 |
| -`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements vector, or `(252,252)` for a square image. |
| 105 | +`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements |
| 106 | +vector, or `(252,252)` for a square image. |
102 | 107 | By default, all samples have equal weights in the computation of the moments.
|
103 |
| -See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations. |
| 108 | +See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) |
| 109 | +to use variants such as exponential weights to favor the most recent observations. |
104 | 110 | """
|
105 | 111 | array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)]))
|
106 | 112 |
|
107 | 113 | """
|
108 |
| - NormalizedTrace(trace::Trace, normalizer::Normalizer) |
109 |
| -
|
110 |
| -Wraps a [`Trace`](@ref) and a [`Normalizer`](@ref). When pushing new elements to the trace, a `NormalizedTrace` will first update a running estimate of the moments of that trace. |
111 |
| -When sampling a normalized trace, it will first normalize the samples using to zero mean and unit variance. |
112 |
| -
|
113 |
| -preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and arrays (see [`array_normalizer`](@ref)) |
114 |
| -
|
115 |
| -#Example |
116 |
| -t = Trajectory( |
117 |
| - container=Traces( |
118 |
| - a_scalar_trace = NormalizedTrace(Float32[], scalar_normalizer()), |
119 |
| - a_non_normalized_trace=Bool[], |
120 |
| - a_vector_trace = NormalizedTrace(Vector{Float32}[], array_normalizer((10,))), |
121 |
| - a_matrix_trace = NormalizedTrace(Matrix{Float32}[], array_normalizer((252,252), weight = OnlineStats.ExponientialWeight(0.9f0))) |
122 |
| - ), |
123 |
| - sampler=BatchSampler(3), |
124 |
| - controler=InsertSampleRatioControler(0.25, 4) |
| 114 | + NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple) |
| 115 | + NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...) |
| 116 | +
|
| 117 | +Wraps an [`AbstractTraces`](@ref) and a `NamedTuple` of `Symbol` => [`Normalizer`](@ref) |
| 118 | +pairs. |
| 119 | +When pushing new elements to the traces, a `NormalizedTraces` will first update a running |
| 120 | +estimate of the moments of traces present in the keys of `normalizers`. |
| 121 | +When sampling a normalized trace, it will first normalize the samples to zero mean and unit |
| 122 | +variance. Traces that do not have a normalizer are sample as usual. |
| 123 | +
|
| 124 | +Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap |
| 125 | +the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise |
| 126 | +
|
| 127 | +Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and |
| 128 | +arrays (see [`array_normalizer`](@ref)). |
| 129 | +
|
| 130 | +# Examples |
| 131 | +``` |
| 132 | +t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,)) |
| 133 | +nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,))) |
| 134 | +traj = Trajectory( |
| 135 | + container = nt, |
| 136 | + sampler = BatchSampler(10) |
125 | 137 | )
|
126 |
| -
|
| 138 | +``` |
127 | 139 | """
|
128 | 140 | struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT}
|
129 | 141 | traces::T
|
|
0 commit comments