Skip to content

Commit e460aa2

Browse files
authored
try to fix bugs of ActionTransformedEnv (#447)
1 parent 7078535 commit e460aa2

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/ReinforcementLearningCore/src/policies/agents/agent.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ function RLBase.update!(
137137
# the global rng. In theory it shouldn't affect the performance of specific
138138
# algorithm.
139139
# TODO: how to inject a local rng here to avoid polluting the global rng
140-
action = rand(action_space(env))
141140

142141
s = policy isa NamedPolicy ? state(env, nameof(policy)) : state(env)
142+
a = policy isa NamedPolicy ? rand(action_space(env, nameof(policy))) : rand(action_space(env))
143143
push!(trajectory[:state], s)
144-
push!(trajectory[:action], action)
144+
push!(trajectory[:action], a)
145145
if haskey(trajectory, :legal_actions_mask)
146146
lasm =
147147
policy isa NamedPolicy ? legal_action_space_mask(env, nameof(policy)) :

src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = iden
1717
ActionTransformedEnv(env, action_mapping, action_space_mapping)
1818

1919
RLBase.action_space(env::ActionTransformedEnv, args...) =
20-
env.action_space_mapping(action_space(env.env), args...)
20+
env.action_space_mapping(action_space(env.env, args...))
2121

2222
RLBase.legal_action_space(env::ActionTransformedEnv, args...) =
23-
env.action_space_mapping(legal_action_space(env.env), args...)
23+
env.action_space_mapping(legal_action_space(env.env, args...))
2424

2525
(env::ActionTransformedEnv)(action, args...; kwargs...) =
2626
env.env(env.action_mapping(action), args...; kwargs...)

0 commit comments

Comments
 (0)