Skip to content

WIP: CRR algorithm #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function RL.Experiment(
start_steps = 1000,
start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng),
update_after = 1000,
update_every = 1,
update_freq = 1,
automatic_entropy_tuning = true,
lr_alpha = 0.003f0,
action_dims = action_dims,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_
model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob)
end

"""
This function is used to infer the probability of getting action `a` given state `s`.
"""
function (model::GaussianNetwork)(state, action)
x = model.pre(state)
μ, logσ = model.μ(x), model.logσ(x)
π_dist = Normal.(μ, exp.(logσ))
logp_π = sum(logpdf.(π_dist, action), dims = 1)
logp_π -= sum((2.0f0 .* (log(2.0f0) .- action - softplus.(-2.0f0 * action))), dims = 1)
end

#####
# DuelingNetwork
#####
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ReinforcementLearningZoo
const RLZoo = ReinforcementLearningZoo
export RLZoo

export GaussianNetwork
export GaussianNetwork, DuelingNetwork

using CircularArrayBuffers
using ReinforcementLearningBase
Expand Down
236 changes: 236 additions & 0 deletions src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
export CRRLearner

"""
CRRLearner(;kwargs)

See paper: [Critic Regularized Regression](https://arxiv.org/abs/2006.15134).

# Keyword arguments

- `approximator`::[`ActorCritic`](@ref): used to get Q-values (Critic) and logits (Actor) of a state.
- `target_approximator`::[`ActorCritic`](@ref): similar to `approximator`, but used to estimate the target.
- `γ::Float32`, reward discount rate.
- `batch_size::Int=32`
- `policy_improvement_mode::Symbol=:exp`, type of the weight function f. Possible values: :binary/:exp.
- `ratio_upper_bound::Float32`, when `policy_improvement_mode` is ":exp", the value of the exp function is upper-bounded by this parameter.
- `beta::Float32`, when `policy_improvement_mode` is ":exp", this is the denominator of the exp function.
- `advantage_estimator::Symbol=:mean`, type of the advantage estimate \\hat{A}. Possible values: :mean/:max.
- `update_freq::Int`: the frequency of updating the `approximator`.
- `update_step::Int=0`
- `target_update_freq::Int`: the frequency of syncing `target_approximator`.
- `continuous::Bool`: type of action space.
- `m::Int`: if `continuous=true`, sample `m` actions to calculate advantage estimate.
- `rng = Random.GLOBAL_RNG`
"""
mutable struct CRRLearner{
Aq<:ActorCritic,
At<:ActorCritic,
R<:AbstractRNG,
} <: AbstractLearner
approximator::Aq
target_approximator::At
γ::Float32
batch_size::Int
policy_improvement_mode::Symbol
ratio_upper_bound::Float32
beta::Float32
advantage_estimator::Symbol
update_freq::Int
update_step::Int
target_update_freq::Int
continuous::Bool
m::Int
rng::R
# for logging
actor_loss::Float32
critic_loss::Float32
end

function CRRLearner(;
approximator::Aq,
target_approximator::At,
γ::Float32 = 0.99f0,
batch_size::Int = 32,
policy_improvement_mode::Symbol = :binary,
ratio_upper_bound::Float32 = 20.0f0,
beta::Float32 = 1.0f0,
advantage_estimator::Symbol = :max,
update_freq::Int = 10,
update_step::Int = 0,
target_update_freq::Int = 100,
continuous::Bool,
m::Int = 4,
rng = Random.GLOBAL_RNG,
) where {Aq<:ActorCritic, At<:ActorCritic}
copyto!(approximator, target_approximator)
CRRLearner(
approximator,
target_approximator,
γ,
batch_size,
policy_improvement_mode,
ratio_upper_bound,
beta,
advantage_estimator,
update_freq,
update_step,
target_update_freq,
continuous,
m,
rng,
0.0f0,
0.0f0,
)
end

Flux.functor(x::CRRLearner) = (Q = x.approximator, Qₜ = x.target_approximator),
y -> begin
x = @set x.approximator = y.Q
x = @set x.target_approximator = y.Qₜ
x
end

function (learner::CRRLearner)(env)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
s = send_to_device(device(learner), s)
if learner.continuous
learner.approximator.actor(s; is_sampling=true) |> vec |> send_to_host
else
learner.approximator.actor(s) |> vec |> send_to_host
end
end

function RLBase.update!(learner::CRRLearner, batch::NamedTuple)
if learner.continuous
continuous_update!(learner, batch)
else
discrete_update!(learner, batch)
end
end

function continuous_update!(learner::CRRLearner, batch::NamedTuple)
AC = learner.approximator
target_AC = learner.target_approximator
γ = learner.γ
beta = learner.beta
batch_size = learner.batch_size
policy_improvement_mode = learner.policy_improvement_mode
ratio_upper_bound = learner.ratio_upper_bound
advantage_estimator = learner.advantage_estimator
D = device(AC)

s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)

a = reshape(a, :, batch_size)

target_a_t = target_AC.actor(s′; is_sampling=true)
target_q_input = vcat(s′, target_a_t)
target_q_t = target_AC.critic(target_q_input)

target = r .+ γ .* (1 .- t) .* target_q_t

q_t = Array{Float32}(undef, learner.m, batch_size)
for i in 1:learner.m
a_sample = AC.actor(learner.rng, s; is_sampling=true)
q_t[i, :] = AC.critic(vcat(s, a_sample))
end

