Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Commit b3ac056

Browse files
authored
Minor fixes (#110)
* minor rename * merge upstream * add a general Space * fix tests
1 parent 94a0d7b commit b3ac056

13 files changed

+199
-60
lines changed

src/base.jl

Lines changed: 179 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,56 @@ Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) =
1010
show(io, MIME"text/markdown"(), env)
1111

1212
function Base.show(io::IO, t::MIME"text/markdown", env::AbstractEnv)
13-
show(io, t, Markdown.parse("""
13+
s = """
1414
# $(nameof(env))
1515
1616
## Traits
1717
| Trait Type | Value |
1818
|:---------- | ----- |
1919
$(join(["|$(string(f))|$(f(env))|" for f in env_traits()], "\n"))
2020
21-
## Action Space
22-
`$(action_space(env))`
21+
## Is Environment Terminated?
22+
$(is_terminated(env) ? "Yes" : "No")
2323
24-
## State Space
25-
`$(state_space(env))`
24+
"""
2625

27-
"""))
26+
if get(io, :is_show_state_space, true)
27+
s *= """
28+
## State Space
29+
`$(state_space(env))`
2830
29-
if NumAgentStyle(env) !== SINGLE_AGENT
30-
show(io, t, Markdown.parse("""
31-
## Players
32-
$(join(["- `$p`" for p in players(env)], "\n"))
31+
"""
32+
end
33+
34+
if get(io, :is_show_action_space, true)
35+
s *= """
36+
## Action Space
37+
`$(action_space(env))`
3338
34-
## Current Player
35-
`$(current_player(env))`
36-
"""))
39+
"""
3740
end
3841

