Skip to content

documentation: add context for pytorch #3352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -415,20 +415,25 @@ Before a model can be served, it must be loaded. The SageMaker PyTorch model ser

.. code:: python

def model_fn(model_dir)
def model_fn(model_dir, context)

``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.

SageMaker will inject the directory where your model files and sub-directories, saved by ``save``, have been mounted.
Your model function should return a model object that can be used for model serving.

The following code-snippet shows an example ``model_fn`` implementation.
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``.
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``. As explained in the preceding example,
``context`` is an optional argument that passes additional information.

.. code:: python

import torch
import os

def model_fn(model_dir):
def model_fn(model_dir, context):
model = Your_Model()
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
model.load_state_dict(torch.load(f))
Expand Down Expand Up @@ -482,13 +487,13 @@ function in the chain. Inside the SageMaker PyTorch model server, the process lo
.. code:: python

# Deserialize the Invoke request body into an object we can perform prediction on
input_object = input_fn(request_body, request_content_type)
input_object = input_fn(request_body, request_content_type, context)

# Perform prediction on the deserialized object, with the loaded model
prediction = predict_fn(input_object, model)
prediction = predict_fn(input_object, model, context)

# Serialize the prediction result into the desired response content type
output = output_fn(prediction, response_content_type)
output = output_fn(prediction, response_content_type, context)

The above code sample shows the three function definitions:

Expand Down Expand Up @@ -536,9 +541,13 @@ it should return an object that can be passed to ``predict_fn`` and have the fol

.. code:: python

def input_fn(request_body, request_content_type)
def input_fn(request_body, request_content_type, context)

Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string.

``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.

The SageMaker PyTorch model server provides a default implementation of ``input_fn``.
This function deserializes JSON, CSV, or NPY encoded data into a torch.Tensor.
Expand Down Expand Up @@ -586,16 +595,19 @@ The ``predict_fn`` function has the following signature:

.. code:: python

def predict_fn(input_object, model)
def predict_fn(input_object, model, context)

Where ``input_object`` is the object returned from ``input_fn`` and
``model`` is the model loaded by ``model_fn``.
If you are using multiple GPUs, then specify the ``context`` argument, which contains information such as the GPU ID for a dynamically-selected GPU and the batch size.
One of the examples below demonstrates how to configure ``predict_fn`` with the ``context`` argument to handle multiple GPUs. For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
If you are using CPUs or a single GPU, then you do not need to specify the ``context`` argument.

The default implementation of ``predict_fn`` invokes the loaded model's ``__call__`` function on ``input_object``,
and returns the resulting value. The return-type should be a torch.Tensor to be compatible with the default
``output_fn``.

The example below shows an overridden ``predict_fn``:
The following example shows an overridden ``predict_fn``:

.. code:: python

Expand All @@ -609,6 +621,20 @@ The example below shows an overridden ``predict_fn``:
with torch.no_grad():
return model(input_data.to(device))

The following example is for use cases with multiple GPUs and shows an overridden ``predict_fn`` that uses the ``context`` argument to dynamically select a GPU device for making predictions:

.. code:: python

import torch
import numpy as np

def predict_fn(input_data, model):
device = torch.device("cuda:" + str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
return model(input_data.to(device))

If you implement your own prediction function, you should take care to ensure that:

- The first argument is expected to be the return value from input_fn.
Expand Down Expand Up @@ -664,11 +690,14 @@ The ``output_fn`` has the following signature:

.. code:: python

def output_fn(prediction, content_type)
def output_fn(prediction, content_type, context)

Where ``prediction`` is the result of invoking ``predict_fn`` and
the content type for the response, as specified by the InvokeEndpoint request.
The function should return a byte array of data serialized to content_type.
the content type for the response, as specified by the InvokeEndpoint request. The function should return a byte array of data serialized to ``content_type``.

``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.

The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
It accepts response content types of "application/json", "text/csv", and "application/x-npy".
Expand Down