Skip to content

Commit dc0de9e

Browse files
committed
Address comments
1 parent 4c36093 commit dc0de9e

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

beginner_source/basics/saveloadrun_tutorial.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
##########################
3333
# To load model weights, you need to create an instance of the same model first, and then load the parameters
3434
# using ``load_state_dict()`` method.
35+
#
36+
# In the code below, we set ``weights_only=True`` to limit the
37+
# functions executed during unpickling to only those necessary for
38+
# loading weights. Using ``weights_only=True`` is considered
39+
# a best practice when loading weights.
3540

3641
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
3742
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
@@ -50,7 +55,12 @@
5055
torch.save(model, 'model.pth')
5156

5257
########################
53-
# We can then load the model like this:
58+
# We can then load the model as demonstrated below.
59+
#
60+
# As described in `Saving and loading torch.nn.Modules <pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`__,
61+
# saving ``state_dict``s is considered the best practice. However,
62+
# below we use ``weights_only=False`` because this involves loading the
63+
# model, which is a legacy use case for ``torch.save``.
5464

5565
model = torch.load('model.pth', weights_only=False),
5666

0 commit comments

Comments
 (0)