Skip to content

Add CircularPrioritizedTraces #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export CircularArraySARTTraces

import CircularArrayBuffers

const CircularArraySARTTraces = Traces{
SS′AA′RT,
<:Tuple{
Expand Down Expand Up @@ -29,3 +31,5 @@ function CircularArraySARTTraces(;
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTTraces) = CircularArrayBuffers.capacity(t.traces[end])
62 changes: 62 additions & 0 deletions src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
export CircularPrioritizedTraces

using CircularArrayBuffers: capacity, CircularVectorBuffer

struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
keys::CircularVectorBuffer{Int,Vector{Int}}
priorities::SumTree{Float32}
traces::T
default_priority::Float32
end

function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
new_names = (:key, :priority, names...)
new_Ts = Tuple{Int,Float32,Ts.parameters...}
c = capacity(traces)
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
CircularVectorBuffer{Int}(c),
SumTree(c),
traces,
default_priority
)
end

function Base.push!(t::CircularPrioritizedTraces, x)
push!(t.traces, x)
if length(t.traces) == 1
push!(t.keys, 1)
push!(t.priorities, t.default_priority)
elseif length(t.traces) > 1
push!(t.keys, t.keys[end] + 1)
push!(t.priorities, t.default_priority)
else
# may be partial inserting at the first step, ignore it
end
end

function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
if k === :priority
@assert length(vs) == length(keys)
for (i, v) in zip(keys, vs)
if t.keys[1] <= i <= t.keys[end]
t.priorities[i-t.keys[1]+1] = v
end
end
else
@error "unsupported yet"
end
end

Base.size(t::CircularPrioritizedTraces) = size(t.traces)

function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
if s === :priority
Trace(ts.priorities)
elseif s === :key
Trace(ts.keys)
else
ts.traces[s]
end
end

Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
1 change: 1 addition & 0 deletions src/common/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ const SS′LL′AA′RT = (SS′..., LL′..., AA′..., RT...)
include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
include("CircularPrioritizedTraces.jl")
4 changes: 3 additions & 1 deletion src/common/sum_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ export SumTree

using Random

import CircularArrayBuffers

"""
SumTree(capacity::Int)
Efficiently sample and update weights.
Expand Down Expand Up @@ -69,7 +71,7 @@ mutable struct SumTree{T} <: AbstractVector{Int}
end
end

capacity(t::SumTree) = t.capacity
CircularArrayBuffers.capacity(t::SumTree) = t.capacity
Base.length(t::SumTree) = t.length
Base.size(t::SumTree) = (length(t),)
Base.eltype(t::SumTree{T}) where {T} = T
Expand Down
9 changes: 9 additions & 0 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ function sample(s::BatchSampler, t::AbstractTraces, names)
NamedTuple{names}(map(x -> t[x][inds], names))
end

#####

sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, keys(t.traces))

function sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> t.traces[x][inds], names)...))
end

#####
# MetaSampler
#####
Expand Down
4 changes: 3 additions & 1 deletion src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ export Trace, Traces, MultiplexTraces, Episode, Episodes

import MacroTools: @forward

import CircularArrayBuffers

#####

abstract type AbstractTrace{E} <: AbstractVector{E} end
Expand Down Expand Up @@ -31,7 +33,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)

@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, CircularArrayBuffers.capacity

#####

Expand Down
3 changes: 3 additions & 0 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ function Base.push!(t::Trajectory, x)
on_insert!(t.controller, 1)
end

Base.setindex!(t::Trajectory, v, k) = setindex!(t.container, v, k)

struct CallMsg
f::Any
args::Tuple
Expand All @@ -84,6 +86,7 @@ end

Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple()))
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple()))
Base.setindex!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, v, I...) = put!(t.controller.ch_in, CallMsg(Base.setindex!, (v, I...), NamedTuple()))

function Base.append!(t::Trajectory, x)
append!(t.container, x)
Expand Down
32 changes: 32 additions & 0 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,36 @@ end
)

@test t isa CircularArraySLARTTraces
end

@testset "CircularPrioritizedTraces" begin
t = CircularPrioritizedTraces(
CircularArraySARTTraces(;
capacity=3
),
default_priority=1.0f0
)

push!(t, (state=0, action=0))

for i in 1:5
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
end

@test length(t) == 3

s = BatchSampler(5)

b = ReinforcementLearningTrajectories.sample(s, t)

t[:priority, [1, 2]] = [0, 0]

# shouldn't be changed since [1,2] are old keys
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]

t[:priority, [3, 4, 5]] = [0, 1, 0]

b = ReinforcementLearningTrajectories.sample(s, t)

@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
end