ps = Flux.params(AC)
gs = gradient(ps) do
# Critic loss
q_input = vcat(s, a)
qa_t = AC.critic(q_input)

critic_loss = Flux.Losses.logitcrossentropy(qa_t, target)

log_π = AC.actor(s, a)

# Actor loss
if advantage_estimator == :max
advantage = qa_t .- maximum(q_t, dims=1)
elseif advantage_estimator == :mean
advantage = qa_t .- mean(q_t, dims=1)
else
error("Wrong parameter.")
end

if policy_improvement_mode == :binary
actor_loss_coef = Float32.(advantage .> 0.0f0)
elseif policy_improvement_mode == :exp
actor_loss_coef = clamp.(exp.(advantage ./ beta), 0.0f0, ratio_upper_bound)
else
error("Wrong parameter.")
end

actor_loss = mean(-log_π .* Zygote.dropgrad(actor_loss_coef))

ignore() do
learner.actor_loss = actor_loss
learner.critic_loss = critic_loss
end

actor_loss + critic_loss
end

update!(AC, gs)
end

function discrete_update!(learner::CRRLearner, batch::NamedTuple)
AC = learner.approximator
target_AC = learner.target_approximator
γ = learner.γ
beta = learner.beta
batch_size = learner.batch_size
policy_improvement_mode = learner.policy_improvement_mode
ratio_upper_bound = learner.ratio_upper_bound
advantage_estimator = learner.advantage_estimator
D = device(AC)

s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
a = CartesianIndex.(a, 1:batch_size)

target_a_t = softmax(target_AC.actor(s′))
target_q_t = target_AC.critic(s′)
expected_target_q = sum(target_a_t .* target_q_t, dims=1)

target = r .+ γ .* (1 .- t) .* expected_target_q

ps = Flux.params(AC)
gs = gradient(ps) do
# Critic loss
q_t = AC.critic(s)
qa_t = q_t[a]
critic_loss = Flux.Losses.logitcrossentropy(qa_t, target)

# Actor loss
a_t = softmax(AC.actor(s))

if advantage_estimator == :max
advantage = qa_t .- maximum(q_t, dims=1)
elseif advantage_estimator == :mean
advantage = qa_t .- mean(q_t, dims=1)
else
error("Wrong parameter.")
end

if policy_improvement_mode == :binary
actor_loss_coef = Float32.(advantage .> 0.0f0)
elseif policy_improvement_mode == :exp
actor_loss_coef = clamp.(exp.(advantage ./ beta), 0.0f0, ratio_upper_bound)
else
error("Wrong parameter.")
end

actor_loss = mean(-log.(a_t[a]) .* Zygote.dropgrad(actor_loss_coef))

ignore() do
learner.actor_loss = actor_loss
learner.critic_loss = critic_loss
end

actor_loss + critic_loss
end

update!(AC, gs)
end
92 changes: 92 additions & 0 deletions src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
export OfflinePolicy, RLTransition

struct RLTransition
state
action
reward
terminal
next_state
end

Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy
learner::L
dataset::Vector{T}
continuous::Bool
batch_size::Int
end

(π::OfflinePolicy)(env) = π(env, ActionStyle(env), action_space(env))

function (π::OfflinePolicy)(env, ::MinimalActionSet, ::Base.OneTo)
if π.continuous
π.learner(env)
else
findmax(π.learner(env))[2]
end
end
(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) = findmax(π.learner(env), legal_action_space_mask(env))[2]

function (π::OfflinePolicy)(env, ::MinimalActionSet, A)
if π.continuous
π.learner(env)
else
A[findmax(π.learner(env))[2]]
end
end
(π::OfflinePolicy)(env, ::FullActionSet, A) = A[findmax(π.learner(env), legal_action_space_mask(env))[2]]

function RLBase.update!(
p::OfflinePolicy,
traj::AbstractTrajectory,
::AbstractEnv,
::PreActStage,
)
l = p.learner
l.update_step += 1

if in(:target_update_freq, fieldnames(typeof(l))) && l.update_step % l.target_update_freq == 0
copyto!(l.target_approximator, l.approximator)
end

l.update_step % l.update_freq == 0 || return

inds, batch = sample(l.rng, p.dataset, p.batch_size)

update!(l, batch)
end

function StatsBase.sample(rng::AbstractRNG, dataset::Vector{T}, batch_size::Int) where {T}
valid_range = 1:length(dataset)
inds = rand(rng, valid_range, batch_size)
batch_data = dataset[inds]
s_length = size(batch_data[1].state)[1]

s = Array{Float32}(undef, s_length, batch_size)
s′ = Array{Float32}(undef, s_length, batch_size)
a = []
r = []
t = []
for (i, data) in enumerate(batch_data)
s[:, i] = data.state
push!(a, data.action)
s′[:, i] = data.next_state
push!(r, data.reward)
push!(t, data.terminal)
end
batch = NamedTuple{SARTS}((s, a, r, t, s′))
inds, batch
end

"""
calculate_CQL_loss(q_value, action; method)

See paper: [Conservative Q-Learning for Offline Reinforcement Learning](https://arxiv.org/abs/2006.04779)
"""
function calculate_CQL_loss(q_value::Matrix{T}, action::Vector{R}; method = "CQL(H)") where {T, R}
if method == "CQL(H)"
cql_loss = mean(log.(sum(exp.(q_value), dims=1)) .- q_value[action])
else
@error Wrong method parameter
end
return cql_loss
end
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
include("behavior_cloning.jl")
include("CRR.jl")
include("common.jl")
Loading