Skip to content

Commit e0000bb

Browse files
authored
Merge branch 'master' into fix-allow-docker-6
2 parents 61d53bc + e8ee340 commit e0000bb

26 files changed

+957
-451
lines changed

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
# Changelog
22

3+
## v2.110.0 (2022-09-27)
4+
5+
### Features
6+
7+
* Support KeepAlivePeriodInSeconds for Training APIs
8+
* added ANALYSIS_CONFIG_SCHEMA_V1_0 in clarify
9+
* add model monitor image accounts for ap-southeast-3
10+
11+
### Bug Fixes and Other Changes
12+
13+
* huggingface release test
14+
* Fixing the logic to return instanceCount for heterogeneousClusters
15+
* Disable type hints in doc signature and add PipelineVariable annotations in docstring
16+
* estimator hyperparameters in script mode
17+
18+
### Documentation Changes
19+
20+
* Added link to example notebook for Pipelines local mode
21+
322
## v2.109.0 (2022-09-09)
423

524
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.109.1.dev0
1+
2.110.1.dev0

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,25 @@ Before a model can be served, it must be loaded. The SageMaker PyTorch model ser
415415

416416
.. code:: python
417417
418-
def model_fn(model_dir)
418+
def model_fn(model_dir, context)
419+
420+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
421+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
422+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
419423

420424
SageMaker will inject the directory where your model files and sub-directories, saved by ``save``, have been mounted.
421425
Your model function should return a model object that can be used for model serving.
422426

423427
The following code-snippet shows an example ``model_fn`` implementation.
424-
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``.
428+
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``. As explained in the preceding example,
429+
``context`` is an optional argument that passes additional information.
425430

426431
.. code:: python
427432
428433
import torch
429434
import os
430435
431-
def model_fn(model_dir):
436+
def model_fn(model_dir, context):
432437
model = Your_Model()
433438
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
434439
model.load_state_dict(torch.load(f))
@@ -482,13 +487,13 @@ function in the chain. Inside the SageMaker PyTorch model server, the process lo
482487
.. code:: python
483488
484489
# Deserialize the Invoke request body into an object we can perform prediction on
485-
input_object = input_fn(request_body, request_content_type)
490+
input_object = input_fn(request_body, request_content_type, context)
486491
487492
# Perform prediction on the deserialized object, with the loaded model
488-
prediction = predict_fn(input_object, model)
493+
prediction = predict_fn(input_object, model, context)
489494
490495
# Serialize the prediction result into the desired response content type
491-
output = output_fn(prediction, response_content_type)
496+
output = output_fn(prediction, response_content_type, context)
492497
493498
The above code sample shows the three function definitions:
494499

@@ -536,9 +541,13 @@ it should return an object that can be passed to ``predict_fn`` and have the fol
536541

537542
.. code:: python
538543
539-
def input_fn(request_body, request_content_type)
544+
def input_fn(request_body, request_content_type, context)
540545
541-
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string
546+
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string.
547+
548+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
549+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
550+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
542551

543552
The SageMaker PyTorch model server provides a default implementation of ``input_fn``.
544553
This function deserializes JSON, CSV, or NPY encoded data into a torch.Tensor.
@@ -586,16 +595,19 @@ The ``predict_fn`` function has the following signature:
586595

587596
.. code:: python
588597
589-
def predict_fn(input_object, model)
598+
def predict_fn(input_object, model, context)
590599
591600
Where ``input_object`` is the object returned from ``input_fn`` and
592601
``model`` is the model loaded by ``model_fn``.
602+
If you are using multiple GPUs, then specify the ``context`` argument, which contains information such as the GPU ID for a dynamically-selected GPU and the batch size.
603+
One of the examples below demonstrates how to configure ``predict_fn`` with the ``context`` argument to handle multiple GPUs. For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
604+
If you are using CPUs or a single GPU, then you do not need to specify the ``context`` argument.
593605

594606
The default implementation of ``predict_fn`` invokes the loaded model's ``__call__`` function on ``input_object``,
595607
and returns the resulting value. The return-type should be a torch.Tensor to be compatible with the default
596608
``output_fn``.
597609

598-
The example below shows an overridden ``predict_fn``:
610+
The following example shows an overridden ``predict_fn``:
599611

