Skip to content

Commit 8cfabd0

Browse files
committed
Merge branch 'master' of https://github.com/aws/sagemaker-python-sdk into smddp-1.4.0-doc
2 parents b85f772 + 7e2c7ab commit 8cfabd0

File tree

15 files changed

+561
-38
lines changed

15 files changed

+561
-38
lines changed

doc/overview.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -773,11 +773,10 @@ Deployment may take about 5 minutes.
773773
   instance_type=instance_type,
774774
)
775775
776-
Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers
777-
image, the corresponding Models and Endpoints display the “pytorch”
778-
prefix when viewed in the AWS console. To verify that these models
779-
were created successfully with your desired base model, refer to
780-
the ``Tags`` section.
776+
Because the model and script URIs are distributed by SageMaker JumpStart,
777+
the endpoint, endpoint config and model resources will be prefixed with
778+
``sagemaker-jumpstart``. Refer to the model ``Tags`` to inspect the
779+
JumpStart artifacts involved in the model creation.
781780

782781
Perform Inference
783782
-----------------

src/sagemaker/estimator.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from sagemaker.job import _Job
5151
from sagemaker.jumpstart.utils import (
5252
add_jumpstart_tags,
53+
get_jumpstart_base_name_if_jumpstart_model,
5354
update_inference_tags_with_jumpstart_training_tags,
5455
)
5556
from sagemaker.local import LocalSession
@@ -569,8 +570,11 @@ def prepare_workflow_for_training(self, job_name=None):
569570
def _ensure_base_job_name(self):
570571
"""Set ``self.base_job_name`` if it is not set already."""
571572
# honor supplied base_job_name or generate it
572-
if self.base_job_name is None:
573-
self.base_job_name = base_name_from_image(self.training_image_uri())
573+
self.base_job_name = (
574+
self.base_job_name
575+
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
576+
or base_name_from_image(self.training_image_uri())
577+
)
574578

