Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Simplify env wrapper #127

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions src/environments/examples/KuhnPokerEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,20 @@ const KUHN_POKER_CARDS = (:J, :Q, :K)
const KUHN_POKER_CARD_COMBINATIONS =
((:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q))
const KUHN_POKER_ACTIONS = (:pass, :bet)
const KUHN_POKER_STATES = (
(),
const KUHN_POKER_STATES = ((),
map(tuple, KUHN_POKER_CARDS)...,
KUHN_POKER_CARD_COMBINATIONS...,
(
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for
actions in (
(),
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for actions in ((),
(:bet,),
(:bet, :bet),
(:bet, :pass),
(:pass,),
(:pass, :pass),
(:pass, :bet),
(:pass, :bet, :pass),
(:pass, :bet, :bet),
)
)...,
)
(:pass, :bet, :bet),)
)...,)

"""
![](https://upload.wikimedia.org/wikipedia/commons/a/a9/Kuhn_poker_tree.svg)
Expand Down Expand Up @@ -146,15 +141,16 @@ end

RLBase.current_player(env::KuhnPokerEnv) =
if length(env.cards) < 2
CHANCE_PLAYER
elseif length(env.actions) == 0
1
elseif length(env.actions) == 1
2
elseif length(env.actions) == 2
1
else
end
CHANCE_PLAYER
elseif length(env.actions) == 0
1
elseif length(env.actions) == 1
2
elseif length(env.actions) == 2
1
else
2 # actually the game is over now
end

RLBase.NumAgentStyle(::KuhnPokerEnv) = MultiAgent(2)
RLBase.DynamicStyle(::KuhnPokerEnv) = SEQUENTIAL
Expand Down
17 changes: 3 additions & 14 deletions src/environments/wrappers/ActionTransformedEnv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export ActionTransformedEnv

struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnvWrapper
struct ActionTransformedEnv{P,M,E <: AbstractEnv} <: AbstractEnvWrapper
action_space_mapping::P
action_mapping::M
env::E
Expand All @@ -15,23 +15,12 @@ feeding it into `env`.
"""
function ActionTransformedEnv(
env;
action_space_mapping = identity,
action_mapping = identity,
action_space_mapping=identity,
action_mapping=identity,
)
ActionTransformedEnv(action_space_mapping, action_mapping, env)
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:action_space, :legal_action_space)
@eval RLBase.$f(x::ActionTransformedEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)

RLBase.action_space(env::ActionTransformedEnv, args...) =
env.action_space_mapping(action_space(env.env), args...)

Expand Down
19 changes: 0 additions & 19 deletions src/environments/wrappers/DefaultStateStyle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,3 @@ Reset the result of `DefaultStateStyle` without changing the original behavior.
DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)

RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:DefaultStateStyle, :state, :state_space)
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)

RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
state(env.env, ss, p)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
state_space(env.env, ss, p)
15 changes: 2 additions & 13 deletions src/environments/wrappers/MaxTimeoutEnv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export MaxTimeoutEnv

mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnvWrapper
mutable struct MaxTimeoutEnv{E <: AbstractEnv} <: AbstractEnvWrapper
env::E
max_t::Int
current_t::Int
Expand All @@ -11,29 +11,18 @@ end

Force `is_terminated(env)` return `true` after `max_t` interactions.
"""
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} =
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int=1) where {E <: AbstractEnv} =
MaxTimeoutEnv(env, max_t, current_t)

function (env::MaxTimeoutEnv)(args...; kwargs...)
env.env(args...; kwargs...)
env.current_t = env.current_t + 1
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:is_terminated, :reset!)
@eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.is_terminated(env::MaxTimeoutEnv) =
(env.current_t > env.max_t) || is_terminated(env.env)

function RLBase.reset!(env::MaxTimeoutEnv)
env.current_t = 1
RLBase.reset!(env.env)
end

RLBase.state(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
15 changes: 1 addition & 14 deletions src/environments/wrappers/RewardOverriddenEnv.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
export RewardOverriddenEnv

struct RewardOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
struct RewardOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
env::E
f::F
end

(env::RewardOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)

RewardOverriddenEnv(f) = env -> RewardOverriddenEnv(f, env)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f != :reward
@eval RLBase.$f(x::RewardOverriddenEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.reward(env::RewardOverriddenEnv, args...; kwargs...) =
env.f(reward(env.env, args...; kwargs...))

RLBase.state(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
13 changes: 1 addition & 12 deletions src/environments/wrappers/StateCachedEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some
environments are stateful during each `state(env)`. For example:
`StateOverriddenEnv(StackFrames(...))`.
"""
mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper
mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper
s::S
env::E
is_state_cached::Bool
Expand All @@ -28,14 +28,3 @@ function RLBase.state(env::StateCachedEnv, args...; kwargs...)
env.s
end
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f != :state
@eval RLBase.$f(x::StateCachedEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
17 changes: 1 addition & 16 deletions src/environments/wrappers/StateOverriddenEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,12 @@ Apply `f` to override `state(env)`.
If the meaning of state space is changed after apply `f`, one should
manually redefine the `RLBase.state_space(env::YourSpecificEnv)`.
"""
struct StateOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
struct StateOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
env::E
f::F
end

StateOverriddenEnv(f) = env -> StateOverriddenEnv(f, env)

(env::StateOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:state,)
@eval RLBase.$f(x::StateOverriddenEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StateOverriddenEnv, args...; kwargs...) =
env.f(state(env.env, args...; kwargs...))

RLBase.state(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
env.f(state(env.env, ss))

RLBase.state_space(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
15 changes: 2 additions & 13 deletions src/environments/wrappers/StochasticEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ export StochasticEnv

using StatsBase: sample, Weights

struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv
struct StochasticEnv{E <: AbstractEnv,R} <: AbstractEnvWrapper
env::E
rng::R
end

function StochasticEnv(env; rng = Random.GLOBAL_RNG)
function StochasticEnv(env; rng=Random.GLOBAL_RNG)
ChanceStyle(env) === EXPLICIT_STOCHASTIC ||
throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported"))
env = StochasticEnv(env, rng)
Expand Down Expand Up @@ -39,14 +39,3 @@ RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC
RLBase.players(env::StochasticEnv) =
[p for p in players(env.env) if p != chance_player(env.env)]
Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:players, :ChanceStyle, :reset!)
@eval RLBase.$f(x::StochasticEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
16 changes: 16 additions & 0 deletions src/environments/wrappers/wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
export AbstractEnvWrapper

abstract type AbstractEnvWrapper <: AbstractEnv end

Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(env)))"

Base.getindex(env::AbstractEnvWrapper) = env.env

(env::AbstractEnvWrapper)(args...; kwargs...) = env.env(args...; kwargs...)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
@eval RLBase.$f(x::AbstractEnvWrapper, args...; kwargs...) = $f(x[], args...; kwargs...)
end

# avoid ambiguous
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state(env[], ss, p)
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state(env[], ss)
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state_space(env[], ss)
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state_space(env[], ss, p)

include("ActionTransformedEnv.jl")
include("DefaultStateStyle.jl")
include("MaxTimeoutEnv.jl")
Expand Down