600612
.. code:: python
601613
@@ -609,6 +621,20 @@ The example below shows an overridden ``predict_fn``:
609621
with torch.no_grad():
610622
return model(input_data.to(device))
611623
624+
The following example is for use cases with multiple GPUs and shows an overridden ``predict_fn`` that uses the ``context`` argument to dynamically select a GPU device for making predictions:
625+
626+
.. code:: python
627+
628+
import torch
629+
import numpy as np
630+
631+
def predict_fn(input_data, model):
632+
device = torch.device("cuda:" + str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
633+
model.to(device)
634+
model.eval()
635+
with torch.no_grad():
636+
return model(input_data.to(device))
637+
612638
If you implement your own prediction function, you should take care to ensure that:
613639

614640
- The first argument is expected to be the return value from input_fn.
@@ -664,11 +690,14 @@ The ``output_fn`` has the following signature:
664690

665691
.. code:: python
666692
667-
def output_fn(prediction, content_type)
693+
def output_fn(prediction, content_type, context)
668694
669695
Where ``prediction`` is the result of invoking ``predict_fn`` and
670-
the content type for the response, as specified by the InvokeEndpoint request.
671-
The function should return a byte array of data serialized to content_type.
696+
the content type for the response, as specified by the InvokeEndpoint request. The function should return a byte array of data serialized to ``content_type``.
697+
698+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
699+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
700+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
672701

673702
The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
674703
It accepts response content types of "application/json", "text/csv", and "application/x-npy".

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ black==22.3.0
1313
stopit==1.1.2
1414
apache-airflow==2.3.4
1515
apache-airflow-providers-amazon==4.0.0
16-
attrs==20.3.0
16+
attrs==22.1.0
1717
fabric==2.6.0
1818
requests==2.27.1
1919
sagemaker-experiments==0.1.35

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def read_requirements(filename):
4747

4848
# Declare minimal set for installation
4949
required_packages = [
50-
"attrs>=20.3.0,<22",
50+
"attrs>=20.3.0,<23",
5151
"boto3>=1.20.21,<2.0",
5252
"google-pasta",
5353
"numpy>=1.9.0,<2.0",
@@ -58,6 +58,7 @@ def read_requirements(filename):
5858
"packaging>=20.0",
5959
"pandas",
6060
"pathos",
61+
"schema",
6162
]
6263

6364
# Specific use case dependencies

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19-
from typing import Union
19+
from typing import Union, Optional, Dict
2020

2121
from six.moves.urllib.parse import urlparse
2222

@@ -30,6 +30,7 @@
3030
from sagemaker.utils import sagemaker_timestamp
3131
from sagemaker.workflow.entities import PipelineVariable
3232
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
33+
from sagemaker.workflow import is_pipeline_variable
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -40,18 +41,20 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4041
This class isn't intended to be instantiated directly.
4142
"""
4243

43-
feature_dim = hp("feature_dim", validation.gt(0), data_type=int)
44-
mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int)
45-
repo_name = None
46-
repo_version = None
44+
feature_dim: hp = hp("feature_dim", validation.gt(0), data_type=int)
45+
mini_batch_size: hp = hp("mini_batch_size", validation.gt(0), data_type=int)
46+
repo_name: Optional[str] = None
47+
repo_version: Optional[str] = None
48+
49+
DEFAULT_MINI_BATCH_SIZE: Optional[int] = None
4750

4851
def __init__(
4952
self,
50-
role,
51-
instance_count=None,
52-
instance_type=None,
53-
data_location=None,
54-
enable_network_isolation=False,
53+
role: str,
54+
instance_count: Optional[Union[int, PipelineVariable]] = None,
55+
instance_type: Optional[Union[str, PipelineVariable]] = None,
56+
data_location: Optional[str] = None,
57+
enable_network_isolation: Union[bool, PipelineVariable] = False,
5558
**kwargs
5659
):
5760
"""Initialize an AmazonAlgorithmEstimatorBase.
@@ -62,16 +65,16 @@ def __init__(
6265
endpoints use this role to access training data and model
6366
artifacts. After the endpoint is created, the inference code
6467
might use the IAM role, if it needs to access an AWS resource.
65-
instance_count (int): Number of Amazon EC2 instances to use
68+
instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
6669
for training. Required.
67-
instance_type (str): Type of EC2 instance to use for training,
70+
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
6871
for example, 'ml.c4.xlarge'. Required.
6972
data_location (str or None): The s3 prefix to upload RecordSet
7073
objects to, expressed as an S3 url. For example
7174
"s3://example-bucket/some-key-prefix/". Objects will be saved in
7275
a unique sub-directory of the specified location. If None, a
7376
default data location will be used.
74-
enable_network_isolation (bool): Specifies whether container will
77+
enable_network_isolation (bool or PipelineVariable): Specifies whether container will
7578
run in network isolation mode. Network isolation mode restricts
7679
the container access to outside networks (such as the internet).
7780
Also known as internet-free mode (default: ``False``).
@@ -113,8 +116,14 @@ def data_location(self):
113116
return self._data_location
114117

115118
@data_location.setter
116-
def data_location(self, data_location):
119+
def data_location(self, data_location: str):
117120
"""Placeholder docstring"""
121+
if is_pipeline_variable(data_location):
122+
raise TypeError(
123+
"Invalid input: data_location should be a plain string "
124+
"rather than a pipeline variable - ({}).".format(type(data_location))
125+
)
126+
118127
if not data_location.startswith("s3://"):
119128
raise ValueError(
120129
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
@@ -198,12 +207,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
198207
@runnable_by_pipeline
199208
def fit(
200209
self,
201-
records,
202-
mini_batch_size=None,
203-
wait=True,
204-
logs=True,
205-
job_name=None,
206-
experiment_config=None,
210+
records: "RecordSet",
211+
mini_batch_size: Optional[int] = None,
212+
wait: bool = True,
213+
logs: bool = True,
214+
job_name: Optional[str] = None,
215+
experiment_config: Optional[Dict[str, str]] = None,
207216
):
208217
"""Fit this Estimator on serialized Record objects, stored in S3.
209218
@@ -301,6 +310,20 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
301310
channel=channel,
302311
)
303312

313+
def _get_default_mini_batch_size(self, num_records: int):
314+
"""Generate the default mini_batch_size"""
315+
if is_pipeline_variable(self.instance_count):
316+
logger.warning(
317+
"mini_batch_size is not given in .fit() and instance_count is a "
318+
"pipeline variable (%s) which is only interpreted in pipeline execution time. "
319+
"Thus setting mini_batch_size to 1, since it can't be greater than "
320+
"number of records per instance_count, otherwise the training job fails.",
321+
type(self.instance_count),
322+
)
323+
return 1
324+
325+
return min(self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)))
326+
304327

305328
class RecordSet(object):
306329
"""Placeholder docstring"""
@@ -461,7 +484,7 @@ def upload_numpy_to_s3_shards(
461484
raise ex
462485

463486

464-
def get_image_uri(region_name, repo_name, repo_version=1):
487+
def get_image_uri(region_name, repo_name, repo_version="1"):
465488
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
466489
467490
Args:

0 commit comments

Comments
 (0)