Skip to content

Commit 18d2b97

Browse files
committed
Improve doc
1 parent 7911566 commit 18d2b97

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

src/normalization.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ end
6060
"""
6161
normalize(os::Group{<:AbstractVector{<:Moments}}, x)
6262
63-
Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x,
64-
normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`.
63+
Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of
64+
each element of x,
65+
normalizes each element of x to zero mean, and unit variance. Treats the last dimension as
66+
a batch dimension if `ndims(x) >= 2`.
6567
"""
6668
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
6769
T = eltype(x)
@@ -89,41 +91,51 @@ end
8991
"""
9092
scalar_normalizer(;weights = OnlineStats.EqualWeight())
9193
92-
Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples have equal weights in the computation of the moments.
93-
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations.
94+
Returns preconfigured normalizer for scalar traces such as rewards. By default, all samples
95+
have equal weights in the computation of the moments.
96+
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
97+
to use variants such as exponential weights to favor the most recent observations.
9498
"""
9599
scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight))
96100

97101
"""
98102
array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight())
99103
100104
Returns preconfigured normalizer for array traces such as vector or matrix states.
101-
`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements vector, or `(252,252)` for a square image.
105+
`size` is a tuple containing the dimension sizes of a state. E.g. `(10,)` for a 10-elements
106+
vector, or `(252,252)` for a square image.
102107
By default, all samples have equal weights in the computation of the moments.
103-
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations.
108+
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
109+
to use variants such as exponential weights to favor the most recent observations.
104110
"""
105111
array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)]))
106112

107113
"""
108-
NormalizedTrace(trace::Trace, normalizer::Normalizer)
109-
110-
Wraps a [`Trace`](@ref) and a [`Normalizer`](@ref). When pushing new elements to the trace, a `NormalizedTrace` will first update a running estimate of the moments of that trace.
111-
When sampling a normalized trace, it will first normalize the samples using to zero mean and unit variance.
112-
113-
preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and arrays (see [`array_normalizer`](@ref))
114-
115-
#Example
116-
t = Trajectory(
117-
container=Traces(
118-
a_scalar_trace = NormalizedTrace(Float32[], scalar_normalizer()),
119-
a_non_normalized_trace=Bool[],
120-
a_vector_trace = NormalizedTrace(Vector{Float32}[], array_normalizer((10,))),
121-
a_matrix_trace = NormalizedTrace(Matrix{Float32}[], array_normalizer((252,252), weight = OnlineStats.ExponientialWeight(0.9f0)))
122-
),
123-
sampler=BatchSampler(3),
124-
controler=InsertSampleRatioControler(0.25, 4)
114+
NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple)
115+
NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...)
116+
117+
Wraps an [`AbstractTraces`](@ref) and a `NamedTuple` of `Symbol` => [`Normalizer`](@ref)
118+
pairs.
119+
When pushing new elements to the traces, a `NormalizedTraces` will first update a running
120+
estimate of the moments of traces present in the keys of `normalizers`.
121+
When sampling a normalized trace, it will first normalize the samples to zero mean and unit
122+
variance. Traces that do not have a normalizer are sample as usual.
123+
124+
Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap
125+
the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
126+
127+
Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and
128+
arrays (see [`array_normalizer`](@ref)).
129+
130+
# Examples
131+
```
132+
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
133+
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
134+
traj = Trajectory(
135+
container = nt,
136+
sampler = BatchSampler(10)
125137
)
126-
138+
```
127139
"""
128140
struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT}
129141
traces::T

0 commit comments

Comments
 (0)