@@ -10,42 +10,56 @@ Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) =
10
10
show (io, MIME " text/markdown" (), env)
11
11
12
12
function Base. show (io:: IO , t:: MIME"text/markdown" , env:: AbstractEnv )
13
- show (io, t, Markdown . parse ( """
13
+ s = """
14
14
# $(nameof (env))
15
15
16
16
## Traits
17
17
| Trait Type | Value |
18
18
|:---------- | ----- |
19
19
$(join ([" |$(string (f)) |$(f (env)) |" for f in env_traits ()], " \n " ))
20
20
21
- ## Action Space
22
- ` $( action_space (env)) `
21
+ ## Is Environment Terminated?
22
+ $( is_terminated (env) ? " Yes " : " No " )
23
23
24
- ## State Space
25
- `$(state_space (env)) `
24
+ """
26
25
27
- """ ))
26
+ if get (io, :is_show_state_space , true )
27
+ s *= """
28
+ ## State Space
29
+ `$(state_space (env)) `
28
30
29
- if NumAgentStyle (env) != = SINGLE_AGENT
30
- show (io, t, Markdown. parse ("""
31
- ## Players
32
- $(join ([" - `$p `" for p in players (env)], " \n " ))
31
+ """
32
+ end
33
+
34
+ if get (io, :is_show_action_space , true )
35
+ s *= """
36
+ ## Action Space
37
+ `$(action_space (env)) `
33
38
34
- ## Current Player
35
- `$(current_player (env)) `
36
- """ ))
39
+ """
37
40
end
38
41
39
- show (io, t, Markdown. parse ("""
40
- ## Is Environment Terminated?
41
- $(is_terminated (env) ? " Yes" : " No" )
42
+ if NumAgentStyle (env) != = SINGLE_AGENT
43
+ s *= """
44
+ ## Players
45
+ $(join ([" - `$p `" for p in players (env)], " \n " ))
46
+
47
+ ## Current Player
48
+ `$(current_player (env)) `
49
+ """
50
+ end
42
51
52
+ if get (io, :is_show_state , true )
53
+ s *= """
43
54
## Current State
44
55
45
56
```
46
57
$(state (env))
47
58
```
48
- """ ))
59
+ """
60
+ end
61
+
62
+ show (io, t, Markdown. parse (s))
49
63
end
50
64
51
65
# ####
@@ -58,9 +72,7 @@ using Test
58
72
Call this function after writing your customized environment to make sure that
59
73
all the necessary interfaces are implemented correctly and consistently.
60
74
"""
61
- function test_interfaces (env)
62
- env = copy (env) # make sure we don't touch the original environment
63
-
75
+ function test_interfaces! (env)
64
76
rng = Random. MersenneTwister (666 )
65
77
66
78
@info " testing $(nameof (env)) , you need to manually check these traits to make sure they are implemented correctly!" NumAgentStyle (
@@ -69,42 +81,41 @@ function test_interfaces(env)
69
81
env,
70
82
) UtilityStyle (env) ChanceStyle (env)
71
83
72
- reset! (env)
73
-
74
84
@testset " copy" begin
75
- old_env = env
76
- env = copy (env)
85
+ X = copy (env)
86
+ Y = copy (env)
87
+ reset! (X)
88
+ reset! (Y)
77
89
78
- if ChanceStyle (env ) ∉ (DETERMINISTIC, EXPLICIT_STOCHASTIC)
90
+ if ChanceStyle (Y ) ∉ (DETERMINISTIC, EXPLICIT_STOCHASTIC)
79
91
s = 888
80
- Random. seed! (env , s)
81
- Random. seed! (old_env , s)
92
+ Random. seed! (Y , s)
93
+ Random. seed! (X , s)
82
94
end
83
95
84
- @test env != = old_env
96
+ @test Y != = X
85
97
86
- @test state (env ) == state (old_env )
87
- @test action_space (env ) == action_space (old_env )
88
- @test reward (env ) == reward (old_env )
89
- @test is_terminated (env ) == is_terminated (old_env )
98
+ @test state (Y ) == state (X )
99
+ @test action_space (Y ) == action_space (X )
100
+ @test reward (Y ) == reward (X )
101
+ @test is_terminated (Y ) == is_terminated (X )
90
102
91
- while ! is_terminated (env )
92
- A, A′ = legal_action_space (old_env ), legal_action_space (env )
103
+ while ! is_terminated (Y )
104
+ A, A′ = legal_action_space (X ), legal_action_space (Y )
93
105
@test A == A′
94
106
a = rand (rng, A)
95
- env (a)
96
- old_env (a)
97
- @test state (env ) == state (old_env )
98
- @test reward (env ) == reward (old_env )
99
- @test is_terminated (env ) == is_terminated (old_env )
107
+ Y (a)
108
+ X (a)
109
+ @test state (Y ) == state (X )
110
+ @test reward (Y ) == reward (X )
111
+ @test is_terminated (Y ) == is_terminated (X )
100
112
end
101
113
end
102
114
103
- reset! (env)
104
-
105
115
@testset " SingleAgent" begin
106
116
if NumAgentStyle (env) === SINGLE_AGENT
107
- total_reward = 0.0
117
+ reset! (env)
118
+ total_reward = 0.
108
119
while ! is_terminated (env)
109
120
if StateStyle (env) isa Tuple
110
121
for ss in StateStyle (env)
@@ -176,6 +187,8 @@ function test_interfaces(env)
176
187
end
177
188
end
178
189
end
190
+
191
+ reset! (env)
179
192
end
180
193
181
194
# ####
@@ -259,9 +272,130 @@ end
259
272
260
273
using IntervalSets
261
274
275
+ Random. rand (s:: Union{Interval, Array{<:Interval}} ) = rand (Random. GLOBAL_RNG, s)
276
+
277
+ function Random. rand (rng:: AbstractRNG , s:: Interval )
278
+ rand (rng) * (s. right - s. left) + s. left
279
+ end
280
+
281
+ # ####
282
+ # WorldSpace
283
+ # ####
284
+
285
+ export WorldSpace
286
+
287
+ """
288
+ In some cases, we may not be interested in the action/state space.
289
+ One can return `WorldSpace()` to keep the interface consistent.
290
+ """
291
+ struct WorldSpace{T} end
292
+
293
+ WorldSpace () = WorldSpace {Any} ()
294
+
295
+ Base. in (x, :: WorldSpace{T} ) where T = x isa T
296
+
297
+ # ####
298
+ # ZeroTo
299
+ # ####
300
+
301
+ export ZeroTo
302
+
262
303
"""
263
- watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
304
+ Similar to `Base.OneTo`. Useful when wrapping third-party environments.
264
305
"""
265
- function Base. in (x:: AbstractArray , s:: Array{<:Interval} )
266
- size (x) == size (s) && all (x .∈ s)
306
+ struct ZeroTo{T<: Integer } <: AbstractUnitRange{T}
307
+ stop:: T
308
+ ZeroTo {T} (n) where {T<: Integer } = new (max (zero (T)- one (T),n))
267
309
end
310
+
311
+ ZeroTo (n:: T ) where {T<: Integer } = ZeroTo {T} (n)
312
+
313
+ Base. show (io:: IO , r:: ZeroTo ) = print (io, " ZeroTo(" , r. stop, " )" )
314
+ Base. length (r:: ZeroTo{T} ) where T = T (r. stop + one (r. stop))
315
+ Base. first (r:: ZeroTo{T} ) where T = zero (r. stop)
316
+
317
+ function getindex (v:: ZeroTo{T} , i:: Integer ) where T
318
+ Base. @_inline_meta
319
+ @boundscheck ((i >= 0 ) & (i <= v. stop)) || throw_boundserror (v, i)
320
+ convert (T, i)
321
+ end
322
+
323
+ # ####
324
+ # ActionProbPair
325
+ # ####
326
+
327
+ export ActionProbPair
328
+
329
+ """
330
+ Used in action space of chance player.
331
+ """
332
+ struct ActionProbPair{A,P}
333
+ action:: A
334
+ prob:: P
335
+ end
336
+
337
+ """
338
+ Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
339
+ Here we assume `wv` sum to `1`
340
+ """
341
+ function weighted_sample (rng:: AbstractRNG , wv)
342
+ t = rand (rng)
343
+ cw = zero (Base. first (wv))
344
+ for (i, w) in enumerate (wv)
345
+ cw += w
346
+ if cw >= t
347
+ return i
348
+ end
349
+ end
350
+ end
351
+
352
+ Random. rand (rng:: AbstractRNG , s:: AbstractVector{<:ActionProbPair} ) = s[weighted_sample (rng, (x. prob for x in s))]
353
+
354
+ (env:: AbstractEnv )(a:: ActionProbPair ) = env (a. action)
355
+
356
+ # ####
357
+ # Space
358
+ # ####
359
+
360
+ export Space
361
+
362
+ """
363
+ A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
364
+ """
365
+ struct Space{T}
366
+ s:: T
367
+ end
368
+
369
+ Random. rand (s:: Space ) = rand (Random. GLOBAL_RNG, s)
370
+
371
+ Random. rand (rng:: AbstractRNG , s:: Space ) = map (s. s) do x
372
+ rand (rng, x)
373
+ end
374
+
375
+ Random. rand (rng:: AbstractRNG , s:: Space{<:Dict} ) = Dict (k=> rand (rng,v) for (k,v) in s. s)
376
+
377
+ function Base. in (X, S:: Space )
378
+ if length (X) == length (S. s)
379
+ for (x,s) in zip (X, S. s)
380
+ if x ∉ s
381
+ return false
382
+ end
383
+ end
384
+ return true
385
+ else
386
+ return false
387
+ end
388
+ end
389
+
390
+ function Base. in (X:: Dict , S:: Space{<:Dict} )
391
+ if keys (X) == keys (S. s)
392
+ for k in keys (X)
393
+ if X[k] ∉ S. s[k]
394
+ return false
395
+ end
396
+ end
397
+ return true
398
+ else
399
+ return false
400
+ end
401
+ end
0 commit comments