@@ -49,7 +49,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
49
49
50
50
function sample (s:: BatchSampler , t:: AbstractTraces , names)
51
51
inds = rand (s. rng, 1 : length (t), s. batch_size)
52
- NamedTuple {names} (map (x -> t[x][inds], names))
52
+ NamedTuple {names} (map (x -> collect ( t[x][inds]) , names))
53
53
end
54
54
55
55
# !!! avoid iterating an empty trajectory
@@ -67,7 +67,7 @@ sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, ke
67
67
68
68
function sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
69
69
inds, priorities = rand (s. rng, t. priorities, s. batch_size)
70
- NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> t. traces[x][inds], names)... ))
70
+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect ( t. traces[x][inds]) , names)... ))
71
71
end
72
72
73
73
# ####
@@ -172,13 +172,13 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
172
172
foldr (((rr, tt), init) -> rr + nbs. γ * init * (1 - tt), zip (r⃗, t⃗); init= 0.0f0 )
173
173
end
174
174
175
- NamedTuple {SS′ART} (( s, s′, a, r, t))
175
+ NamedTuple {SS′ART} (map (collect, ( s, s′, a, r, t) ))
176
176
end
177
177
178
178
function sample (s:: NStepBatchSampler , ts, :: Val{SS′L′ART} , inds)
179
179
s, s′, a, r, t = sample (s, ts, Val (SSART), inds)
180
180
l = consecutive_view (ts[:next_legal_actions_mask ], inds)
181
- NamedTuple {SSLART} (( s, s′, l, a, r, t))
181
+ NamedTuple {SSLART} (map (collect, ( s, s′, l, a, r, t) ))
182
182
end
183
183
184
184
function sample (s:: NStepBatchSampler{names} , t:: CircularPrioritizedTraces ) where {names}
@@ -187,4 +187,4 @@ function sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where
187
187
(key= t. keys[inds], priority= priorities),
188
188
sample (s, t. traces, Val (names), inds)
189
189
)
190
- end
190
+ end
0 commit comments