39-
show(io, t, Markdown.parse("""
40-
## Is Environment Terminated?
41-
$(is_terminated(env) ? "Yes" : "No")
42+
if NumAgentStyle(env) !== SINGLE_AGENT
43+
s *= """
44+
## Players
45+
$(join(["- `$p`" for p in players(env)], "\n"))
46+
47+
## Current Player
48+
`$(current_player(env))`
49+
"""
50+
end
4251

52+
if get(io, :is_show_state, true)
53+
s *= """
4354
## Current State
4455
4556
```
4657
$(state(env))
4758
```
48-
"""))
59+
"""
60+
end
61+
62+
show(io, t, Markdown.parse(s))
4963
end
5064

5165
#####
@@ -58,9 +72,7 @@ using Test
5872
Call this function after writing your customized environment to make sure that
5973
all the necessary interfaces are implemented correctly and consistently.
6074
"""
61-
function test_interfaces(env)
62-
env = copy(env) # make sure we don't touch the original environment
63-
75+
function test_interfaces!(env)
6476
rng = Random.MersenneTwister(666)
6577

6678
@info "testing $(nameof(env)), you need to manually check these traits to make sure they are implemented correctly!" NumAgentStyle(
@@ -69,42 +81,41 @@ function test_interfaces(env)
6981
env,
7082
) UtilityStyle(env) ChanceStyle(env)
7183

72-
reset!(env)
73-
7484
@testset "copy" begin
75-
old_env = env
76-
env = copy(env)
85+
X = copy(env)
86+
Y = copy(env)
87+
reset!(X)
88+
reset!(Y)
7789

78-
if ChanceStyle(env) (DETERMINISTIC, EXPLICIT_STOCHASTIC)
90+
if ChanceStyle(Y) (DETERMINISTIC, EXPLICIT_STOCHASTIC)
7991
s = 888
80-
Random.seed!(env, s)
81-
Random.seed!(old_env, s)
92+
Random.seed!(Y, s)
93+
Random.seed!(X, s)
8294
end
8395

84-
@test env !== old_env
96+
@test Y !== X
8597

86-
@test state(env) == state(old_env)
87-
@test action_space(env) == action_space(old_env)
88-
@test reward(env) == reward(old_env)
89-
@test is_terminated(env) == is_terminated(old_env)
98+
@test state(Y) == state(X)
99+
@test action_space(Y) == action_space(X)
100+
@test reward(Y) == reward(X)
101+
@test is_terminated(Y) == is_terminated(X)
90102

91-
while !is_terminated(env)
92-
A, A′ = legal_action_space(old_env), legal_action_space(env)
103+
while !is_terminated(Y)
104+
A, A′ = legal_action_space(X), legal_action_space(Y)
93105
@test A == A′
94106
a = rand(rng, A)
95-
env(a)
96-
old_env(a)
97-
@test state(env) == state(old_env)
98-
@test reward(env) == reward(old_env)
99-
@test is_terminated(env) == is_terminated(old_env)
107+
Y(a)
108+
X(a)
109+
@test state(Y) == state(X)
110+
@test reward(Y) == reward(X)
111+
@test is_terminated(Y) == is_terminated(X)
100112
end
101113
end
102114

103-
reset!(env)
104-
105115
@testset "SingleAgent" begin
106116
if NumAgentStyle(env) === SINGLE_AGENT
107-
total_reward = 0.0
117+
reset!(env)
118+
total_reward = 0.
108119
while !is_terminated(env)
109120
if StateStyle(env) isa Tuple
110121
for ss in StateStyle(env)
@@ -176,6 +187,8 @@ function test_interfaces(env)
176187
end
177188
end
178189
end
190+
191+
reset!(env)
179192
end
180193

181194
#####
@@ -259,9 +272,130 @@ end
259272

260273
using IntervalSets
261274

275+
Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
276+
277+
function Random.rand(rng::AbstractRNG, s::Interval)
278+
rand(rng) * (s.right - s.left) + s.left
279+
end
280+
281+
#####
282+
# WorldSpace
283+
#####
284+
285+
export WorldSpace
286+
287+
"""
288+
In some cases, we may not be interested in the action/state space.
289+
One can return `WorldSpace()` to keep the interface consistent.
290+
"""
291+
struct WorldSpace{T} end
292+
293+
WorldSpace() = WorldSpace{Any}()
294+
295+
Base.in(x, ::WorldSpace{T}) where T = x isa T
296+
297+
#####
298+
# ZeroTo
299+
#####
300+
301+
export ZeroTo
302+
262303
"""
263-
watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
304+
Similar to `Base.OneTo`. Useful when wrapping third-party environments.
264305
"""
265-
function Base.in(x::AbstractArray, s::Array{<:Interval})
266-
size(x) == size(s) && all(x .∈ s)
306+
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
307+
stop::T
308+
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T)-one(T),n))
267309
end
310+
311+
ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)
312+
313+
Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
314+
Base.length(r::ZeroTo{T}) where T = T(r.stop + one(r.stop))
315+
Base.first(r::ZeroTo{T}) where T = zero(r.stop)
316+
317+
function getindex(v::ZeroTo{T}, i::Integer) where T
318+
Base.@_inline_meta
319+
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
320+
convert(T, i)
321+
end
322+
323+
#####
324+
# ActionProbPair
325+
#####
326+
327+
export ActionProbPair
328+
329+
"""
330+
Used in action space of chance player.
331+
"""
332+
struct ActionProbPair{A,P}
333+
action::A
334+
prob::P
335+
end
336+
337+
"""
338+
Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
339+
Here we assume `wv` sum to `1`
340+
"""
341+
function weighted_sample(rng::AbstractRNG, wv)
342+
t = rand(rng)
343+
cw = zero(Base.first(wv))
344+
for (i, w) in enumerate(wv)
345+
cw += w
346+
if cw >= t
347+
return i
348+
end
349+
end
350+
end
351+
352+
Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) = s[weighted_sample(rng, (x.prob for x in s))]
353+
354+
(env::AbstractEnv)(a::ActionProbPair) = env(a.action)
355+
356+
#####
357+
# Space
358+
#####
359+
360+
export Space
361+
362+
"""
363+
A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
364+
"""
365+
struct Space{T}
366+
s::T
367+
end
368+
369+
Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)
370+
371+
Random.rand(rng::AbstractRNG, s::Space) = map(s.s) do x
372+
rand(rng, x)
373+
end
374+
375+
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k=>rand(rng,v) for (k,v) in s.s)
376+
377+
function Base.in(X, S::Space)
378+
if length(X) == length(S.s)
379+
for (x,s) in zip(X, S.s)
380+
if x s
381+
return false
382+
end
383+
end
384+
return true
385+
else
386+
return false
387+
end
388+
end
389+
390+
function Base.in(X::Dict, S::Space{<:Dict})
391+
if keys(X) == keys(S.s)
392+
for k in keys(X)
393+
if X[k] S.s[k]
394+
return false
395+
end
396+
end
397+
return true
398+
else
399+
return false
400+
end
401+
end

src/examples/PigEnv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ prob(env::PigEnv, ::ChancePlayer) = fill(1 / 6, 6) # TODO: uniform distribution
3636

3737
state(env::PigEnv, ::Observation{Vector{Int}}, p) = env.scores
3838
state_space(env::PigEnv, ::Observation, p) =
39-
[0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]
39+
Space([0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores])
4040

4141
is_terminated(env::PigEnv) = any(s >= PIG_TARGET_SCORE for s in env.scores)
4242

src/examples/TicTacToeEnv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ current_player(env::TicTacToeEnv) = env.player
7575
players(env::TicTacToeEnv) = (CROSS, NOUGHT)
7676

7777
state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
78-
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = fill(false..true, 3, 3, 3)
78+
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = Space(fill(false..true, 3, 3, 3))
7979
state(env::TicTacToeEnv, ::Observation{Int}, p) = get_tic_tac_toe_state_info()[env].index
8080
state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
8181
Base.OneTo(length(get_tic_tac_toe_state_info()))

src/interface.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ abstract type AbstractInformationStyle <: AbstractEnvStyle end
168168
@api const IMPERFECT_INFORMATION = ImperfectInformation()
169169

170170
"""
171-
InformationStyle(env) = PERFECT_INFORMATION
171+
InformationStyle(env) = IMPERFECT_INFORMATION
172172
173173
Distinguish environments between [`PERFECT_INFORMATION`](@ref) and
174-
[`IMPERFECT_INFORMATION`](@ref). [`PERFECT_INFORMATION`](@ref) is returned by default.
174+
[`IMPERFECT_INFORMATION`](@ref). [`IMPERFECT_INFORMATION`](@ref) is returned by default.
175175
"""
176176
@env_api InformationStyle(env::T) where {T<:AbstractEnv} = InformationStyle(T)
177-
InformationStyle(::Type{<:AbstractEnv}) = PERFECT_INFORMATION
177+
InformationStyle(::Type{<:AbstractEnv}) = IMPERFECT_INFORMATION
178178

179179
#####
180180
### ChanceStyle
@@ -391,7 +391,12 @@ const SPECTOR = Spector()
391391

392392
@api (env::AbstractEnv)(action, player = current_player(env))
393393

394-
"Make an independent copy of `env`"
394+
"""
395+
Make an independent copy of `env`,
396+
397+
!!! note
398+
rng (if `env` has) is also copied!
399+
"""
395400
@api copy(env::AbstractEnv) = deepcopy(env)
396401
@api copyto!(dest::AbstractEnv, src::AbstractEnv)
397402

test/examples/kuhn_poker.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
env = KuhnPokerEnv()
44

5-
RLBase.test_interfaces(env)
5+
RLBase.test_interfaces!(env)
66

77
end

test/examples/monty_hall_problem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
rng = StableRNG(123)
44
env = MontyHallEnv(; rng = rng)
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
n_win_car = 0
99
N = 50_000

test/examples/multi_arm_bandits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
env = MultiArmBanditsEnv(; rng = rng)
55
rewards = []
66

7-
RLBase.test_interfaces(env)
7+
RLBase.test_interfaces!(env)
88

99
N = 50_000
1010
for _ in 1:N

test/examples/pig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
@testset "PigEnv" begin
22
env = PigEnv()
3-
RLBase.test_interfaces(env)
3+
RLBase.test_interfaces!(env)
44
end

test/examples/random_walk_1d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
end_rewards = 3 => 5
44
env = RandomWalk1D(; rewards = end_rewards)
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
rng = StableRNG(123)
99
N = 50_000

test/examples/rock_paper_scissors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
rng = StableRNG(123)
44
env = RockPaperScissorsEnv()
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
rewards = [[], []]
99
for _ in 1:50_000

0 commit comments

Comments
 (0)