Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Minor fixes #110

Merged
merged 5 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 179 additions & 45 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,56 @@ Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) =
show(io, MIME"text/markdown"(), env)

function Base.show(io::IO, t::MIME"text/markdown", env::AbstractEnv)
show(io, t, Markdown.parse("""
s = """
# $(nameof(env))

## Traits
| Trait Type | Value |
|:---------- | ----- |
$(join(["|$(string(f))|$(f(env))|" for f in env_traits()], "\n"))

## Action Space
`$(action_space(env))`
## Is Environment Terminated?
$(is_terminated(env) ? "Yes" : "No")

## State Space
`$(state_space(env))`
"""

"""))
if get(io, :is_show_state_space, true)
s *= """
## State Space
`$(state_space(env))`

if NumAgentStyle(env) !== SINGLE_AGENT
show(io, t, Markdown.parse("""
## Players
$(join(["- `$p`" for p in players(env)], "\n"))
"""
end

if get(io, :is_show_action_space, true)
s *= """
## Action Space
`$(action_space(env))`

## Current Player
`$(current_player(env))`
"""))
"""
end

show(io, t, Markdown.parse("""
## Is Environment Terminated?
$(is_terminated(env) ? "Yes" : "No")
if NumAgentStyle(env) !== SINGLE_AGENT
s *= """
## Players
$(join(["- `$p`" for p in players(env)], "\n"))

## Current Player
`$(current_player(env))`
"""
end

if get(io, :is_show_state, true)
s *= """
## Current State

```
$(state(env))
```
"""))
"""
end

show(io, t, Markdown.parse(s))
end

#####
Expand All @@ -58,9 +72,7 @@ using Test
Call this function after writing your customized environment to make sure that
all the necessary interfaces are implemented correctly and consistently.
"""
function test_interfaces(env)
env = copy(env) # make sure we don't touch the original environment

function test_interfaces!(env)
rng = Random.MersenneTwister(666)

@info "testing $(nameof(env)), you need to manually check these traits to make sure they are implemented correctly!" NumAgentStyle(
Expand All @@ -69,42 +81,41 @@ function test_interfaces(env)
env,
) UtilityStyle(env) ChanceStyle(env)

reset!(env)

@testset "copy" begin
old_env = env
env = copy(env)
X = copy(env)
Y = copy(env)
reset!(X)
reset!(Y)

if ChanceStyle(env) ∉ (DETERMINISTIC, EXPLICIT_STOCHASTIC)
if ChanceStyle(Y) ∉ (DETERMINISTIC, EXPLICIT_STOCHASTIC)
s = 888
Random.seed!(env, s)
Random.seed!(old_env, s)
Random.seed!(Y, s)
Random.seed!(X, s)
end

@test env !== old_env
@test Y !== X

@test state(env) == state(old_env)
@test action_space(env) == action_space(old_env)
@test reward(env) == reward(old_env)
@test is_terminated(env) == is_terminated(old_env)
@test state(Y) == state(X)
@test action_space(Y) == action_space(X)
@test reward(Y) == reward(X)
@test is_terminated(Y) == is_terminated(X)

while !is_terminated(env)
A, A′ = legal_action_space(old_env), legal_action_space(env)
while !is_terminated(Y)
A, A′ = legal_action_space(X), legal_action_space(Y)
@test A == A′
a = rand(rng, A)
env(a)
old_env(a)
@test state(env) == state(old_env)
@test reward(env) == reward(old_env)
@test is_terminated(env) == is_terminated(old_env)
Y(a)
X(a)
@test state(Y) == state(X)
@test reward(Y) == reward(X)
@test is_terminated(Y) == is_terminated(X)
end
end

reset!(env)

@testset "SingleAgent" begin
if NumAgentStyle(env) === SINGLE_AGENT
total_reward = 0.0
reset!(env)
total_reward = 0.
while !is_terminated(env)
if StateStyle(env) isa Tuple
for ss in StateStyle(env)
Expand Down Expand Up @@ -176,6 +187,8 @@ function test_interfaces(env)
end
end
end

reset!(env)
end

#####
Expand Down Expand Up @@ -259,9 +272,130 @@ end

using IntervalSets

Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)

function Random.rand(rng::AbstractRNG, s::Interval)
rand(rng) * (s.right - s.left) + s.left
end

#####
# WorldSpace
#####

export WorldSpace

"""
In some cases, we may not be interested in the action/state space.
One can return `WorldSpace()` to keep the interface consistent.
"""
struct WorldSpace{T} end

WorldSpace() = WorldSpace{Any}()

Base.in(x, ::WorldSpace{T}) where T = x isa T

#####
# ZeroTo
#####

export ZeroTo

"""
watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
Similar to `Base.OneTo`. Useful when wrapping third-party environments.
"""
function Base.in(x::AbstractArray, s::Array{<:Interval})
size(x) == size(s) && all(x .∈ s)
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T)-one(T),n))
end

ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)

Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
Base.length(r::ZeroTo{T}) where T = T(r.stop + one(r.stop))
Base.first(r::ZeroTo{T}) where T = zero(r.stop)

function getindex(v::ZeroTo{T}, i::Integer) where T
Base.@_inline_meta
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

#####
# ActionProbPair
#####

export ActionProbPair

"""
Used in action space of chance player.
"""
struct ActionProbPair{A,P}
action::A
prob::P
end

"""
Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
Here we assume `wv` sum to `1`
"""
function weighted_sample(rng::AbstractRNG, wv)
t = rand(rng)
cw = zero(Base.first(wv))
for (i, w) in enumerate(wv)
cw += w
if cw >= t
return i
end
end
end

Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) = s[weighted_sample(rng, (x.prob for x in s))]

(env::AbstractEnv)(a::ActionProbPair) = env(a.action)

#####
# Space
#####

export Space

"""
A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
"""
struct Space{T}
s::T
end

Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)

Random.rand(rng::AbstractRNG, s::Space) = map(s.s) do x
rand(rng, x)
end

Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k=>rand(rng,v) for (k,v) in s.s)

function Base.in(X, S::Space)
if length(X) == length(S.s)
for (x,s) in zip(X, S.s)
if x ∉ s
return false
end
end
return true
else
return false
end
end

function Base.in(X::Dict, S::Space{<:Dict})
if keys(X) == keys(S.s)
for k in keys(X)
if X[k] ∉ S.s[k]
return false
end
end
return true
else
return false
end
end
2 changes: 1 addition & 1 deletion src/examples/PigEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ prob(env::PigEnv, ::ChancePlayer) = fill(1 / 6, 6) # TODO: uniform distribution

state(env::PigEnv, ::Observation{Vector{Int}}, p) = env.scores
state_space(env::PigEnv, ::Observation, p) =
[0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]
Space([0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores])

is_terminated(env::PigEnv) = any(s >= PIG_TARGET_SCORE for s in env.scores)

Expand Down
2 changes: 1 addition & 1 deletion src/examples/TicTacToeEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ current_player(env::TicTacToeEnv) = env.player
players(env::TicTacToeEnv) = (CROSS, NOUGHT)

state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = fill(false..true, 3, 3, 3)
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = Space(fill(false..true, 3, 3, 3))
state(env::TicTacToeEnv, ::Observation{Int}, p) = get_tic_tac_toe_state_info()[env].index
state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
Base.OneTo(length(get_tic_tac_toe_state_info()))
Expand Down
13 changes: 9 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ abstract type AbstractInformationStyle <: AbstractEnvStyle end
@api const IMPERFECT_INFORMATION = ImperfectInformation()

"""
InformationStyle(env) = PERFECT_INFORMATION
InformationStyle(env) = IMPERFECT_INFORMATION

Distinguish environments between [`PERFECT_INFORMATION`](@ref) and
[`IMPERFECT_INFORMATION`](@ref). [`PERFECT_INFORMATION`](@ref) is returned by default.
[`IMPERFECT_INFORMATION`](@ref). [`IMPERFECT_INFORMATION`](@ref) is returned by default.
"""
@env_api InformationStyle(env::T) where {T<:AbstractEnv} = InformationStyle(T)
InformationStyle(::Type{<:AbstractEnv}) = PERFECT_INFORMATION
InformationStyle(::Type{<:AbstractEnv}) = IMPERFECT_INFORMATION

#####
### ChanceStyle
Expand Down Expand Up @@ -391,7 +391,12 @@ const SPECTOR = Spector()

@api (env::AbstractEnv)(action, player = current_player(env))

"Make an independent copy of `env`"
"""
Make an independent copy of `env`,

!!! note
rng (if `env` has) is also copied!
"""
@api copy(env::AbstractEnv) = deepcopy(env)
@api copyto!(dest::AbstractEnv, src::AbstractEnv)

Expand Down
2 changes: 1 addition & 1 deletion test/examples/kuhn_poker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

env = KuhnPokerEnv()

RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)

end
2 changes: 1 addition & 1 deletion test/examples/monty_hall_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
rng = StableRNG(123)
env = MontyHallEnv(; rng = rng)

RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)

n_win_car = 0
N = 50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/multi_arm_bandits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
env = MultiArmBanditsEnv(; rng = rng)
rewards = []

RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)

N = 50_000
for _ in 1:N
Expand Down
2 changes: 1 addition & 1 deletion test/examples/pig.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "PigEnv" begin
env = PigEnv()
RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)
end
2 changes: 1 addition & 1 deletion test/examples/random_walk_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
end_rewards = 3 => 5
env = RandomWalk1D(; rewards = end_rewards)

RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)

rng = StableRNG(123)
N = 50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/rock_paper_scissors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
rng = StableRNG(123)
env = RockPaperScissorsEnv()

RLBase.test_interfaces(env)
RLBase.test_interfaces!(env)

rewards = [[], []]
for _ in 1:50_000
Expand Down
Loading