Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Add converters #109

Merged
merged 4 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/converters.jl
Original file line number Diff line number Diff line change
@@ -1 +1,45 @@
export is_discrete_space, discrete2standard_discrete

is_discrete_space(x) = is_discrete_space(typeof(x))

is_discrete_space(::Type{AbstractVector}) = true
is_discrete_space(::Type{Tuple}) = true
is_discrete_space(::Type{NamedTuple}) = true

is_discrete_space(::Type{Space}) = false

"""
discrete2standard_discrete(env)

Convert an `env` with a discrete action space to a standard form:

- The action space is of type `Base.OneTo`
- If the `env` is of `FULL_ACTION_SET`, then each action in the
`legal_action_space(env)` is also an `Int` in the action space.

The standard form is useful for some algorithms (like Q-learning).
"""
function discrete2standard_discrete(env::AbstractEnv)
A = action_space(env)
if is_discrete_space(A)
AS = ActionStyle(env)
if AS === FULL_ACTION_SET
mapping = Dict(x => i for (i,x) in enumerate(A))
ActionTransformedEnv(
env;
action_space_mapping = a -> map(x -> mapping[x], a),
action_mapping = i -> A[i]
)
elseif AS === MINIMAL_ACTION_SET
ActionTransformedEnv(
env;
action_space_mapping = x -> Base.OneTo(length(A)),
action_mapping = i -> A[i]
)
else
@error "unknown ActionStyle $AS"
end
else
throw(ArgumentError("unrecognized action space: $A"))
end
end
14 changes: 10 additions & 4 deletions src/environments/3rd_party/open_spiel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import .OpenSpiel:
num_distinct_actions,
num_players,
apply_action,
current_player,
player_reward,
legal_actions,
legal_actions_mask,
Expand All @@ -29,7 +28,6 @@ import .OpenSpiel:
chance_outcomes,
max_chance_outcomes,
utility
using StatsBase: sample, weights


"""
Expand All @@ -56,9 +54,17 @@ RLBase.reset!(env::OpenSpielEnv) = env.state = new_initial_state(env.game)

(env::OpenSpielEnv)(action::Int) = apply_action(env.state, action)

RLBase.current_player(env::OpenSpielEnv) = current_player(env.state)
RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
RLBase.players(env::OpenSpielEnv) = 0:(num_players(env.game)-1)

function RLBase.players(env::OpenSpielEnv)
p = 0:(num_players(env.game)-1)
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
(p..., RLBase.chance_player(env))
else
p
end
end

function RLBase.action_space(env::OpenSpielEnv, player)
if player == chance_player(env)
Expand Down
8 changes: 6 additions & 2 deletions src/environments/examples/KuhnPokerEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
RLBase.is_terminated(env::KuhnPokerEnv) =
length(env.actions) == 2 && (env.actions[1] == :bet || env.actions[2] == :pass) ||
length(env.actions) == 3
RLBase.players(env::KuhnPokerEnv) = 1:2
RLBase.players(env::KuhnPokerEnv) = (1, 2, CHANCE_PLAYER)

function RLBase.state(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, p::Int)
if length(env.cards) >= p
Expand Down Expand Up @@ -120,7 +120,9 @@ function RLBase.prob(env::KuhnPokerEnv, ::ChancePlayer)
fill(1 / 3, 3)
elseif length(env.cards) == 1
p = fill(1 / 2, 3)
p[env.cards[1]] = 0
i = findfirst(==(env.cards[1]), KUHN_POKER_CARDS)
p[i] = 0
p
else
@error "it's not chance player's turn!"
end
Expand All @@ -131,6 +133,8 @@ end
(env::KuhnPokerEnv)(action::Symbol, ::ChancePlayer) = push!(env.cards, action)
(env::KuhnPokerEnv)(action::Symbol, ::Int) = push!(env.actions, action)

RLBase.reward(::KuhnPokerEnv, ::ChancePlayer) = 0

function RLBase.reward(env::KuhnPokerEnv, p)
if is_terminated(env)
v = KUHN_POKER_REWARD_TABLE[(env.cards..., env.actions...)]
Expand Down
146 changes: 0 additions & 146 deletions src/environments/wrappers/MultiThreadEnv.jl

This file was deleted.

49 changes: 49 additions & 0 deletions src/environments/wrappers/StochasticEnv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
export StochasticEnv

using StatsBase: sample, Weights

struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv
env::E
rng::R
end

function StochasticEnv(env;rng=Random.GLOBAL_RNG)
ChanceStyle(env) === EXPLICIT_STOCHASTIC || throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported"))
env = StochasticEnv(env,rng)
reset!(env)
env
end

function RLBase.reset!(env::StochasticEnv)
reset!(env.env)
while current_player(env.env) == chance_player(env.env)
p = prob(env.env)
A = action_space(env.env)
x = A[sample(env.rng, Weights(p, 1.0))]
env.env(x)
end
end

function (env::StochasticEnv)(a)
env.env(a)
while current_player(env.env) == chance_player(env.env)
p = prob(env.env)
A = action_space(env.env)
x = A[sample(env.rng, Weights(p, 1.0))]
env.env(x)
end
end

RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC
RLBase.players(env::StochasticEnv) = [p for p in players(env.env) if p != chance_player(env.env)]
Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:players, :ChanceStyle, :reset!)
@eval RLBase.$f(x::StochasticEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
2 changes: 1 addition & 1 deletion src/environments/wrappers/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(en
include("ActionTransformedEnv.jl")
include("DefaultStateStyle.jl")
include("MaxTimeoutEnv.jl")
include("MultiThreadEnv.jl")
include("RewardOverriddenEnv.jl")
include("StateCachedEnv.jl")
include("StateOverriddenEnv.jl")
include("StochasticEnv.jl")
30 changes: 9 additions & 21 deletions test/environments/wrappers/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,6 @@
end
end

@testset "MultiThreadEnv" begin
rng = StableRNG(123)
env = MultiThreadEnv(4) do
AtariEnv("pong")
end

reset!(env)
n = 1_000
for _ in 1:n
A = legal_action_space(env)
a = rand(rng, A)
@test a in A

S = state_space(env)
s = state(env)
@test s in S
env(a)
reset!(env)
end
end

@testset "RewardOverriddenEnv" begin
rng = StableRNG(123)
env = TigerProblemEnv(; rng = rng)
Expand Down Expand Up @@ -101,4 +80,13 @@
# RLBase.test_runnable!(env′)
end

@testset "StochasticEnv" begin
env = KuhnPokerEnv()
rng = StableRNG(123)
env′ = StochasticEnv(env;rng=rng)

RLBase.test_interfaces!(env′)
RLBase.test_runnable!(env′)
end

end