@@ -122,7 +122,11 @@ When sampling a normalized trace, it will first normalize the samples to zero me
122
122
variance. Traces that do not have a normalizer are sample as usual.
123
123
124
124
Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap
125
- the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
125
+ the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
126
+ the running estimate will reset after each episode.
127
+
128
+ When used with a MultiplexTraces, the normalizer used with for one symbol (e.g. :state) will
129
+ be the same used for the other one (e.g. :next_state).
126
130
127
131
Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and
128
132
arrays (see [`array_normalizer`](@ref)).
@@ -131,6 +135,7 @@ arrays (see [`array_normalizer`](@ref)).
131
135
```
132
136
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
133
137
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
138
+ # :next_state will also be normalized.
134
139
traj = Trajectory(
135
140
container = nt,
136
141
sampler = BatchSampler(10)
@@ -147,6 +152,18 @@ function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pa
147
152
@assert key in keys (traces) " Traces do not have key $key , valid keys are $(keys (traces)) ."
148
153
end
149
154
nt = (; trace_normalizer_pairs... )
155
+ for trace in traces. traces
156
+ # check if all traces of MultiplexTraces are in pairs
157
+ if trace isa MultiplexTraces
158
+ if length (intersect (keys (trace), keys (trace_normalizer_pairs))) in [0 , length (keys (trace))] # check if none or all keys are in normalizers
159
+ continue
160
+ else # if not then one is missing
161
+ present_key = only (intersect (keys (trace), keys (trace_normalizer_pairs)))
162
+ absent_key = only (setdiff (keys (trace), keys (trace_normalizer_pairs)))
163
+ nt = merge (nt, (;(absent_key => nt[present_key],). .. )) # assign the same normalizer
164
+ end
165
+ end
166
+ end
150
167
NormalizedTraces {names, TT, typeof(traces), keys(nt), typeof(values(nt))} (traces, nt)
151
168
end
152
169
0 commit comments