Skip to content

Commit 3d0416d

Browse files
committed
remove fetch and subtype
1 parent b142710 commit 3d0416d

File tree

4 files changed

+73
-99
lines changed

4 files changed

+73
-99
lines changed

src/normalization.jl

Lines changed: 59 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ end
1313

1414
@forward Normalizer.os OnlineStats.mean, OnlineStats.std, Base.iterate, normalize, Base.length
1515

16-
17-
1816
#Treats last dim as batch dim
1917
function OnlineStats.fit!(n::Normalizer, data::AbstractArray)
2018
for d in eachslice(data, dims = ndims(data))
@@ -47,6 +45,47 @@ function OnlineStats.fit!(n::Normalizer, data::Number)
4745
n
4846
end
4947

48+
"""
49+
normalize(os::Moments, x)
50+
51+
Given an Moments estimate of the elements of x, a vector of scalar traces,
52+
normalizes x elementwise to zero mean, and unit variance.
53+
"""
54+
function normalize(os::Moments, x)
55+
T = eltype(x)
56+
m, s = T(mean(os)), T(std(os))
57+
return (x .- m) ./ s
58+
end
59+
60+
"""
61+
normalize(os::Group{<:AbstractVector{<:Moments}}, x)
62+
63+
Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x,
64+
normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`.
65+
"""
66+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
67+
T = eltype(x)
68+
m = [T(mean(stat)) for stat in os]
69+
s = [T(std(stat)) for stat in os]
70+
return (x .- m) ./ s
71+
end
72+
73+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray)
74+
xn = similar(x)
75+
for (i, slice) in enumerate(eachslice(x, dims = ndims(x)))
76+
xn[repeat([:], ndims(x)-1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...)
77+
end
78+
return xn
79+
end
80+
81+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray})
82+
xn = similar(x)
83+
for (i,el) in enumerate(x)
84+
xn[i] = normalize(os, vec(el))
85+
end
86+
return xn
87+
end
88+
5089
"""
5190
scalar_normalizer(;weights = OnlineStats.EqualWeight())
5291
@@ -65,7 +104,6 @@ See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/sta
65104
"""
66105
array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)]))
67106

68-
69107
"""
70108
NormalizedTrace(trace::Trace, normalizer::Normalizer)
71109
@@ -87,91 +125,32 @@ t = Trajectory(
87125
)
88126
89127
"""
90-
struct NormalizedTraces{T <: AbstractTraces, names, N} #<: AbstractTraces
128+
struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT}
91129
traces::T
92-
normalizers::NamedTuple{names, N}
93-
function NormalizedTraces(traces::AbstractTraces; trace_normalizer_pairs...)
94-
for key in keys(trace_normalizer_pairs)
95-
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
96-
end
97-
nt = (; trace_normalizer_pairs...)
98-
new{typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
99-
end
130+
normalizers::NamedTuple{normnames, N}
100131
end
101132

102-
@forward NormalizedTraces.traces Base.length, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.prepend!, Base.pushfirst!
103-
104-
function Base.push!(nt::NormalizedTraces, x)
105-
for key in intersect(keys(nt.normalizers), keys(x))
106-
fit!(nt.normalizers[key], x[key])
133+
function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pairs...) where names where TT
134+
for key in keys(trace_normalizer_pairs)
135+
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
107136
end
108-
push!(nt.traces, x)
137+
nt = (; trace_normalizer_pairs...)
138+
NormalizedTraces{names, TT, typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
109139
end
110140

111-
function Base.append!(nt::NormalizedTraces, x)
112-
for key in intersect(keys(nt.normalizers), keys(x))
113-
fit!(nt.normalizers[key], x[key])
114-
end
115-
append!(nt.traces, x)
116-
end
117-
118-
"""
119-
normalize!(os::Moments, x)
120-
121-
Given an Moments estimate of the elements of x, a vector of scalar traces,
122-
normalizes x elementwise to zero mean, and unit variance.
123-
"""
124-
function normalize(os::Moments, x::AbstractVector)
125-
T = eltype(x)
126-
m, s = T(mean(os)), T(std(os))
127-
return (x .- m) ./ s
128-
end
129-
130-
"""
131-
normalize!(os::Group{<:AbstractVector{<:Moments}}, x)
132-
133-
Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x,
134-
normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`.
135-
"""
136-
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
137-
T = eltype(x)
138-
m = [T(mean(stat)) for stat in os]
139-
s = [T(std(stat)) for stat in os]
140-
return (x .- m) ./ s
141-
end
141+
@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent
142142

143-
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray)
144-
xn = similar(x)
145-
for (i, slice) in enumerate(eachslice(x, dims = ndims(x)))
146-
xn[repeat([:], ndims(x)-1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...)
147-
end
148-
return xn
149-
end
150-
151-
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray})
152-
xn = similar(x)
153-
for (i,el) in enumerate(x)
154-
xn[i] = normalize(os, vec(el))
143+
for f in (:push!, :pushfirst!, :append!, :prepend!)
144+
@eval function Base.$f(nt::NormalizedTraces, x::NamedTuple)
145+
for key in intersect(keys(nt.normalizers), keys(x))
146+
fit!(nt.normalizers[key], x[key])
147+
end
148+
$f(nt.traces, x)
155149
end
156-
return xn
157150
end
158151

