Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 05a6e22

Browse files
authored
Add wrapper for snake game (#80)
* add snake game * update readme * ignore SnakeGameEnv in test
1 parent a9b4e01 commit 05a6e22

File tree

7 files changed

+72
-1
lines changed

7 files changed

+72
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
2525
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
2626
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
2727
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
28+
SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429"
2829
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2930

3031
[targets]
31-
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs", "OpenSpiel"]
32+
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs", "OpenSpiel", "SnakeGames"]

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ By default, only some basic environments are installed. If you want to use some
3636
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | |
3737
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)| Tested with `[email protected]`|
3838
| `OpenSpielEnv` | [OpenSpiel.jl](https://github.com/JuliaReinforcementLearning/OpenSpiel.jl) | |
39+
| `SnakeGameEnv` | [SnakeGames.jl](https://github.com/JuliaReinforcementLearning/SnakeGames.jl) | `SingleAgent`/`Multi-Agent`, `FullActionSet`/`MinimalActionSet`|
3940

4041
## Usage
4142

src/ReinforcementLearningEnvironments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function __init__()
2020
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include("environments/gym.jl")
2121
@require POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" include("environments/mdp.jl")
2222
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("environments/open_spiel.jl")
23+
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include("environments/snake.jl")
2324
end
2425

2526
end # module

src/environments/snake.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
function SnakeGameEnv(;action_style=MINIMAL_ACTION_SET,kw...)
2+
game = SnakeGame(;kw...)
3+
n_snakes = length(game.snakes)
4+
num_agent_style = n_snakes == 1 ? SINGLE_AGENT : MultiAgent{n_snakes}()
5+
SnakeGameEnv{action_style, num_agent_style, typeof(game)}(
6+
game,
7+
map(length, game.snakes),
8+
Vector{CartesianIndex{2}}(undef, length(game.snakes)),
9+
false
10+
)
11+
end
12+
13+
RLBase.ActionStyle(env::SnakeGameEnv{A}) where A = A
14+
RLBase.NumAgentStyle(env::SnakeGameEnv{<:Any, N}) where {N} = N
15+
RLBase.DynamicStyle(env::SnakeGameEnv{<:Any, SINGLE_AGENT}) = SEQUENTIAL
16+
RLBase.DynamicStyle(env::SnakeGameEnv{<:Any, <:MultiAgent}) = SIMULTANEOUS
17+
18+
const SNAKE_GAME_ACTIONS = (
19+
CartesianIndex(-1, 0),
20+
CartesianIndex(1, 0),
21+
CartesianIndex(0, 1),
22+
CartesianIndex(0, -1)
23+
)
24+
25+
function (env::SnakeGameEnv{A})(actions::Vector{CartesianIndex{2}}) where {A}
26+
if A === MINIMAL_ACTION_SET
27+
# avoid turn back
28+
actions = [a_new == -a_old ? a_old : a_new for (a_new, a_old) in zip(actions, env.latest_actions)]
29+
end
30+
31+
env.latest_actions .= actions
32+
map!(length, env.latest_snakes_length, env.game.snakes)
33+
env.is_terminated = !env.game(actions)
34+
end
35+
36+
(env::SnakeGameEnv)(action::Int) = env([SNAKE_GAME_ACTIONS[action]])
37+
(env::SnakeGameEnv)(actions::Vector{Int}) = env(map(a -> SNAKE_GAME_ACTIONS[a], actions))
38+
39+
RLBase.get_actions(env::SnakeGameEnv) = 1:4
40+
RLBase.get_state(env::SnakeGameEnv) = env.game.board
41+
RLBase.get_reward(env::SnakeGameEnv{<:Any, SINGLE_AGENT}) = length(env.game.snakes[]) - env.latest_snakes_length[]
42+
RLBase.get_reward(env::SnakeGameEnv) = length.(env.game.snakes) .- env.latest_snakes_length
43+
RLBase.get_terminal(env::SnakeGameEnv) = env.is_terminated
44+
45+
RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET, SINGLE_AGENT}) = findall(!=(-env.latest_actions[]), SNAKE_GAME_ACTIONS)
46+
RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET}) = [findall(!=(-a), SNAKE_GAME_ACTIONS) for a in env.latest_actions]
47+
48+
RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET, SINGLE_AGENT}) = [a!=-env.latest_actions[] for a in SNAKE_GAME_ACTIONS]
49+
RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET}) = [[x!=-a for x in SNAKE_GAME_ACTIONS] for a in env.latest_actions]
50+
51+
function RLBase.reset!(env::SnakeGameEnv)
52+
SnakeGames.reset!(env.game)
53+
env.is_terminated = false
54+
fill!(env.latest_actions, CartesianIndex(0,0))
55+
map!(length, env.latest_snakes_length, env.game.snakes)
56+
end
57+
58+
Base.display(env::SnakeGameEnv) = display(env.game)

src/environments/structs.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,11 @@ mutable struct OpenSpielEnv{S,T,ST,G,R} <: AbstractEnv
4444
rng::R
4545
end
4646
export OpenSpielEnv
47+
48+
mutable struct SnakeGameEnv{A,N,G} <: AbstractEnv
49+
game::G
50+
latest_snakes_length::Vector{Int}
51+
latest_actions::Vector{CartesianIndex{2}}
52+
is_terminated::Bool
53+
end
54+
export SnakeGameEnv

test/environments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
for env_exp in [
3131
# :(basic_ViZDoom_env()), # comment out due to https://github.com/JuliaReinforcementLearning/ViZDoom.jl/issues/7
32+
# (:(SnakeGameEnv())), # avoid breaking CI
3233
:(POMDPEnv(TigerPOMDP())),
3334
:(MDPEnv(MountainCar())),
3435
:(MountainCarEnv()),

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using PyCall
66
using POMDPs
77
using POMDPModels
88
using OpenSpiel
9+
using SnakeGames
910
using Random
1011

1112
@testset "ReinforcementLearningEnvironments" begin

0 commit comments

Comments
 (0)