Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Commit b9ed73f

Browse files
committed
minor rename
1 parent e73163e commit b9ed73f

11 files changed

+150
-55
lines changed

src/base.jl

Lines changed: 132 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,56 @@ end
99
Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) = show(io, MIME"text/markdown"(), env)
1010

1111
function Base.show(io::IO, t::MIME"text/markdown", env::AbstractEnv)
12-
show(io, t, Markdown.parse("""
12+
s = """
1313
# $(nameof(env))
1414
1515
## Traits
1616
| Trait Type | Value |
1717
|:---------- | ----- |
1818
$(join(["|$(string(f))|$(f(env))|" for f in env_traits()], "\n"))
1919
20-
## Action Space
21-
`$(action_space(env))`
20+
## Is Environment Terminated?
21+
$(is_terminated(env) ? "Yes" : "No")
2222
23-
## State Space
24-
`$(state_space(env))`
23+
"""
2524

26-
"""))
25+
if get(io, :is_show_state_space, true)
26+
s *= """
27+
## State Space
28+
`$(state_space(env))`
2729
28-
if NumAgentStyle(env) !== SINGLE_AGENT
29-
show(io, t, Markdown.parse("""
30-
## Players
31-
$(join(["- `$p`" for p in players(env)], "\n"))
30+
"""
31+
end
3232

33-
## Current Player
34-
`$(current_player(env))`
35-
"""))
33+
if get(io, :is_show_action_space, true)
34+
s *= """
35+
## Action Space
36+
`$(action_space(env))`
37+
38+
"""
3639
end
3740

38-
show(io, t, Markdown.parse("""
39-
## Is Environment Terminated?
40-
$(is_terminated(env) ? "Yes" : "No")
41+
if NumAgentStyle(env) !== SINGLE_AGENT
42+
s *= """
43+
## Players
44+
$(join(["- `$p`" for p in players(env)], "\n"))
4145
46+
## Current Player
47+
`$(current_player(env))`
48+
"""
49+
end
50+
51+
if get(io, :is_show_state, true)
52+
s *= """
4253
## Current State
4354
4455
```
4556
$(state(env))
4657
```
47-
"""))
58+
"""
59+
end
60+
61+
show(io, t, Markdown.parse(s))
4862
end
4963

5064
#####
@@ -57,48 +71,45 @@ using Test
5771
Call this function after writing your customized environment to make sure that
5872
all the necessary interfaces are implemented correctly and consistently.
5973
"""
60-
function test_interfaces(env)
61-
env = copy(env) # make sure we don't touch the original environment
62-
74+
function test_interfaces!(env)
6375
rng = Random.MersenneTwister(666)
6476

6577
@info "testing $(nameof(env)), you need to manually check these traits to make sure they are implemented correctly!" NumAgentStyle(env) DynamicStyle(env) ActionStyle(env) InformationStyle(env) StateStyle(env) RewardStyle(env) UtilityStyle(env) ChanceStyle(env)
6678

67-
reset!(env)
68-
6979
@testset "copy" begin
70-
old_env = env
71-
env = copy(env)
80+
X = copy(env)
81+
Y = copy(env)
82+
reset!(X)
83+
reset!(Y)
7284

73-
if ChanceStyle(env) (DETERMINISTIC, EXPLICIT_STOCHASTIC)
85+
if ChanceStyle(Y) (DETERMINISTIC, EXPLICIT_STOCHASTIC)
7486
s = 888
75-
Random.seed!(env, s)
76-
Random.seed!(old_env, s)
87+
Random.seed!(Y, s)
88+
Random.seed!(X, s)
7789
end
7890

79-
@test env !== old_env
91+
@test Y !== X
8092

81-
@test state(env) == state(old_env)
82-
@test action_space(env) == action_space(old_env)
83-
@test reward(env) == reward(old_env)
84-
@test is_terminated(env) == is_terminated(old_env)
93+
@test state(Y) == state(X)
94+
@test action_space(Y) == action_space(X)
95+
@test reward(Y) == reward(X)
96+
@test is_terminated(Y) == is_terminated(X)
8597

86-
while !is_terminated(env)
87-
A, A′ = legal_action_space(old_env), legal_action_space(env)
98+
while !is_terminated(Y)
99+
A, A′ = legal_action_space(X), legal_action_space(Y)
88100
@test A == A′
89101
a = rand(rng, A)
90-
env(a)
91-
old_env(a)
92-
@test state(env) == state(old_env)
93-
@test reward(env) == reward(old_env)
94-
@test is_terminated(env) == is_terminated(old_env)
102+
Y(a)
103+
X(a)
104+
@test state(Y) == state(X)
105+
@test reward(Y) == reward(X)
106+
@test is_terminated(Y) == is_terminated(X)
95107
end
96108
end
97109

98-
reset!(env)
99-
100110
@testset "SingleAgent" begin
101111
if NumAgentStyle(env) === SINGLE_AGENT
112+
reset!(env)
102113
total_reward = 0.
103114
while !is_terminated(env)
104115
if StateStyle(env) isa Tuple
@@ -170,6 +181,8 @@ function test_interfaces(env)
170181
end
171182
end
172183
end
184+
185+
reset!(env)
173186
end
174187

175188
#####
@@ -255,4 +268,81 @@ watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
255268
"""
256269
function Base.in(x::AbstractArray, s::Array{<:Interval})
257270
size(x) == size(s) && all(x .∈ s)
258-
end
271+
end
272+
273+
Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
274+
275+
function Random.rand(rng::AbstractRNG, s::Interval)
276+
rand(rng) * (s.right - s.left) + s.left
277+
end
278+
279+
function Random.rand(rng::AbstractRNG, s::Array{<:Interval})
280+
map(x -> rand(rng, x), s)
281+
end
282+
283+
export WorldSpace
284+
285+
struct WorldSpace{T} end
286+
287+
WorldSpace() = WorldSpace{Any}()
288+
289+
Base.in(x, ::WorldSpace{T}) where T = x isa T
290+
291+
#####
292+
# ZeroTo
293+
#####
294+
295+
export ZeroTo
296+
297+
"""
298+
Similar to `Base.OneTo`. Useful when wrapping third-party environments.
299+
"""
300+
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
301+
stop::T
302+
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T)-one(T),n))
303+
end
304+
305+
ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)
306+
307+
Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
308+
Base.length(r::ZeroTo{T}) where T = T(r.stop + one(r.stop))
309+
Base.first(r::ZeroTo{T}) where T = zero(r.stop)
310+
311+
function getindex(v::ZeroTo{T}, i::Integer) where T
312+
Base.@_inline_meta
313+
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
314+
convert(T, i)
315+
end
316+
317+
#####
318+
# ActionProbPair
319+
#####
320+
321+
export ActionProbPair
322+
323+
"""
324+
Used in action space of chance player.
325+
"""
326+
struct ActionProbPair{A,P}
327+
action::A
328+
prob::P
329+
end
330+
331+
"""
332+
Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
333+
Here we assume `wv` sum to `1`
334+
"""
335+
function weighted_sample(rng::AbstractRNG, wv)
336+
t = rand(rng)
337+
cw = zero(Base.first(wv))
338+
for (i, w) in enumerate(wv)
339+
cw += w
340+
if cw >= t
341+
return i
342+
end
343+
end
344+
end
345+
346+
Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) = s[weighted_sample(rng, (x.prob for x in s))]
347+
348+
(env::AbstractEnv)(a::ActionProbPair) = env(a.action)

