Skip to content

Commit 561a5f0

Browse files
committed
Fix multiplex
1 parent ab2df32 commit 561a5f0

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/normalization.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ When sampling a normalized trace, it will first normalize the samples to zero me
122122
variance. Traces that do not have a normalizer are sample as usual.
123123
124124
Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap
125-
the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
125+
the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
126+
the running estimate will reset after each episode.
127+
128+
When used with a MultiplexTraces, the normalizer used with for one symbol (e.g. :state) will
129+
be the same used for the other one (e.g. :next_state).
126130
127131
Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and
128132
arrays (see [`array_normalizer`](@ref)).
@@ -131,6 +135,7 @@ arrays (see [`array_normalizer`](@ref)).
131135
```
132136
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
133137
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
138+
# :next_state will also be normalized.
134139
traj = Trajectory(
135140
container = nt,
136141
sampler = BatchSampler(10)
@@ -147,6 +152,18 @@ function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pa
147152
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
148153
end
149154
nt = (; trace_normalizer_pairs...)
155+
for trace in traces.traces
156+
#check if all traces of MultiplexTraces are in pairs
157+
if trace isa MultiplexTraces
158+
if length(intersect(keys(trace), keys(trace_normalizer_pairs))) in [0, length(keys(trace))] #check if none or all keys are in normalizers
159+
continue
160+
else #if not then one is missing
161+
present_key = only(intersect(keys(trace), keys(trace_normalizer_pairs)))
162+
absent_key = only(setdiff(keys(trace), keys(trace_normalizer_pairs)))
163+
nt = merge(nt, (;(absent_key => nt[present_key],)...)) #assign the same normalizer
164+
end
165+
end
166+
end
150167
NormalizedTraces{names, TT, typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
151168
end
152169

test/normalization.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@ using ReinforcementLearningTrajectories
33
import ReinforcementLearningTrajectories: sample
44
import OnlineStats: mean, std
55

6-
@testset "normalization.jl" begin
6+
#@testset "normalization.jl" begin
77
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
88
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
99
m = mean(0:4)
1010
s = std(0:4)
11-
11+
ss = std([0,1,2,2,3,4])
1212
for i in 0:4
1313
r = ((1.0:5.0) .+ i) .% 5
1414
push!(nt, (state = [r;], action = 1, reward = Float32(i), terminal = false))
1515
end
16-
push!(nt, (next_state = fill(m, 5), next_action = 1)) #does not update because next_state is not in keys of normalizers. Is this desirable or not ?
16+
push!(nt, (next_state = fill(m, 5), next_action = 1)) #this also updates state moments
1717

1818
@test mean(nt.normalizers[:reward].os) == m && std(nt.normalizers[:reward].os) == s
1919
@test all(nt.normalizers[:state].os) do moments
20-
mean(moments) == m && std(moments) == s
20+
mean(moments) == m && std(moments) == ss
2121
end
2222

2323
unnormalized_batch = t[[1:5;]]
@@ -31,7 +31,8 @@ import OnlineStats: mean, std
3131
controller = InsertSampleRatioController(ratio = Inf, threshold = 0)
3232
)
3333
normalized_batch = sample(traj)
34-
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./s)
34+
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./ss)
35+
@test all(extrema(normalized_batch[:next_state]) .≈ ((0, 4) .- m)./ss)
3536
@test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m)./s)
3637
#check for no mutation
3738
unnormalized_batch = t[[1:5;]]

0 commit comments

Comments
 (0)