Skip to content

Commit 47c8ff3

Browse files
author
Keshav Chandak
committed
feat: Model Package support for updating approval
1 parent 410ab2c commit 47c8ff3

File tree

14 files changed

+114
-34
lines changed

14 files changed

+114
-34
lines changed

src/sagemaker/chainer/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
def register(
150150
self,
151-
content_types: List[Union[str, PipelineVariable]],
152-
response_types: List[Union[str, PipelineVariable]],
151+
content_types: List[Union[str, PipelineVariable]] = None,
152+
response_types: List[Union[str, PipelineVariable]] = None,
153153
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154154
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155155
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def deploy(
16651665

16661666
def register(
16671667
self,
1668-
content_types,
1669-
response_types,
1668+
content_types=None,
1669+
response_types=None,
16701670
inference_instances=None,
16711671
transform_instances=None,
16721672
image_uri=None,

src/sagemaker/huggingface/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def deploy(
332332

333333
def register(
334334
self,
335-
content_types: List[Union[str, PipelineVariable]],
336-
response_types: List[Union[str, PipelineVariable]],
335+
content_types: List[Union[str, PipelineVariable]] = None,
336+
response_types: List[Union[str, PipelineVariable]] = None,
337337
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
338338
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
339339
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/model.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4444
load_sagemaker_config,
4545
)
46+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4647
from sagemaker.session import Session
4748
from sagemaker.model_metrics import ModelMetrics
4849
from sagemaker.deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375
self.dependencies = updates["dependencies"]
375376
self.uploaded_code = None
376377
self.repacked_model_data = None
378+
self.content_types = None
379+
self.response_types = None
377380

378381
@runnable_by_pipeline
379382
def register(
380383
self,
381-
content_types: List[Union[str, PipelineVariable]],
382-
response_types: List[Union[str, PipelineVariable]],
384+
content_types: List[Union[str, PipelineVariable]] = None,
385+
response_types: List[Union[str, PipelineVariable]] = None,
383386
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
384387
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
385388
model_package_name: Optional[Union[str, PipelineVariable]] = None,
@@ -456,16 +459,33 @@ def register(
456459
in case the Model instance is built with
457460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461
"""
459-
if self.model_data is None:
460-
raise ValueError("SageMaker Model Package cannot be created without model data.")
461462
if isinstance(self.model_data, dict):
462463
raise ValueError(
463464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464465
)
465466

467+
if content_types is not None:
468+
self.content_types = content_types
469+
470+
if response_types is not None:
471+
self.response_types = response_types
472+
473+
if self.content_types is None:
474+
raise ValueError("The supported MIME types for the input data is not set")
475+
476+
if self.response_types is None:
477+
raise ValueError("The supported MIME types for the output data is not set")
478+
466479
if image_uri is not None:
467480
self.image_uri = image_uri
468481

482+
if model_package_group_name is None and model_package_name is None:
483+
# If model package group and model package name is not set
484+
# then register to auto-generated model package group
485+
model_package_group_name = utils.base_name_from_image(
486+
self.image_uri, default_base_name=ModelPackage.__name__
487+
)
488+
469489
if model_package_group_name is not None:
470490
container_def = self.prepare_container_def()
471491
container_def = update_container_with_inference_params(
@@ -478,12 +498,14 @@ def register(
478498
else:
479499
container_def = {
480500
"Image": self.image_uri,
481-
"ModelDataUrl": self.model_data,
482501
}
483502

503+
if self.model_data is not None:
504+
container_def["ModelDataUrl"] = self.model_data
505+
484506
model_pkg_args = sagemaker.get_model_package_args(
485-
content_types,
486-
response_types,
507+
self.content_types,
508+
self.response_types,
487509
inference_instances=inference_instances,
488510
transform_instances=transform_instances,
489511
model_package_name=model_package_name,
@@ -1885,6 +1907,17 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851907
self._ensure_base_name_if_needed(model_package_name)
18861908
self._set_model_name_if_needed()
18871909

1910+
# Quering the approval status for the model package
1911+
# Approving the model package in case it is not approved
1912+
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
1913+
ModelPackageName=model_package_name
1914+
)
1915+
approval_status = model_package_desc["ModelApprovalStatus"]
1916+
if approval_status != ModelApprovalStatusEnum.APPROVED:
1917+
if self.model_package_arn is None:
1918+
self.model_package_arn = model_package_desc["ModelPackageArn"]
1919+
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)
1920+
18881921
self.sagemaker_session.create_model(
18891922
self.name,
18901923
self.role,
@@ -1898,3 +1931,25 @@ def _ensure_base_name_if_needed(self, base_name):
18981931
"""Set the base name if there is no model name provided."""
18991932
if self.name is None:
19001933
self._base_name = base_name
1934+
1935+
def update_approval_status(self, approval_status, approval_description=None):
1936+
"""Update the approval status for the model package
1937+
1938+
Args:
1939+
approval_status (str or PipelineVariable): Model Approval Status, values can be
1940+
"Approved", "Rejected", or "PendingManualApproval".
1941+
approval_description (str): Optional. Description for the approval status of the model
1942+
(default: None).
1943+
"""
1944+
if self.model_package_arn is None:
1945+
raise ValueError("model_package_arn is required to update the status.")
1946+
1947+
update_approval_args = {
1948+
"ModelPackageArn": self.model_package_arn,
1949+
"ModelApprovalStatus": approval_status,
1950+
}
1951+
1952+
if approval_description is not None:
1953+
update_approval_args["ApprovalDescription"] = approval_description
1954+
1955+
self.sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def __init__(
150150

151151
def register(
152152
self,
153-
content_types: List[Union[str, PipelineVariable]],
154-
response_types: List[Union[str, PipelineVariable]],
153+
content_types: List[Union[str, PipelineVariable]] = None,
154+
response_types: List[Union[str, PipelineVariable]] = None,
155155
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156156
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
157157
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
335335
@runnable_by_pipeline
336336
def register(
337337
self,
338-
content_types: List[Union[str, PipelineVariable]],
339-
response_types: List[Union[str, PipelineVariable]],
338+
content_types: List[Union[str, PipelineVariable]] = None,
339+
response_types: List[Union[str, PipelineVariable]] = None,
340340
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
341341
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
342342
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152

153153
def register(
154154
self,
155-
content_types: List[Union[str, PipelineVariable]],
156-
response_types: List[Union[str, PipelineVariable]],
155+
content_types: List[Union[str, PipelineVariable]] = None,
156+
response_types: List[Union[str, PipelineVariable]] = None,
157157
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
158158
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
159159
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/session.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5762,8 +5762,8 @@ def wait_for_inference_recommendations_job(
57625762

57635763

57645764
def get_model_package_args(
5765-
content_types,
5766-
response_types,
5765+
content_types=None,
5766+
response_types=None,
57675767
inference_instances=None,
57685768
transform_instances=None,
57695769
model_package_name=None,
@@ -5831,19 +5831,23 @@ def get_model_package_args(
58315831
else:
58325832
container = {
58335833
"Image": image_uri,
5834-
"ModelDataUrl": model_data,
58355834
}
5835+
if model_data is not None:
5836+
container["ModelDataUrl"] = model_data
5837+
58365838
containers = [container]
58375839

58385840
model_package_args = {
58395841
"containers": containers,
5840-
"content_types": content_types,
5841-
"response_types": response_types,
58425842
"inference_instances": inference_instances,
58435843
"transform_instances": transform_instances,
58445844
"marketplace_cert": marketplace_cert,
58455845
}
58465846

5847+
if content_types is not None:
5848+
model_package_args["content_types"] = content_types
5849+
if response_types is not None:
5850+
model_package_args["response_types"] = response_types
58475851
if model_package_name is not None:
58485852
model_package_args["model_package_name"] = model_package_name
58495853
if model_package_group_name is not None:

src/sagemaker/sklearn/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(
145145

146146
def register(
147147
self,
148-
content_types: List[Union[str, PipelineVariable]],
149-
response_types: List[Union[str, PipelineVariable]],
148+
content_types: List[Union[str, PipelineVariable]] = None,
149+
response_types: List[Union[str, PipelineVariable]] = None,
150150
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
151151
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
152152
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207

208208
def register(
209209
self,
210-
content_types: List[Union[str, PipelineVariable]],
211-
response_types: List[Union[str, PipelineVariable]],
210+
content_types: List[Union[str, PipelineVariable]] = None,
211+
response_types: List[Union[str, PipelineVariable]] = None,
212212
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
213213
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
214214
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/workflow/_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def arguments(self) -> RequestType:
443443
model = self.estimator.create_model(**self.kwargs)
444444
self.image_uri = model.image_uri
445445

446-
if self.model_data is None:
447-
self.model_data = model.model_data
448-
449446
# reset placeholder
450447
self.estimator.output_path = output_path
451448

src/sagemaker/xgboost/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def __init__(
133133

134134
def register(
135135
self,
136-
content_types: List[Union[str, PipelineVariable]],
137-
response_types: List[Union[str, PipelineVariable]],
136+
content_types: List[Union[str, PipelineVariable]] = None,
137+
response_types: List[Union[str, PipelineVariable]] = None,
138138
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
139139
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
140140
model_package_name: Optional[Union[str, PipelineVariable]] = None,

tests/unit/sagemaker/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
ENTRY_POINT_INFERENCE = "inference.py"
5454
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
5555
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
56-
56+
MODEL_PACKAGE_ARN = "arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1"
5757

5858
MODEL_DESCRIPTION = "a description"
5959

tests/unit/sagemaker/model/test_model_package.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import sagemaker
2121
from sagemaker.model import ModelPackage
22+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
2223

2324
DESCRIBE_MODEL_PACKAGE_RESPONSE = {
2425
"InferenceSpecification": {
@@ -34,6 +35,7 @@
3435
],
3536
"SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
3637
},
38+
"ModelApprovalStatus": "PendingManualApproval",
3739
"ModelPackageDescription": "Model Package created from training with "
3840
"arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
3941
"CreationTime": 1542752036.687,
@@ -52,6 +54,14 @@
5254
"CertifyForMarketplace": False,
5355
}
5456

57+
MODEL_DATA = {
58+
"S3DataSource": {
59+
"S3Uri": "s3://bucket/model/prefix/",
60+
"S3DataType": "S3Prefix",
61+
"CompressionType": "None",
62+
}
63+
}
64+
5565

5666
@pytest.fixture
5767
def sagemaker_session():
@@ -296,3 +306,17 @@ def test_model_package_create_transformer_with_product_id(sagemaker_session):
296306
assert transformer.model_name == "auto-generated-model"
297307
assert transformer.instance_type == "ml.m4.xlarge"
298308
assert transformer.env is None
309+
310+
311+
@patch("sagemaker.model.ModelPackage.update_approval_status")
312+
def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_session):
313+
tags = {"Key": "foo", "Value": "bar"}
314+
model_package = ModelPackage(
315+
role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session
316+
)
317+
model_package.deploy(tags=tags, instance_type="ml.p2.xlarge", initial_instance_count=1)
318+
319+
assert (
320+
update_approval_status.call_args_list[0][1]["approval_status"]
321+
== ModelApprovalStatusEnum.APPROVED
322+
)

0 commit comments

Comments
 (0)