Skip to content

Commit 88bafc0

Browse files
author
Deng
committed
address comments
1 parent efba0e1 commit 88bafc0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

doc/using_pytorch.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ For example:
126126
import os
127127
import torch
128128
129-
model_path = os.path.join(model_dir, "model.pt")
130-
model = torch.jit.load(model_path)
129+
# ... train `model`, then save it to `model_dir`
130+
model_dir = os.path.join(model_dir, "model.pt")
131+
torch.jit.save(model, model_dir)
131132
132133
Using third-party libraries
133134
---------------------------
@@ -315,7 +316,7 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
315316
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
316317
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
317318
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
318-
to save your ScriptModule. For more information on inference script, please refer to:
319+
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to:
319320
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
320321

321322
Serve a PyTorch Model

0 commit comments

Comments
 (0)