Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 7443597

Browse files
authored
ignore ViZDoom error for now (#11)
* allow ViZDoom broken * bugfix * ignore ViZDoom for now * add some quick fixes, mdp.jl needs careful review
1 parent 155050d commit 7443597

File tree

5 files changed

+31
-9
lines changed

5 files changed

+31
-9
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,9 @@ deps/deps.jl
55

66
Manifest.toml
77

8-
_vizdoom.ini
8+
_vizdoom.ini
9+
10+
.vscode/*
11+
!.vscode/tasks.json
12+
!.vscode/launch.json
13+
!.vscode/extensions.json

src/environments/classic_control/mdp.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272

7373
function interact!(env::MDPEnv, action)
7474
s = rand(env.rng, transition(env.model, env.state, env.actions[action]))
75-
r = reward(env.model, env.state, env.actions[action])
75+
r = POMDPs.reward(env.model, env.state, env.actions[action])
7676
env.state = s
7777
(observation = stateindex(env.model, s),
7878
reward = r,
@@ -141,6 +141,7 @@ struct DeterministicNextStateReward
141141
end
142142
reward(::AbstractRNG, r::DeterministicNextStateReward, s, a, s′) = r.value[s′]
143143
expected_rewards(r::DeterministicNextStateReward, trans_probs) = expected_rewards(r.value, trans_probs)
144+
144145
function expected_rewards(r::AbstractVector, trans_probs)
145146
result = zeros(size(trans_probs))
146147
for i in eachindex(trans_probs)
@@ -321,7 +322,11 @@ the same value.
321322
function set_terminal_states!(mdp, range)
322323
mdp.isterminal[range] .= 1
323324
for s in findall(x -> x == 1, mdp.isterminal)
324-
mdp.reward[:, s] .= mean(mdp.reward[:, s])
325+
if mdp.reward isa DeterministicStateActionReward
326+
mdp.reward.value[:, s] .= mean(mdp.reward.value[:, s])
327+
else
328+
mdp.reward[:, s] .= mean(mdp.reward[:, s])
329+
end
325330
for a in 1:length(mdp.action_space)
326331
empty_trans_prob!(mdp.trans_probs[a, s])
327332
end
@@ -334,7 +339,11 @@ Returns a random deterministic SimpleMDPEnv.
334339
"""
335340
function deterministic_MDP(; ns = 10^4, na = 10)
336341
mdp = SimpleMDPEnv(ns, na, init = "deterministic")
337-
mdp.reward = mdp.reward .* (mdp.reward .< -1.5)
342+
if mdp.reward isa DeterministicStateActionReward
343+
mdp.reward.value .*= mdp.reward.value .< -1.5
344+
else
345+
mdp.reward = mdp.reward .* (mdp.reward .< -1.5)
346+
end
338347
mdp
339348
end
340349

@@ -353,7 +362,11 @@ Returns a tree_MDP with random rewards.
353362
function deterministic_tree_MDP_with_rand_reward(; args...)
354363
mdp = deterministic_tree_MDP(; args...)
355364
nonterminals = findall(x -> x == 0, mdp.isterminal)
356-
mdp.reward[:, nonterminals] = -rand(length(mdp.action_space), length(nonterminals))
365+
if mdp.reward isa DeterministicStateActionReward
366+
mdp.reward.value[:, nonterminals] = -rand(length(mdp.action_space), length(nonterminals))
367+
else
368+
mdp.reward[:, nonterminals] = -rand(length(mdp.action_space), length(nonterminals))
369+
end
357370
mdp
358371
end
359372

@@ -377,7 +390,11 @@ Returns a random deterministic absorbing SimpleMDPEnv
377390
"""
378391
function absorbing_deterministic_tree_MDP(;ns = 10^3, na = 10)
379392
mdp = SimpleMDPEnv(ns, na, init = "deterministic")
380-
mdp.reward .= mdp.reward .* (mdp.reward .< -.5)
393+
if mdp.reward isa DeterministicStateActionReward
394+
mdp.reward.value .*= mdp.reward.value .< -.5
395+
else
396+
mdp.reward .= mdp.reward .* (mdp.reward .< -.5)
397+
end
381398
mdp.initialstates = 1:div(ns, 100)
382399
reset!(mdp)
383400
set_terminal_states!(mdp, ns - div(ns, 100) + 1:ns)

src/environments/gym.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,5 @@ function list_gym_env_names(;
9292
"gym.envs.toy_text",
9393
"gym.envs.unittest"])
9494
gym = pyimport("gym")
95-
[x.id for x in gym.envs.registry.all() if split(x._entry_point, ':')[1] in modules]
95+
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
9696
end

test/environments.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
for env_exp in [
5252
:(HanabiEnv()),
53-
:(basic_ViZDoom_env()),
53+
# :(basic_ViZDoom_env()), # comment out due to https://github.com/JuliaReinforcementLearning/ViZDoom.jl/issues/7
5454
:(CartPoleEnv()),
5555
:(MountainCarEnv()),
5656
:(ContinuousMountainCarEnv()),

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Test
22
using ReinforcementLearningEnvironments
33
using ArcadeLearningEnvironment
44
using POMDPModels
5-
using ViZDoom
5+
# using ViZDoom
66
using PyCall
77
using Hanabi
88

0 commit comments

Comments
 (0)