Skip to content

NStepBatchSampler #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.4"
version = "0.3.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion src/common/sum_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Random.rand(rng::AbstractRNG, t::SumTree{T}) where {T} = get(t, rand(rng, T) * t
Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)

function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
inds, priorities = Vector{Int}(undef, n), Vector{Float64}(undef, n)
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
for i in 1:n
v = (i - 1 + rand(rng, T)) / n
ind, p = get(t, v * t.tree[1])
Expand Down
109 changes: 63 additions & 46 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,75 +163,92 @@ end

export NStepBatchSampler

mutable struct NStepBatchSampler{traces}
"""
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)

Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
that in up to `n > 1` steps later in the buffer. The reward will be
the discounted sum of the `n` rewards, with `γ` as the discount factor.

NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
of partial observability, for example when the state is approximated by `stack_size` consecutive
frames.
"""
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
n::Int # !!! n starts from 1
γ::Float32
batch_size::Int
stack_size::Union{Nothing,Int}
stack_size::S
rng::Any
end

NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)

NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stack_size == 1 ? nothing : stack_size
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
end

function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
# think about the extreme case where s.stack_size == 1 and s.n == 1
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
end
return range, ns
end

function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
valid_range = valid_range_nbatchsampler(s, ts)
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
StatsBase.sample(s, ts, Val(names))
end

function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
valid_range = valid_range_nbatchsampler(s, ts)
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
weights, ns = valid_range(s, t)
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
fetch(s, t, Val(names), inds, ns)
end


function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
if isnothing(nbs.stack_size)
s = ts[:state][inds]
s′ = ts[:next_state][inds.+(nbs.n-1)]
else
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
end

a = ts[:action][inds]
t_horizon = ts[:terminal][[x + j for j in 0:nbs.n-1, x in inds]]
r_horizon = ts[:reward][[x + j for j in 0:nbs.n-1, x in inds]]

@assert ndims(t_horizon) == 2
t = any(t_horizon, dims=1) |> vec

@assert ndims(r_horizon) == 2
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
end

NamedTuple{SS′ART}(map(collect, (s, s′, a, r, t)))
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
end

function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
s, s′, a, r, t = StatsBase.sample(s, ts, Val(SSART), inds)
l = consecutive_view(ts[:next_legal_actions_mask], inds)
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
#state and next_state have specialized fetch methods due to stack_size
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]

#reward due to discounting
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
rewards = Vector{eltype(trace)}(undef, length(inds))
for (i,idx) in enumerate(inds)
rewards_to_go = trace[idx:idx+ns[i]-1]
rewards[i] = foldr((x,y)->x + s.γ*y, rewards_to_go)
end
return rewards
end
#terminal is that of the nth step
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = trace[inds .+ ns .- 1]
#right multiplex traces must be n-step sampled
fetch(::NStepBatchSampler, trace::RelativeTrace{1,0} , ::Val, inds, ns) = trace[inds .+ ns .- 1]
#normal trace types are fetched at inds
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds] #other types of trace are sampled normally

function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
t = e.traces
st = deepcopy(t.priorities)
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
valids, ns = valid_range(s,e)
st .*= valids[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
inds, priorities = rand(s.rng, st, s.batch_size)

merge(
(key=t.keys[inds], priority=priorities),
StatsBase.sample(s, t.traces, Val(names), inds)
fetch(s, e, Val(names), inds, ns)
)
end

Expand Down
3 changes: 2 additions & 1 deletion src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)

@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, Base.eltype

#By default, AbstractTrace have infinity capacity (like a Vector). This method is specialized for
#CircularArraySARTSTraces in common.jl. The functions below are made that way to avoid type piracy.
Expand Down Expand Up @@ -94,6 +94,7 @@ Base.getindex(s::RelativeTrace{0,-1}, I) = getindex(s.trace, I)
Base.getindex(s::RelativeTrace{1,0}, I) = getindex(s.trace, I .+ 1)
Base.setindex!(s::RelativeTrace{0,-1}, v, I) = setindex!(s.trace, v, I)
Base.setindex!(s::RelativeTrace{1,0}, v, I) = setindex!(s.trace, v, I .+ 1)
Base.eltype(t::RelativeTrace) = eltype(t.trace)
capacity(t::RelativeTrace) = capacity(t.trace)

