Skip to content

feature: Adding serial inference pipeline support to RegisterModel Step #2405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0a0ce3c
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 19, 2021
5e90186
Merge branch 'master' into model-registry-sip
sreedes May 25, 2021
6561856
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 19, 2021
a4a352a
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
f23ff20
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
f52ab31
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
9528898
Using Model class to abstract the containers
sreedes Jun 7, 2021
e86f8dd
Using Model class to abstract the containers
sreedes Jun 7, 2021
cd85320
Using Model class to abstract the containers
sreedes Jun 7, 2021
5f2c6c1
Formatting changes and lint fixes
sreedes Jun 7, 2021
deda665
Merge branch 'aws:master' into model-registry-sip
sreedes Jun 11, 2021
bcf55af
Merge branch 'master' into model-registry-sip
ahsan-z-khan Jun 14, 2021
562b56e
Repack steps for models
sreedes Jun 28, 2021
ee2b1bc
Repack steps for models
sreedes Jun 28, 2021
9c4a914
Merge branch 'master' into model-registry-sip
sreedes Jun 28, 2021
bdf809f
lint error corrections
sreedes Jun 30, 2021
ced6d1c
lint error corrections
sreedes Jun 30, 2021
bb5fd97
Unit test corrections
sreedes Jun 30, 2021
151d3bb
Unit test corrections
sreedes Jun 30, 2021
06f9659
Unit test corrections
sreedes Jun 30, 2021
1a349c3
Integ test added
sreedes Jul 5, 2021
aaf4d52
Merge branch 'master' into model-registry-sip
sreedes Jul 5, 2021
382da53
Update test_workflow.py
sreedes Jul 5, 2021
a5b1c43
Update model.py
sreedes Jul 5, 2021
36287f5
Update model.py
sreedes Jul 5, 2021
faac0a7
fix unit test failure
sreedes Jul 5, 2021
a0f446f
fixing unit tests for repack
sreedes Jul 5, 2021
2b7f86b
Path fix for the integ tests
sreedes Jul 6, 2021
f84c50f
Fix integ test paths
sreedes Jul 6, 2021
a614e6e
Black tool fixes
sreedes Jul 6, 2021
a0abe80
Test data files
sreedes Jul 6, 2021
ef1b929
Merge branch 'master' into model-registry-sip
sreedes Jul 8, 2021
6713ff6
Merge branch 'master' into model-registry-sip
ahsan-z-khan Jul 8, 2021
6538f6b
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 8, 2021
5261a2f
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 9, 2021
d95235f
Updating the input to RegisterModelStep to Pipeline Model
sreedes Jul 12, 2021
80e391e
Fix for local tests
sreedes Jul 12, 2021
6eeff17
Checking if the key is present
sreedes Jul 12, 2021
89ed8e2
black check fix
sreedes Jul 12, 2021
22fccc0
Fix unit test failures
sreedes Jul 12, 2021
6e12e41
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 13, 2021
c2d5e61
Updating the input to model or pipeline model
sreedes Jul 13, 2021
4a8672e
Review comment incorporation
sreedes Jul 14, 2021
33db77f
Fixing unit test failure
sreedes Jul 14, 2021
a752214
Fixing unit test failure
sreedes Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401
from sagemaker.session import Session # noqa: F401
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
from sagemaker.session import get_model_package_args # noqa: F401
from sagemaker.session import production_variant # noqa: F401
from sagemaker.session import get_execution_role # noqa: F401

Expand Down
83 changes: 7 additions & 76 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ def register(
if self.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")

model_pkg_args = self._get_model_package_args(
model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name,
model_package_group_name,
image_uri,
self.model_data,
image_uri or self.image_uri,
model_metrics,
metadata_properties,
marketplace_cert,
Expand All @@ -181,80 +182,6 @@ def register(
model_package_arn=model_package.get("ModelPackageArn"),
)

def _get_model_package_args(
self,
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
tags=None,
):
"""Get arguments for session.create_model_package method.

Args:
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
image_uri (str): Inference image uri for the container. Model class' self.image will
be used if it is None (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
Returns:
dict: A dictionary of method argument names and values.
"""
if image_uri:
self.image_uri = image_uri
container = {
"Image": self.image_uri,
"ModelDataUrl": self.model_data,
}

model_package_args = {
"containers": [container],
"content_types": content_types,
"response_types": response_types,
"inference_instances": inference_instances,
"transform_instances": transform_instances,
"marketplace_cert": marketplace_cert,
}

if model_package_name is not None:
model_package_args["model_package_name"] = model_package_name
if model_package_group_name is not None:
model_package_args["model_package_group_name"] = model_package_group_name
if model_metrics is not None:
model_package_args["model_metrics"] = model_metrics._to_request_dict()
if metadata_properties is not None:
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
if approval_status is not None:
model_package_args["approval_status"] = approval_status
if description is not None:
model_package_args["description"] = description
if tags is not None:
model_package_args["tags"] = tags
return model_package_args

def _init_sagemaker_session_if_does_not_exist(self, instance_type):
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.

Expand Down Expand Up @@ -1128,6 +1055,10 @@ def _upload_code(self, key_prefix, repack=False):
)

if repack and self.model_data is not None and self.entry_point is not None:
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
# model is not yet there, defer repacking to later during pipeline execution
return

bucket = self.bucket or self.sagemaker_session.default_bucket()
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])

