Skip to content

Commit eec1640

Browse files
committed
Updates per review
* mostly better looking links * torchrl and tensordict bump to 0.2.1 to support MacOS * updated image * updated Further Reading to go to TorchRL docs Signed-off-by: markstur <[email protected]>
1 parent c549d04 commit eec1640

File tree

4 files changed

+32
-34
lines changed

4 files changed

+32
-34
lines changed

_static/img/rollout_recurrent.png

35.5 KB
Loading

en-wordlist.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ TorchDynamo
208208
TorchInductor
209209
TorchMultimodal
210210
TorchRL
211-
torchrl
212211
TorchRL's
213212
TorchScript
214213
TorchX

intermediate_source/dqn_with_rnn_tutorial.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# consecutive steps, and use this as an input to the policy along with the
3434
# current observation.
3535
#
36-
# This tutorial shows how to incorporate an RNN in a policy.
36+
# This tutorial shows how to incorporate an RNN in a policy using TorchRL.
3737
#
3838
# Key learnings:
3939
#
@@ -51,7 +51,7 @@
5151
# As this figure shows, our environment populates the TensorDict with zeroed recurrent
5252
# states which are read by the policy together with the observation to produce an
5353
# action, and recurrent states that will be used for the next step.
54-
# When the :func:`torchrl.envs.step_mdp` function is called, the recurrent states
54+
# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states
5555
# from the next state are brought to the current TensorDict. Let's see how this
5656
# is implemented in practice.
5757

@@ -60,7 +60,7 @@
6060
#
6161
# .. code-block:: bash
6262
#
63-
# !pip3 install torchrl-nightly
63+
# !pip3 install torchrl
6464
# !pip3 install gym[mujoco]
6565
# !pip3 install tqdm
6666
#
@@ -104,17 +104,17 @@
104104
# 84x84, scaling down the rewards and normalizing the observations.
105105
#
106106
# .. note::
107-
# The :class:`torchrl.envs.StepCounter` transform is accessory. Since the CartPole
107+
# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole
108108
# task goal is to make trajectories as long as possible, counting the steps
109109
# can help us track the performance of our policy.
110110
#
111111
# Two transforms are important for the purpose of this tutorial:
112112
#
113-
# - :class:`torchrl.envs.InitTracker` will stamp the
114-
# calls to :meth:`torchrl.envs.EnvBase.reset` by adding a ``"is_init"``
113+
# - :class:`~torchrl.envs.transforms.InitTracker` will stamp the
114+
# calls to :meth:`~torchrl.envs.EnvBase.reset` by adding a ``"is_init"``
115115
# boolean mask in the TensorDict that will track which steps require a reset
116116
# of the RNN hidden states.
117-
# - The :class:`torchrl.envs.TensorDictPrimer` transform is a bit more
117+
# - The :class:`~torchrl.envs.transforms.TensorDictPrimer` transform is a bit more
118118
# technical. It is not required to use RNN policies. However, it
119119
# instructs the environment (and subsequently the collector) that some extra
120120
# keys are to be expected. Once added, a call to `env.reset()` will populate
@@ -127,7 +127,7 @@
127127
# the training of our policy, but it will make the recurrent keys disappear
128128
# from the collected data and the replay buffer, which will in turn lead to
129129
# a slightly less optimal training.
130-
# Fortunately, the :class:`torchrl.modules.LSTMModule` we propose is
130+
# Fortunately, the :class:`~torchrl.modules.LSTMModule` we propose is
131131
# equipped with a helper method to build just that transform for us, so
132132
# we can wait until we build it!
133133
#
@@ -155,16 +155,16 @@
155155
# Policy
156156
# ------
157157
#
158-
# Our policy will have 3 components: a :class:`torchrl.modules.ConvNet`
159-
# backbone, an :class:`torchrl.modules.LSTMModule` memory layer and a shallow
160-
# :class:`torchrl.modules.MLP` block that will map the LSTM output onto the
158+
# Our policy will have 3 components: a :class:`~torchrl.modules.ConvNet`
159+
# backbone, an :class:`~torchrl.modules.LSTMModule` memory layer and a shallow
160+
# :class:`~torchrl.modules.MLP` block that will map the LSTM output onto the
161161
# action values.
162162
#
163163
# Convolutional network
164164
# ~~~~~~~~~~~~~~~~~~~~~
165165
#
166166
# We build a convolutional network flanked with a :class:`torch.nn.AdaptiveAvgPool2d`
167-
# that will squash the output in a vector of size 64. The :class:`torchrl.modules.ConvNet`
167+
# that will squash the output in a vector of size 64. The :class:`~torchrl.modules.ConvNet`
168168
# can assist us with this:
169169
#
170170

