|
| 1 | +module DiscreteMaze |
| 2 | +using Random, StatsBase, SparseArrays, GR, ..ReinforcementLearningEnvironments |
| 3 | +export DiscreteMazeEnv |
| 4 | + |
| 5 | +function emptymaze(dimx, dimy) |
| 6 | + maze = ones(Int, dimx, dimy) |
| 7 | + maze[1,:] .= maze[end,:] .= maze[:, 1] .= maze[:, end] .= 0 # borders |
| 8 | + return maze |
| 9 | +end |
| 10 | +iswall(maze, pos) = maze[pos] == 0 |
| 11 | +isinsideframe(maze, i::Int) = isinsideframe(maze, CartesianIndices(maze)[i]) |
| 12 | +isinsideframe(maze, i) = i[1] > 1 && i[2] > 1 && i[1] < size(maze, 1) && i[2] < size(maze, 2) |
| 13 | + |
| 14 | +const UP = CartesianIndex(0, -1) |
| 15 | +const DOWN = CartesianIndex(0, 1) |
| 16 | +const LEFT = CartesianIndex(-1, 0) |
| 17 | +const RIGHT = CartesianIndex(1, 0) |
| 18 | +function orthogonal_directions(dir) |
| 19 | + dir[1] == 0 && return (LEFT, RIGHT) |
| 20 | + return (UP, DOWN) |
| 21 | +end |
| 22 | + |
| 23 | +function is_wall_neighbour(maze, pos) |
| 24 | + for dir in (UP, DOWN, LEFT, RIGHT, UP + RIGHT, UP + LEFT, DOWN + RIGHT, DOWN + LEFT) |
| 25 | + iswall(maze, pos + dir) && return true |
| 26 | + end |
| 27 | + return false |
| 28 | +end |
| 29 | +function is_wall_tangential(maze, pos, dir) |
| 30 | + for ortho_dir in orthogonal_directions(dir) |
| 31 | + iswall(maze, pos + ortho_dir) && return true |
| 32 | + end |
| 33 | + return false |
| 34 | +end |
| 35 | +is_wall_ahead(maze, pos, dir) = iswall(maze, pos + dir) |
| 36 | + |
| 37 | +function addrandomwall!(maze; rng = Random.GLOBAL_RNG) |
| 38 | + potential_startpos = filter(x -> !is_wall_neighbour(maze, x), |
| 39 | + findall(x -> x != 0, maze)) |
| 40 | + if potential_startpos == [] |
| 41 | + @warn("Cannot add a random wall.") |
| 42 | + return maze |
| 43 | + end |
| 44 | + pos = rand(rng, potential_startpos) |
| 45 | + direction = rand(rng, (UP, DOWN, LEFT, RIGHT)) |
| 46 | + while true |
| 47 | + maze[pos] = 0 |
| 48 | + pos += direction |
| 49 | + is_wall_tangential(maze, pos, direction) && break |
| 50 | + if is_wall_ahead(maze, pos, direction) |
| 51 | + maze[pos] = 0 |
| 52 | + break |
| 53 | + end |
| 54 | + end |
| 55 | + return maze |
| 56 | +end |
| 57 | + |
| 58 | +function n_effective(n, f, list) |
| 59 | + N = n === nothing ? div(length(list), Int(1/f)) : n |
| 60 | + min(N, length(list)) |
| 61 | +end |
| 62 | +function breaksomewalls!(m; f = 1/50, n = nothing, rng = Random.GLOBAL_RNG) |
| 63 | + wallpos = Int[] |
| 64 | + for i in 1:length(m) |
| 65 | + iswall(m, i) && isinsideframe(m, i) && push!(wallpos, i) |
| 66 | + end |
| 67 | + pos = sample(rng, wallpos, n_effective(n, f, wallpos), replace = false) |
| 68 | + m[pos] .= 1 |
| 69 | + m |
| 70 | +end |
| 71 | +function addobstacles!(m; f = 1/100, n = nothing, rng = Random.GLOBAL_RNG) |
| 72 | + nz = findall(x -> x == 1, reshape(m, :)) |
| 73 | + pos = sample(rng, nz, n_effective(n, f, nz), replace = false) |
| 74 | + m[pos] .= 0 |
| 75 | + m |
| 76 | +end |
| 77 | +function setTandR!(d) |
| 78 | + for s in LinearIndices(d.maze)[findall(x -> x != 0, d.maze)] |
| 79 | + setTandR!(d, s) |
| 80 | + end |
| 81 | +end |
| 82 | +function setTandR!(d, s) |
| 83 | + T = d.mdp.trans_probs |
| 84 | + R = d.mdp.reward |
| 85 | + goals = d.goals |
| 86 | + ns = length(d.mdp.observation_space) |
| 87 | + maze = d.maze |
| 88 | + if s in goals |
| 89 | + idx_goals = findfirst(x -> x == s, goals) |
| 90 | + R.value[s] = d.goalrewards[idx_goals] |
| 91 | + end |
| 92 | + pos = CartesianIndices(maze)[s] |
| 93 | + for (aind, a) in enumerate((UP, DOWN, LEFT, RIGHT)) |
| 94 | + nextpos = maze[pos + a] == 0 ? pos : pos + a |
| 95 | + if d.neighbourstateweight > 0 |
| 96 | + positions = [nextpos] |
| 97 | + weights = [1.] |
| 98 | + for dir in (UP, DOWN, LEFT, RIGHT) |
| 99 | + if maze[nextpos + dir] != 0 |
| 100 | + push!(positions, nextpos + dir) |
| 101 | + push!(weights, d.neighbourstateweight) |
| 102 | + end |
| 103 | + end |
| 104 | + states = LinearIndices(maze)[positions] |
| 105 | + weights /= sum(weights) |
| 106 | + T[aind, s] = sparsevec(states, weights, ns) |
| 107 | + else |
| 108 | + nexts = LinearIndices(maze)[nextpos] |
| 109 | + T[aind, s] = sparsevec([nexts], [1.], ns) |
| 110 | + end |
| 111 | + end |
| 112 | +end |
| 113 | + |
| 114 | +""" |
| 115 | + struct DiscreteMazeEnv |
| 116 | + mdp::MDP |
| 117 | + maze::Array{Int, 2} |
| 118 | + goals::Array{Int, 1} |
| 119 | + statefrommaze::Array{Int, 1} |
| 120 | + mazefromstate::Array{Int, 1} |
| 121 | +""" |
| 122 | +struct DiscreteMazeEnv{T} |
| 123 | + mdp::T |
| 124 | + maze::Array{Int, 2} |
| 125 | + goals::Array{Int, 1} |
| 126 | + goalrewards::Array{Float64, 1} |
| 127 | + neighbourstateweight::Float64 |
| 128 | +end |
| 129 | +""" |
| 130 | + DiscreteMazeEnv(; nx = 40, ny = 40, nwalls = div(nx*ny, 20), ngoals = 1, |
| 131 | + goalrewards = 1, stepcost = 0, stochastic = false, |
| 132 | + neighbourstateweight = .05, rng = Random.GLOBAL_RNG) |
| 133 | +
|
| 134 | +Returns a `DiscreteMazeEnv` of width `nx` and height `ny` with `nwalls` walls and |
| 135 | +`ngoals` goal locations with reward `goalreward` (a list of different rewards |
| 136 | +for the different goal states or constant reward for all goals), cost of moving |
| 137 | +`stepcost` (reward = -`stepcost`); if `stochastic = true` the actions lead with |
| 138 | +a certain probability to a neighbouring state, where `neighbourstateweight` |
| 139 | +controls this probability. |
| 140 | +""" |
| 141 | +function DiscreteMazeEnv(; nx = 40, ny = 40, nwalls = div(nx*ny, 20), |
| 142 | + rng = Random.GLOBAL_RNG, kwargs...) |
| 143 | + m = emptymaze(nx, ny) |
| 144 | + for _ in 1:nwalls |
| 145 | + addrandomwall!(m, rng = rng) |
| 146 | + end |
| 147 | + breaksomewalls!(m, rng = rng) |
| 148 | + DiscreteMazeEnv(m; rng = rng, kwargs...) |
| 149 | +end |
| 150 | +function DiscreteMazeEnv(maze; ngoals = 1, |
| 151 | + goalrewards = 1., |
| 152 | + stepcost = 0, |
| 153 | + stochastic = false, |
| 154 | + neighbourstateweight = stochastic ? .05 : 0., |
| 155 | + rng = Random.GLOBAL_RNG) |
| 156 | + na = 4 |
| 157 | + ns = length(maze) |
| 158 | + legalstates = LinearIndices(maze)[findall(x -> x != 0, maze)] |
| 159 | + T = Array{SparseVector{Float64,Int}}(undef, na, ns) |
| 160 | + goals = sort(sample(rng, legalstates, ngoals, replace = false)) |
| 161 | + R = DeterministicNextStateReward(fill(-stepcost, ns)) |
| 162 | + isterminal = zeros(Int, ns); isterminal[goals] .= 1 |
| 163 | + isinitial = setdiff(legalstates, goals) |
| 164 | + res = DiscreteMazeEnv(SimpleMDPEnv(DiscreteSpace(ns, 1), |
| 165 | + DiscreteSpace(na, 1), |
| 166 | + rand(rng, legalstates), |
| 167 | + T, R, |
| 168 | + isinitial, |
| 169 | + isterminal, |
| 170 | + rng), |
| 171 | + maze, |
| 172 | + goals, |
| 173 | + typeof(goalrewards) <: Number ? fill(goalrewards, ngoals) : |
| 174 | + goalrewards, |
| 175 | + neighbourstateweight) |
| 176 | + setTandR!(res) |
| 177 | + res |
| 178 | +end |
| 179 | + |
| 180 | +ReinforcementLearningEnvironments.interact!(env::DiscreteMazeEnv, a) = interact!(env.mdp, a) |
| 181 | +ReinforcementLearningEnvironments.reset!(env::DiscreteMazeEnv) = reset!(env.mdp) |
| 182 | +ReinforcementLearningEnvironments.observe(env::DiscreteMazeEnv) = observe(env.mdp) |
| 183 | +ReinforcementLearningEnvironments.action_space(env::DiscreteMazeEnv) = action_space(env.mdp) |
| 184 | +ReinforcementLearningEnvironments.observation_space(env::DiscreteMazeEnv) = observation_space(env.mdp) |
| 185 | +function ReinforcementLearningEnvironments.render(env::DiscreteMazeEnv) |
| 186 | + goals = env.goals |
| 187 | + m = copy(env.maze) |
| 188 | + m[goals] .= 3 |
| 189 | + m[env.mdp.state] = 2 |
| 190 | + imshow(m, colormap = 21, size = (400, 400)) |
| 191 | +end |
| 192 | +end |
0 commit comments