Expand Down
235 changes: 158 additions & 77 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,8 +2521,11 @@ def _create_model_request(
request = {"ModelName": name, "ExecutionRoleArn": role}
if isinstance(container_definition, list):
request["Containers"] = container_definition
elif "ModelPackageName" in container_definition:
request["Containers"] = [container_definition]
else:
request["PrimaryContainer"] = container_definition

if tags:
request["Tags"] = tags

Expand Down Expand Up @@ -2731,7 +2734,7 @@ def create_model_package_from_containers(
description (str): Model Package description (default: None).
"""

request = self._get_create_model_package_request(
request = get_create_model_package_request(
model_package_name,
model_package_group_name,
containers,
Expand All @@ -2747,82 +2750,6 @@ def create_model_package_from_containers(
)
return self.sagemaker_client.create_model_package(**request)

def _get_create_model_package_request(
self,
model_package_name=None,
model_package_group_name=None,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status="PendingManualApproval",
description=None,
tags=None,
):
"""Get request dictionary for CreateModelPackage API.

Args:
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
containers (list): A list of inference containers that can be used for inference
specifications of Model Package (default: None).
content_types (list): The supported MIME types for the input data (default: None).
response_types (list): The supported MIME types for the output data (default: None).
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
"""
if all([model_package_name, model_package_group_name]):
raise ValueError(
"model_package_name and model_package_group_name cannot be present at the "
"same time."
)
request_dict = {}
if model_package_name is not None:
request_dict["ModelPackageName"] = model_package_name
if model_package_group_name is not None:
request_dict["ModelPackageGroupName"] = model_package_group_name
if description is not None:
request_dict["ModelPackageDescription"] = description
if tags is not None:
request_dict["Tags"] = tags
if model_metrics:
request_dict["ModelMetrics"] = model_metrics
if metadata_properties:
request_dict["MetadataProperties"] = metadata_properties
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
raise ValueError(
"content_types, response_types, inference_inferences and transform_instances "
"must be provided if containers is present."
)
inference_specification = {
"Containers": containers,
"SupportedContentTypes": content_types,
"SupportedResponseMIMETypes": response_types,
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
"SupportedTransformInstanceTypes": transform_instances,
}
request_dict["InferenceSpecification"] = inference_specification
request_dict["CertifyForMarketplace"] = marketplace_cert
request_dict["ModelApprovalStatus"] = approval_status
return request_dict

def wait_for_model_package(self, model_package_name, poll=5):
"""Wait for an Amazon SageMaker endpoint deployment to complete.

Expand Down Expand Up @@ -4097,6 +4024,160 @@ def account_id(self) -> str:
return sts_client.get_caller_identity()["Account"]


def get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=None,
model_package_group_name=None,
model_data=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
tags=None,
container_def_list=None,
):
"""Get arguments for create_model_package method.

Args:
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
image_uri (str): Inference image uri for the container. Model class' self.image will
be used if it is None (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
container_def_list (list): A list of container defintiions.
Returns:
dict: A dictionary of method argument names and values.
"""
if container_def_list is not None:
containers = container_def_list
else:
container = {
"Image": image_uri,
"ModelDataUrl": model_data,
}
containers = [container]

model_package_args = {
"containers": containers,
"content_types": content_types,
"response_types": response_types,
"inference_instances": inference_instances,
"transform_instances": transform_instances,
"marketplace_cert": marketplace_cert,
}

if model_package_name is not None:
model_package_args["model_package_name"] = model_package_name
if model_package_group_name is not None:
model_package_args["model_package_group_name"] = model_package_group_name
if model_metrics is not None:
model_package_args["model_metrics"] = model_metrics._to_request_dict()
if metadata_properties is not None:
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
if approval_status is not None:
model_package_args["approval_status"] = approval_status
if description is not None:
model_package_args["description"] = description
if tags is not None:
model_package_args["tags"] = tags
return model_package_args


def get_create_model_package_request(
model_package_name=None,
model_package_group_name=None,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status="PendingManualApproval",
description=None,
tags=None,
):
"""Get request dictionary for CreateModelPackage API.

Args:
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
containers (list): A list of inference containers that can be used for inference
specifications of Model Package (default: None).
content_types (list): The supported MIME types for the input data (default: None).
response_types (list): The supported MIME types for the output data (default: None).
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
"""
if all([model_package_name, model_package_group_name]):
raise ValueError(
"model_package_name and model_package_group_name cannot be present at the " "same time."
)
request_dict = {}
if model_package_name is not None:
request_dict["ModelPackageName"] = model_package_name
if model_package_group_name is not None:
request_dict["ModelPackageGroupName"] = model_package_group_name
if description is not None:
request_dict["ModelPackageDescription"] = description
if tags is not None:
request_dict["Tags"] = tags
if model_metrics:
request_dict["ModelMetrics"] = model_metrics
if metadata_properties:
request_dict["MetadataProperties"] = metadata_properties
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
raise ValueError(
"content_types, response_types, inference_inferences and transform_instances "
"must be provided if containers is present."
)
inference_specification = {
"Containers": containers,
"SupportedContentTypes": content_types,
"SupportedResponseMIMETypes": response_types,
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
"SupportedTransformInstanceTypes": transform_instances,
}
request_dict["InferenceSpecification"] = inference_specification
request_dict["CertifyForMarketplace"] = marketplace_cert
request_dict["ModelApprovalStatus"] = approval_status
return request_dict


def update_args(args: Dict[str, Any], **kwargs):
"""Updates the request arguments dict with the value if populated.

Expand Down
Loading