"""
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ReinforcementLearningTrajectories
using CircularArrayBuffers, DataStructures
using StableRNGs
using Test
import ReinforcementLearningTrajectories.StatsBase.sample
using CUDA
using Adapt
using Random
Expand Down
185 changes: 65 additions & 120 deletions test/samplers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ReinforcementLearningTrajectories.fetch
@testset "Samplers" begin
@testset "BatchSampler" begin
sz = 32
Expand Down Expand Up @@ -74,132 +75,76 @@

#! format: off
@testset "NStepSampler" begin
γ = 0.9
γ = 0.99
n_stack = 2
n_horizon = 3
batch_size = 4

t1 = MultiplexTraces{(:state, :next_state)}(1:10) +
MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) +
Traces(
reward=1:9,
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
)

s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)

xs = RLTrajectories.StatsBase.sample(s1, t1)

@test size(xs.state) == (n_stack, batch_size)
@test size(xs.next_state) == (n_stack, batch_size)
@test size(xs.action) == (batch_size,)
@test size(xs.reward) == (batch_size,)
@test size(xs.terminal) == (batch_size,)


state_size = (2,3)
n_state = reduce(*, state_size)
total_length = 10
t2 = MultiplexTraces{(:state, :next_state)}(
reshape(1:n_state * total_length, state_size..., total_length)
) +
MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) +
Traces(
reward=1:total_length-1,
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
)

xs2 = RLTrajectories.StatsBase.sample(s1, t2)

@test size(xs2.state) == (state_size..., n_stack, batch_size)
@test size(xs2.next_state) == (state_size..., n_stack, batch_size)
@test size(xs2.action) == (batch_size,)
@test size(xs2.reward) == (batch_size,)
@test size(xs2.terminal) == (batch_size,)

inds = [3, 5, 7]
xs3 = RLTrajectories.StatsBase.sample(s1, t2, Val(SS′ART), inds)

@test xs3.state == cat(
(
reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack)
for i in inds
)...
;dims=length(state_size) + 2
)

@test xs3.next_state == xs3.state .+ (n_state * n_horizon)
@test xs3.action == iseven.(inds)
@test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds]

# manual calculation
@test xs3.reward[1] ≈ 3 + γ * 4 # terminated at step 4
@test xs3.reward[2] ≈ 5 + γ * (6 + γ * 7)
@test xs3.reward[3] ≈ 7 + γ * (8 + γ * 9)
end
#! format: on

@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
n=1
γ=0.99f0

t = Trajectory(
container=CircularPrioritizedTraces(
CircularArraySARTSTraces(
capacity=5,
state=Float32 => (4,),
);
default_priority=100.0f0
),
sampler=NStepBatchSampler{SS′ART}(
n=n,
γ=γ,
batch_size=32,
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)
batch_size = 1000
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)

push!(t, (state = 1, action = true))
for i = 1:9
push!(t, (state = i+1, action = true, reward = i, terminal = false))
push!(eb, (state = 1, action = 1))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
end

b = RLTrajectories.StatsBase.sample(t)
@test haskey(b, :priority)
@test sum(b.action .== 0) == 0
end


@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
n=1
γ=0.99f0

t = Trajectory(
container=CircularArraySARTSTraces(
capacity=5,
state=Float32 => (4,),
),
sampler=NStepBatchSampler{SS′ART}(
n=n,
γ=γ,
batch_size=32,
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)

push!(t, (state = 1, action = true))
for i = 1:9
push!(t, (state = i+1, action = true, reward = i, terminal = false))
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
end
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
for key in keys(eb)
@test haskey(batch, key)
end
#state: samples with stack_size
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
@test states == [1 2 3 4 7 8 9;
2 3 4 5 8 9 10]
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
#next_state: samples with stack_size and nsteps forward
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
@test next_states == [4 5 5 5 10 10 10;
5 6 6 6 11 11 11]
@test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state])))
#action: samples normally
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
@test actions == inds
@test all(in(actions), unique(batch[:action]))
#next_action: is a multiplex trace: should automatically sample nsteps forward
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
@test next_actions == [5, 6, 6, 6, 11, 11, 11]
@test all(in(next_actions), unique(batch[:next_action]))
#reward: discounted sum
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
@test rewards ≈ [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10]
@test all(in(rewards), unique(batch[:reward]))
#terminal: nsteps forward
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
@test terminals == [0,1,1,1,0,0,0]

### CircularPrioritizedTraces and NStepBatchSampler
γ = 0.99
n_horizon = 3
batch_size = 4
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)

b = RLTrajectories.StatsBase.sample(t)
@test sum(b.action .== 0) == 0
push!(eb, (state = 1, action = 1))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
end
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
for key in (keys(eb)..., :key, :priority)
@test haskey(batch, key)
end
end

@testset "EpisodesSampler" begin
Expand Down