@@ -189,8 +189,8 @@
189189
# LSTM Module
190190
# ~~~~~~~~~~~
191191
#
192-
# TorchRL provides a specialized :class:`torchrl.modules.LSTMModule` class
193-
# to incorporate LSTMs in your code-base. It is a :class:`tensordict.nn.TensorDictModuleBase`
192+
# TorchRL provides a specialized :class:`~torchrl.modules.LSTMModule` class
193+
# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase`
194194
# subclass: as such, it has a set of ``in_keys`` and ``out_keys`` that indicate
195195
# what values should be expected to be read and written/updated during the
196196
# execution of the module. The class comes with customizable predefined
@@ -201,7 +201,7 @@
201201
# dropout or multi-layered LSTMs.
202202
# However, to respect TorchRL's conventions, this LSTM must have the ``batch_first``
203203
# attribute set to ``True`` which is **not** the default in PyTorch. However,
204-
# our :class:`torchrl.modules.LSTMModule` changes this default
204+
# our :class:`~torchrl.modules.LSTMModule` changes this default
205205
# behavior, so we're good with a native call.
206206
#
207207
# Also, the LSTM cannot have a ``bidirectional`` attribute set to ``True`` as
@@ -227,13 +227,13 @@
227227
# as well as recurrent key names. The out_keys are preceded by a "next" prefix
228228
# that indicates that they will need to be written in the "next" TensorDict.
229229
# We use this convention (which can be overridden by passing the in_keys/out_keys
230-
# arguments) to make sure that a call to :func:`torchrl.envs.step_mdp` will
230+
# arguments) to make sure that a call to :func:`~torchrl.envs.utils.step_mdp` will
231231
# move the recurrent state to the root TensorDict, making it available to the
232232
# RNN during the following call (see figure in the intro).
233233
#
234234
# As mentioned earlier, we have one more optional transform to add to our
235235
# environment to make sure that the recurrent states are passed to the buffer.
236-
# The :meth:`torchrl.modules.LSTMModule.make_tensordict_primer` method does
236+
# The :meth:`~torchrl.modules.LSTMModule.make_tensordict_primer` method does
237237
# exactly that:
238238
#
239239
env.append_transform(lstm.make_tensordict_primer())
@@ -268,7 +268,7 @@
268268
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
269269
#
270270
# The last part of our policy is the Q-Value Module.
271-
# The Q-Value module :class:`torchrl.modules.QValueModule`
271+
# The Q-Value module :class:`~torchrl.modules.tensordict_module.QValueModule`
272272
# will read the ``"action_values"`` key that is produced by our MLP and
273273
# from it, gather the action that has the maximum value.
274274
# The only thing we need to do is to specify the action space, which can be done
@@ -280,20 +280,20 @@
280280
######################################################################
281281
# .. note::
282282
# TorchRL also provides a wrapper class :class:`torchrl.modules.QValueActor` that
283-
# wraps a module in a Sequential together with a :class:`torchrl.modules.QValueModule`
283+
# wraps a module in a Sequential together with a :class:`~torchrl.modules.tensordict_module.QValueModule`
284284
# like we are doing explicitly here. There is little advantage to do this
285285
# and the process is less transparent, but the end results will be similar to
286286
# what we do here.
287287
#
288-
# We can now put things together in a :class:`tensordict.nn.TensorDictSequential`
288+
# We can now put things together in a :class:`~tensordict.nn.TensorDictSequential`
289289
#
290290
stoch_policy = Seq(feature, lstm, mlp, qval)
291291

