Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 9e02182

Browse files
authored
Simplify env wrapper (#127)
* simplify the definition of environment wrappers * simplify further
1 parent 9f50b34 commit 9e02182

File tree

9 files changed

+40
-119
lines changed

9 files changed

+40
-119
lines changed

src/environments/examples/KuhnPokerEnv.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,20 @@ const KUHN_POKER_CARDS = (:J, :Q, :K)
44
const KUHN_POKER_CARD_COMBINATIONS =
55
((:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q))
66
const KUHN_POKER_ACTIONS = (:pass, :bet)
7-
const KUHN_POKER_STATES = (
8-
(),
7+
const KUHN_POKER_STATES = ((),
98
map(tuple, KUHN_POKER_CARDS)...,
109
KUHN_POKER_CARD_COMBINATIONS...,
1110
(
12-
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for
13-
actions in (
14-
(),
11+
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for actions in ((),
1512
(:bet,),
1613
(:bet, :bet),
1714
(:bet, :pass),
1815
(:pass,),
1916
(:pass, :pass),
2017
(:pass, :bet),
2118
(:pass, :bet, :pass),
22-
(:pass, :bet, :bet),
23-
)
24-
)...,
25-
)
19+
(:pass, :bet, :bet),)
20+
)...,)
2621

2722
"""
2823
![](https://upload.wikimedia.org/wikipedia/commons/a/a9/Kuhn_poker_tree.svg)
@@ -146,15 +141,16 @@ end
146141

147142
RLBase.current_player(env::KuhnPokerEnv) =
148143
if length(env.cards) < 2
149-
CHANCE_PLAYER
150-
elseif length(env.actions) == 0
151-
1
152-
elseif length(env.actions) == 1
153-
2
154-
elseif length(env.actions) == 2
155-
1
156-
else
157-
end
144+
CHANCE_PLAYER
145+
elseif length(env.actions) == 0
146+
1
147+
elseif length(env.actions) == 1
148+
2
149+
elseif length(env.actions) == 2
150+
1
151+
else
152+
2 # actually the game is over now
153+
end
158154

159155
RLBase.NumAgentStyle(::KuhnPokerEnv) = MultiAgent(2)
160156
RLBase.DynamicStyle(::KuhnPokerEnv) = SEQUENTIAL

src/environments/wrappers/ActionTransformedEnv.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export ActionTransformedEnv
22

3-
struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnvWrapper
3+
struct ActionTransformedEnv{P,M,E <: AbstractEnv} <: AbstractEnvWrapper
44
action_space_mapping::P
55
action_mapping::M
66
env::E
@@ -15,23 +15,12 @@ feeding it into `env`.
1515
"""
1616
function ActionTransformedEnv(
1717
env;
18-
action_space_mapping = identity,
19-
action_mapping = identity,
18+
action_space_mapping=identity,
19+
action_mapping=identity,
2020
)
2121
ActionTransformedEnv(action_space_mapping, action_mapping, env)
2222
end
2323

24-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
25-
if f (:action_space, :legal_action_space)
26-
@eval RLBase.$f(x::ActionTransformedEnv, args...; kwargs...) =
27-
$f(x.env, args...; kwargs...)
28-
end
29-
end
30-
31-
RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
32-
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) =
33-
state_space(env.env, ss)
34-
3524
RLBase.action_space(env::ActionTransformedEnv, args...) =
3625
env.action_space_mapping(action_space(env.env), args...)
3726

src/environments/wrappers/DefaultStateStyle.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,3 @@ Reset the result of `DefaultStateStyle` without changing the original behavior.
1212
DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
1313

1414
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S
15-
16-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
17-
if f (:DefaultStateStyle, :state, :state_space)
18-
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
19-
$f(x.env, args...; kwargs...)
20-
end
21-
end
22-
23-
(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)
24-
25-
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
26-
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
27-
state(env.env, ss, p)
28-
29-
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
30-
state_space(env.env, ss)
31-
32-
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
33-
state_space(env.env, ss, p)
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export MaxTimeoutEnv
22

