Skip to content

feature: Adding serial inference pipeline support to RegisterModel Step #2405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0a0ce3c
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 19, 2021
5e90186
Merge branch 'master' into model-registry-sip
sreedes May 25, 2021
6561856
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 19, 2021
a4a352a
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
f23ff20
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
f52ab31
feature: Adding serial inference pipeline support to RegisterModel Step
sreedes May 27, 2021
9528898
Using Model class to abstract the containers
sreedes Jun 7, 2021
e86f8dd
Using Model class to abstract the containers
sreedes Jun 7, 2021
cd85320
Using Model class to abstract the containers
sreedes Jun 7, 2021
5f2c6c1
Formatting changes and lint fixes
sreedes Jun 7, 2021
deda665
Merge branch 'aws:master' into model-registry-sip
sreedes Jun 11, 2021
bcf55af
Merge branch 'master' into model-registry-sip
ahsan-z-khan Jun 14, 2021
562b56e
Repack steps for models
sreedes Jun 28, 2021
ee2b1bc
Repack steps for models
sreedes Jun 28, 2021
9c4a914
Merge branch 'master' into model-registry-sip
sreedes Jun 28, 2021
bdf809f
lint error corrections
sreedes Jun 30, 2021
ced6d1c
lint error corrections
sreedes Jun 30, 2021
bb5fd97
Unit test corrections
sreedes Jun 30, 2021
151d3bb
Unit test corrections
sreedes Jun 30, 2021
06f9659
Unit test corrections
sreedes Jun 30, 2021
1a349c3
Integ test added
sreedes Jul 5, 2021
aaf4d52
Merge branch 'master' into model-registry-sip
sreedes Jul 5, 2021
382da53
Update test_workflow.py
sreedes Jul 5, 2021
a5b1c43
Update model.py
sreedes Jul 5, 2021
36287f5
Update model.py
sreedes Jul 5, 2021
faac0a7
fix unit test failure
sreedes Jul 5, 2021
a0f446f
fixing unit tests for repack
sreedes Jul 5, 2021
2b7f86b
Path fix for the integ tests
sreedes Jul 6, 2021
f84c50f
Fix integ test paths
sreedes Jul 6, 2021
a614e6e
Black tool fixes
sreedes Jul 6, 2021
a0abe80
Test data files
sreedes Jul 6, 2021
ef1b929
Merge branch 'master' into model-registry-sip
sreedes Jul 8, 2021
6713ff6
Merge branch 'master' into model-registry-sip
ahsan-z-khan Jul 8, 2021
6538f6b
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 8, 2021
5261a2f
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 9, 2021
d95235f
Updating the input to RegisterModelStep to Pipeline Model
sreedes Jul 12, 2021
80e391e
Fix for local tests
sreedes Jul 12, 2021
6eeff17
Checking if the key is present
sreedes Jul 12, 2021
89ed8e2
black check fix
sreedes Jul 12, 2021
22fccc0
Fix unit test failures
sreedes Jul 12, 2021
6e12e41
Merge branch 'aws:master' into model-registry-sip
sreedes Jul 13, 2021
c2d5e61
Updating the input to model or pipeline model
sreedes Jul 13, 2021
4a8672e
Review comment incorporation
sreedes Jul 14, 2021
33db77f
Fixing unit test failure
sreedes Jul 14, 2021
a752214
Fixing unit test failure
sreedes Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _get_model_package_args(
marketplace_cert=False,
approval_status=None,
description=None,
container_def_list=None,
):
"""Get arguments for session.create_model_package method.

Expand All @@ -219,6 +220,7 @@ def _get_model_package_args(
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
container_def_list (list): A list of container defintiions.
Returns:
dict: A dictionary of method argument names and values.
"""
Expand All @@ -229,8 +231,13 @@ def _get_model_package_args(
"ModelDataUrl": self.model_data,
}

if container_def_list is not None:
containers = container_def_list
else:
containers = [container]