575579
def _get_or_create_name(self, name=None):
576580
"""Generate a name based on the base job name or training image if needed.
@@ -1208,7 +1212,15 @@ def deploy(
12081212
is_serverless = serverless_inference_config is not None
12091213
self._ensure_latest_training_job()
12101214
self._ensure_base_job_name()
1211-
default_name = name_from_base(self.base_job_name)
1215+
1216+
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1217+
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
1218+
)
1219+
default_name = (
1220+
name_from_base(jumpstart_base_name)
1221+
if jumpstart_base_name
1222+
else name_from_base(self.base_job_name)
1223+
)
12121224
endpoint_name = endpoint_name or default_name
12131225
model_name = model_name or default_name
12141226

src/sagemaker/huggingface/model.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.predictor import Predictor
2727
from sagemaker.serializers import JSONSerializer
28+
from sagemaker.session import Session
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -169,9 +170,125 @@ def __init__(
169170
super(HuggingFaceModel, self).__init__(
170171
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
171172
)
173+
self.sagemaker_session = self.sagemaker_session or Session()
172174

173175
self.model_server_workers = model_server_workers
174176

177+
# TODO: Remove the following function
178+
# botocore needs to add hugginface to the list of valid neo compilable frameworks.
179+
# Ideally with inferentia framewrok, call to .compile( ... ) method will create the image_uri.
180+
# currently, call to compile( ... ) method is causing `ValidationException`
181+
def deploy(
182+
self,
183+
initial_instance_count=None,
184+
instance_type=None,
185+
serializer=None,
186+
deserializer=None,
187+
accelerator_type=None,
188+
endpoint_name=None,
189+
tags=None,
190+
kms_key=None,
191+
wait=True,
192+
data_capture_config=None,
193+
async_inference_config=None,
194+
serverless_inference_config=None,
195+
**kwargs,
196+
):
197+
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
198+
199+
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an
200+
``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None,
201+
this method returns a the result of invoking ``self.predictor_cls`` on
202+
the created endpoint name.
203+
204+
The name of the created model is accessible in the ``name`` field of
205+
this ``Model`` after deploy returns
206+
207+
The name of the created endpoint is accessible in the
208+
``endpoint_name`` field of this ``Model`` after deploy returns.
209+
210+
Args:
211+
initial_instance_count (int): The initial number of instances to run
212+
in the ``Endpoint`` created from this ``Model``. If not using
213+
serverless inference, then it need to be a number larger or equals
214+
to 1 (default: None)
215+
instance_type (str): The EC2 instance type to deploy this Model to.
216+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
217+
serverless inference, then it is required to deploy a model.
218+
(default: None)
219+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
220+
serializer object, used to encode data for an inference endpoint
221+
(default: None). If ``serializer`` is not None, then
222+
``serializer`` will override the default serializer. The
223+
default serializer is set by the ``predictor_cls``.
224+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
225+
deserializer object, used to decode data from an inference
226+
endpoint (default: None). If ``deserializer`` is not None, then
227+
``deserializer`` will override the default deserializer. The
228+
default deserializer is set by the ``predictor_cls``.
229+
accelerator_type (str): Type of Elastic Inference accelerator to
230+
deploy this model for model loading and inference, for example,
231+
'ml.eia1.medium'. If not specified, no Elastic Inference
232+
accelerator will be attached to the endpoint. For more
233+
information:
234+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
235+
endpoint_name (str): The name of the endpoint to create (default:
236+
None). If not specified, a unique endpoint name will be created.
237+
tags (List[dict[str, str]]): The list of tags to attach to this
238+
specific endpoint.
239+
kms_key (str): The ARN of the KMS key that is used to encrypt the
240+
data on the storage volume attached to the instance hosting the
241+
endpoint.
242+
wait (bool): Whether the call should wait until the deployment of
243+
this model completes (default: True).
244+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
245+
configuration related to Endpoint data capture for use with
246+
Amazon SageMaker Model Monitoring. Default: None.
247+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
248+
configuration related to async endpoint. Use this configuration when trying
249+
to create async endpoint and make async inference. If empty config object
250+
passed through, will use default config to deploy async endpoint. Deploy a
251+
real-time endpoint if it's None. (default: None)
252+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
253+
Specifies configuration related to serverless endpoint. Use this configuration
254+
when trying to create serverless endpoint and make serverless inference. If
255+
empty object passed through, will use pre-defined values in
256+
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
257+
instance based endpoint if it's None. (default: None)
258+
Raises:
259+
ValueError: If arguments combination check failed in these circumstances:
260+
- If no role is specified or
261+
- If serverless inference config is not specified and instance type and instance
262+
count are also not specified or
263+
- If a wrong type of object is provided as serverless inference config or async
264+
inference config
265+
Returns:
266+
callable[string, sagemaker.session.Session] or None: Invocation of
267+
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
268+
is not None. Otherwise, return None.
269+
"""
270+
271+
if not self.image_uri and instance_type.startswith("ml.inf"):
272+
self.image_uri = self.serving_image_uri(
273+
region_name=self.sagemaker_session.boto_session.region_name,
274+
instance_type=instance_type,
275+
)
276+
277+
return super(HuggingFaceModel, self).deploy(
278+
initial_instance_count,
279+
instance_type,
280+
serializer,
281+
deserializer,
282+
accelerator_type,
283+
endpoint_name,
284+
tags,
285+
kms_key,
286+
wait,
287+
data_capture_config,
288+
async_inference_config,
289+
serverless_inference_config,
290+
)
291+
175292
def register(
176293
self,
177294
content_types,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"inference": {
3+
"processors": ["inf"],
4+
"version_aliases": {"4.12": "4.12.3"},
5+
"versions": {
6+
"4.12.3": {
7+
"version_aliases": {"pytorch1.9": "pytorch1.9.1"},
8+
"pytorch1.9.1": {
9+
"py_versions": ["py37"],
10+
"repository": "huggingface-pytorch-inference-neuron",
11+
"registries": {
12+
"af-south-1": "626614931356",
13+
"ap-east-1": "871362719292",
14+
"ap-northeast-1": "763104351884",
15+
"ap-northeast-2": "763104351884",
16+
"ap-northeast-3": "364406365360",
17+
"ap-south-1": "763104351884",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ca-central-1": "763104351884",
21+
"cn-north-1": "727897471807",
22+
"cn-northwest-1": "727897471807",
23+
"eu-central-1": "763104351884",
24+
"eu-north-1": "763104351884",
25+
"eu-west-1": "763104351884",
26+
"eu-west-2": "763104351884",
27+
"eu-west-3": "763104351884",
28+
"eu-south-1": "692866216735",
29+
"me-south-1": "217643126080",
30+
"sa-east-1": "763104351884",
31+
"us-east-1": "763104351884",
32+
"us-east-2": "763104351884",
33+
"us-gov-west-1": "442386744353",
34+
"us-iso-east-1": "886529160074",
35+
"us-west-1": "763104351884",
36+
"us-west-2": "763104351884"
37+
},
38+
"container_version": {"inf": "ubuntu18.04"},
39+
"sdk_versions": ["sdk1.17.1"]
40+
}
41+
}
42+
}
43+
}
44+
}

src/sagemaker/image_uris.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
2626

27-
2827
logger = logging.getLogger(__name__)
2928

