Skip to content

Multi step sampler #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
end

function StatsBase.sample(s::BatchSampler, nt::NormalizedTraces, names, weights = StatsBase.UnitWeights{Int}(length(nt)))
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batch_size)
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batchsize)
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
NamedTuple{names}(collect(maybe_normalize(nt[x][inds], x)) for x in names)
end
Expand Down
132 changes: 102 additions & 30 deletions src/samplers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Random
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler

struct SampleGenerator{S,T}
sampler::S
Expand Down Expand Up @@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
export BatchSampler

struct BatchSampler{names}
batch_size::Int
batchsize::Int
rng::Random.AbstractRNG
end

"""
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
BatchSampler{names}(;batchsize, rng=Random.GLOBAL_RNG)
BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)

Uniformly sample **ONE** batch of `batch_size` examples for each trace specified
Uniformly sample **ONE** batch of `batchsize` examples for each trace specified
in `names`. If `names` is not set, all the traces will be sampled.
"""
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
BatchSampler(batchsize; kw...) = BatchSampler(; batchsize=batchsize, kw...)
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
BatchSampler{names}(batchsize; kw...) where {names} = BatchSampler{names}(; batchsize=batchsize, kw...)
BatchSampler{names}(; batchsize, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batchsize, rng)

StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = StatsBase.sample(s, t, keys(t))
StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = StatsBase.sample(s, t, names)

function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batchsize)
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
end

Expand All @@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
w .*= e.sampleable_inds[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
end

function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
inds, priorities = rand(s.rng, t.priorities, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
end

Expand Down Expand Up @@ -165,41 +165,42 @@ end
export NStepBatchSampler

"""
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)

NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)

Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
that in up to `n > 1` steps later in the buffer. The reward will be
the discounted sum of the `n` rewards, with `γ` as the discount factor.

NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
of partial observability, for example when the state is approximated by `stack_size` consecutive
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
of partial observability, for example when the state is approximated by `stacksize` consecutive
frames.
"""
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
n::Int # !!! n starts from 1
γ::Float32
batch_size::Int
stack_size::S
rng::Any
batchsize::Int
stacksize::S
rng::R
end

NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stack_size == 1 ? nothing : stack_size
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
ss = stacksize == 1 ? nothing : stacksize
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
end

#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
end
return range, ns
Expand All @@ -211,19 +212,19 @@ end

function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
weights, ns = valid_range(s, t)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
fetch(s, t, Val(names), inds, ns)
end

function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
end

#state and next_state have specialized fetch methods due to stack_size
#state and next_state have specialized fetch methods due to stacksize
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stacksize+1:0, x in inds]]
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stacksize+1:0, (idx,x) in enumerate(inds)]]

#reward due to discounting
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
Expand All @@ -247,7 +248,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
fetch(s, e, Val(names), inds, ns)
Expand Down Expand Up @@ -297,3 +298,74 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)

return [make_episode(t, r, names) for r in ranges]
end

#####MultiStepSampler

"""
MultiStepSampler{names}(batchsize, stacksize, n, rng)

Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
`x`. The samples are returned in an array of batchsize elements. For each element, n is
truncated by the end of its episode. This means that the dimensions of each sample are not
the same.
"""
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
n::Int
batchsize::Int
stacksize::Int
rng::R
end

MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
function MultiStepSampler{names}(; n, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stacksize == 1 ? nothing : stacksize
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
end

function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
end
return range, ns
end

function StatsBase.sample(s::MultiStepSampler{names}, ts) where {names}
StatsBase.sample(s, ts, Val(names))
end

function StatsBase.sample(s::MultiStepSampler, t::EpisodesBuffer, ::Val{names}) where names
weights, ns = valid_range(s, t)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
fetch(s, t, Val(names), inds, ns)
end

function fetch(s::MultiStepSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
end

function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
end

function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
end

function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
t = e.traces
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
fetch(s, e, Val(names), inds, ns)
)
end
56 changes: 50 additions & 6 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ import ReinforcementLearningTrajectories.fetch
γ = 0.99
n_stack = 2
n_horizon = 3
batch_size = 1000
batchsize = 1000
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stacksize=n_stack, batchsize=batchsize)

push!(eb, (state = 1, action = 1))
for i = 1:5
Expand All @@ -98,12 +98,12 @@ import ReinforcementLearningTrajectories.fetch
for key in keys(eb)
@test haskey(batch, key)
end
#state: samples with stack_size
#state: samples with stacksize
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
@test states == [1 2 3 4 7 8 9;
2 3 4 5 8 9 10]
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
#next_state: samples with stack_size and nsteps forward
#next_state: samples with stacksize and nsteps forward
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
@test next_states == [4 5 5 5 10 10 10;
5 6 6 6 11 11 11]
Expand All @@ -127,9 +127,9 @@ import ReinforcementLearningTrajectories.fetch
### CircularPrioritizedTraces and NStepBatchSampler
γ = 0.99
n_horizon = 3
batch_size = 4
batchsize = 4
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)

push!(eb, (state = 1, action = 1))
for i = 1:5
Expand Down Expand Up @@ -196,4 +196,48 @@ import ReinforcementLearningTrajectories.fetch
@test length(b[2][:state]) == 5
@test !haskey(b[1], :action)
end
@testset "MultiStepSampler" begin
n_stack = 2
n_horizon = 3
batchsize = 1000
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
s1 = MultiStepSampler(eb, n=n_horizon, stacksize=n_stack, batchsize=batchsize)

push!(eb, (state = 1, action = 1))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
end
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
for key in keys(eb)
@test haskey(batch, key)
end
#state and next_state: samples with stacksize
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
@test states == [[1 2 3; 2 3 4], [2 3 4; 3 4 5], [3 4; 4 5], [4; 5;;], [7 8 9; 8 9 10], [8 9; 9 10], [9; 10;;]]
@test all(in(states), batch[:state])
#next_state: samples with stacksize and nsteps forward
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
@test next_states == [[2 3 4; 3 4 5], [3 4 5; 4 5 6], [4 5; 5 6], [5; 6;;], [8 9 10; 9 10 11], [9 10; 10 11], [10; 11;;]]
@test all(in(next_states), batch[:next_state])
#all other traces sample normally
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
@test actions == [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]
@test all(in(actions), batch[:action])
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
@test next_actions == [a .+ 1 for a in [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]]
@test all(in(next_actions), batch[:next_action])
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
@test rewards == actions
@test all(in(rewards), batch[:reward])
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
@test terminals == [[a == 5 ? 1 : 0 for a in acs] for acs in actions]
end
end