Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit c51d583

Browse files
committed
add more flexible reward schemes to SimpleMDPEnv
1 parent 6a7f1bd commit c51d583

File tree

1 file changed

+83
-18
lines changed
  • src/environments/classic_control

1 file changed

+83
-18
lines changed

src/environments/classic_control/mdp.jl

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Random, POMDPs, POMDPModels, SparseArrays, LinearAlgebra, StatsBase
22

33
export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochastic_MDP, stochastic_tree_MDP,
4-
deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP
4+
deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP,
5+
DeterministicStateActionReward, DeterministicNextStateReward,
6+
NormalStateActionReward, NormalNextStateReward
57

68
#####
79
##### POMDPEnv
@@ -24,11 +26,11 @@ POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv(
2426
DiscreteSpace(n_states(model)),
2527
rng)
2628

27-
function interact!(env::POMDPEnv, action)
29+
function interact!(env::POMDPEnv, action)
2830
s, o, r = generate_sor(env.model, env.state, env.actions[action], env.rng)
2931
env.state = s
30-
(observation = observationindex(env.model, o),
31-
reward = r,
32+
(observation = observationindex(env.model, o),
33+
reward = r,
3234
isdone = isterminal(env.model, s))
3335
end
3436

@@ -72,13 +74,13 @@ function interact!(env::MDPEnv, action)
7274
s = rand(env.rng, transition(env.model, env.state, env.actions[action]))
7375
r = reward(env.model, env.state, env.actions[action])
7476
env.state = s
75-
(observation = stateindex(env.model, s),
76-
reward = r,
77+
(observation = stateindex(env.model, s),
78+
reward = r,
7779
isdone = isterminal(env.model, s))
7880
end
7981

8082
function observe(env::MDPEnv)
81-
(observation = stateindex(env.model, env.state),
83+
(observation = stateindex(env.model, env.state),
8284
isdone = isterminal(env.model, env.state))
8385
end
8486

@@ -91,35 +93,97 @@ end
9193
na::Int64
9294
state::Int64
9395
trans_probs::Array{AbstractArray, 2}
94-
reward::Array{Float64, 2}
96+
reward::R
9597
initialstates::Array{Int64, 1}
9698
isterminal::Array{Int64, 1}
99+
rng::S
97100
A Markov Decision Process with `ns` states, `na` actions, current `state`,
98101
`na`x`ns` - array of transition probabilites `trans_props` which consists for
99102
every (action, state) pair of a (potentially sparse) array that sums to 1 (see
100103
[`get_prob_vec_random`](@ref), [`get_prob_vec_uniform`](@ref),
101104
[`get_prob_vec_deterministic`](@ref) for helpers to constract the transition
102-
probabilities) `na`x`ns` - array of `reward`, array of initial states
105+
probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref),
106+
[`DeterministicNextStateReward`](@ref), [`NormalNextStateReward`](@ref),
107+
[`NormalStateActionReward`](@ref)), array of initial states
103108
`initialstates`, and `ns` - array of 0/1 indicating if a state is terminal.
104109
"""
105-
mutable struct SimpleMDPEnv{T}
110+
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG}
106111
observation_space::DiscreteSpace
107112
action_space::DiscreteSpace
108113
state::Int64
109114
trans_probs::Array{T, 2}
110-
reward::Array{Float64, 2}
115+
reward::R
111116
initialstates::Array{Int64, 1}
112117
isterminal::Array{Int64, 1}
118+
rng::S
113119
end
114120

115121
function SimpleMDPEnv(ospace, aspace, state, trans_probs::Array{T, 2},
116-
reward, initialstates, isterminal) where T
117-
SimpleMDPEnv{T}(ospace, aspace, state, trans_probs, reward, initialstates, isterminal)
122+
reward::R, initialstates, isterminal,
123+
rng::S = Random.GLOBAL_RNG) where {T,R,S}
124+
if R <: AbstractMatrix # to ensure compatibility with previous versions
125+
reward = DeterministicStateActionReward(reward)
126+
end
127+
SimpleMDPEnv{T,typeof(reward),S}(ospace, aspace, state, trans_probs,
128+
reward, initialstates, isterminal, rng)
118129
end
119130

120131
observation_space(env::SimpleMDPEnv) = env.observation_space
121132
action_space(env::SimpleMDPEnv) = env.action_space
122133

134+
# reward types
135+
"""
136+
struct DeterministicNextStateReward
137+
value::Vector{Float64}
138+
"""
139+
struct DeterministicNextStateReward
140+
value::Vector{Float64}
141+
end
142+
reward(::AbstractRNG, r::DeterministicNextStateReward, s, a, s′) = r.value[s′]
143+
expected_rewards(r::DeterministicNextStateReward, trans_probs) = expected_rewards(r.value, trans_probs)
144+
function expected_rewards(r::AbstractVector, trans_probs)
145+
result = zeros(size(trans_probs))
146+
for i in eachindex(trans_probs)
147+
result[i] = dot(trans_probs[i], r)
148+
end
149+
result
150+
end
151+
"""
152+
struct DeterministicStateActionReward
153+
value::Array{Float64, 2}
154+
155+
`value` should be a `na × ns`-matrix.
156+
"""
157+
struct DeterministicStateActionReward
158+
value::Array{Float64, 2}
159+
end
160+
reward(::AbstractRNG, r::DeterministicStateActionReward, s, a, s′) = r.value[a, s]
161+
expected_rewards(r::DeterministicStateActionReward, ::Any) = r.value
162+
"""
163+
struct NormalNextStateReward
164+
mean::Vector{Float64}
165+
std::Vector{Float64}
166+
"""
167+
struct NormalNextStateReward
168+
mean::Vector{Float64}
169+
std::Vector{Float64}
170+
end
171+
reward(rng, r::NormalNextStateReward, s, a, s′) = r.mean[s′] + randn(rng) * r.std[s′]
172+
expected_rewards(r::NormalNextStateReward, trans_probs) = expected_rewards(r.mean, trans_probs)
173+
"""
174+
struct NormalStateActionReward
175+
mean::Array{Float64, 2}
176+
std::Array{Float64, 2}
177+
178+
`mean` and `std` should be `na × ns`-matrices.
179+
"""
180+
struct NormalStateActionReward
181+
mean::Array{Float64, 2}
182+
std::Array{Float64, 2}
183+
end
184+
reward(rng, r::NormalStateActionReward, s, a, s′) = r.mean[a, s] + randn(rng) * r.std[a, s]
185+
expected_rewards(r::NormalStateActionReward, ::Any) = r.mean
186+
123187
# run SimpleMDPEnv
124188
"""
125189
run!(mdp::SimpleMDPEnv, action::Int64)
@@ -129,7 +193,7 @@ function run!(mdp::SimpleMDPEnv, action::Int64)
129193
if mdp.isterminal[mdp.state] == 1
130194
reset!(mdp)
131195
else
132-
mdp.state = wsample(mdp.trans_probs[action, mdp.state])
196+
mdp.state = wsample(mdp.rng, mdp.trans_probs[action, mdp.state])
133197
(observation = mdp.state,)
134198
end
135199
end
@@ -141,8 +205,9 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int64, 1}) = run!(mdp, policy[mdp.state])
141205

