Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit d5836d8

Browse files
jbreafindmyway
authored andcommitted
add DiscreteMazeEnv (#10)
1 parent 7443597 commit d5836d8

File tree

4 files changed

+198
-1
lines changed

4 files changed

+198
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ By default, only some basic environments are installed. If you want to use some
4242
- PendulumEnv
4343
- MDPEnv
4444
- POMDPEnv
45+
- DiscreteMazeEnv
4546
- SimpleMDPEnv
4647
- deterministic_MDP
4748
- absorbing_deterministic_tree_MDP
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
include("cart_pole.jl")
22
include("mountain_car.jl")
33
include("pendulum.jl")
4-
include("mdp.jl")
4+
include("mdp.jl")
5+
include("discrete_maze.jl")
6+
using .DiscreteMaze
7+
export DiscreteMazeEnv
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
module DiscreteMaze
2+
using Random, StatsBase, SparseArrays, GR, ..ReinforcementLearningEnvironments
3+
export DiscreteMazeEnv
4+
5+
function emptymaze(dimx, dimy)
6+
maze = ones(Int, dimx, dimy)
7+
maze[1,:] .= maze[end,:] .= maze[:, 1] .= maze[:, end] .= 0 # borders
8+
return maze
9+
end
10+
iswall(maze, pos) = maze[pos] == 0
11+
isinsideframe(maze, i::Int) = isinsideframe(maze, CartesianIndices(maze)[i])
12+
isinsideframe(maze, i) = i[1] > 1 && i[2] > 1 && i[1] < size(maze, 1) && i[2] < size(maze, 2)
13+
14+
const UP = CartesianIndex(0, -1)
15+
const DOWN = CartesianIndex(0, 1)
16+
const LEFT = CartesianIndex(-1, 0)
17+
const RIGHT = CartesianIndex(1, 0)
18+
function orthogonal_directions(dir)
19+
dir[1] == 0 && return (LEFT, RIGHT)
20+
return (UP, DOWN)
21+
end
22+
23+
function is_wall_neighbour(maze, pos)
24+
for dir in (UP, DOWN, LEFT, RIGHT, UP + RIGHT, UP + LEFT, DOWN + RIGHT, DOWN + LEFT)
25+
iswall(maze, pos + dir) && return true
26+
end
27+
return false
28+
end
29+
function is_wall_tangential(maze, pos, dir)
30+
for ortho_dir in orthogonal_directions(dir)
31+
iswall(maze, pos + ortho_dir) && return true
32+
end
33+
return false
34+
end
35+
is_wall_ahead(maze, pos, dir) = iswall(maze, pos + dir)
36+
37+
function addrandomwall!(maze; rng = Random.GLOBAL_RNG)
38+
potential_startpos = filter(x -> !is_wall_neighbour(maze, x),
39+
findall(x -> x != 0, maze))
40+
if potential_startpos == []
41+
@warn("Cannot add a random wall.")
42+
return maze
43+
end
44+
pos = rand(rng, potential_startpos)
45+
direction = rand(rng, (UP, DOWN, LEFT, RIGHT))
46+
while true
47+
maze[pos] = 0
48+
pos += direction
49+
is_wall_tangential(maze, pos, direction) && break
50+
if is_wall_ahead(maze, pos, direction)
51+
maze[pos] = 0
52+
break
53+
end
54+
end
55+
return maze
56+
end
57+
58+
function n_effective(n, f, list)
59+
N = n === nothing ? div(length(list), Int(1/f)) : n
60+
min(N, length(list))
61+
end
62+
function breaksomewalls!(m; f = 1/50, n = nothing, rng = Random.GLOBAL_RNG)
63+
wallpos = Int[]
64+
for i in 1:length(m)
65+
iswall(m, i) && isinsideframe(m, i) && push!(wallpos, i)
66+
end
67+
pos = sample(rng, wallpos, n_effective(n, f, wallpos), replace = false)
68+
m[pos] .= 1
69+
m
70+
end
71+
function addobstacles!(m; f = 1/100, n = nothing, rng = Random.GLOBAL_RNG)
72+
nz = findall(x -> x == 1, reshape(m, :))
73+
pos = sample(rng, nz, n_effective(n, f, nz), replace = false)
74+
m[pos] .= 0
75+
m
76+
end
77+
function setTandR!(d)
78+
for s in LinearIndices(d.maze)[findall(x -> x != 0, d.maze)]
79+
setTandR!(d, s)
80+
end
81+
end
82+
function setTandR!(d, s)
83+
T = d.mdp.trans_probs
84+
R = d.mdp.reward
85+
goals = d.goals
86+
ns = length(d.mdp.observation_space)
87+
maze = d.maze
88+
if s in goals
89+
idx_goals = findfirst(x -> x == s, goals)
90+
R.value[s] = d.goalrewards[idx_goals]
91+
end
92+
pos = CartesianIndices(maze)[s]
93+
for (aind, a) in enumerate((UP, DOWN, LEFT, RIGHT))
94+
nextpos = maze[pos + a] == 0 ? pos : pos + a
95+
if d.neighbourstateweight > 0
96+
positions = [nextpos]
97+
weights = [1.]
98+
for dir in (UP, DOWN, LEFT, RIGHT)
99+
if maze[nextpos + dir] != 0
100+
push!(positions, nextpos + dir)
101+
push!(weights, d.neighbourstateweight)
102+
end
103+
end
104+
states = LinearIndices(maze)[positions]
105+
weights /= sum(weights)
106+
T[aind, s] = sparsevec(states, weights, ns)
107+
else
108+
nexts = LinearIndices(maze)[nextpos]
109+
T[aind, s] = sparsevec([nexts], [1.], ns)
110+
end
111+
end
112+
end
113+
114+
"""
115+
struct DiscreteMazeEnv
116+
mdp::MDP
117+
maze::Array{Int, 2}
118+
goals::Array{Int, 1}
119+
statefrommaze::Array{Int, 1}
120+
mazefromstate::Array{Int, 1}
121+
"""
122+
struct DiscreteMazeEnv{T}
123+
mdp::T
124+
maze::Array{Int, 2}
125+
goals::Array{Int, 1}
126+
goalrewards::Array{Float64, 1}
127+
neighbourstateweight::Float64
128+
end
129+
"""
130+
DiscreteMazeEnv(; nx = 40, ny = 40, nwalls = div(nx*ny, 20), ngoals = 1,
131+
goalrewards = 1, stepcost = 0, stochastic = false,
132+
neighbourstateweight = .05, rng = Random.GLOBAL_RNG)
133+
134+
Returns a `DiscreteMazeEnv` of width `nx` and height `ny` with `nwalls` walls and
135+
`ngoals` goal locations with reward `goalreward` (a list of different rewards
136+
for the different goal states or constant reward for all goals), cost of moving
137+
`stepcost` (reward = -`stepcost`); if `stochastic = true` the actions lead with
138+
a certain probability to a neighbouring state, where `neighbourstateweight`
139+
controls this probability.
140+
"""
141+
function DiscreteMazeEnv(; nx = 40, ny = 40, nwalls = div(nx*ny, 20),
142+
rng = Random.GLOBAL_RNG, kwargs...)
143+
m = emptymaze(nx, ny)
144+
for _ in 1:nwalls
145+
addrandomwall!(m, rng = rng)
146+
end
147+
breaksomewalls!(m, rng = rng)
148+
DiscreteMazeEnv(m; rng = rng, kwargs...)
149+
end
150+
function DiscreteMazeEnv(maze; ngoals = 1,
151+
goalrewards = 1.,
152+
stepcost = 0,
153+
stochastic = false,
154+
neighbourstateweight = stochastic ? .05 : 0.,
155+
rng = Random.GLOBAL_RNG)
156+
na = 4
157+
ns = length(maze)
158+
legalstates = LinearIndices(maze)[findall(x -> x != 0, maze)]
159+
T = Array{SparseVector{Float64,Int}}(undef, na, ns)
160+
goals = sort(sample(rng, legalstates, ngoals, replace = false))
161+
R = DeterministicNextStateReward(fill(-stepcost, ns))
162+
isterminal = zeros(Int, ns); isterminal[goals] .= 1
163+
isinitial = setdiff(legalstates, goals)
164+
res = DiscreteMazeEnv(SimpleMDPEnv(DiscreteSpace(ns, 1),
165+
DiscreteSpace(na, 1),
166+
rand(rng, legalstates),
167+
T, R,
168+
isinitial,
169+
isterminal,
170+
rng),
171+
maze,
172+
goals,
173+
typeof(goalrewards) <: Number ? fill(goalrewards, ngoals) :
174+
goalrewards,
175+
neighbourstateweight)
176+
setTandR!(res)
177+
res
178+
end
179+
180+
ReinforcementLearningEnvironments.interact!(env::DiscreteMazeEnv, a) = interact!(env.mdp, a)
181+
ReinforcementLearningEnvironments.reset!(env::DiscreteMazeEnv) = reset!(env.mdp)
182+
ReinforcementLearningEnvironments.observe(env::DiscreteMazeEnv) = observe(env.mdp)
183+
ReinforcementLearningEnvironments.action_space(env::DiscreteMazeEnv) = action_space(env.mdp)
184+
ReinforcementLearningEnvironments.observation_space(env::DiscreteMazeEnv) = observation_space(env.mdp)
185+
function ReinforcementLearningEnvironments.render(env::DiscreteMazeEnv)
186+
goals = env.goals
187+
m = copy(env.maze)
188+
m[goals] .= 3
189+
m[env.mdp.state] = 2
190+
imshow(m, colormap = 21, size = (400, 400))
191+
end
192+
end

test/environments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
:(MDPEnv(LegacyGridWorld())),
5959
:(POMDPEnv(TigerPOMDP())),
6060
:(SimpleMDPEnv()),
61+
:(DiscreteMazeEnv()),
6162
:(deterministic_MDP()),
6263
:(absorbing_deterministic_tree_MDP()),
6364
:(stochastic_MDP()),

0 commit comments

Comments
 (0)