|
33 | 33 | # consecutive steps, and use this as an input to the policy along with the
|
34 | 34 | # current observation.
|
35 | 35 | #
|
36 |
| -# This tutorial shows how to incorporate an RNN in a policy. |
| 36 | +# This tutorial shows how to incorporate an RNN in a policy using TorchRL. |
37 | 37 | #
|
38 | 38 | # Key learnings:
|
39 | 39 | #
|
|
51 | 51 | # As this figure shows, our environment populates the TensorDict with zeroed recurrent
|
52 | 52 | # states which are read by the policy together with the observation to produce an
|
53 | 53 | # action, and recurrent states that will be used for the next step.
|
54 |
| -# When the :func:`torchrl.envs.step_mdp` function is called, the recurrent states |
| 54 | +# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states |
55 | 55 | # from the next state are brought to the current TensorDict. Let's see how this
|
56 | 56 | # is implemented in practice.
|
57 | 57 |
|
|
60 | 60 | #
|
61 | 61 | # .. code-block:: bash
|
62 | 62 | #
|
63 |
| -# !pip3 install torchrl-nightly |
| 63 | +# !pip3 install torchrl |
64 | 64 | # !pip3 install gym[mujoco]
|
65 | 65 | # !pip3 install tqdm
|
66 | 66 | #
|
|
104 | 104 | # 84x84, scaling down the rewards and normalizing the observations.
|
105 | 105 | #
|
106 | 106 | # .. note::
|
107 |
| -# The :class:`torchrl.envs.StepCounter` transform is accessory. Since the CartPole |
| 107 | +# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole |
108 | 108 | # task goal is to make trajectories as long as possible, counting the steps
|
109 | 109 | # can help us track the performance of our policy.
|
110 | 110 | #
|
111 | 111 | # Two transforms are important for the purpose of this tutorial:
|
112 | 112 | #
|
113 |
| -# - :class:`torchrl.envs.InitTracker` will stamp the |
114 |
| -# calls to :meth:`torchrl.envs.EnvBase.reset` by adding a ``"is_init"`` |
| 113 | +# - :class:`~torchrl.envs.transforms.InitTracker` will stamp the |
| 114 | +# calls to :meth:`~torchrl.envs.EnvBase.reset` by adding a ``"is_init"`` |
115 | 115 | # boolean mask in the TensorDict that will track which steps require a reset
|
116 | 116 | # of the RNN hidden states.
|
117 |
| -# - The :class:`torchrl.envs.TensorDictPrimer` transform is a bit more |
| 117 | +# - The :class:`~torchrl.envs.transforms.TensorDictPrimer` transform is a bit more |
118 | 118 | # technical. It is not required to use RNN policies. However, it
|
119 | 119 | # instructs the environment (and subsequently the collector) that some extra
|
120 | 120 | # keys are to be expected. Once added, a call to `env.reset()` will populate
|
|
127 | 127 | # the training of our policy, but it will make the recurrent keys disappear
|
128 | 128 | # from the collected data and the replay buffer, which will in turn lead to
|
129 | 129 | # a slightly less optimal training.
|
130 |
| -# Fortunately, the :class:`torchrl.modules.LSTMModule` we propose is |
| 130 | +# Fortunately, the :class:`~torchrl.modules.LSTMModule` we propose is |
131 | 131 | # equipped with a helper method to build just that transform for us, so
|
132 | 132 | # we can wait until we build it!
|
133 | 133 | #
|
|
155 | 155 | # Policy
|
156 | 156 | # ------
|
157 | 157 | #
|
158 |
| -# Our policy will have 3 components: a :class:`torchrl.modules.ConvNet` |
159 |
| -# backbone, an :class:`torchrl.modules.LSTMModule` memory layer and a shallow |
160 |
| -# :class:`torchrl.modules.MLP` block that will map the LSTM output onto the |
| 158 | +# Our policy will have 3 components: a :class:`~torchrl.modules.ConvNet` |
| 159 | +# backbone, an :class:`~torchrl.modules.LSTMModule` memory layer and a shallow |
| 160 | +# :class:`~torchrl.modules.MLP` block that will map the LSTM output onto the |
161 | 161 | # action values.
|
162 | 162 | #
|
163 | 163 | # Convolutional network
|
164 | 164 | # ~~~~~~~~~~~~~~~~~~~~~
|
165 | 165 | #
|
166 | 166 | # We build a convolutional network flanked with a :class:`torch.nn.AdaptiveAvgPool2d`
|
167 |
| -# that will squash the output in a vector of size 64. The :class:`torchrl.modules.ConvNet` |
| 167 | +# that will squash the output in a vector of size 64. The :class:`~torchrl.modules.ConvNet` |
168 | 168 | # can assist us with this:
|
169 | 169 | #
|
170 | 170 |
|
|
189 | 189 | # LSTM Module
|
190 | 190 | # ~~~~~~~~~~~
|
191 | 191 | #
|
192 |
| -# TorchRL provides a specialized :class:`torchrl.modules.LSTMModule` class |
193 |
| -# to incorporate LSTMs in your code-base. It is a :class:`tensordict.nn.TensorDictModuleBase` |
| 192 | +# TorchRL provides a specialized :class:`~torchrl.modules.LSTMModule` class |
| 193 | +# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase` |
194 | 194 | # subclass: as such, it has a set of ``in_keys`` and ``out_keys`` that indicate
|
195 | 195 | # what values should be expected to be read and written/updated during the
|
196 | 196 | # execution of the module. The class comes with customizable predefined
|
|
201 | 201 | # dropout or multi-layered LSTMs.
|
202 | 202 | # However, to respect TorchRL's conventions, this LSTM must have the ``batch_first``
|
203 | 203 | # attribute set to ``True`` which is **not** the default in PyTorch. However,
|
204 |
| -# our :class:`torchrl.modules.LSTMModule` changes this default |
| 204 | +# our :class:`~torchrl.modules.LSTMModule` changes this default |
205 | 205 | # behavior, so we're good with a native call.
|
206 | 206 | #
|
207 | 207 | # Also, the LSTM cannot have a ``bidirectional`` attribute set to ``True`` as
|
|
227 | 227 | # as well as recurrent key names. The out_keys are preceded by a "next" prefix
|
228 | 228 | # that indicates that they will need to be written in the "next" TensorDict.
|
229 | 229 | # We use this convention (which can be overridden by passing the in_keys/out_keys
|
230 |
| -# arguments) to make sure that a call to :func:`torchrl.envs.step_mdp` will |
| 230 | +# arguments) to make sure that a call to :func:`~torchrl.envs.utils.step_mdp` will |
231 | 231 | # move the recurrent state to the root TensorDict, making it available to the
|
232 | 232 | # RNN during the following call (see figure in the intro).
|
233 | 233 | #
|
234 | 234 | # As mentioned earlier, we have one more optional transform to add to our
|
235 | 235 | # environment to make sure that the recurrent states are passed to the buffer.
|
236 |
| -# The :meth:`torchrl.modules.LSTMModule.make_tensordict_primer` method does |
| 236 | +# The :meth:`~torchrl.modules.LSTMModule.make_tensordict_primer` method does |
237 | 237 | # exactly that:
|
238 | 238 | #
|
239 | 239 | env.append_transform(lstm.make_tensordict_primer())
|
|
268 | 268 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
269 | 269 | #
|
270 | 270 | # The last part of our policy is the Q-Value Module.
|
271 |
| -# The Q-Value module :class:`torchrl.modules.QValueModule` |
| 271 | +# The Q-Value module :class:`~torchrl.modules.tensordict_module.QValueModule` |
272 | 272 | # will read the ``"action_values"`` key that is produced by our MLP and
|
273 | 273 | # from it, gather the action that has the maximum value.
|
274 | 274 | # The only thing we need to do is to specify the action space, which can be done
|
|
280 | 280 | ######################################################################
|
281 | 281 | # .. note::
|
282 | 282 | # TorchRL also provides a wrapper class :class:`torchrl.modules.QValueActor` that
|
283 |
| -# wraps a module in a Sequential together with a :class:`torchrl.modules.QValueModule` |
| 283 | +# wraps a module in a Sequential together with a :class:`~torchrl.modules.tensordict_module.QValueModule` |
284 | 284 | # like we are doing explicitly here. There is little advantage to do this
|
285 | 285 | # and the process is less transparent, but the end results will be similar to
|
286 | 286 | # what we do here.
|
287 | 287 | #
|
288 |
| -# We can now put things together in a :class:`tensordict.nn.TensorDictSequential` |
| 288 | +# We can now put things together in a :class:`~tensordict.nn.TensorDictSequential` |
289 | 289 | #
|
290 | 290 | stoch_policy = Seq(feature, lstm, mlp, qval)
|
291 | 291 |
|
292 | 292 | ######################################################################
|
293 | 293 | # DQN being a deterministic algorithm, exploration is a crucial part of it.
|
294 | 294 | # We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
|
295 | 295 | # progressively to 0.
|
296 |
| -# This decay is achieved via a call to :meth:`torchrl.modules.EGreedyWrapper.step` |
| 296 | +# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step` |
297 | 297 | # (see training loop below).
|
298 | 298 | #
|
299 | 299 | stoch_policy = EGreedyWrapper(
|
|
311 | 311 | # To use it, we just need to tell the LSTM module to run on "recurrent-mode"
|
312 | 312 | # when used by the loss.
|
313 | 313 | # As we'll usually want to have two copies of the LSTM module, we do this by
|
314 |
| -# calling a :meth:`torchrl.modules.LSTMModule.set_recurrent_mode` method that |
| 314 | +# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that |
315 | 315 | # will return a new instance of the LSTM (with shared weights) that will
|
316 | 316 | # assume that the input data is sequential in nature.
|
317 | 317 | #
|
|
329 | 329 | #
|
330 | 330 | # Out DQN loss requires us to pass the policy and, again, the action-space.
|
331 | 331 | # While this may seem redundant, it is important as we want to make sure that
|
332 |
| -# the :class:`torchrl.objectives.DQNLoss` and the :class:`torchrl.modules.QValueModule` |
| 332 | +# the :class:`~torchrl.objectives.DQNLoss` and the :class:`~torchrl.modules.tensordict_module.QValueModule` |
333 | 333 | # classes are compatible, but aren't strongly dependent on each other.
|
334 | 334 | #
|
335 | 335 | # To use the Double-DQN, we ask for a ``delay_value`` argument that will
|
|
339 | 339 |
|
340 | 340 | ######################################################################
|
341 | 341 | # Since we are using a double DQN, we need to update the target parameters.
|
342 |
| -# We'll use a :class:`torchrl.objectives.SoftUpdate` instance to carry out |
| 342 | +# We'll use a :class:`~torchrl.objectives.SoftUpdate` instance to carry out |
343 | 343 | # this work.
|
344 | 344 | #
|
345 | 345 | updater = SoftUpdate(loss_fn, eps=0.95)
|
|
355 | 355 | # will be designed to store 20 thousands trajectories of 50 steps each.
|
356 | 356 | # At each optimization step (16 per data collection), we'll collect 4 items
|
357 | 357 | # from our buffer, for a total of 200 transitions.
|
358 |
| -# We'll use a :class:`torchrl.data.LazyMemmapStorage` storage to keep the data |
| 358 | +# We'll use a :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` storage to keep the data |
359 | 359 | # on disk.
|
360 | 360 | #
|
361 | 361 | # .. note::
|
|
394 | 394 | # it is important to pass data that is not flattened
|
395 | 395 | rb.extend(data.unsqueeze(0).to_tensordict().cpu())
|
396 | 396 | for _ in range(utd):
|
397 |
| - s = rb.sample().to(device) |
| 397 | + s = rb.sample().to(device, non_blocking=True) |
398 | 398 | loss_vals = loss_fn(s)
|
399 | 399 | loss_vals["loss"].backward()
|
400 | 400 | optim.step()
|
|
424 | 424 | # Conclusion
|
425 | 425 | # ----------
|
426 | 426 | #
|
427 |
| -# We have seen how an RNN can be incorporated in a policy in torchrl. |
| 427 | +# We have seen how an RNN can be incorporated in a policy in TorchRL. |
428 | 428 | # You should now be able:
|
429 | 429 | #
|
430 |
| -# - Create an LSTM module that acts as a :class:`tensordict.nn.TensorDictModule` |
431 |
| -# - Indicate to the LSTM module that a reset is needed via an :class:`torchrl.envs.InitTracker` |
| 430 | +# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule` |
| 431 | +# - Indicate to the LSTM module that a reset is needed via an :class:`~torchrl.envs.transforms.InitTracker` |
432 | 432 | # transform
|
433 | 433 | # - Incorporate this module in a policy and in a loss module
|
434 | 434 | # - Make sure that the collector is made aware of the recurrent state entries
|
|
437 | 437 | #
|
438 | 438 | # Further Reading
|
439 | 439 | # ---------------
|
440 |
| -# |
441 |
| -# - `TorchRL <https://pytorch.org/rl/>` |
442 |
| - |
| 440 | +# |
| 441 | +# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_. |
0 commit comments