model_package_args = {
"containers": [container],
"containers": containers,
"content_types": content_types,
"response_types": response_types,
"inference_instances": inference_instances,
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,19 @@ class _RegisterModelStep(Step):
Estimator's training container image will be used (default: None).
compile_model_family (str): Instance family for compiled model, if specified, a compiled
model will be used (default: None).
container_def_list (list): A list of container defintiions.
**kwargs: additional arguments to `create_model`.
"""

def __init__(
self,
name: str,
estimator: EstimatorBase,
model_data,
content_types,
response_types,
inference_instances,
transform_instances,
model_data=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change necessary?

Copy link
Contributor Author

@sreedes sreedes Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now that we are passing model list, this input may be empty, hence marked it as None.

model_package_group_name=None,
model_metrics=None,
metadata_properties=None,
Expand All @@ -225,6 +226,7 @@ def __init__(
compile_model_family=None,
description=None,
depends_on: List[str] = None,
container_def_list=None,
**kwargs,
):
"""Constructor of a register model step.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - see below that model_data is missing a type in the argument description. i.e. model_data: the S3... please add model_data (str): the S3...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure i shall update.

Expand Down Expand Up @@ -254,6 +256,7 @@ def __init__(
description (str): Model Package description (default: None).
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
depends on
container_def_list (list): A list of container defintiions.
**kwargs: additional arguments to `create_model`.
"""
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
Expand All @@ -270,6 +273,7 @@ def __init__(
self.image_uri = image_uri
self.compile_model_family = compile_model_family
self.description = description
self.container_def_list = container_def_list
self.kwargs = kwargs

self._properties = Properties(
Expand Down Expand Up @@ -301,7 +305,7 @@ def arguments(self) -> RequestType:
self.estimator.output_path = output_path

# yeah, there is some framework stuff going on that we need to pull in here
if model.image_uri is None:
if model.image_uri is None and self.container_def_list is None:
region_name = self.estimator.sagemaker_session.boto_session.region_name
model.image_uri = image_uris.retrieve(
model._framework_name,
Expand All @@ -324,6 +328,7 @@ def arguments(self) -> RequestType:
metadata_properties=self.metadata_properties,
approval_status=self.approval_status,
description=self.description,
container_def_list=self.container_def_list,
)
request_dict = model.sagemaker_session._get_create_model_package_request(
**model_package_args
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,19 @@ def __init__(
self,
name: str,
estimator: EstimatorBase,
model_data,
content_types,
response_types,
inference_instances,
transform_instances,
model_data=None,
depends_on: List[str] = None,
model_package_group_name=None,
model_metrics=None,
approval_status=None,
image_uri=None,
compile_model_family=None,
description=None,
container_def_list=None,
**kwargs,
):
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
compile_model_family (str): The instance family for the compiled model. If
specified, a compiled model is used (default: None).
description (str): Model Package description (default: None).
container_def_list (list): A list of container defintiions.
**kwargs: additional arguments to `create_model`.
"""
steps: List[Step] = []
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
image_uri=image_uri,
compile_model_family=compile_model_family,
description=description,
container_def_list=container_def_list,
**kwargs,
)
if not repack_model:
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/sagemaker/workflow/test_step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,76 @@ def test_register_model_tf(estimator_tf, model_metrics):
)


def test_register_model_sip(estimator, model_metrics):
container_def_list = [
{
"Image": "fakeimage1",
"ModelDataUrl": "Url1",
"Environment": [{"k1": "v1"}, {"k2": "v2"}],
},
{
"Image": "fakeimage2",
"ModelDataUrl": "Url2",
"Environment": [{"k3": "v3"}, {"k4": "v4"}],
},
]

register_model = RegisterModel(
name="RegisterModelStep",
estimator=estimator,
content_types=["content_type"],
response_types=["response_type"],
inference_instances=["inference_instance"],
transform_instances=["transform_instance"],
model_package_group_name="mpg",
model_metrics=model_metrics,
approval_status="Approved",
description="description",
container_def_list=container_def_list,
depends_on=["TestStep"],
)
assert ordered(register_model.request_dicts()) == ordered(
[
{
"Name": "RegisterModelStep",
"Type": "RegisterModel",
"DependsOn": ["TestStep"],
"Arguments": {
"InferenceSpecification": {
"Containers": [
{
"Image": "fakeimage1",
"ModelDataUrl": "Url1",
"Environment": [{"k1": "v1"}, {"k2": "v2"}],
},
{
"Image": "fakeimage2",
"ModelDataUrl": "Url2",
"Environment": [{"k3": "v3"}, {"k4": "v4"}],
},
],
"SupportedContentTypes": ["content_type"],
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
"SupportedResponseMIMETypes": ["response_type"],
"SupportedTransformInstanceTypes": ["transform_instance"],
},
"ModelApprovalStatus": "Approved",
"ModelMetrics": {
"ModelQuality": {
"Statistics": {
"ContentType": "text/csv",
"S3Uri": f"s3://{BUCKET}/metrics.csv",
},
},
},
"ModelPackageDescription": "description",
"ModelPackageGroupName": "mpg",
},
},
]
)


def test_register_model_with_model_repack(estimator, model_metrics):
model_data = f"s3://{BUCKET}/model.tar.gz"
register_model = RegisterModel(
Expand Down