Skip to content

Commit f09d328

Browse files
authored
Merge pull request #26 from findmyway/tianjun/add_prioritized_traces
Add CircularPrioritizedTraces
2 parents 42f0bc7 + ec7a63d commit f09d328

File tree

8 files changed

+117
-2
lines changed

8 files changed

+117
-2
lines changed

src/common/CircularArraySARTTraces.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
export CircularArraySARTTraces
22

3+
import CircularArrayBuffers
4+
35
const CircularArraySARTTraces = Traces{
46
SS′AA′RT,
57
<:Tuple{
@@ -29,3 +31,5 @@ function CircularArraySARTTraces(;
2931
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3032
)
3133
end
34+
35+
CircularArrayBuffers.capacity(t::CircularArraySARTTraces) = CircularArrayBuffers.capacity(t.traces[end])
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
export CircularPrioritizedTraces
2+
3+
using CircularArrayBuffers: capacity, CircularVectorBuffer
4+
5+
struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
6+
keys::CircularVectorBuffer{Int,Vector{Int}}
7+
priorities::SumTree{Float32}
8+
traces::T
9+
default_priority::Float32
10+
end
11+
12+
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
13+
new_names = (:key, :priority, names...)
14+
new_Ts = Tuple{Int,Float32,Ts.parameters...}
15+
c = capacity(traces)
16+
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
17+
CircularVectorBuffer{Int}(c),
18+
SumTree(c),
19+
traces,
20+
default_priority
21+
)
22+
end
23+
24+
function Base.push!(t::CircularPrioritizedTraces, x)
25+
push!(t.traces, x)
26+
if length(t.traces) == 1
27+
push!(t.keys, 1)
28+
push!(t.priorities, t.default_priority)
29+
elseif length(t.traces) > 1
30+
push!(t.keys, t.keys[end] + 1)
31+
push!(t.priorities, t.default_priority)
32+
else
33+
# may be partial inserting at the first step, ignore it
34+
end
35+
end
36+
37+
function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
38+
if k === :priority
39+
@assert length(vs) == length(keys)
40+
for (i, v) in zip(keys, vs)
41+
if t.keys[1] <= i <= t.keys[end]
42+
t.priorities[i-t.keys[1]+1] = v
43+
end
44+
end
45+
else
46+
@error "unsupported yet"
47+
end
48+
end
49+
50+
Base.size(t::CircularPrioritizedTraces) = size(t.traces)
51+
52+
function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
53+
if s === :priority
54+
Trace(ts.priorities)
55+
elseif s === :key
56+
Trace(ts.keys)
57+
else
58+
ts.traces[s]
59+
end
60+
end
61+
62+
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))

src/common/common.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ const SS′LL′AA′RT = (SS′..., LL′..., AA′..., RT...)
1414
include("sum_tree.jl")
1515
include("CircularArraySARTTraces.jl")
1616
include("CircularArraySLARTTraces.jl")
17+
include("CircularPrioritizedTraces.jl")

src/common/sum_tree.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ export SumTree
22

33
using Random
44

5+
import CircularArrayBuffers
6+
57
"""
68
SumTree(capacity::Int)
79
Efficiently sample and update weights.
@@ -69,7 +71,7 @@ mutable struct SumTree{T} <: AbstractVector{Int}
6971
end
7072
end
7173

72-
capacity(t::SumTree) = t.capacity
74+
CircularArrayBuffers.capacity(t::SumTree) = t.capacity
7375
Base.length(t::SumTree) = t.length
7476
Base.size(t::SumTree) = (length(t),)
7577
Base.eltype(t::SumTree{T}) where {T} = T

src/samplers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ function sample(s::BatchSampler, t::AbstractTraces, names)
3434
NamedTuple{names}(map(x -> t[x][inds], names))
3535
end
3636

37+
#####
38+
39+
sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, keys(t.traces))
40+
41+
function sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
42+
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
43+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> t.traces[x][inds], names)...))
44+
end
45+
3746
#####
3847
# MetaSampler
3948
#####

src/traces.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ export Trace, Traces, MultiplexTraces, Episode, Episodes
22

33
import MacroTools: @forward
44

5+
import CircularArrayBuffers
6+
57
#####
68

79
abstract type AbstractTrace{E} <: AbstractVector{E} end
@@ -31,7 +33,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
3133
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
3234
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
3335

34-
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
36+
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, CircularArrayBuffers.capacity
3537

3638
#####
3739

src/trajectory.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ function Base.push!(t::Trajectory, x)
7676
on_insert!(t.controller, 1)
7777
end
7878

79+
Base.setindex!(t::Trajectory, v, k) = setindex!(t.container, v, k)
80+
7981
struct CallMsg
8082
f::Any
8183
args::Tuple
@@ -84,6 +86,7 @@ end
8486

8587
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple()))
8688
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple()))
89+
Base.setindex!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, v, I...) = put!(t.controller.ch_in, CallMsg(Base.setindex!, (v, I...), NamedTuple()))
8790

8891
function Base.append!(t::Trajectory, x)
8992
append!(t.container, x)

test/common.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,36 @@ end
102102
)
103103

104104
@test t isa CircularArraySLARTTraces
105+
end
106+
107+
@testset "CircularPrioritizedTraces" begin
108+
t = CircularPrioritizedTraces(
109+
CircularArraySARTTraces(;
110+
capacity=3
111+
),
112+
default_priority=1.0f0
113+
)
114+
115+
push!(t, (state=0, action=0))
116+
117+
for i in 1:5
118+
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
119+
end
120+
121+
@test length(t) == 3
122+
123+
s = BatchSampler(5)
124+
125+
b = ReinforcementLearningTrajectories.sample(s, t)
126+
127+
t[:priority, [1, 2]] = [0, 0]
128+
129+
# shouldn't be changed since [1,2] are old keys
130+
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]
131+
132+
t[:priority, [3, 4, 5]] = [0, 1, 0]
133+
134+
b = ReinforcementLearningTrajectories.sample(s, t)
135+
136+
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
105137
end

0 commit comments

Comments
 (0)