1
1
using Hanabi
2
2
3
- export HanabiEnv, legal_actions
3
+ export HanabiEnv, legal_actions, observe, reset!, interact!
4
4
export PlayCard, DiscardCard, RevealColor, RevealRank, parse_move
5
+ export cur_player, get_score, get_fireworks, encode_observation, encode_observation!, legal_actions!, legal_actions, get_cur_player
5
6
6
7
@enum HANABI_OBSERVATION_ENCODER_TYPE CANONICAL
7
8
@enum COLOR R Y G W B
@@ -11,43 +12,58 @@ export PlayCard, DiscardCard, RevealColor, RevealRank, parse_move
11
12
const CHANCE_PLAYER_ID = - 1
12
13
const COLORS_DICT = Dict (string (x) => x for x in instances (COLOR))
13
14
15
+ # ##
16
+ # ## finalizers
17
+ # ##
18
+
19
+ move_finalizer (x) = finalizer (m -> delete_move (m), x)
20
+ history_item_finalizer (x) = finalizer (h -> delete_history_item (h), x)
21
+ game_finalizer (x) = finalizer (g -> delete_game (g), x)
22
+ observation_finalizer (x) = finalizer (o -> delete_observation (o), x)
23
+ observation_encoder_finalizer (x) = finalizer (e -> delete_observation_encoder (e), x)
24
+ state_finalizer (x) = finalizer (s -> delete_state (s), x)
25
+
14
26
# ##
15
27
# ## moves
16
28
# ##
17
29
18
30
function PlayCard (card_idx:: Int )
19
31
m = Ref {HanabiMove} ()
20
32
get_play_move (card_idx - 1 , m)
33
+ move_finalizer (m)
21
34
m
22
35
end
23
36
24
37
function DiscardCard (card_idx:: Int )
25
38
m = Ref {HanabiMove} ()
26
39
get_discard_move (card_idx - 1 , m)
40
+ move_finalizer (m)
27
41
m
28
42
end
29
43
30
44
function RevealColor (target_offset:: Int , color:: COLOR )
31
45
m = Ref {HanabiMove} ()
32
46
get_reveal_color_move (target_offset, color, m)
47
+ move_finalizer (m)
33
48
m
34
49
end
35
50
36
51
function RevealRank (target_offset:: Int , rank:: Int )
37
52
m = Ref {HanabiMove} ()
38
53
get_reveal_rank_move (target_offset, rank - 1 , m)
54
+ move_finalizer (m)
39
55
m
40
56
end
41
57
42
58
function parse_move (s:: String )
43
59
m = match (r" PlayCard\( (?<card_idx>[1-5])\) " , s)
44
- ! isnothing (m ) && return PlayCard (parse (Int, m[:card_idx ]))
60
+ ! (m === nothing ) && return PlayCard (parse (Int, m[:card_idx ]))
45
61
m = match (r" DiscardCard\( (?<card_idx>[1-5])\) " , s)
46
- ! isnothing (m ) && return DiscardCard (parse (Int, m[:card_idx ]))
62
+ ! (m === nothing ) && return DiscardCard (parse (Int, m[:card_idx ]))
47
63
m = match (r" RevealColor\( (?<target>[1-5]),(?<color>[RYGWB])\) " , s)
48
- ! isnothing (m ) && return RevealColor (parse (Int, m[:target ]), COLORS_DICT[m[:color ]])
64
+ ! (m === nothing ) && return RevealColor (parse (Int, m[:target ]), COLORS_DICT[m[:color ]])
49
65
m = match (r" RevealRank\( (?<target>[1-5]),(?<rank>[1-5])\) " , s)
50
- ! isnothing (m ) && return RevealRank (parse (Int, m[:target ]), parse (Int, m[:rank ]))
66
+ ! (m === nothing ) && return RevealRank (parse (Int, m[:target ]), parse (Int, m[:rank ]))
51
67
return nothing
52
68
end
53
69
73
89
74
90
"""
75
91
HanabiEnv(;kw...)
76
-
77
92
Default game params:
78
-
79
93
random_start_player = false,
80
94
seed = -1,
81
95
max_life_tokens = 3,
@@ -86,13 +100,12 @@ colors = 5,
86
100
observation_type = 1,
87
101
players = 2
88
102
"""
89
- mutable struct HanabiEnv <: AbstractEnv
103
+ mutable struct HanabiEnv
90
104
game:: Base.RefValue{Hanabi.LibHanabi.PyHanabiGame}
91
105
state:: Base.RefValue{Hanabi.LibHanabi.PyHanabiState}
92
106
moves:: Vector{Base.RefValue{Hanabi.LibHanabi.PyHanabiMove}}
93
107
observation_encoder:: Base.RefValue{Hanabi.LibHanabi.PyHanabiObservationEncoder}
94
- observation_space:: MultiDiscreteSpace{Int, 1}
95
- action_space:: DiscreteSpace{Int}
108
+ observation_length:: Int
96
109
reward:: HanabiResult
97
110
98
111
function HanabiEnv (;kw... )
@@ -105,29 +118,30 @@ mutable struct HanabiEnv <: AbstractEnv
105
118
new_game (game, length (params), params)
106
119
end
107
120
121
+ game_finalizer (game)
122
+
108
123
state = Ref {HanabiState} ()
124
+ new_state (game, state)
125
+ state_finalizer (state)
109
126
110
127
observation_encoder = Ref {HanabiObservationEncoder} ()
111
128
new_observation_encoder (observation_encoder, game, CANONICAL)
129
+ observation_encoder_finalizer (observation_encoder)
112
130
observation_length = parse (Int, unsafe_string (observation_shape (observation_encoder)))
113
- observation_space = MultiDiscreteSpace (ones (Int, observation_length), zeros (Int, observation_length))
114
131
115
132
n_moves = max_moves (game)
116
- action_space = DiscreteSpace (Int (n_moves))
117
133
moves = [Ref {HanabiMove} () for _ in 1 : n_moves]
118
134
for i in 1 : n_moves
119
135
get_move_by_uid (game, i- 1 , moves[i])
136
+ move_finalizer (moves[i])
120
137
end
121
138
122
- env = new (game, state, moves, observation_encoder, observation_space, action_space , HanabiResult (Int32 (0 ), Int32 (0 )))
139
+ env = new (game, state, moves, observation_encoder, observation_length , HanabiResult (Int32 (0 ), Int32 (0 )))
123
140
reset! (env) # reset immediately
124
141
env
125
142
end
126
143
end
127
144
128
- observation_space (env:: HanabiEnv ) = env. observation_space
129
- action_space (env:: HanabiEnv ) = env. action_space
130
-
131
145
line_sep (x, sep= " =" ) = repeat (sep, 25 ) * x * repeat (sep, 25 )
132
146
133
147
function Base. show (io:: IO , env:: HanabiEnv )
@@ -139,11 +153,22 @@ function Base.show(io::IO, env::HanabiEnv)
139
153
""" )
140
154
end
141
155
156
+ function highlight (s)
157
+ s = replace (s, " R" => Base. text_colors[:red ] * " R" * Base. text_colors[:default ])
158
+ s = replace (s, " G" => Base. text_colors[:green ] * " G" * Base. text_colors[:default ])
159
+ s = replace (s, " B" => Base. text_colors[:blue ] * " B" * Base. text_colors[:default ])
160
+ s = replace (s, " Y" => Base. text_colors[:yellow ] * " Y" * Base. text_colors[:default ])
161
+ s = replace (s, " W" => Base. text_colors[:white ] * " W" * Base. text_colors[:default ])
162
+ s
163
+ end
164
+
142
165
Base. show (io:: IO , game:: Base.RefValue{Hanabi.LibHanabi.PyHanabiGame} ) = print (io, unsafe_string (game_param_string (game)))
143
- Base. show (io:: IO , state:: Base.RefValue{Hanabi.LibHanabi.PyHanabiState} ) = print (io, unsafe_string (state_to_string (state)))
144
- Base. show (io:: IO , obs:: Base.RefValue{Hanabi.LibHanabi.PyHanabiObservation} ) = print (io, unsafe_string (obs_to_string (obs)))
166
+ Base. show (io:: IO , state:: Base.RefValue{Hanabi.LibHanabi.PyHanabiState} ) = print (io, highlight ( " \n " * unsafe_string (state_to_string (state) )))
167
+ Base. show (io:: IO , obs:: Base.RefValue{Hanabi.LibHanabi.PyHanabiObservation} ) = print (io, highlight ( " \n " * unsafe_string (obs_to_string (obs) )))
145
168
146
169
function reset! (env:: HanabiEnv )
170
+ env. state = Ref {HanabiState} ()
171
+ state_finalizer (env. state)
147
172
new_state (env. game, env. state)
148
173
while state_cur_player (env. state) == CHANCE_PLAYER_ID
149
174
state_deal_random_card (env. state)
@@ -167,27 +192,27 @@ function interact!(env::HanabiEnv, move::Base.RefValue{Hanabi.LibHanabi.PyHanabi
167
192
new_score = state_score (env. state)
168
193
env. reward. player = player
169
194
env. reward. score_gain = new_score - old_score
195
+ nothing
196
+ end
170
197
171
- observation = Ref {HanabiObservation} ()
172
- new_observation (env. state, player, observation)
198
+ function observe (env:: HanabiEnv , observer= state_cur_player (env. state))
199
+ raw_obs = Ref {HanabiObservation} ()
200
+ observation_finalizer (raw_obs)
201
+ new_observation (env. state, observer, raw_obs)
173
202
174
- (observation = _encode_observation (observation, env) ,
175
- reward = env. reward. score_gain,
203
+ (observation = raw_obs ,
204
+ reward = env. reward. player == observer ? env . reward . score_gain : Int32 ( 0 ) ,
176
205
isdone = state_end_of_game_status (env. state) != Int (NOT_FINISHED),
177
- raw_obs = observation )
206
+ game = env . game )
178
207
end
179
208
180
- function observe (env:: HanabiEnv , observer= state_cur_player (env. state))
181
- observation = Ref {HanabiObservation} ()
182
- new_observation (env. state, observer, observation)
183
- (observation = _encode_observation (observation, env),
184
- reward = env. reward. player == observer ? env. reward. score_gain : Int32 (0 ),
185
- isdone = state_end_of_game_status (env. state) != Int (NOT_FINISHED),
186
- raw_obs = observation)
209
+ function encode_observation (obs, env)
210
+ encoding = Vector {Int32} (undef, env. observation_length)
211
+ encode_obs (env. observation_encoder, obs, encoding)
212
+ encoding
187
213
end
188
214
189
- function _encode_observation (obs, env)
190
- encoding = Vector {Int32} (undef, length (env. observation_space. low))
215
+ function encode_observation! (obs, env, encoding)
191
216
encode_obs (env. observation_encoder, obs, encoding)
192
217
encoding
193
218
end
196
221
# ## Some Useful APIs
197
222
# ##
198
223
224
+ get_score (env:: HanabiEnv ) = state_score (env. state)
225
+ cur_player (env:: HanabiEnv ) = state_cur_player (env. state)
226
+
199
227
function legal_actions (env:: HanabiEnv )
200
228
actions = Int32[]
201
229
for (i, move) in enumerate (env. moves)
@@ -206,44 +234,39 @@ function legal_actions(env::HanabiEnv)
206
234
actions
207
235
end
208
236
209
- function get_card_knowledge (obs)
210
- knowledges = []
211
- for pid in 0 : obs_num_players (obs)- 1
212
- hand_kd = []
213
- for i in 0 : obs_get_hand_size (obs, pid) - 1
214
- kd = Ref {HanabiCardKnowledge} ()
215
- obs_get_hand_card_knowledge (obs, pid, i, kd)
216
- push! (
217
- hand_kd,
218
- Dict {String, Any} (
219
- " color" => color_was_hinted (kd) > 0 ? COLOR (known_color (kd)) : nothing ,
220
- " rank" => rank_was_hinted (kd) > 0 ? known_rank (kd) : nothing ))
221
- end
222
- push! (knowledges, hand_kd)
237
+ legal_actions! (env:: HanabiEnv , actions:: AbstractVector{Bool} ) = legal_actions! (env, actions, true , false )
238
+ legal_actions! (env:: HanabiEnv , actions:: AbstractVector{T} ) where T<: Number = legal_actions! (env, actions, zero (T), typemin (T))
239
+
240
+ function legal_actions! (env:: HanabiEnv , actions, legal_value, illegal_value)
241
+ for (i, move) in enumerate (env. moves)
242
+ actions[i] = move_is_legal (env. state, move) ? legal_value : illegal_value
223
243
end
224
- knowledges
244
+ actions
225
245
end
226
246
227
- function observed_hands (obs)
228
- hands = Vector{HanabiCard}[]
229
- for pid in 0 : obs_num_players (obs)- 1
230
- cards = HanabiCard[]
231
- for i in 0 : obs_get_hand_size (obs, pid)- 1
232
- card_ref = Ref {HanabiCard} ()
233
- obs_get_hand_card (obs, pid, i, card_ref)
234
- push! (cards, card_ref[])
235
- end
236
- push! (hands, cards)
237
- end
238
- hands
247
+ function get_hand_card_knowledge (obs, pid, i)
248
+ knowledge = Ref {HanabiCardKnowledge} ()
249
+ obs_get_hand_card_knowledge (obs, pid, i, knowledge)
250
+ knowledge
251
+ end
252
+
253
+ function get_hand_card (obs, pid, i)
254
+ card_ref = Ref {HanabiCard} ()
255
+ obs_get_hand_card (obs, pid, i, card_ref)
256
+ card_ref[]
239
257
end
240
258
241
- function discard_pile (obs)
242
- cards = HanabiCard[]
243
- for i in 0 : obs_discard_pile_size (obs)- 1
244
- card_ref = Ref {HanabiCard} ()
245
- obs_get_discard (obs, i, card_ref)
246
- push! (cards, card_ref[])
259
+ rank (knowledge:: Base.RefValue{Hanabi.LibHanabi.PyHanabiCardKnowledge} ) = rank_was_hinted (knowledge) != 0 ? known_rank (knowledge) + 1 : nothing
260
+ rank (card:: Hanabi.LibHanabi.PyHanabiCard ) = card. rank + 1
261
+ color (knowledge:: Base.RefValue{Hanabi.LibHanabi.PyHanabiCardKnowledge} ) = color_was_hinted (knowledge) != 0 ? COLOR (known_color (knowledge)) : nothing
262
+ color (card:: Hanabi.LibHanabi.PyHanabiCard ) = COLOR (card. color)
263
+
264
+ function get_fireworks (game, observation)
265
+ fireworks = Dict {COLOR, Int} ()
266
+ for c in 0 : (num_colors (game) - 1 )
267
+ fireworks[COLOR (c)] = obs_fireworks (observation, c) + 1
247
268
end
248
- cards
269
+ fireworks
249
270
end
271
+
272
+ get_cur_player (env) = cur_player (env) + 1 # pid is 0-based
0 commit comments