You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: doc/frameworks/pytorch/using_pytorch.rst
+39-1Lines changed: 39 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -365,11 +365,29 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
365
365
model.load_state_dict(torch.load(f))
366
366
return model
367
367
368
-
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
368
+
However, if you are using PyTorch Elastic Inference 1.3.1, you do not have to provide a ``model_fn`` since the PyTorch serving
369
369
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
370
370
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
371
371
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load(..., map_location=torch.device('cpu'))``.
372
372
373
+
If you are using PyTorch Elastic Inference 1.5.1, you should provide ``model_fn`` like below in your script to use new api ``attach_eia``. Reference can be find in `Elastic Inference documentation <https://docs.aws.amazon.com/elastic-inference/latest/developerguide/ei-pytorch-using.html>`_.
374
+
375
+
376
+
.. code:: python
377
+
378
+
import torch
379
+
380
+
381
+
defmodel_fn(model_dir):
382
+
model = torch.jit.load('model.pth', map_location=torch.device('cpu'))
383
+
if torch.__version__=='1.5.1':
384
+
import torcheia
385
+
model = model.eval()
386
+
# attach_eia() is introduced in PyTorch Elastic Inference 1.5.1,
387
+
model = torcheia.jit.attach_eia(model, 0)
388
+
return model
389
+
390
+
373
391
The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.
374
392
375
393
For more information on the default inference handler functions, please refer to:
@@ -416,6 +434,7 @@ The SageMaker PyTorch model server provides default implementations of these fun
416
434
You can provide your own implementations for these functions in your hosting script.
417
435
If you omit any definition then the SageMaker PyTorch model server will use its default implementation for that
418
436
function.
437
+
If you use PyTorch Elastic Inference 1.5.1, remember to implement ``predict_fn`` yourself.
419
438
420
439
The ``Predictor`` used by PyTorch in the SageMaker Python SDK serializes NumPy arrays to the `NPY <https://docs.scipy.org/doc/numpy/neps/npy-format.html>`_ format
421
440
by default, with Content-Type ``application/x-npy``. The SageMaker PyTorch model server can deserialize NPY-formatted
@@ -547,6 +566,25 @@ block, for example:
547
566
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
548
567
output = model(input_data)
549
568
569
+
If you use PyTorch Elastic Inference 1.5.1, please implement your own ``predict_fn`` like below.
570
+
571
+
.. code:: python
572
+
573
+
import numpy as np
574
+
import torch
575
+
576
+
577
+
defpredict_fn(input_data, model):
578
+
device = torch.device("cpu")
579
+
input_data = data.to(device)
580
+
# make sure torcheia is imported so that Elastic Inference api call will be invoked
0 commit comments