3029
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
@@ -47,6 +46,8 @@ def retrieve(
4746
model_version=None,
4847
tolerate_vulnerable_model=False,
4948
tolerate_deprecated_model=False,
49+
sdk_version=None,
50+
inference_tool=None,
5051
) -> str:
5152
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5253
@@ -88,6 +89,11 @@ def retrieve(
8889
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
8990
should be tolerated without an exception raised. If False, raises an exception
9091
if the version of the model is deprecated. (Default: False).
92+
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
93+
(default: None).
94+
inference_tool (str): the tool that will be used to aid in the inference.
95+
Valid values: "neuron, None"
96+
(default: None).
9197
9298
Returns:
9399
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -100,7 +106,6 @@ def retrieve(
100106
DeprecatedJumpStartModelError: If the version of the model is deprecated.
101107
"""
102108
if is_jumpstart_model_input(model_id, model_version):
103-
104109
return artifacts._retrieve_image_uri(
105110
model_id,
106111
model_version,
@@ -118,9 +123,13 @@ def retrieve(
118123
tolerate_vulnerable_model,
119124
tolerate_deprecated_model,
120125
)
121-
122126
if training_compiler_config is None:
123-
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
127+
_framework = framework
128+
if framework == HUGGING_FACE_FRAMEWORK:
129+
inference_tool = _get_inference_tool(inference_tool, instance_type)
130+
if inference_tool == "neuron":
131+
_framework = f"{framework}-{inference_tool}"
132+
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
124133
elif framework == HUGGING_FACE_FRAMEWORK:
125134
config = _config_for_framework_and_scope(
126135
framework + "-training-compiler", image_scope, accelerator_type
@@ -129,6 +138,7 @@ def retrieve(
129138
raise ValueError(
130139
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
131140
)
141+
132142
original_version = version
133143
version = _validate_version_and_set_if_needed(version, config, framework)
134144
version_config = config["versions"][_version_for_config(version, config)]
@@ -138,7 +148,6 @@ def retrieve(
138148
full_base_framework_version = version_config["version_aliases"].get(
139149
base_framework_version, base_framework_version
140150
)
141-
142151
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
143152
version_config = version_config.get(full_base_framework_version)
144153

@@ -161,25 +170,37 @@ def retrieve(
161170
pt_or_tf_version = (
162171
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
163172
)
164-
165173
_version = original_version
174+
166175
if repo in [
167176
"huggingface-pytorch-trcomp-training",
168177
"huggingface-tensorflow-trcomp-training",
169178
]:
170179
_version = version
180+
if repo in ["huggingface-pytorch-inference-neuron"]:
181+
if not sdk_version:
182+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
183+
container_version = sdk_version + "-" + container_version
184+
if config.get("version_aliases").get(original_version):
185+
_version = config.get("version_aliases")[original_version]
186+
if (
187+
config.get("versions", {})
188+
.get(_version, {})
189+
.get("version_aliases", {})
190+
.get(base_framework_version, {})
191+
):
192+
_base_framework_version = config.get("versions")[_version]["version_aliases"][
193+
base_framework_version
194+
]
195+
pt_or_tf_version = (
196+
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
197+
)
171198

172199
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
173-
174200
else:
175201
tag_prefix = version_config.get("tag_prefix", version)
176202

177-
tag = _format_tag(
178-
tag_prefix,
179-
processor,
180-
py_version,
181-
container_version,
182-
)
203+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
183204

184205
if _should_auto_select_container_version(instance_type, distribution):
185206
container_versions = {
@@ -248,6 +269,20 @@ def config_for_framework(framework):
248269
return json.load(f)
249270

250271

272+
def _get_inference_tool(inference_tool, instance_type):
273+
"""Extract the inference tool name from instance type."""
274+
if not inference_tool and instance_type:
275+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
276+
if match and match[1].startswith("inf"):
277+
return "neuron"
278+
return inference_tool
279+
280+
281+
def _get_latest_versions(list_of_versions):
282+
"""Extract the latest version from the input list of available versions."""
283+
return sorted(list_of_versions, reverse=True)[0]
284+
285+
251286
def _validate_accelerator_type(accelerator_type):
252287
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
253288
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
@@ -310,6 +345,8 @@ def _processor(instance_type, available_processors):
310345

311346
if instance_type.startswith("local"):
312347
processor = "cpu" if instance_type == "local" else "gpu"
348+
elif instance_type.startswith("neuron"):
349+
processor = "neuron"
313350
else:
314351
# looks for either "ml.<family>.<size>" or "ml_<family>"
315352
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
@@ -387,8 +424,10 @@ def _validate_arg(arg, available_options, arg_name):
387424
)
388425

389426

390-
def _format_tag(tag_prefix, processor, py_version, container_version):
427+
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
391428
"""Creates a tag for the image URI."""
429+
if inference_tool:
430+
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
392431
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
393432

394433

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _get_manifest_key_from_model_id_semantic_version(
229229
)
230230

231231
else:
232-
possible_model_ids = [header.model_id for header in manifest.values()]
232+
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
233233
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234234
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235235

0 commit comments

Comments
 (0)