Skip to content

feat: Model Package support for updating approval #4134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,8 @@ def deploy(

def register(
self,
content_types,
response_types,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
image_uri=None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def deploy(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
76 changes: 69 additions & 7 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
load_sagemaker_config,
)
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.session import Session
from sagemaker.model_metrics import ModelMetrics
from sagemaker.deprecations import removed_kwargs
Expand Down Expand Up @@ -374,12 +375,14 @@ def __init__(
self.dependencies = updates["dependencies"]
self.uploaded_code = None
self.repacked_model_data = None
self.content_types = None
self.response_types = None

@runnable_by_pipeline
def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -456,16 +459,33 @@ def register(
in case the Model instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
if self.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")
if isinstance(self.model_data, dict):
raise ValueError(
"SageMaker Model Package currently cannot be created with ModelDataSource."
)

if content_types is not None:
self.content_types = content_types

if response_types is not None:
self.response_types = response_types

if self.content_types is None:
raise ValueError("The supported MIME types for the input data is not set")

if self.response_types is None:
raise ValueError("The supported MIME types for the output data is not set")

if image_uri is not None:
self.image_uri = image_uri

if model_package_group_name is None and model_package_name is None:
# If model package group and model package name is not set
# then register to auto-generated model package group
model_package_group_name = utils.base_name_from_image(
self.image_uri, default_base_name=ModelPackage.__name__
)

if model_package_group_name is not None:
container_def = self.prepare_container_def()
container_def = update_container_with_inference_params(
Expand All @@ -478,12 +498,14 @@ def register(
else:
container_def = {
"Image": self.image_uri,
"ModelDataUrl": self.model_data,
}

if self.model_data is not None:
container_def["ModelDataUrl"] = self.model_data

model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
self.content_types,
self.response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_name=model_package_name,
Expand Down Expand Up @@ -511,6 +533,7 @@ def register(
role=self.role,
model_data=self.model_data,
model_package_arn=model_package.get("ModelPackageArn"),
sagemaker_session=self.sagemaker_session,
)

@runnable_by_pipeline
Expand Down Expand Up @@ -1751,6 +1774,7 @@ def __init__(

# works for MODEL_PACKAGE_ARN with or without version info.
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"


class ModelPackage(Model):
Expand Down Expand Up @@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
self._ensure_base_name_if_needed(model_package_name)
self._set_model_name_if_needed()

# Quering the approval status for the model package
# Approving the versioned model package in case it is not approved
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=self.model_package_arn or model_package_name
)
if self.model_package_arn is None:
self.model_package_arn = model_package_desc["ModelPackageArn"]
if re.match(MODEL_PACKAGE_VERSIONED_ARN_PATTERN, self.model_package_arn):
approval_status = model_package_desc.get("ModelApprovalStatus", "")
if approval_status != ModelApprovalStatusEnum.APPROVED:
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)

self.sagemaker_session.create_model(
self.name,
self.role,
Expand All @@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
"""Set the base name if there is no model name provided."""
if self.name is None:
self._base_name = base_name

def update_approval_status(self, approval_status, approval_description=None):
"""Update the approval status for the model package

Args:
approval_status (str or PipelineVariable): Model Approval Status, values can be
"Approved", "Rejected", or "PendingManualApproval".
approval_description (str): Optional. Description for the approval status of the model
(default: None).
"""

# Models can lazy-init sagemaker_session until deploy() is called to support
# LocalMode so we must make sure we have an actual session
sagemaker_session = self.sagemaker_session or sagemaker.Session()
if self.model_package_arn is None:
raise ValueError("model_package_arn is required to update the status.")

update_approval_args = {
"ModelPackageArn": self.model_package_arn,
"ModelApprovalStatus": approval_status,
}

if approval_description is not None:
update_approval_args["ApprovalDescription"] = approval_description

sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)
4 changes: 2 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
@runnable_by_pipeline
def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
14 changes: 9 additions & 5 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5830,8 +5830,8 @@ def wait_for_inference_recommendations_job(


def get_model_package_args(
content_types,
response_types,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_package_name=None,
Expand Down Expand Up @@ -5899,19 +5899,23 @@ def get_model_package_args(
else:
container = {
"Image": image_uri,
"ModelDataUrl": model_data,
}
if model_data is not None:
container["ModelDataUrl"] = model_data

containers = [container]

model_package_args = {
"containers": containers,
"content_types": content_types,
"response_types": response_types,
"inference_instances": inference_instances,
"transform_instances": transform_instances,
"marketplace_cert": marketplace_cert,
}

if content_types is not None:
model_package_args["content_types"] = content_types
if response_types is not None:
model_package_args["response_types"] = response_types
if model_package_name is not None:
model_package_args["model_package_name"] = model_package_name
if model_package_group_name is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
3 changes: 0 additions & 3 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,6 @@ def arguments(self) -> RequestType:
model = self.estimator.create_model(**self.kwargs)
self.image_uri = model.image_uri

if self.model_data is None:
self.model_data = model.model_data

# reset placeholder
self.estimator.output_path = output_path

Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def __init__(

def register(
self,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
Expand Down
63 changes: 63 additions & 0 deletions tests/integ/test_model_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR
from sagemaker.xgboost import XGBoostModel

_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")


def test_update_approval_model_package(sagemaker_session):

model_group_name = unique_name_from_base("test-model-group")

sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_group_name
)

xgb_model_data_s3 = sagemaker_session.upload_data(
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
model = XGBoostModel(
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
)

model_package = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=model_group_name,
)

model_package.update_approval_status(
approval_status=ModelApprovalStatusEnum.APPROVED, approval_description="dummy"
)

desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)
assert desc_model_package["ModelApprovalStatus"] == ModelApprovalStatusEnum.APPROVED
assert desc_model_package["ApprovalDescription"] == "dummy"

sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package.model_package_arn
)
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_group_name
)
1 change: 0 additions & 1 deletion tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"


MODEL_DESCRIPTION = "a description"

SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"]
Expand Down
Loading