Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 15d2740

Browse files
authored
Add converters (#109)
* add a converter * remove unused function names * fix tests * minor fix with ChancePlayer in OpenSpiel
1 parent 812c42b commit 15d2740

File tree

7 files changed

+119
-174
lines changed

7 files changed

+119
-174
lines changed

src/converters.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,45 @@
1+
export is_discrete_space, discrete2standard_discrete
12

3+
is_discrete_space(x) = is_discrete_space(typeof(x))
4+
5+
is_discrete_space(::Type{AbstractVector}) = true
6+
is_discrete_space(::Type{Tuple}) = true
7+
is_discrete_space(::Type{NamedTuple}) = true
8+
9+
is_discrete_space(::Type{Space}) = false
10+
11+
"""
12+
discrete2standard_discrete(env)
13+
14+
Convert an `env` with a discrete action space to a standard form:
15+
16+
- The action space is of type `Base.OneTo`
17+
- If the `env` is of `FULL_ACTION_SET`, then each action in the
18+
`legal_action_space(env)` is also an `Int` in the action space.
19+
20+
The standard form is useful for some algorithms (like Q-learning).
21+
"""
22+
function discrete2standard_discrete(env::AbstractEnv)
23+
A = action_space(env)
24+
if is_discrete_space(A)
25+
AS = ActionStyle(env)
26+
if AS === FULL_ACTION_SET
27+
mapping = Dict(x => i for (i,x) in enumerate(A))
28+
ActionTransformedEnv(
29+
env;
30+
action_space_mapping = a -> map(x -> mapping[x], a),
31+
action_mapping = i -> A[i]
32+
)
33+
elseif AS === MINIMAL_ACTION_SET
34+
ActionTransformedEnv(
35+
env;
36+
action_space_mapping = x -> Base.OneTo(length(A)),
37+
action_mapping = i -> A[i]
38+
)
39+
else
40+
@error "unknown ActionStyle $AS"
41+
end
42+
else
43+
throw(ArgumentError("unrecognized action space: $A"))
44+
end
45+
end

src/environments/3rd_party/open_spiel.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import .OpenSpiel:
1616
num_distinct_actions,
1717
num_players,
1818
apply_action,
19-
current_player,
2019
player_reward,
2120
legal_actions,
2221
legal_actions_mask,
@@ -29,7 +28,6 @@ import .OpenSpiel:
2928
chance_outcomes,
3029
max_chance_outcomes,
3130
utility
32-
using StatsBase: sample, weights
3331

3432

3533
"""
@@ -56,9 +54,17 @@ RLBase.reset!(env::OpenSpielEnv) = env.state = new_initial_state(env.game)
5654

5755
(env::OpenSpielEnv)(action::Int) = apply_action(env.state, action)
5856

59-
RLBase.current_player(env::OpenSpielEnv) = current_player(env.state)
57+
RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
6058
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
61-
RLBase.players(env::OpenSpielEnv) = 0:(num_players(env.game)-1)
59+
60+
function RLBase.players(env::OpenSpielEnv)
61+
p = 0:(num_players(env.game)-1)
62+
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
63+
(p..., RLBase.chance_player(env))
64+
else
65+
p
66+
end
67+
end
6268

6369
function RLBase.action_space(env::OpenSpielEnv, player)
6470
if player == chance_player(env)

src/environments/examples/KuhnPokerEnv.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
RLBase.is_terminated(env::KuhnPokerEnv) =
8888
length(env.actions) == 2 && (env.actions[1] == :bet || env.actions[2] == :pass) ||
8989
length(env.actions) == 3
90-
RLBase.players(env::KuhnPokerEnv) = 1:2
90+
RLBase.players(env::KuhnPokerEnv) = (1, 2, CHANCE_PLAYER)
9191

9292
function RLBase.state(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, p::Int)
9393
if length(env.cards) >= p
@@ -120,7 +120,9 @@ function RLBase.prob(env::KuhnPokerEnv, ::ChancePlayer)
120120
fill(1 / 3, 3)
121121
elseif length(env.cards) == 1
122122
p = fill(1 / 2, 3)
123-
p[env.cards[1]] = 0
123+
i = findfirst(==(env.cards[1]), KUHN_POKER_CARDS)
124+
p[i] = 0
125+
p
124126
else
125127
@error "it's not chance player's turn!"
126128
end
@@ -131,6 +133,8 @@ end
131133
(env::KuhnPokerEnv)(action::Symbol, ::ChancePlayer) = push!(env.cards, action)
132134
(env::KuhnPokerEnv)(action::Symbol, ::Int) = push!(env.actions, action)
133135

136+
RLBase.reward(::KuhnPokerEnv, ::ChancePlayer) = 0
137+
134138
function RLBase.reward(env::KuhnPokerEnv, p)
135139
if is_terminated(env)
136140
v = KUHN_POKER_REWARD_TABLE[(env.cards..., env.actions...)]

src/environments/wrappers/MultiThreadEnv.jl

Lines changed: 0 additions & 146 deletions
This file was deleted.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
export StochasticEnv
2+
3+
using StatsBase: sample, Weights
4+
5+
struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv
6+
env::E
7+
rng::R
8+
end
9+
10+
function StochasticEnv(env;rng=Random.GLOBAL_RNG)
11+
ChanceStyle(env) === EXPLICIT_STOCHASTIC || throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported"))
12+
env = StochasticEnv(env,rng)
13+
reset!(env)
14+
env
15+
end
16+
17+
function RLBase.reset!(env::StochasticEnv)
18+
reset!(env.env)
19+
while current_player(env.env) == chance_player(env.env)
20+
p = prob(env.env)
21+
A = action_space(env.env)
22+
x = A[sample(env.rng, Weights(p, 1.0))]
23+
env.env(x)
24+
end
25+
end
26+
27+
function (env::StochasticEnv)(a)
28+
env.env(a)
29+
while current_player(env.env) == chance_player(env.env)
30+
p = prob(env.env)
31+
A = action_space(env.env)
32+
x = A[sample(env.rng, Weights(p, 1.0))]
33+
env.env(x)
34+
end
35+
end
36+
37+
RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC
38+
RLBase.players(env::StochasticEnv) = [p for p in players(env.env) if p != chance_player(env.env)]
39+
Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s)
40+
41+
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
42+
if f (:players, :ChanceStyle, :reset!)
43+
@eval RLBase.$f(x::StochasticEnv, args...; kwargs...) =
44+
$f(x.env, args...; kwargs...)
45+
end
46+
end
47+
48+
RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
49+
RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)

src/environments/wrappers/wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(en
55
include("ActionTransformedEnv.jl")
66
include("DefaultStateStyle.jl")
77
include("MaxTimeoutEnv.jl")
8-
include("MultiThreadEnv.jl")
98
include("RewardOverriddenEnv.jl")
109
include("StateCachedEnv.jl")
1110
include("StateOverriddenEnv.jl")
11+
include("StochasticEnv.jl")

test/environments/wrappers/wrappers.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,6 @@
3939
end
4040
end
4141

42-
@testset "MultiThreadEnv" begin
43-
rng = StableRNG(123)
44-
env = MultiThreadEnv(4) do
45-
AtariEnv("pong")
46-
end
47-
48-
reset!(env)
49-
n = 1_000
50-
for _ in 1:n
51-
A = legal_action_space(env)
52-
a = rand(rng, A)
53-
@test a in A
54-
55-
S = state_space(env)
56-
s = state(env)
57-
@test s in S
58-
env(a)
59-
reset!(env)
60-
end
61-
end
62-
6342
@testset "RewardOverriddenEnv" begin
6443
rng = StableRNG(123)
6544
env = TigerProblemEnv(; rng = rng)
@@ -101,4 +80,13 @@
10180
# RLBase.test_runnable!(env′)
10281
end
10382

83+
@testset "StochasticEnv" begin
84+
env = KuhnPokerEnv()
85+
rng = StableRNG(123)
86+
env′ = StochasticEnv(env;rng=rng)
87+
88+
RLBase.test_interfaces!(env′)
89+
RLBase.test_runnable!(env′)
90+
end
91+
10492
end

0 commit comments

Comments
 (0)