@@ -144,3 +144,169 @@ index 531c523..78b4e1c 100644
144
144
145
145
# Using a placeholder to update the variable during restore to avoid memory leak.
146
146
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
147
+ diff --git a/rl_coach/agents/actor_critic_agent.py b/rl_coach/agents/actor_critic_agent.py
148
+ index 35c8bf9..4f3ce60 100644
149
+ --- a/rl_coach/agents/actor_critic_agent.py
150
+ +++ b/rl_coach/agents/actor_critic_agent.py
151
+ @@ -94,11 +94,14 @@ class ActorCriticAgentParameters(AgentParameters):
152
+ class ActorCriticAgent(PolicyOptimizationAgent):
153
+ def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
154
+ super().__init__(agent_parameters, parent)
155
+ + print("[RL] ActorCriticAgent init")
156
+ self.last_gradient_update_step_idx = 0
157
+ self.action_advantages = self.register_signal('Advantages')
158
+ self.state_values = self.register_signal('Values')
159
+ self.value_loss = self.register_signal('Value Loss')
160
+ self.policy_loss = self.register_signal('Policy Loss')
161
+ + print("[RL] ActorCriticAgent init successful")
162
+ +
163
+
164
+ # Discounting function used to calculate discounted returns.
165
+ def discount(self, x, gamma):
166
+ diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py
167
+ index 5d12e0b..0ee3cfb 100644
168
+ --- a/rl_coach/agents/agent.py
169
+ +++ b/rl_coach/agents/agent.py
170
+ @@ -74,7 +74,7 @@ class Agent(AgentInterface):
171
+ self.imitation = False
172
+ self.agent_logger = Logger()
173
+ self.agent_episode_logger = EpisodeLogger()
174
+ -
175
+ + print("[RL] Created agent loggers")
176
+ # get the memory
177
+ # - distributed training + shared memory:
178
+ # * is chief? -> create the memory and add it to the scratchpad
179
+ @@ -84,22 +84,30 @@ class Agent(AgentInterface):
180
+ memory_name = self.ap.memory.path.split(':')[1]
181
+ self.memory_lookup_name = self.full_name_id + '.' + memory_name
182
+ if self.shared_memory and not self.is_chief:
183
+ + print("[RL] Creating shared memory")
184
+ self.memory = self.shared_memory_scratchpad.get(self.memory_lookup_name)
185
+ else:
186
+ + print("[RL] Dynamic import of memory: ", self.ap.memory)
187
+ # modules
188
+ self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
189
+ + print("[RL] Dynamically imported of memory", self.memory)
190
+
191
+ if hasattr(self.ap.memory, 'memory_backend_params'):
192
+ + print("[RL] Getting memory backend", self.ap.memory.memory_backend_params)
193
+ self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
194
+ + print("[RL] Memory backend", self.memory_backend)
195
+
196
+ if self.ap.memory.memory_backend_params.run_type != 'trainer':
197
+ + print("[RL] Setting memory backend", self.memory_backend)
198
+ self.memory.set_memory_backend(self.memory_backend)
199
+
200
+ if self.shared_memory and self.is_chief:
201
+ + print("[RL] Shared memory scratchpad")
202
+ self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
203
+
204
+ # set devices
205
+ if type(agent_parameters.task_parameters) == DistributedTaskParameters:
206
+ + print("[RL] Setting distributed devices")
207
+ self.has_global = True
208
+ self.replicated_device = agent_parameters.task_parameters.device
209
+ self.worker_device = "/job:worker/task:{}".format(self.task_id)
210
+ @@ -108,6 +116,7 @@ class Agent(AgentInterface):
211
+ else:
212
+ self.worker_device += "/device:GPU:0"
213
+ else:
214
+ + print("[RL] Setting devices")
215
+ self.has_global = False
216
+ self.replicated_device = None
217
+ if agent_parameters.task_parameters.use_cpu:
218
+ @@ -115,7 +124,7 @@ class Agent(AgentInterface):
219
+ else:
220
+ self.worker_device = [Device(DeviceType.GPU, i)
221
+ for i in range(agent_parameters.task_parameters.num_gpu)]
222
+ -
223
+ + print("[RL] Setting filters")
224
+ # filters
225
+ self.input_filter = self.ap.input_filter
226
+ self.input_filter.set_name('input_filter')
227
+ @@ -134,21 +143,26 @@ class Agent(AgentInterface):
228
+ # 3. Single worker (=both TF and Mxnet) - no data sharing needed + numpy arithmetic backend
229
+
230
+ if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
231
+ + print("[RL] Setting filter devices: distributed")
232
+ self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
233
+ self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
234
+ self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
235
+ elif (type(agent_parameters.task_parameters) == DistributedTaskParameters and
236
+ agent_parameters.task_parameters.framework_type == Frameworks.tensorflow):
237
+ + print("[RL] Setting filter devices: tf")
238
+ self.input_filter.set_device(device, mode='tf')
239
+ self.output_filter.set_device(device, mode='tf')
240
+ self.pre_network_filter.set_device(device, mode='tf')
241
+ else:
242
+ + print("[RL] Setting filter devices: numpy")
243
+ self.input_filter.set_device(device, mode='numpy')
244
+ self.output_filter.set_device(device, mode='numpy')
245
+ self.pre_network_filter.set_device(device, mode='numpy')
246
+
247
+ # initialize all internal variables
248
+ + print("[RL] Setting Phase")
249
+ self._phase = RunPhase.HEATUP
250
+ + print("[RL] After setting Phase")
251
+ self.total_shaped_reward_in_current_episode = 0
252
+ self.total_reward_in_current_episode = 0
253
+ self.total_steps_counter = 0
254
+ @@ -180,7 +194,7 @@ class Agent(AgentInterface):
255
+ # environment parameters
256
+ self.spaces = None
257
+ self.in_action_space = self.ap.algorithm.in_action_space
258
+ -
259
+ + print("[RL] Setting signals")
260
+ # signals
261
+ self.episode_signals = []
262
+ self.step_signals = []
263
+ @@ -195,6 +209,8 @@ class Agent(AgentInterface):
264
+
265
+ # batch rl
266
+ self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
267
+ + print("[RL] Agent init successful")
268
+ +
269
+
270
+ @property
271
+ def parent(self) -> 'LevelManager':
272
+ diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py
273
+ index 866fe8a..3e89908 100644
274
+ --- a/rl_coach/agents/agent.py
275
+ +++ b/rl_coach/agents/agent.py
276
+ @@ -28,6 +28,8 @@ from rl_coach.base_parameters import AgentParameters, Device, DeviceType, Distri
277
+ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
278
+ from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
279
+ from rl_coach.logger import screen, Logger, EpisodeLogger
280
+ + from rl_coach.memories.memory import Memory
281
+ + from rl_coach.memories.non_episodic.experience_replay import ExperienceReplay
282
+ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
283
+ from rl_coach.saver import SaverCollection
284
+ from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
285
+ @@ -572,7 +574,8 @@ class Agent(AgentInterface):
286
+ self.current_episode += 1
287
+
288
+ if self.phase != RunPhase.TEST:
289
+ - if isinstance(self.memory, EpisodicExperienceReplay):
290
+ + if isinstance(self.memory, EpisodicExperienceReplay) or \
291
+ + (isinstance(self.memory, Memory) and not isinstance(self.memory, ExperienceReplay)):
292
+ self.call_memory('store_episode', self.current_episode_buffer)
293
+ elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
294
+ for transition in self.current_episode_buffer.transitions:
295
+ @@ -618,7 +621,8 @@ class Agent(AgentInterface):
296
+ self.input_filter.reset()
297
+ self.output_filter.reset()
298
+ self.pre_network_filter.reset()
299
+ - if isinstance(self.memory, EpisodicExperienceReplay):
300
+ + if isinstance(self.memory, EpisodicExperienceReplay) or \
301
+ + (isinstance(self.memory, Memory) and not isinstance(self.memory, ExperienceReplay)):
302
+ self.call_memory('verify_last_episode_is_closed')
303
+
304
+ for network in self.networks.values():
305
+ @@ -953,7 +957,7 @@ class Agent(AgentInterface):
306
+ # for episodic memories we keep the transitions in a local buffer until the episode is ended.
307
+ # for regular memories we insert the transitions directly to the memory
308
+ self.current_episode_buffer.insert(transition)
309
+ - if not isinstance(self.memory, EpisodicExperienceReplay) \
310
+ + if isinstance(self.memory, ExperienceReplay) \
311
+ and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
312
+ self.call_memory('store', transition)
0 commit comments