Skip to content

feature: Add ModelStep for SageMaker Model Building Pipeline #3076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import re
import copy
from typing import List, Dict

import sagemaker
from sagemaker import (
Expand All @@ -38,6 +39,7 @@
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -289,6 +291,7 @@ def __init__(
self.uploaded_code = None
self.repacked_model_data = None

@runnable_by_pipeline
def register(
self,
content_types,
Expand All @@ -310,12 +313,12 @@ def register(
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Args:
content_types (list): The supported MIME types for the input data (default: None).
response_types (list): The supported MIME types for the output data (default: None).
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time (default: None).
generate inferences in real-time.
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed (default: None).
job can be run or on which an endpoint can be deployed.
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -366,12 +369,50 @@ def register(
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
)
if isinstance(self.sagemaker_session, PipelineSession):
return None
return ModelPackage(
role=self.role,
model_data=self.model_data,
model_package_arn=model_package.get("ModelPackageArn"),
)

@runnable_by_pipeline
def create(
self,
instance_type: str = None,
accelerator_type: str = None,
serverless_inference_config: ServerlessInferenceConfig = None,
tags: List[Dict[str, str]] = None,
):
"""Create a SageMaker Model Entity

Args:
instance_type (str): The EC2 instance type that this Model will be
used for, this is only used to determine if the image needs GPU
support or not (default: None).
accelerator_type (str): Type of Elastic Inference accelerator to
attach to an endpoint for model loading and inference, for
example, 'ml.eia1.medium'. If not specified, no Elastic
Inference accelerator will be attached to the endpoint (default: None).
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs
(default: None).
tags (List[Dict[str, str]]): The list of tags to add to
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
'tagvalue'}] For more information about tags, see
https://boto3.amazonaws.com/v1/documentation
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
"""
# TODO: we should replace _create_sagemaker_model() with create()
self._create_sagemaker_model(
instance_type=instance_type,
accelerator_type=accelerator_type,
tags=tags,
serverless_inference_config=serverless_inference_config,
)

def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.

Expand Down Expand Up @@ -455,6 +496,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
if repack and self.model_data is not None and self.entry_point is not None:
if is_pipeline_variable(self.model_data):
# model is not yet there, defer repacking to later during pipeline execution
if not isinstance(self.sagemaker_session, PipelineSession):
# TODO: link the doc in the warning once ready
logging.warning(
"The model_data is a Pipeline variable of type %s, "
"which should be used under `PipelineSession` and "
"leverage `ModelStep` to create or register model. "
"Otherwise some functionalities e.g. "
"runtime repack may be missing",
type(self.model_data),
)
return
self.sagemaker_session.context.need_runtime_repack.add(id(self))
# Add the uploaded_code and repacked_model_data to update the container env
self.repacked_model_data = self.model_data
self.uploaded_code = fw_utils.UploadedCode(
s3_prefix=self.repacked_model_data,
script_name=os.path.basename(self.entry_point),
)
return
if local_code and self.model_data.startswith("file://"):
repacked_model_data = self.model_data
Expand Down Expand Up @@ -538,22 +597,29 @@ def _create_sagemaker_model(
serverless_inference_config=serverless_inference_config,
)

self._ensure_base_name_if_needed(
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
)
self._set_model_name_if_needed()
if not isinstance(self.sagemaker_session, PipelineSession):
# _base_name, model_name are not needed under PipelineSession.
# the model_data may be Pipeline variable
# which may break the _base_name generation
self._ensure_base_name_if_needed(
image_uri=container_def["Image"],
script_uri=self.source_dir,
model_uri=self.model_data,
)
self._set_model_name_if_needed()

enable_network_isolation = self.enable_network_isolation()

self._init_sagemaker_session_if_does_not_exist(instance_type)
self.sagemaker_session.create_model(
self.name,
self.role,
container_def,
create_model_args = dict(
name=self.name,
role=self.role,
container_defs=container_def,
vpc_config=self.vpc_config,
enable_network_isolation=enable_network_isolation,
tags=tags,
)
self.sagemaker_session.create_model(**create_model_args)

def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
"""Create a base name from the image URI if there is no model name provided.
Expand Down
104 changes: 100 additions & 4 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
"""Placeholder docstring"""
from __future__ import absolute_import

from typing import Optional, Dict

import sagemaker
from sagemaker import ModelMetrics
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.session import Session
from sagemaker.utils import name_from_image
from sagemaker.transformer import Transformer
from sagemaker.workflow.pipeline_context import runnable_by_pipeline


class PipelineModel(object):
Expand Down Expand Up @@ -221,6 +227,17 @@ def deploy(
return predictor
return None

@runnable_by_pipeline
def create(self, instance_type: str):
"""Create a SageMaker Model Entity

Args:
instance_type (str): The EC2 instance type that this Model will be
used for, this is only used to determine if the image needs GPU
support or not.
"""
self._create_sagemaker_pipeline_model(instance_type)

def _create_sagemaker_pipeline_model(self, instance_type):
"""Create a SageMaker Model Entity

Expand All @@ -235,13 +252,92 @@ def _create_sagemaker_pipeline_model(self, instance_type):
containers = self.pipeline_container_def(instance_type)

self.name = self.name or name_from_image(containers[0]["Image"])
self.sagemaker_session.create_model(
self.name,
self.role,
containers,
create_model_args = dict(
name=self.name,
role=self.role,
container_defs=containers,
vpc_config=self.vpc_config,
enable_network_isolation=self.enable_network_isolation,
)
self.sagemaker_session.create_model(**create_model_args)

@runnable_by_pipeline
def register(
self,
content_types: list,
response_types: list,
inference_instances: list,
transform_instances: list,
model_package_name: Optional[str] = None,
model_package_group_name: Optional[str] = None,
image_uri: Optional[str] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
marketplace_cert: bool = False,
approval_status: Optional[str] = None,
description: Optional[str] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, str]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Args:
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
image_uri (str): Inference image uri for the container. Model class' self.image will
be used if it is None (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
"""
for model in self.models:
if model.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")
if model_package_group_name is not None:
container_def = self.pipeline_container_def(inference_instances[0])
else:
container_def = [
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
for model in self.models
]

model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=model_package_name,
model_package_group_name=model_package_group_name,
model_metrics=model_metrics,
metadata_properties=metadata_properties,
marketplace_cert=marketplace_cert,
approval_status=approval_status,
description=description,
container_def_list=container_def,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

def transformer(
self,
Expand Down
67 changes: 38 additions & 29 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,23 +2679,24 @@ def create_model(
primary_container=primary_container,
tags=tags,
)
LOGGER.info("Creating model with name: %s", name)
LOGGER.debug("CreateModel request: %s", json.dumps(create_model_request, indent=4))

try:
self.sagemaker_client.create_model(**create_model_request)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]

if (
error_code == "ValidationException"
and "Cannot create already existing model" in message
):
LOGGER.warning("Using already existing model: %s", name)
else:
raise
def submit(request):
LOGGER.info("Creating model with name: %s", name)
LOGGER.debug("CreateModel request: %s", json.dumps(request, indent=4))
try:
self.sagemaker_client.create_model(**request)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if (
error_code == "ValidationException"
and "Cannot create already existing model" in message
):
LOGGER.warning("Using already existing model: %s", name)
else:
raise

self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
return name

def create_model_from_job(
Expand Down Expand Up @@ -2829,10 +2830,9 @@ def create_model_package_from_containers(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

"""

request = get_create_model_package_request(
model_pkg_request = get_create_model_package_request(
model_package_name,
model_package_group_name,
containers,
Expand All @@ -2849,16 +2849,22 @@ def create_model_package_from_containers(
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
)
if model_package_group_name is not None:
try:
self.sagemaker_client.describe_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
except ClientError:
self.sagemaker_client.create_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
return self.sagemaker_client.create_model_package(**request)

def submit(request):
if model_package_group_name is not None:
try:
self.sagemaker_client.describe_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
except ClientError:
self.sagemaker_client.create_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
return self.sagemaker_client.create_model_package(**request)

return self._intercept_create_request(
model_pkg_request, submit, self.create_model_package_from_containers.__name__
)

def wait_for_model_package(self, model_package_name, poll=5):
"""Wait for an Amazon SageMaker endpoint deployment to complete.
Expand Down Expand Up @@ -4177,7 +4183,9 @@ def account_id(self) -> str:
)
return sts_client.get_caller_identity()["Account"]

def _intercept_create_request(self, request: typing.Dict, create):
def _intercept_create_request(
self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument
):
"""This function intercepts the create job request.

PipelineSession inherits this Session class and will override
Expand All @@ -4186,8 +4194,9 @@ def _intercept_create_request(self, request: typing.Dict, create):
Args:
request (dict): the create job request
create (functor): a functor calls the sagemaker client create method
func_name (str): the name of the function needed intercepting
"""
create(request)
return create(request)


def get_model_package_args(
Expand Down
Loading