Skip to content

Commit bc245a8

Browse files
authored
Merge pull request #27 from findmyway/tianjun/bump_version
minor enhancement and bump version
2 parents f09d328 + 9b9897e commit bc245a8

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.1.2"
3+
version = "0.1.3"
44

55
[deps]
66
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"

src/samplers.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,12 @@ function sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
145145
s, s′, a, r, t = sample(s, ts, Val(SSART), inds)
146146
l = consecutive_view(ts[:next_legal_actions_mask], inds)
147147
NamedTuple{SSLART}((s, s′, l, a, r, t))
148-
end
148+
end
149+
150+
function sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where {names}
151+
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
152+
merge(
153+
(key=t.keys[inds], priority=priorities),
154+
sample(s, t.traces, Val(names), inds)
155+
)
156+
end

src/trajectory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function Base.push!(t::Trajectory, x)
7676
on_insert!(t.controller, 1)
7777
end
7878

79-
Base.setindex!(t::Trajectory, v, k) = setindex!(t.container, v, k)
79+
Base.setindex!(t::Trajectory, v, I...) = setindex!(t.container, v, I...)
8080

8181
struct CallMsg
8282
f::Any

0 commit comments

Comments
 (0)