Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 1eb14f3

Browse files
authored
add test for SnakeGames (#106)
1 parent 0413be4 commit 1eb14f3

File tree

7 files changed

+26
-12
lines changed

7 files changed

+26
-12
lines changed

src/environments/3rd_party/snake.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using .SnakeGames
2+
13
function SnakeGameEnv(; action_style = MINIMAL_ACTION_SET, kw...)
24
game = SnakeGame(; kw...)
35
n_snakes = length(game.snakes)
@@ -41,19 +43,20 @@ end
4143

4244
RLBase.action_space(env::SnakeGameEnv) = 1:4
4345
RLBase.state(env::SnakeGameEnv) = env.game.board
46+
RLBase.state_space(env::SnakeGameEnv) = Space(fill(false..true, size(env.game.board)))
4447
RLBase.reward(env::SnakeGameEnv{<:Any,SINGLE_AGENT}) =
4548
length(env.game.snakes[]) - env.latest_snakes_length[]
4649
RLBase.reward(env::SnakeGameEnv) = length.(env.game.snakes) .- env.latest_snakes_length
4750
RLBase.is_terminated(env::SnakeGameEnv) = env.is_terminated
4851

49-
RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET,SINGLE_AGENT}) =
52+
RLBase.legal_action_space(env::SnakeGameEnv{FULL_ACTION_SET,SINGLE_AGENT}) =
5053
findall(!=(-env.latest_actions[]), SNAKE_GAME_ACTIONS)
51-
RLBase.get_legal_actions(env::SnakeGameEnv{FULL_ACTION_SET}) =
54+
RLBase.legal_action_space(env::SnakeGameEnv{FULL_ACTION_SET}) =
5255
[findall(!=(-a), SNAKE_GAME_ACTIONS) for a in env.latest_actions]
5356

54-
RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET,SINGLE_AGENT}) =
57+
RLBase.legal_action_space_mask(env::SnakeGameEnv{FULL_ACTION_SET,SINGLE_AGENT}) =
5558
[a != -env.latest_actions[] for a in SNAKE_GAME_ACTIONS]
56-
RLBase.get_legal_actions_mask(env::SnakeGameEnv{FULL_ACTION_SET}) =
59+
RLBase.legal_action_space_mask(env::SnakeGameEnv{FULL_ACTION_SET}) =
5760
[[x != -a for x in SNAKE_GAME_ACTIONS] for a in env.latest_actions]
5861

5962
function RLBase.reset!(env::SnakeGameEnv)

src/environments/wrappers/ActionTransformedEnv.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(e
3232
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) =
3333
state_space(env.env, ss)
3434

35-
RLBase.action_space(env::ActionTransformedEnv) =
36-
env.action_space_mapping(action_space(env.env))
37-
RLBase.legal_action_space(env::ActionTransformedEnv) =
38-
env.action_space_mapping(legal_action_space(env.env))
35+
RLBase.action_space(env::ActionTransformedEnv, args...) =
36+
env.action_space_mapping(action_space(env.env), args...)
37+
38+
RLBase.legal_action_space(env::ActionTransformedEnv, args...) =
39+
env.action_space_mapping(legal_action_space(env.env), args...)
3940

4041
(env::ActionTransformedEnv)(action, args...; kwargs...) =
4142
env.env(env.action_mapping(action), args...; kwargs...)

src/environments/wrappers/DefaultStateStyle.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
1414
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S
1515

1616
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
17-
if f (:DefaultStateStyle, :state, :state_space)
17+
if f (:DefaultStateStyle, )
1818
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
1919
$f(x.env, args...; kwargs...)
2020
end

src/environments/wrappers/StateOverriddenEnv.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ export StateOverriddenEnv
44
StateOverriddenEnv(f, env)
55
66
Apply `f` to override `state(env)`.
7+
8+
!!! note
9+
If the meaning of state space is changed after apply `f`, one should
10+
manually redefine the `RLBase.state_space(env::YourSpecificEnv)`.
711
"""
812
struct StateOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
913
env::E
@@ -15,7 +19,7 @@ StateOverriddenEnv(f) = env -> StateOverriddenEnv(f, env)
1519
(env::StateOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)
1620

1721
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
18-
if f (:state, :state_space)
22+
if f (:state, )
1923
@eval RLBase.$f(x::StateOverriddenEnv, args...; kwargs...) =
2024
$f(x.env, args...; kwargs...)
2125
end
@@ -26,5 +30,6 @@ RLBase.state(env::StateOverriddenEnv, args...; kwargs...) =
2630

2731
RLBase.state(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
2832
env.f(state(env.env, ss))
33+
2934
RLBase.state_space(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
30-
state_space(env.env, ss)
35+
state_space(env.env, ss)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
44
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
66
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
7+
SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429"
78
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
89
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@testset "snake game env" begin
2+
env = SnakeGameEnv()
3+
RLBase.test_runnable!(env)
4+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ReinforcementLearningEnvironments
44
using ArcadeLearningEnvironment
55
using PyCall
66
using OpenSpiel
7-
# using SnakeGames
7+
using SnakeGames
88
using Random
99
using StableRNGs
1010
using Statistics

0 commit comments

Comments
 (0)