Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

add OpenSpiel #33

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
GR = "0.46"
Expand All @@ -17,10 +18,11 @@ julia = "1.3"

[extras]
ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977"
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs"]
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs", "OpenSpiel"]
1 change: 1 addition & 0 deletions src/ReinforcementLearningEnvironments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function __init__()
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("environments/atari.jl")
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include("environments/gym.jl")
@require POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" include("environments/mdp.jl")
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("environments/open_spiel.jl")
end

end # module
8 changes: 8 additions & 0 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
module AtariWrapper

using ArcadeLearningEnvironment, GR, Random
using ReinforcementLearningBase

export AtariEnv

Expand Down Expand Up @@ -163,3 +166,8 @@ function RLBase.render(env::AtariEnv)
end

list_atari_rom_names() = getROMList()

end

using .AtariWrapper
export AtariEnv
8 changes: 8 additions & 0 deletions src/environments/gym.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
module GymWrapper

using ReinforcementLearningBase
using PyCall

export GymEnv
Expand Down Expand Up @@ -111,3 +114,8 @@ function list_gym_env_names(;
gym = pyimport("gym")
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
end

end

using .GymWrapper
export GymEnv
10 changes: 9 additions & 1 deletion src/environments/mdp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module POMDPWrapper

export POMDPEnv

using ReinforcementLearningBase
using POMDPs
using Random

Expand All @@ -18,7 +21,7 @@ mutable struct POMDPEnv{M<:POMDP,S,O,I,R,RNG<:AbstractRNG} <: AbstractEnv
rng::RNG
end

Random.seed!(env::POMDPEnv, seed) = seed!(env.rng, seed)
Random.seed!(env::POMDPEnv, seed) = Random.seed!(env.rng, seed)

function POMDPEnv(model::POMDP; seed = nothing)
rng = MersenneTwister(seed)
Expand Down Expand Up @@ -129,3 +132,8 @@ end

RLBase.get_observation_space(env::MDPEnv) = get_observation_space(env.model)
RLBase.get_action_space(env::MDPEnv) = get_action_space(env.model)

end

using .POMDPWrapper
export POMDPEnv
151 changes: 151 additions & 0 deletions src/environments/open_spiel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
module OpenSpielWrapper

export OpenSpielEnv

using ReinforcementLearningBase
using OpenSpiel
using Random
using StatsBase:sample, weights

abstract type AbstractObservationType end

mutable struct OpenSpielEnv{O, D, S, G, R} <: AbstractEnv
state::S
game::G
rng::R
end

"""
OpenSpielEnv(name; observation_type=nothing, kwargs...)

# Arguments

- `name`::`String`, you can call `resigtered_names()` to see all the supported names. Note that the name can contains parameters, like `"goofspiel(imp_info=True,num_cards=4,points_order=descending)"`. Because the parameters part is parsed by the backend C++ code, the bool variable must be `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style.
- `observation_type`::`Union{Symbol,Nothing}`, Supported values are [`:information`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L342-L367), [`:observation`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L397-L408) or `nothing`. The default value is `nothing`, which means `:information` if the game ` provides_information_state_tensor`. If not, it means `:observation`.
"""
function OpenSpielEnv(name;
seed = nothing,
observation_type=nothing,
kwargs...
)
game = load_game(name, kwargs...)
game_type = get_type(game)

has_info_state = provides_information_state_tensor(game_type)
has_obs_state = provides_observation_tensor(game_type)
has_info_state || has_obs_state || @error "the environment neither provides information tensor nor provides observation tensor"
if isnothing(observation_type)
observation_type = has_info_state ? :information : :observation
end
if observation_type == :observation
has_obs_state || @error "the environment doesn't support observation_type of $observation_type"
elseif observation_type == :information
has_info_state || @error "the environment doesn't support observation_type of $observation_type"
else
@error "unknown observation_type $observation_type"
end

d = dynamics(game_type)
dynamic_style = if d === OpenSpiel.SEQUENTIAL
RLBase.SEQUENTIAL
elseif d === OpenSpiel.SIMULTANEOUS
RLBase.SIMULTANEOUS
else
@error "unknown dynamic style of $d"
end

state = new_initial_state(game)

rng = MersenneTwister(seed)

env = OpenSpielEnv{observation_type, dynamic_style, typeof(state), typeof(game), typeof(rng)}(state, game, rng)
reset!(env)
env
end

RLBase.DynamicStyle(env::OpenSpielEnv{O, D}) where {O, D} = D

function RLBase.reset!(env::OpenSpielEnv)
state = new_initial_state(env.game)
_sample_external_events!(env.rng, state)
env.state = state
end

function _sample_external_events!(rng::AbstractRNG, state)
while is_chance_node(state)
outcomes_with_probs = chance_outcomes(state)
actions, probs = zip(outcomes_with_probs...)
action = actions[sample(rng, weights(collect(probs)))]
apply_action(state, action)
end
end

function (env::OpenSpielEnv)(action)
apply_action(env.state, action)
_sample_external_events!(env.rng, env.state)
end

(env::OpenSpielEnv)(player, action) = env(DynamicStyle(env), player, action)

function (env::OpenSpielEnv)(::Sequential, player, action)
if get_current_player(env) == player
apply_action(env.state, action)
else
apply_action(env.state, OpenSpiel.INVALID_ACTION[])
end
_sample_external_events!(env.rng, env.state)
end

(env::OpenSpielEnv)(::Simultaneous, player, action) = @error "Simultaneous environments can not take in the actions from players seperately"

struct OpenSpielObs{O, D, S, P}
state::S
player::P
end

RLBase.observe(env::OpenSpielEnv{O,D, S}, player::P) where {O, D,S,P} = OpenSpielObs{O, D,S, P}(env.state, player)

RLBase.get_action_space(env::OpenSpielEnv) = DiscreteSpace(0:num_distinct_actions(env.game)-1)

function RLBase.get_observation_space(env::OpenSpielEnv{:information})
s = information_state_tensor_size(env.game)
MultiContinuousSpace(
fill(typemin(Float64), s...),
fill(typemax(Float64), s...),
)
end

function RLBase.get_observation_space(env::OpenSpielEnv{:observation})
s = observation_tensor_size(env.game)
MultiContinuousSpace(
fill(typemin(Float64), s...),
fill(typemax(Float64), s...),
)
end

RLBase.get_current_player(env::OpenSpielEnv) = current_player(env.state)

RLBase.get_num_players(env::OpenSpielEnv) = num_players(env.game)

Random.seed!(env::OpenSpielEnv, seed) = Random.seed!(env.rng, seed)

RLBase.ActionStyle(::OpenSpielObs) = FULL_ACTION_SET

RLBase.get_legal_actions(obs::OpenSpielObs) = legal_actions(obs.state, obs.player)

RLBase.get_legal_actions_mask(obs::OpenSpielObs) = legal_actions_mask(obs.state, obs.player)

RLBase.get_terminal(obs::OpenSpielObs) = OpenSpiel.is_terminal(obs.state)

RLBase.get_reward(obs::OpenSpielObs) = rewards(obs.state)[obs.player+1] # player starts with 0

RLBase.get_state(obs::OpenSpielObs{:information}) = information_state_tensor(obs.state, obs.player)

RLBase.get_state(obs::OpenSpielObs{:observation}) = observation_tensor(obs.state, obs.player)

RLBase.get_invalid_action(obs::OpenSpielObs) = convert(Int, OpenSpiel.INVALID_ACTION[])

end

using .OpenSpielWrapper
export OpenSpielEnv
4 changes: 2 additions & 2 deletions test/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
end
end

gym_env_names = ReinforcementLearningEnvironments.list_gym_env_names(
gym_env_names = ReinforcementLearningEnvironments.GymWrapper.list_gym_env_names(
modules = [
"gym.envs.algorithmic",
"gym.envs.classic_control",
Expand All @@ -30,7 +30,7 @@

gym_env_names = filter(x -> x != "KellyCoinflipGeneralized-v0", gym_env_names) # not sure why this env has outliers

atari_env_names = ReinforcementLearningEnvironments.list_atari_rom_names()
atari_env_names = ReinforcementLearningEnvironments.AtariWrapper.list_atari_rom_names()
atari_env_names = filter(x -> x != "defender", atari_env_names)

for env_exp in [
Expand Down
38 changes: 38 additions & 0 deletions test/open_spiel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testset "OpenSpielEnv" begin

for name in ["tic_tac_toe", "kuhn_poker", "goofspiel(imp_info=True,num_cards=4,points_order=descending)"]
env = OpenSpielEnv(name, seed=123)
get_current_player(env)
get_num_players(env)
get_observation_space(env)
get_action_space(env)
DynamicStyle(env)

obs = observe(env)
obs_0 = observe(env, 0)
obs_1 = observe(env, 1)
ActionStyle(obs_0)
get_legal_actions_mask(obs_0)
get_legal_actions_mask(obs_1)
get_legal_actions(obs_0)
get_legal_actions(obs_1)
get_terminal(obs_0)
get_terminal(obs_1)
get_reward(obs_0)
get_reward(obs_1)
get_state(obs_0)
get_state(obs_1)
get_invalid_action(obs_0)

Random.seed!(env, 456)
reset!(env)

while true
obs = observe(env)
get_terminal(obs) && break
action = rand(get_legal_actions(obs))
env(action)
end
@test true
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ using ArcadeLearningEnvironment
using PyCall
using POMDPs
using POMDPModels
using OpenSpiel
using Random

RLBase.get_observation_space(m::TigerPOMDP) = DiscreteSpace((false, true))

@testset "ReinforcementLearningEnvironments" begin

include("environments.jl")
include("atari.jl")
include("open_spiel.jl")
end