3-
mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnvWrapper
3+
mutable struct MaxTimeoutEnv{E <: AbstractEnv} <: AbstractEnvWrapper
44
env::E
55
max_t::Int
66
current_t::Int
@@ -11,29 +11,18 @@ end
1111
1212
Force `is_terminated(env)` return `true` after `max_t` interactions.
1313
"""
14-
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} =
14+
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int=1) where {E <: AbstractEnv} =
1515
MaxTimeoutEnv(env, max_t, current_t)
1616

1717
function (env::MaxTimeoutEnv)(args...; kwargs...)
1818
env.env(args...; kwargs...)
1919
env.current_t = env.current_t + 1
2020
end
2121

22-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
23-
if f (:is_terminated, :reset!)
24-
@eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) =
25-
$f(x.env, args...; kwargs...)
26-
end
27-
end
28-
2922
RLBase.is_terminated(env::MaxTimeoutEnv) =
3023
(env.current_t > env.max_t) || is_terminated(env.env)
3124

3225
function RLBase.reset!(env::MaxTimeoutEnv)
3326
env.current_t = 1
3427
RLBase.reset!(env.env)
3528
end
36-
37-
RLBase.state(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
38-
RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) =
39-
state_space(env.env, ss)
Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
11
export RewardOverriddenEnv
22

3-
struct RewardOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
3+
struct RewardOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
44
env::E
55
f::F
66
end
77

8-
(env::RewardOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)
9-
108
RewardOverriddenEnv(f) = env -> RewardOverriddenEnv(f, env)
119

12-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
13-
if f != :reward
14-
@eval RLBase.$f(x::RewardOverriddenEnv, args...; kwargs...) =
15-
$f(x.env, args...; kwargs...)
16-
end
17-
end
18-
1910
RLBase.reward(env::RewardOverriddenEnv, args...; kwargs...) =
2011
env.f(reward(env.env, args...; kwargs...))
21-
22-
RLBase.state(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
23-
RLBase.state_space(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) =
24-
state_space(env.env, ss)

src/environments/wrappers/StateCachedEnv.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some
66
environments are stateful during each `state(env)`. For example:
77
`StateOverriddenEnv(StackFrames(...))`.
88
"""
9-
mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper
9+
mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper
1010
s::S
1111
env::E
1212
is_state_cached::Bool
@@ -28,14 +28,3 @@ function RLBase.state(env::StateCachedEnv, args...; kwargs...)
2828
env.s
2929
end
3030
end
31-
32-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
33-
if f != :state
34-
@eval RLBase.$f(x::StateCachedEnv, args...; kwargs...) =
35-
$f(x.env, args...; kwargs...)
36-
end
37-
end
38-
39-
RLBase.state(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
40-
RLBase.state_space(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) =
41-
state_space(env.env, ss)

src/environments/wrappers/StateOverriddenEnv.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,12 @@ Apply `f` to override `state(env)`.
99
If the meaning of state space is changed after apply `f`, one should
1010
manually redefine the `RLBase.state_space(env::YourSpecificEnv)`.
1111
"""
12-
struct StateOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
12+
struct StateOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
1313
env::E
1414
f::F
1515
end
1616

1717
StateOverriddenEnv(f) = env -> StateOverriddenEnv(f, env)
1818

19-
(env::StateOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)
20-
21-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
22-
if f (:state,)
23-
@eval RLBase.$f(x::StateOverriddenEnv, args...; kwargs...) =
24-
$f(x.env, args...; kwargs...)
25-
end
26-
end
27-
2819
RLBase.state(env::StateOverriddenEnv, args...; kwargs...) =
2920
env.f(state(env.env, args...; kwargs...))
30-
31-
RLBase.state(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
32-
env.f(state(env.env, ss))
33-
34-
RLBase.state_space(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
35-
state_space(env.env, ss)

src/environments/wrappers/StochasticEnv.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ export StochasticEnv
22

33
using StatsBase: sample, Weights
44

5-
struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv
5+
struct StochasticEnv{E <: AbstractEnv,R} <: AbstractEnvWrapper
66
env::E
77
rng::R
88
end
99

10-
function StochasticEnv(env; rng = Random.GLOBAL_RNG)
10+
function StochasticEnv(env; rng=Random.GLOBAL_RNG)
1111
ChanceStyle(env) === EXPLICIT_STOCHASTIC ||
1212
throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported"))
1313
env = StochasticEnv(env, rng)
@@ -39,14 +39,3 @@ RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC
3939
RLBase.players(env::StochasticEnv) =
4040
[p for p in players(env.env) if p != chance_player(env.env)]
4141
Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s)
42-
43-
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
44-
if f (:players, :ChanceStyle, :reset!)
45-
@eval RLBase.$f(x::StochasticEnv, args...; kwargs...) =
46-
$f(x.env, args...; kwargs...)
47-
end
48-
end
49-
50-
RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
51-
RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) =
52-
state_space(env.env, ss)

src/environments/wrappers/wrappers.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1+
export AbstractEnvWrapper
2+
13
abstract type AbstractEnvWrapper <: AbstractEnv end
24

35
Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(env)))"
46

7+
Base.getindex(env::AbstractEnvWrapper) = env.env
8+
9+
(env::AbstractEnvWrapper)(args...; kwargs...) = env.env(args...; kwargs...)
10+
11+
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
12+
@eval RLBase.$f(x::AbstractEnvWrapper, args...; kwargs...) = $f(x[], args...; kwargs...)
13+
end
14+
15+
# avoid ambiguous
16+
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state(env[], ss, p)
17+
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state(env[], ss)
18+
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state_space(env[], ss)
19+
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state_space(env[], ss, p)
20+
521
include("ActionTransformedEnv.jl")
622
include("DefaultStateStyle.jl")
723
include("MaxTimeoutEnv.jl")

0 commit comments

Comments
 (0)