Skip to content

Commit b142710

Browse files
committed
move to NormalizedTraces
1 parent 1a0b458 commit b142710

File tree

5 files changed

+84
-77
lines changed

5 files changed

+84
-77
lines changed

src/ReinforcementLearningTrajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ const RLTrajectories = ReinforcementLearningTrajectories
44
export RLTrajectories
55

66
include("patch.jl")
7-
87
include("traces.jl")
98
include("samplers.jl")
109
include("controllers.jl")
1110
include("trajectory.jl")
11+
include("normalization.jl")
1212
include("common/common.jl")
1313

1414
end

src/normalization.jl

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import OnlineStats: OnlineStats, Group, Moments, fit!, OnlineStat, Weight, EqualWeight, mean, std
2-
export scalar_normalizer, array_normalizer, NormalizedTrace, Normalizer
2+
export scalar_normalizer, array_normalizer, NormalizedTraces, Normalizer
33
import MacroTools.@forward
44

55
"""
66
Normalizer(::OnlineStat)
77
8-
Wraps an OnlineStat to be used by a [`NormalizedTrajectory`](@ref).
8+
Wraps an OnlineStat to be used by a [`NormalizedTraces`](@ref).
99
"""
1010
struct Normalizer{OS<:OnlineStat}
1111
os::OS
@@ -87,25 +87,32 @@ t = Trajectory(
8787
)
8888
8989
"""
90-
struct NormalizedTrace{T <: Trace, N <: Normalizer}
91-
trace::T
92-
normalizer::N
90+
struct NormalizedTraces{T <: AbstractTraces, names, N} #<: AbstractTraces
91+
traces::T
92+
normalizers::NamedTuple{names, N}
93+
function NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...)
94+
for key in keys(trace_normalizer_pairs)
95+
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
96+
end
97+
nt = (; trace_normalizer_pairs...)
98+
new{typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
99+
end
93100
end
94101

95-
NormalizedTrace(x, normalizer) = NormalizedTrace(convert(Trace, x), normalizer)
96-
97-
@forward NormalizedTrace.trace Base.length, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!
98-
99-
Base.convert(::Type{Trace}, x::NormalizedTrace) = x #ignore conversion to Trace
102+
@forward NormalizedTraces.traces Base.length, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.prepend!, Base.pushfirst!
100103

101-
function Base.push!(nt::NormalizedTrace, x)
102-
fit!(nt.normalizer, x)
103-
push!(nt.trace, x)
104+
function Base.push!(nt::NormalizedTraces, x)
105+
for key in intersect(keys(nt.normalizers), keys(x))
106+
fit!(nt.normalizers[key], x[key])
107+
end
108+
push!(nt.traces, x)
104109
end
105110

106-
function Base.append!(nt::NormalizedTrace, x)
107-
fit!(nt.normalizer, x)
108-
append!(nt.trace, x)
111+
function Base.append!(nt::NormalizedTraces, x)
112+
for key in intersect(keys(nt.normalizers), keys(x))
113+
fit!(nt.normalizers[key], x[key])
114+
end
115+
append!(nt.traces, x)
109116
end
110117

111118
"""
@@ -149,12 +156,22 @@ function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:A
149156
return xn
150157
end
151158

152-
function fetch(nt::NormalizedTrace, inds)
153-
batch = fetch(nt.trace, inds)
154-
normalize(nt.normalizer.os, batch)
159+
function fetch(nt::NormalizedTraces, inds)
160+
batch = fetch(nt.traces, inds)
161+
return (; (key => (key in keys(nt.normalizers) ? normalize(nt.normalizers[key].os, data) : data) for (key, data) in pairs(batch))...)
162+
end
163+
164+
#=
165+
abstract type AbstractFoo{T} end
166+
167+
struct Foo{T} <: AbstractFoo{T}
168+
x::T
169+
end
170+
171+
struct Baz{F <: Foo} <: supertype(Type(F))
172+
foo::F
155173
end
156174
157-
function sample(s, nt::NormalizedTrace)
158-
batch = sample(s, nt.trace)
159-
normalize(nt.normalizer.os, batch)
160-
end
175+
struct Bal{F <: AbstractFoo{T where T}} <: AbstractFoo{T}
176+
foo::F{T}
177+
end=#

src/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
2727

2828
function sample(s::BatchSampler, t::AbstractTraces, names)
2929
inds = rand(s.rng, 1:length(t), s.batch_size)
30-
NamedTuple{names}(s.transformer(t[x][inds]) for x in names)
30+
NamedTuple{names}(s.transformer(fetch(t[x], inds) for x in names))
3131
end
3232

3333
"""

src/traces.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
3333

3434
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
3535

36-
fetch(t::Trace, inds) = t[inds]
36+
fetch(t::AbstractTrace, inds) = t[inds]
3737
#####
3838

3939
"""
4040
For each concrete `AbstractTraces`, we have the following assumption:
4141
42-
43-
##
4442
1. Every inner trace is an `AbstractVector`
4543
1. Support partial updating
4644
1. Return *View* by default when getting elements.
@@ -58,6 +56,8 @@ end
5856
Base.keys(t::AbstractTraces{names}) where {names} = names
5957
Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names
6058

