Skip to content

Commit 143abcc

Browse files
authored
documentation: update predict_fn implementation for PyTorch EIA 1.5.1 (#2053)
1 parent 414357e commit 143abcc

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,29 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
365365
model.load_state_dict(torch.load(f))
366366
return model
367367
368-
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
368+
However, if you are using PyTorch Elastic Inference 1.3.1, you do not have to provide a ``model_fn`` since the PyTorch serving
369369
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
370370
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
371371
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load(..., map_location=torch.device('cpu'))``.
372372

373+
If you are using PyTorch Elastic Inference 1.5.1, you should provide ``model_fn`` like below in your script to use new api ``attach_eia``. Reference can be find in `Elastic Inference documentation <https://docs.aws.amazon.com/elastic-inference/latest/developerguide/ei-pytorch-using.html>`_.
374+
375+
376+
.. code:: python
377+
378+
import torch
379+
380+
381+
def model_fn(model_dir):
382+
model = torch.jit.load('model.pth', map_location=torch.device('cpu'))
383+
if torch.__version__ == '1.5.1':
384+
import torcheia
385+
model = model.eval()
386+
# attach_eia() is introduced in PyTorch Elastic Inference 1.5.1,
387+
model = torcheia.jit.attach_eia(model, 0)
388+
return model
389+
390+
373391
The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.
374392

375393
For more information on the default inference handler functions, please refer to:
@@ -416,6 +434,7 @@ The SageMaker PyTorch model server provides default implementations of these fun
416434
You can provide your own implementations for these functions in your hosting script.
417435
If you omit any definition then the SageMaker PyTorch model server will use its default implementation for that
418436
function.
437+
If you use PyTorch Elastic Inference 1.5.1, remember to implement ``predict_fn`` yourself.
419438

420439
The ``Predictor`` used by PyTorch in the SageMaker Python SDK serializes NumPy arrays to the `NPY <https://docs.scipy.org/doc/numpy/neps/npy-format.html>`_ format
421440
by default, with Content-Type ``application/x-npy``. The SageMaker PyTorch model server can deserialize NPY-formatted
@@ -547,6 +566,25 @@ block, for example:
547566
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
548567
output = model(input_data)
549568
569+
If you use PyTorch Elastic Inference 1.5.1, please implement your own ``predict_fn`` like below.
570+
571+
.. code:: python
572+
573+
import numpy as np
574+
import torch
575+
576+
577+
def predict_fn(input_data, model):
578+
device = torch.device("cpu")
579+
input_data = data.to(device)
580+
# make sure torcheia is imported so that Elastic Inference api call will be invoked
581+
import torcheia
582+
# we need to set the profiling executor for EIA
583+
torch._C._jit_set_profiling_executor(False)
584+
with torch.jit.optimized_execution(True):
585+
output = model.forward(input_data)
586+
587+
550588
Process Model Output
551589
^^^^^^^^^^^^^^^^^^^^
552590

0 commit comments

Comments
 (0)