@@ -12,6 +12,7 @@ import .OpenSpiel:
12
12
information,
13
13
information_state_tensor,
14
14
information_state_tensor_size,
15
+ information_state_tensor_shape,
15
16
information_state_string,
16
17
num_distinct_actions,
17
18
num_players,
@@ -22,6 +23,7 @@ import .OpenSpiel:
22
23
rewards,
23
24
reward_model,
24
25
observation_tensor_size,
26
+ observation_tensor_shape,
25
27
observation_tensor,
26
28
observation_string,
27
29
chance_mode,
@@ -42,7 +44,7 @@ import .OpenSpiel:
42
44
`True` or `False` (instead of `true` or `false`). Another approach is to just
43
45
specify parameters in `kwargs` in the Julia style.
44
46
"""
45
- function OpenSpielEnv (name = " kuhn_poker" ; kwargs... )
47
+ function OpenSpielEnv (name= " kuhn_poker" ; kwargs... )
46
48
game = load_game (String (name); kwargs... )
47
49
state = new_initial_state (game)
48
50
OpenSpielEnv (state, game)
@@ -58,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
58
60
RLBase. chance_player (env:: OpenSpielEnv ) = convert (Int, OpenSpiel. CHANCE_PLAYER)
59
61
60
62
function RLBase. players (env:: OpenSpielEnv )
61
- p = 0 : (num_players (env. game)- 1 )
63
+ p = 0 : (num_players (env. game) - 1 )
62
64
if ChanceStyle (env) === EXPLICIT_STOCHASTIC
63
65
(p... , RLBase. chance_player (env))
64
66
else
@@ -89,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player)
89
91
# @assert player == chance_player(env)
90
92
p = zeros (length (action_space (env)))
91
93
for (k, v) in chance_outcomes (env. state)
92
- p[k+ 1 ] = v
94
+ p[k + 1 ] = v
93
95
end
94
96
p
95
97
end
@@ -100,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
100
102
num_distinct_actions (env. game)
101
103
mask = BitArray (undef, n)
102
104
for a in legal_actions (env. state, player)
103
- mask[a+ 1 ] = true
105
+ mask[a + 1 ] = true
104
106
end
105
107
mask
106
108
end
@@ -135,23 +137,28 @@ end
135
137
_state (env:: OpenSpielEnv , :: RLBase.InformationSet{String} , player) =
136
138
information_state_string (env. state, player)
137
139
_state (env:: OpenSpielEnv , :: RLBase.InformationSet{Array} , player) =
138
- information_state_tensor (env. state, player)
140
+ reshape ( information_state_tensor (env. state, player), reverse ( information_state_tensor_shape (env . game)) ... )
139
141
_state (env:: OpenSpielEnv , :: Observation{String} , player) =
140
142
observation_string (env. state, player)
141
143
_state (env:: OpenSpielEnv , :: Observation{Array} , player) =
142
- observation_tensor (env. state, player)
144
+ reshape ( observation_tensor (env. state, player), reverse ( observation_tensor_shape (env . game)) ... )
143
145
144
146
RLBase. state_space (
145
147
env:: OpenSpielEnv ,
146
148
:: Union{InformationSet{String},Observation{String}} ,
147
149
p,
148
150
) = WorldSpace {AbstractString} ()
149
- RLBase. state_space (
150
- env:: OpenSpielEnv ,
151
- :: Union{InformationSet{Array},Observation{Array}} ,
151
+
152
+ RLBase. state_space (env:: OpenSpielEnv , :: InformationSet{Array} ,
153
+ p,
154
+ ) = Space (
155
+ fill (typemin (Float64).. typemax (Float64), reverse (information_state_tensor_shape (env. game))... ),
156
+ )
157
+
158
+ RLBase. state_space (env:: OpenSpielEnv , :: Observation{Array} ,
152
159
p,
153
160
) = Space (
154
- fill (typemin (Float64).. typemax (Float64), information_state_tensor_size ( env. state) ),
161
+ fill (typemin (Float64).. typemax (Float64), reverse ( observation_tensor_shape ( env. game)) ... ),
155
162
)
156
163
157
164
Random. seed! (env:: OpenSpielEnv , s) = @warn " seed!(OpenSpielEnv) is not supported currently."
@@ -192,9 +199,7 @@ RLBase.RewardStyle(env::OpenSpielEnv) =
192
199
reward_model (get_type (env. game)) == OpenSpiel. REWARDS ? RLBase. STEP_REWARD :
193
200
RLBase. TERMINAL_REWARD
194
201
195
- RLBase. StateStyle (env:: OpenSpielEnv ) = (
196
- RLBase. InformationSet {String} (),
202
+ RLBase. StateStyle (env:: OpenSpielEnv ) = (RLBase. InformationSet {String} (),
197
203
RLBase. InformationSet {Array} (),
198
204
RLBase. Observation {String} (),
199
- RLBase. Observation {Array} (),
200
- )
205
+ RLBase. Observation {Array} (),)
0 commit comments