Skip to content

Commit 5455132

Browse files
committed
fixes and tests
1 parent 87fa7df commit 5455132

File tree

4 files changed

+103
-147
lines changed

4 files changed

+103
-147
lines changed

src/samplers.jl

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -184,61 +184,71 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
184184
rng::Any
185185
end
186186

187-
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
187+
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
188188
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
189189
@assert n >= 1 "n must be ≥ 1."
190-
NStepBatchSampler{names}(n, γ, batch_size, stack_size == 1 ? nothing : stack_size, rng)
190+
ss = stack_size == 1 ? nothing : stack_size
191+
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
191192
end
192193

193-
function valid_range_nbatchsampler(s::NStepBatchSampler, eb::EpisodesBuffer)
194+
#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
195+
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
194196
range = copy(eb.sampleable_inds)
197+
ns = Vector{Int}(undef, length(eb.sampleable_inds))
195198
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
196199
for idx in eachindex(range)
197-
valid = eb.step_numbers[idx] >= stack_size && eb.step_numbers[idx] <= eb.episodes_lengths[idx] + 1 - eb.n && eb.sampleable_inds[idx]
198-
range[idx] = valid
200+
step_number = eb.step_numbers[idx]
201+
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
202+
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
199203
end
200-
return range
204+
return range, ns
201205
end
202206

203207
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
204208
StatsBase.sample(s, ts, Val(names))
205209
end
206210

207-
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, names)
208-
valid_range = valid_range_nbatchsampler(s, t)
209-
StatsBase.sample(s, t.traces, names, StatsBase.FrequencyWeights(valid_range))
211+
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
212+
weights, ns = valid_range(s, t)
213+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
214+
fetch(s, t, Val(names), inds, ns)
210215
end
211216

212-
function StatsBase.sample(s::NStepBatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
213-
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
214-
NamedTuple{names}map(name -> collect(fetch(s, ts[name], Val(name), inds)), names)
217+
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
218+
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
215219
end
216220

217221
#state and next_state have specialized fetch methods due to stack_size
218-
fetch(::NStepBatchSampler{names, Nothing}, trace, ::Val{:state}, inds) where {names} = trace[inds]
219-
fetch(s::NStepBatchSampler{names, Int}, trace, ::Val{:state}, inds) where {names} = trace[[x + s.n - 1 + i for i in -s.stack_size+1:0, x in inds]]
220-
fetch(s::NStepBatchSampler{names, Nothing}, trace, ::Val{:next_state}, inds) where {names} = trace[inds.+(s.n-1)]
221-
fetch(s::NStepBatchSampler{names, Int}, trace, ::Val{:next_state}, inds) where {names} = trace[[x + s.n - 1 + i for i in -s.stack_size+1:0, x in inds]]
222+
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
223+
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
224+
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
225+
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]
226+
222227
#reward due to discounting
223-
function fetch(s::NStepBatchSampler{names}, trace, ::Val{:reward}, inds) where {names}
224-
rewards = trace[[x + j for j in 0:nbs.n-1, x in inds]]
225-
return reduce((x,y)->x + s.γ*y, rewards, init = zero(eltype(rewards)), dims = 1)
228+
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
229+
rewards = Vector{eltype(trace)}(undef, length(inds))
230+
for (i,idx) in enumerate(inds)
231+
rewards_to_go = trace[idx:idx+ns[i]-1]
232+
rewards[i] = foldr((x,y)->x + s.γ*y, rewards_to_go)
233+
end
234+
return rewards
226235
end
227236
#terminal is that of the nth step
228-
fetch(s::NStepBatchSampler{names}, trace, ::Val{:terminal}, inds) where {names} = trace[inds.+s.n]
237+
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = trace[inds .+ ns .- 1]
229238
#right multiplex traces must be n-step sampled
230-
fetch(::NStepBatchSampler{names}, trace::RelativeTrace{1,0} , ::Val{<:Symbol}, inds) where {names} = trace[inds.+(s.n-1)]
239+
fetch(::NStepBatchSampler, trace::RelativeTrace{1,0} , ::Val, inds, ns) = trace[inds .+ ns .- 1]
231240
#normal trace types are fetched at inds
232-
fetch(::NStepBatchSampler{names}, trace, ::Val{<:Symbol}, inds) where {names} = trace[inds] #other types of trace are sampled normaly
241+
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds] #other types of trace are sampled normally
233242

