Skip to content

Commit db65419

Browse files
author
Dewen Qi
committed
feature: Add ModelStep for SageMaker Model Building Pipeline
1 parent 811056b commit db65419

19 files changed

+2680
-338
lines changed

src/sagemaker/model.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import re
2121
import copy
22+
from typing import List, Dict
2223

2324
import sagemaker
2425
from sagemaker import (
@@ -38,6 +39,7 @@
3839
from sagemaker.async_inference import AsyncInferenceConfig
3940
from sagemaker.predictor_async import AsyncPredictor
4041
from sagemaker.workflow import is_pipeline_variable
42+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4143

4244
LOGGER = logging.getLogger("sagemaker")
4345

@@ -289,6 +291,7 @@ def __init__(
289291
self.uploaded_code = None
290292
self.repacked_model_data = None
291293

294+
@runnable_by_pipeline
292295
def register(
293296
self,
294297
content_types,
@@ -309,12 +312,12 @@ def register(
309312
"""Creates a model package for creating SageMaker models or listing on Marketplace.
310313
311314
Args:
312-
content_types (list): The supported MIME types for the input data (default: None).
313-
response_types (list): The supported MIME types for the output data (default: None).
315+
content_types (list): The supported MIME types for the input data.
316+
response_types (list): The supported MIME types for the output data.
314317
inference_instances (list): A list of the instance types that are used to
315-
generate inferences in real-time (default: None).
318+
generate inferences in real-time.
316319
transform_instances (list): A list of the instance types on which a transformation
317-
job can be run or on which an endpoint can be deployed (default: None).
320+
job can be run or on which an endpoint can be deployed.
318321
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
319322
using `model_package_name` makes the Model Package un-versioned (default: None).
320323
model_package_group_name (str): Model Package Group name, exclusive to
@@ -364,12 +367,49 @@ def register(
364367
model_package = self.sagemaker_session.create_model_package_from_containers(
365368
**model_pkg_args
366369
)
370+
if isinstance(self.sagemaker_session, PipelineSession):
371+
return None
367372
return ModelPackage(
368373
role=self.role,
369374
model_data=self.model_data,
370375
model_package_arn=model_package.get("ModelPackageArn"),
371376
)
372377

378+
@runnable_by_pipeline
379+
def create(
380+
self,
381+
instance_type: str = None,
382+
accelerator_type: str = None,
383+
serverless_inference_config: ServerlessInferenceConfig = None,
384+
tags: List[Dict[str, str]] = None,
385+
):
386+
"""Create a SageMaker Model Entity
387+
388+
Args:
389+
instance_type (str): The EC2 instance type that this Model will be
390+
used for, this is only used to determine if the image needs GPU
391+
support or not (default: None).
392+
accelerator_type (str): Type of Elastic Inference accelerator to
393+
attach to an endpoint for model loading and inference, for
394+
example, 'ml.eia1.medium'. If not specified, no Elastic
395+
Inference accelerator will be attached to the endpoint (default: None).
396+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
397+
Specifies configuration related to serverless endpoint. Instance type is
398+
not provided in serverless inference. So this is used to find image URIs
399+
(default: None).
400+
tags (List[Dict[str, str]]): The list of tags to add to
401+
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
402+
'tagvalue'}] For more information about tags, see
403+
https://boto3.amazonaws.com/v1/documentation
404+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
405+
"""
406+
self._create_sagemaker_model(
407+
instance_type=instance_type,
408+
accelerator_type=accelerator_type,
409+
tags=tags,
410+
serverless_inference_config=serverless_inference_config,
411+
)
412+
373413
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
374414
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
375415
@@ -452,6 +492,17 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
452492
if repack and self.model_data is not None and self.entry_point is not None:
453493
if is_pipeline_variable(self.model_data):
454494
# model is not yet there, defer repacking to later during pipeline execution
495+
if not isinstance(self.sagemaker_session, PipelineSession):
496+
logging.warning(
497+
"The model_data is a Pipeline variable of type %s, "
498+
"which should be used under `PipelineSession` and "
499+
"leverage `ModelStep` to create or register model. "
500+
"Otherwise some functionalities e.g. "
501+
"runtime repack may be missing",
502+
type(self.model_data),
503+
)
504+
return
505+
self.sagemaker_session.mark_runtime_repack(id(self))
455506
return
456507

457508
bucket = self.bucket or self.sagemaker_session.default_bucket()
@@ -533,22 +584,29 @@ def _create_sagemaker_model(
533584
serverless_inference_config=serverless_inference_config,
534585
)
535586

536-
self._ensure_base_name_if_needed(
537-
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
538-
)
539-
self._set_model_name_if_needed()
587+
if not isinstance(self.sagemaker_session, PipelineSession):
588+
# _base_name, model_name are not needed under PipelineSession.
589+
# the model_data may be Pipeline variable
590+
# which may break the _base_name generation
591+
self._ensure_base_name_if_needed(
592+
image_uri=container_def["Image"],
593+
script_uri=self.source_dir,
594+
model_uri=self.model_data,
595+
)
596+
self._set_model_name_if_needed()
540597

541598
enable_network_isolation = self.enable_network_isolation()
542599

543600
self._init_sagemaker_session_if_does_not_exist(instance_type)
544-
self.sagemaker_session.create_model(
545-
self.name,
546-
self.role,
547-
container_def,
601+
create_model_args = dict(
602+
name=self.name,
603+
role=self.role,
604+
container_defs=container_def,
548605
vpc_config=self.vpc_config,
549606
enable_network_isolation=enable_network_isolation,
550607
tags=tags,
551608
)
609+
self.sagemaker_session.create_model(**create_model_args)
552610

553611
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
554612
"""Create a base name from the image URI if there is no model name provided.

src/sagemaker/pipeline.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Dict
17+
1618
import sagemaker
19+
from sagemaker import ModelMetrics
20+
from sagemaker.drift_check_baselines import DriftCheckBaselines
21+
from sagemaker.metadata_properties import MetadataProperties
1722
from sagemaker.session import Session
1823
from sagemaker.utils import name_from_image
1924
from sagemaker.transformer import Transformer
25+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2026

2127

2228
class PipelineModel(object):
@@ -221,6 +227,17 @@ def deploy(
221227
return predictor
222228
return None
223229

230+
@runnable_by_pipeline
231+
def create(self, instance_type: str):
232+
"""Create a SageMaker Model Entity
233+
234+
Args:
235+
instance_type (str): The EC2 instance type that this Model will be
236+
used for, this is only used to determine if the image needs GPU
237+
support or not.
238+
"""
239+
self._create_sagemaker_pipeline_model(instance_type)
240+
224241
def _create_sagemaker_pipeline_model(self, instance_type):
225242
"""Create a SageMaker Model Entity
226243
@@ -235,13 +252,92 @@ def _create_sagemaker_pipeline_model(self, instance_type):
235252
containers = self.pipeline_container_def(instance_type)
236253

237254
self.name = self.name or name_from_image(containers[0]["Image"])
238-
self.sagemaker_session.create_model(
239-
self.name,
240-
self.role,
241-
containers,
255+
create_model_args = dict(
256+
name=self.name,
257+
role=self.role,
258+
container_defs=containers,
242259
vpc_config=self.vpc_config,
243260
enable_network_isolation=self.enable_network_isolation,
244261
)
262+
self.sagemaker_session.create_model(**create_model_args)
263+
264+
@runnable_by_pipeline
265+
def register(
266+
self,
267+
content_types: list,
268+
response_types: list,
269+
inference_instances: list,
270+
transform_instances: list,
271+
model_package_name: Optional[str] = None,
272+
model_package_group_name: Optional[str] = None,
273+
image_uri: Optional[str] = None,
274+
model_metrics: Optional[ModelMetrics] = None,
275+
metadata_properties: Optional[MetadataProperties] = None,
276+
marketplace_cert: bool = False,
277+
approval_status: Optional[str] = None,
278+
description: Optional[str] = None,
279+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280+
customer_metadata_properties: Optional[Dict[str, str]] = None,
281+
):
282+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
283+
284+
Args:
285+
content_types (list): The supported MIME types for the input data.
286+
response_types (list): The supported MIME types for the output data.
287+
inference_instances (list): A list of the instance types that are used to
288+
generate inferences in real-time.
289+
transform_instances (list): A list of the instance types on which a transformation
290+
job can be run or on which an endpoint can be deployed.
291+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
292+
using `model_package_name` makes the Model Package un-versioned (default: None).
293+
model_package_group_name (str): Model Package Group name, exclusive to
294+
`model_package_name`, using `model_package_group_name` makes the Model Package
295+
versioned (default: None).
296+
image_uri (str): Inference image uri for the container. Model class' self.image will
297+
be used if it is None (default: None).
298+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
299+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
300+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
301+
for AWS Marketplace (default: False).
302+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
303+
or "PendingManualApproval" (default: "PendingManualApproval").
304+
description (str): Model Package description (default: None).
305+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
306+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
307+
metadata properties (default: None).
308+
309+
Returns:
310+
A `sagemaker.model.ModelPackage` instance.
311+
"""
312+
for model in self.models:
313+
if model.model_data is None:
314+
raise ValueError("SageMaker Model Package cannot be created without model data.")
315+
if model_package_group_name is not None:
316+
container_def = self.pipeline_container_def(inference_instances[0])
317+
else:
318+
container_def = [
319+
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
320+
for model in self.models
321+
]
322+
323+
model_pkg_args = sagemaker.get_model_package_args(
324+
content_types,
325+
response_types,
326+
inference_instances,
327+
transform_instances,
328+
model_package_name=model_package_name,
329+
model_package_group_name=model_package_group_name,
330+
model_metrics=model_metrics,
331+
metadata_properties=metadata_properties,
332+
marketplace_cert=marketplace_cert,
333+
approval_status=approval_status,
334+
description=description,
335+
container_def_list=container_def,
336+
drift_check_baselines=drift_check_baselines,
337+
customer_metadata_properties=customer_metadata_properties,
338+
)
339+
340+
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
245341

246342
def transformer(
247343
self,

src/sagemaker/session.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2679,23 +2679,24 @@ def create_model(
26792679
primary_container=primary_container,
26802680
tags=tags,
26812681
)
2682-
LOGGER.info("Creating model with name: %s", name)
2683-
LOGGER.debug("CreateModel request: %s", json.dumps(create_model_request, indent=4))
26842682

2685-
try:
2686-
self.sagemaker_client.create_model(**create_model_request)
2687-
except ClientError as e:
2688-
error_code = e.response["Error"]["Code"]
2689-
message = e.response["Error"]["Message"]
2690-
2691-
if (
2692-
error_code == "ValidationException"
2693-
and "Cannot create already existing model" in message
2694-
):
2695-
LOGGER.warning("Using already existing model: %s", name)
2696-
else:
2697-
raise
2683+
def submit(request):
2684+
LOGGER.info("Creating model with name: %s", name)
2685+
LOGGER.debug("CreateModel request: %s", json.dumps(request, indent=4))
2686+
try:
2687+
self.sagemaker_client.create_model(**request)
2688+
except ClientError as e:
2689+
error_code = e.response["Error"]["Code"]
2690+
message = e.response["Error"]["Message"]
2691+
if (
2692+
error_code == "ValidationException"
2693+
and "Cannot create already existing model" in message
2694+
):
2695+
LOGGER.warning("Using already existing model: %s", name)
2696+
else:
2697+
raise
26982698

2699+
self._intercept_create_request(create_model_request, submit, "create_model")
26992700
return name
27002701

27012702
def create_model_from_job(
@@ -2828,10 +2829,9 @@ def create_model_package_from_containers(
28282829
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
28292830
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
28302831
metadata properties (default: None).
2831-
28322832
"""
28332833

2834-
request = get_create_model_package_request(
2834+
model_pkg_request = get_create_model_package_request(
28352835
model_package_name,
28362836
model_package_group_name,
28372837
containers,
@@ -2847,16 +2847,22 @@ def create_model_package_from_containers(
28472847
drift_check_baselines=drift_check_baselines,
28482848
customer_metadata_properties=customer_metadata_properties,
28492849
)
2850-
if model_package_group_name is not None:
2851-
try:
2852-
self.sagemaker_client.describe_model_package_group(
2853-
ModelPackageGroupName=request["ModelPackageGroupName"]
2854-
)
2855-
except ClientError:
2856-
self.sagemaker_client.create_model_package_group(
2857-
ModelPackageGroupName=request["ModelPackageGroupName"]
2858-
)
2859-
return self.sagemaker_client.create_model_package(**request)
2850+
2851+
def submit(request):
2852+
if model_package_group_name is not None:
2853+
try:
2854+
self.sagemaker_client.describe_model_package_group(
2855+
ModelPackageGroupName=request["ModelPackageGroupName"]
2856+
)
2857+
except ClientError:
2858+
self.sagemaker_client.create_model_package_group(
2859+
ModelPackageGroupName=request["ModelPackageGroupName"]
2860+
)
2861+
return self.sagemaker_client.create_model_package(**request)
2862+
2863+
return self._intercept_create_request(
2864+
model_pkg_request, submit, "create_model_package_from_containers"
2865+
)
28602866

28612867
def wait_for_model_package(self, model_package_name, poll=5):
28622868
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -4175,14 +4181,17 @@ def account_id(self) -> str:
41754181
)
41764182
return sts_client.get_caller_identity()["Account"]
41774183

4178-
def _intercept_create_request(self, request: typing.Dict, create):
4184+
def _intercept_create_request(
4185+
self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument
4186+
):
41794187
"""This function intercepts the create job request
41804188
41814189
Args:
41824190
request (dict): the create job request
41834191
create (functor): a functor calls the sagemaker client create method
4192+
func_name (str): the name of the function needed intercepting
41844193
"""
4185-
create(request)
4194+
return create(request)
41864195

41874196

41884197
def get_model_package_args(

0 commit comments

Comments
 (0)