59+
#use fetch instead of getindex when sampling to retain compatibility
60+
fetch(t::AbstractTraces, inds) = t[inds]
6161
#####
6262

6363
"""
@@ -84,7 +84,7 @@ end
8484

8585
function MultiplexTraces{names}(t) where {names}
8686
if length(names) != 2
87-
throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names"))
87+
throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $(length(names)) trace names"))
8888
end
8989
trace = convert(AbstractTrace, t)
9090
MultiplexTraces{names,typeof(trace),eltype(trace)}(trace)
@@ -301,13 +301,6 @@ end
301301

302302
Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
303303

304-
function sample(s::BatchSampler, t::Traces)
305-
inds = rand(s.rng, 1:length(t), s.batch_size)
306-
map(t.traces) do x
307-
fetch(x, inds)
308-
end |> s.transformer
309-
end
310-
311304
for f in (:push!, :pushfirst!)
312305
@eval function Base.$f(ts::Traces, xs::NamedTuple)
313306
for (k, v) in pairs(xs)

test/normalization.jl

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,43 @@
11
using Test
2-
using Trajectories
3-
import Trajectories.normalize
4-
import OnlineStats: fit!, mean, std
2+
using ReinforcementLearningTrajectories
3+
import ReinforcementLearningTrajectories: fetch, sample
4+
import OnlineStats: mean, std
55

66
@testset "normalization.jl" begin
7-
#scalar normalization
8-
rewards = [1.:10;]
9-
rn = scalar_normalizer()
10-
fit!(rn, rewards)
11-
batch_reward = normalize(rn, [6.,5.,10.])
12-
@test batch_reward ([6.,5.,10.] .- mean(1:10))./std(1:10)
13-
#vector normalization
14-
states = reshape([1:50;], 5, 10)
15-
sn = array_normalizer((5,))
16-
fit!(sn, states)
17-
@test [mean(stat) for stat in sn] == [mean((1:5:46) .+i) for i in 0:4]
18-
batch_states =normalize(sn, reshape(repeat(5.:-1:1, 5), 5,5))
19-
@test all(length(unique(x)) == 1 for x in eachrow(batch_states))
20-
#array normalization
21-
states = reshape(1.:250, 5,5,10)
22-
sn = array_normalizer((5,5))
23-
fit!(sn, states)
24-
batch_states = normalize(sn, collect(states))
25-
26-
#NormalizedTrace
27-
t = Trajectory(
28-
container=Traces(
29-
a= NormalizedTrace(Float32[], scalar_normalizer()),
30-
b=Int[],
31-
c=NormalizedTrace(Vector{Float32}[], array_normalizer((10,))) #TODO check with ElasticArrays and Episodes
32-
),
33-
sampler=BatchSampler(300000),
34-
controler=InsertSampleRatioControler(Inf, 0)
7+
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
8+
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
9+
m = mean(0:4)
10+
s = std(0:4)
11+
12+
for i in 0:4
13+
r = ((1.0:5.0) .+ i) .% 5
14+
push!(nt, (state = [r;], action = 1, reward = Float32(i), terminal = false))
15+
end
16+
push!(nt, (next_state = fill(m, 5), next_action = 1)) #does not update because next_state is not in keys of normlizers. Is this desirable or not ?
17+
18+
@test mean(nt.normalizers[:reward].os) == m && std(nt.normalizers[:reward].os) == s
19+
@test all(nt.normalizers[:state].os) do moments
20+
mean(moments) == m && std(moments) == s
21+
end
22+
23+
unnormalized_batch = fetch(t, [1:5;])
24+
@test unnormalized_batch[:reward] == [0:4;]
25+
@test extrema(unnormalized_batch[:state]) == (0, 4)
26+
normalized_batch = fetch(nt, [1:5;])
27+
@test normalized_batch[:reward] ([0:4;] .- m)./s
28+
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./s)
29+
@test normalized_batch[:state][:,5] ([0:4;] .- m)./s
30+
#check for no mutation
31+
unnormalized_batch = fetch(t, [1:5;])
32+
@test unnormalized_batch[:reward] == [0:4;]
33+
@test extrema(unnormalized_batch[:state]) == (0, 4)
34+
#=
35+
traj = Trajectory(
36+
container = nt,
37+
sampler = BatchSampler(10),
38+
controller = InsertSampleRatioController(ratio = Inf, threshold = 0)
3539
)
36-
append!(t, a = [1,2,3], b = [1,2,3], c = eachcol(reshape(1f0:30, 10,3)))
37-
push!(t, a = 2, b = 2, c = fill(mean(1:30), 10))
38-
@test mean(t.container[:a].trace.x) 2.
39-
@test std(t.container[:a].trace.x) std([1,2,2,3])
40-
a,b,c = take!(t)
41-
@test eltype(a) == Float32
42-
@test mean(a) 0 atol = 0.01
43-
@test mean(b) 2 atol = 0.01 #b is not normalized
44-
@test eltype(first(c)) == Float32
45-
@test all(isapprox(0f0, atol = 0.01), vec(mean(reduce(hcat,c), dims = 2)))
40+
41+
batch = sample(traj)=#
42+
4643
end

0 commit comments

Comments
 (0)