142206

143207
function interact!(env::SimpleMDPEnv, action)
144-
r = env.reward[action, env.state]
208+
oldstate = env.state
145209
run!(env, action)
210+
r = reward(env.rng, env.reward, oldstate, action, env.state)
146211
(observation = env.state, reward = r, isdone = env.isterminal[env.state] == 1)
147212
end
148213

@@ -151,7 +216,7 @@ function observe(env::SimpleMDPEnv)
151216
end
152217

153218
function reset!(env::SimpleMDPEnv)
154-
env.state = rand(env.initialstates)
219+
env.state = rand(env.rng, env.initialstates)
155220
nothing
156221
end
157222

@@ -274,7 +339,7 @@ function deterministic_MDP(; ns = 10^4, na = 10)
274339
end
275340

276341
"""
277-
deterministic_tree_MDP(; na = 4, depth = 5)
342+
deterministic_tree_MDP(; na = 4, depth = 5)
278343
Returns a tree_MDP with random rewards at the leaf nodes.
279344
"""
280345
function deterministic_tree_MDP(; na = 4, depth = 5)
@@ -317,4 +382,4 @@ function absorbing_deterministic_tree_MDP(;ns = 10^3, na = 10)
317382
reset!(mdp)
318383
set_terminal_states!(mdp, ns - div(ns, 100) + 1:ns)
319384
mdp
320-
end
385+
end

0 commit comments

Comments
 (0)