Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 82a21f3

Browse files
authored
update Hanabi.jl (#12)
* update Hanabi.jl * support Julia 1.0 * update README.md
1 parent d5836d8 commit 82a21f3

File tree

4 files changed

+93
-83
lines changed

4 files changed

+93
-83
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ language: julia
33
os:
44
- linux
55
julia:
6+
- 1.0
67
- 1.1
78
- nightly
89
notifications:

README.md

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,16 @@ Install:
1010
(v1.1) pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl
1111
```
1212

13-
**TODO:**
14-
15-
- [x] Add a Docker file for quick test.
16-
```
17-
$ docker run -it --rm juliareinforcementlearning/reinforcementlearningenvironments
18-
```
19-
- [ ] Add benchmarks.
20-
2113
## API
2214

2315
| Method | Description |
2416
| :--- | :--------- |
2517
| `observe(env, observer=:default)` | Return the observation of `env` from the view of `observer`|
2618
| `reset!(env)` | Reset `env` to an initial state|
2719
| `interact!(env, action)` | Send `action` to `env`. For some multi-agent environments, `action` can be a dictionary of actions from different agents|
20+
| **Optional Methods** | |
2821
| `action_space(env)` | Return the action space of `env` |
2922
| `observation_space(env)` | Return the observation space of `env`|
30-
| **Optional Methods** | |
3123
| `render(env)` | Show the current state of environment |
3224

3325
## Supported Environments
@@ -59,7 +51,7 @@ By default, only some basic environments are installed. If you want to use some
5951
| `AtariEnv` | [ArcadeLearningEnvironment.jl](https://github.com/JuliaReinforcementLearning/ArcadeLearningEnvironment.jl) | |
6052
| `ViZDoomEnv` | [ViZDoom.jl](https://github.com/JuliaReinforcementLearning/ViZDoom.jl) | Currently only a basic environment is supported. (By calling `basic_ViZDoom_env()`)|
6153
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | You need to manually install `gym` first |
62-
| `HanabiEnv` | [Hanabi.jl](https://github.com/JuliaReinforcementLearning/Hanabi.jl) | `Hanabi.jl` hasn't been registered yet. Install by `pkg> add https://github.com/JuliaReinforcementLearning/Hanabi.jl` |
54+
| `HanabiEnv` | [Hanabi.jl](https://github.com/JuliaReinforcementLearning/Hanabi.jl) | Hanabi is a turn based multi-player environment, the API is slightly different from the environments above.|
6355

6456
**TODO:**
6557

@@ -72,7 +64,7 @@ Take the `AtariEnv` for example:
7264

7365
1. Install this package by:
7466
```julia
75-
(v1.1) pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl
67+
(v1.1) pkg> add ReinforcementLearningEnvironments
7668
```
7769
2. Install corresponding dependent package by:
7870
```julia

src/environments/hanabi.jl

Lines changed: 89 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Hanabi
22

3-
export HanabiEnv, legal_actions
3+
export HanabiEnv, legal_actions, observe, reset!, interact!
44
export PlayCard, DiscardCard, RevealColor, RevealRank, parse_move
5+
export cur_player, get_score, get_fireworks, encode_observation, encode_observation!, legal_actions!, legal_actions, get_cur_player
56

67
@enum HANABI_OBSERVATION_ENCODER_TYPE CANONICAL
78
@enum COLOR R Y G W B
@@ -11,43 +12,58 @@ export PlayCard, DiscardCard, RevealColor, RevealRank, parse_move
1112
const CHANCE_PLAYER_ID = -1
1213
const COLORS_DICT = Dict(string(x) => x for x in instances(COLOR))
1314

15+
###
16+
### finalizers
17+
###
18+
19+
move_finalizer(x) = finalizer(m -> delete_move(m), x)
20+
history_item_finalizer(x) = finalizer(h -> delete_history_item(h), x)
21+
game_finalizer(x) = finalizer(g -> delete_game(g), x)
22+
observation_finalizer(x) = finalizer(o -> delete_observation(o), x)
23+
observation_encoder_finalizer(x) = finalizer(e -> delete_observation_encoder(e), x)
24+
state_finalizer(x) = finalizer(s -> delete_state(s), x)
25+
1426
###
1527
### moves
1628
###
1729

1830
function PlayCard(card_idx::Int)
1931
m = Ref{HanabiMove}()
2032
get_play_move(card_idx - 1, m)
33+
move_finalizer(m)
2134
m
2235
end
2336

2437
function DiscardCard(card_idx::Int)
2538
m = Ref{HanabiMove}()
2639
get_discard_move(card_idx - 1, m)
40+
move_finalizer(m)
2741
m
2842
end
2943

3044
function RevealColor(target_offset::Int, color::COLOR)
3145
m = Ref{HanabiMove}()
3246
get_reveal_color_move(target_offset, color, m)
47+
move_finalizer(m)
3348
m
3449
end
3550

3651
function RevealRank(target_offset::Int, rank::Int)
3752
m = Ref{HanabiMove}()
3853
get_reveal_rank_move(target_offset, rank - 1, m)
54+
move_finalizer(m)
3955
m
4056
end
4157

4258
function parse_move(s::String)
4359
m = match(r"PlayCard\((?<card_idx>[1-5])\)", s)
44-
!isnothing(m) && return PlayCard(parse(Int, m[:card_idx]))
60+
!(m === nothing) && return PlayCard(parse(Int, m[:card_idx]))
4561
m = match(r"DiscardCard\((?<card_idx>[1-5])\)", s)
46-
!isnothing(m) && return DiscardCard(parse(Int, m[:card_idx]))
62+
!(m === nothing) && return DiscardCard(parse(Int, m[:card_idx]))
4763
m = match(r"RevealColor\((?<target>[1-5]),(?<color>[RYGWB])\)", s)
48-
!isnothing(m) && return RevealColor(parse(Int, m[:target]), COLORS_DICT[m[:color]])
64+
!(m === nothing) && return RevealColor(parse(Int, m[:target]), COLORS_DICT[m[:color]])
4965
m = match(r"RevealRank\((?<target>[1-5]),(?<rank>[1-5])\)", s)
50-
!isnothing(m) && return RevealRank(parse(Int, m[:target]), parse(Int, m[:rank]))
66+
!(m === nothing) && return RevealRank(parse(Int, m[:target]), parse(Int, m[:rank]))
5167
return nothing
5268
end
5369

@@ -73,9 +89,7 @@ end
7389

7490
"""
7591
HanabiEnv(;kw...)
76-
7792
Default game params:
78-
7993
random_start_player = false,
8094
seed = -1,
8195
max_life_tokens = 3,
@@ -86,13 +100,12 @@ colors = 5,
86100
observation_type = 1,
87101
players = 2
88102
"""
89-
mutable struct HanabiEnv <: AbstractEnv
103+
mutable struct HanabiEnv
90104
game::Base.RefValue{Hanabi.LibHanabi.PyHanabiGame}
91105
state::Base.RefValue{Hanabi.LibHanabi.PyHanabiState}
92106
moves::Vector{Base.RefValue{Hanabi.LibHanabi.PyHanabiMove}}
93107
observation_encoder::Base.RefValue{Hanabi.LibHanabi.PyHanabiObservationEncoder}
94-
observation_space::MultiDiscreteSpace{Int, 1}
95-
action_space::DiscreteSpace{Int}
108+
observation_length::Int
96109
reward::HanabiResult
97110

98111
function HanabiEnv(;kw...)
@@ -105,29 +118,30 @@ mutable struct HanabiEnv <: AbstractEnv
105118
new_game(game, length(params), params)
106119
end
107120

121+
game_finalizer(game)
122+
108123
state = Ref{HanabiState}()
124+
new_state(game, state)
125+
state_finalizer(state)
109126

110127
observation_encoder = Ref{HanabiObservationEncoder}()
111128
new_observation_encoder(observation_encoder, game, CANONICAL)
129+
observation_encoder_finalizer(observation_encoder)
112130
observation_length = parse(Int, unsafe_string(observation_shape(observation_encoder)))
113-
observation_space = MultiDiscreteSpace(ones(Int, observation_length), zeros(Int, observation_length))
114131

115132
n_moves = max_moves(game)
116-
action_space = DiscreteSpace(Int(n_moves))
117133
moves = [Ref{HanabiMove}() for _ in 1:n_moves]
118134
for i in 1:n_moves
119135
get_move_by_uid(game, i-1, moves[i])
136+
move_finalizer(moves[i])
120137
end
121138

122-
env = new(game, state, moves, observation_encoder, observation_space, action_space, HanabiResult(Int32(0), Int32(0)))
139+
env = new(game, state, moves, observation_encoder, observation_length, HanabiResult(Int32(0), Int32(0)))
123140
reset!(env) # reset immediately
124141
env
125142
end
126143
end
127144

128-
observation_space(env::HanabiEnv) = env.observation_space
129-
action_space(env::HanabiEnv) = env.action_space
130-
131145
line_sep(x, sep="=") = repeat(sep, 25) * x * repeat(sep, 25)
132146

133147
function Base.show(io::IO, env::HanabiEnv)
@@ -139,11 +153,22 @@ function Base.show(io::IO, env::HanabiEnv)
139153
""")
140154
end
141155

156+
function highlight(s)
157+
s = replace(s, "R" => Base.text_colors[:red] * "R" * Base.text_colors[:default])
158+
s = replace(s, "G" => Base.text_colors[:green] * "G" * Base.text_colors[:default])
159+
s = replace(s, "B" => Base.text_colors[:blue] * "B" * Base.text_colors[:default])
160+
s = replace(s, "Y" => Base.text_colors[:yellow] * "Y" * Base.text_colors[:default])
161+
s = replace(s, "W" => Base.text_colors[:white] * "W" * Base.text_colors[:default])
162+
s
163+
end
164+
142165
Base.show(io::IO, game::Base.RefValue{Hanabi.LibHanabi.PyHanabiGame}) = print(io, unsafe_string(game_param_string(game)))
143-
Base.show(io::IO, state::Base.RefValue{Hanabi.LibHanabi.PyHanabiState}) = print(io, unsafe_string(state_to_string(state)))
144-
Base.show(io::IO, obs::Base.RefValue{Hanabi.LibHanabi.PyHanabiObservation}) = print(io, unsafe_string(obs_to_string(obs)))
166+
Base.show(io::IO, state::Base.RefValue{Hanabi.LibHanabi.PyHanabiState}) = print(io, highlight("\n" * unsafe_string(state_to_string(state))))
167+
Base.show(io::IO, obs::Base.RefValue{Hanabi.LibHanabi.PyHanabiObservation}) = print(io, highlight("\n" * unsafe_string(obs_to_string(obs))))
145168

146169
function reset!(env::HanabiEnv)
170+
env.state = Ref{HanabiState}()
171+
state_finalizer(env.state)
147172
new_state(env.game, env.state)
148173
while state_cur_player(env.state) == CHANCE_PLAYER_ID
149174
state_deal_random_card(env.state)
@@ -167,27 +192,27 @@ function interact!(env::HanabiEnv, move::Base.RefValue{Hanabi.LibHanabi.PyHanabi
167192
new_score = state_score(env.state)
168193
env.reward.player = player
169194
env.reward.score_gain = new_score - old_score
195+
nothing
196+
end
170197

171-
observation = Ref{HanabiObservation}()
172-
new_observation(env.state, player, observation)
198+
function observe(env::HanabiEnv, observer=state_cur_player(env.state))
199+
raw_obs = Ref{HanabiObservation}()
200+
observation_finalizer(raw_obs)
201+
new_observation(env.state, observer, raw_obs)
173202

174-
(observation = _encode_observation(observation, env),
175-
reward = env.reward.score_gain,
203+
(observation = raw_obs,
204+
reward = env.reward.player == observer ? env.reward.score_gain : Int32(0),
176205
isdone = state_end_of_game_status(env.state) != Int(NOT_FINISHED),
177-
raw_obs = observation)
206+
game = env.game)
178207
end
179208

180-
function observe(env::HanabiEnv, observer=state_cur_player(env.state))
181-
observation = Ref{HanabiObservation}()
182-
new_observation(env.state, observer, observation)
183-
(observation = _encode_observation(observation, env),
184-
reward = env.reward.player == observer ? env.reward.score_gain : Int32(0),
185-
isdone = state_end_of_game_status(env.state) != Int(NOT_FINISHED),
186-
raw_obs = observation)
209+
function encode_observation(obs, env)
210+
encoding = Vector{Int32}(undef, env.observation_length)
211+
encode_obs(env.observation_encoder, obs, encoding)
212+
encoding
187213
end
188214

189-
function _encode_observation(obs, env)
190-
encoding = Vector{Int32}(undef, length(env.observation_space.low))
215+
function encode_observation!(obs, env, encoding)
191216
encode_obs(env.observation_encoder, obs, encoding)
192217
encoding
193218
end
@@ -196,6 +221,9 @@ end
196221
### Some Useful APIs
197222
###
198223

224+
get_score(env::HanabiEnv) = state_score(env.state)
225+
cur_player(env::HanabiEnv) = state_cur_player(env.state)
226+
199227
function legal_actions(env::HanabiEnv)
200228
actions = Int32[]
201229
for (i, move) in enumerate(env.moves)
@@ -206,44 +234,39 @@ function legal_actions(env::HanabiEnv)
206234
actions
207235
end
208236

209-
function get_card_knowledge(obs)
210-
knowledges = []
211-
for pid in 0:obs_num_players(obs)-1
212-
hand_kd = []
213-
for i in 0:obs_get_hand_size(obs, pid) - 1
214-
kd = Ref{HanabiCardKnowledge}()
215-
obs_get_hand_card_knowledge(obs, pid, i, kd)
216-
push!(
217-
hand_kd,
218-
Dict{String, Any}(
219-
"color" => color_was_hinted(kd) > 0 ? COLOR(known_color(kd)) : nothing,
220-
"rank" => rank_was_hinted(kd) > 0 ? known_rank(kd) : nothing))
221-
end
222-
push!(knowledges, hand_kd)
237+
legal_actions!(env::HanabiEnv, actions::AbstractVector{Bool}) = legal_actions!(env, actions, true, false)
238+
legal_actions!(env::HanabiEnv, actions::AbstractVector{T}) where T<:Number = legal_actions!(env, actions, zero(T), typemin(T))
239+
240+
function legal_actions!(env::HanabiEnv, actions, legal_value, illegal_value)
241+
for (i, move) in enumerate(env.moves)
242+
actions[i] = move_is_legal(env.state, move) ? legal_value : illegal_value
223243
end
224-
knowledges
244+
actions
225245
end
226246

227-
function observed_hands(obs)
228-
hands = Vector{HanabiCard}[]
229-
for pid in 0:obs_num_players(obs)-1
230-
cards = HanabiCard[]
231-
for i in 0:obs_get_hand_size(obs, pid)-1
232-
card_ref = Ref{HanabiCard}()
233-
obs_get_hand_card(obs, pid, i, card_ref)
234-
push!(cards, card_ref[])
235-
end
236-
push!(hands, cards)
237-
end
238-
hands
247+
function get_hand_card_knowledge(obs, pid, i)
248+
knowledge = Ref{HanabiCardKnowledge}()
249+
obs_get_hand_card_knowledge(obs, pid, i, knowledge)
250+
knowledge
251+
end
252+
253+
function get_hand_card(obs, pid, i)
254+
card_ref = Ref{HanabiCard}()
255+
obs_get_hand_card(obs, pid, i, card_ref)
256+
card_ref[]
239257
end
240258

241-
function discard_pile(obs)
242-
cards = HanabiCard[]
243-
for i in 0:obs_discard_pile_size(obs)-1
244-
card_ref = Ref{HanabiCard}()
245-
obs_get_discard(obs, i, card_ref)
246-
push!(cards, card_ref[])
259+
rank(knowledge::Base.RefValue{Hanabi.LibHanabi.PyHanabiCardKnowledge}) = rank_was_hinted(knowledge) != 0 ? known_rank(knowledge) + 1 : nothing
260+
rank(card::Hanabi.LibHanabi.PyHanabiCard) = card.rank + 1
261+
color(knowledge::Base.RefValue{Hanabi.LibHanabi.PyHanabiCardKnowledge}) = color_was_hinted(knowledge) != 0 ? COLOR(known_color(knowledge)) : nothing
262+
color(card::Hanabi.LibHanabi.PyHanabiCard) = COLOR(card.color)
263+
264+
function get_fireworks(game, observation)
265+
fireworks = Dict{COLOR, Int}()
266+
for c in 0:(num_colors(game) - 1)
267+
fireworks[COLOR(c)] = obs_fireworks(observation, c) + 1
247268
end
248-
cards
269+
fireworks
249270
end
271+
272+
get_cur_player(env) = cur_player(env) + 1 # pid is 0-based

test/environments.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,11 @@
2020

2121
function basic_env_test(env::HanabiEnv, n=100)
2222
reset!(env)
23-
os = observation_space(env)
24-
as = action_space(env)
25-
@test os isa AbstractSpace
26-
@test as isa AbstractSpace
2723
@test reset!(env) == nothing
2824
for _ in 1:n
2925
a = rand(legal_actions(env))
30-
@test a in as
3126
interact!(env, a)
3227
obs, reward, isdone = observe(env)
33-
@test obs in os
3428
if isdone
3529
reset!(env)
3630
end

0 commit comments

Comments
 (0)