Skip to content

Commit 13efaef

Browse files
authored
Merge branch 'master' into fix/jumpstart-amt-tracking
2 parents 1aab22b + 8c52f1b commit 13efaef

26 files changed

+2833
-402
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def read_requirements(filename):
8484
url="https://github.com/aws/sagemaker-python-sdk/",
8585
license="Apache License 2.0",
8686
keywords="ML Amazon AWS AI Tensorflow MXNet",
87+
python_requires=">= 3.6",
8788
classifiers=[
8889
"Development Status :: 5 - Production/Stable",
8990
"Intended Audience :: Developers",

src/sagemaker/model.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import re
2121
import copy
22+
from typing import List, Dict
2223

2324
import sagemaker
2425
from sagemaker import (
@@ -38,6 +39,7 @@
3839
from sagemaker.async_inference import AsyncInferenceConfig
3940
from sagemaker.predictor_async import AsyncPredictor
4041
from sagemaker.workflow import is_pipeline_variable
42+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4143

4244
LOGGER = logging.getLogger("sagemaker")
4345

@@ -289,6 +291,7 @@ def __init__(
289291
self.uploaded_code = None
290292
self.repacked_model_data = None
291293

294+
@runnable_by_pipeline
292295
def register(
293296
self,
294297
content_types,
@@ -310,12 +313,12 @@ def register(
310313
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311314
312315
Args:
313-
content_types (list): The supported MIME types for the input data (default: None).
314-
response_types (list): The supported MIME types for the output data (default: None).
316+
content_types (list): The supported MIME types for the input data.
317+
response_types (list): The supported MIME types for the output data.
315318
inference_instances (list): A list of the instance types that are used to
316-
generate inferences in real-time (default: None).
319+
generate inferences in real-time.
317320
transform_instances (list): A list of the instance types on which a transformation
318-
job can be run or on which an endpoint can be deployed (default: None).
321+
job can be run or on which an endpoint can be deployed.
319322
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320323
using `model_package_name` makes the Model Package un-versioned (default: None).
321324
model_package_group_name (str): Model Package Group name, exclusive to
@@ -366,12 +369,50 @@ def register(
366369
model_package = self.sagemaker_session.create_model_package_from_containers(
367370
**model_pkg_args
368371
)
372+
if isinstance(self.sagemaker_session, PipelineSession):
373+
return None
369374
return ModelPackage(
370375
role=self.role,
371376
model_data=self.model_data,
372377
model_package_arn=model_package.get("ModelPackageArn"),
373378
)
374379

380+
@runnable_by_pipeline
381+
def create(
382+
self,
383+
instance_type: str = None,
384+
accelerator_type: str = None,
385+
serverless_inference_config: ServerlessInferenceConfig = None,
386+
tags: List[Dict[str, str]] = None,
387+
):
388+
"""Create a SageMaker Model Entity
389+
390+
Args:
391+
instance_type (str): The EC2 instance type that this Model will be
392+
used for, this is only used to determine if the image needs GPU
393+
support or not (default: None).
394+
accelerator_type (str): Type of Elastic Inference accelerator to
395+
attach to an endpoint for model loading and inference, for
396+
example, 'ml.eia1.medium'. If not specified, no Elastic
397+
Inference accelerator will be attached to the endpoint (default: None).
398+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399+
Specifies configuration related to serverless endpoint. Instance type is
400+
not provided in serverless inference. So this is used to find image URIs
401+
(default: None).
402+
tags (List[Dict[str, str]]): The list of tags to add to
403+
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
404+
'tagvalue'}] For more information about tags, see
405+
https://boto3.amazonaws.com/v1/documentation
406+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
407+
"""
408+
# TODO: we should replace _create_sagemaker_model() with create()
409+
self._create_sagemaker_model(
410+
instance_type=instance_type,
411+
accelerator_type=accelerator_type,
412+
tags=tags,
413+
serverless_inference_config=serverless_inference_config,
414+
)
415+
375416
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
376417
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
377418
@@ -455,6 +496,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
455496
if repack and self.model_data is not None and self.entry_point is not None:
456497
if is_pipeline_variable(self.model_data):
457498
# model is not yet there, defer repacking to later during pipeline execution
499+
if not isinstance(self.sagemaker_session, PipelineSession):
500+
# TODO: link the doc in the warning once ready
501+
logging.warning(
502+
"The model_data is a Pipeline variable of type %s, "
503+
"which should be used under `PipelineSession` and "
504+
"leverage `ModelStep` to create or register model. "
505+
"Otherwise some functionalities e.g. "
506+
"runtime repack may be missing",
507+
type(self.model_data),
508+
)
509+
return
510+
self.sagemaker_session.context.need_runtime_repack.add(id(self))
511+
# Add the uploaded_code and repacked_model_data to update the container env
512+
self.repacked_model_data = self.model_data
513+
self.uploaded_code = fw_utils.UploadedCode(
514+
s3_prefix=self.repacked_model_data,
515+
script_name=os.path.basename(self.entry_point),
516+
)
458517
return
459518
if local_code and self.model_data.startswith("file://"):
460519
repacked_model_data = self.model_data
@@ -538,22 +597,29 @@ def _create_sagemaker_model(
538597
serverless_inference_config=serverless_inference_config,
539598
)
540599

541-
self._ensure_base_name_if_needed(
542-
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
543-
)
544-
self._set_model_name_if_needed()
600+
if not isinstance(self.sagemaker_session, PipelineSession):
601+
# _base_name, model_name are not needed under PipelineSession.
602+
# the model_data may be Pipeline variable
603+
# which may break the _base_name generation
604+
self._ensure_base_name_if_needed(
605+
image_uri=container_def["Image"],
606+
script_uri=self.source_dir,
607+
model_uri=self.model_data,
608+
)
609+
self._set_model_name_if_needed()
545610

546611
enable_network_isolation = self.enable_network_isolation()
547612

548613
self._init_sagemaker_session_if_does_not_exist(instance_type)
549-
self.sagemaker_session.create_model(
550-
self.name,
551-
self.role,
552-
container_def,
614+
create_model_args = dict(
615+
name=self.name,
616+
role=self.role,
617+
container_defs=container_def,
553618
vpc_config=self.vpc_config,
554619
enable_network_isolation=enable_network_isolation,
555620
tags=tags,
556621
)
622+
self.sagemaker_session.create_model(**create_model_args)
557623

558624
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
559625
"""Create a base name from the image URI if there is no model name provided.

src/sagemaker/pipeline.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Dict
17+
1618
import sagemaker
19+
from sagemaker import ModelMetrics
20+
from sagemaker.drift_check_baselines import DriftCheckBaselines
21+
from sagemaker.metadata_properties import MetadataProperties
1722
from sagemaker.session import Session
1823
from sagemaker.utils import name_from_image
1924
from sagemaker.transformer import Transformer
25+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2026

2127

2228
class PipelineModel(object):
@@ -221,6 +227,17 @@ def deploy(
221227
return predictor
222228
return None
223229

230+
@runnable_by_pipeline
231+
def create(self, instance_type: str):
232+
"""Create a SageMaker Model Entity
233+
234+
Args:
235+
instance_type (str): The EC2 instance type that this Model will be
236+
used for, this is only used to determine if the image needs GPU
237+
support or not.
238+
"""
239+
self._create_sagemaker_pipeline_model(instance_type)
240+
224241
def _create_sagemaker_pipeline_model(self, instance_type):
225242
"""Create a SageMaker Model Entity
226243
@@ -235,13 +252,92 @@ def _create_sagemaker_pipeline_model(self, instance_type):
235252
containers = self.pipeline_container_def(instance_type)
236253

237254
self.name = self.name or name_from_image(containers[0]["Image"])
238-
self.sagemaker_session.create_model(
239-
self.name,
240-
self.role,
241-
containers,
255+
create_model_args = dict(
256+
name=self.name,
257+
role=self.role,
258+
container_defs=containers,
242259
vpc_config=self.vpc_config,
243260
enable_network_isolation=self.enable_network_isolation,
244261
)
262+
self.sagemaker_session.create_model(**create_model_args)
263+
264+
@runnable_by_pipeline
265+
def register(
266+
self,
267+
content_types: list,
268+
response_types: list,
269+
inference_instances: list,
270+
transform_instances: list,
271+
model_package_name: Optional[str] = None,
272+
model_package_group_name: Optional[str] = None,
273+
image_uri: Optional[str] = None,
274+
model_metrics: Optional[ModelMetrics] = None,
275+
metadata_properties: Optional[MetadataProperties] = None,
276+
marketplace_cert: bool = False,
277+
approval_status: Optional[str] = None,
278+
description: Optional[str] = None,
279+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280+
customer_metadata_properties: Optional[Dict[str, str]] = None,
281+
):
282+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
283+
284+
Args:
285+
content_types (list): The supported MIME types for the input data.
286+
response_types (list): The supported MIME types for the output data.
287+
inference_instances (list): A list of the instance types that are used to
288+
generate inferences in real-time.
289+
transform_instances (list): A list of the instance types on which a transformation
290+
job can be run or on which an endpoint can be deployed.
291+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
292+
using `model_package_name` makes the Model Package un-versioned (default: None).
293+
model_package_group_name (str): Model Package Group name, exclusive to
294+
`model_package_name`, using `model_package_group_name` makes the Model Package
295+
versioned (default: None).
296+
image_uri (str): Inference image uri for the container. Model class' self.image will
297+
be used if it is None (default: None).
298+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
299+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
300+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
301+
for AWS Marketplace (default: False).
302+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
303+
or "PendingManualApproval" (default: "PendingManualApproval").
304+
description (str): Model Package description (default: None).
305+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
306+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
307+
metadata properties (default: None).
308+
309+
Returns:
310+
A `sagemaker.model.ModelPackage` instance.
311+
"""
312+
for model in self.models:
313+
if model.model_data is None:
314+
raise ValueError("SageMaker Model Package cannot be created without model data.")
315+
if model_package_group_name is not None:
316+
container_def = self.pipeline_container_def(inference_instances[0])
317+
else:
318+
container_def = [
319+
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
320+
for model in self.models
321+
]
322+
323+
model_pkg_args = sagemaker.get_model_package_args(
324+
content_types,
325+
response_types,
326+
inference_instances,
327+
transform_instances,
328+
model_package_name=model_package_name,
329+
model_package_group_name=model_package_group_name,
330+
model_metrics=model_metrics,
331+
metadata_properties=metadata_properties,
332+
marketplace_cert=marketplace_cert,
333+
approval_status=approval_status,
334+
description=description,
335+
container_def_list=container_def,
336+
drift_check_baselines=drift_check_baselines,
337+
customer_metadata_properties=customer_metadata_properties,
338+
)
339+
340+
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
245341

246342
def transformer(
247343
self,

0 commit comments

Comments
 (0)