Skip to content

Commit 8f5b374

Browse files
feature: Adding serial inference pipeline support to RegisterModel Step (#2405)
* feature: Adding serial inference pipeline support to RegisterModel Step * feature: Adding serial inference pipeline support to RegisterModel Step * Using Model class to abstract the containers * Using Model class to abstract the containers * Formatting changes and lint fixes * Repack steps for models * Repack steps for models * lint error corrections * lint error corrections * Unit test corrections * Unit test corrections * Integ test added * Update test_workflow.py * Update model.py * Update model.py * fix unit test failure * fixing unit tests for repack * Path fix for the integ tests * Fix integ test paths * Black tool fixes * Test data files * Updating the input to RegisterModelStep to Pipeline Model * Fix for local tests * Checking if the key is present * black check fix * Fix unit test failures * Updating the input to model or pipeline model * Review comment incorporation * Fixing unit test failure * Fixing unit test failure Co-authored-by: Ahsan Khan <[email protected]>
1 parent 9e1fe91 commit 8f5b374

File tree

14 files changed

+1007
-204
lines changed

14 files changed

+1007
-204
lines changed

src/sagemaker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401
5656
from sagemaker.session import Session # noqa: F401
5757
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
58+
from sagemaker.session import get_model_package_args # noqa: F401
5859
from sagemaker.session import production_variant # noqa: F401
5960
from sagemaker.session import get_execution_role # noqa: F401
6061

src/sagemaker/model.py

Lines changed: 7 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,15 @@ def register(
158158
if self.model_data is None:
159159
raise ValueError("SageMaker Model Package cannot be created without model data.")
160160

161-
model_pkg_args = self._get_model_package_args(
161+
model_pkg_args = sagemaker.get_model_package_args(
162162
content_types,
163163
response_types,
164164
inference_instances,
165165
transform_instances,
166166
model_package_name,
167167
model_package_group_name,
168-
image_uri,
168+
self.model_data,
169+
image_uri or self.image_uri,
169170
model_metrics,
170171
metadata_properties,
171172
marketplace_cert,
@@ -181,80 +182,6 @@ def register(
181182
model_package_arn=model_package.get("ModelPackageArn"),
182183
)
183184

184-
def _get_model_package_args(
185-
self,
186-
content_types,
187-
response_types,
188-
inference_instances,
189-
transform_instances,
190-
model_package_name=None,
191-
model_package_group_name=None,
192-
image_uri=None,
193-
model_metrics=None,
194-
metadata_properties=None,
195-
marketplace_cert=False,
196-
approval_status=None,
197-
description=None,
198-
tags=None,
199-
):
200-
"""Get arguments for session.create_model_package method.
201-
202-
Args:
203-
content_types (list): The supported MIME types for the input data.
204-
response_types (list): The supported MIME types for the output data.
205-
inference_instances (list): A list of the instance types that are used to
206-
generate inferences in real-time.
207-
transform_instances (list): A list of the instance types on which a transformation
208-
job can be run or on which an endpoint can be deployed.
209-
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
210-
using `model_package_name` makes the Model Package un-versioned (default: None).
211-
model_package_group_name (str): Model Package Group name, exclusive to
212-
`model_package_name`, using `model_package_group_name` makes the Model Package
213-
versioned (default: None).
214-
image_uri (str): Inference image uri for the container. Model class' self.image will
215-
be used if it is None (default: None).
216-
model_metrics (ModelMetrics): ModelMetrics object (default: None).
217-
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
218-
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
219-
for AWS Marketplace (default: False).
220-
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
221-
or "PendingManualApproval" (default: "PendingManualApproval").
222-
description (str): Model Package description (default: None).
223-
Returns:
224-
dict: A dictionary of method argument names and values.
225-
"""
226-
if image_uri:
227-
self.image_uri = image_uri
228-
container = {
229-
"Image": self.image_uri,
230-
"ModelDataUrl": self.model_data,
231-
}
232-
233-
model_package_args = {
234-
"containers": [container],
235-
"content_types": content_types,
236-
"response_types": response_types,
237-
"inference_instances": inference_instances,
238-
"transform_instances": transform_instances,
239-
"marketplace_cert": marketplace_cert,
240-
}
241-
242-
if model_package_name is not None:
243-
model_package_args["model_package_name"] = model_package_name
244-
if model_package_group_name is not None:
245-
model_package_args["model_package_group_name"] = model_package_group_name
246-
if model_metrics is not None:
247-
model_package_args["model_metrics"] = model_metrics._to_request_dict()
248-
if metadata_properties is not None:
249-
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
250-
if approval_status is not None:
251-
model_package_args["approval_status"] = approval_status
252-
if description is not None:
253-
model_package_args["description"] = description
254-
if tags is not None:
255-
model_package_args["tags"] = tags
256-
return model_package_args
257-
258185
def _init_sagemaker_session_if_does_not_exist(self, instance_type):
259186
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
260187
@@ -1128,6 +1055,10 @@ def _upload_code(self, key_prefix, repack=False):
11281055
)
11291056

11301057
if repack and self.model_data is not None and self.entry_point is not None:
1058+
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
1059+
# model is not yet there, defer repacking to later during pipeline execution
1060+
return
1061+
11311062
bucket = self.bucket or self.sagemaker_session.default_bucket()
11321063
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
11331064

src/sagemaker/session.py

Lines changed: 158 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,8 +2521,11 @@ def _create_model_request(
25212521
request = {"ModelName": name, "ExecutionRoleArn": role}
25222522
if isinstance(container_definition, list):
25232523
request["Containers"] = container_definition
2524+
elif "ModelPackageName" in container_definition:
2525+
request["Containers"] = [container_definition]
25242526
else:
25252527
request["PrimaryContainer"] = container_definition
2528+
25262529
if tags:
25272530
request["Tags"] = tags
25282531

@@ -2731,7 +2734,7 @@ def create_model_package_from_containers(
27312734
description (str): Model Package description (default: None).
27322735
"""
27332736

2734-
request = self._get_create_model_package_request(
2737+
request = get_create_model_package_request(
27352738
model_package_name,
27362739
model_package_group_name,
27372740
containers,
@@ -2747,82 +2750,6 @@ def create_model_package_from_containers(
27472750
)
27482751
return self.sagemaker_client.create_model_package(**request)
27492752

2750-
def _get_create_model_package_request(
2751-
self,
2752-
model_package_name=None,
2753-
model_package_group_name=None,
2754-
containers=None,
2755-
content_types=None,
2756-
response_types=None,
2757-
inference_instances=None,
2758-
transform_instances=None,
2759-
model_metrics=None,
2760-
metadata_properties=None,
2761-
marketplace_cert=False,
2762-
approval_status="PendingManualApproval",
2763-
description=None,
2764-
tags=None,
2765-
):
2766-
"""Get request dictionary for CreateModelPackage API.
2767-
2768-
Args:
2769-
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
2770-
using `model_package_name` makes the Model Package un-versioned (default: None).
2771-
model_package_group_name (str): Model Package Group name, exclusive to
2772-
`model_package_name`, using `model_package_group_name` makes the Model Package
2773-
versioned (default: None).
2774-
containers (list): A list of inference containers that can be used for inference
2775-
specifications of Model Package (default: None).
2776-
content_types (list): The supported MIME types for the input data (default: None).
2777-
response_types (list): The supported MIME types for the output data (default: None).
2778-
inference_instances (list): A list of the instance types that are used to
2779-
generate inferences in real-time (default: None).
2780-
transform_instances (list): A list of the instance types on which a transformation
2781-
job can be run or on which an endpoint can be deployed (default: None).
2782-
model_metrics (ModelMetrics): ModelMetrics object (default: None).
2783-
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
2784-
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
2785-
for AWS Marketplace (default: False).
2786-
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
2787-
or "PendingManualApproval" (default: "PendingManualApproval").
2788-
description (str): Model Package description (default: None).
2789-
"""
2790-
if all([model_package_name, model_package_group_name]):
2791-
raise ValueError(
2792-
"model_package_name and model_package_group_name cannot be present at the "
2793-
"same time."
2794-
)
2795-
request_dict = {}
2796-
if model_package_name is not None:
2797-
request_dict["ModelPackageName"] = model_package_name
2798-
if model_package_group_name is not None:
2799-
request_dict["ModelPackageGroupName"] = model_package_group_name
2800-
if description is not None:
2801-
request_dict["ModelPackageDescription"] = description
2802-
if tags is not None:
2803-
request_dict["Tags"] = tags
2804-
if model_metrics:
2805-
request_dict["ModelMetrics"] = model_metrics
2806-
if metadata_properties:
2807-
request_dict["MetadataProperties"] = metadata_properties
2808-
if containers is not None:
2809-
if not all([content_types, response_types, inference_instances, transform_instances]):
2810-
raise ValueError(
2811-
"content_types, response_types, inference_inferences and transform_instances "
2812-
"must be provided if containers is present."
2813-
)
2814-
inference_specification = {
2815-
"Containers": containers,
2816-
"SupportedContentTypes": content_types,
2817-
"SupportedResponseMIMETypes": response_types,
2818-
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
2819-
"SupportedTransformInstanceTypes": transform_instances,
2820-
}
2821-
request_dict["InferenceSpecification"] = inference_specification
2822-
request_dict["CertifyForMarketplace"] = marketplace_cert
2823-
request_dict["ModelApprovalStatus"] = approval_status
2824-
return request_dict
2825-
28262753
def wait_for_model_package(self, model_package_name, poll=5):
28272754
"""Wait for an Amazon SageMaker endpoint deployment to complete.
28282755
@@ -4097,6 +4024,160 @@ def account_id(self) -> str:
40974024
return sts_client.get_caller_identity()["Account"]
40984025

40994026

4027+
def get_model_package_args(
4028+
content_types,
4029+
response_types,
4030+
inference_instances,
4031+
transform_instances,
4032+
model_package_name=None,
4033+
model_package_group_name=None,
4034+
model_data=None,
4035+
image_uri=None,
4036+
model_metrics=None,
4037+
metadata_properties=None,
4038+
marketplace_cert=False,
4039+
approval_status=None,
4040+
description=None,
4041+
tags=None,
4042+
container_def_list=None,
4043+
):
4044+
"""Get arguments for create_model_package method.
4045+
4046+
Args:
4047+
content_types (list): The supported MIME types for the input data.
4048+
response_types (list): The supported MIME types for the output data.
4049+
inference_instances (list): A list of the instance types that are used to
4050+
generate inferences in real-time.
4051+
transform_instances (list): A list of the instance types on which a transformation
4052+
job can be run or on which an endpoint can be deployed.
4053+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
4054+
using `model_package_name` makes the Model Package un-versioned (default: None).
4055+
model_package_group_name (str): Model Package Group name, exclusive to
4056+
`model_package_name`, using `model_package_group_name` makes the Model Package
4057+
versioned (default: None).
4058+
image_uri (str): Inference image uri for the container. Model class' self.image will
4059+
be used if it is None (default: None).
4060+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
4061+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
4062+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
4063+
for AWS Marketplace (default: False).
4064+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
4065+
or "PendingManualApproval" (default: "PendingManualApproval").
4066+
description (str): Model Package description (default: None).
4067+
container_def_list (list): A list of container defintiions.
4068+
Returns:
4069+
dict: A dictionary of method argument names and values.
4070+
"""
4071+
if container_def_list is not None:
4072+
containers = container_def_list
4073+
else:
4074+
container = {
4075+
"Image": image_uri,
4076+
"ModelDataUrl": model_data,
4077+
}
4078+
containers = [container]
4079+
4080+
model_package_args = {
4081+
"containers": containers,
4082+
"content_types": content_types,
4083+
"response_types": response_types,
4084+
"inference_instances": inference_instances,
4085+
"transform_instances": transform_instances,
4086+
"marketplace_cert": marketplace_cert,
4087+
}
4088+
4089+
if model_package_name is not None:
4090+
model_package_args["model_package_name"] = model_package_name
4091+
if model_package_group_name is not None:
4092+
model_package_args["model_package_group_name"] = model_package_group_name
4093+
if model_metrics is not None:
4094+
model_package_args["model_metrics"] = model_metrics._to_request_dict()
4095+
if metadata_properties is not None:
4096+
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
4097+
if approval_status is not None:
4098+
model_package_args["approval_status"] = approval_status
4099+
if description is not None:
4100+
model_package_args["description"] = description
4101+
if tags is not None:
4102+
model_package_args["tags"] = tags
4103+
return model_package_args
4104+
4105+
4106+
def get_create_model_package_request(
4107+
model_package_name=None,
4108+
model_package_group_name=None,
4109+
containers=None,
4110+
content_types=None,
4111+
response_types=None,
4112+
inference_instances=None,
4113+
transform_instances=None,
4114+
model_metrics=None,
4115+
metadata_properties=None,
4116+
marketplace_cert=False,
4117+
approval_status="PendingManualApproval",
4118+
description=None,
4119+
tags=None,
4120+
):
4121+
"""Get request dictionary for CreateModelPackage API.
4122+
4123+
Args:
4124+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
4125+
using `model_package_name` makes the Model Package un-versioned (default: None).
4126+
model_package_group_name (str): Model Package Group name, exclusive to
4127+
`model_package_name`, using `model_package_group_name` makes the Model Package
4128+
versioned (default: None).
4129+
containers (list): A list of inference containers that can be used for inference
4130+
specifications of Model Package (default: None).
4131+
content_types (list): The supported MIME types for the input data (default: None).
4132+
response_types (list): The supported MIME types for the output data (default: None).
4133+
inference_instances (list): A list of the instance types that are used to
4134+
generate inferences in real-time (default: None).
4135+
transform_instances (list): A list of the instance types on which a transformation
4136+
job can be run or on which an endpoint can be deployed (default: None).
4137+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
4138+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
4139+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
4140+
for AWS Marketplace (default: False).
4141+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
4142+
or "PendingManualApproval" (default: "PendingManualApproval").
4143+
description (str): Model Package description (default: None).
4144+
"""
4145+
if all([model_package_name, model_package_group_name]):
4146+
raise ValueError(
4147+
"model_package_name and model_package_group_name cannot be present at the " "same time."
4148+
)
4149+
request_dict = {}
4150+
if model_package_name is not None:
4151+
request_dict["ModelPackageName"] = model_package_name
4152+
if model_package_group_name is not None:
4153+
request_dict["ModelPackageGroupName"] = model_package_group_name
4154+
if description is not None:
4155+
request_dict["ModelPackageDescription"] = description
4156+
if tags is not None:
4157+
request_dict["Tags"] = tags
4158+
if model_metrics:
4159+
request_dict["ModelMetrics"] = model_metrics
4160+
if metadata_properties:
4161+
request_dict["MetadataProperties"] = metadata_properties
4162+
if containers is not None:
4163+
if not all([content_types, response_types, inference_instances, transform_instances]):
4164+
raise ValueError(
4165+
"content_types, response_types, inference_inferences and transform_instances "
4166+
"must be provided if containers is present."
4167+
)
4168+
inference_specification = {
4169+
"Containers": containers,
4170+
"SupportedContentTypes": content_types,
4171+
"SupportedResponseMIMETypes": response_types,
4172+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4173+
"SupportedTransformInstanceTypes": transform_instances,
4174+
}
4175+
request_dict["InferenceSpecification"] = inference_specification
4176+
request_dict["CertifyForMarketplace"] = marketplace_cert
4177+
request_dict["ModelApprovalStatus"] = approval_status
4178+
return request_dict
4179+
4180+
41004181
def update_args(args: Dict[str, Any], **kwargs):
41014182
"""Updates the request arguments dict with the value if populated.
41024183

0 commit comments

Comments
 (0)