src/interface.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ abstract type AbstractInformationStyle <: AbstractEnvStyle end
168168
@api const IMPERFECT_INFORMATION = ImperfectInformation()
169169

170170
"""
171-
InformationStyle(env) = PERFECT_INFORMATION
171+
InformationStyle(env) = IMPERFECT_INFORMATION
172172
173173
Distinguish environments between [`PERFECT_INFORMATION`](@ref) and
174-
[`IMPERFECT_INFORMATION`](@ref). [`PERFECT_INFORMATION`](@ref) is returned by default.
174+
[`IMPERFECT_INFORMATION`](@ref). [`IMPERFECT_INFORMATION`](@ref) is returned by default.
175175
"""
176176
@env_api InformationStyle(env::T) where {T<:AbstractEnv} = InformationStyle(T)
177-
InformationStyle(::Type{<:AbstractEnv}) = PERFECT_INFORMATION
177+
InformationStyle(::Type{<:AbstractEnv}) = IMPERFECT_INFORMATION
178178

179179
#####
180180
### ChanceStyle
@@ -391,7 +391,12 @@ const SPECTOR = Spector()
391391

392392
@api (env::AbstractEnv)(action, player = current_player(env))
393393

394-
"Make an independent copy of `env`"
394+
"""
395+
Make an independent copy of `env`,
396+
397+
!!! note
398+
rng (if `env` has) is also copied!
399+
"""
395400
@api copy(env::AbstractEnv) = deepcopy(env)
396401
@api copyto!(dest::AbstractEnv, src::AbstractEnv)
397402

test/examples/kuhn_poker.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
env = KuhnPokerEnv()
44

5-
RLBase.test_interfaces(env)
5+
RLBase.test_interfaces!(env)
66

77
end

test/examples/monty_hall_problem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
rng = StableRNG(123)
44
env = MontyHallEnv(;rng=rng)
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
n_win_car = 0
99
N = 50_000

test/examples/multi_arm_bandits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ rng = StableRNG(123)
44
env = MultiArmBanditsEnv(;rng=rng)
55
rewards = []
66

7-
RLBase.test_interfaces(env)
7+
RLBase.test_interfaces!(env)
88

99
N = 50_000
1010
for _ in 1:N

test/examples/pig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
@testset "PigEnv" begin
22
env = PigEnv()
3-
RLBase.test_interfaces(env)
3+
RLBase.test_interfaces!(env)
44
end

test/examples/random_walk_1d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
end_rewards = 3 => 5
44
env = RandomWalk1D(;rewards=end_rewards)
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
rng = StableRNG(123)
99
N = 50_000

test/examples/rock_paper_scissors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
rng = StableRNG(123)
44
env = RockPaperScissorsEnv()
55

6-
RLBase.test_interfaces(env)
6+
RLBase.test_interfaces!(env)
77

88
rewards = [[],[]]
99
for _ in 1:50_000

test/examples/tic_tac_toe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
env = TicTacToeEnv()
44

5-
RLBase.test_interfaces(env)
5+
RLBase.test_interfaces!(env)
66

77
@test length(state_space(env, Observation{Int}())) == 5478
88

test/examples/tiger_problem_env.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ rng = StableRNG(123)
44
obs_prob = 0.85
55
env = TigerProblemEnv(;rng=rng, obs_prob=obs_prob)
66

7-
RLBase.test_interfaces(env)
7+
RLBase.test_interfaces!(env)
88

99
rewards = []
1010
for _ in 1:50_000

test/examples/tiny_hanabi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
env = TinyHanabiEnv()
44

5-
RLBase.test_interfaces(env)
5+
RLBase.test_interfaces!(env)
66

77
end

0 commit comments

Comments
 (0)