Skip to content

Commit ae42284

Browse files
authored
Merge branch 'main' into fix-prio
2 parents a9c5b5c + c89ed6f commit ae42284

File tree

4 files changed

+97
-5
lines changed

4 files changed

+97
-5
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1616
[compat]
1717
Adapt = "3"
1818
CircularArrayBuffers = "0.1"
19+
DataStructures = "0.18"
1920
ElasticArrays = "1"
2021
MacroTools = "0.5"
2122
OnlineStats = "1"
2223
StackViews = "0.1"
23-
julia = "1.9"
24-
DataStructures = "0.18"
2524
StatsBase = "0.34"
25+
julia = "1.9"
2626

2727
[extras]
2828
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
29+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3031

3132
[targets]
32-
test = ["Test", "CUDA"]
33+
test = ["Test", "CUDA", "StableRNGs"]

src/common/sum_tree.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,28 @@ function Base.empty!(t::SumTree)
131131
t
132132
end
133133

134+
"""
135+
correct_sample(t::SumTree, leaf_ind)
136+
Check whether the sampled leaf is valid and if not return another valid leaf close to it. Used to correct samples with zero priority which may occur due to numerical errors with floats.
137+
"""
138+
function correct_sample(t::SumTree, leaf_ind)
139+
p = t.tree[leaf_ind]
140+
# walk backwards until p != 0 or until leftmost leaf reached
141+
tmp_ind = leaf_ind
142+
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
143+
tmp_ind -= 1
144+
p = t.tree[tmp_ind]
145+
end
146+
# walk forwards until p != 0 or until rightmost leaf reached
147+
iszero(p) && (tmp_ind = leaf_ind)
148+
while iszero(p) && (tmp_ind - t.nparents) <= t.length
149+
tmp_ind += 1
150+
p = t.tree[tmp_ind]
151+
end
152+
return p, tmp_ind
153+
end
154+
155+
134156
function Base.get(t::SumTree, v)
135157
parent_ind = 1
136158
leaf_ind = parent_ind
@@ -152,7 +174,7 @@ function Base.get(t::SumTree, v)
152174
if leaf_ind <= t.nparents
153175
leaf_ind += t.capacity
154176
end
155-
p = t.tree[leaf_ind]
177+
p, leaf_ind = correct_sample(t, leaf_ind)
156178
ind = leaf_ind - t.nparents
157179
real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1
158180
real_ind, p
@@ -172,4 +194,4 @@ function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
172194
inds, priorities
173195
end
174196

175-
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)
197+
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
using ReinforcementLearningTrajectories
22
using CircularArrayBuffers, DataStructures
3+
using StableRNGs
34
using Test
45
import ReinforcementLearningTrajectories.StatsBase.sample
56
using CUDA
67
using Adapt
8+
using Random
9+
import ReinforcementLearningTrajectories.StatsBase.sample
10+
import StatsBase.countmap
11+
712

813
struct TestAdaptor end
914

@@ -13,6 +18,7 @@ Adapt.adapt_storage(to::TestAdaptor, x) = CUDA.functional() ? CUDA.cu(x) : x
1318

1419
@testset "ReinforcementLearningTrajectories.jl" begin
1520
include("traces.jl")
21+
include("sum_tree.jl")
1622
include("common.jl")
1723
include("samplers.jl")
1824
include("controllers.jl")

test/sum_tree.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function gen_rand_sumtree(n, seed, type::DataType=Float32)
2+
rng = StableRNG(seed)
3+
a = SumTree(type, n)
4+
append!(a, rand(rng, type, n))
5+
return a
6+
end
7+
8+
function gen_sumtree_with_zeros(n, seed, type::DataType=Float32)
9+
a = gen_rand_sumtree(n, seed, type)
10+
b = rand(StableRNG(seed), Bool, n)
11+
return copy_multiply(a, b)
12+
end
13+
14+
function copy_multiply(stree, m)
15+
new_tree = deepcopy(stree)
16+
new_tree .*= m
17+
return new_tree
18+
end
19+
20+
function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1)
21+
for _ in iters
22+
(_, p) = rand(rng, t)
23+
p == 0 && return false
24+
end
25+
return true
26+
end
27+
sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters)
28+
sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds)
29+
30+
31+
function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length)
32+
for i = 1:iters
33+
indices[i], priorities[i] = rand(rng, t)
34+
end
35+
imap = countmap(indices)
36+
est_pdf = Dict(k=>v/length(indices) for (k, v) in imap)
37+
ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t))
38+
abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)]
39+
return abserrs
40+
end
41+
sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters)
42+
function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n)
43+
p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
44+
i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
45+
results = Vector{Vector{Float64}}(undef, length(seeds))
46+
Threads.@threads for ix = 1:length(seeds)
47+
results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters)
48+
end
49+
return results
50+
end
51+
52+
@testset "SumTree" begin
53+
n = 1024
54+
seeds = 1:100
55+
nozero_iters=1024
56+
distr_iters=1024*10_000
57+
abstol = 0.05
58+
maxerr=0.01
59+
60+
@test sumtree_nozero(n, seeds, nozero_iters)
61+
@test all(x->all(x .< maxerr) && sum(abs2, x) < abstol,
62+
sumtree_distribution(n, seeds, distr_iters))
63+
end

0 commit comments

Comments
 (0)