Skip to content

Commit 65978c9

Browse files
committed
sample namespace
1 parent ca511f7 commit 65978c9

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

test/common.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ end
145145

146146
s = BatchSampler(5)
147147

148-
b = ReinforcementLearningTrajectories.StatsBase.sample(s, t)
148+
b = sample(s, t)
149149

150150
t[:priority, [1, 2]] = [0, 0]
151151

@@ -154,7 +154,7 @@ end
154154

155155
t[:priority, [3, 4, 5]] = [0, 1, 0]
156156

157-
b = ReinforcementLearningTrajectories.StatsBase.sample(s, t)
157+
b = sample(s, t)
158158

159159
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
160160

@@ -176,7 +176,7 @@ end
176176
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
177177
end
178178
s = BatchSampler(1000)
179-
b = ReinforcementLearningTrajectories.StatsBase.sample(s, eb)
179+
b = sample(s, eb)
180180
cm = counter(b[:state])
181181
@test !haskey(cm, 6)
182182
@test !haskey(cm, 11)

test/normalization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import OnlineStats: mean, std
2929
@test extrema(unnormalized_batch[:state]) == (0, 4)
3030
normalized_batch = nt[[1:5;]]
3131

32-
normalized_batch = ReinforcementLearningTrajectories.StatsBase.sample(traj)
32+
normalized_batch = sample(traj)
3333
@test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./ss)
3434
@test all(extrema(normalized_batch[:next_state]) .≈ ((0, 4) .- m)./ss)
3535
@test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m)./s)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using CircularArrayBuffers, DataStructures
33
using Test
44
using CUDA
55
using Adapt
6+
import ReinforcementLearningTrajectories.StatsBase.sample
67

78
struct TestAdaptor end
89

test/samplers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
action=rand(1:4, 5),
77
)
88

9-
b = ReinforcementLearningTrajectories.StatsBase.sample(s, t)
9+
b = sample(s, t)
1010

1111
@test keys(b) == (:state, :action)
1212
@test size(b.state) == (3, 4, sz)
@@ -23,7 +23,7 @@
2323
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
2424
end
2525
s = BatchSampler(1000)
26-
b = ReinforcementLearningTrajectories.StatsBase.sample(s, eb)
26+
b = sample(s, eb)
2727
cm = counter(b[:state])
2828
@test !haskey(cm, 6)
2929
@test !haskey(cm, 11)

0 commit comments

Comments
 (0)