Skip to content

Commit d5d7c97

Browse files
authored
Merge pull request #34 from JuliaReinforcementLearning/Move-samples-to-Matrix
Move samples to Matrix
2 parents 29b7cb4 + 5ef9f06 commit d5d7c97

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/samplers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
4949

5050
function sample(s::BatchSampler, t::AbstractTraces, names)
5151
inds = rand(s.rng, 1:length(t), s.batch_size)
52-
NamedTuple{names}(map(x -> t[x][inds], names))
52+
NamedTuple{names}(map(x -> collect(t[x][inds]), names))
5353
end
5454

5555
# !!! avoid iterating an empty trajectory
@@ -67,7 +67,7 @@ sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, ke
6767

6868
function sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
6969
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
70-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> t.traces[x][inds], names)...))
70+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
7171
end
7272

7373
#####
@@ -172,13 +172,13 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
172172
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
173173
end
174174

175-
NamedTuple{SS′ART}((s, s′, a, r, t))
175+
NamedTuple{SS′ART}(map(collect, (s, s′, a, r, t)))
176176
end
177177

178178
function sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
179179
s, s′, a, r, t = sample(s, ts, Val(SSART), inds)
180180
l = consecutive_view(ts[:next_legal_actions_mask], inds)
181-
NamedTuple{SSLART}((s, s′, l, a, r, t))
181+
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
182182
end
183183

184184
function sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where {names}
@@ -187,4 +187,4 @@ function sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where
187187
(key=t.keys[inds], priority=priorities),
188188
sample(s, t.traces, Val(names), inds)
189189
)
190-
end
190+
end

0 commit comments

Comments
 (0)