Skip to content

Commit fb23bee

Browse files
committed
add more traces
1 parent edd4c52 commit fb23bee

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
8+
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
89
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/common/ElasticArraySARTTraces.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using ElasticArrays
2+
3+
const ElasticSARTTraces = Traces{
4+
SART,
5+
<:Tuple{
6+
<:Trace{<:ElasticArray},
7+
<:Trace{<:ElasticArray},
8+
<:Trace{<:ElasticArray},
9+
<:Trace{<:ElasticArray},
10+
}
11+
}
12+
13+
function ElasticSARTTraces(;
14+
state=Int => (),
15+
action=Int => (),
16+
reward=Float32 => (),
17+
terminal=Bool => ()
18+
)
19+
state_eltype, state_size = state
20+
action_eltype, action_size = action
21+
reward_eltype, reward_size = reward
22+
terminal_eltype, terminal_size = terminal
23+
24+
Traces(
25+
state=ElasticArray{state_eltype}(state_size..., 0),
26+
action=ElasticArray{action_eltype}(action_size..., 0),
27+
reward=ElasticArray{reward_eltype}(reward_size..., 0),
28+
terminal=ElasticArray{terminal_eltype}(terminal_size..., 0),
29+
)
30+
end
31+
32+
function Random.rand(s::BatchSampler, t::ElasticSARTTraces)
33+
inds = rand(s.rng, 1:length(t), s.batch_size)
34+
inds′ = inds .+ 1
35+
(
36+
state=t[:state][inds],
37+
action=t[:action][inds],
38+
reward=t[:reward][inds],
39+
terminal=t[:terminal][inds],
40+
next_state=t[:state][inds′],
41+
next_action=t[:state][inds′]
42+
) |> s.transformer
43+
end

src/common/ReservoirTraces.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using MacroTools: @forward
2+
3+
mutable struct ReservoirTraces{T,R}
4+
traces::T
5+
n::Int
6+
capacity::Int
7+
rng::R
8+
end
9+
10+
@forward ReservoirTrajectory.buffer sample, Base.keys, Base.haskey, Base.getindex, Base.view, Base.length, Base.setindex!, Base.lastindex, Base.firstindex
11+
12+
function Base.push!(t::ReservoirTraces, x)
13+
if t.n < t.capacity
14+
push!(t.traces, x)
15+
t.n += 1
16+
else
17+
i = rand(t.rng, 1:length(t))
18+
t.traces[i] = x
19+
end
20+
end

0 commit comments

Comments
 (0)