Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 79cb5d7

Browse files
authored
fix state_size (#132)
1 parent 3b32dae commit 79cb5d7

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/environments/3rd_party/open_spiel.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import .OpenSpiel:
1212
information,
1313
information_state_tensor,
1414
information_state_tensor_size,
15+
information_state_tensor_shape,
1516
information_state_string,
1617
num_distinct_actions,
1718
num_players,
@@ -22,6 +23,7 @@ import .OpenSpiel:
2223
rewards,
2324
reward_model,
2425
observation_tensor_size,
26+
observation_tensor_shape,
2527
observation_tensor,
2628
observation_string,
2729
chance_mode,
@@ -42,7 +44,7 @@ import .OpenSpiel:
4244
`True` or `False` (instead of `true` or `false`). Another approach is to just
4345
specify parameters in `kwargs` in the Julia style.
4446
"""
45-
function OpenSpielEnv(name = "kuhn_poker"; kwargs...)
47+
function OpenSpielEnv(name="kuhn_poker"; kwargs...)
4648
game = load_game(String(name); kwargs...)
4749
state = new_initial_state(game)
4850
OpenSpielEnv(state, game)
@@ -58,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
5860
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
5961

6062
function RLBase.players(env::OpenSpielEnv)
61-
p = 0:(num_players(env.game)-1)
63+
p = 0:(num_players(env.game) - 1)
6264
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
6365
(p..., RLBase.chance_player(env))
6466
else
@@ -89,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player)
8991
# @assert player == chance_player(env)
9092
p = zeros(length(action_space(env)))
9193
for (k, v) in chance_outcomes(env.state)
92-
p[k+1] = v
94+
p[k + 1] = v
9395
end
9496
p
9597
end
@@ -100,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
100102
num_distinct_actions(env.game)
101103
mask = BitArray(undef, n)
102104
for a in legal_actions(env.state, player)
103-
mask[a+1] = true
105+
mask[a + 1] = true
104106
end
105107
mask
106108
end
@@ -135,23 +137,28 @@ end
135137
_state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) =
136138
information_state_string(env.state, player)
137139
_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) =
138-
information_state_tensor(env.state, player)
140+
reshape(information_state_tensor(env.state, player), reverse(information_state_tensor_shape(env.game))...)
139141
_state(env::OpenSpielEnv, ::Observation{String}, player) =
140142
observation_string(env.state, player)
141143
_state(env::OpenSpielEnv, ::Observation{Array}, player) =
142-
observation_tensor(env.state, player)
144+
reshape(observation_tensor(env.state, player), reverse(observation_tensor_shape(env.game))...)
143145

144146
RLBase.state_space(
145147
env::OpenSpielEnv,
146148
::Union{InformationSet{String},Observation{String}},
147149
p,
148150
) = WorldSpace{AbstractString}()
149-
RLBase.state_space(
150-
env::OpenSpielEnv,
151-
::Union{InformationSet{Array},Observation{Array}},
151+
152+
RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array},
153+
p,
154+
) = Space(
155+
fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...),
156+
)
157+
158+
RLBase.state_space(env::OpenSpielEnv, ::Observation{Array},
152159
p,
153160
) = Space(
154-
fill(typemin(Float64)..typemax(Float64), information_state_tensor_size(env.state)),
161+
fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...),
155162
)
156163

157164
Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently."
@@ -192,9 +199,7 @@ RLBase.RewardStyle(env::OpenSpielEnv) =
192199
reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD :
193200
RLBase.TERMINAL_REWARD
194201

195-
RLBase.StateStyle(env::OpenSpielEnv) = (
196-
RLBase.InformationSet{String}(),
202+
RLBase.StateStyle(env::OpenSpielEnv) = (RLBase.InformationSet{String}(),
197203
RLBase.InformationSet{Array}(),
198204
RLBase.Observation{String}(),
199-
RLBase.Observation{Array}(),
200-
)
205+
RLBase.Observation{Array}(),)

0 commit comments

Comments
 (0)