Skip to content

Commit 0721dd2

Browse files
committed
init
1 parent 63fd846 commit 0721dd2

File tree

3 files changed

+114
-3
lines changed

3 files changed

+114
-3
lines changed

src/samplers.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
export BatchSampler
1+
export BatchSampler, MetaSampler, MultiBatchSampler
22

33
using Random
44

5-
struct BatchSampler
5+
abstract type AbstractSampler end
6+
7+
struct BatchSampler <: AbstractSampler
68
batch_size::Int
79
rng::Random.AbstractRNG
810
transformer::Any
@@ -16,3 +18,40 @@ Uniformly sample a batch of examples for each trace.
1618
See also [`sample`](@ref).
1719
"""
1820
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity)
21+
22+
"""
23+
MetaSampler(::NamedTuple)
24+
25+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a batch from each sampler.
26+
Used internally for algorithms that sample multiple times per epoch.
27+
28+
# Example
29+
30+
MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
31+
"""
32+
struct MetaSampler{names, T} <: AbstractSampler
33+
samplers::NamedTuple{names, T}
34+
end
35+
36+
MetaSampler(; kw...) = MetaSampler(NamedTuple(kw))
37+
38+
function sample(s::MetaSampler, t)
39+
(;[(k, sample(v, t)) for (k,v) in pairs(s.samplers)]...)
40+
end
41+
42+
43+
"""
44+
MultiBatchSampler(sampler, n)
45+
46+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination with MetaSampler to allow different sampling rates between samplers.
47+
48+
# Example
49+
50+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5))
51+
"""
52+
struct MultiBatchSampler{S <: AbstractSampler} <: AbstractSampler
53+
sampler::S
54+
n::Int
55+
end
56+
57+
sample(m::MultiBatchSampler, t) = [sample(m.sampler, t) for _ in 1:m.n]

src/trajectory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ end
8383

8484
function Base.take!(t::Trajectory)
8585
res = on_sample!(t.controler)
86-
if isnothing(res)
86+
if isnothing(res) && !isnothing(t.controler)
8787
nothing
8888
else
8989
sample(t.sampler, t.container)

test/samplers.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using Trajectories, Test
2+
3+
@testset "MetaSampler" begin
4+
t = Trajectory(
5+
container=Traces(
6+
a=Int[],
7+
b=Bool[]
8+
),
9+
sampler = MetaSampler(policy = BatchSampler(3), critic = BatchSampler(5)),
10+
controler = InsertSampleControler(10, 0)
11+
)
12+
13+
append!(t; a=[1, 2, 3, 4], b=[false, true, false, true])
14+
15+
batches = []
16+
17+
for batch in t
18+
push!(batches, batch)
19+
end
20+
21+
@test length(batches) == 10
22+
@test length(batches[1][:policy][:a]) == 3 && length(batches[1][:critic][:b]) == 5
23+
end
24+
25+
@testset "MultiBatchSampler" begin
26+
t = Trajectory(
27+
container=Traces(
28+
a=Int[],
29+
b=Bool[]
30+
),
31+
sampler = MetaSampler(policy = BatchSampler(3), critic = MultiBatchSampler(BatchSampler(5), 2)),
32+
controler = InsertSampleControler(10, 0)
33+
)
34+
35+
append!(t; a=[1, 2, 3, 4], b=[false, true, false, true])
36+
37+
batches = []
38+
39+
for batch in t
40+
push!(batches, batch)
41+
end
42+
43+
@test length(batches) == 10
44+
@test length(batches[1][:policy][:a]) == 3
45+
@test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic
46+
@test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples
47+
end
48+
49+
@testset "async trajectories" begin
50+
threshould = 100
51+
ratio = 1 / 4
52+
t = Trajectory(
53+
container=Traces(
54+
a=Int[],
55+
b=Bool[]
56+
),
57+
sampler=BatchSampler(3),
58+
controler=AsyncInsertSampleRatioControler(ratio, threshould)
59+
)
60+
61+
n = 100
62+
insert_task = @async for i in 1:n
63+
append!(t; a=[i, i, i, i], b=[false, true, false, true])
64+
end
65+
66+
s = 0
67+
sample_task = @async for _ in t
68+
s += 1
69+
end
70+
sleep(1)
71+
@test s == (n - threshould * ratio) + 1
72+
end

0 commit comments

Comments
 (0)