Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 1246031

Browse files
authored
Multi agent changes (#82)
* add snake game * update readme * ignore SnakeGameEnv in test * refactor OpenSpiel a little * finish CFR * comment out tests related to SnakeGames for CI * update dependency of RLBase
1 parent 2444330 commit 1246031

File tree

4 files changed

+62
-45
lines changed

4 files changed

+62
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1414
[compat]
1515
GR = "0.46, 0.47, 0.48, 0.49, 0.50, 0.51"
1616
OrdinaryDiffEq = "5"
17-
ReinforcementLearningBase = "0.8"
17+
ReinforcementLearningBase = "0.8.1"
1818
Requires = "1.0"
1919
StatsBase = "0.32, 0.33"
2020
julia = "1.3"

src/environments/open_spiel.jl

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import .OpenSpiel:
33
get_type,
44
provides_information_state_tensor,
55
provides_observation_tensor,
6+
provides_information_state_string,
7+
provides_observation_string,
68
dynamics,
79
new_initial_state,
810
chance_mode,
@@ -36,35 +38,32 @@ using StatsBase: sample, weights
3638
# Arguments
3739
3840
- `name`::`String`, you can call `ReinforcementLearningEnvironments.OpenSpiel.registered_names()` to see all the supported names. Note that the name can contains parameters, like `"goofspiel(imp_info=True,num_cards=4,points_order=descending)"`. Because the parameters part is parsed by the backend C++ code, the bool variable must be `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style.
39-
- `state_type`::`Union{Symbol,Nothing}`, Supported values are [`:information`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L342-L367), [`:observation`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L397-L408) or `nothing`. The default value is `nothing`, which means `:information` if the game ` provides_information_state_tensor`. If not, it means `:observation`.
41+
- `default_state_style`::`Union{AbstractStateStyle,Nothing}`, Supported values are [`Information{<:Union{String,Array}}`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L342-L367), [`Observation{<:Union{String,Array}}`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L397-L408) or `nothing`.
4042
- `rng::AbstractRNG`, used to initial the `rng` for chance nodes. And the `rng` will only be used if the environment contains chance node, else it is set to `nothing`. To set the seed of inner environment, you may check the documentation of each specific game. Usually adding a keyword argument named `seed` should work.
4143
- `is_chance_agent_required::Bool=false`, by default, no chance agent is required. An internal `rng` will be used to automatically generate actions for chance node. If set to `true`, you need to feed the action of chance agent to environment explicitly. And the `seed` will be ignored.
4244
"""
4345
function OpenSpielEnv(
4446
name;
4547
rng = Random.GLOBAL_RNG,
46-
state_type = nothing,
48+
default_state_style = nothing,
4749
is_chance_agent_required = false,
4850
kwargs...,
4951
)
5052
game = load_game(name; kwargs...)
5153
game_type = get_type(game)
5254

53-
has_info_state = provides_information_state_tensor(game_type)
54-
has_obs_state = provides_observation_tensor(game_type)
55-
has_info_state ||
56-
has_obs_state ||
57-
@error "the environment neither provides information tensor nor provides observation tensor"
58-
if isnothing(state_type)
59-
state_type = has_info_state ? :information : :observation
60-
end
61-
62-
if state_type == :observation
63-
has_obs_state || @error "the environment doesn't support state_typeof $state_type"
64-
elseif state_type == :information
65-
has_info_state || @error "the environment doesn't support state_typeof $state_type"
66-
else
67-
@error "unknown state_type $state_type"
55+
if isnothing(default_state_style)
56+
default_state_style = if provides_information_state_string(game_type)
57+
RLBase.Information{String}()
58+
elseif provides_information_state_tensor(game_type)
59+
RLBase.Information{Array}()
60+
elseif provides_observation_tensor(game_type)
61+
Observation{Array}()
62+
elseif provides_observation_string(game_type)
63+
Observation{String}()
64+
else
65+
nothing
66+
end
6867
end
6968

7069
state = new_initial_state(game)
@@ -103,7 +102,7 @@ function OpenSpielEnv(
103102
end
104103

105104
env =
106-
OpenSpielEnv{state_type,Tuple{c,d,i,n,r,u},typeof(state),typeof(game),typeof(rng)}(
105+
OpenSpielEnv{Tuple{default_state_style,c,d,i,n,r,u},typeof(state),typeof(game),typeof(rng)}(
107106
state,
108107
game,
109108
rng,
@@ -113,14 +112,15 @@ function OpenSpielEnv(
113112
end
114113

115114
RLBase.ActionStyle(env::OpenSpielEnv) = FULL_ACTION_SET
116-
RLBase.ChanceStyle(env::OpenSpielEnv{S,Tuple{C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = C
117-
RLBase.InformationStyle(env::OpenSpielEnv{S,Tuple{C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = I
118-
RLBase.NumAgentStyle(env::OpenSpielEnv{S,Tuple{C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = N
119-
RLBase.RewardStyle(env::OpenSpielEnv{S,Tuple{C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = R
120-
RLBase.UtilityStyle(env::OpenSpielEnv{S,Tuple{C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = U
115+
RLBase.ChanceStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = C
116+
RLBase.InformationStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = I
117+
RLBase.NumAgentStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = N
118+
RLBase.RewardStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = R
119+
RLBase.UtilityStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = U
120+
RLBase.DefaultStateStyle(env::OpenSpielEnv{Tuple{S,C,D,I,N,R,U}}) where {S,C,D,I,N,R,U} = S
121121

122-
Base.copy(env::OpenSpielEnv{S,T,ST,G,R}) where {S,T,ST,G,R} =
123-
OpenSpielEnv{S,T,ST,G,R}(copy(env.state), env.game, env.rng)
122+
Base.copy(env::OpenSpielEnv{T,ST,G,R}) where {T,ST,G,R} =
123+
OpenSpielEnv{T,ST,G,R}(copy(env.state), env.game, env.rng)
124124

125125
function RLBase.reset!(env::OpenSpielEnv)
126126
state = new_initial_state(env.game)
@@ -132,55 +132,72 @@ _sample_external_events!(::Nothing, state) = nothing
132132

133133
function _sample_external_events!(rng::AbstractRNG, state)
134134
while is_chance_node(state)
135-
outcomes_with_probs = chance_outcomes(state)
136-
actions, probs = zip(outcomes_with_probs...)
137-
action = actions[sample(rng, weights(collect(probs)))]
138-
apply_action(state, action)
135+
apply_action(state, rand(rng, reinterpret(ActionProbPair{Int, Float64}, chance_outcomes(state))).action)
139136
end
140137
end
141138

142-
function (env::OpenSpielEnv)(action)
139+
function (env::OpenSpielEnv)(action::Int)
143140
apply_action(env.state, action)
144141
ChanceStyle(env) === STOCHASTIC && _sample_external_events!(env.rng, env.state)
145142
end
146143

147-
RLBase.get_actions(env::OpenSpielEnv) = 0:num_distinct_actions(env.game)-1
148144
RLBase.get_current_player(env::OpenSpielEnv) = current_player(env.state)
149145
RLBase.get_chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
150-
RLBase.get_players(env::OpenSpielEnv) = 0:(num_players(env.game)-1)
151-
152-
function Random.seed!(env::OpenSpielEnv, seed)
153-
if ChanceStyle(env) === STOCHASTIC
154-
Random.seed!(env.rng, seed)
146+
RLBase.get_players(env::OpenSpielEnv) = get_players(env, ChanceStyle(env))
147+
RLBase.get_players(env::OpenSpielEnv, ::Any) = 0:(num_players(env.game)-1)
148+
RLBase.get_players(env::OpenSpielEnv, ::Union{ExplicitStochastic, SampledStochastic}) = (get_chance_player(env), 0:(num_players(env.game)-1)...)
149+
RLBase.get_num_players(env::OpenSpielEnv) = length(get_players(env))
150+
151+
function RLBase.get_actions(env::OpenSpielEnv, player)
152+
if player == get_chance_player(env)
153+
reinterpret(ActionProbPair{Int, Float64}, chance_outcomes(env.state))
155154
else
156-
@error "only environments of STOCHASTIC are supported, perhaps initialize the environment with a seed argument instead?"
155+
0:num_distinct_actions(env.game)-1
157156
end
158157
end
159158

160-
RLBase.get_legal_actions(env::OpenSpielEnv, player) = legal_actions(env.state, player)
159+
function RLBase.get_legal_actions(env::OpenSpielEnv, player)
160+
if player == get_chance_player(env)
161+
reinterpret(ActionProbPair{Int, Float64}, chance_outcomes(env.state))
162+
else
163+
legal_actions(env.state, player)
164+
end
165+
end
161166

162167
function RLBase.get_legal_actions_mask(env::OpenSpielEnv, player)
163-
n = player == convert(Int, OpenSpiel.CHANCE_PLAYER) ? max_chance_outcomes(env.game) :
164-
num_distinct_actions(env.game)
168+
n = player == get_chance_player(env) ? max_chance_outcomes(env.game) : num_distinct_actions(env.game)
165169
mask = BitArray(undef, n)
166170
for a in legal_actions(env.state, player)
167171
mask[a+1] = true
168172
end
169173
mask
170174
end
171175

176+
function Random.seed!(env::OpenSpielEnv, seed)
177+
if ChanceStyle(env) === STOCHASTIC
178+
Random.seed!(env.rng, seed)
179+
else
180+
@error "only environments of STOCHASTIC are supported, perhaps initialize the environment with a seed argument instead?"
181+
end
182+
end
183+
172184
RLBase.get_terminal(env::OpenSpielEnv, player) = OpenSpiel.is_terminal(env.state)
173185

174186
function RLBase.get_reward(env::OpenSpielEnv, player)
175187
if DynamicStyle(env) === SIMULTANEOUS &&
176188
player == convert(Int, OpenSpiel.SIMULTANEOUS_PLAYER)
177189
rewards(env.state)
190+
elseif player == get_chance_player(env)
191+
0 # ??? type stable
178192
else
179193
player_reward(env.state, player)
180194
end
181195
end
182196

183-
RLBase.get_state(env::OpenSpielEnv) = env.state
184-
RLBase.get_state(env::OpenSpielEnv, player::Integer) = env.state
197+
RLBase.get_state(env::OpenSpielEnv, player::Integer) = get_state(env, DefaultStateStyle(env), player)
198+
RLBase.get_state(env::OpenSpielEnv, ::RLBase.Information{String}, player) = information_state_string(env.state, player)
199+
RLBase.get_state(env::OpenSpielEnv, ::RLBase.Information{Array}, player) = information_state_tensor(env.state, player)
200+
RLBase.get_state(env::OpenSpielEnv, ::Observation{String}, player) = observation_string(env.state, player)
201+
RLBase.get_state(env::OpenSpielEnv, ::Observation{Array}, player) = observation_tensor(env.state, player)
185202

186203
RLBase.get_history(env::OpenSpielEnv) = history(env.state)

src/environments/structs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mutable struct MDPEnv{M,S,A,R} <: AbstractEnv
3838
end
3939
export MDPEnv
4040

41-
mutable struct OpenSpielEnv{S,T,ST,G,R} <: AbstractEnv
41+
mutable struct OpenSpielEnv{T,ST,G,R} <: AbstractEnv
4242
state::ST
4343
game::G
4444
rng::R

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using PyCall
66
using POMDPs
77
using POMDPModels
88
using OpenSpiel
9-
using SnakeGames
9+
# using SnakeGames
1010
using Random
1111

1212
@testset "ReinforcementLearningEnvironments" begin

0 commit comments

Comments
 (0)