Skip to content

Commit f33dff7

Browse files
author
Dewen Qi
committed
feature: Add ModelStep for SageMaker Model Building Pipeline
1 parent 811056b commit f33dff7

19 files changed

+2561
-327
lines changed

src/sagemaker/model.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import re
2121
import copy
22+
from typing import List, Dict
2223

2324
import sagemaker
2425
from sagemaker import (
@@ -38,6 +39,7 @@
3839
from sagemaker.async_inference import AsyncInferenceConfig
3940
from sagemaker.predictor_async import AsyncPredictor
4041
from sagemaker.workflow import is_pipeline_variable
42+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4143

4244
LOGGER = logging.getLogger("sagemaker")
4345

@@ -289,6 +291,7 @@ def __init__(
289291
self.uploaded_code = None
290292
self.repacked_model_data = None
291293

294+
@runnable_by_pipeline
292295
def register(
293296
self,
294297
content_types,
@@ -336,6 +339,7 @@ def register(
336339
Returns:
337340
A `sagemaker.model.ModelPackage` instance.
338341
"""
342+
339343
if self.model_data is None:
340344
raise ValueError("SageMaker Model Package cannot be created without model data.")
341345
if image_uri is not None:
@@ -361,15 +365,52 @@ def register(
361365
drift_check_baselines=drift_check_baselines,
362366
customer_metadata_properties=customer_metadata_properties,
363367
)
368+
364369
model_package = self.sagemaker_session.create_model_package_from_containers(
365370
**model_pkg_args
366371
)
372+
if isinstance(self.sagemaker_session, PipelineSession):
373+
return None
367374
return ModelPackage(
368375
role=self.role,
369376
model_data=self.model_data,
370377
model_package_arn=model_package.get("ModelPackageArn"),
371378
)
372379

380+
@runnable_by_pipeline
381+
def create(
382+
self,
383+
instance_type: str = None,
384+
accelerator_type: str = None,
385+
serverless_inference_config: ServerlessInferenceConfig = None,
386+
tags: List[Dict[str, str]] = None,
387+
):
388+
"""Create a Model Entity ???
389+
390+
Args:
391+
instance_type (str): The EC2 instance type that this Model will be
392+
used for, this is only used to determine if the image needs GPU
393+
support or not.
394+
accelerator_type (str): Type of Elastic Inference accelerator to
395+
attach to an endpoint for model loading and inference, for
396+
example, 'ml.eia1.medium'. If not specified, no Elastic
397+
Inference accelerator will be attached to the endpoint.
398+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399+
Specifies configuration related to serverless endpoint. Instance type is
400+
not provided in serverless inference. So this is used to find image URIs.
401+
tags (List[Dict[str, str]]): Optional. The list of tags to add to
402+
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
403+
'tagvalue'}] For more information about tags, see
404+
https://boto3.amazonaws.com/v1/documentation
405+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
406+
"""
407+
self._create_sagemaker_model(
408+
instance_type=instance_type,
409+
accelerator_type=accelerator_type,
410+
tags=tags,
411+
serverless_inference_config=serverless_inference_config,
412+
)
413+
373414
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
374415
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
375416
@@ -452,6 +493,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
452493
if repack and self.model_data is not None and self.entry_point is not None:
453494
if is_pipeline_variable(self.model_data):
454495
# model is not yet there, defer repacking to later during pipeline execution
496+
if isinstance(self.sagemaker_session, PipelineSession):
497+
self.sagemaker_session.mark_runtime_repack(id(self))
455498
return
456499

457500
bucket = self.bucket or self.sagemaker_session.default_bucket()
@@ -466,7 +509,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
466509
sagemaker_session=self.sagemaker_session,
467510
kms_key=self.model_kms_key,
468511
)
469-
470512
self.repacked_model_data = repacked_model_data
471513
self.uploaded_code = fw_utils.UploadedCode(
472514
s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)
@@ -533,23 +575,32 @@ def _create_sagemaker_model(
533575
serverless_inference_config=serverless_inference_config,
534576
)
535577

536-
self._ensure_base_name_if_needed(
537-
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
538-
)
539-
self._set_model_name_if_needed()
578+
if not isinstance(self.sagemaker_session, PipelineSession):
579+
# _base_name, model_name are not needed under PipelineSession.
580+
# the model_data, source_dir may be Pipeline variable
581+
# which may break the _base_name generation
582+
self._ensure_base_name_if_needed(
583+
image_uri=container_def["Image"],
584+
script_uri=self.source_dir,
585+
model_uri=self.model_data,
586+
)
587+
self._set_model_name_if_needed()
540588

541589
enable_network_isolation = self.enable_network_isolation()
542590

543591
self._init_sagemaker_session_if_does_not_exist(instance_type)
544-
self.sagemaker_session.create_model(
545-
self.name,
546-
self.role,
547-
container_def,
592+
593+
create_model_args = dict(
594+
name=self.name,
595+
role=self.role,
596+
container_defs=container_def,
548597
vpc_config=self.vpc_config,
549598
enable_network_isolation=enable_network_isolation,
550599
tags=tags,
551600
)
552601

602+
self.sagemaker_session.create_model(**create_model_args)
603+
553604
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
554605
"""Create a base name from the image URI if there is no model name provided.
555606

src/sagemaker/pipeline.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sagemaker.session import Session
1818
from sagemaker.utils import name_from_image
1919
from sagemaker.transformer import Transformer
20+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2021

2122

2223
class PipelineModel(object):
@@ -221,6 +222,17 @@ def deploy(
221222
return predictor
222223
return None
223224

225+
@runnable_by_pipeline
226+
def create(self, instance_type):
227+
"""Create a SageMaker Model Entity ???
228+
229+
Args:
230+
instance_type (str): The EC2 instance type that this Model will be
231+
used for, this is only used to determine if the image needs GPU
232+
support or not.
233+
"""
234+
self._create_sagemaker_pipeline_model(instance_type)
235+
224236
def _create_sagemaker_pipeline_model(self, instance_type):
225237
"""Create a SageMaker Model Entity
226238
@@ -231,18 +243,98 @@ def _create_sagemaker_pipeline_model(self, instance_type):
231243
"""
232244
if not self.sagemaker_session:
233245
self.sagemaker_session = Session()
234-
235246
containers = self.pipeline_container_def(instance_type)
236247

237248
self.name = self.name or name_from_image(containers[0]["Image"])
238-
self.sagemaker_session.create_model(
239-
self.name,
240-
self.role,
241-
containers,
249+
250+
create_model_args = dict(
251+
name=self.name,
252+
role=self.role,
253+
container_defs=containers,
242254
vpc_config=self.vpc_config,
243255
enable_network_isolation=self.enable_network_isolation,
244256
)
245257

258+
self.sagemaker_session.create_model(**create_model_args)
259+
260+
@runnable_by_pipeline
261+
def register(
262+
self,
263+
content_types,
264+
response_types,
265+
inference_instances,
266+
transform_instances,
267+
model_package_name=None,
268+
model_package_group_name=None,
269+
image_uri=None,
270+
model_metrics=None,
271+
metadata_properties=None,
272+
marketplace_cert=False,
273+
approval_status=None,
274+
description=None,
275+
drift_check_baselines=None,
276+
customer_metadata_properties=None,
277+
):
278+
"""Creates a model package for creating SageMaker pipeline model.
279+
280+
Args: ???
281+
content_types (list): The supported MIME types for the input data (default: None).
282+
response_types (list): The supported MIME types for the output data (default: None).
283+
inference_instances (list): A list of the instance types that are used to
284+
generate inferences in real-time (default: None).
285+
transform_instances (list): A list of the instance types on which a transformation
286+
job can be run or on which an endpoint can be deployed (default: None).
287+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
288+
using `model_package_name` makes the Model Package un-versioned (default: None).
289+
model_package_group_name (str): Model Package Group name, exclusive to
290+
`model_package_name`, using `model_package_group_name` makes the Model Package
291+
versioned (default: None).
292+
image_uri (str): Inference image uri for the container. Model class' self.image will
293+
be used if it is None (default: None).
294+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
295+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
296+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
297+
for AWS Marketplace (default: False).
298+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
299+
or "PendingManualApproval" (default: "PendingManualApproval").
300+
description (str): Model Package description (default: None).
301+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
302+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
303+
metadata properties (default: None).
304+
305+
Returns:
306+
A `sagemaker.model.ModelPackage` instance.
307+
"""
308+
for model in self.models:
309+
if model.model_data is None:
310+
raise ValueError("SageMaker Model Package cannot be created without model data.")
311+
if model_package_group_name is not None:
312+
container_def = self.pipeline_container_def(inference_instances[0])
313+
else:
314+
container_def = [
315+
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
316+
for model in self.models
317+
]
318+
319+
model_pkg_args = sagemaker.get_model_package_args(
320+
content_types,
321+
response_types,
322+
inference_instances,
323+
transform_instances,
324+
model_package_name=model_package_name,
325+
model_package_group_name=model_package_group_name,
326+
model_metrics=model_metrics,
327+
metadata_properties=metadata_properties,
328+
marketplace_cert=marketplace_cert,
329+
approval_status=approval_status,
330+
description=description,
331+
container_def_list=container_def,
332+
drift_check_baselines=drift_check_baselines,
333+
customer_metadata_properties=customer_metadata_properties,
334+
)
335+
336+
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
337+
246338
def transformer(
247339
self,
248340
instance_count,

0 commit comments

Comments
 (0)