292292
######################################################################
293293
# DQN being a deterministic algorithm, exploration is a crucial part of it.
294294
# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
295295
# progressively to 0.
296-
# This decay is achieved via a call to :meth:`torchrl.modules.EGreedyWrapper.step`
296+
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step`
297297
# (see training loop below).
298298
#
299299
stoch_policy = EGreedyWrapper(
@@ -311,7 +311,7 @@
311311
# To use it, we just need to tell the LSTM module to run on "recurrent-mode"
312312
# when used by the loss.
313313
# As we'll usually want to have two copies of the LSTM module, we do this by
314-
# calling a :meth:`torchrl.modules.LSTMModule.set_recurrent_mode` method that
314+
# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that
315315
# will return a new instance of the LSTM (with shared weights) that will
316316
# assume that the input data is sequential in nature.
317317
#
@@ -329,7 +329,7 @@
329329
#
330330
# Out DQN loss requires us to pass the policy and, again, the action-space.
331331
# While this may seem redundant, it is important as we want to make sure that
332-
# the :class:`torchrl.objectives.DQNLoss` and the :class:`torchrl.modules.QValueModule`
332+
# the :class:`~torchrl.objectives.DQNLoss` and the :class:`~torchrl.modules.tensordict_module.QValueModule`
333333
# classes are compatible, but aren't strongly dependent on each other.
334334
#
335335
# To use the Double-DQN, we ask for a ``delay_value`` argument that will
@@ -339,7 +339,7 @@
339339

340340
######################################################################
341341
# Since we are using a double DQN, we need to update the target parameters.
342-
# We'll use a :class:`torchrl.objectives.SoftUpdate` instance to carry out
342+
# We'll use a :class:`~torchrl.objectives.SoftUpdate` instance to carry out
343343
# this work.
344344
#
345345
updater = SoftUpdate(loss_fn, eps=0.95)
@@ -355,7 +355,7 @@
355355
# will be designed to store 20 thousands trajectories of 50 steps each.
356356
# At each optimization step (16 per data collection), we'll collect 4 items
357357
# from our buffer, for a total of 200 transitions.
358-
# We'll use a :class:`torchrl.data.LazyMemmapStorage` storage to keep the data
358+
# We'll use a :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` storage to keep the data
359359
# on disk.
360360
#
361361
# .. note::
@@ -394,7 +394,7 @@
394394
# it is important to pass data that is not flattened
395395
rb.extend(data.unsqueeze(0).to_tensordict().cpu())
396396
for _ in range(utd):
397-
s = rb.sample().to(device)
397+
s = rb.sample().to(device, non_blocking=True)
398398
loss_vals = loss_fn(s)
399399
loss_vals["loss"].backward()
400400
optim.step()
@@ -424,11 +424,11 @@
424424
# Conclusion
425425
# ----------
426426
#
427-
# We have seen how an RNN can be incorporated in a policy in torchrl.
427+
# We have seen how an RNN can be incorporated in a policy in TorchRL.
428428
# You should now be able:
429429
#
430-
# - Create an LSTM module that acts as a :class:`tensordict.nn.TensorDictModule`
431-
# - Indicate to the LSTM module that a reset is needed via an :class:`torchrl.envs.InitTracker`
430+
# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule`
431+
# - Indicate to the LSTM module that a reset is needed via an :class:`~torchrl.envs.transforms.InitTracker`
432432
# transform
433433
# - Incorporate this module in a policy and in a loss module
434434
# - Make sure that the collector is made aware of the recurrent state entries
@@ -437,6 +437,5 @@
437437
#
438438
# Further Reading
439439
# ---------------
440-
#
441-
# - `TorchRL <https://pytorch.org/rl/>`
442-
440+
#
441+
# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jinja2==3.0.3
2626
pytorch-lightning
2727
torchx
2828
torchrl==0.2.1
29-
tensordict==0.2.0
29+
tensordict==0.2.1
3030
ax-platform
3131
nbformat>=4.2.0
3232
datasets

0 commit comments

Comments
 (0)