Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Add wrapper for snake game #80

Merged
merged 3 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs", "OpenSpiel"]
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs", "OpenSpiel", "SnakeGames"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ By default, only some basic environments are installed. If you want to use some
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | |
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)| Tested with `[email protected]`|
| `OpenSpielEnv` | [OpenSpiel.jl](https://github.com/JuliaReinforcementLearning/OpenSpiel.jl) | |
| `SnakeGameEnv` | [SnakeGames.jl](https://github.com/JuliaReinforcementLearning/SnakeGames.jl) | `SingleAgent`/`Multi-Agent`, `FullActionSet`/`MinimalActionSet`|

## Usage

Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningEnvironments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function __init__()
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include("environments/gym.jl")
@require POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" include("environments/mdp.jl")
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("environments/open_spiel.jl")
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include("environments/snake.jl")
end

end # module
58 changes: 58 additions & 0 deletions src/environments/snake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
function SnakeGameEnv(;action_style=MINIMAL_ACTION_SET,kw...)
game = SnakeGame(;kw...)
n_snakes = length(game.snakes)
num_agent_style = n_snakes == 1 ? SINGLE_AGENT : MultiAgent{n_snakes}()
SnakeGameEnv{action_style, num_agent_style, typeof(game)}(
game,
map(length, game.snakes),
Vector{CartesianIndex{2}}(undef, length(game.snakes)),
false
)
end

RLBase.ActionStyle(env::SnakeGameEnv{A}) where A = A
RLBase.NumAgentStyle(env::SnakeGameEnv{<:Any, N}) where {N} = N
RLBase.DynamicStyle(env::SnakeGameEnv{<:Any, SINGLE_AGENT}) = SEQUENTIAL
RLBase.DynamicStyle(env::SnakeGameEnv{<:Any, <:MultiAgent}) = SIMULTANEOUS

const SNAKE_GAME_ACTIONS = (
CartesianIndex(-1, 0),
CartesianIndex(1, 0),
CartesianIndex(0, 1),
CartesianIndex(0, -1)
)

function (env::SnakeGameEnv{A})(actions::Vector{CartesianIndex{2}}) where {A}
if A === MINIMAL_ACTION_SET
# avoid turn back
actions = [a_new == -a_old ? a_old : a_new for (a_new, a_old) in zip(actions, env.latest_actions)]
end

env.latest_actions .= actions
map!(length, env.latest_snakes_length, env.game.snakes)
env.is_terminated = !env.game(actions)
end

(env::SnakeGameEnv)(action::Int) = env([SNAKE_GAME_ACTIONS[action]])
(env::SnakeGameEnv)(actions::Vector{Int}) = env(map(a -> SNAKE_GAME_ACTIONS[a], actions))

RLBase.get_actions(env::SnakeGameEnv) = 1:4
RLBase.get_state(env::SnakeGameEnv) = env.game.board
RLBase.get_reward(env::SnakeGameEnv{<:Any, SINGLE_AGENT}) = length(env.game.snakes[]) - env.latest_snakes_length[]
RLBase.get_reward(env::SnakeGameEnv) = length.(env.game.snakes) .- env.latest_snakes_length
RLBase.get_terminal(env::SnakeGameEnv) = env.is_terminated

RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET, SINGLE_AGENT}) = findall(!=(-env.latest_actions[]), SNAKE_GAME_ACTIONS)
RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET}) = [findall(!=(-a), SNAKE_GAME_ACTIONS) for a in env.latest_actions]

RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET, SINGLE_AGENT}) = [a!=-env.latest_actions[] for a in SNAKE_GAME_ACTIONS]
RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET}) = [[x!=-a for x in SNAKE_GAME_ACTIONS] for a in env.latest_actions]

function RLBase.reset!(env::SnakeGameEnv)
SnakeGames.reset!(env.game)
env.is_terminated = false
fill!(env.latest_actions, CartesianIndex(0,0))
map!(length, env.latest_snakes_length, env.game.snakes)
end

Base.display(env::SnakeGameEnv) = display(env.game)
8 changes: 8 additions & 0 deletions src/environments/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ mutable struct OpenSpielEnv{S,T,ST,G,R} <: AbstractEnv
rng::R
end
export OpenSpielEnv

mutable struct SnakeGameEnv{A,N,G} <: AbstractEnv
game::G
latest_snakes_length::Vector{Int}
latest_actions::Vector{CartesianIndex{2}}
is_terminated::Bool
end
export SnakeGameEnv
1 change: 1 addition & 0 deletions test/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

for env_exp in [
# :(basic_ViZDoom_env()), # comment out due to https://github.com/JuliaReinforcementLearning/ViZDoom.jl/issues/7
# (:(SnakeGameEnv())), # avoid breaking CI
:(POMDPEnv(TigerPOMDP())),
:(MDPEnv(MountainCar())),
:(MountainCarEnv()),
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using PyCall
using POMDPs
using POMDPModels
using OpenSpiel
using SnakeGames
using Random

@testset "ReinforcementLearningEnvironments" begin
Expand Down