Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

add more flexible reward schemes to SimpleMDPEnv #7

Merged
merged 1 commit into from
Jul 24, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 83 additions & 18 deletions src/environments/classic_control/mdp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Random, POMDPs, POMDPModels, SparseArrays, LinearAlgebra, StatsBase

export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochastic_MDP, stochastic_tree_MDP,
deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP
deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP,
DeterministicStateActionReward, DeterministicNextStateReward,
NormalStateActionReward, NormalNextStateReward

#####
##### POMDPEnv
Expand All @@ -24,11 +26,11 @@ POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv(
DiscreteSpace(n_states(model)),
rng)

function interact!(env::POMDPEnv, action)
function interact!(env::POMDPEnv, action)
s, o, r = generate_sor(env.model, env.state, env.actions[action], env.rng)
env.state = s
(observation = observationindex(env.model, o),
reward = r,
(observation = observationindex(env.model, o),
reward = r,
isdone = isterminal(env.model, s))
end

Expand Down Expand Up @@ -72,13 +74,13 @@ function interact!(env::MDPEnv, action)
s = rand(env.rng, transition(env.model, env.state, env.actions[action]))
r = reward(env.model, env.state, env.actions[action])
env.state = s
(observation = stateindex(env.model, s),
reward = r,
(observation = stateindex(env.model, s),
reward = r,
isdone = isterminal(env.model, s))
end

function observe(env::MDPEnv)
(observation = stateindex(env.model, env.state),
(observation = stateindex(env.model, env.state),
isdone = isterminal(env.model, env.state))
end

Expand All @@ -91,35 +93,97 @@ end
na::Int64
state::Int64
trans_probs::Array{AbstractArray, 2}
reward::Array{Float64, 2}
reward::R
initialstates::Array{Int64, 1}
isterminal::Array{Int64, 1}
rng::S
A Markov Decision Process with `ns` states, `na` actions, current `state`,
`na`x`ns` - array of transition probabilites `trans_props` which consists for
every (action, state) pair of a (potentially sparse) array that sums to 1 (see
[`get_prob_vec_random`](@ref), [`get_prob_vec_uniform`](@ref),
[`get_prob_vec_deterministic`](@ref) for helpers to constract the transition
probabilities) `na`x`ns` - array of `reward`, array of initial states
probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref),
[`DeterministicNextStateReward`](@ref), [`NormalNextStateReward`](@ref),
[`NormalStateActionReward`](@ref)), array of initial states
`initialstates`, and `ns` - array of 0/1 indicating if a state is terminal.
"""
mutable struct SimpleMDPEnv{T}
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG}
observation_space::DiscreteSpace
action_space::DiscreteSpace
state::Int64
trans_probs::Array{T, 2}
reward::Array{Float64, 2}
reward::R
initialstates::Array{Int64, 1}
isterminal::Array{Int64, 1}
rng::S
end

function SimpleMDPEnv(ospace, aspace, state, trans_probs::Array{T, 2},
reward, initialstates, isterminal) where T
SimpleMDPEnv{T}(ospace, aspace, state, trans_probs, reward, initialstates, isterminal)
reward::R, initialstates, isterminal,
rng::S = Random.GLOBAL_RNG) where {T,R,S}
if R <: AbstractMatrix # to ensure compatibility with previous versions
reward = DeterministicStateActionReward(reward)
end
SimpleMDPEnv{T,typeof(reward),S}(ospace, aspace, state, trans_probs,
reward, initialstates, isterminal, rng)
end

observation_space(env::SimpleMDPEnv) = env.observation_space
action_space(env::SimpleMDPEnv) = env.action_space

# reward types
"""
struct DeterministicNextStateReward
value::Vector{Float64}
"""
struct DeterministicNextStateReward
value::Vector{Float64}
end
reward(::AbstractRNG, r::DeterministicNextStateReward, s, a, s′) = r.value[s′]
expected_rewards(r::DeterministicNextStateReward, trans_probs) = expected_rewards(r.value, trans_probs)
function expected_rewards(r::AbstractVector, trans_probs)
result = zeros(size(trans_probs))
for i in eachindex(trans_probs)
result[i] = dot(trans_probs[i], r)
end
result
end
"""
struct DeterministicStateActionReward
value::Array{Float64, 2}

`value` should be a `na × ns`-matrix.
"""
struct DeterministicStateActionReward
value::Array{Float64, 2}
end
reward(::AbstractRNG, r::DeterministicStateActionReward, s, a, s′) = r.value[a, s]
expected_rewards(r::DeterministicStateActionReward, ::Any) = r.value
"""
struct NormalNextStateReward
mean::Vector{Float64}
std::Vector{Float64}
"""
struct NormalNextStateReward
mean::Vector{Float64}
std::Vector{Float64}
end
reward(rng, r::NormalNextStateReward, s, a, s′) = r.mean[s′] + randn(rng) * r.std[s′]
expected_rewards(r::NormalNextStateReward, trans_probs) = expected_rewards(r.mean, trans_probs)
"""
struct NormalStateActionReward
mean::Array{Float64, 2}
std::Array{Float64, 2}

`mean` and `std` should be `na × ns`-matrices.
"""
struct NormalStateActionReward
mean::Array{Float64, 2}
std::Array{Float64, 2}
end
reward(rng, r::NormalStateActionReward, s, a, s′) = r.mean[a, s] + randn(rng) * r.std[a, s]
expected_rewards(r::NormalStateActionReward, ::Any) = r.mean

# run SimpleMDPEnv
"""
run!(mdp::SimpleMDPEnv, action::Int64)
Expand All @@ -129,7 +193,7 @@ function run!(mdp::SimpleMDPEnv, action::Int64)
if mdp.isterminal[mdp.state] == 1
reset!(mdp)
else
mdp.state = wsample(mdp.trans_probs[action, mdp.state])
mdp.state = wsample(mdp.rng, mdp.trans_probs[action, mdp.state])
(observation = mdp.state,)
end
end
Expand All @@ -141,8 +205,9 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int64, 1}) = run!(mdp, policy[mdp.state])


function interact!(env::SimpleMDPEnv, action)
r = env.reward[action, env.state]
oldstate = env.state
run!(env, action)
r = reward(env.rng, env.reward, oldstate, action, env.state)
(observation = env.state, reward = r, isdone = env.isterminal[env.state] == 1)
end

Expand All @@ -151,7 +216,7 @@ function observe(env::SimpleMDPEnv)
end

function reset!(env::SimpleMDPEnv)
env.state = rand(env.initialstates)
env.state = rand(env.rng, env.initialstates)
nothing
end

Expand Down Expand Up @@ -274,7 +339,7 @@ function deterministic_MDP(; ns = 10^4, na = 10)
end

"""
deterministic_tree_MDP(; na = 4, depth = 5)
deterministic_tree_MDP(; na = 4, depth = 5)
Returns a tree_MDP with random rewards at the leaf nodes.
"""
function deterministic_tree_MDP(; na = 4, depth = 5)
Expand Down Expand Up @@ -317,4 +382,4 @@ function absorbing_deterministic_tree_MDP(;ns = 10^3, na = 10)
reset!(mdp)
set_terminal_states!(mdp, ns - div(ns, 100) + 1:ns)
mdp
end
end