159-
function fetch(nt::NormalizedTraces, inds)
160-
batch = fetch(nt.traces, inds)
161-
return (; (key => (key in keys(nt.normalizers) ? normalize(nt.normalizers[key].os, data) : data) for (key, data) in pairs(batch))...)
152+
function sample(s::BatchSampler, nt::NormalizedTraces, names)
153+
inds = rand(s.rng, 1:length(nt), s.batch_size)
154+
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
155+
NamedTuple{names}(s.transformer(maybe_normalize(nt[x][inds], x) for x in names))
162156
end
163-
164-
#=
165-
abstract type AbstractFoo{T} end
166-
167-
struct Foo{T} <: AbstractFoo{T}
168-
x::T
169-
end
170-
171-
struct Baz{F <: Foo} <: supertype(Type(F))
172-
foo::F
173-
end
174-
175-
struct Bal{F <: AbstractFoo{T where T}} <: AbstractFoo{T}
176-
foo::F{T}
177-
end=#

src/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
2727

2828
function sample(s::BatchSampler, t::AbstractTraces, names)
2929
inds = rand(s.rng, 1:length(t), s.batch_size)
30-
NamedTuple{names}(s.transformer(fetch(t[x], inds) for x in names))
30+
NamedTuple{names}(s.transformer(t[x][inds] for x in names))
3131
end
3232

3333
"""

src/traces.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
3333

3434
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
3535

36-
fetch(t::AbstractTrace, inds) = t[inds]
3736
#####
3837

3938
"""
@@ -56,8 +55,6 @@ end
5655
Base.keys(t::AbstractTraces{names}) where {names} = names
5756
Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names
5857

59-
#use fetch instead of getindex when sampling to retain compatibility
60-
fetch(t::AbstractTraces, inds) = t[inds]
6158
#####
6259

6360
"""

test/normalization.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test
22
using ReinforcementLearningTrajectories
3-
import ReinforcementLearningTrajectories: fetch, sample
3+
import ReinforcementLearningTrajectories: sample
44
import OnlineStats: mean, std
55

66
@testset "normalization.jl" begin
@@ -13,31 +13,29 @@ import OnlineStats: mean, std
1313
r = ((1.0:5.0) .+ i) .% 5
1414
push!(nt, (state = [r;], action = 1, reward = Float32(i), terminal = false))
1515
end
16-
push!(nt, (next_state = fill(m, 5), next_action = 1)) #does not update because next_state is not in keys of normlizers. Is this desirable or not ?
16+
push!(nt, (next_state = fill(m, 5), next_action = 1)) #does not update because next_state is not in keys of normalizers. Is this desirable or not ?
1717

1818
@test mean(nt.normalizers[:reward].os) == m && std(nt.normalizers[:reward].os) == s
1919
@test all(nt.normalizers[:state].os) do moments
2020
mean(moments) == m && std(moments) == s
2121
end
2222

23-
unnormalized_batch = fetch(t, [1:5;])
23+
unnormalized_batch = t[[1:5;]]
2424
@test unnormalized_batch[:reward] == [0:4;]
2525
@test extrema(unnormalized_batch[:state]) == (0, 4)
26-
normalized_batch = fetch(nt, [1:5;])
27-
@test normalized_batch[:reward] ([0:4;] .- m)./s
28-
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./s)
29-
@test normalized_batch[:state][:,5] ([0:4;] .- m)./s
30-
#check for no mutation
31-
unnormalized_batch = fetch(t, [1:5;])
32-
@test unnormalized_batch[:reward] == [0:4;]
33-
@test extrema(unnormalized_batch[:state]) == (0, 4)
34-
#=
26+
normalized_batch = nt[[1:5;]]
27+
3528
traj = Trajectory(
3629
container = nt,
37-
sampler = BatchSampler(10),
30+
sampler = BatchSampler(1000),
3831
controller = InsertSampleRatioController(ratio = Inf, threshold = 0)
3932
)
40-
41-
batch = sample(traj)=#
33+
normalized_batch = sample(traj)
34+
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./s)
35+
@test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m)./s)
36+
#check for no mutation
37+
unnormalized_batch = t[[1:5;]]
38+
@test unnormalized_batch[:reward] == [0:4;]
39+
@test extrema(unnormalized_batch[:state]) == (0, 4)
4240

4341
end

0 commit comments

Comments
 (0)