234243
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
235244
t = e.traces
236245
st = deepcopy(t.priorities)
237-
st .*= valid_range_nbatchsampler(s,e) #temporary sumtree that puts 0 priority to non sampleable indices.
246+
valids, ns = valid_range(s,e)
247+
st .*= valids[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
238248
inds, priorities = rand(s.rng, st, s.batch_size)
239249
merge(
240250
(key=t.keys[inds], priority=priorities),
241-
fetch(s, t.traces, Val(names), inds)
251+
fetch(s, e, Val(names), inds, ns)
242252
)
243253
end
244254

src/traces.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
4848
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
4949
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
5050

51-
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
51+
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, Base.eltype
5252

5353
#By default, AbstractTrace have infinity capacity (like a Vector). This method is specialized for
5454
#CircularArraySARTSTraces in common.jl. The functions below are made that way to avoid type piracy.
@@ -94,6 +94,7 @@ Base.getindex(s::RelativeTrace{0,-1}, I) = getindex(s.trace, I)
9494
Base.getindex(s::RelativeTrace{1,0}, I) = getindex(s.trace, I .+ 1)
9595
Base.setindex!(s::RelativeTrace{0,-1}, v, I) = setindex!(s.trace, v, I)
9696
Base.setindex!(s::RelativeTrace{1,0}, v, I) = setindex!(s.trace, v, I .+ 1)
97+
Base.eltype(t::RelativeTrace) = eltype(t.trace)
9798
capacity(t::RelativeTrace) = capacity(t.trace)
9899

99100
"""

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using ReinforcementLearningTrajectories
22
using CircularArrayBuffers, DataStructures
33
using Test
4+
import ReinforcementLearningTrajectories.StatsBase.sample
45
using CUDA
56
using Adapt
6-
import ReinforcementLearningTrajectories.StatsBase.sample
77

88
struct TestAdaptor end
99

test/samplers.jl

Lines changed: 66 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
@testset "Samplers" begin
1+
import ReinforcementLearningTrajectories.fetch
2+
#@testset "Samplers" begin
23
@testset "BatchSampler" begin
34
sz = 32
45
s = BatchSampler(sz)
@@ -74,132 +75,76 @@
7475

7576
#! format: off
7677
@testset "NStepSampler" begin
77-
γ = 0.9
78+
γ = 0.99
7879
n_stack = 2
7980
n_horizon = 3
80-
batch_size = 4
81-
82-
t1 = MultiplexTraces{(:state, :next_state)}(1:10) +
83-
MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) +
84-
Traces(
85-
reward=1:9,
86-
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
87-
)
88-
89-
s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
90-
91-
xs = RLTrajectories.StatsBase.sample(s1, t1)
92-
93-
@test size(xs.state) == (n_stack, batch_size)
94-
@test size(xs.next_state) == (n_stack, batch_size)
95-
@test size(xs.action) == (batch_size,)
96-
@test size(xs.reward) == (batch_size,)
97-
@test size(xs.terminal) == (batch_size,)
98-
99-
100-
state_size = (2,3)
101-
n_state = reduce(*, state_size)
102-
total_length = 10
103-
t2 = MultiplexTraces{(:state, :next_state)}(
104-
reshape(1:n_state * total_length, state_size..., total_length)
105-
) +
106-
MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) +
107-
Traces(
108-
reward=1:total_length-1,
109-
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
110-
)
111-
112-
xs2 = RLTrajectories.StatsBase.sample(s1, t2)
113-
114-
@test size(xs2.state) == (state_size..., n_stack, batch_size)
115-
@test size(xs2.next_state) == (state_size..., n_stack, batch_size)
116-
@test size(xs2.action) == (batch_size,)
117-
@test size(xs2.reward) == (batch_size,)
118-
@test size(xs2.terminal) == (batch_size,)
119-
120-
inds = [3, 5, 7]
121-
xs3 = RLTrajectories.StatsBase.sample(s1, t2, Val(SS′ART), inds)
122-
123-
@test xs3.state == cat(
124-
(
125-
reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack)
126-
for i in inds
127-
)...
128-
;dims=length(state_size) + 2
129-
)
130-
131-
@test xs3.next_state == xs3.state .+ (n_state * n_horizon)
132-
@test xs3.action == iseven.(inds)
133-
@test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds]
134-
135-
# manual calculation
136-
@test xs3.reward[1] 3 + γ * 4 # terminated at step 4
137-
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
138-
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
139-
end
140-
#! format: on
141-
142-
@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
143-
n=1
144-
γ=0.99f0
145-
146-
t = Trajectory(
147-
container=CircularPrioritizedTraces(
148-
CircularArraySARTSTraces(
149-
capacity=5,
150-
state=Float32 => (4,),
151-
);
152-
default_priority=100.0f0
153-
),
154-
sampler=NStepBatchSampler{SS′ART}(
155-
n=n,
156-
γ=γ,
157-
batch_size=32,
158-
),
159-
controller=InsertSampleRatioController(
160-
threshold=100,
161-
n_inserted=-1
162-
)
163-
)
81+
batch_size = 1000
82+
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
83+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
16484

165-
push!(t, (state = 1, action = true))
166-
for i = 1:9
167-
push!(t, (state = i+1, action = true, reward = i, terminal = false))
85+
push!(eb, (state = 1, action = 1))
86+
for i = 1:5
87+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
16888
end
169-
170-
b = RLTrajectories.StatsBase.sample(t)
171-
@test haskey(b, :priority)
172-
@test sum(b.action .== 0) == 0
173-
end
174-
175-
176-
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
177-
n=1
178-
γ=0.99f0
179-
180-
t = Trajectory(
181-
container=CircularArraySARTSTraces(
182-
capacity=5,
183-
state=Float32 => (4,),
184-
),
185-
sampler=NStepBatchSampler{SS′ART}(
186-
n=n,
187-
γ=γ,
188-
batch_size=32,
189-
),
190-
controller=InsertSampleRatioController(
191-
threshold=100,
192-
n_inserted=-1
193-
)
194-
)
195-
196-
push!(t, (state = 1, action = true))
197-
for i = 1:9
198-
push!(t, (state = i+1, action = true, reward = i, terminal = false))
89+
push!(eb, (state = 7, action = 7))
90+
for (j,i) = enumerate(8:11)
91+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
92+
end
93+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
94+
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
95+
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
96+
inds = [i for i in eachindex(weights) if weights[i] == 1]
97+
batch = sample(s1, eb)
98+
for key in keys(eb)
99+
@test haskey(batch, key)
199100
end
101+
#state: samples with stack_size
102+
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
103+
@test states == [1 2 3 4 7 8 9;
104+
2 3 4 5 8 9 10]
105+
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
106+
#next_state: samples with stack_size and nsteps forward
107+
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
108+
@test next_states == [4 5 5 5 10 10 10;
109+
5 6 6 6 11 11 11]
110+
@test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state])))
111+
#action: samples normally
112+
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
113+
@test actions == inds
114+
@test all(in(actions), unique(batch[:action]))
115+
#next_action: is a multiplex trace: should automatically sample nsteps forward
116+
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
117+
@test next_actions == [5, 6, 6, 6, 11, 11, 11]
118+
@test all(in(next_actions), unique(batch[:next_action]))
119+
#reward: discounted sum
120+
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
121+
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10]
122+
@test all(in(rewards), unique(batch[:reward]))
123+
#terminal: nsteps forward
124+
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
125+
@test terminals == [0,1,1,1,0,0,0]
126+
127+
### CircularPrioritizedTraces and NStepBatchSampler
128+
γ = 0.99
129+
n_horizon = 3
130+
batch_size = 4
131+
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)
200133

201-
b = RLTrajectories.StatsBase.sample(t)
202-
@test sum(b.action .== 0) == 0
134+
push!(eb, (state = 1, action = 1))
135+
for i = 1:5
136+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
137+
end
138+
push!(eb, (state = 7, action = 7))
139+
for (j,i) = enumerate(8:11)
140+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
141+
end
142+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
143+
inds = [i for i in eachindex(weights) if weights[i] == 1]
144+
batch = sample(s1, eb)
145+
for key in (keys(eb)..., :key, :priority)
146+
@test haskey(batch, key)
147+
end
203148
end
204149

205150
@testset "EpisodesSampler" begin

0 commit comments

Comments
 (0)