@@ -191,213 +191,22 @@ function test_interfaces!(env)
191
191
reset! (env)
192
192
end
193
193
194
- # ####
195
- # Generate README
196
- # ####
197
-
198
- gen_traits_table (envs) = gen_traits_table (stdout , envs)
199
-
200
- function gen_traits_table (io, envs)
201
- trait_dict = Dict ()
202
- for f in env_traits ()
203
- for env in envs
204
- if ! haskey (trait_dict, f)
205
- trait_dict[f] = Set ()
206
- end
207
- t = f (env)
208
- if f == StateStyle
209
- if t isa Tuple
210
- for x in t
211
- push! (trait_dict[f], nameof (typeof (x)))
212
- end
213
- else
214
- push! (trait_dict[f], nameof (typeof (t)))
215
- end
216
- else
217
- push! (trait_dict[f], nameof (typeof (t)))
218
- end
219
- end
220
- end
221
-
222
- println (io, " <table>" )
223
-
224
- print (io, " <th colspan=\" 2\" >Traits</th>" )
225
- for i in 1 : length (envs)
226
- print (io, " <th> $(i) </th>" )
227
- end
228
-
229
- for k in sort (collect (keys (trait_dict)), by = nameof)
230
- vs = trait_dict[k]
231
- print (io, " <tr> <th rowspan=\" $(length (vs)) \" > $(nameof (k)) </th>" )
232
- for (i, v) in enumerate (vs)
233
- if i != 1
234
- print (io, " <tr> " )
235
- end
236
- print (io, " <th> $(v) </th>" )
237
- for env in envs
238
- if k == StateStyle && k (env) isa Tuple
239
- ss = k (env)
240
- if v in map (x -> nameof (typeof (x)), ss)
241
- print (io, " <td> ✔ </td>" )
242
- else
243
- print (io, " <td> </td> " )
244
- end
245
- else
246
- if nameof (typeof (k (env))) == v
247
- print (io, " <td> ✔ </td>" )
248
- else
249
- print (io, " <td> </td> " )
250
- end
251
- end
252
- end
253
- println (io, " </tr>" )
254
- end
255
- end
256
-
257
- println (io, " </table>" )
258
-
259
- print (io, " <ol>" )
260
- for env in envs
261
- println (
262
- io,
263
- " <li> <a href=\" https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/$(nameof (env)) .jl\" > $(nameof (env)) </a></li>" ,
264
- )
265
- end
266
- print (io, " </ol>" )
267
- end
268
-
269
- # ####
270
- # Utils
271
- # ####
272
-
273
- using IntervalSets
274
-
275
- Random. rand (s:: Union{Interval,Array{<:Interval}} ) = rand (Random. GLOBAL_RNG, s)
276
-
277
- function Random. rand (rng:: AbstractRNG , s:: Interval )
278
- rand (rng) * (s. right - s. left) + s. left
279
- end
280
-
281
- # ####
282
- # WorldSpace
283
- # ####
284
-
285
- export WorldSpace
286
-
287
- """
288
- In some cases, we may not be interested in the action/state space.
289
- One can return `WorldSpace()` to keep the interface consistent.
290
- """
291
- struct WorldSpace{T} end
292
-
293
- WorldSpace () = WorldSpace {Any} ()
294
-
295
- Base. in (x, :: WorldSpace{T} ) where {T} = x isa T
296
-
297
- # ####
298
- # ZeroTo
299
- # ####
300
-
301
- export ZeroTo
302
-
303
- """
304
- Similar to `Base.OneTo`. Useful when wrapping third-party environments.
305
- """
306
- struct ZeroTo{T<: Integer } <: AbstractUnitRange{T}
307
- stop:: T
308
- ZeroTo {T} (n) where {T<: Integer } = new (max (zero (T) - one (T), n))
309
- end
310
-
311
- ZeroTo (n:: T ) where {T<: Integer } = ZeroTo {T} (n)
312
-
313
- Base. show (io:: IO , r:: ZeroTo ) = print (io, " ZeroTo(" , r. stop, " )" )
314
- Base. length (r:: ZeroTo{T} ) where {T} = T (r. stop + one (r. stop))
315
- Base. first (r:: ZeroTo{T} ) where {T} = zero (r. stop)
316
-
317
- function getindex (v:: ZeroTo{T} , i:: Integer ) where {T}
318
- Base. @_inline_meta
319
- @boundscheck ((i >= 0 ) & (i <= v. stop)) || throw_boundserror (v, i)
320
- convert (T, i)
321
- end
322
-
323
- # ####
324
- # ActionProbPair
325
- # ####
326
-
327
- export ActionProbPair
328
-
329
- """
330
- Used in action space of chance player.
331
- """
332
- struct ActionProbPair{A,P}
333
- action:: A
334
- prob:: P
335
- end
336
-
337
- """
338
- Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
339
- Here we assume `wv` sum to `1`
340
- """
341
- function weighted_sample (rng:: AbstractRNG , wv)
342
- t = rand (rng)
343
- cw = zero (Base. first (wv))
344
- for (i, w) in enumerate (wv)
345
- cw += w
346
- if cw >= t
347
- return i
348
- end
349
- end
350
- end
351
-
352
- Random. rand (rng:: AbstractRNG , s:: AbstractVector{<:ActionProbPair} ) =
353
- s[weighted_sample (rng, (x. prob for x in s))]
354
-
355
- (env:: AbstractEnv )(a:: ActionProbPair ) = env (a. action)
356
-
357
- # ####
358
- # Space
359
- # ####
360
-
361
- export Space
362
-
363
- """
364
- A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
365
- """
366
- struct Space{T}
367
- s:: T
368
- end
369
-
370
- Random. rand (s:: Space ) = rand (Random. GLOBAL_RNG, s)
371
-
372
- Random. rand (rng:: AbstractRNG , s:: Space ) =
373
- map (s. s) do x
374
- rand (rng, x)
375
- end
376
-
377
- Random. rand (rng:: AbstractRNG , s:: Space{<:Dict} ) = Dict (k => rand (rng, v) for (k, v) in s. s)
378
-
379
- function Base. in (X, S:: Space )
380
- if length (X) == length (S. s)
381
- for (x, s) in zip (X, S. s)
382
- if x ∉ s
383
- return false
384
- end
385
- end
386
- return true
387
- else
388
- return false
389
- end
390
- end
391
-
392
- function Base. in (X:: Dict , S:: Space{<:Dict} )
393
- if keys (X) == keys (S. s)
394
- for k in keys (X)
395
- if X[k] ∉ S. s[k]
396
- return false
194
+ function test_runnable! (env, n = 1000 ;rng= Random. GLOBAL_RNG)
195
+ @testset " random policy with $(nameof (env)) " begin
196
+ reset! (env)
197
+ for _ in 1 : n
198
+ A = legal_action_space (env)
199
+ a = rand (rng, A)
200
+ @test a in A
201
+
202
+ S = state_space (env)
203
+ s = state (env)
204
+ @test s in S
205
+ env (a)
206
+ if is_terminated (env)
207
+ reset! (env)
397
208
end
398
209
end
399
- return true
400
- else
401
- return false
210
+ reset! (env)
402
211
end
403
212
end
0 commit comments