|
1 |
| -@testset "Samplers" begin |
| 1 | +import ReinforcementLearningTrajectories.fetch |
| 2 | +#@testset "Samplers" begin |
2 | 3 | @testset "BatchSampler" begin
|
3 | 4 | sz = 32
|
4 | 5 | s = BatchSampler(sz)
|
|
74 | 75 |
|
75 | 76 | #! format: off
|
76 | 77 | @testset "NStepSampler" begin
|
77 |
| - γ = 0.9 |
| 78 | + γ = 0.99 |
78 | 79 | n_stack = 2
|
79 | 80 | n_horizon = 3
|
80 |
| - batch_size = 4 |
81 |
| - |
82 |
| - t1 = MultiplexTraces{(:state, :next_state)}(1:10) + |
83 |
| - MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) + |
84 |
| - Traces( |
85 |
| - reward=1:9, |
86 |
| - terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1], |
87 |
| - ) |
88 |
| - |
89 |
| - s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size) |
90 |
| - |
91 |
| - xs = RLTrajectories.StatsBase.sample(s1, t1) |
92 |
| - |
93 |
| - @test size(xs.state) == (n_stack, batch_size) |
94 |
| - @test size(xs.next_state) == (n_stack, batch_size) |
95 |
| - @test size(xs.action) == (batch_size,) |
96 |
| - @test size(xs.reward) == (batch_size,) |
97 |
| - @test size(xs.terminal) == (batch_size,) |
98 |
| - |
99 |
| - |
100 |
| - state_size = (2,3) |
101 |
| - n_state = reduce(*, state_size) |
102 |
| - total_length = 10 |
103 |
| - t2 = MultiplexTraces{(:state, :next_state)}( |
104 |
| - reshape(1:n_state * total_length, state_size..., total_length) |
105 |
| - ) + |
106 |
| - MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) + |
107 |
| - Traces( |
108 |
| - reward=1:total_length-1, |
109 |
| - terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1], |
110 |
| - ) |
111 |
| - |
112 |
| - xs2 = RLTrajectories.StatsBase.sample(s1, t2) |
113 |
| - |
114 |
| - @test size(xs2.state) == (state_size..., n_stack, batch_size) |
115 |
| - @test size(xs2.next_state) == (state_size..., n_stack, batch_size) |
116 |
| - @test size(xs2.action) == (batch_size,) |
117 |
| - @test size(xs2.reward) == (batch_size,) |
118 |
| - @test size(xs2.terminal) == (batch_size,) |
119 |
| - |
120 |
| - inds = [3, 5, 7] |
121 |
| - xs3 = RLTrajectories.StatsBase.sample(s1, t2, Val(SS′ART), inds) |
122 |
| - |
123 |
| - @test xs3.state == cat( |
124 |
| - ( |
125 |
| - reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack) |
126 |
| - for i in inds |
127 |
| - )... |
128 |
| - ;dims=length(state_size) + 2 |
129 |
| - ) |
130 |
| - |
131 |
| - @test xs3.next_state == xs3.state .+ (n_state * n_horizon) |
132 |
| - @test xs3.action == iseven.(inds) |
133 |
| - @test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds] |
134 |
| - |
135 |
| - # manual calculation |
136 |
| - @test xs3.reward[1] ≈ 3 + γ * 4 # terminated at step 4 |
137 |
| - @test xs3.reward[2] ≈ 5 + γ * (6 + γ * 7) |
138 |
| - @test xs3.reward[3] ≈ 7 + γ * (8 + γ * 9) |
139 |
| - end |
140 |
| - #! format: on |
141 |
| - |
142 |
| - @testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin |
143 |
| - n=1 |
144 |
| - γ=0.99f0 |
145 |
| - |
146 |
| - t = Trajectory( |
147 |
| - container=CircularPrioritizedTraces( |
148 |
| - CircularArraySARTSTraces( |
149 |
| - capacity=5, |
150 |
| - state=Float32 => (4,), |
151 |
| - ); |
152 |
| - default_priority=100.0f0 |
153 |
| - ), |
154 |
| - sampler=NStepBatchSampler{SS′ART}( |
155 |
| - n=n, |
156 |
| - γ=γ, |
157 |
| - batch_size=32, |
158 |
| - ), |
159 |
| - controller=InsertSampleRatioController( |
160 |
| - threshold=100, |
161 |
| - n_inserted=-1 |
162 |
| - ) |
163 |
| - ) |
| 81 | + batch_size = 1000 |
| 82 | + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) |
| 83 | + s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size) |
164 | 84 |
|
165 |
| - push!(t, (state = 1, action = true)) |
166 |
| - for i = 1:9 |
167 |
| - push!(t, (state = i+1, action = true, reward = i, terminal = false)) |
| 85 | + push!(eb, (state = 1, action = 1)) |
| 86 | + for i = 1:5 |
| 87 | + push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) |
168 | 88 | end
|
169 |
| - |
170 |
| - b = RLTrajectories.StatsBase.sample(t) |
171 |
| - @test haskey(b, :priority) |
172 |
| - @test sum(b.action .== 0) == 0 |
173 |
| - end |
174 |
| - |
175 |
| - |
176 |
| - @testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin |
177 |
| - n=1 |
178 |
| - γ=0.99f0 |
179 |
| - |
180 |
| - t = Trajectory( |
181 |
| - container=CircularArraySARTSTraces( |
182 |
| - capacity=5, |
183 |
| - state=Float32 => (4,), |
184 |
| - ), |
185 |
| - sampler=NStepBatchSampler{SS′ART}( |
186 |
| - n=n, |
187 |
| - γ=γ, |
188 |
| - batch_size=32, |
189 |
| - ), |
190 |
| - controller=InsertSampleRatioController( |
191 |
| - threshold=100, |
192 |
| - n_inserted=-1 |
193 |
| - ) |
194 |
| - ) |
195 |
| - |
196 |
| - push!(t, (state = 1, action = true)) |
197 |
| - for i = 1:9 |
198 |
| - push!(t, (state = i+1, action = true, reward = i, terminal = false)) |
| 89 | + push!(eb, (state = 7, action = 7)) |
| 90 | + for (j,i) = enumerate(8:11) |
| 91 | + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) |
| 92 | + end |
| 93 | + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) |
| 94 | + @test weights == [0,1,1,1,1,0,0,1,1,1,0] |
| 95 | + @test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode |
| 96 | + inds = [i for i in eachindex(weights) if weights[i] == 1] |
| 97 | + batch = sample(s1, eb) |
| 98 | + for key in keys(eb) |
| 99 | + @test haskey(batch, key) |
199 | 100 | end
|
| 101 | + #state: samples with stack_size |
| 102 | + states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds]) |
| 103 | + @test states == [1 2 3 4 7 8 9; |
| 104 | + 2 3 4 5 8 9 10] |
| 105 | + @test all(in(eachcol(states)), unique(eachcol(batch[:state]))) |
| 106 | + #next_state: samples with stack_size and nsteps forward |
| 107 | + next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds]) |
| 108 | + @test next_states == [4 5 5 5 10 10 10; |
| 109 | + 5 6 6 6 11 11 11] |
| 110 | + @test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state]))) |
| 111 | + #action: samples normally |
| 112 | + actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds]) |
| 113 | + @test actions == inds |
| 114 | + @test all(in(actions), unique(batch[:action])) |
| 115 | + #next_action: is a multiplex trace: should automatically sample nsteps forward |
| 116 | + next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds]) |
| 117 | + @test next_actions == [5, 6, 6, 6, 11, 11, 11] |
| 118 | + @test all(in(next_actions), unique(batch[:next_action])) |
| 119 | + #reward: discounted sum |
| 120 | + rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds]) |
| 121 | + @test rewards ≈ [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10] |
| 122 | + @test all(in(rewards), unique(batch[:reward])) |
| 123 | + #terminal: nsteps forward |
| 124 | + terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds]) |
| 125 | + @test terminals == [0,1,1,1,0,0,0] |
| 126 | + |
| 127 | + ### CircularPrioritizedTraces and NStepBatchSampler |
| 128 | + γ = 0.99 |
| 129 | + n_horizon = 3 |
| 130 | + batch_size = 4 |
| 131 | + eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) |
| 132 | + s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size) |
200 | 133 |
|
201 |
| - b = RLTrajectories.StatsBase.sample(t) |
202 |
| - @test sum(b.action .== 0) == 0 |
| 134 | + push!(eb, (state = 1, action = 1)) |
| 135 | + for i = 1:5 |
| 136 | + push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) |
| 137 | + end |
| 138 | + push!(eb, (state = 7, action = 7)) |
| 139 | + for (j,i) = enumerate(8:11) |
| 140 | + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) |
| 141 | + end |
| 142 | + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) |
| 143 | + inds = [i for i in eachindex(weights) if weights[i] == 1] |
| 144 | + batch = sample(s1, eb) |
| 145 | + for key in (keys(eb)..., :key, :priority) |
| 146 | + @test haskey(batch, key) |
| 147 | + end |
203 | 148 | end
|
204 | 149 |
|
205 | 150 | @testset "EpisodesSampler" begin
|
|
0 commit comments