Skip to content

Multi step sampler #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.4"
version = "0.3.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion src/common/sum_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Random.rand(rng::AbstractRNG, t::SumTree{T}) where {T} = get(t, rand(rng, T) * t
Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)

function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
inds, priorities = Vector{Int}(undef, n), Vector{Float64}(undef, n)
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
for i in 1:n
v = (i - 1 + rand(rng, T)) / n
ind, p = get(t, v * t.tree[1])
Expand Down
2 changes: 1 addition & 1 deletion src/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
end

function StatsBase.sample(s::BatchSampler, nt::NormalizedTraces, names, weights = StatsBase.UnitWeights{Int}(length(nt)))
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batch_size)
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batchsize)
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
NamedTuple{names}(collect(maybe_normalize(nt[x][inds], x)) for x in names)
end
Expand Down
220 changes: 155 additions & 65 deletions src/samplers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Random
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler

struct SampleGenerator{S,T}
sampler::S
Expand Down Expand Up @@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
export BatchSampler

struct BatchSampler{names}
batch_size::Int
batchsize::Int
rng::Random.AbstractRNG
end

"""
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
BatchSampler{names}(;batchsize, rng=Random.GLOBAL_RNG)
BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)

Uniformly sample **ONE** batch of `batch_size` examples for each trace specified
Uniformly sample **ONE** batch of `batchsize` examples for each trace specified
in `names`. If `names` is not set, all the traces will be sampled.
"""
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
BatchSampler(batchsize; kw...) = BatchSampler(; batchsize=batchsize, kw...)
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
BatchSampler{names}(batchsize; kw...) where {names} = BatchSampler{names}(; batchsize=batchsize, kw...)
BatchSampler{names}(; batchsize, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batchsize, rng)

StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = StatsBase.sample(s, t, keys(t))
StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = StatsBase.sample(s, t, names)

function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batchsize)
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
end

Expand All @@ -72,14 +72,15 @@ StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = Stats

function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}, names)
t = e.traces
st = deepcopy(t.priorities)
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
inds, priorities = rand(s.rng, st, s.batch_size)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
w .*= e.sampleable_inds[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
end

function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
inds, priorities = rand(s.rng, t.priorities, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
end

Expand Down Expand Up @@ -163,75 +164,93 @@ end

export NStepBatchSampler

mutable struct NStepBatchSampler{traces}
"""
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)

Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
that in up to `n > 1` steps later in the buffer. The reward will be
the discounted sum of the `n` rewards, with `γ` as the discount factor.

NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
of partial observability, for example when the state is approximated by `stacksize` consecutive
frames.
"""
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
n::Int # !!! n starts from 1
γ::Float32
batch_size::Int
stack_size::Union{Nothing,Int}
rng::Any
batchsize::Int
stacksize::S
rng::R
end

NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)

NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stacksize == 1 ? nothing : stacksize
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
end

function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
# think about the extreme case where s.stack_size == 1 and s.n == 1
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
end
return range, ns
end

function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
valid_range = valid_range_nbatchsampler(s, ts)
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
StatsBase.sample(s, ts, Val(names))
end

function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
valid_range = valid_range_nbatchsampler(s, ts)
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
weights, ns = valid_range(s, t)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
fetch(s, t, Val(names), inds, ns)
end


function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
if isnothing(nbs.stack_size)
s = ts[:state][inds]
s′ = ts[:next_state][inds.+(nbs.n-1)]
else
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
end

a = ts[:action][inds]
t_horizon = ts[:terminal][[x + j for j in 0:nbs.n-1, x in inds]]
r_horizon = ts[:reward][[x + j for j in 0:nbs.n-1, x in inds]]

@assert ndims(t_horizon) == 2
t = any(t_horizon, dims=1) |> vec

@assert ndims(r_horizon) == 2
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
end

NamedTuple{SS′ART}(map(collect, (s, s′, a, r, t)))
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
end

function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
s, s′, a, r, t = StatsBase.sample(s, ts, Val(SSART), inds)
l = consecutive_view(ts[:next_legal_actions_mask], inds)
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
#state and next_state have specialized fetch methods due to stacksize
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stacksize+1:0, x in inds]]
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stacksize+1:0, (idx,x) in enumerate(inds)]]

#reward due to discounting
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
rewards = Vector{eltype(trace)}(undef, length(inds))
for (i,idx) in enumerate(inds)
rewards_to_go = trace[idx:idx+ns[i]-1]
rewards[i] = foldr((x,y)->x + s.γ*y, rewards_to_go)
end
return rewards
end
#terminal is that of the nth step
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = trace[inds .+ ns .- 1]
#right multiplex traces must be n-step sampled
fetch(::NStepBatchSampler, trace::RelativeTrace{1,0} , ::Val, inds, ns) = trace[inds .+ ns .- 1]
#normal trace types are fetched at inds
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds] #other types of trace are sampled normally

function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
t = e.traces
st = deepcopy(t.priorities)
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
inds, priorities = rand(s.rng, st, s.batch_size)

p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=priorities),
StatsBase.sample(s, t.traces, Val(names), inds)
(key=t.keys[inds], priority=p[inds]),
fetch(s, e, Val(names), inds, ns)
)
end

Expand Down Expand Up @@ -278,3 +297,74 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)

return [make_episode(t, r, names) for r in ranges]
end

#####MultiStepSampler

"""
MultiStepSampler{names}(batchsize, stacksize, n, rng)

Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
`x`. The samples are returned in an array of batchsize elements. For each element, n is
truncated by the end of its episode. This means that the dimensions of each sample are not
the same.
"""
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
n::Int
batchsize::Int
stacksize::Int
rng::R
end

MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
function MultiStepSampler{names}(; n, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stacksize == 1 ? nothing : stacksize
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
end

function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
end
return range, ns
end

function StatsBase.sample(s::MultiStepSampler{names}, ts) where {names}
StatsBase.sample(s, ts, Val(names))
end

function StatsBase.sample(s::MultiStepSampler, t::EpisodesBuffer, ::Val{names}) where names
weights, ns = valid_range(s, t)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
fetch(s, t, Val(names), inds, ns)
end

function fetch(s::MultiStepSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
end

function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
end

function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
end

function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
t = e.traces
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
fetch(s, e, Val(names), inds, ns)
)
end
3 changes: 2 additions & 1 deletion src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)

@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, Base.eltype

#By default, AbstractTrace have infinity capacity (like a Vector). This method is specialized for
#CircularArraySARTSTraces in common.jl. The functions below are made that way to avoid type piracy.
Expand Down Expand Up @@ -94,6 +94,7 @@ Base.getindex(s::RelativeTrace{0,-1}, I) = getindex(s.trace, I)
Base.getindex(s::RelativeTrace{1,0}, I) = getindex(s.trace, I .+ 1)
Base.setindex!(s::RelativeTrace{0,-1}, v, I) = setindex!(s.trace, v, I)
Base.setindex!(s::RelativeTrace{1,0}, v, I) = setindex!(s.trace, v, I .+ 1)
Base.eltype(t::RelativeTrace) = eltype(t.trace)
capacity(t::RelativeTrace) = capacity(t.trace)

"""
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ReinforcementLearningTrajectories
using CircularArrayBuffers, DataStructures
using StableRNGs
using Test
import ReinforcementLearningTrajectories.StatsBase.sample
using CUDA
using Adapt
using Random
Expand Down
Loading